[Mlir-commits] [mlir] 3a3a09f - [mlir][python] Provide more convenient wrappers for std.ConstantOp

Alex Zinenko llvmlistbot at llvm.org
Mon Oct 4 02:45:37 PDT 2021


Author: Alex Zinenko
Date: 2021-10-04T11:45:27+02:00
New Revision: 3a3a09f65412dc38aba6b7370b93f9d2c7fd1c30

URL: https://github.com/llvm/llvm-project/commit/3a3a09f65412dc38aba6b7370b93f9d2c7fd1c30
DIFF: https://github.com/llvm/llvm-project/commit/3a3a09f65412dc38aba6b7370b93f9d2c7fd1c30.diff

LOG: [mlir][python] Provide more convenient wrappers for std.ConstantOp

Constructing a ConstantOp using the default-generated API is verbose and
requires to specify the constant type twice: for the result type of the
operation and for the type of the attribute. It also requires to explicitly
construct the attribute. Provide custom constructors that take the type once
and accept a raw value instead of the attribute. This requires dynamic dispatch
based on type in the constructor. Also provide the corresponding accessors to
raw values.

In addition, provide a "refinement" class ConstantIndexOp similar to what
exists in C++. Unlike other "op view" Python classes, operations cannot be
automatically downcasted to this class since it does not correspond to a
specific operation name. It only exists to simplify construction of the
operation.

Depends On D110946

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110947

Added: 
    mlir/python/mlir/dialects/_std_ops_ext.py
    mlir/test/python/dialects/std.py

Modified: 
    mlir/python/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index eb7e1e40d3f18..4f0d1548e8793 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -136,7 +136,9 @@ declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/StandardOps.td
-  SOURCES dialects/std.py
+  SOURCES
+    dialects/std.py
+    dialects/_std_ops_ext.py
   DIALECT_NAME std)
 
 declare_mlir_dialect_python_bindings(

diff  --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py
new file mode 100644
index 0000000000000..bb67fe44d3b59
--- /dev/null
+++ b/mlir/python/mlir/dialects/_std_ops_ext.py
@@ -0,0 +1,71 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from .builtin import FuncOp
+  from ._ods_common import get_default_loc_context as _get_default_loc_context
+
+  from typing import Any, List, Optional, Union
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+
+def _isa(obj: Any, cls: type):
+  try:
+    cls(obj)
+  except ValueError:
+    return False
+  return True
+
+
+def _is_any_of(obj: Any, classes: List[type]):
+  return any(_isa(obj, cls) for cls in classes)
+
+
+def _is_integer_like_type(type: Type):
+  return _is_any_of(type, [IntegerType, IndexType])
+
+
+def _is_float_type(type: Type):
+  return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+
+
+class ConstantOp:
+  """Specialization for the constant op class."""
+
+  def __init__(self,
+               result: Type,
+               value: Union[int, float, Attribute],
+               *,
+               loc=None,
+               ip=None):
+    if isinstance(value, int):
+      super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
+    elif isinstance(value, float):
+      super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
+    else:
+      super().__init__(result, value, loc=loc, ip=ip)
+
+  @classmethod
+  def create_index(cls, value: int, *, loc=None, ip=None):
+    """Create an index-typed constant."""
+    return cls(
+        IndexType.get(context=_get_default_loc_context(loc)),
+        value,
+        loc=loc,
+        ip=ip)
+
+  @property
+  def type(self):
+    return self.results[0].type
+
+  @property
+  def literal_value(self) -> Union[int, float]:
+    if _is_integer_like_type(self.type):
+      return IntegerAttr(self.value).value
+    elif _is_float_type(self.type):
+      return FloatAttr(self.value).value
+    else:
+      raise ValueError("only integer and float constants have literal values")

diff  --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py
new file mode 100644
index 0000000000000..ed507d664de69
--- /dev/null
+++ b/mlir/test/python/dialects/std.py
@@ -0,0 +1,64 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import std
+
+
+def constructAndPrintInModule(f):
+  print("\nTEST:", f.__name__)
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      f()
+    print(module)
+  return f
+
+# CHECK-LABEL: TEST: testConstantOp
+
+ at constructAndPrintInModule
+def testConstantOp():
+  c1 = std.ConstantOp(IntegerType.get_signless(32), 42)
+  c2 = std.ConstantOp(IntegerType.get_signless(64), 100)
+  c3 = std.ConstantOp(F32Type.get(), 3.14)
+  c4 = std.ConstantOp(F64Type.get(), 1.23)
+  # CHECK: 42
+  print(c1.literal_value)
+
+  # CHECK: 100
+  print(c2.literal_value)
+
+  # CHECK: 3.140000104904175
+  print(c3.literal_value)
+
+  # CHECK: 1.23
+  print(c4.literal_value)
+
+# CHECK: = constant 42 : i32
+# CHECK: = constant 100 : i64
+# CHECK: = constant 3.140000e+00 : f32
+# CHECK: = constant 1.230000e+00 : f64
+
+# CHECK-LABEL: TEST: testVectorConstantOp
+ at constructAndPrintInModule
+def testVectorConstantOp():
+  int_type = IntegerType.get_signless(32)
+  vec_type = VectorType.get([2, 2], int_type)
+  c1 = std.ConstantOp(vec_type,
+                      DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
+  try:
+    print(c1.literal_value)
+  except ValueError as e:
+    assert "only integer and float constants have literal values" in str(e)
+  else:
+    assert False
+
+# CHECK: = constant dense<42> : vector<2x2xi32>
+
+# CHECK-LABEL: TEST: testConstantIndexOp
+ at constructAndPrintInModule
+def testConstantIndexOp():
+  c1 = std.ConstantOp.create_index(10)
+  # CHECK: 10
+  print(c1.literal_value)
+
+# CHECK: = constant 10 : index


        


More information about the Mlir-commits mailing list