[Mlir-commits] [mlir] b5f3a12 - [mlir][Python][Linalg] Add support for captures in body builder.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Apr 16 01:48:01 PDT 2021


Author: Nicolas Vasilache
Date: 2021-04-16T08:47:26Z
New Revision: b5f3a128bf8cae46ccf0616477a4775fd168fd7c

URL: https://github.com/llvm/llvm-project/commit/b5f3a128bf8cae46ccf0616477a4775fd168fd7c
DIFF: https://github.com/llvm/llvm-project/commit/b5f3a128bf8cae46ccf0616477a4775fd168fd7c.diff

LOG: [mlir][Python][Linalg] Add support for captures in body builder.

When Linalg named ops support was added, captures were omitted
from the body builder. This revision adds support for captures
which allows us to write FillOp in a more idiomatic fashion using
the _linalg_ops_ext mixin support.

This raises an issue in the generation of `_linalg_ops_gen.py` where
```
  @property
  def result(self):
    return self.operation.results[0] if len(self.operation.results) > 1 else None
```.
The condition should be `== 1`.

This will be fixed in a separate commit.

Differential Revision: https://reviews.llvm.org/D100363

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/Linalg.h
    mlir/lib/Bindings/Python/DialectLinalg.cpp
    mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
    mlir/lib/CAPI/Dialect/Linalg.cpp
    mlir/test/Bindings/Python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 06f15f062c333..6e20eec16481a 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -1,11 +1,11 @@
-//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect --------*- C -*-===//
+//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect -------*- C -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
 // Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===----------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
 
 #ifndef MLIR_C_DIALECT_LINALG_H
 #define MLIR_C_DIALECT_LINALG_H
@@ -18,9 +18,11 @@ extern "C" {
 #endif
 
 /// Apply the special region builder for the builtin named Linalg op.
+/// The list of `capture` MlirValue is passed as-is to the region builder.
 /// Assert that `op` is a builtin named Linalg op.
 MLIR_CAPI_EXPORTED void
-mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op,
+                                   intptr_t n, MlirValue const *mlirCaptures);
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 

diff  --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index e4ef69411be8a..849a0039a3ccb 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -22,10 +22,15 @@ namespace python {
 void populateDialectLinalgSubmodule(py::module &m) {
   m.def(
       "fill_builtin_region",
-      [](PyDialectDescriptor &dialect, PyOperation &op) {
-        return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
+      [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
+        llvm::SmallVector<MlirValue, 4> mlirOperands;
+        mlirOperands.reserve(captures.size());
+        for (auto v : captures)
+          mlirOperands.push_back(py::cast<PyValue *>(v)->get());
+        mlirLinalgFillBuiltinNamedOpRegion(
+            dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data());
       },
-      py::arg("dialect"), py::arg("op"),
+      py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 }

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
index d787943d16372..4714e69b3e403 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
@@ -5,6 +5,47 @@
 from typing import Optional, Sequence, Union
 from ..ir import *
 from ._ods_common import get_default_loc_context
+# TODO: resolve name collision for Linalg functionality that is injected inside
+# the _mlir.dialects.linalg directly via pybind.
+from _mlir.dialects.linalg import fill_builtin_region
+
+
+def isa(cls : Type, ty : Type):
+  try:
+    cls(ty)
+    return True
+  except ValueError:
+    return False
+
+
+class FillOp:
+  """Extends the linalg.fill op."""
+
+  def __init__(self,
+               output: Value,
+               value: Value,
+               *,
+               loc=None,
+               ip=None):
+    results = []
+    if isa(RankedTensorType, output.type):
+      results = [output.type]
+    op = self.build_generic(results=results,
+                            operands=[output, value],
+                            attributes=None,
+                            loc=loc,
+                            ip=ip)
+    OpView.__init__(self, op)
+    linalgDialect = Context.current.get_dialect_descriptor("linalg")
+    fill_builtin_region(linalgDialect, self.operation, [value])
+    # TODO: self.result is None. When len(results) == 1 we expect it to be
+    # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
+    # in the generator of _linalg_ops_gen.py where we have:
+    # ```
+    # def result(self):
+    #   return self.operation.results[0] \
+    #     if len(self.operation.results) > 1 else None
+    # ```
 
 
 class InitTensorOp:

diff  --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 1c50aa612cd31..6f6e090d737a8 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir/CAPI/Registration.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 
 using namespace mlir;
@@ -16,8 +17,14 @@ using namespace mlir::linalg;
 /// Apply the special region builder for the builtin named Linalg op.
 /// Assert that `op` is a builtin named Linalg op.
 void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
-                                        MlirOperation mlirOp) {
+                                        MlirOperation mlirOp, intptr_t n,
+                                        MlirValue const *mlirCaptures) {
   Operation *op = unwrap(mlirOp);
+  SmallVector<Value> captures;
+  captures.reserve(n);
+  for (unsigned idx = 0; idx < n; ++idx)
+    captures.push_back(unwrap(mlirCaptures[idx]));
+
   LinalgDialect::RegionBuilderFunType fun =
       static_cast<LinalgDialect *>(unwrap(linalgDialect))
           ->getRegionBuilder(op->getName().getStringRef());
@@ -25,15 +32,18 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
   assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
   assert(op->getRegion(0).getBlocks().empty() &&
          "Expected Linalg op with 0 blocks");
+
   SmallVector<Type, 8> argTypes;
   auto linalgOp = cast<LinalgOp>(op);
   for (auto t : linalgOp.getShapedOperandTypes())
     argTypes.push_back(getElementTypeOrSelf(t));
+
   OpBuilder b(op->getContext());
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
-  // TODO: allow captures.
-  fun(*body, ValueRange{});
+  b.setInsertionPointToStart(body);
+  mlir::edsc::ScopedContext scope(b, op->getLoc());
+  fun(*body, captures);
 }
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py
index afcb5820a2216..f153ecbb28768 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/ops.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py
@@ -38,6 +38,40 @@ def zero_d():
 
   print(module)
 
+# CHECK-LABEL: TEST: testFill
+ at run
+def testFill():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      # CHECK-LABEL: func @fill_tensor
+      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
+      #  CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
+      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[OUT]], %[[CST]]) : tensor<12x?xf32>, f32 -> tensor<12x?xf32>
+      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
+      @builtin.FuncOp.from_py_func(
+          RankedTensorType.get((12, -1), f32))
+      def fill_tensor(out):
+        zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
+        # TODO: FillOp.result is None. When len(results) == 1 we expect it to
+        # be results[0] as per _linalg_ops_gen.py. This seems like an
+        # orthogonal bug in the generator of _linalg_ops_gen.py.
+        return linalg.FillOp(output=out, value=zero).results[0]
+
+      # CHECK-LABEL: func @fill_buffer
+      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
+      #  CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
+      #  CHECK-NEXT: linalg.fill(%[[OUT]], %[[CST]]) : memref<12x?xf32>, f32
+      #  CHECK-NEXT: return
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((12, -1), f32))
+      def fill_buffer(out):
+        zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
+        linalg.FillOp(output=out, value=zero)
+
+  print(module)
+
 
 # CHECK-LABEL: TEST: testStructuredOpOnTensors
 @run


        


More information about the Mlir-commits mailing list