[Mlir-commits] [mlir] 43b9fa3 - [mlir][Linalg][Python] Create the body of builtin named Linalg ops
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Mar 31 01:01:06 PDT 2021
Author: Nicolas Vasilache
Date: 2021-03-31T07:58:32Z
New Revision: 43b9fa3ce0ddfa673158af1596c3aac613b258b3
URL: https://github.com/llvm/llvm-project/commit/43b9fa3ce0ddfa673158af1596c3aac613b258b3
DIFF: https://github.com/llvm/llvm-project/commit/43b9fa3ce0ddfa673158af1596c3aac613b258b3.diff
LOG: [mlir][Linalg][Python] Create the body of builtin named Linalg ops
This revision adds support to properly add the body of registered
builtin named linalg ops.
At this time, indexing_map and iterator_type support is still
missing so the op is not executable yet.
Differential Revision: https://reviews.llvm.org/D99578
Added:
mlir/lib/Bindings/Python/DialectLinalg.cpp
mlir/lib/Bindings/Python/DialectLinalg.h
Modified:
mlir/include/mlir-c/Dialect/Linalg.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/lib/CAPI/Dialect/Linalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/test/Bindings/Python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index be73a5c8c207a..06f15f062c333 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -17,6 +17,11 @@
extern "C" {
#endif
+/// Apply the special region builder for the builtin named Linalg op.
+/// Assert that `op` is a builtin named Linalg op.
+MLIR_CAPI_EXPORTED void
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 5a906ff2dafdf..007cb6de12f60 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -37,6 +37,14 @@ def Linalg_Dialect : Dialect {
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
+ let extraClassDeclaration = [{
+ using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
+ RegionBuilderFunType getRegionBuilder(StringRef name) {
+ return namedStructuredOpRegionBuilders.lookup(name);
+ }
+ private:
+ llvm::StringMap<RegionBuilderFunType> namedStructuredOpRegionBuilders;
+ }];
}
// Whether a type is a RangeType.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 71ac601977fa6..d94e43b78edf1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 43d6275d4d20c..39192cc54d3c2 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -69,6 +69,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
INSTALL_DIR
python
SOURCES
+ DialectLinalg.cpp
MainModule.cpp
IRAffine.cpp
IRAttributes.cpp
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
new file mode 100644
index 0000000000000..e4ef69411be8a
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -0,0 +1,34 @@
+//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/IR.h"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(py::module &m) {
+ m.def(
+ "fill_builtin_region",
+ [](PyDialectDescriptor &dialect, PyOperation &op) {
+ return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
+ },
+ py::arg("dialect"), py::arg("op"),
+ "Fill the region for `op`, which is assumed to be a builtin named Linalg "
+ "op.");
+}
+
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.h b/mlir/lib/Bindings/Python/DialectLinalg.h
new file mode 100644
index 0000000000000..3735dbf6f6286
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectLinalg.h
@@ -0,0 +1,22 @@
+//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
+#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 5fe0401afaeb6..79128f2677a5c 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -10,6 +10,7 @@
#include "PybindUtils.h"
+#include "DialectLinalg.h"
#include "ExecutionEngine.h"
#include "Globals.h"
#include "IRModule.h"
@@ -225,4 +226,9 @@ PYBIND11_MODULE(_mlir, m) {
auto executionEngineModule =
m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
populateExecutionEngineSubmodule(executionEngineModule);
+
+ // Define and populate Linalg submodule.
+ auto dialectsModule = m.def_submodule("dialects");
+ auto linalgModule = dialectsModule.def_submodule("linalg");
+ populateDialectLinalgSubmodule(linalgModule);
}
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
index d6dc9895f89a8..002ae51ba1b04 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -61,11 +61,10 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
raise NotImplementedError(
f"Emission of composite linalg ops not supported: {op_configs}")
- # TODO: this file should probably not be called dsl.py but rather is a client
- # of the dsl.py.
- from .... import linalg as linalg_ops
- emit_generic = (emit_generic or
- (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
+ ctx = ir.Context.current
+ linalgDialect = ctx.get_dialect_descriptor("linalg")
+ fully_qualified_name = 'linalg.' + self.op_name
+ emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))
op_config = op_configs[0]
if op_config.structured_op:
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
index e8e7eb5c3463e..2395a422e354b 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -7,6 +7,9 @@
from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std
+# TODO: resolve name collision for Linalg functionality that is injected inside
+# the _mlir.dialects.linalg directly via pybind.
+from _mlir.dialects.linalg import fill_builtin_region
from .scalar_expr import *
from .config import *
@@ -16,7 +19,6 @@
"emit_named_structured_op",
]
-
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
outs: Value):
@@ -97,11 +99,18 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
type_mapping, indexing_maps_attr, iterator_types_attr = \
prepare_common_structured_op(op_config, *ins, outs = outs)
- if not op_class_name in linalg.__dict__.keys():
+ # If we get here, there must exist a builtin class `op_class_name`.
+ ctx = Context.current
+ fully_qualified_name = 'linalg.' + op_name
+ if (not ctx.is_registered_operation(fully_qualified_name) or
+ not op_class_name in linalg.__dict__.keys()):
raise NotImplementedError(
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
+ linalgDialect = ctx.get_dialect_descriptor("linalg")
+ fill_builtin_region(linalgDialect, named_op.operation)
+
if len(out_arg_defs) == 1:
return named_op.result
else:
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index da6fd4846bd67..1c50aa612cd31 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -10,5 +10,30 @@
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg,
- mlir::linalg::LinalgDialect)
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Apply the special region builder for the builtin named Linalg op.
+/// Assert that `op` is a builtin named Linalg op.
+void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
+ MlirOperation mlirOp) {
+ Operation *op = unwrap(mlirOp);
+ LinalgDialect::RegionBuilderFunType fun =
+ static_cast<LinalgDialect *>(unwrap(linalgDialect))
+ ->getRegionBuilder(op->getName().getStringRef());
+ assert(fun && "Expected a builtin named Linalg op.");
+ assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
+ assert(op->getRegion(0).getBlocks().empty() &&
+ "Expected Linalg op with 0 blocks");
+ SmallVector<Type, 8> argTypes;
+ auto linalgOp = cast<LinalgOp>(op);
+ for (auto t : linalgOp.getShapedOperandTypes())
+ argTypes.push_back(getElementTypeOrSelf(t));
+ OpBuilder b(op->getContext());
+ Region ®ion = op->getRegion(0);
+ Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes);
+ // TODO: allow captures.
+ fun(*body, ValueRange{});
+}
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index 8cd2d4f833a7d..2288f734dcb5b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -57,6 +57,38 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
// LinalgDialect
//===----------------------------------------------------------------------===//
+/// Trait to check if T provides a `regionBuilder` method.
+template <typename T, typename... Args>
+using has_region_builder = decltype(T::regionBuilder);
+template <typename T>
+using detect_has_region_builder = llvm::is_detected<has_region_builder, T>;
+
+/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
+/// an OpInterface).
+template <typename OpType, typename = std::enable_if_t<
+ !detect_has_region_builder<OpType>::value>>
+void addNamedOpBuilderImpl(
+ llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+ // Do nothing.
+}
+
+template <typename OpType,
+ typename = std::enable_if_t<detect_has_region_builder<OpType>::value>,
+ typename = void>
+void addNamedOpBuilderImpl(
+ llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+ map.insert(std::make_pair(
+ OpType::getOperationName(),
+ static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder)));
+}
+
+template <typename... OpTypes>
+void addNamedOpBuilders(
+ llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+ (void)std::initializer_list<int>{0,
+ (addNamedOpBuilderImpl<OpTypes>(map), 0)...};
+}
+
void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations<
@@ -72,6 +104,12 @@ void mlir::linalg::LinalgDialect::initialize() {
#include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc"
>();
+ // Fill the Linalg-specific OpName to RegionBuilder map.
+ addNamedOpBuilders<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(namedStructuredOpRegionBuilders);
+
addInterfaces<LinalgInlinerInterface>();
}
diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py
index 8f2eb06004cee..489aa63f47fdf 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/ops.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py
@@ -5,7 +5,6 @@
from mlir.dialects import linalg
from mlir.dialects import std
-
def run(f):
print("\nTEST:", f.__name__)
f()
@@ -82,9 +81,9 @@ def testStructuredOpOnBuffers():
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
print(module)
-# CHECK-LABEL: TEST: testNamedStructuredOp
+# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
@run
-def testNamedStructuredOp():
+def testNamedStructuredOpCustomForm():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
@@ -93,10 +92,45 @@ def testNamedStructuredOp():
RankedTensorType.get((16, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
- # CHECK: linalg.matmul
- # TODO: prperly hook up the region.
+ # First check the named form with custom format
+ # CHECK: linalg.matmul
+ # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
+ # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
+ # CHECK-SAME: -> tensor<4x8xf32>
+ # CHECK-NEXT: return
return linalg.matmul(lhs, rhs, outs=[init_result.result])
+ print(module)
+
+# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
+ at run
+def testNamedStructuredOpGenericForm():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+ RankedTensorType.get((16, 8), f32))
+ def named_form(lhs, rhs):
+ init_result = linalg.InitTensorOp([4, 8], f32)
+ # CHECK: "linalg.matmul"(%{{.*}})
+ # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
+ # CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
+ # CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
+ # CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
+ # CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
+ # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
+ return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+ module.operation.print(print_generic_op_form=True)
+
+# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
+ at run
+def testNamedStructuredAsGenericOp():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def generic_form(lhs, rhs):
More information about the Mlir-commits
mailing list