[Mlir-commits] [mlir] (WIP) [MLIR] Add a new interface for "IR parameterization" (PR #78544)
Mehdi Amini
llvmlistbot at llvm.org
Wed Jan 17 22:52:19 PST 2024
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/78544
This implements the ability to define "meta program": that is a mechanism similar to C++ template.
So as an example, this input IR:
```
testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
%value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
testparametric.print_attr #testparametric.param<"B">
return
}
func.func @caller() {
%cst0 = arith.constant 0 : i32
%cst1 = arith.constant 1. : f32
%cst2 = arith.constant 2. : f64
testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
return
}
```
Will see the @callee parametric function be instantiated for each call-site:
```
func.func @caller() {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f64
testparametric.call @callee$__mlir_instance__$A$i32$B$32(%c0_i32) meta = {} : (i32) -> ()
testparametric.call @callee$__mlir_instance__$A$f32$B$64(%cst) meta = {} : (f32) -> ()
testparametric.call @callee$__mlir_instance__$A$f64$B$128(%cst_0) meta = {} : (f64) -> ()
return
}
testparametric.func @callee$__mlir_instance__$A$f32$B$64(%arg0: f32) {
%0 = add %arg0, %arg0 : (f32, f32) -> f32
print_attr 64 : i64
return
}
testparametric.func @callee$__mlir_instance__$A$f64$B$128(%arg0: f64) {
%0 = add %arg0, %arg0 : (f64, f64) -> f64
print_attr 128 : i64
return
}
testparametric.func @callee$__mlir_instance__$A$i32$B$32(%arg0: i32) {
%0 = add %arg0, %arg0 : (i32, i32) -> i32
print_attr 32 : i64
return
}
```
>From 69c09a66d055c937c0815fb368ec4629cf5cc65f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 17 Jan 2024 19:41:21 -0800
Subject: [PATCH] [MLIR] Add a new interface for "IR parameterization"
This implements the ability to define "meta program": that is a mechanism
similar to C++ template.
So as an example, this input IR:
```
testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
%value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
testparametric.print_attr #testparametric.param<"B">
return
}
func.func @caller() {
%cst0 = arith.constant 0 : i32
%cst1 = arith.constant 1. : f32
%cst2 = arith.constant 2. : f64
testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
return
}
```
Will see the @callee parametric function be instantiated for each call-site:
```
func.func @caller() {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f64
testparametric.call @callee$__mlir_instance__$A$i32$B$32(%c0_i32) meta = {} : (i32) -> ()
testparametric.call @callee$__mlir_instance__$A$f32$B$64(%cst) meta = {} : (f32) -> ()
testparametric.call @callee$__mlir_instance__$A$f64$B$128(%cst_0) meta = {} : (f64) -> ()
return
}
testparametric.func @callee$__mlir_instance__$A$f32$B$64(%arg0: f32) {
%0 = add %arg0, %arg0 : (f32, f32) -> f32
print_attr 64 : i64
return
}
testparametric.func @callee$__mlir_instance__$A$f64$B$128(%arg0: f64) {
%0 = add %arg0, %arg0 : (f64, f64) -> f64
print_attr 128 : i64
return
}
testparametric.func @callee$__mlir_instance__$A$i32$B$32(%arg0: i32) {
%0 = add %arg0, %arg0 : (i32, i32) -> i32
print_attr 32 : i64
return
}
```
---
mlir/include/mlir/IR/SymbolInterfaces.td | 2 +-
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 +
.../ParametricSpecializationOpInterface.h | 25 ++
.../ParametricSpecializationOpInterface.td | 46 +++
.../Transforms/ParametricSpecialization.h | 11 +
mlir/lib/Interfaces/CMakeLists.txt | 2 +
.../ParametricSpecializationOpInterface.cpp | 13 +
mlir/lib/Transforms/CMakeLists.txt | 2 +
.../Transforms/ParametricSpecialization.cpp | 13 +
mlir/test/Parametric/ops.mlir | 18 ++
mlir/test/lib/Dialect/CMakeLists.txt | 1 +
.../lib/Dialect/TestParametric/CMakeLists.txt | 68 ++++
.../TestParametric/TestParametricAttrDefs.td | 38 +++
.../TestParametricAttributes.cpp | 42 +++
.../TestParametric/TestParametricAttributes.h | 34 ++
.../TestParametric/TestParametricDialect.cpp | 297 ++++++++++++++++++
.../TestParametric/TestParametricDialect.h | 45 +++
.../TestParametric/TestParametricDialect.td | 27 ++
.../TestParametricInterfaces.cpp | 11 +
.../TestParametric/TestParametricInterfaces.h | 33 ++
.../TestParametricInterfaces.td | 34 ++
.../TestParametric/TestParametricOps.td | 202 ++++++++++++
.../TestParametric/TestParametricTypeDefs.td | 37 +++
.../TestParametric/TestParametricTypes.cpp | 42 +++
.../TestParametric/TestParametricTypes.h | 154 +++++++++
.../lib/Dialect/TestParametric/lit.local.cfg | 1 +
mlir/test/lib/Transforms/CMakeLists.txt | 1 +
.../TestParametricSpecialization.cpp | 191 +++++++++++
mlir/tools/mlir-lsp-server/CMakeLists.txt | 1 +
.../tools/mlir-lsp-server/mlir-lsp-server.cpp | 4 +
mlir/tools/mlir-opt/CMakeLists.txt | 1 +
mlir/tools/mlir-opt/mlir-opt.cpp | 6 +
32 files changed, 1402 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
create mode 100644 mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
create mode 100644 mlir/include/mlir/Transforms/ParametricSpecialization.h
create mode 100644 mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
create mode 100644 mlir/lib/Transforms/ParametricSpecialization.cpp
create mode 100644 mlir/test/Parametric/ops.mlir
create mode 100644 mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricOps.td
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp
create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h
create mode 100644 mlir/test/lib/Dialect/TestParametric/lit.local.cfg
create mode 100644 mlir/test/lib/Transforms/TestParametricSpecialization.cpp
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837c..68bae5d3f991da 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -154,7 +154,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
"bool", "isDeclaration", (ins), [{}],
/*defaultImplementation=*/[{
// By default, assume that the operation defines a symbol.
- return false;
+ return false;
}]
>,
];
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf01..2f3e34e266e3fb 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
+add_mlir_interface(ParametricSpecializationOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
add_mlir_interface(SideEffectInterfaces)
diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
new file mode 100644
index 00000000000000..88770e7239ac0f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
@@ -0,0 +1,25 @@
+//===- ParametricSpecializationOpInterface.h - Parallel combining op interface
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interface for ops that parallel combining
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
+#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
new file mode 100644
index 00000000000000..e3c12d6b4b60f9
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
@@ -0,0 +1,46 @@
+//===-- ParametricSpecializationOpInterface.td -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
+#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def ParametricOpInterface : OpInterface<"ParametricOpInterface"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "::mlir::LogicalResult", "specialize", (ins
+ "::mlir::DictionaryAttr":$params)>,
+ InterfaceMethod<"",
+ "::mlir::LogicalResult", "checkOperand", (ins
+ "::mlir::OpOperand &":$operand,
+ "::mlir::Type":$concreteType)>,
+ InterfaceMethod<"Only for symbol operation which will be cloned, mangle in-place.",
+ "::mlir::FailureOr<::mlir::StringAttr>", "getMangledName", (ins
+ "::mlir::DictionaryAttr":$metaArgs), "", [{
+ return failure();
+ }]
+>,
+ ];
+}
+
+def SpecializingOpInterface : OpInterface<"SpecializingOpInterface"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "::mlir::SymbolRefAttr", "getTarget", (ins)>,
+ InterfaceMethod<"",
+ "::mlir::DictionaryAttr", "getMetaArgs", (ins)>,
+ InterfaceMethod<"",
+ "::mlir::LogicalResult", "setSpecializedTarget", (ins
+ "::mlir::SymbolOpInterface":$target)>,
+ ];
+}
+
+#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
diff --git a/mlir/include/mlir/Transforms/ParametricSpecialization.h b/mlir/include/mlir/Transforms/ParametricSpecialization.h
new file mode 100644
index 00000000000000..1bbe3e2a557ef1
--- /dev/null
+++ b/mlir/include/mlir/Transforms/ParametricSpecialization.h
@@ -0,0 +1,11 @@
+//===- RemoveDeadValues.h - Specialize Meta Program -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index e7c76e70ed6b5d..1998b66f168f36 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES
LoopLikeInterface.cpp
MemorySlotInterfaces.cpp
ParallelCombiningOpInterface.cpp
+ ParametricSpecializationOpInterface.cpp
RuntimeVerifiableOpInterface.cpp
ShapedOpInterfaces.cpp
SideEffectInterfaces.cpp
@@ -80,6 +81,7 @@ add_mlir_library(MLIRLoopLikeInterface
add_mlir_interface_library(MemorySlotInterfaces)
add_mlir_interface_library(ParallelCombiningOpInterface)
+add_mlir_interface_library(ParametricSpecializationOpInterface)
add_mlir_interface_library(RuntimeVerifiableOpInterface)
add_mlir_interface_library(ShapedOpInterfaces)
add_mlir_interface_library(SideEffectInterfaces)
diff --git a/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
new file mode 100644
index 00000000000000..80fc2caf0d12ac
--- /dev/null
+++ b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
@@ -0,0 +1,13 @@
+//===- ParametricSpecializationOpInterface.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index af51a4ab1157f1..8254f9d212c603 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransforms
LoopInvariantCodeMotion.cpp
Mem2Reg.cpp
OpStats.cpp
+ ParametricSpecialization.cpp
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
@@ -32,6 +33,7 @@ add_mlir_library(MLIRTransforms
MLIRFunctionInterfaces
MLIRLoopLikeInterface
MLIRMemorySlotInterfaces
+ MLIRParametricSpecializationOpInterface
MLIRPass
MLIRRuntimeVerifiableOpInterface
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Transforms/ParametricSpecialization.cpp b/mlir/lib/Transforms/ParametricSpecialization.cpp
new file mode 100644
index 00000000000000..fcc3daacad447d
--- /dev/null
+++ b/mlir/lib/Transforms/ParametricSpecialization.cpp
@@ -0,0 +1,13 @@
+//===- RemoveDeadValues.cpp - Specialize Meta Program ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/ParametricSpecialization.h"
+
+using namespace mlir;
+
+void specialize(Operation *op) {}
\ No newline at end of file
diff --git a/mlir/test/Parametric/ops.mlir b/mlir/test/Parametric/ops.mlir
new file mode 100644
index 00000000000000..ed8c87cd48ccee
--- /dev/null
+++ b/mlir/test/Parametric/ops.mlir
@@ -0,0 +1,18 @@
+
+
+testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
+ %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
+ testparametric.print_attr #testparametric.param<"B">
+ return
+}
+
+func.func @caller() {
+ %cst0 = arith.constant 0 : i32
+ %cst1 = arith.constant 1. : f32
+ %cst2 = arith.constant 2. : f64
+ testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
+ testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
+ testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
+ testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 30a17c201ff763..8c1be74f15899e 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -17,6 +17,7 @@ add_subdirectory(SPIRV)
add_subdirectory(Tensor)
add_subdirectory(Test)
add_subdirectory(TestDyn)
+add_subdirectory(TestParametric)
add_subdirectory(Tosa)
add_subdirectory(Transform)
add_subdirectory(Vector)
diff --git a/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
new file mode 100644
index 00000000000000..dcc79f15993b7e
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
@@ -0,0 +1,68 @@
+set(LLVM_OPTIONAL_SOURCES
+ TestParametricDialect.cpp
+)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricInterfaces.td)
+mlir_tablegen(TestParametricAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TestParametricAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(TestParametricTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(TestParametricTypeInterfaces.cpp.inc -gen-type-interface-defs)
+mlir_tablegen(TestParametricOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(TestParametricOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTestParametricInterfaceIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricOps.td)
+mlir_tablegen(TestParametricAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TestParametricAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTestParametricAttrDefIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricTypeDefs.td)
+mlir_tablegen(TestParametricTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=testparametric)
+mlir_tablegen(TestParametricTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=testparametric)
+add_public_tablegen_target(MLIRTestParametricTypeDefIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricOps.td)
+mlir_tablegen(TestParametricOps.h.inc -gen-op-decls)
+mlir_tablegen(TestParametricOps.cpp.inc -gen-op-defs)
+mlir_tablegen(TestParametricOpsDialect.h.inc -gen-dialect-decls -dialect=testparametric)
+mlir_tablegen(TestParametricOpsDialect.cpp.inc -gen-dialect-defs -dialect=testparametric)
+add_public_tablegen_target(MLIRTestParametricOpsIncGen)
+
+# Exclude testparametrics from libMLIR.so
+add_mlir_library(MLIRTestParametricDialect
+ TestParametricAttributes.cpp
+ TestParametricDialect.cpp
+ TestParametricInterfaces.cpp
+ TestParametricTypes.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ DEPENDS
+ MLIRTestParametricAttrDefIncGen
+ MLIRTestParametricInterfaceIncGen
+ MLIRTestParametricTypeDefIncGen
+ MLIRTestParametricOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRControlFlowInterfaces
+ MLIRDataLayoutInterfaces
+ MLIRDerivedAttributeOpInterface
+ MLIRDestinationStyleOpInterface
+ MLIRDialect
+ MLIRDLTIDialect
+ MLIRFuncDialect
+ MLIRFunctionInterfaces
+ MLIRFuncTransforms
+ MLIRIR
+ MLIRInferIntRangeInterface
+ MLIRInferTypeOpInterface
+ MLIRLinalgDialect
+ MLIRLinalgTransforms
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRReduce
+ MLIRTensorDialect
+ MLIRTransformUtils
+ MLIRTransforms
+)
+
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
new file mode 100644
index 00000000000000..c9133a99a34413
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
@@ -0,0 +1,38 @@
+//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data attribute definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_ATTRDEFS
+#define TESTPARAMETRIC_ATTRDEFS
+
+// To get the test dialect definition.
+include "TestParametricDialect.td"
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpAsmInterface.td"
+
+// All of the attributes will extend this class.
+class TestParametric_Attr<string name, list<Trait> traits = []>
+ : AttrDef<TestParametric_Dialect, name, traits>;
+
+def TestParametric_ParamAttr : TestParametric_Attr<"Param"> {
+ let mnemonic = "param";
+ // List of type parameters.
+ let parameters = (
+ ins
+ "::mlir::StringAttr":$ref
+ );
+ let assemblyFormat = "`<` $ref `>`";
+}
+
+#endif // TESTPARAMETRIC_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
new file mode 100644
index 00000000000000..a5dde555b3e891
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
@@ -0,0 +1,42 @@
+//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains attributes defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricAttributes.h"
+#include "TestParametricDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace testparametric;
+
+//===----------------------------------------------------------------------===//
+// TestParametricDialect
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "TestParametricAttrDefs.cpp.inc"
+
+void TestParametricDialect::registerAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "TestParametricAttrDefs.cpp.inc"
+ >();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
new file mode 100644
index 00000000000000..054fc0f598d0a4
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
@@ -0,0 +1,34 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTPARAMETRICATTRIBUTES_H
+#define MLIR_TESTPARAMETRICATTRIBUTES_H
+
+#include <tuple>
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+
+#include "TestParametricAttrInterfaces.h.inc"
+#include "TestParametricOpEnums.h.inc"
+#include "mlir/IR/DialectResourceBlobManager.h"
+
+namespace testparametric {} // namespace testparametric
+
+#define GET_ATTRDEF_CLASSES
+#include "TestParametricAttrDefs.h.inc"
+
+#endif // MLIR_TESTPARAMETRICATTRIBUTES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
new file mode 100644
index 00000000000000..693d93910e8188
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
@@ -0,0 +1,297 @@
+//===- TestParametricDialect.cpp - MLIR Dialect for Testing
+//----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricDialect.h"
+#include "TestParametricAttributes.h"
+#include "TestParametricInterfaces.h"
+#include "TestParametricTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/ODSSupport.h"
+#include "mlir/IR/OperationSupport.h"
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Base64.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+#include <numeric>
+#include <optional>
+
+// Include this before the using namespace lines below to
+// test that we don't have namespace dependencies.
+#include "TestParametricOpsDialect.cpp.inc"
+
+using namespace mlir;
+using namespace testparametric;
+
+void TestParametricDialect::initialize() {
+ registerAttributes();
+ registerTypes();
+ addOperations<
+#define GET_OP_LIST
+#include "TestParametricOps.cpp.inc"
+ >();
+}
+void testparametric::registerTestParametricDialect(DialectRegistry ®istry) {
+ registry.insert<TestParametricDialect>();
+}
+
+#include "TestParametricOpInterfaces.cpp.inc"
+#include "TestParametricTypeInterfaces.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "TestParametricOps.cpp.inc"
+
+::mlir::ParseResult ParametricFuncOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void ParametricFuncOp::print(mlir::OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ ParametricFuncOp fn =
+ symbolTable.lookupNearestSymbolFrom<ParametricFuncOp>(*this, fnAttr);
+ if (!fn)
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+
+ // Verify that the operand and result types match the callee.
+ auto fnType = fn.getFunctionType();
+ if (fnType.getNumInputs() != getNumOperands())
+ return emitOpError("incorrect number of operands for callee");
+
+ DictionaryAttr metaParams = fn.getMetaParamsAttr();
+ DictionaryAttr metaArgs = getMetaArgs();
+ if (metaParams && metaArgs.size() != metaParams.size())
+ return emitOpError("incorrect number of meta operands for callee");
+
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ auto operandType = getOperand(i).getType();
+ auto paramType = fnType.getInput(i);
+ if (auto metaParamType = dyn_cast<ParamType>(paramType)) {
+ auto metaArg = metaArgs.get(metaParamType.getRef());
+ if (!metaArg)
+ return emitOpError("Missing meta args for type operand ")
+ << metaParamType.getRef();
+ auto metaArgType = dyn_cast<TypeAttr>(metaArg);
+ if (!metaArgType)
+ return emitOpError("Expected TypeAttr for meta args ")
+ << metaParamType.getRef() << ", got " << metaArg;
+ if (metaArgType.getValue() != operandType)
+ return emitOpError("Mismatch between operand type and meta args type: ")
+ << operandType << " vs " << metaArgType;
+ continue;
+ }
+ if (operandType != paramType) {
+ return emitOpError("operand type mismatch: expected operand type ")
+ << fnType.getInput(i) << ", but provided "
+ << getOperand(i).getType() << " for operand number " << i;
+ }
+ }
+ if (fnType.getNumResults() != getNumResults())
+ return emitOpError("incorrect number of results for callee");
+
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i).getType() != fnType.getResult(i)) {
+ auto diag = emitOpError("result type mismatch at index ") << i;
+ diag.attachNote() << " op result types: " << getResultTypes();
+ diag.attachNote() << "function result types: " << fnType.getResults();
+ return diag;
+ }
+ }
+ return success();
+}
+
+/// Specialization Interface Implementation
+
+static LogicalResult replaceValueType(Value value, Type newType) {
+ for (OpOperand &use : value.getUses()) {
+ if (auto paramOp = dyn_cast<ParametricOpInterface>(use.getOwner())) {
+ if (failed(paramOp.checkOperand(use, newType))) {
+ paramOp.emitOpError() << "fails to replace operand type for operand #"
+ << use.getOperandNumber() << " with " << newType;
+ return failure();
+ }
+ }
+ }
+ value.setType(newType);
+ return success();
+}
+static LogicalResult replaceValueType(Value value, DictionaryAttr metaArgs) {
+ auto paramType = dyn_cast<ParamType>(value.getType());
+ if (!paramType)
+ return success();
+ auto metaArg =
+ llvm::dyn_cast_or_null<TypeAttr>(metaArgs.get(paramType.getRef()));
+ if (!metaArg) {
+ if (value.getDefiningOp())
+ value.getDefiningOp()->emitError()
+ << "expected TypeAttr for specializing meta arg " << paramType
+ << ", got " << metaArgs;
+ return failure();
+ }
+ return replaceValueType(value, metaArg.getValue());
+}
+
+LogicalResult ParametricFuncOp::specialize(DictionaryAttr metaArgs) {
+ auto mangledName = getMangledName(metaArgs);
+ if (failed(mangledName))
+ return failure();
+ setSymNameAttr(*mangledName);
+ removeMetaParamsAttr();
+
+ auto specializeTypes = [&](auto typeRange, SmallVector<Type> &specialized) {
+ for (Type ty : typeRange) {
+ auto paramType = dyn_cast<ParamType>(ty);
+ if (!paramType) {
+ specialized.push_back(ty);
+ continue;
+ }
+ auto metaArg =
+ llvm::dyn_cast_or_null<TypeAttr>(metaArgs.get(paramType.getRef()));
+ if (!metaArg) {
+ emitOpError() << "expected TypeAttr for specializing meta arg "
+ << paramType << ", got " << metaArgs;
+ return failure();
+ }
+ specialized.push_back(metaArg.getValue());
+ }
+ return success();
+ };
+ auto fnType = getFunctionType();
+ SmallVector<Type> argTypes, resTypes;
+ if (failed(specializeTypes(fnType.getInputs(), argTypes)))
+ return failure();
+ if (failed(specializeTypes(fnType.getResults(), resTypes)))
+ return failure();
+ for (auto argTypes : llvm::zip(argTypes, this->getArguments())) {
+ auto newType = std::get<0>(argTypes);
+ auto blockArg = std::get<1>(argTypes);
+ if (failed(replaceValueType(blockArg, newType)))
+ return failure();
+ }
+
+ setFunctionType(FunctionType::get(getContext(), argTypes, resTypes));
+ if (getFunctionBody()
+ .walk([&](Operation *op) {
+ if (auto parametricOp = dyn_cast<ParametricOpInterface>(op)) {
+ if (failed(parametricOp.specialize(metaArgs)))
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return failure();
+ return success();
+}
+
+LogicalResult ParametricFuncOp::checkOperand(mlir::OpOperand &, mlir::Type) {
+ return success();
+}
+
+FailureOr<StringAttr>
+ParametricFuncOp::getMangledName(DictionaryAttr metaArgs) {
+ auto name = getNameAttr();
+ if (!name)
+ return failure();
+ std::string mangledName;
+ llvm::raw_string_ostream os(mangledName);
+ os << name.getValue() << "$__mlir_instance__";
+ for (NamedAttribute name : metaArgs) {
+ os << "$" << name.getName().getValue();
+ Attribute value = name.getValue();
+ if (auto intAttr = dyn_cast<IntegerAttr>(value))
+ os << "$" << intAttr.getValue();
+ else
+ os << "$" << value;
+ }
+
+ return StringAttr::get(getContext(), os.str());
+}
+
+LogicalResult AddOp::specialize(DictionaryAttr metaArgs) {
+ if (failed(replaceValueType(getResult(), metaArgs)))
+ return failure();
+ return success();
+}
+
+LogicalResult AddOp::checkOperand(mlir::OpOperand &, mlir::Type) {
+ return success();
+}
+
+SymbolRefAttr CallOp::getTarget() { return getCalleeAttr(); }
+
+LogicalResult CallOp::setSpecializedTarget(SymbolOpInterface target) {
+ // TODO: check validity first.
+ setCalleeAttr(SymbolRefAttr::get(target.getNameAttr()));
+ setMetaArgsAttr(DictionaryAttr::get(getContext()));
+ return success();
+}
+
+LogicalResult PrintAttrOp::specialize(DictionaryAttr metaArgs) {
+ auto valueAttr = dyn_cast_or_null<ParamAttr>(getValueAttr());
+ if (!valueAttr)
+ return success();
+ auto metaArg = metaArgs.get(valueAttr.getRef());
+ if (!metaArg) {
+ emitOpError() << "failed to specialize, missing " << valueAttr.getRef()
+ << " entry in " << metaArgs;
+ return failure();
+ }
+ setValueAttr(metaArg);
+ return success();
+}
+
+LogicalResult PrintAttrOp::checkOperand(mlir::OpOperand &, mlir::Type) {
+ return success();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h
new file mode 100644
index 00000000000000..510c94c1100dc1
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h
@@ -0,0 +1,45 @@
+//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a fake 'test' dialect that can be used for testing things
+// that do not have a respective counterpart in the main source directories.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTDIALECT_H
+#define MLIR_TESTDIALECT_H
+
+#include "TestParametricAttributes.h"
+#include "TestParametricInterfaces.h"
+#include "TestParametricTypes.h"
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include <memory>
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricOpInterfaces.h.inc"
+#include "TestParametricOpsDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "TestParametricOps.h.inc"
+
+namespace testparametric {
+void registerTestParametricDialect(::mlir::DialectRegistry ®istry);
+} // namespace testparametric
+
+#endif // MLIR_TESTDIALECT_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td
new file mode 100644
index 00000000000000..54575cdc723eec
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td
@@ -0,0 +1,27 @@
+//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_DIALECT
+#define TESTPARAMETRIC_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def TestParametric_Dialect : Dialect {
+ let name = "testparametric";
+ let cppNamespace = "::testparametric";
+ let useDefaultAttributePrinterParser = 1;
+ let useDefaultTypePrinterParser = 1;
+
+ let extraClassDeclaration = [{
+ void registerAttributes();
+ void registerTypes();
+ private:
+ }];
+}
+
+#endif // TESTPARAMETRIC_DIALECT
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp
new file mode 100644
index 00000000000000..597654c639e81d
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp
@@ -0,0 +1,11 @@
+//===- TestInterfaces.cpp - MLIR interfaces for testing ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricInterfaces.h"
+
+using namespace mlir;
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h
new file mode 100644
index 00000000000000..092c24058c0176
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h
@@ -0,0 +1,33 @@
+//===- TestInterfaces.h - MLIR interfaces for testing -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares interfaces for the 'test' dialect that can be used for
+// testing the interface infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
+#define MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+
+class SpecializationParams {
+public:
+ SpecializationParams() {}
+
+private:
+ DenseMap<StringAttr, Attribute> params;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td
new file mode 100644
index 00000000000000..954ab0cac1fcf1
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td
@@ -0,0 +1,34 @@
+//===-- TestInterfaces.td - Test dialect interfaces --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES
+#define MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def TestParametricOpInterface : OpInterface<"TestParametricOpInterface"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "LogicalResult", "specializeAttr", (ins
+ "::mlir::StringAttr":$parameterName,
+ "::mlir::Attribute":$concreteAttr)>,
+ InterfaceMethod<"",
+ "LogicalResult", "specializeType", (ins
+ "::mlir::StringAttr":$parameterName,
+ "::mlir::Type":$concreteType)>,
+ InterfaceMethod<"",
+ "LogicalResult", "checkOperand", (ins
+ "::mlir::OpOperand &":$operand,
+ "::mlir::Type":$concreteType)>,
+ ];
+}
+
+
+
+#endif // MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td b/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td
new file mode 100644
index 00000000000000..43c04c763957cd
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td
@@ -0,0 +1,202 @@
+//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_OPS
+#define TESTPARAMETRIC_OPS
+
+include "TestParametricDialect.td"
+include "mlir/Dialect/DLTI/DLTIBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/PatternBase.td"
+include "mlir/IR/RegionKindInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Interfaces/ParametricSpecializationOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+
+// Include the attribute definitions.
+include "TestParametricAttrDefs.td"
+// Include the type definitions.
+include "TestParametricTypeDefs.td"
+
+
+class TESTParametric_Op<string mnemonic, list<Trait> traits = []> :
+ Op<TestParametric_Dialect, mnemonic, traits>;
+
+def TESTParametric_ParametricFuncOp : TESTParametric_Op<"func", [
+ AutomaticAllocationScope, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface,
+ DeclareOpInterfaceMethods<ParametricOpInterface>
+ ]> {
+ let summary = "Parametric function.";
+ let description = [{
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictionaryAttr>:$metaParams,
+ OptionalAttr<StrAttr>:$sym_visibility,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+ /// Returns the argument types of this function.
+ ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ::llvm::ArrayRef<::mlir::Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Allow the dialect prefix to be omitted.
+ static ::llvm::StringRef getDefaultDialect() { return "testparametric"; }
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+
+ //===------------------------------------------------------------------===//
+ // ParametricOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ ::mlir::FailureOr<::mlir::StringAttr> getMangledName(::mlir::DictionaryAttr);
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
+
+def ReturnOp : TESTParametric_Op<"return", [Pure, HasParent<"ParametricFuncOp">,
+ ReturnLike, Terminator]> {
+ let summary = "Function return operation";
+ let description = [{
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
+def CallOp : TESTParametric_Op<"call",
+ [CallOpInterface,
+ DeclareOpInterfaceMethods<SpecializingOpInterface>,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "call operation";
+ let description = [{
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$callee,
+ Variadic<AnyType>:$operands,
+ DictionaryAttr:$metaArgs
+ );
+ let results = (outs Variadic<AnyType>);
+
+ let builders = [
+ OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::TypeRange":$results,
+ CArg<"mlir::ValueRange", "{}">:$operands), [{
+ $_state.addOperands(operands);
+ $_state.addAttribute("callee", callee);
+ $_state.addTypes(results);
+ }]>,
+ OpBuilder<(ins "mlir::StringAttr":$callee, "mlir::TypeRange":$results,
+ CArg<"mlir::ValueRange", "{}">:$operands), [{
+ build($_builder, $_state, mlir::SymbolRefAttr::get(callee), results, operands);
+ }]>,
+ OpBuilder<(ins "llvm::StringRef":$callee, "mlir::TypeRange":$results,
+ CArg<"mlir::ValueRange", "{}">:$operands), [{
+ build($_builder, $_state, mlir::StringAttr::get($_builder.getContext(), callee),
+ results, operands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ mlir::FunctionType getCalleeType();
+
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ mlir::MutableOperandRange getArgOperandsMutable() {
+ return getOperandsMutable();
+ }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+
+ /// Return the callee of this operation.
+ mlir::CallInterfaceCallable getCallableForCallee() {
+ return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
+ }
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
+ }
+ }];
+
+ let assemblyFormat = [{
+ $callee `(` $operands `)` `meta` `=` $metaArgs attr-dict `:` functional-type($operands, results)
+ }];
+}
+
+def AddOp : TESTParametric_Op<"add", [Pure, DeclareOpInterfaceMethods<ParametricOpInterface>]> {
+ let summary = "Add operation";
+ let description = [{
+ }];
+
+ let arguments = (ins
+ AnyType:$lhs,
+ AnyType:$rhs
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = "$lhs `` `,` $rhs attr-dict `:` functional-type(operands, results)";
+}
+
+def PrintAttrOp : TESTParametric_Op<"print_attr", [DeclareOpInterfaceMethods<ParametricOpInterface>]> {
+ let summary = "Print operation";
+ let description = [{
+ }];
+
+ let arguments = (ins
+ AnyAttr:$value
+ );
+
+ let assemblyFormat = "$value attr-dict";
+}
+
+
+
+#endif // TESTPARAMETRIC_OPS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td
new file mode 100644
index 00000000000000..d3f3d0327db919
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td
@@ -0,0 +1,37 @@
+//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data type definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_TYPEDEFS
+#define TESTPARAMETRIC_TYPEDEFS
+
+// To get the test dialect def.
+include "TestParametricDialect.td"
+include "TestParametricAttrDefs.td"
+include "TestParametricInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+
+// All of the types will extend this class.
+class TestParametric_Type<string name, list<Trait> traits = []>
+ : TypeDef<TestParametric_Dialect, name, traits>;
+
+def TestParametric_ParamType : TestParametric_Type<"Param"> {
+ let mnemonic = "param";
+ // List of type parameters.
+ let parameters = (
+ ins
+ "::mlir::StringAttr":$ref
+ );
+ let assemblyFormat = "`<` $ref `>`";
+}
+
+#endif // TESTPARAMETRIC_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp
new file mode 100644
index 00000000000000..aaad6b880361b3
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp
@@ -0,0 +1,42 @@
+//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricTypes.h"
+#include "TestParametricDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/TypeSize.h"
+#include <optional>
+
+using namespace mlir;
+using namespace testparametric;
+
+#define GET_TYPEDEF_CLASSES
+#include "TestParametricTypeDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+void TestParametricDialect::registerTypes() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "TestParametricTypeDefs.cpp.inc"
+ >();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h
new file mode 100644
index 00000000000000..0e397d2bfb1ceb
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h
@@ -0,0 +1,154 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTPARAMETRICTYPES_H
+#define MLIR_TESTPARAMETRICTYPES_H
+
+#include <optional>
+#include <tuple>
+
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+namespace test {
+class TestAttrWithFormatAttr;
+
+/// FieldInfo represents a field in the StructType data type. It is used as a
+/// parameter in TestTypeDefs.td.
+struct FieldInfo {
+ ::llvm::StringRef name;
+ ::mlir::Type type;
+
+ // Custom allocation called from generated constructor code
+ FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
+ return FieldInfo{alloc.copyInto(name), type};
+ }
+};
+
+/// A custom type for a test type parameter.
+struct CustomParam {
+ int value;
+
+ bool operator==(const CustomParam &other) const {
+ return other.value == value;
+ }
+};
+
+inline llvm::hash_code hash_value(const test::CustomParam ¶m) {
+ return llvm::hash_value(param.value);
+}
+
+} // namespace test
+
+namespace mlir {
+template <>
+struct FieldParser<test::CustomParam> {
+ static FailureOr<test::CustomParam> parse(AsmParser &parser) {
+ auto value = FieldParser<int>::parse(parser);
+ if (failed(value))
+ return failure();
+ return test::CustomParam{*value};
+ }
+};
+
+inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
+ test::CustomParam param) {
+ return printer << param.value;
+}
+
+/// Overload the attribute parameter parser for optional integers.
+template <>
+struct FieldParser<std::optional<int>> {
+ static FailureOr<std::optional<int>> parse(AsmParser &parser) {
+ std::optional<int> value;
+ value.emplace();
+ OptionalParseResult result = parser.parseOptionalInteger(*value);
+ if (result.has_value()) {
+ if (succeeded(*result))
+ return value;
+ return failure();
+ }
+ value.reset();
+ return value;
+ }
+};
+} // namespace mlir
+
+#include "TestParametricTypeInterfaces.h.inc"
+
+namespace test {
+
+/// Storage for simple named recursive types, where the type is identified by
+/// its name and can "contain" another type, including itself.
+struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
+ using KeyTy = ::llvm::StringRef;
+
+ explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key) {}
+
+ bool operator==(const KeyTy &other) const { return name == other; }
+
+ static TestRecursiveTypeStorage *
+ construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TestRecursiveTypeStorage>())
+ TestRecursiveTypeStorage(allocator.copyInto(key));
+ }
+
+ ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
+ ::mlir::Type newBody) {
+ // Cannot set a different body than before.
+ if (body && body != newBody)
+ return ::mlir::failure();
+
+ body = newBody;
+ return ::mlir::success();
+ }
+
+ ::llvm::StringRef name;
+ ::mlir::Type body;
+};
+
+/// Simple recursive type identified by its name and pointing to another named
+/// type, potentially itself. This requires the body to be mutated separately
+/// from type creation.
+class TestRecursiveType
+ : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
+ TestRecursiveTypeStorage,
+ ::mlir::TypeTrait::IsMutable> {
+public:
+ using Base::Base;
+
+ static constexpr ::mlir::StringLiteral name = "test.recursive";
+
+ static TestRecursiveType get(::mlir::MLIRContext *ctx,
+ ::llvm::StringRef name) {
+ return Base::get(ctx, name);
+ }
+
+ /// Body getter and setter.
+ ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
+ ::mlir::Type getBody() const { return getImpl()->body; }
+
+ /// Name/key getter.
+ ::llvm::StringRef getName() { return getImpl()->name; }
+};
+
+} // namespace test
+
+#define GET_TYPEDEF_CLASSES
+#include "TestParametricTypeDefs.h.inc"
+
+#endif // MLIR_TESTPARAMETRICTYPES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/lit.local.cfg b/mlir/test/lib/Dialect/TestParametric/lit.local.cfg
new file mode 100644
index 00000000000000..65a7f202dc82a9
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove(".td")
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..038064417bf52d 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_library(MLIRTestTransforms
TestControlFlowSink.cpp
TestInlining.cpp
TestIntRangeInference.cpp
+ TestParametricSpecialization.cpp
TestMakeIsolatedFromAbove.cpp
TestTopologicalSort.cpp
${MLIRTestTransformsPDLSrc}
diff --git a/mlir/test/lib/Transforms/TestParametricSpecialization.cpp b/mlir/test/lib/Transforms/TestParametricSpecialization.cpp
new file mode 100644
index 00000000000000..157d4b044c5038
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestParametricSpecialization.cpp
@@ -0,0 +1,191 @@
+//===- TestParametricSpecialization.cpp - Pass for metaprog
+// specialization--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Threading.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/ParametricSpecialization.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <mutex>
+#include <utility>
+
+#define DEBUG_TYPE "parametric-specialization"
+
+using namespace mlir;
+
+namespace {
+
+struct SpecializingRequest {
+ /// Op to specialize
+ ParametricOpInterface targetOp;
+ /// The arguments to specialize it with.
+ DictionaryAttr metaArgs;
+ /// The "callers" to update
+ SmallVector<SpecializingOpInterface, 0> callers;
+ /// The operation post-specialization
+ OwningOpRef<ParametricOpInterface> specialized;
+};
+
+struct TestParametricSpecializationPass
+ : public PassWrapper<TestParametricSpecializationPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParametricSpecializationPass)
+
+ StringRef getArgument() const final {
+ return "test-parametric-specialization";
+ }
+ StringRef getDescription() const final {
+ return "Test the parametric specialization of parametric programs.";
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ if (!op->hasTrait<OpTrait::SymbolTable>()) {
+ op->emitOpError()
+ << getArgument()
+ << " pass can only run on an operation that defines a SymbolTable";
+ signalPassFailure();
+ }
+ OpBuilder builder(op->getContext());
+ SymbolTable symTab(op);
+
+ MLIRContext &ctx = getContext();
+
+ // Walk the body of the module, and find "roots": operations that are
+ // already specialized. We'll use these as "roots" to specialize the
+ // parametric ones.
+ SmallVector<Operation *> rootOps;
+ for (Operation &nestedOp : op->getRegion(0).getOps()) {
+ if (!isa<ParametricOpInterface>(nestedOp))
+ rootOps.push_back(&nestedOp);
+ }
+
+ std::map<StringRef, SpecializingRequest> specializationRequests;
+ std::mutex tasksMutex;
+ LogicalResult result = success();
+
+ // Run in parallel on every root, and for each, walk the body and find
+ // "calls" to functions that need specialization.
+ result = failableParallelForEach(&ctx, rootOps, [&](Operation *root) {
+ auto result = root->walk([&](Operation *innerOp) {
+ auto specializingOp = dyn_cast<SpecializingOpInterface>(innerOp);
+ if (!specializingOp)
+ return WalkResult::advance();
+ auto targetNameAttr = specializingOp.getTarget();
+ auto targetOp = symTab.lookup(targetNameAttr.getRootReference());
+ if (!targetOp) {
+ innerOp->emitOpError()
+ << "can't find target '" << targetNameAttr << "' in SymbolTable";
+ return WalkResult::interrupt();
+ }
+ auto parametricTargetOp = dyn_cast<ParametricOpInterface>(targetOp);
+ if (!parametricTargetOp) {
+ auto diag = targetOp->emitOpError();
+ diag << "expected target to implement 'ParametricOpInterface'";
+ diag.attachNote() << "while specializing " << *innerOp;
+ return WalkResult::interrupt();
+ }
+ auto metaArgs = specializingOp.getMetaArgs();
+ auto failureOrMangledName = parametricTargetOp.getMangledName(metaArgs);
+ if (failed(failureOrMangledName)) {
+ parametricTargetOp->emitOpError()
+ << "failed to mangled with meta args " << metaArgs;
+ return WalkResult::interrupt();
+ }
+ StringAttr mangledName = *failureOrMangledName;
+ std::unique_lock<std::mutex> lock(tasksMutex);
+ auto &request = specializationRequests[mangledName.getValue()];
+ if (request.targetOp && request.targetOp != targetOp) {
+ auto diag = targetOp->emitOpError();
+ diag << "unexpected mangling collision while specializing with "
+ "meta args "
+ << metaArgs << ", mangled name " << mangledName;
+ diag.attachNote() << "while specializing " << *innerOp;
+ return WalkResult::interrupt();
+ }
+ request.targetOp = parametricTargetOp;
+ request.metaArgs = metaArgs;
+ request.callers.push_back(specializingOp);
+ LLVM_DEBUG({ llvm::errs() << "Request for " << mangledName << "\n"; });
+ return WalkResult::advance();
+ });
+ return success(!result.wasInterrupted());
+ });
+ if (failed(result)) {
+ signalPassFailure();
+ return;
+ }
+ LLVM_DEBUG({
+ llvm::errs() << "Got " << specializationRequests.size() << " requests\n";
+ });
+
+ std::map<StringRef,
+ std::pair<SmallVector<Operation *, 0> *, OwningOpRef<Operation *>>>
+ specializationResults;
+ result = failableParallelForEach(
+ &ctx, specializationRequests,
+ [&](std::pair<const llvm::StringRef, SpecializingRequest> &request) {
+ ParametricOpInterface targetOp = request.second.targetOp;
+ DictionaryAttr metaArgs = request.second.metaArgs;
+ OwningOpRef<ParametricOpInterface> specializedOp(targetOp.clone());
+ if (failed(specializedOp->specialize(metaArgs))) {
+ std::unique_lock<std::mutex> lock(tasksMutex);
+ targetOp->emitOpError() << "failed to specialize with " << metaArgs;
+ return failure();
+ }
+ request.second.specialized = std::move(specializedOp);
+ return success();
+ });
+ if (failed(result)) {
+ signalPassFailure();
+ return;
+ }
+
+ llvm::ThreadPool &threadPool = ctx.getThreadPool();
+ llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
+ for (auto &request : specializationRequests) {
+ Operation *op = request.second.specialized.get();
+ symTab.insert(op);
+ LLVM_DEBUG({
+ llvm::errs() << "Inserted " << cast<SymbolOpInterface>(op).getName()
+ << "\n";
+ });
+ tasksGroup.async([&] {
+ Operation *op = request.second.specialized.release();
+ auto specializedOp = cast<SymbolOpInterface>(op);
+ for (SpecializingOpInterface caller : request.second.callers) {
+ if (failed(caller.setSpecializedTarget(specializedOp))) {
+ std::unique_lock<std::mutex> lock(tasksMutex);
+ caller->emitOpError() << "failed to specialize\n";
+ signalPassFailure();
+ }
+ }
+ });
+ }
+ tasksGroup.wait();
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestParametricSpecializationPass() {
+ PassRegistration<TestParametricSpecializationPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt
index 0134b54eef1b07..6480056c66e2b8 100644
--- a/mlir/tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt
@@ -18,6 +18,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestAnalysis
MLIRTestDialect
MLIRTestDynDialect
+ MLIRTestParametricDialect
MLIRTestIR
MLIRTestPass
MLIRTestReducer
diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
index f0ecc5adc68b36..f64e278237d593 100644
--- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
+++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
@@ -20,6 +20,9 @@ void registerTestDialect(DialectRegistry &);
void registerTestDynDialect(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
} // namespace test
+namespace testparametric {
+void registerTestParametricDialect(DialectRegistry &);
+} // namespace testparametric
#endif
int main(int argc, char **argv) {
@@ -31,6 +34,7 @@ int main(int argc, char **argv) {
::test::registerTestDialect(registry);
::test::registerTestTransformDialectExtension(registry);
::test::registerTestDynDialect(registry);
+ ::testparametric::registerTestParametricDialect(registry);
#endif
return failed(MlirLspServerMain(argc, argv, registry));
}
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 9ad5b32c24f9de..176152f4fd0f4b 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -35,6 +35,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestAnalysis
MLIRTestDialect
MLIRTestDynDialect
+ MLIRTestParametricDialect
MLIRTestIR
MLIRTestOneToNTypeConversionPass
MLIRTestPass
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..54aaf1072ba163 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -125,6 +125,7 @@ void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
void registerTestPadFusion();
+void registerTestParametricSpecializationPass();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
@@ -154,6 +155,9 @@ void registerTestDynDialect(DialectRegistry &);
void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
} // namespace test
+namespace testparametric {
+void registerTestParametricDialect(DialectRegistry &);
+} // namespace testparametric
#ifdef MLIR_INCLUDE_TESTS
void registerTestPasses() {
@@ -248,6 +252,7 @@ void registerTestPasses() {
mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestPadFusion();
+ mlir::test::registerTestParametricSpecializationPass();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSCFWhileOpBuilderPass();
@@ -293,6 +298,7 @@ int main(int argc, char **argv) {
::test::registerTestTransformDialectExtension(registry);
::test::registerTestTilingInterfaceTransformDialectExtension(registry);
::test::registerTestDynDialect(registry);
+ ::testparametric::registerTestParametricDialect(registry);
#endif
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "MLIR modular optimizer driver\n", registry));
More information about the Mlir-commits
mailing list