[Mlir-commits] [mlir] (WIP) [MLIR] Add a new interface for "IR parameterization" (PR #78544)

Mehdi Amini llvmlistbot at llvm.org
Wed Jan 17 22:52:19 PST 2024


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/78544

This implements the ability to define "meta program": that is a mechanism similar to C++ template.

So as an example, this input IR:

```
testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
    %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
    testparametric.print_attr #testparametric.param<"B">
    return
}

func.func @caller() {
    %cst0 = arith.constant 0 : i32
    %cst1 = arith.constant 1. : f32
    %cst2 = arith.constant 2. : f64
    testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
    testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
    testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
    return
}
```

Will see the @callee parametric function be instantiated for each call-site:

```
  func.func @caller() {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant 1.000000e+00 : f32
    %cst_0 = arith.constant 2.000000e+00 : f64
    testparametric.call @callee$__mlir_instance__$A$i32$B$32(%c0_i32) meta = {} : (i32) -> ()
    testparametric.call @callee$__mlir_instance__$A$f32$B$64(%cst) meta = {} : (f32) -> ()
    testparametric.call @callee$__mlir_instance__$A$f64$B$128(%cst_0) meta = {} : (f64) -> ()
    return
  }
  testparametric.func @callee$__mlir_instance__$A$f32$B$64(%arg0: f32) {
    %0 = add %arg0, %arg0 : (f32, f32) -> f32
    print_attr 64 : i64
    return
  }
  testparametric.func @callee$__mlir_instance__$A$f64$B$128(%arg0: f64) {
    %0 = add %arg0, %arg0 : (f64, f64) -> f64
    print_attr 128 : i64
    return
  }
  testparametric.func @callee$__mlir_instance__$A$i32$B$32(%arg0: i32) {
    %0 = add %arg0, %arg0 : (i32, i32) -> i32
    print_attr 32 : i64
    return
  }
```

>From 69c09a66d055c937c0815fb368ec4629cf5cc65f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 17 Jan 2024 19:41:21 -0800
Subject: [PATCH] [MLIR] Add a new interface for "IR parameterization"

This implements the ability to define "meta program": that is a mechanism
similar to C++ template.

So as an example, this input IR:

```
testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
    %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
    testparametric.print_attr #testparametric.param<"B">
    return
}

func.func @caller() {
    %cst0 = arith.constant 0 : i32
    %cst1 = arith.constant 1. : f32
    %cst2 = arith.constant 2. : f64
    testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
    testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
    testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
    return
}
```

Will see the @callee parametric function be instantiated for each call-site:

```
  func.func @caller() {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant 1.000000e+00 : f32
    %cst_0 = arith.constant 2.000000e+00 : f64
    testparametric.call @callee$__mlir_instance__$A$i32$B$32(%c0_i32) meta = {} : (i32) -> ()
    testparametric.call @callee$__mlir_instance__$A$f32$B$64(%cst) meta = {} : (f32) -> ()
    testparametric.call @callee$__mlir_instance__$A$f64$B$128(%cst_0) meta = {} : (f64) -> ()
    return
  }
  testparametric.func @callee$__mlir_instance__$A$f32$B$64(%arg0: f32) {
    %0 = add %arg0, %arg0 : (f32, f32) -> f32
    print_attr 64 : i64
    return
  }
  testparametric.func @callee$__mlir_instance__$A$f64$B$128(%arg0: f64) {
    %0 = add %arg0, %arg0 : (f64, f64) -> f64
    print_attr 128 : i64
    return
  }
  testparametric.func @callee$__mlir_instance__$A$i32$B$32(%arg0: i32) {
    %0 = add %arg0, %arg0 : (i32, i32) -> i32
    print_attr 32 : i64
    return
  }
```
---
 mlir/include/mlir/IR/SymbolInterfaces.td      |   2 +-
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   1 +
 .../ParametricSpecializationOpInterface.h     |  25 ++
 .../ParametricSpecializationOpInterface.td    |  46 +++
 .../Transforms/ParametricSpecialization.h     |  11 +
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 .../ParametricSpecializationOpInterface.cpp   |  13 +
 mlir/lib/Transforms/CMakeLists.txt            |   2 +
 .../Transforms/ParametricSpecialization.cpp   |  13 +
 mlir/test/Parametric/ops.mlir                 |  18 ++
 mlir/test/lib/Dialect/CMakeLists.txt          |   1 +
 .../lib/Dialect/TestParametric/CMakeLists.txt |  68 ++++
 .../TestParametric/TestParametricAttrDefs.td  |  38 +++
 .../TestParametricAttributes.cpp              |  42 +++
 .../TestParametric/TestParametricAttributes.h |  34 ++
 .../TestParametric/TestParametricDialect.cpp  | 297 ++++++++++++++++++
 .../TestParametric/TestParametricDialect.h    |  45 +++
 .../TestParametric/TestParametricDialect.td   |  27 ++
 .../TestParametricInterfaces.cpp              |  11 +
 .../TestParametric/TestParametricInterfaces.h |  33 ++
 .../TestParametricInterfaces.td               |  34 ++
 .../TestParametric/TestParametricOps.td       | 202 ++++++++++++
 .../TestParametric/TestParametricTypeDefs.td  |  37 +++
 .../TestParametric/TestParametricTypes.cpp    |  42 +++
 .../TestParametric/TestParametricTypes.h      | 154 +++++++++
 .../lib/Dialect/TestParametric/lit.local.cfg  |   1 +
 mlir/test/lib/Transforms/CMakeLists.txt       |   1 +
 .../TestParametricSpecialization.cpp          | 191 +++++++++++
 mlir/tools/mlir-lsp-server/CMakeLists.txt     |   1 +
 .../tools/mlir-lsp-server/mlir-lsp-server.cpp |   4 +
 mlir/tools/mlir-opt/CMakeLists.txt            |   1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |   6 +
 32 files changed, 1402 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
 create mode 100644 mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
 create mode 100644 mlir/include/mlir/Transforms/ParametricSpecialization.h
 create mode 100644 mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
 create mode 100644 mlir/lib/Transforms/ParametricSpecialization.cpp
 create mode 100644 mlir/test/Parametric/ops.mlir
 create mode 100644 mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricOps.td
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp
 create mode 100644 mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h
 create mode 100644 mlir/test/lib/Dialect/TestParametric/lit.local.cfg
 create mode 100644 mlir/test/lib/Transforms/TestParametricSpecialization.cpp

diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837c..68bae5d3f991da 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -154,7 +154,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
       "bool", "isDeclaration", (ins), [{}],
       /*defaultImplementation=*/[{
         // By default, assume that the operation defines a symbol.
-        return false;
+        return false; 
       }]
     >,
   ];
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf01..2f3e34e266e3fb 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(ParallelCombiningOpInterface)
+add_mlir_interface(ParametricSpecializationOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
 add_mlir_interface(SideEffectInterfaces)
diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
new file mode 100644
index 00000000000000..88770e7239ac0f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.h
@@ -0,0 +1,25 @@
+//===- ParametricSpecializationOpInterface.h - Parallel combining op interface
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interface for ops that parallel combining
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
+#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES_H_
diff --git a/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
new file mode 100644
index 00000000000000..e3c12d6b4b60f9
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParametricSpecializationOpInterface.td
@@ -0,0 +1,46 @@
+//===-- ParametricSpecializationOpInterface.td -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
+#define MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def ParametricOpInterface : OpInterface<"ParametricOpInterface"> {
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"",
+      "::mlir::LogicalResult", "specialize", (ins
+        "::mlir::DictionaryAttr":$params)>,
+    InterfaceMethod<"",
+      "::mlir::LogicalResult", "checkOperand", (ins
+        "::mlir::OpOperand &":$operand,
+        "::mlir::Type":$concreteType)>,
+    InterfaceMethod<"Only for symbol operation which will be cloned, mangle in-place.",
+      "::mlir::FailureOr<::mlir::StringAttr>", "getMangledName", (ins
+        "::mlir::DictionaryAttr":$metaArgs), "", [{
+        return failure();
+      }]
+>,
+  ];
+}
+
+def SpecializingOpInterface : OpInterface<"SpecializingOpInterface"> {
+ let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"",
+      "::mlir::SymbolRefAttr", "getTarget", (ins)>,
+    InterfaceMethod<"",
+      "::mlir::DictionaryAttr", "getMetaArgs", (ins)>,
+    InterfaceMethod<"",
+      "::mlir::LogicalResult", "setSpecializedTarget", (ins
+        "::mlir::SymbolOpInterface":$target)>,
+  ];
+}
+
+#endif // MLIR_INTERFACES_PARAMETRICSPECIALIZATIONOPINTERFACES
diff --git a/mlir/include/mlir/Transforms/ParametricSpecialization.h b/mlir/include/mlir/Transforms/ParametricSpecialization.h
new file mode 100644
index 00000000000000..1bbe3e2a557ef1
--- /dev/null
+++ b/mlir/include/mlir/Transforms/ParametricSpecialization.h
@@ -0,0 +1,11 @@
+//===- RemoveDeadValues.h - Specialize Meta Program -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index e7c76e70ed6b5d..1998b66f168f36 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES
   LoopLikeInterface.cpp
   MemorySlotInterfaces.cpp
   ParallelCombiningOpInterface.cpp
+  ParametricSpecializationOpInterface.cpp
   RuntimeVerifiableOpInterface.cpp
   ShapedOpInterfaces.cpp
   SideEffectInterfaces.cpp
@@ -80,6 +81,7 @@ add_mlir_library(MLIRLoopLikeInterface
 
 add_mlir_interface_library(MemorySlotInterfaces)
 add_mlir_interface_library(ParallelCombiningOpInterface)
+add_mlir_interface_library(ParametricSpecializationOpInterface)
 add_mlir_interface_library(RuntimeVerifiableOpInterface)
 add_mlir_interface_library(ShapedOpInterfaces)
 add_mlir_interface_library(SideEffectInterfaces)
diff --git a/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
new file mode 100644
index 00000000000000..80fc2caf0d12ac
--- /dev/null
+++ b/mlir/lib/Interfaces/ParametricSpecializationOpInterface.cpp
@@ -0,0 +1,13 @@
+//===- ParametricSpecializationOpInterface.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index af51a4ab1157f1..8254f9d212c603 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransforms
   LoopInvariantCodeMotion.cpp
   Mem2Reg.cpp
   OpStats.cpp
+  ParametricSpecialization.cpp
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
@@ -32,6 +33,7 @@ add_mlir_library(MLIRTransforms
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface
   MLIRMemorySlotInterfaces
+  MLIRParametricSpecializationOpInterface
   MLIRPass
   MLIRRuntimeVerifiableOpInterface
   MLIRSideEffectInterfaces
diff --git a/mlir/lib/Transforms/ParametricSpecialization.cpp b/mlir/lib/Transforms/ParametricSpecialization.cpp
new file mode 100644
index 00000000000000..fcc3daacad447d
--- /dev/null
+++ b/mlir/lib/Transforms/ParametricSpecialization.cpp
@@ -0,0 +1,13 @@
+//===- RemoveDeadValues.cpp - Specialize Meta Program ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/ParametricSpecialization.h"
+
+using namespace mlir;
+
+void specialize(Operation *op) {}
\ No newline at end of file
diff --git a/mlir/test/Parametric/ops.mlir b/mlir/test/Parametric/ops.mlir
new file mode 100644
index 00000000000000..ed8c87cd48ccee
--- /dev/null
+++ b/mlir/test/Parametric/ops.mlir
@@ -0,0 +1,18 @@
+
+
+testparametric.func @callee(%arg0: !testparametric.param<"A"> ) attributes { metaParams = ["A", "B"]} {
+    %value = testparametric.add %arg0, %arg0 : (!testparametric.param<"A">, !testparametric.param<"A">) -> !testparametric.param<"A">
+    testparametric.print_attr #testparametric.param<"B">
+    return
+}
+
+func.func @caller() {
+    %cst0 = arith.constant 0 : i32
+    %cst1 = arith.constant 1. : f32
+    %cst2 = arith.constant 2. : f64
+    testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
+    testparametric.call @callee(%cst0) meta = {"A" = i32, "B" = 32 : i64 } : (i32) -> ()
+    testparametric.call @callee(%cst1) meta = {"A" = f32, "B" = 64 : i64 } : (f32) -> ()
+    testparametric.call @callee(%cst2) meta = {"A" = f64, "B" = 128 : i64 } : (f64) -> ()
+    return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 30a17c201ff763..8c1be74f15899e 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -17,6 +17,7 @@ add_subdirectory(SPIRV)
 add_subdirectory(Tensor)
 add_subdirectory(Test)
 add_subdirectory(TestDyn)
+add_subdirectory(TestParametric)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)
diff --git a/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
new file mode 100644
index 00000000000000..dcc79f15993b7e
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/CMakeLists.txt
@@ -0,0 +1,68 @@
+set(LLVM_OPTIONAL_SOURCES
+  TestParametricDialect.cpp
+)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricInterfaces.td)
+mlir_tablegen(TestParametricAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(TestParametricAttrInterfaces.cpp.inc -gen-attr-interface-defs)
+mlir_tablegen(TestParametricTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(TestParametricTypeInterfaces.cpp.inc -gen-type-interface-defs)
+mlir_tablegen(TestParametricOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(TestParametricOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTestParametricInterfaceIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricOps.td)
+mlir_tablegen(TestParametricAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TestParametricAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTestParametricAttrDefIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricTypeDefs.td)
+mlir_tablegen(TestParametricTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=testparametric)
+mlir_tablegen(TestParametricTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=testparametric)
+add_public_tablegen_target(MLIRTestParametricTypeDefIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestParametricOps.td)
+mlir_tablegen(TestParametricOps.h.inc -gen-op-decls)
+mlir_tablegen(TestParametricOps.cpp.inc -gen-op-defs)
+mlir_tablegen(TestParametricOpsDialect.h.inc -gen-dialect-decls -dialect=testparametric)
+mlir_tablegen(TestParametricOpsDialect.cpp.inc -gen-dialect-defs -dialect=testparametric)
+add_public_tablegen_target(MLIRTestParametricOpsIncGen)
+
+# Exclude testparametrics from libMLIR.so
+add_mlir_library(MLIRTestParametricDialect
+  TestParametricAttributes.cpp
+  TestParametricDialect.cpp
+  TestParametricInterfaces.cpp
+  TestParametricTypes.cpp
+  
+  EXCLUDE_FROM_LIBMLIR
+
+  DEPENDS
+  MLIRTestParametricAttrDefIncGen
+  MLIRTestParametricInterfaceIncGen
+  MLIRTestParametricTypeDefIncGen
+  MLIRTestParametricOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRControlFlowInterfaces
+  MLIRDataLayoutInterfaces
+  MLIRDerivedAttributeOpInterface
+  MLIRDestinationStyleOpInterface
+  MLIRDialect
+  MLIRDLTIDialect
+  MLIRFuncDialect
+  MLIRFunctionInterfaces
+  MLIRFuncTransforms
+  MLIRIR
+  MLIRInferIntRangeInterface
+  MLIRInferTypeOpInterface
+  MLIRLinalgDialect
+  MLIRLinalgTransforms
+  MLIRLLVMDialect
+  MLIRPass
+  MLIRReduce
+  MLIRTensorDialect
+  MLIRTransformUtils
+  MLIRTransforms
+)
+
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
new file mode 100644
index 00000000000000..c9133a99a34413
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttrDefs.td
@@ -0,0 +1,38 @@
+//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data attribute definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_ATTRDEFS
+#define TESTPARAMETRIC_ATTRDEFS
+
+// To get the test dialect definition.
+include "TestParametricDialect.td"
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpAsmInterface.td"
+
+// All of the attributes will extend this class.
+class TestParametric_Attr<string name, list<Trait> traits = []>
+    : AttrDef<TestParametric_Dialect, name, traits>;
+
+def TestParametric_ParamAttr : TestParametric_Attr<"Param"> {
+  let mnemonic = "param";
+  // List of type parameters.
+  let parameters = (
+    ins
+    "::mlir::StringAttr":$ref
+  );
+  let assemblyFormat = "`<` $ref `>`";
+}
+
+#endif // TESTPARAMETRIC_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
new file mode 100644
index 00000000000000..a5dde555b3e891
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.cpp
@@ -0,0 +1,42 @@
+//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains attributes defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricAttributes.h"
+#include "TestParametricDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace testparametric;
+
+//===----------------------------------------------------------------------===//
+// TestParametricDialect
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "TestParametricAttrDefs.cpp.inc"
+
+void TestParametricDialect::registerAttributes() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "TestParametricAttrDefs.cpp.inc"
+      >();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
new file mode 100644
index 00000000000000..054fc0f598d0a4
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricAttributes.h
@@ -0,0 +1,34 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTPARAMETRICATTRIBUTES_H
+#define MLIR_TESTPARAMETRICATTRIBUTES_H
+
+#include <tuple>
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+
+#include "TestParametricAttrInterfaces.h.inc"
+#include "TestParametricOpEnums.h.inc"
+#include "mlir/IR/DialectResourceBlobManager.h"
+
+namespace testparametric {} // namespace testparametric
+
+#define GET_ATTRDEF_CLASSES
+#include "TestParametricAttrDefs.h.inc"
+
+#endif // MLIR_TESTPARAMETRICATTRIBUTES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
new file mode 100644
index 00000000000000..693d93910e8188
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.cpp
@@ -0,0 +1,297 @@
+//===- TestParametricDialect.cpp - MLIR Dialect for Testing
+//----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricDialect.h"
+#include "TestParametricAttributes.h"
+#include "TestParametricInterfaces.h"
+#include "TestParametricTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/ODSSupport.h"
+#include "mlir/IR/OperationSupport.h"
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Base64.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+#include <numeric>
+#include <optional>
+
+// Include this before the using namespace lines below to
+// test that we don't have namespace dependencies.
+#include "TestParametricOpsDialect.cpp.inc"
+
+using namespace mlir;
+using namespace testparametric;
+
+void TestParametricDialect::initialize() {
+  registerAttributes();
+  registerTypes();
+  addOperations<
+#define GET_OP_LIST
+#include "TestParametricOps.cpp.inc"
+      >();
+}
+void testparametric::registerTestParametricDialect(DialectRegistry &registry) {
+  registry.insert<TestParametricDialect>();
+}
+
+#include "TestParametricOpInterfaces.cpp.inc"
+#include "TestParametricTypeInterfaces.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "TestParametricOps.cpp.inc"
+
+::mlir::ParseResult ParametricFuncOp::parse(mlir::OpAsmParser &parser,
+                                            mlir::OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void ParametricFuncOp::print(mlir::OpAsmPrinter &p) {
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  ParametricFuncOp fn =
+      symbolTable.lookupNearestSymbolFrom<ParametricFuncOp>(*this, fnAttr);
+  if (!fn)
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+
+  // Verify that the operand and result types match the callee.
+  auto fnType = fn.getFunctionType();
+  if (fnType.getNumInputs() != getNumOperands())
+    return emitOpError("incorrect number of operands for callee");
+
+  DictionaryAttr metaParams = fn.getMetaParamsAttr();
+  DictionaryAttr metaArgs = getMetaArgs();
+  if (metaParams && metaArgs.size() != metaParams.size())
+    return emitOpError("incorrect number of meta operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    auto operandType = getOperand(i).getType();
+    auto paramType = fnType.getInput(i);
+    if (auto metaParamType = dyn_cast<ParamType>(paramType)) {
+      auto metaArg = metaArgs.get(metaParamType.getRef());
+      if (!metaArg)
+        return emitOpError("Missing meta args for type operand ")
+               << metaParamType.getRef();
+      auto metaArgType = dyn_cast<TypeAttr>(metaArg);
+      if (!metaArgType)
+        return emitOpError("Expected TypeAttr for meta args ")
+               << metaParamType.getRef() << ", got " << metaArg;
+      if (metaArgType.getValue() != operandType)
+        return emitOpError("Mismatch between operand type and meta args type: ")
+               << operandType << " vs " << metaArgType;
+      continue;
+    }
+    if (operandType != paramType) {
+      return emitOpError("operand type mismatch: expected operand type ")
+             << fnType.getInput(i) << ", but provided "
+             << getOperand(i).getType() << " for operand number " << i;
+    }
+  }
+  if (fnType.getNumResults() != getNumResults())
+    return emitOpError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i).getType() != fnType.getResult(i)) {
+      auto diag = emitOpError("result type mismatch at index ") << i;
+      diag.attachNote() << "      op result types: " << getResultTypes();
+      diag.attachNote() << "function result types: " << fnType.getResults();
+      return diag;
+    }
+  }
+  return success();
+}
+
+/// Specialization Interface Implementation
+
+static LogicalResult replaceValueType(Value value, Type newType) {
+  for (OpOperand &use : value.getUses()) {
+    if (auto paramOp = dyn_cast<ParametricOpInterface>(use.getOwner())) {
+      if (failed(paramOp.checkOperand(use, newType))) {
+        paramOp.emitOpError() << "fails to replace operand type for operand #"
+                              << use.getOperandNumber() << " with " << newType;
+        return failure();
+      }
+    }
+  }
+  value.setType(newType);
+  return success();
+}
+static LogicalResult replaceValueType(Value value, DictionaryAttr metaArgs) {
+  auto paramType = dyn_cast<ParamType>(value.getType());
+  if (!paramType)
+    return success();
+  auto metaArg =
+      llvm::dyn_cast_or_null<TypeAttr>(metaArgs.get(paramType.getRef()));
+  if (!metaArg) {
+    if (value.getDefiningOp())
+      value.getDefiningOp()->emitError()
+          << "expected TypeAttr for specializing meta arg " << paramType
+          << ", got " << metaArgs;
+    return failure();
+  }
+  return replaceValueType(value, metaArg.getValue());
+}
+
+LogicalResult ParametricFuncOp::specialize(DictionaryAttr metaArgs) {
+  auto mangledName = getMangledName(metaArgs);
+  if (failed(mangledName))
+    return failure();
+  setSymNameAttr(*mangledName);
+  removeMetaParamsAttr();
+
+  auto specializeTypes = [&](auto typeRange, SmallVector<Type> &specialized) {
+    for (Type ty : typeRange) {
+      auto paramType = dyn_cast<ParamType>(ty);
+      if (!paramType) {
+        specialized.push_back(ty);
+        continue;
+      }
+      auto metaArg =
+          llvm::dyn_cast_or_null<TypeAttr>(metaArgs.get(paramType.getRef()));
+      if (!metaArg) {
+        emitOpError() << "expected TypeAttr for specializing meta arg "
+                      << paramType << ", got " << metaArgs;
+        return failure();
+      }
+      specialized.push_back(metaArg.getValue());
+    }
+    return success();
+  };
+  auto fnType = getFunctionType();
+  SmallVector<Type> argTypes, resTypes;
+  if (failed(specializeTypes(fnType.getInputs(), argTypes)))
+    return failure();
+  if (failed(specializeTypes(fnType.getResults(), resTypes)))
+    return failure();
+  for (auto argTypes : llvm::zip(argTypes, this->getArguments())) {
+    auto newType = std::get<0>(argTypes);
+    auto blockArg = std::get<1>(argTypes);
+    if (failed(replaceValueType(blockArg, newType)))
+      return failure();
+  }
+
+  setFunctionType(FunctionType::get(getContext(), argTypes, resTypes));
+  if (getFunctionBody()
+          .walk([&](Operation *op) {
+            if (auto parametricOp = dyn_cast<ParametricOpInterface>(op)) {
+              if (failed(parametricOp.specialize(metaArgs)))
+                return WalkResult::interrupt();
+            }
+            return WalkResult::advance();
+          })
+          .wasInterrupted())
+    return failure();
+  return success();
+}
+
+LogicalResult ParametricFuncOp::checkOperand(mlir::OpOperand &, mlir::Type) {
+  return success();
+}
+
+FailureOr<StringAttr>
+ParametricFuncOp::getMangledName(DictionaryAttr metaArgs) {
+  auto name = getNameAttr();
+  if (!name)
+    return failure();
+  std::string mangledName;
+  llvm::raw_string_ostream os(mangledName);
+  os << name.getValue() << "$__mlir_instance__";
+  for (NamedAttribute name : metaArgs) {
+    os << "$" << name.getName().getValue();
+    Attribute value = name.getValue();
+    if (auto intAttr = dyn_cast<IntegerAttr>(value))
+      os << "$" << intAttr.getValue();
+    else
+      os << "$" << value;
+  }
+
+  return StringAttr::get(getContext(), os.str());
+}
+
+LogicalResult AddOp::specialize(DictionaryAttr metaArgs) {
+  if (failed(replaceValueType(getResult(), metaArgs)))
+    return failure();
+  return success();
+}
+
+LogicalResult AddOp::checkOperand(mlir::OpOperand &, mlir::Type) {
+  return success();
+}
+
+SymbolRefAttr CallOp::getTarget() { return getCalleeAttr(); }
+
+LogicalResult CallOp::setSpecializedTarget(SymbolOpInterface target) {
+  // TODO: check validity first.
+  setCalleeAttr(SymbolRefAttr::get(target.getNameAttr()));
+  setMetaArgsAttr(DictionaryAttr::get(getContext()));
+  return success();
+}
+
+LogicalResult PrintAttrOp::specialize(DictionaryAttr metaArgs) {
+  auto valueAttr = dyn_cast_or_null<ParamAttr>(getValueAttr());
+  if (!valueAttr)
+    return success();
+  auto metaArg = metaArgs.get(valueAttr.getRef());
+  if (!metaArg) {
+    emitOpError() << "failed to specialize, missing " << valueAttr.getRef()
+                  << " entry in " << metaArgs;
+    return failure();
+  }
+  setValueAttr(metaArg);
+  return success();
+}
+
+LogicalResult PrintAttrOp::checkOperand(mlir::OpOperand &, mlir::Type) {
+  return success();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h
new file mode 100644
index 00000000000000..510c94c1100dc1
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.h
@@ -0,0 +1,45 @@
+//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a fake 'test' dialect that can be used for testing things
+// that do not have a respective counterpart in the main source directories.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTDIALECT_H
+#define MLIR_TESTDIALECT_H
+
+#include "TestParametricAttributes.h"
+#include "TestParametricInterfaces.h"
+#include "TestParametricTypes.h"
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include <memory>
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricOpInterfaces.h.inc"
+#include "TestParametricOpsDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "TestParametricOps.h.inc"
+
+namespace testparametric {
+void registerTestParametricDialect(::mlir::DialectRegistry &registry);
+} // namespace testparametric
+
+#endif // MLIR_TESTDIALECT_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td
new file mode 100644
index 00000000000000..54575cdc723eec
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricDialect.td
@@ -0,0 +1,27 @@
+//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_DIALECT
+#define TESTPARAMETRIC_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def TestParametric_Dialect : Dialect {
+  let name = "testparametric";
+  let cppNamespace = "::testparametric";
+  let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
+
+  let extraClassDeclaration = [{
+    void registerAttributes();
+    void registerTypes();
+ private:
+  }];
+}
+
+#endif // TESTPARAMETRIC_DIALECT
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp
new file mode 100644
index 00000000000000..597654c639e81d
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.cpp
@@ -0,0 +1,11 @@
+//===- TestInterfaces.cpp - MLIR interfaces for testing ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricInterfaces.h"
+
+using namespace mlir;
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h
new file mode 100644
index 00000000000000..092c24058c0176
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.h
@@ -0,0 +1,33 @@
+//===- TestInterfaces.h - MLIR interfaces for testing -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares interfaces for the 'test' dialect that can be used for
+// testing the interface infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
+#define MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+
+class SpecializationParams {
+public:
+  SpecializationParams() {}
+
+private:
+  DenseMap<StringAttr, Attribute> params;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td
new file mode 100644
index 00000000000000..954ab0cac1fcf1
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricInterfaces.td
@@ -0,0 +1,34 @@
+//===-- TestInterfaces.td - Test dialect interfaces --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES
+#define MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def TestParametricOpInterface : OpInterface<"TestParametricOpInterface"> {
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<"",
+      "LogicalResult", "specializeAttr", (ins
+        "::mlir::StringAttr":$parameterName,
+        "::mlir::Attribute":$concreteAttr)>,
+    InterfaceMethod<"",
+      "LogicalResult", "specializeType", (ins
+        "::mlir::StringAttr":$parameterName,
+        "::mlir::Type":$concreteType)>,
+    InterfaceMethod<"",
+      "LogicalResult", "checkOperand", (ins
+        "::mlir::OpOperand &":$operand,
+        "::mlir::Type":$concreteType)>,
+  ];
+}
+
+
+
+#endif // MLIR_TEST_DIALECT_TESTPARAMETRIC_INTERFACES
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td b/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td
new file mode 100644
index 00000000000000..43c04c763957cd
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricOps.td
@@ -0,0 +1,202 @@
+//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_OPS
+#define TESTPARAMETRIC_OPS
+
+include "TestParametricDialect.td"
+include "mlir/Dialect/DLTI/DLTIBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/PatternBase.td"
+include "mlir/IR/RegionKindInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Interfaces/ParametricSpecializationOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+
+// Include the attribute definitions.
+include "TestParametricAttrDefs.td"
+// Include the type definitions.
+include "TestParametricTypeDefs.td"
+
+
+class TESTParametric_Op<string mnemonic, list<Trait> traits = []> :
+    Op<TestParametric_Dialect, mnemonic, traits>;
+
+def TESTParametric_ParametricFuncOp : TESTParametric_Op<"func", [
+  AutomaticAllocationScope, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface,
+  DeclareOpInterfaceMethods<ParametricOpInterface>
+  ]> {
+  let summary = "Parametric function.";
+  let description = [{
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictionaryAttr>:$metaParams,
+    OptionalAttr<StrAttr>:$sym_visibility,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+    /// Returns the argument types of this function.
+    ::llvm::ArrayRef<::mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+    /// Returns the result types of this function.
+    ::llvm::ArrayRef<::mlir::Type> getResultTypes() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // OpAsmOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Allow the dialect prefix to be omitted.
+    static ::llvm::StringRef getDefaultDialect() { return "testparametric"; }
+
+    //===------------------------------------------------------------------===//
+    // SymbolOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    bool isDeclaration() { return isExternal(); }
+
+    //===------------------------------------------------------------------===//
+    // ParametricOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    ::mlir::FailureOr<::mlir::StringAttr> getMangledName(::mlir::DictionaryAttr);
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
+
+def ReturnOp : TESTParametric_Op<"return", [Pure, HasParent<"ParametricFuncOp">,
+                                ReturnLike, Terminator]> {
+  let summary = "Function return operation";
+  let description = [{
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
+def CallOp : TESTParametric_Op<"call",
+    [CallOpInterface,
+     DeclareOpInterfaceMethods<SpecializingOpInterface>,
+     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "call operation";
+  let description = [{
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<AnyType>:$operands,
+    DictionaryAttr:$metaArgs
+  );
+  let results = (outs Variadic<AnyType>);
+
+  let builders = [
+    OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::TypeRange":$results,
+      CArg<"mlir::ValueRange", "{}">:$operands), [{
+      $_state.addOperands(operands);
+      $_state.addAttribute("callee", callee);
+      $_state.addTypes(results);
+    }]>,
+    OpBuilder<(ins "mlir::StringAttr":$callee, "mlir::TypeRange":$results,
+      CArg<"mlir::ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, mlir::SymbolRefAttr::get(callee), results, operands);
+    }]>,
+    OpBuilder<(ins "llvm::StringRef":$callee, "mlir::TypeRange":$results,
+      CArg<"mlir::ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, mlir::StringAttr::get($_builder.getContext(), callee),
+            results, operands);
+    }]>];
+
+  let extraClassDeclaration = [{
+    mlir::FunctionType getCalleeType();
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    mlir::MutableOperandRange getArgOperandsMutable() {
+      return getOperandsMutable();
+    }
+
+    operand_iterator arg_operand_begin() { return operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+
+    /// Return the callee of this operation.
+    mlir::CallInterfaceCallable getCallableForCallee() {
+      return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
+    }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
+    }
+  }];
+
+  let assemblyFormat = [{
+    $callee `(` $operands `)` `meta` `=` $metaArgs attr-dict `:` functional-type($operands, results)
+  }];
+}
+
+def AddOp : TESTParametric_Op<"add", [Pure, DeclareOpInterfaceMethods<ParametricOpInterface>]> {
+  let summary = "Add operation";
+  let description = [{
+  }];
+
+  let arguments = (ins
+    AnyType:$lhs,
+    AnyType:$rhs
+  );
+  let results = (outs
+    AnyType:$result
+  );
+
+  let assemblyFormat = "$lhs `` `,` $rhs attr-dict `:` functional-type(operands, results)";
+}
+
+def PrintAttrOp : TESTParametric_Op<"print_attr", [DeclareOpInterfaceMethods<ParametricOpInterface>]> {
+  let summary = "Print operation";
+  let description = [{
+  }];
+
+  let arguments = (ins
+    AnyAttr:$value
+  );
+
+  let assemblyFormat = "$value attr-dict";
+}
+
+
+
+#endif // TESTPARAMETRIC_OPS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td b/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td
new file mode 100644
index 00000000000000..d3f3d0327db919
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypeDefs.td
@@ -0,0 +1,37 @@
+//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data type definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TESTPARAMETRIC_TYPEDEFS
+#define TESTPARAMETRIC_TYPEDEFS
+
+// To get the test dialect def.
+include "TestParametricDialect.td"
+include "TestParametricAttrDefs.td"
+include "TestParametricInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
+
+// All of the types will extend this class.
+class TestParametric_Type<string name, list<Trait> traits = []>
+    : TypeDef<TestParametric_Dialect, name, traits>;
+
+def TestParametric_ParamType : TestParametric_Type<"Param"> {
+  let mnemonic = "param";
+  // List of type parameters.
+  let parameters = (
+    ins
+    "::mlir::StringAttr":$ref
+  );
+  let assemblyFormat = "`<` $ref `>`";
+}
+
+#endif // TESTPARAMETRIC_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp
new file mode 100644
index 00000000000000..aaad6b880361b3
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.cpp
@@ -0,0 +1,42 @@
+//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestParametricTypes.h"
+#include "TestParametricDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/TypeSize.h"
+#include <optional>
+
+using namespace mlir;
+using namespace testparametric;
+
+#define GET_TYPEDEF_CLASSES
+#include "TestParametricTypeDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TestDialect
+//===----------------------------------------------------------------------===//
+
+void TestParametricDialect::registerTypes() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "TestParametricTypeDefs.cpp.inc"
+      >();
+}
diff --git a/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h
new file mode 100644
index 00000000000000..0e397d2bfb1ceb
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/TestParametricTypes.h
@@ -0,0 +1,154 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTPARAMETRICTYPES_H
+#define MLIR_TESTPARAMETRICTYPES_H
+
+#include <optional>
+#include <tuple>
+
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+namespace test {
+class TestAttrWithFormatAttr;
+
+/// FieldInfo represents a field in the StructType data type. It is used as a
+/// parameter in TestTypeDefs.td.
+struct FieldInfo {
+  ::llvm::StringRef name;
+  ::mlir::Type type;
+
+  // Custom allocation called from generated constructor code
+  FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
+    return FieldInfo{alloc.copyInto(name), type};
+  }
+};
+
+/// A custom type for a test type parameter.
+struct CustomParam {
+  int value;
+
+  bool operator==(const CustomParam &other) const {
+    return other.value == value;
+  }
+};
+
+inline llvm::hash_code hash_value(const test::CustomParam &param) {
+  return llvm::hash_value(param.value);
+}
+
+} // namespace test
+
+namespace mlir {
+template <>
+struct FieldParser<test::CustomParam> {
+  static FailureOr<test::CustomParam> parse(AsmParser &parser) {
+    auto value = FieldParser<int>::parse(parser);
+    if (failed(value))
+      return failure();
+    return test::CustomParam{*value};
+  }
+};
+
+inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
+                                    test::CustomParam param) {
+  return printer << param.value;
+}
+
+/// Overload the attribute parameter parser for optional integers.
+template <>
+struct FieldParser<std::optional<int>> {
+  static FailureOr<std::optional<int>> parse(AsmParser &parser) {
+    std::optional<int> value;
+    value.emplace();
+    OptionalParseResult result = parser.parseOptionalInteger(*value);
+    if (result.has_value()) {
+      if (succeeded(*result))
+        return value;
+      return failure();
+    }
+    value.reset();
+    return value;
+  }
+};
+} // namespace mlir
+
+#include "TestParametricTypeInterfaces.h.inc"
+
+namespace test {
+
+/// Storage for simple named recursive types, where the type is identified by
+/// its name and can "contain" another type, including itself.
+struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
+  using KeyTy = ::llvm::StringRef;
+
+  explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key) {}
+
+  bool operator==(const KeyTy &other) const { return name == other; }
+
+  static TestRecursiveTypeStorage *
+  construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<TestRecursiveTypeStorage>())
+        TestRecursiveTypeStorage(allocator.copyInto(key));
+  }
+
+  ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
+                               ::mlir::Type newBody) {
+    // Cannot set a different body than before.
+    if (body && body != newBody)
+      return ::mlir::failure();
+
+    body = newBody;
+    return ::mlir::success();
+  }
+
+  ::llvm::StringRef name;
+  ::mlir::Type body;
+};
+
+/// Simple recursive type identified by its name and pointing to another named
+/// type, potentially itself. This requires the body to be mutated separately
+/// from type creation.
+class TestRecursiveType
+    : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
+                                    TestRecursiveTypeStorage,
+                                    ::mlir::TypeTrait::IsMutable> {
+public:
+  using Base::Base;
+
+  static constexpr ::mlir::StringLiteral name = "test.recursive";
+
+  static TestRecursiveType get(::mlir::MLIRContext *ctx,
+                               ::llvm::StringRef name) {
+    return Base::get(ctx, name);
+  }
+
+  /// Body getter and setter.
+  ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
+  ::mlir::Type getBody() const { return getImpl()->body; }
+
+  /// Name/key getter.
+  ::llvm::StringRef getName() { return getImpl()->name; }
+};
+
+} // namespace test
+
+#define GET_TYPEDEF_CLASSES
+#include "TestParametricTypeDefs.h.inc"
+
+#endif // MLIR_TESTPARAMETRICTYPES_H
diff --git a/mlir/test/lib/Dialect/TestParametric/lit.local.cfg b/mlir/test/lib/Dialect/TestParametric/lit.local.cfg
new file mode 100644
index 00000000000000..65a7f202dc82a9
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestParametric/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove(".td")
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..038064417bf52d 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_library(MLIRTestTransforms
   TestControlFlowSink.cpp
   TestInlining.cpp
   TestIntRangeInference.cpp
+  TestParametricSpecialization.cpp
   TestMakeIsolatedFromAbove.cpp
   TestTopologicalSort.cpp
   ${MLIRTestTransformsPDLSrc}
diff --git a/mlir/test/lib/Transforms/TestParametricSpecialization.cpp b/mlir/test/lib/Transforms/TestParametricSpecialization.cpp
new file mode 100644
index 00000000000000..157d4b044c5038
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestParametricSpecialization.cpp
@@ -0,0 +1,191 @@
+//===- TestParametricSpecialization.cpp - Pass for metaprog
+// specialization--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Threading.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/ParametricSpecializationOpInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/ParametricSpecialization.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <mutex>
+#include <utility>
+
+#define DEBUG_TYPE "parametric-specialization"
+
+using namespace mlir;
+
+namespace {
+
+struct SpecializingRequest {
+  /// Op to specialize
+  ParametricOpInterface targetOp;
+  /// The arguments to specialize it with.
+  DictionaryAttr metaArgs;
+  /// The "callers" to update
+  SmallVector<SpecializingOpInterface, 0> callers;
+  /// The operation post-specialization
+  OwningOpRef<ParametricOpInterface> specialized;
+};
+
+struct TestParametricSpecializationPass
+    : public PassWrapper<TestParametricSpecializationPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParametricSpecializationPass)
+
+  StringRef getArgument() const final {
+    return "test-parametric-specialization";
+  }
+  StringRef getDescription() const final {
+    return "Test the parametric specialization of parametric programs.";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    if (!op->hasTrait<OpTrait::SymbolTable>()) {
+      op->emitOpError()
+          << getArgument()
+          << " pass can only run on an operation that defines a SymbolTable";
+      signalPassFailure();
+    }
+    OpBuilder builder(op->getContext());
+    SymbolTable symTab(op);
+
+    MLIRContext &ctx = getContext();
+
+    // Walk the body of the module, and find "roots": operations that are
+    // already specialized. We'll use these as "roots" to specialize the
+    // parametric ones.
+    SmallVector<Operation *> rootOps;
+    for (Operation &nestedOp : op->getRegion(0).getOps()) {
+      if (!isa<ParametricOpInterface>(nestedOp))
+        rootOps.push_back(&nestedOp);
+    }
+
+    std::map<StringRef, SpecializingRequest> specializationRequests;
+    std::mutex tasksMutex;
+    LogicalResult result = success();
+
+    // Run in parallel on every root, and for each, walk the body and find
+    // "calls" to functions that need specialization.
+    result = failableParallelForEach(&ctx, rootOps, [&](Operation *root) {
+      auto result = root->walk([&](Operation *innerOp) {
+        auto specializingOp = dyn_cast<SpecializingOpInterface>(innerOp);
+        if (!specializingOp)
+          return WalkResult::advance();
+        auto targetNameAttr = specializingOp.getTarget();
+        auto targetOp = symTab.lookup(targetNameAttr.getRootReference());
+        if (!targetOp) {
+          innerOp->emitOpError()
+              << "can't find target '" << targetNameAttr << "' in SymbolTable";
+          return WalkResult::interrupt();
+        }
+        auto parametricTargetOp = dyn_cast<ParametricOpInterface>(targetOp);
+        if (!parametricTargetOp) {
+          auto diag = targetOp->emitOpError();
+          diag << "expected target to implement 'ParametricOpInterface'";
+          diag.attachNote() << "while specializing " << *innerOp;
+          return WalkResult::interrupt();
+        }
+        auto metaArgs = specializingOp.getMetaArgs();
+        auto failureOrMangledName = parametricTargetOp.getMangledName(metaArgs);
+        if (failed(failureOrMangledName)) {
+          parametricTargetOp->emitOpError()
+              << "failed to mangled with meta args " << metaArgs;
+          return WalkResult::interrupt();
+        }
+        StringAttr mangledName = *failureOrMangledName;
+        std::unique_lock<std::mutex> lock(tasksMutex);
+        auto &request = specializationRequests[mangledName.getValue()];
+        if (request.targetOp && request.targetOp != targetOp) {
+          auto diag = targetOp->emitOpError();
+          diag << "unexpected mangling collision while specializing with "
+                  "meta args "
+               << metaArgs << ", mangled name " << mangledName;
+          diag.attachNote() << "while specializing " << *innerOp;
+          return WalkResult::interrupt();
+        }
+        request.targetOp = parametricTargetOp;
+        request.metaArgs = metaArgs;
+        request.callers.push_back(specializingOp);
+        LLVM_DEBUG({ llvm::errs() << "Request for " << mangledName << "\n"; });
+        return WalkResult::advance();
+      });
+      return success(!result.wasInterrupted());
+    });
+    if (failed(result)) {
+      signalPassFailure();
+      return;
+    }
+    LLVM_DEBUG({
+      llvm::errs() << "Got " << specializationRequests.size() << " requests\n";
+    });
+
+    std::map<StringRef,
+             std::pair<SmallVector<Operation *, 0> *, OwningOpRef<Operation *>>>
+        specializationResults;
+    result = failableParallelForEach(
+        &ctx, specializationRequests,
+        [&](std::pair<const llvm::StringRef, SpecializingRequest> &request) {
+          ParametricOpInterface targetOp = request.second.targetOp;
+          DictionaryAttr metaArgs = request.second.metaArgs;
+          OwningOpRef<ParametricOpInterface> specializedOp(targetOp.clone());
+          if (failed(specializedOp->specialize(metaArgs))) {
+            std::unique_lock<std::mutex> lock(tasksMutex);
+            targetOp->emitOpError() << "failed to specialize with " << metaArgs;
+            return failure();
+          }
+          request.second.specialized = std::move(specializedOp);
+          return success();
+        });
+    if (failed(result)) {
+      signalPassFailure();
+      return;
+    }
+
+    llvm::ThreadPool &threadPool = ctx.getThreadPool();
+    llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
+    for (auto &request : specializationRequests) {
+      Operation *op = request.second.specialized.get();
+      symTab.insert(op);
+      LLVM_DEBUG({
+        llvm::errs() << "Inserted " << cast<SymbolOpInterface>(op).getName()
+                     << "\n";
+      });
+      tasksGroup.async([&] {
+        Operation *op = request.second.specialized.release();
+        auto specializedOp = cast<SymbolOpInterface>(op);
+        for (SpecializingOpInterface caller : request.second.callers) {
+          if (failed(caller.setSpecializedTarget(specializedOp))) {
+            std::unique_lock<std::mutex> lock(tasksMutex);
+            caller->emitOpError() << "failed to specialize\n";
+            signalPassFailure();
+          }
+        }
+      });
+    }
+    tasksGroup.wait();
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestParametricSpecializationPass() {
+  PassRegistration<TestParametricSpecializationPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt
index 0134b54eef1b07..6480056c66e2b8 100644
--- a/mlir/tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt
@@ -18,6 +18,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestAnalysis
     MLIRTestDialect
     MLIRTestDynDialect
+    MLIRTestParametricDialect
     MLIRTestIR
     MLIRTestPass
     MLIRTestReducer
diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
index f0ecc5adc68b36..f64e278237d593 100644
--- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
+++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
@@ -20,6 +20,9 @@ void registerTestDialect(DialectRegistry &);
 void registerTestDynDialect(DialectRegistry &);
 void registerTestTransformDialectExtension(DialectRegistry &);
 } // namespace test
+namespace testparametric {
+void registerTestParametricDialect(DialectRegistry &);
+} // namespace testparametric
 #endif
 
 int main(int argc, char **argv) {
@@ -31,6 +34,7 @@ int main(int argc, char **argv) {
   ::test::registerTestDialect(registry);
   ::test::registerTestTransformDialectExtension(registry);
   ::test::registerTestDynDialect(registry);
+  ::testparametric::registerTestParametricDialect(registry);
 #endif
   return failed(MlirLspServerMain(argc, argv, registry));
 }
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 9ad5b32c24f9de..176152f4fd0f4b 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -35,6 +35,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestAnalysis
     MLIRTestDialect
     MLIRTestDynDialect
+    MLIRTestParametricDialect
     MLIRTestIR
     MLIRTestOneToNTypeConversionPass
     MLIRTestPass
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..54aaf1072ba163 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -125,6 +125,7 @@ void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
 void registerTestPadFusion();
+void registerTestParametricSpecializationPass();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
@@ -154,6 +155,9 @@ void registerTestDynDialect(DialectRegistry &);
 void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &);
 void registerTestTransformDialectExtension(DialectRegistry &);
 } // namespace test
+namespace testparametric {
+void registerTestParametricDialect(DialectRegistry &);
+} // namespace testparametric
 
 #ifdef MLIR_INCLUDE_TESTS
 void registerTestPasses() {
@@ -248,6 +252,7 @@ void registerTestPasses() {
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();
   mlir::test::registerTestPadFusion();
+  mlir::test::registerTestParametricSpecializationPass();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
@@ -293,6 +298,7 @@ int main(int argc, char **argv) {
   ::test::registerTestTransformDialectExtension(registry);
   ::test::registerTestTilingInterfaceTransformDialectExtension(registry);
   ::test::registerTestDynDialect(registry);
+  ::testparametric::registerTestParametricDialect(registry);
 #endif
   return mlir::asMainReturnCode(mlir::MlirOptMain(
       argc, argv, "MLIR modular optimizer driver\n", registry));



More information about the Mlir-commits mailing list