[Mlir-commits] [mlir] af66cd1 - [mlir][Python] Context managers for Context, InsertionPoint, Location.

Stella Laurenzo llvmlistbot at llvm.org
Sun Nov 1 19:02:09 PST 2020


Author: Stella Laurenzo
Date: 2020-11-01T19:00:39-08:00
New Revision: af66cd173fe00b86d1c399b585738eaa01a65042

URL: https://github.com/llvm/llvm-project/commit/af66cd173fe00b86d1c399b585738eaa01a65042
DIFF: https://github.com/llvm/llvm-project/commit/af66cd173fe00b86d1c399b585738eaa01a65042.diff

LOG: [mlir][Python] Context managers for Context, InsertionPoint, Location.

* Finishes support for Context, InsertionPoint and Location to be carried by the thread using context managers.
* Introduces type casters and utilities so that DefaultPyMlirContext and DefaultPyLocation in method signatures does the right thing (allows explicit or gets from the thread context).
* Extend the rules for the thread context stack to handle nesting, appropriately inheriting and clearing depending on whether the context is the same.
* Refactors all method signatures to follow the new convention on trailing parameters for defaulting parameters (loc, ip, context). When the objects are carried in the thread context, this allows most explicit uses of these values to be elided.
* Removes the style guide section on putting accessors to construct global objects on the PyMlirContext: this style fails to make good use of the new facility since it is often the only thing remaining needing an MlirContext.
* Moves Module parse/creation from mlir.ir.Context to static methods on mlir.ir.Module.
* Moves Context.create_operation to a static Operation.create method.
* Moves Type parsing from mlir.ir.Context to static methods on mlir.ir.Type.
* Moves Attribute parsing from mlir.ir.Context to static methods on mlir.ir.Attribute.
* Move Location factory methods from mlir.ir.Context to static methods on mlir.ir.Location.
* Refactors the std dialect fake "ODS" generated code to take advantage of the new scheme.

Differential Revision: https://reviews.llvm.org/D90547

Added: 
    mlir/test/Bindings/Python/context_managers.py

Modified: 
    mlir/docs/Bindings/Python.md
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/lib/Bindings/Python/mlir/dialects/std.py
    mlir/test/Bindings/Python/dialects.py
    mlir/test/Bindings/Python/insertion_point.py
    mlir/test/Bindings/Python/ir_array_attributes.py
    mlir/test/Bindings/Python/ir_attributes.py
    mlir/test/Bindings/Python/ir_location.py
    mlir/test/Bindings/Python/ir_module.py
    mlir/test/Bindings/Python/ir_operation.py
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 782b46f503ea..e3aaae1f902a 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -110,47 +110,6 @@ future need:
 from _mlir import *
 ```
 
-### Limited use of globals
-
-For normal operations, parent-child constructor relationships are realized with
-constructor methods on a parent class as opposed to requiring
-invocation/creation from a global symbol.
-
-For example, consider two code fragments:
-
-```python
-
-op = build_my_op()
-
-region = mlir.Region(op)
-
-```
-
-vs
-
-```python
-
-op = build_my_op()
-
-region = op.new_region()
-
-```
-
-For tightly coupled data structures like `Operation`, the latter is generally
-preferred because:
-
-* It is syntactically less possible to create something that is going to access
-  illegal memory (less error handling in the bindings, less testing, etc).
-
-* It reduces the global-API surface area for creating related entities. This
-  makes it more likely that if constructing IR based on an Operation instance of
-  unknown providence, receiving code can just call methods on it to do what they
-  want versus needing to reach back into the global namespace and find the right
-  `Region` class.
-
-* It leaks fewer things that are in place for C++ convenience (i.e. default
-  constructors to invalid instances).
-
 ### Use the C-API
 
 The Python APIs should seek to layer on top of the C-API to the degree possible.
@@ -171,6 +130,20 @@ There are several top-level types in the core IR that are strongly owned by thei
 
 All other objects are dependent. All objects maintain a back-reference (keep-alive) to their closest containing top-level object. Further, dependent objects fall into two categories: a) uniqued (which live for the life-time of the context) and b) mutable. Mutable objects need additional machinery for keeping track of when the C++ instance that backs their Python object is no longer valid (typically due to some specific mutation of the IR, deletion, or bulk operation).
 
+### Optionality and argument ordering in the Core IR
+
+The following types support being bound to the current thread as a context manager:
+
+* `PyLocation` (`loc: mlir.ir.Location = None`)
+* `PyInsertionPoint` (`ip: mlir.ir.InsertionPoint = None`)
+* `PyMlirContext` (`context: mlir.ir.Context = None`)
+
+In order to support composability of function arguments, when these types appear as arguments, they should always be the last and appear in the above order and with the given names (which is generally the order in which they are expected to need to be expressed explicitly in special cases) as necessary. Each should carry a default value of `py::none()` and use either a manual or automatic conversion for resolving either with the explicit value or a value from the thread context manager (i.e. `DefaultingPyMlirContext` or `DefaultingPyLocation`).
+
+The rationale for this is that in Python, trailing keyword arguments to the *right* are the most composable, enabling a variety of strategies such as kwarg passthrough, default values, etc. Keeping function signatures composable increases the chances that interesting DSLs and higher level APIs can be constructed without a lot of exotic boilerplate.
+
+Used consistently, this enables a style of IR construction that rarely needs to use explicit contexts, locations, or insertion points but is free to do so when extra control is needed.
+
 #### Operation hierarchy
 
 As mentioned above, `PyOperation` is special because it can exist in either a top-level or dependent state. The life-cycle is unidirectional: operations can be created detached (top-level) and once added to another operation, they are then dependent for the remainder of their lifetime. The situation is more complicated when considering construction scenarios where an operation is added to a transitive parent that is still detached, necessitating further accounting at such transition points (i.e. all such added children are initially added to the IR with a parent of their outer-most detached operation, but then once it is added to an attached operation, they need to be re-parented to the containing module).

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 5e0e45d0784d..8c17e8e6d933 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -28,23 +28,18 @@ using llvm::SmallVector;
 // Docstrings (trivial, non-duplicated docstrings are included inline).
 //------------------------------------------------------------------------------
 
-static const char kContextCreateOperationDocstring[] =
-    R"(Creates a new operation.
+static const char kContextParseTypeDocstring[] =
+    R"(Parses the assembly form of a type.
 
-Args:
-  name: Operation name (e.g. "dialect.operation").
-  location: A Location object.
-  results: Sequence of Type representing op result types.
-  attributes: Dict of str:Attribute.
-  successors: List of Block for the operation's successors.
-  regions: Number of regions to create.
+Returns a Type object or raises a ValueError if the type cannot be parsed.
 
-Returns:
-  A new "detached" Operation object. Detached operations can be added
-  to blocks, which causes them to become "attached."
+See also: https://mlir.llvm.org/docs/LangRef/#type-system
 )";
 
-static const char kContextParseDocstring[] =
+static const char kContextGetFileLocationDocstring[] =
+    R"(Gets a Location representing a file, line and column)";
+
+static const char kModuleParseDocstring[] =
     R"(Parses a module's assembly format from a string.
 
 Returns a new MlirModule or raises a ValueError if the parsing fails.
@@ -52,20 +47,24 @@ Returns a new MlirModule or raises a ValueError if the parsing fails.
 See also: https://mlir.llvm.org/docs/LangRef/
 )";
 
-static const char kContextParseTypeDocstring[] =
-    R"(Parses the assembly form of a type.
-
-Returns a Type object or raises a ValueError if the type cannot be parsed.
+static const char kOperationCreateDocstring[] =
+    R"(Creates a new operation.
 
-See also: https://mlir.llvm.org/docs/LangRef/#type-system
+Args:
+  name: Operation name (e.g. "dialect.operation").
+  results: Sequence of Type representing op result types.
+  attributes: Dict of str:Attribute.
+  successors: List of Block for the operation's successors.
+  regions: Number of regions to create.
+  location: A Location object (defaults to resolve from context manager).
+  ip: An InsertionPoint (defaults to resolve from context manager or set to
+    False to disable insertion, even with an insertion point set in the
+    context manager).
+Returns:
+  A new "detached" Operation object. Detached operations can be added
+  to blocks, which causes them to become "attached."
 )";
 
-static const char kContextGetUnknownLocationDocstring[] =
-    R"(Gets a Location representing an unknown location)";
-
-static const char kContextGetFileLocationDocstring[] =
-    R"(Gets a Location representing a file, line and column)";
-
 static const char kOperationPrintDocstring[] =
     R"(Prints the assembly form of the operation to a file like object.
 
@@ -545,108 +544,26 @@ size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
-py::object PyMlirContext::createOperation(
-    std::string name, PyLocation location,
-    llvm::Optional<std::vector<PyValue *>> operands,
-    llvm::Optional<std::vector<PyType *>> results,
-    llvm::Optional<py::dict> attributes,
-    llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
-  llvm::SmallVector<MlirValue, 4> mlirOperands;
-  llvm::SmallVector<MlirType, 4> mlirResults;
-  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
-  llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
-
-  // General parameter validation.
-  if (regions < 0)
-    throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
-
-  // Unpack/validate operands.
-  if (operands) {
-    mlirOperands.reserve(operands->size());
-    for (PyValue *operand : *operands) {
-      if (!operand)
-        throw SetPyError(PyExc_ValueError, "operand value cannot be None");
-      mlirOperands.push_back(operand->get());
-    }
-  }
-
-  // Unpack/validate results.
-  if (results) {
-    mlirResults.reserve(results->size());
-    for (PyType *result : *results) {
-      // TODO: Verify result type originate from the same context.
-      if (!result)
-        throw SetPyError(PyExc_ValueError, "result type cannot be None");
-      mlirResults.push_back(result->type);
-    }
-  }
-  // Unpack/validate attributes.
-  if (attributes) {
-    mlirAttributes.reserve(attributes->size());
-    for (auto &it : *attributes) {
+pybind11::object PyMlirContext::contextEnter() {
+  return PyThreadContextEntry::pushContext(*this);
+}
 
-      auto name = it.first.cast<std::string>();
-      auto &attribute = it.second.cast<PyAttribute &>();
-      // TODO: Verify attribute originates from the same context.
-      mlirAttributes.emplace_back(std::move(name), attribute.attr);
-    }
-  }
-  // Unpack/validate successors.
-  if (successors) {
-    llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
-    mlirSuccessors.reserve(successors->size());
-    for (auto *successor : *successors) {
-      // TODO: Verify successor originate from the same context.
-      if (!successor)
-        throw SetPyError(PyExc_ValueError, "successor block cannot be None");
-      mlirSuccessors.push_back(successor->get());
-    }
-  }
+void PyMlirContext::contextExit(pybind11::object excType,
+                                pybind11::object excVal,
+                                pybind11::object excTb) {
+  PyThreadContextEntry::popContext(*this);
+}
 
-  // Apply unpacked/validated to the operation state. Beyond this
-  // point, exceptions cannot be thrown or else the state will leak.
-  MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
-  if (!mlirOperands.empty())
-    mlirOperationStateAddOperands(&state, mlirOperands.size(),
-                                  mlirOperands.data());
-  if (!mlirResults.empty())
-    mlirOperationStateAddResults(&state, mlirResults.size(),
-                                 mlirResults.data());
-  if (!mlirAttributes.empty()) {
-    // Note that the attribute names directly reference bytes in
-    // mlirAttributes, so that vector must not be changed from here
-    // on.
-    llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
-    mlirNamedAttributes.reserve(mlirAttributes.size());
-    for (auto &it : mlirAttributes)
-      mlirNamedAttributes.push_back(
-          mlirNamedAttributeGet(it.first.c_str(), it.second));
-    mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
-                                    mlirNamedAttributes.data());
-  }
-  if (!mlirSuccessors.empty())
-    mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
-                                    mlirSuccessors.data());
-  if (regions) {
-    llvm::SmallVector<MlirRegion, 4> mlirRegions;
-    mlirRegions.resize(regions);
-    for (int i = 0; i < regions; ++i)
-      mlirRegions[i] = mlirRegionCreate();
-    mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
-                                      mlirRegions.data());
+PyMlirContext &DefaultingPyMlirContext::resolve() {
+  PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
+  if (!context) {
+    throw SetPyError(
+        PyExc_RuntimeError,
+        "An MLIR function requires a Context but none was provided in the call "
+        "or from the surrounding environment. Either pass to the function with "
+        "a 'context=' argument or establish a default using 'with Context():'");
   }
-
-  // Construct the operation.
-  MlirOperation operation = mlirOperationCreate(&state);
-  PyOperationRef created = PyOperation::createDetached(getRef(), operation);
-
-  // InsertPoint active?
-  PyInsertionPoint *ip =
-      PyThreadContextEntry::getDefaultInsertionPoint(/*required=*/false);
-  if (ip)
-    ip->insert(*created.get());
-
-  return created.releaseObject();
+  return *context;
 }
 
 //------------------------------------------------------------------------------
@@ -658,17 +575,33 @@ std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
   return stack;
 }
 
-PyThreadContextEntry *PyThreadContextEntry::getTos() {
+PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
   auto &stack = getStack();
   if (stack.empty())
     return nullptr;
   return &stack.back();
 }
 
-void PyThreadContextEntry::push(pybind11::object context,
-                                pybind11::object insertionPoint) {
+void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
+                                py::object insertionPoint,
+                                py::object location) {
   auto &stack = getStack();
-  stack.emplace_back(std::move(context), std::move(insertionPoint));
+  stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
+                     std::move(location));
+  // If the new stack has more than one entry and the context of the new top
+  // entry matches the previous, copy the insertionPoint and location from the
+  // previous entry if missing from the new top entry.
+  if (stack.size() > 1) {
+    auto &prev = *(stack.rbegin() + 1);
+    auto &current = stack.back();
+    if (current.context.is(prev.context)) {
+      // Default non-context objects from the previous entry.
+      if (!current.insertionPoint)
+        current.insertionPoint = prev.insertionPoint;
+      if (!current.location)
+        current.location = prev.location;
+    }
+  }
 }
 
 PyMlirContext *PyThreadContextEntry::getContext() {
@@ -683,30 +616,87 @@ PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
   return py::cast<PyInsertionPoint *>(insertionPoint);
 }
 
-PyMlirContext *PyThreadContextEntry::getDefaultContext(bool required) {
-  auto *tos = getTos();
-  PyMlirContext *context = tos ? tos->getContext() : nullptr;
-  if (required && !context) {
-    throw SetPyError(
-        PyExc_RuntimeError,
-        "A default context is required for this call but is not provided. "
-        "Establish a default by surrounding the code with "
-        "'with context:'");
-  }
-  return context;
+PyLocation *PyThreadContextEntry::getLocation() {
+  if (!location)
+    return nullptr;
+  return py::cast<PyLocation *>(location);
+}
+
+PyMlirContext *PyThreadContextEntry::getDefaultContext() {
+  auto *tos = getTopOfStack();
+  return tos ? tos->getContext() : nullptr;
+}
+
+PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
+  auto *tos = getTopOfStack();
+  return tos ? tos->getInsertionPoint() : nullptr;
+}
+
+PyLocation *PyThreadContextEntry::getDefaultLocation() {
+  auto *tos = getTopOfStack();
+  return tos ? tos->getLocation() : nullptr;
+}
+
+py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
+  py::object contextObj = py::cast(context);
+  push(FrameKind::Context, /*context=*/contextObj,
+       /*insertionPoint=*/py::object(),
+       /*location=*/py::object());
+  return contextObj;
+}
+
+void PyThreadContextEntry::popContext(PyMlirContext &context) {
+  auto &stack = getStack();
+  if (stack.empty())
+    throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
+  auto &tos = stack.back();
+  if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
+    throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
+  stack.pop_back();
 }
 
-PyInsertionPoint *
-PyThreadContextEntry::getDefaultInsertionPoint(bool required) {
-  auto *tos = getTos();
-  PyInsertionPoint *ip = tos ? tos->getInsertionPoint() : nullptr;
-  if (required && !ip)
+py::object
+PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
+  py::object contextObj =
+      insertionPoint.getBlock().getParentOperation()->getContext().getObject();
+  py::object insertionPointObj = py::cast(insertionPoint);
+  push(FrameKind::InsertionPoint,
+       /*context=*/contextObj,
+       /*insertionPoint=*/insertionPointObj,
+       /*location=*/py::object());
+  return insertionPointObj;
+}
+
+void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
+  auto &stack = getStack();
+  if (stack.empty())
+    throw SetPyError(PyExc_RuntimeError,
+                     "Unbalanced InsertionPoint enter/exit");
+  auto &tos = stack.back();
+  if (tos.frameKind != FrameKind::InsertionPoint &&
+      tos.getInsertionPoint() != &insertionPoint)
     throw SetPyError(PyExc_RuntimeError,
-                     "A default insertion point is required for this call but "
-                     "is not provided. "
-                     "Establish a default by surrounding the code with "
-                     "'with InsertionPoint(...):'");
-  return ip;
+                     "Unbalanced InsertionPoint enter/exit");
+  stack.pop_back();
+}
+
+py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
+  py::object contextObj = location.getContext().getObject();
+  py::object locationObj = py::cast(location);
+  push(FrameKind::Location, /*context=*/contextObj,
+       /*insertionPoint=*/py::object(),
+       /*location=*/locationObj);
+  return locationObj;
+}
+
+void PyThreadContextEntry::popLocation(PyLocation &location) {
+  auto &stack = getStack();
+  if (stack.empty())
+    throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
+  auto &tos = stack.back();
+  if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
+    throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
+  stack.pop_back();
 }
 
 //------------------------------------------------------------------------------
@@ -727,6 +717,31 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
   return dialect;
 }
 
+//------------------------------------------------------------------------------
+// PyLocation
+//------------------------------------------------------------------------------
+
+py::object PyLocation::contextEnter() {
+  return PyThreadContextEntry::pushLocation(*this);
+}
+
+void PyLocation::contextExit(py::object excType, py::object excVal,
+                             py::object excTb) {
+  PyThreadContextEntry::popLocation(*this);
+}
+
+PyLocation &DefaultingPyLocation::resolve() {
+  auto *location = PyThreadContextEntry::getDefaultLocation();
+  if (!location) {
+    throw SetPyError(
+        PyExc_RuntimeError,
+        "An MLIR function requires a Location but none was provided in the "
+        "call or from the surrounding environment. Either pass to the function "
+        "with a 'loc=' argument or establish a default using 'with loc:'");
+  }
+  return *location;
+}
+
 //------------------------------------------------------------------------------
 // PyModule
 //------------------------------------------------------------------------------
@@ -911,6 +926,117 @@ PyBlock PyOperation::getBlock() {
   return PyBlock{std::move(parentOperation), block};
 }
 
+py::object PyOperation::create(
+    std::string name, llvm::Optional<std::vector<PyValue *>> operands,
+    llvm::Optional<std::vector<PyType *>> results,
+    llvm::Optional<py::dict> attributes,
+    llvm::Optional<std::vector<PyBlock *>> successors, int regions,
+    DefaultingPyLocation location, py::object maybeIp) {
+  llvm::SmallVector<MlirValue, 4> mlirOperands;
+  llvm::SmallVector<MlirType, 4> mlirResults;
+  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
+  llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
+
+  // General parameter validation.
+  if (regions < 0)
+    throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
+
+  // Unpack/validate operands.
+  if (operands) {
+    mlirOperands.reserve(operands->size());
+    for (PyValue *operand : *operands) {
+      if (!operand)
+        throw SetPyError(PyExc_ValueError, "operand value cannot be None");
+      mlirOperands.push_back(operand->get());
+    }
+  }
+
+  // Unpack/validate results.
+  if (results) {
+    mlirResults.reserve(results->size());
+    for (PyType *result : *results) {
+      // TODO: Verify result type originate from the same context.
+      if (!result)
+        throw SetPyError(PyExc_ValueError, "result type cannot be None");
+      mlirResults.push_back(result->type);
+    }
+  }
+  // Unpack/validate attributes.
+  if (attributes) {
+    mlirAttributes.reserve(attributes->size());
+    for (auto &it : *attributes) {
+
+      auto name = it.first.cast<std::string>();
+      auto &attribute = it.second.cast<PyAttribute &>();
+      // TODO: Verify attribute originates from the same context.
+      mlirAttributes.emplace_back(std::move(name), attribute.attr);
+    }
+  }
+  // Unpack/validate successors.
+  if (successors) {
+    llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
+    mlirSuccessors.reserve(successors->size());
+    for (auto *successor : *successors) {
+      // TODO: Verify successor originate from the same context.
+      if (!successor)
+        throw SetPyError(PyExc_ValueError, "successor block cannot be None");
+      mlirSuccessors.push_back(successor->get());
+    }
+  }
+
+  // Apply unpacked/validated to the operation state. Beyond this
+  // point, exceptions cannot be thrown or else the state will leak.
+  MlirOperationState state = mlirOperationStateGet(name.c_str(), location->loc);
+  if (!mlirOperands.empty())
+    mlirOperationStateAddOperands(&state, mlirOperands.size(),
+                                  mlirOperands.data());
+  if (!mlirResults.empty())
+    mlirOperationStateAddResults(&state, mlirResults.size(),
+                                 mlirResults.data());
+  if (!mlirAttributes.empty()) {
+    // Note that the attribute names directly reference bytes in
+    // mlirAttributes, so that vector must not be changed from here
+    // on.
+    llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
+    mlirNamedAttributes.reserve(mlirAttributes.size());
+    for (auto &it : mlirAttributes)
+      mlirNamedAttributes.push_back(
+          mlirNamedAttributeGet(it.first.c_str(), it.second));
+    mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
+                                    mlirNamedAttributes.data());
+  }
+  if (!mlirSuccessors.empty())
+    mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
+                                    mlirSuccessors.data());
+  if (regions) {
+    llvm::SmallVector<MlirRegion, 4> mlirRegions;
+    mlirRegions.resize(regions);
+    for (int i = 0; i < regions; ++i)
+      mlirRegions[i] = mlirRegionCreate();
+    mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
+                                      mlirRegions.data());
+  }
+
+  // Construct the operation.
+  MlirOperation operation = mlirOperationCreate(&state);
+  PyOperationRef created =
+      PyOperation::createDetached(location->getContext(), operation);
+
+  // InsertPoint active?
+  if (!maybeIp.is(py::cast(false))) {
+    PyInsertionPoint *ip;
+    if (maybeIp.is_none()) {
+      ip = PyThreadContextEntry::getDefaultInsertionPoint();
+    } else {
+      ip = py::cast<PyInsertionPoint *>(maybeIp);
+    }
+    if (ip)
+      ip->insert(*created.get());
+  }
+
+  return created.releaseObject();
+}
+
 PyOpView::PyOpView(py::object operation)
     : operationObject(std::move(operation)),
       operation(py::cast<PyOperation *>(this->operationObject)) {}
@@ -998,26 +1124,13 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
 }
 
 py::object PyInsertionPoint::contextEnter() {
-  auto context = block.getParentOperation()->getContext().getObject();
-  py::object self = py::cast(this);
-  PyThreadContextEntry::push(/*context=*/std::move(context),
-                             /*insertionPoint=*/self);
-  return self;
+  return PyThreadContextEntry::pushInsertionPoint(*this);
 }
 
 void PyInsertionPoint::contextExit(pybind11::object excType,
                                    pybind11::object excVal,
                                    pybind11::object excTb) {
-  auto &stack = PyThreadContextEntry::getStack();
-  if (stack.empty())
-    throw SetPyError(PyExc_RuntimeError,
-                     "Unbalanced insertion point enter/exit");
-  auto &tos = stack.back();
-  PyInsertionPoint *current = tos.getInsertionPoint();
-  if (current != this)
-    throw SetPyError(PyExc_RuntimeError,
-                     "Unbalanced insertion point enter/exit");
-  stack.pop_back();
+  PyThreadContextEntry::popInsertionPoint(*this);
 }
 
 //------------------------------------------------------------------------------
@@ -1299,10 +1412,9 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        // TODO: Make the location optional and create a default location.
-        [](PyType &type, double value, PyLocation &loc) {
+        [](PyType &type, double value, DefaultingPyLocation loc) {
           MlirAttribute attr =
-              mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc);
+              mlirFloatAttrDoubleGetChecked(type.type, value, loc->loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirAttributeIsNull(attr)) {
@@ -1313,25 +1425,25 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
           }
           return PyFloatAttribute(type.getContext(), attr);
         },
-        py::arg("type"), py::arg("value"), py::arg("loc"),
+        py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
         "Gets an uniqued float point attribute associated to a type");
     c.def_static(
         "get_f32",
-        [](PyMlirContext &context, double value) {
+        [](double value, DefaultingPyMlirContext context) {
           MlirAttribute attr = mlirFloatAttrDoubleGet(
-              context.get(), mlirF32TypeGet(context.get()), value);
-          return PyFloatAttribute(context.getRef(), attr);
+              context->get(), mlirF32TypeGet(context->get()), value);
+          return PyFloatAttribute(context->getRef(), attr);
         },
-        py::arg("context"), py::arg("value"),
+        py::arg("value"), py::arg("context") = py::none(),
         "Gets an uniqued float point attribute associated to a f32 type");
     c.def_static(
         "get_f64",
-        [](PyMlirContext &context, double value) {
+        [](double value, DefaultingPyMlirContext context) {
           MlirAttribute attr = mlirFloatAttrDoubleGet(
-              context.get(), mlirF64TypeGet(context.get()), value);
-          return PyFloatAttribute(context.getRef(), attr);
+              context->get(), mlirF64TypeGet(context->get()), value);
+          return PyFloatAttribute(context->getRef(), attr);
         },
-        py::arg("context"), py::arg("value"),
+        py::arg("value"), py::arg("context") = py::none(),
         "Gets an uniqued float point attribute associated to a f64 type");
     c.def_property_readonly(
         "value",
@@ -1377,11 +1489,12 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context, bool value) {
-          MlirAttribute attr = mlirBoolAttrGet(context.get(), value);
-          return PyBoolAttribute(context.getRef(), attr);
+        [](bool value, DefaultingPyMlirContext context) {
+          MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+          return PyBoolAttribute(context->getRef(), attr);
         },
-        py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute");
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued bool attribute");
     c.def_property_readonly(
         "value",
         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
@@ -1398,11 +1511,12 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context, std::string value) {
+        [](std::string value, DefaultingPyMlirContext context) {
           MlirAttribute attr =
-              mlirStringAttrGet(context.get(), value.size(), &value[0]);
-          return PyStringAttribute(context.getRef(), attr);
+              mlirStringAttrGet(context->get(), value.size(), &value[0]);
+          return PyStringAttribute(context->getRef(), attr);
         },
+        py::arg("value"), py::arg("context") = py::none(),
         "Gets a uniqued string attribute");
     c.def_static(
         "get_typed",
@@ -1432,9 +1546,9 @@ class PyDenseElementsAttribute
   static constexpr const char *pyClassName = "DenseElementsAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
 
-  static PyDenseElementsAttribute getFromBuffer(PyMlirContext &contextWrapper,
-                                                py::buffer array,
-                                                bool signless) {
+  static PyDenseElementsAttribute
+  getFromBuffer(py::buffer array, bool signless,
+                DefaultingPyMlirContext contextWrapper) {
     // Request a contiguous view. In exotic cases, this will cause a copy.
     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
     Py_buffer *view = new Py_buffer();
@@ -1444,21 +1558,21 @@ class PyDenseElementsAttribute
     }
     py::buffer_info arrayInfo(view);
 
-    MlirContext context = contextWrapper.get();
+    MlirContext context = contextWrapper->get();
     // Switch on the types that can be bulk loaded between the Python and
     // MLIR-C APIs.
     if (arrayInfo.format == "f") {
       // f32
       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
       return PyDenseElementsAttribute(
-          contextWrapper.getRef(),
+          contextWrapper->getRef(),
           bulkLoad(context, mlirDenseElementsAttrFloatGet,
                    mlirF32TypeGet(context), arrayInfo));
     } else if (arrayInfo.format == "d") {
       // f64
       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
       return PyDenseElementsAttribute(
-          contextWrapper.getRef(),
+          contextWrapper->getRef(),
           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
                    mlirF64TypeGet(context), arrayInfo));
     } else if (arrayInfo.format == "i") {
@@ -1466,7 +1580,7 @@ class PyDenseElementsAttribute
       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
       MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
                                       : mlirIntegerTypeSignedGet(context, 32);
-      return PyDenseElementsAttribute(contextWrapper.getRef(),
+      return PyDenseElementsAttribute(contextWrapper->getRef(),
                                       bulkLoad(context,
                                                mlirDenseElementsAttrInt32Get,
                                                elementType, arrayInfo));
@@ -1475,7 +1589,7 @@ class PyDenseElementsAttribute
       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
       MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
                                       : mlirIntegerTypeUnsignedGet(context, 32);
-      return PyDenseElementsAttribute(contextWrapper.getRef(),
+      return PyDenseElementsAttribute(contextWrapper->getRef(),
                                       bulkLoad(context,
                                                mlirDenseElementsAttrUInt32Get,
                                                elementType, arrayInfo));
@@ -1484,7 +1598,7 @@ class PyDenseElementsAttribute
       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
       MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
                                       : mlirIntegerTypeSignedGet(context, 64);
-      return PyDenseElementsAttribute(contextWrapper.getRef(),
+      return PyDenseElementsAttribute(contextWrapper->getRef(),
                                       bulkLoad(context,
                                                mlirDenseElementsAttrInt64Get,
                                                elementType, arrayInfo));
@@ -1493,7 +1607,7 @@ class PyDenseElementsAttribute
       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
       MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
                                       : mlirIntegerTypeUnsignedGet(context, 64);
-      return PyDenseElementsAttribute(contextWrapper.getRef(),
+      return PyDenseElementsAttribute(contextWrapper->getRef(),
                                       bulkLoad(context,
                                                mlirDenseElementsAttrUInt64Get,
                                                elementType, arrayInfo));
@@ -1540,8 +1654,9 @@ class PyDenseElementsAttribute
 
   static void bindDerived(ClassTy &c) {
     c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
-                 py::arg("context"), py::arg("array"),
-                 py::arg("signless") = true, "Gets from a buffer or ndarray")
+                 py::arg("array"), py::arg("signless") = true,
+                 py::arg("context") = py::none(),
+                 "Gets from a buffer or ndarray")
         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
                     py::arg("shaped_type"), py::arg("element_attr"),
                     "Gets a DenseElementsAttr where all values are the same")
@@ -1624,24 +1739,27 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get_signless",
-        [](PyMlirContext &context, unsigned width) {
-          MlirType t = mlirIntegerTypeGet(context.get(), width);
-          return PyIntegerType(context.getRef(), t);
+        [](unsigned width, DefaultingPyMlirContext context) {
+          MlirType t = mlirIntegerTypeGet(context->get(), width);
+          return PyIntegerType(context->getRef(), t);
         },
+        py::arg("width"), py::arg("context") = py::none(),
         "Create a signless integer type");
     c.def_static(
         "get_signed",
-        [](PyMlirContext &context, unsigned width) {
-          MlirType t = mlirIntegerTypeSignedGet(context.get(), width);
-          return PyIntegerType(context.getRef(), t);
+        [](unsigned width, DefaultingPyMlirContext context) {
+          MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+          return PyIntegerType(context->getRef(), t);
         },
+        py::arg("width"), py::arg("context") = py::none(),
         "Create a signed integer type");
     c.def_static(
         "get_unsigned",
-        [](PyMlirContext &context, unsigned width) {
-          MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width);
-          return PyIntegerType(context.getRef(), t);
+        [](unsigned width, DefaultingPyMlirContext context) {
+          MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+          return PyIntegerType(context->getRef(), t);
         },
+        py::arg("width"), py::arg("context") = py::none(),
         "Create an unsigned integer type");
     c.def_property_readonly(
         "width",
@@ -1678,11 +1796,11 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context) {
-          MlirType t = mlirIndexTypeGet(context.get());
-          return PyIndexType(context.getRef(), t);
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirIndexTypeGet(context->get());
+          return PyIndexType(context->getRef(), t);
         },
-        "Create a index type.");
+        py::arg("context") = py::none(), "Create a index type.");
   }
 };
 
@@ -1696,11 +1814,11 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context) {
-          MlirType t = mlirBF16TypeGet(context.get());
-          return PyBF16Type(context.getRef(), t);
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirBF16TypeGet(context->get());
+          return PyBF16Type(context->getRef(), t);
         },
-        "Create a bf16 type.");
+        py::arg("context") = py::none(), "Create a bf16 type.");
   }
 };
 
@@ -1714,11 +1832,11 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context) {
-          MlirType t = mlirF16TypeGet(context.get());
-          return PyF16Type(context.getRef(), t);
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirF16TypeGet(context->get());
+          return PyF16Type(context->getRef(), t);
         },
-        "Create a f16 type.");
+        py::arg("context") = py::none(), "Create a f16 type.");
   }
 };
 
@@ -1732,11 +1850,11 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context) {
-          MlirType t = mlirF32TypeGet(context.get());
-          return PyF32Type(context.getRef(), t);
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirF32TypeGet(context->get());
+          return PyF32Type(context->getRef(), t);
         },
-        "Create a f32 type.");
+        py::arg("context") = py::none(), "Create a f32 type.");
   }
 };
 
@@ -1750,11 +1868,11 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context) {
-          MlirType t = mlirF64TypeGet(context.get());
-          return PyF64Type(context.getRef(), t);
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirF64TypeGet(context->get());
+          return PyF64Type(context->getRef(), t);
         },
-        "Create a f64 type.");
+        py::arg("context") = py::none(), "Create a f64 type.");
   }
 };
 
@@ -1768,11 +1886,11 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context) {
-          MlirType t = mlirNoneTypeGet(context.get());
-          return PyNoneType(context.getRef(), t);
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirNoneTypeGet(context->get());
+          return PyNoneType(context->getRef(), t);
         },
-        "Create a none type.");
+        py::arg("context") = py::none(), "Create a none type.");
   }
 };
 
@@ -1892,10 +2010,10 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        // TODO: Make the location optional and create a default location.
-        [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
+        [](std::vector<int64_t> shape, PyType &elementType,
+           DefaultingPyLocation loc) {
           MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
-                                                elementType.type, loc.loc);
+                                                elementType.type, loc->loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -1907,6 +2025,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
           }
           return PyVectorType(elementType.getContext(), t);
         },
+        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
         "Create a vector type");
   }
 };
@@ -1922,10 +2041,10 @@ class PyRankedTensorType
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        // TODO: Make the location optional and create a default location.
-        [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
+        [](std::vector<int64_t> shape, PyType &elementType,
+           DefaultingPyLocation loc) {
           MlirType t = mlirRankedTensorTypeGetChecked(
-              shape.size(), shape.data(), elementType.type, loc.loc);
+              shape.size(), shape.data(), elementType.type, loc->loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -1939,6 +2058,7 @@ class PyRankedTensorType
           }
           return PyRankedTensorType(elementType.getContext(), t);
         },
+        py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
         "Create a ranked tensor type");
   }
 };
@@ -1954,10 +2074,9 @@ class PyUnrankedTensorType
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        // TODO: Make the location optional and create a default location.
-        [](PyType &elementType, PyLocation &loc) {
+        [](PyType &elementType, DefaultingPyLocation loc) {
           MlirType t =
-              mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
+              mlirUnrankedTensorTypeGetChecked(elementType.type, loc->loc);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -1971,6 +2090,7 @@ class PyUnrankedTensorType
           }
           return PyUnrankedTensorType(elementType.getContext(), t);
         },
+        py::arg("element_type"), py::arg("loc") = py::none(),
         "Create a unranked tensor type");
   }
 };
@@ -1989,10 +2109,10 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
          "get_contiguous_memref",
          // TODO: Make the location optional and create a default location.
          [](PyType &elementType, std::vector<int64_t> shape,
-            unsigned memorySpace, PyLocation &loc) {
+            unsigned memorySpace, DefaultingPyLocation loc) {
            MlirType t = mlirMemRefTypeContiguousGetChecked(
                elementType.type, shape.size(), shape.data(), memorySpace,
-               loc.loc);
+               loc->loc);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2006,7 +2126,8 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            }
            return PyMemRefType(elementType.getContext(), t);
          },
-         "Create a memref type")
+         py::arg("element_type"), py::arg("shape"), py::arg("memory_space"),
+         py::arg("loc") = py::none(), "Create a memref type")
         .def_property_readonly(
             "num_affine_maps",
             [](PyMemRefType &self) -> intptr_t {
@@ -2034,10 +2155,10 @@ class PyUnrankedMemRefType
   static void bindDerived(ClassTy &c) {
     c.def_static(
          "get",
-         // TODO: Make the location optional and create a default location.
-         [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
+         [](PyType &elementType, unsigned memorySpace,
+            DefaultingPyLocation loc) {
            MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
-                                                         memorySpace, loc.loc);
+                                                         memorySpace, loc->loc);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2051,7 +2172,8 @@ class PyUnrankedMemRefType
            }
            return PyUnrankedMemRefType(elementType.getContext(), t);
          },
-         "Create a unranked memref type")
+         py::arg("element_type"), py::arg("memory_space"),
+         py::arg("loc") = py::none(), "Create a unranked memref type")
         .def_property_readonly(
             "memory_space",
             [](PyUnrankedMemRefType &self) -> unsigned {
@@ -2071,15 +2193,16 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get_tuple",
-        [](PyMlirContext &context, py::list elementList) {
+        [](py::list elementList, DefaultingPyMlirContext context) {
           intptr_t num = py::len(elementList);
           // Mapping py::list to SmallVector.
           SmallVector<MlirType, 4> elements;
           for (auto element : elementList)
             elements.push_back(element.cast<PyType>().type);
-          MlirType t = mlirTupleTypeGet(context.get(), num, elements.data());
-          return PyTupleType(context.getRef(), t);
+          MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+          return PyTupleType(context->getRef(), t);
         },
+        py::arg("elements"), py::arg("context") = py::none(),
         "Create a tuple type");
     c.def(
         "get_type",
@@ -2107,16 +2230,16 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](PyMlirContext &context, std::vector<PyType> inputs,
-           std::vector<PyType> results) {
+        [](std::vector<PyType> inputs, std::vector<PyType> results,
+           DefaultingPyMlirContext context) {
           SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
           SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
-          MlirType t = mlirFunctionTypeGet(context.get(), inputsRaw.size(),
+          MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
                                            inputsRaw.data(), resultsRaw.size(),
                                            resultsRaw.data());
-          return PyFunctionType(context.getRef(), t);
+          return PyFunctionType(context->getRef(), t);
         },
-        py::arg("context"), py::arg("inputs"), py::arg("results"),
+        py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
         "Gets a FunctionType from a list of input and result types");
     c.def_property_readonly(
         "inputs",
@@ -2170,6 +2293,17 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+      .def("__enter__", &PyMlirContext::contextEnter)
+      .def("__exit__", &PyMlirContext::contextExit)
+      .def_property_readonly_static(
+          "current",
+          [](py::object & /*class*/) {
+            auto *context = PyThreadContextEntry::getDefaultContext();
+            if (!context)
+              throw SetPyError(PyExc_ValueError, "No current Context");
+            return context;
+          },
+          "Gets the Context bound to the current thread or raises ValueError")
       .def_property_readonly(
           "dialects",
           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
@@ -2196,79 +2330,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           [](PyMlirContext &self, bool value) {
             mlirContextSetAllowUnregisteredDialects(self.get(), value);
-          })
-      .def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
-           py::arg("location"), py::arg("operands") = py::none(),
-           py::arg("results") = py::none(), py::arg("attributes") = py::none(),
-           py::arg("successors") = py::none(), py::arg("regions") = 0,
-           kContextCreateOperationDocstring)
-      .def(
-          "parse_module",
-          [](PyMlirContext &self, const std::string moduleAsm) {
-            MlirModule module =
-                mlirModuleCreateParse(self.get(), moduleAsm.c_str());
-            // TODO: Rework error reporting once diagnostic engine is exposed
-            // in C API.
-            if (mlirModuleIsNull(module)) {
-              throw SetPyError(
-                  PyExc_ValueError,
-                  "Unable to parse module assembly (see diagnostics)");
-            }
-            return PyModule::forModule(module).releaseObject();
-          },
-          kContextParseDocstring)
-      .def(
-          "create_module",
-          [](PyMlirContext &self, PyLocation &loc) {
-            MlirModule module = mlirModuleCreateEmpty(loc.loc);
-            return PyModule::forModule(module).releaseObject();
-          },
-          py::arg("loc"), "Creates an empty module")
-      .def(
-          "parse_attr",
-          [](PyMlirContext &self, std::string attrSpec) {
-            MlirAttribute type =
-                mlirAttributeParseGet(self.get(), attrSpec.c_str());
-            // TODO: Rework error reporting once diagnostic engine is exposed
-            // in C API.
-            if (mlirAttributeIsNull(type)) {
-              throw SetPyError(PyExc_ValueError,
-                               llvm::Twine("Unable to parse attribute: '") +
-                                   attrSpec + "'");
-            }
-            return PyAttribute(self.getRef(), type);
-          },
-          py::keep_alive<0, 1>())
-      .def(
-          "parse_type",
-          [](PyMlirContext &self, std::string typeSpec) {
-            MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str());
-            // TODO: Rework error reporting once diagnostic engine is exposed
-            // in C API.
-            if (mlirTypeIsNull(type)) {
-              throw SetPyError(PyExc_ValueError,
-                               llvm::Twine("Unable to parse type: '") +
-                                   typeSpec + "'");
-            }
-            return PyType(self.getRef(), type);
-          },
-          kContextParseTypeDocstring)
-      .def(
-          "get_unknown_location",
-          [](PyMlirContext &self) {
-            return PyLocation(self.getRef(),
-                              mlirLocationUnknownGet(self.get()));
-          },
-          kContextGetUnknownLocationDocstring)
-      .def(
-          "get_file_location",
-          [](PyMlirContext &self, std::string filename, int line, int col) {
-            return PyLocation(self.getRef(),
-                              mlirLocationFileLineColGet(
-                                  self.get(), filename.c_str(), line, col));
-          },
-          kContextGetFileLocationDocstring, py::arg("filename"),
-          py::arg("line"), py::arg("col"));
+          });
 
   //----------------------------------------------------------------------------
   // Mapping of PyDialectDescriptor
@@ -2327,6 +2389,35 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of Location
   //----------------------------------------------------------------------------
   py::class_<PyLocation>(m, "Location")
+      .def("__enter__", &PyLocation::contextEnter)
+      .def("__exit__", &PyLocation::contextExit)
+      .def_property_readonly_static(
+          "current",
+          [](py::object & /*class*/) {
+            auto *loc = PyThreadContextEntry::getDefaultLocation();
+            if (!loc)
+              throw SetPyError(PyExc_ValueError, "No current Location");
+            return loc;
+          },
+          "Gets the Location bound to the current thread or raises ValueError")
+      .def_static(
+          "unknown",
+          [](DefaultingPyMlirContext context) {
+            return PyLocation(context->getRef(),
+                              mlirLocationUnknownGet(context->get()));
+          },
+          py::arg("context") = py::none(),
+          "Gets a Location representing an unknown location")
+      .def_static(
+          "file",
+          [](std::string filename, int line, int col,
+             DefaultingPyMlirContext context) {
+            return PyLocation(context->getRef(),
+                              mlirLocationFileLineColGet(
+                                  context->get(), filename.c_str(), line, col));
+          },
+          py::arg("filename"), py::arg("line"), py::arg("col"),
+          py::arg("context") = py::none(), kContextGetFileLocationDocstring)
       .def_property_readonly(
           "context",
           [](PyLocation &self) { return self.getContext().getObject(); },
@@ -2344,6 +2435,29 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   py::class_<PyModule>(m, "Module")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+      .def_static(
+          "parse",
+          [](const std::string moduleAsm, DefaultingPyMlirContext context) {
+            MlirModule module =
+                mlirModuleCreateParse(context->get(), moduleAsm.c_str());
+            // TODO: Rework error reporting once diagnostic engine is exposed
+            // in C API.
+            if (mlirModuleIsNull(module)) {
+              throw SetPyError(
+                  PyExc_ValueError,
+                  "Unable to parse module assembly (see diagnostics)");
+            }
+            return PyModule::forModule(module).releaseObject();
+          },
+          py::arg("asm"), py::arg("context") = py::none(),
+          kModuleParseDocstring)
+      .def_static(
+          "create",
+          [](DefaultingPyLocation loc) {
+            MlirModule module = mlirModuleCreateEmpty(loc->loc);
+            return PyModule::forModule(module).releaseObject();
+          },
+          py::arg("loc") = py::none(), "Creates an empty module")
       .def_property_readonly(
           "context",
           [](PyModule &self) { return self.getContext().getObject(); },
@@ -2388,6 +2502,13 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of Operation.
   //----------------------------------------------------------------------------
   py::class_<PyOperation>(m, "Operation")
+      .def_static("create", &PyOperation::create, py::arg("name"),
+                  py::arg("operands") = py::none(),
+                  py::arg("results") = py::none(),
+                  py::arg("attributes") = py::none(),
+                  py::arg("successors") = py::none(), py::arg("regions") = 0,
+                  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
+                  kOperationCreateDocstring)
       .def_property_readonly(
           "context",
           [](PyOperation &self) { return self.getContext().getObject(); },
@@ -2520,6 +2641,16 @@ void mlir::python::populateIRSubmodule(py::module &m) {
            "Inserts after the last operation but still inside the block.")
       .def("__enter__", &PyInsertionPoint::contextEnter)
       .def("__exit__", &PyInsertionPoint::contextExit)
+      .def_property_readonly_static(
+          "current",
+          [](py::object & /*class*/) {
+            auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
+            if (!ip)
+              throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
+            return ip;
+          },
+          "Gets the InsertionPoint bound to the current thread or raises "
+          "ValueError if none has been set")
       .def(py::init<PyOperation &>(), py::arg("beforeOperation"),
            "Inserts before a referenced operation.")
       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
@@ -2533,6 +2664,22 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of PyAttribute.
   //----------------------------------------------------------------------------
   py::class_<PyAttribute>(m, "Attribute")
+      .def_static(
+          "parse",
+          [](std::string attrSpec, DefaultingPyMlirContext context) {
+            MlirAttribute type =
+                mlirAttributeParseGet(context->get(), attrSpec.c_str());
+            // TODO: Rework error reporting once diagnostic engine is exposed
+            // in C API.
+            if (mlirAttributeIsNull(type)) {
+              throw SetPyError(PyExc_ValueError,
+                               llvm::Twine("Unable to parse attribute: '") +
+                                   attrSpec + "'");
+            }
+            return PyAttribute(context->getRef(), type);
+          },
+          py::arg("asm"), py::arg("context") = py::none(),
+          "Parses an attribute from an assembly form")
       .def_property_readonly(
           "context",
           [](PyAttribute &self) { return self.getContext().getObject(); },
@@ -2628,6 +2775,21 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of PyType.
   //----------------------------------------------------------------------------
   py::class_<PyType>(m, "Type")
+      .def_static(
+          "parse",
+          [](std::string typeSpec, DefaultingPyMlirContext context) {
+            MlirType type = mlirTypeParseGet(context->get(), typeSpec.c_str());
+            // TODO: Rework error reporting once diagnostic engine is exposed
+            // in C API.
+            if (mlirTypeIsNull(type)) {
+              throw SetPyError(PyExc_ValueError,
+                               llvm::Twine("Unable to parse type: '") +
+                                   typeSpec + "'");
+            }
+            return PyType(context->getRef(), type);
+          },
+          py::arg("asm"), py::arg("context") = py::none(),
+          kContextParseTypeDocstring)
       .def_property_readonly(
           "context", [](PyType &self) { return self.getContext().getObject(); },
           "Context that owns the Type")

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 6b1b69941958..e7fdbb9e7a5c 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -11,7 +11,7 @@
 
 #include <vector>
 
-#include <pybind11/pybind11.h>
+#include "PybindUtils.h"
 
 #include "mlir-c/IR.h"
 #include "llvm/ADT/DenseMap.h"
@@ -22,7 +22,9 @@ namespace python {
 class PyBlock;
 class PyInsertionPoint;
 class PyLocation;
+class DefaultingPyLocation;
 class PyMlirContext;
+class DefaultingPyMlirContext;
 class PyModule;
 class PyOperation;
 class PyType;
@@ -81,43 +83,65 @@ class PyObjectRef {
 };
 
 /// Tracks an entry in the thread context stack. New entries are pushed onto
-/// here for each with block that activates a new InsertionPoint or Context.
-/// Pushing either a context or an insertion point resets the other:
-///   - a new context activates a new entry with a null insertion point.
-///   - a new insertion point activates a new entry with the context that the
-///     insertion point is bound to.
+/// here for each with block that activates a new InsertionPoint, Context or
+/// Location.
+///
+/// Pushing either a Location or InsertionPoint also pushes its associated
+/// Context. Pushing a Context will not modify the Location or InsertionPoint
+/// unless if they are from a 
diff erent context, in which case, they are
+/// cleared.
 class PyThreadContextEntry {
 public:
-  PyThreadContextEntry(pybind11::object context,
-                       pybind11::object insertionPoint)
-      : context(std::move(context)), insertionPoint(std::move(insertionPoint)) {
-  }
+  enum class FrameKind {
+    Context,
+    InsertionPoint,
+    Location,
+  };
+
+  PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
+                       pybind11::object insertionPoint,
+                       pybind11::object location)
+      : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
+        location(std::move(location)), frameKind(frameKind) {}
 
   /// Gets the top of stack context and return nullptr if not defined.
-  /// If required is true and there is no default, a nice user-facing exception
-  /// is raised.
-  static PyMlirContext *getDefaultContext(bool required);
+  static PyMlirContext *getDefaultContext();
 
   /// Gets the top of stack insertion point and return nullptr if not defined.
-  /// If required is true and there is no default, a nice user-facing exception
-  /// is raised.
-  static PyInsertionPoint *getDefaultInsertionPoint(bool required);
+  static PyInsertionPoint *getDefaultInsertionPoint();
+
+  /// Gets the top of stack location and returns nullptr if not defined.
+  static PyLocation *getDefaultLocation();
 
   PyMlirContext *getContext();
   PyInsertionPoint *getInsertionPoint();
+  PyLocation *getLocation();
+  FrameKind getFrameKind() { return frameKind; }
 
   /// Stack management.
-  static PyThreadContextEntry *getTos();
-  static void push(pybind11::object context, pybind11::object insertionPoint);
+  static PyThreadContextEntry *getTopOfStack();
+  static pybind11::object pushContext(PyMlirContext &context);
+  static void popContext(PyMlirContext &context);
+  static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
+  static void popInsertionPoint(PyInsertionPoint &insertionPoint);
+  static pybind11::object pushLocation(PyLocation &location);
+  static void popLocation(PyLocation &location);
 
   /// Gets the thread local stack.
   static std::vector<PyThreadContextEntry> &getStack();
 
 private:
+  static void push(FrameKind frameKind, pybind11::object context,
+                   pybind11::object insertionPoint, pybind11::object location);
+
   /// An object reference to the PyContext.
   pybind11::object context;
   /// An object reference to the current insertion point.
   pybind11::object insertionPoint;
+  /// An object reference to the current location.
+  pybind11::object location;
+  // The kind of push that was performed.
+  FrameKind frameKind;
 };
 
 /// Wrapper around MlirContext.
@@ -172,14 +196,10 @@ class PyMlirContext {
   /// Used for testing.
   size_t getLiveModuleCount();
 
-  /// Creates an operation. See corresponding python docstring.
-  pybind11::object
-  createOperation(std::string name, PyLocation location,
-                  llvm::Optional<std::vector<PyValue *>> operands,
-                  llvm::Optional<std::vector<PyType *>> results,
-                  llvm::Optional<pybind11::dict> attributes,
-                  llvm::Optional<std::vector<PyBlock *>> successors,
-                  int regions);
+  /// Enter and exit the context manager.
+  pybind11::object contextEnter();
+  void contextExit(pybind11::object excType, pybind11::object excVal,
+                   pybind11::object excTb);
 
 private:
   PyMlirContext(MlirContext context);
@@ -213,6 +233,17 @@ class PyMlirContext {
   friend class PyOperation;
 };
 
+/// Used in function arguments when None should resolve to the current context
+/// manager set instance.
+class DefaultingPyMlirContext
+    : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
+public:
+  using Defaulting::Defaulting;
+  static constexpr const char kTypeDescription[] =
+      "[ThreadContextAware] mlir.ir.Context";
+  static PyMlirContext &resolve();
+};
+
 /// Base class for all objects that directly or indirectly depend on an
 /// MlirContext. The lifetime of the context will extend at least to the
 /// lifetime of these instances.
@@ -275,9 +306,26 @@ class PyLocation : public BaseContextObject {
 public:
   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
       : BaseContextObject(std::move(contextRef)), loc(loc) {}
+
+  /// Enter and exit the context manager.
+  pybind11::object contextEnter();
+  void contextExit(pybind11::object excType, pybind11::object excVal,
+                   pybind11::object excTb);
+
   MlirLocation loc;
 };
 
+/// Used in function arguments when None should resolve to the current context
+/// manager set instance.
+class DefaultingPyLocation
+    : public Defaulting<DefaultingPyLocation, PyLocation> {
+public:
+  using Defaulting::Defaulting;
+  static constexpr const char kTypeDescription[] =
+      "[ThreadContextAware] mlir.ir.Location";
+  static PyLocation &resolve();
+};
+
 /// Wrapper around MlirModule.
 /// This is the top-level, user-owned object that contains regions/ops/blocks.
 class PyModule;
@@ -376,6 +424,14 @@ class PyOperation : public BaseContextObject {
   /// no parent.
   PyOperationRef getParentOperation();
 
+  /// Creates an operation. See corresponding python docstring.
+  static pybind11::object
+  create(std::string name, llvm::Optional<std::vector<PyValue *>> operands,
+         llvm::Optional<std::vector<PyType *>> results,
+         llvm::Optional<pybind11::dict> attributes,
+         llvm::Optional<std::vector<PyBlock *>> successors, int regions,
+         DefaultingPyLocation location, pybind11::object ip);
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,
@@ -478,6 +534,8 @@ class PyInsertionPoint {
   void contextExit(pybind11::object excType, pybind11::object excVal,
                    pybind11::object excTb);
 
+  PyBlock &getBlock() { return block; }
+
 private:
   // Trampoline constructor that avoids null initializing members while
   // looking up parents.
@@ -560,4 +618,17 @@ void populateIRSubmodule(pybind11::module &m);
 } // namespace python
 } // namespace mlir
 
+namespace pybind11 {
+namespace detail {
+
+template <>
+struct type_caster<mlir::python::DefaultingPyMlirContext>
+    : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
+template <>
+struct type_caster<mlir::python::DefaultingPyLocation>
+    : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
+
+} // namespace detail
+} // namespace pybind11
+
 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 0435aa461809..c97b87173e87 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -15,13 +15,6 @@
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/Twine.h"
 
-namespace pybind11 {
-namespace detail {
-template <typename T>
-struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
-} // namespace detail
-} // namespace pybind11
-
 namespace mlir {
 namespace python {
 
@@ -32,7 +25,78 @@ namespace python {
 pybind11::error_already_set SetPyError(PyObject *excClass,
                                        const llvm::Twine &message);
 
+/// CRTP template for special wrapper types that are allowed to be passed in as
+/// 'None' function arguments and can be resolved by some global mechanic if
+/// so. Such types will raise an error if this global resolution fails, and
+/// it is actually illegal for them to ever be unresolved. From a user
+/// perspective, they behave like a smart ptr to the underlying type (i.e.
+/// 'get' method and operator-> overloaded).
+///
+/// Derived types must provide a method, which is called when an environmental
+/// resolution is required. It must raise an exception if resolution fails:
+///   static ReferrentTy &resolve()
+///
+/// They must also provide a parameter description that will be used in
+/// error messages about mismatched types:
+///   static constexpr const char kTypeDescription[] = "<Description>";
+
+template <typename DerivedTy, typename T>
+class Defaulting {
+public:
+  using ReferrentTy = T;
+  /// Type casters require the type to be default constructible, but using
+  /// such an instance is illegal.
+  Defaulting() = default;
+  Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
+
+  ReferrentTy *get() { return referrent; }
+  ReferrentTy *operator->() { return referrent; }
+
+private:
+  ReferrentTy *referrent = nullptr;
+};
+
 } // namespace python
 } // namespace mlir
 
+namespace pybind11 {
+namespace detail {
+
+template <typename DefaultingTy>
+struct MlirDefaultingCaster {
+  PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
+
+  bool load(pybind11::handle src, bool) {
+    if (src.is_none()) {
+      // Note that we do want an exception to propagate from here as it will be
+      // the most informative.
+      value = DefaultingTy{DefaultingTy::resolve()};
+      return true;
+    }
+
+    // Unlike many casters that chain, these casters are expected to always
+    // succeed, so instead of doing an isinstance check followed by a cast,
+    // just cast in one step and handle the exception. Returning false (vs
+    // letting the exception propagate) causes higher level signature parsing
+    // code to produce nice error messages (other than "Cannot cast...").
+    try {
+      value = DefaultingTy{
+          pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
+      return true;
+    } catch (std::exception &e) {
+      return false;
+    }
+  }
+
+  static handle cast(DefaultingTy src, return_value_policy policy,
+                     handle parent) {
+    return pybind11::cast(src, policy);
+  }
+};
+
+template <typename T>
+struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
+} // namespace detail
+} // namespace pybind11
+
 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/lib/Bindings/Python/mlir/dialects/std.py
index 2afc642e0e3d..74f990cdb5ed 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/std.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/std.py
@@ -5,20 +5,22 @@
 # TODO: This file should be auto-generated.
 
 from . import _cext
+_ir = _cext.ir
 
 @_cext.register_dialect
-class _Dialect(_cext.ir.Dialect):
+class _Dialect(_ir.Dialect):
   # Special case: 'std' namespace aliases to the empty namespace.
   DIALECT_NAMESPACE = "std"
   pass
 
 @_cext.register_operation(_Dialect)
-class AddFOp(_cext.ir.OpView):
+class AddFOp(_ir.OpView):
   OPERATION_NAME = "std.addf"
 
-  def __init__(self, loc, lhs, rhs):
-    super().__init__(loc.context.create_operation(
-      "std.addf", loc, operands=[lhs, rhs], results=[lhs.type]))
+  def __init__(self, lhs, rhs, loc=None, ip=None):
+    super().__init__(_ir.Operation.create(
+      "std.addf", operands=[lhs, rhs], results=[lhs.type],
+      loc=loc, ip=ip))
 
   @property
   def lhs(self):

diff  --git a/mlir/test/Bindings/Python/context_managers.py b/mlir/test/Bindings/Python/context_managers.py
new file mode 100644
index 000000000000..33a89381a416
--- /dev/null
+++ b/mlir/test/Bindings/Python/context_managers.py
@@ -0,0 +1,99 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testContextEnterExit
+def testContextEnterExit():
+  with Context() as ctx:
+    assert Context.current is ctx
+  try:
+    _ = Context.current
+  except ValueError as e:
+    # CHECK: No current Context
+    print(e)
+  else: assert False, "Expected exception"
+
+run(testContextEnterExit)
+
+
+# CHECK-LABEL: TEST: testLocationEnterExit
+def testLocationEnterExit():
+  ctx1 = Context()
+  with Location.unknown(ctx1) as loc1:
+    assert Context.current is ctx1
+    assert Location.current is loc1
+
+    # Re-asserting the same context should not change the location.
+    with ctx1:
+      assert Context.current is ctx1
+      assert Location.current is loc1
+      # Asserting a 
diff erent context should clear it.
+      with Context() as ctx2:
+        assert Context.current is ctx2
+        try:
+          _ = Location.current
+        except ValueError: pass
+        else: assert False, "Expected exception"
+
+      # And should restore.
+      assert Context.current is ctx1
+      assert Location.current is loc1
+
+  # All should clear.
+  try:
+    _ = Location.current
+  except ValueError as e:
+    # CHECK: No current Location
+    print(e)
+  else: assert False, "Expected exception"
+
+run(testLocationEnterExit)
+
+
+# CHECK-LABEL: TEST: testInsertionPointEnterExit
+def testInsertionPointEnterExit():
+  ctx1 = Context()
+  m = Module.create(Location.unknown(ctx1))
+  ip = InsertionPoint.at_block_terminator(m.body)
+
+  with ip:
+    assert InsertionPoint.current is ip
+    # Asserting a location from the same context should preserve.
+    with Location.unknown(ctx1) as loc1:
+      assert InsertionPoint.current is ip
+      assert Location.current is loc1
+    # Location should clear.
+    try:
+      _ = Location.current
+    except ValueError: pass
+    else: assert False, "Expected exception"
+
+    # Asserting the same Context should preserve.
+    with ctx1:
+      assert InsertionPoint.current is ip
+
+    # Asserting a 
diff erent context should clear it.
+    with Context() as ctx2:
+      assert Context.current is ctx2
+      try:
+        _ = InsertionPoint.current
+      except ValueError: pass
+      else: assert False, "Expected exception"
+
+  # All should clear.
+  try:
+    _ = InsertionPoint.current
+  except ValueError as e:
+    # CHECK: No current InsertionPoint
+    print(e)
+  else: assert False, "Expected exception"
+
+run(testInsertionPointEnterExit)

diff  --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index 172258dd7840..e66c67f08095 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -1,18 +1,18 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import mlir
+from mlir.ir import *
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testDialectDescriptor
 def testDialectDescriptor():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   d = ctx.get_dialect_descriptor("std")
   # CHECK: <DialectDescriptor std>
   print(d)
@@ -30,7 +30,7 @@ def testDialectDescriptor():
 
 # CHECK-LABEL: TEST: testUserDialectClass
 def testUserDialectClass():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   # Access using attribute.
   d = ctx.dialects.std
   # Note that the standard dialect namespace prints as ''. Others will print
@@ -68,26 +68,25 @@ def testUserDialectClass():
 # TODO: Op creation and access is still quite verbose: simplify this test as
 # additional capabilities come online.
 def testCustomOpView():
-  ctx = mlir.ir.Context()
-  ctx.allow_unregistered_dialects = True
-  f32 = mlir.ir.F32Type.get(ctx)
-  loc = ctx.get_unknown_location()
-  m = ctx.create_module(loc)
-
   def createInput():
-    op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32])
+    op = Operation.create("pytest_dummy.intinput", results=[f32])
     # TODO: Auto result cast from operation
     return op.results[0]
 
-  with mlir.ir.InsertionPoint.at_block_terminator(m.body):
-    # Create via dialects context collection.
-    input1 = createInput()
-    input2 = createInput()
-    op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
+  with Context() as ctx, Location.unknown():
+    ctx.allow_unregistered_dialects = True
+    m = Module.create()
+
+    with InsertionPoint.at_block_terminator(m.body):
+      f32 = F32Type.get()
+      # Create via dialects context collection.
+      input1 = createInput()
+      input2 = createInput()
+      op1 = ctx.dialects.std.AddFOp(input1, input2)
 
-    # Create via an import
-    from mlir.dialects.std import AddFOp
-    AddFOp(loc, input1, op1.result)
+      # Create via an import
+      from mlir.dialects.std import AddFOp
+      AddFOp(input1, op1.result)
 
   # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
   # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"

diff  --git a/mlir/test/Bindings/Python/insertion_point.py b/mlir/test/Bindings/Python/insertion_point.py
index bbdd670e6aa7..d2b05b807978 100644
--- a/mlir/test/Bindings/Python/insertion_point.py
+++ b/mlir/test/Bindings/Python/insertion_point.py
@@ -16,18 +16,18 @@ def run(f):
 def test_insert_at_block_end():
   ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  module = ctx.parse_module(r"""
-    func @foo() -> () {
-      "custom.op1"() : () -> ()
-    }
-  """)
-  entry_block = module.body.operations[0].regions[0].blocks[0]
-  ip = InsertionPoint(entry_block)
-  ip.insert(ctx.create_operation("custom.op2", loc))
-  # CHECK: "custom.op1"
-  # CHECK: "custom.op2"
-  module.operation.print()
+  with Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @foo() -> () {
+        "custom.op1"() : () -> ()
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    ip = InsertionPoint(entry_block)
+    ip.insert(Operation.create("custom.op2"))
+    # CHECK: "custom.op1"
+    # CHECK: "custom.op2"
+    module.operation.print()
 
 run(test_insert_at_block_end)
 
@@ -36,20 +36,20 @@ def test_insert_at_block_end():
 def test_insert_before_operation():
   ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  module = ctx.parse_module(r"""
-    func @foo() -> () {
-      "custom.op1"() : () -> ()
-      "custom.op2"() : () -> ()
-    }
-  """)
-  entry_block = module.body.operations[0].regions[0].blocks[0]
-  ip = InsertionPoint(entry_block.operations[1])
-  ip.insert(ctx.create_operation("custom.op3", loc))
-  # CHECK: "custom.op1"
-  # CHECK: "custom.op3"
-  # CHECK: "custom.op2"
-  module.operation.print()
+  with Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @foo() -> () {
+        "custom.op1"() : () -> ()
+        "custom.op2"() : () -> ()
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    ip = InsertionPoint(entry_block.operations[1])
+    ip.insert(Operation.create("custom.op3"))
+    # CHECK: "custom.op1"
+    # CHECK: "custom.op3"
+    # CHECK: "custom.op2"
+    module.operation.print()
 
 run(test_insert_before_operation)
 
@@ -58,18 +58,18 @@ def test_insert_before_operation():
 def test_insert_at_block_begin():
   ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  module = ctx.parse_module(r"""
-    func @foo() -> () {
-      "custom.op2"() : () -> ()
-    }
-  """)
-  entry_block = module.body.operations[0].regions[0].blocks[0]
-  ip = InsertionPoint.at_block_begin(entry_block)
-  ip.insert(ctx.create_operation("custom.op1", loc))
-  # CHECK: "custom.op1"
-  # CHECK: "custom.op2"
-  module.operation.print()
+  with Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @foo() -> () {
+        "custom.op2"() : () -> ()
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    ip = InsertionPoint.at_block_begin(entry_block)
+    ip.insert(Operation.create("custom.op1"))
+    # CHECK: "custom.op1"
+    # CHECK: "custom.op2"
+    module.operation.print()
 
 run(test_insert_at_block_begin)
 
@@ -86,19 +86,19 @@ def test_insert_at_block_begin_empty():
 def test_insert_at_terminator():
   ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  module = ctx.parse_module(r"""
-    func @foo() -> () {
-      "custom.op1"() : () -> ()
-      return
-    }
-  """)
-  entry_block = module.body.operations[0].regions[0].blocks[0]
-  ip = InsertionPoint.at_block_terminator(entry_block)
-  ip.insert(ctx.create_operation("custom.op2", loc))
-  # CHECK: "custom.op1"
-  # CHECK: "custom.op2"
-  module.operation.print()
+  with Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @foo() -> () {
+        "custom.op1"() : () -> ()
+        return
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    ip = InsertionPoint.at_block_terminator(entry_block)
+    ip.insert(Operation.create("custom.op2"))
+    # CHECK: "custom.op1"
+    # CHECK: "custom.op2"
+    module.operation.print()
 
 run(test_insert_at_terminator)
 
@@ -107,20 +107,20 @@ def test_insert_at_terminator():
 def test_insert_at_block_terminator_missing():
   ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  module = ctx.parse_module(r"""
-    func @foo() -> () {
-      "custom.op1"() : () -> ()
-    }
-  """)
-  entry_block = module.body.operations[0].regions[0].blocks[0]
-  try:
-    ip = InsertionPoint.at_block_terminator(entry_block)
-  except ValueError as e:
-    # CHECK: Block has no terminator
-    print(e)
-  else:
-    assert False, "Expected exception"
+  with ctx:
+    module = Module.parse(r"""
+      func @foo() -> () {
+        "custom.op1"() : () -> ()
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    try:
+      ip = InsertionPoint.at_block_terminator(entry_block)
+    except ValueError as e:
+      # CHECK: Block has no terminator
+      print(e)
+    else:
+      assert False, "Expected exception"
 
 run(test_insert_at_block_terminator_missing)
 
@@ -129,24 +129,24 @@ def test_insert_at_block_terminator_missing():
 def test_insertion_point_context():
   ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  module = ctx.parse_module(r"""
-    func @foo() -> () {
-      "custom.op1"() : () -> ()
-    }
-  """)
-  entry_block = module.body.operations[0].regions[0].blocks[0]
-  with InsertionPoint(entry_block):
-    ctx.create_operation("custom.op2", loc)
-    with InsertionPoint.at_block_begin(entry_block):
-      ctx.create_operation("custom.opa", loc)
-      ctx.create_operation("custom.opb", loc)
-    ctx.create_operation("custom.op3", loc)
-  # CHECK: "custom.opa"
-  # CHECK: "custom.opb"
-  # CHECK: "custom.op1"
-  # CHECK: "custom.op2"
-  # CHECK: "custom.op3"
-  module.operation.print()
+  with Location.unknown(ctx):
+    module = Module.parse(r"""
+      func @foo() -> () {
+        "custom.op1"() : () -> ()
+      }
+    """)
+    entry_block = module.body.operations[0].regions[0].blocks[0]
+    with InsertionPoint(entry_block):
+      Operation.create("custom.op2")
+      with InsertionPoint.at_block_begin(entry_block):
+        Operation.create("custom.opa")
+        Operation.create("custom.opb")
+      Operation.create("custom.op3")
+    # CHECK: "custom.opa"
+    # CHECK: "custom.opb"
+    # CHECK: "custom.op1"
+    # CHECK: "custom.op2"
+    # CHECK: "custom.op3"
+    module.operation.print()
 
 run(test_insertion_point_context)

diff  --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py
index 97a9802ae148..74b5451aafe0 100644
--- a/mlir/test/Bindings/Python/ir_array_attributes.py
+++ b/mlir/test/Bindings/Python/ir_array_attributes.py
@@ -3,27 +3,27 @@
 # and we may want to disable if not available.
 
 import gc
-import mlir
+from mlir.ir import *
 import numpy as np
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 ################################################################################
 # Tests of the array/buffer .get() factory method on unsupported dtype.
 ################################################################################
 
 def testGetDenseElementsUnsupported():
-  ctx = mlir.ir.Context()
-  array = np.array([["hello", "goodbye"]])
-  try:
-    attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  except ValueError as e:
-    # CHECK: unimplemented array format conversion from format:
-    print(e)
+  with Context():
+    array = np.array([["hello", "goodbye"]])
+    try:
+      attr = DenseElementsAttr.get(array)
+    except ValueError as e:
+      # CHECK: unimplemented array format conversion from format:
+      print(e)
 
 run(testGetDenseElementsUnsupported)
 
@@ -33,63 +33,60 @@ def testGetDenseElementsUnsupported():
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatInt
 def testGetDenseElementsSplatInt():
-  ctx = mlir.ir.Context()
-  loc = ctx.get_unknown_location()
-  t = mlir.ir.IntegerType.get_signless(ctx, 32)
-  element = mlir.ir.IntegerAttr.get(t, 555)
-  shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
-  attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element)
-  # CHECK: dense<555> : tensor<2x3x4xi32>
-  print(attr)
-  # CHECK: is_splat: True
-  print("is_splat:", attr.is_splat)
+  with Context(), Location.unknown():
+    t = IntegerType.get_signless(32)
+    element = IntegerAttr.get(t, 555)
+    shaped_type = RankedTensorType.get((2, 3, 4), t)
+    attr = DenseElementsAttr.get_splat(shaped_type, element)
+    # CHECK: dense<555> : tensor<2x3x4xi32>
+    print(attr)
+    # CHECK: is_splat: True
+    print("is_splat:", attr.is_splat)
 
 run(testGetDenseElementsSplatInt)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
 def testGetDenseElementsSplatFloat():
-  ctx = mlir.ir.Context()
-  loc = ctx.get_unknown_location()
-  t = mlir.ir.F32Type.get(ctx)
-  element = mlir.ir.FloatAttr.get(t, 1.2, loc)
-  shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
-  attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element)
-  # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
-  print(attr)
+  with Context(), Location.unknown():
+    t = F32Type.get()
+    element = FloatAttr.get(t, 1.2)
+    shaped_type = RankedTensorType.get((2, 3, 4), t)
+    attr = DenseElementsAttr.get_splat(shaped_type, element)
+    # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
+    print(attr)
 
 run(testGetDenseElementsSplatFloat)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
 def testGetDenseElementsSplatErrors():
-  ctx = mlir.ir.Context()
-  loc = ctx.get_unknown_location()
-  t = mlir.ir.F32Type.get(ctx)
-  other_t = mlir.ir.F64Type.get(ctx)
-  element = mlir.ir.FloatAttr.get(t, 1.2, loc)
-  other_element = mlir.ir.FloatAttr.get(other_t, 1.2, loc)
-  shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc)
-  dynamic_shaped_type = mlir.ir.UnrankedTensorType.get(t, loc)
-  non_shaped_type = t
-
-  try:
-    attr = mlir.ir.DenseElementsAttr.get_splat(non_shaped_type, element)
-  except ValueError as e:
-    # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
-    print(e)
-
-  try:
-    attr = mlir.ir.DenseElementsAttr.get_splat(dynamic_shaped_type, element)
-  except ValueError as e:
-    # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
-    print(e)
-
-  try:
-    attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, other_element)
-  except ValueError as e:
-    # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
-    print(e)
+  with Context(), Location.unknown():
+    t = F32Type.get()
+    other_t = F64Type.get()
+    element = FloatAttr.get(t, 1.2)
+    other_element = FloatAttr.get(other_t, 1.2)
+    shaped_type = RankedTensorType.get((2, 3, 4), t)
+    dynamic_shaped_type = UnrankedTensorType.get(t)
+    non_shaped_type = t
+
+    try:
+      attr = DenseElementsAttr.get_splat(non_shaped_type, element)
+    except ValueError as e:
+      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
+      print(e)
+
+    try:
+      attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
+    except ValueError as e:
+      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
+      print(e)
+
+    try:
+      attr = DenseElementsAttr.get_splat(shaped_type, other_element)
+    except ValueError as e:
+      # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
+      print(e)
 
 run(testGetDenseElementsSplatErrors)
 
@@ -102,24 +99,24 @@ def testGetDenseElementsSplatErrors():
 
 # CHECK-LABEL: TEST: testGetDenseElementsF32
 def testGetDenseElementsF32():
-  ctx = mlir.ir.Context()
-  array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
-  print(attr)
-  # CHECK: is_splat: False
-  print("is_splat:", attr.is_splat)
+  with Context():
+    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
+    print(attr)
+    # CHECK: is_splat: False
+    print("is_splat:", attr.is_splat)
 
 run(testGetDenseElementsF32)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsF64
 def testGetDenseElementsF64():
-  ctx = mlir.ir.Context()
-  array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
-  print(attr)
+  with Context():
+    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
+    print(attr)
 
 run(testGetDenseElementsF64)
 
@@ -127,43 +124,43 @@ def testGetDenseElementsF64():
 ### 32 bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI32Signless
 def testGetDenseElementsI32Signless():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+    print(attr)
 
 run(testGetDenseElementsI32Signless)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
 def testGetDenseElementsUI32Signless():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+    print(attr)
 
 run(testGetDenseElementsUI32Signless)
 
 # CHECK-LABEL: TEST: testGetDenseElementsI32
 def testGetDenseElementsI32():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+    attr = DenseElementsAttr.get(array, signless=False)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
+    print(attr)
 
 run(testGetDenseElementsI32)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI32
 def testGetDenseElementsUI32():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+    attr = DenseElementsAttr.get(array, signless=False)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
+    print(attr)
 
 run(testGetDenseElementsUI32)
 
@@ -171,43 +168,43 @@ def testGetDenseElementsUI32():
 ## 64bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI64Signless
 def testGetDenseElementsI64Signless():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+    print(attr)
 
 run(testGetDenseElementsI64Signless)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
 def testGetDenseElementsUI64Signless():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+    print(attr)
 
 run(testGetDenseElementsUI64Signless)
 
 # CHECK-LABEL: TEST: testGetDenseElementsI64
 def testGetDenseElementsI64():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+    attr = DenseElementsAttr.get(array, signless=False)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
+    print(attr)
 
 run(testGetDenseElementsI64)
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI64
 def testGetDenseElementsUI64():
-  ctx = mlir.ir.Context()
-  array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
-  attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False)
-  # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
-  print(attr)
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+    attr = DenseElementsAttr.get(array, signless=False)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
+    print(attr)
 
 run(testGetDenseElementsUI64)
 

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index d1f3e6b4a61a..39d69483d1b7 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -1,19 +1,19 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import mlir
+from mlir.ir import *
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testParsePrint
 def testParsePrint():
-  ctx = mlir.ir.Context()
-  t = ctx.parse_attr('"hello"')
+  with Context() as ctx:
+    t = Attribute.parse('"hello"')
   assert t.context is ctx
   ctx = None
   gc.collect()
@@ -29,156 +29,155 @@ def testParsePrint():
 # TODO: Hook the diagnostic manager to capture a more meaningful error
 # message.
 def testParseError():
-  ctx = mlir.ir.Context()
-  try:
-    t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST")
-  except ValueError as e:
-    # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
-    print("testParseError:", e)
-  else:
-    print("Exception not produced")
+  with Context():
+    try:
+      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
+    except ValueError as e:
+      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
+      print("testParseError:", e)
+    else:
+      print("Exception not produced")
 
 run(testParseError)
 
 
 # CHECK-LABEL: TEST: testAttrEq
 def testAttrEq():
-  ctx = mlir.ir.Context()
-  a1 = ctx.parse_attr('"attr1"')
-  a2 = ctx.parse_attr('"attr2"')
-  a3 = ctx.parse_attr('"attr1"')
-  # CHECK: a1 == a1: True
-  print("a1 == a1:", a1 == a1)
-  # CHECK: a1 == a2: False
-  print("a1 == a2:", a1 == a2)
-  # CHECK: a1 == a3: True
-  print("a1 == a3:", a1 == a3)
-  # CHECK: a1 == None: False
-  print("a1 == None:", a1 == None)
+  with Context():
+    a1 = Attribute.parse('"attr1"')
+    a2 = Attribute.parse('"attr2"')
+    a3 = Attribute.parse('"attr1"')
+    # CHECK: a1 == a1: True
+    print("a1 == a1:", a1 == a1)
+    # CHECK: a1 == a2: False
+    print("a1 == a2:", a1 == a2)
+    # CHECK: a1 == a3: True
+    print("a1 == a3:", a1 == a3)
+    # CHECK: a1 == None: False
+    print("a1 == None:", a1 == None)
 
 run(testAttrEq)
 
 
 # CHECK-LABEL: TEST: testAttrEqDoesNotRaise
 def testAttrEqDoesNotRaise():
-  ctx = mlir.ir.Context()
-  a1 = ctx.parse_attr('"attr1"')
-  not_an_attr = "foo"
-  # CHECK: False
-  print(a1 == not_an_attr)
-  # CHECK: False
-  print(a1 == None)
-  # CHECK: True
-  print(a1 != None)
+  with Context():
+    a1 = Attribute.parse('"attr1"')
+    not_an_attr = "foo"
+    # CHECK: False
+    print(a1 == not_an_attr)
+    # CHECK: False
+    print(a1 == None)
+    # CHECK: True
+    print(a1 != None)
 
 run(testAttrEqDoesNotRaise)
 
 
 # CHECK-LABEL: TEST: testStandardAttrCasts
 def testStandardAttrCasts():
-  ctx = mlir.ir.Context()
-  a1 = ctx.parse_attr('"attr1"')
-  astr = mlir.ir.StringAttr(a1)
-  aself = mlir.ir.StringAttr(astr)
-  # CHECK: Attribute("attr1")
-  print(repr(astr))
-  try:
-    tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0"))
-  except ValueError as e:
-    # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
-    print("ValueError:", e)
-  else:
-    print("Exception not produced")
+  with Context():
+    a1 = Attribute.parse('"attr1"')
+    astr = StringAttr(a1)
+    aself = StringAttr(astr)
+    # CHECK: Attribute("attr1")
+    print(repr(astr))
+    try:
+      tillegal = StringAttr(Attribute.parse("1.0"))
+    except ValueError as e:
+      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
+      print("ValueError:", e)
+    else:
+      print("Exception not produced")
 
 run(testStandardAttrCasts)
 
 
 # CHECK-LABEL: TEST: testFloatAttr
 def testFloatAttr():
-  ctx = mlir.ir.Context()
-  fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32"))
-  # CHECK: fattr value: 42.0
-  print("fattr value:", fattr.value)
-
-  # Test factory methods.
-  loc = ctx.get_unknown_location()
-  # CHECK: default_get: 4.200000e+01 : f32
-  print("default_get:", mlir.ir.FloatAttr.get(
-      mlir.ir.F32Type.get(ctx), 42.0, loc))
-  # CHECK: f32_get: 4.200000e+01 : f32
-  print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
-  # CHECK: f64_get: 4.200000e+01 : f64
-  print("f64_get:", mlir.ir.FloatAttr.get_f64(ctx, 42.0))
-  try:
-    fattr_invalid = mlir.ir.FloatAttr.get(
-        mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc)
-  except ValueError as e:
-    # CHECK: invalid 'Type(i32)' and expected floating point type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context(), Location.unknown():
+    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
+    # CHECK: fattr value: 42.0
+    print("fattr value:", fattr.value)
+
+    # Test factory methods.
+    # CHECK: default_get: 4.200000e+01 : f32
+    print("default_get:", FloatAttr.get(
+        F32Type.get(), 42.0))
+    # CHECK: f32_get: 4.200000e+01 : f32
+    print("f32_get:", FloatAttr.get_f32(42.0))
+    # CHECK: f64_get: 4.200000e+01 : f64
+    print("f64_get:", FloatAttr.get_f64(42.0))
+    try:
+      fattr_invalid = FloatAttr.get(
+          IntegerType.get_signless(32), 42)
+    except ValueError as e:
+      # CHECK: invalid 'Type(i32)' and expected floating point type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testFloatAttr)
 
 
 # CHECK-LABEL: TEST: testIntegerAttr
 def testIntegerAttr():
-  ctx = mlir.ir.Context()
-  iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
-  # CHECK: iattr value: 42
-  print("iattr value:", iattr.value)
-  # CHECK: iattr type: i64
-  print("iattr type:", iattr.type)
-
-  # Test factory methods.
-  # CHECK: default_get: 42 : i32
-  print("default_get:", mlir.ir.IntegerAttr.get(
-      mlir.ir.IntegerType.get_signless(ctx, 32), 42))
+  with Context() as ctx:
+    iattr = IntegerAttr(Attribute.parse("42"))
+    # CHECK: iattr value: 42
+    print("iattr value:", iattr.value)
+    # CHECK: iattr type: i64
+    print("iattr type:", iattr.type)
+
+    # Test factory methods.
+    # CHECK: default_get: 42 : i32
+    print("default_get:", IntegerAttr.get(
+        IntegerType.get_signless(32), 42))
 
 run(testIntegerAttr)
 
 
 # CHECK-LABEL: TEST: testBoolAttr
 def testBoolAttr():
-  ctx = mlir.ir.Context()
-  battr = mlir.ir.BoolAttr(ctx.parse_attr("true"))
-  # CHECK: iattr value: 1
-  print("iattr value:", battr.value)
+  with Context() as ctx:
+    battr = BoolAttr(Attribute.parse("true"))
+    # CHECK: iattr value: 1
+    print("iattr value:", battr.value)
 
-  # Test factory methods.
-  # CHECK: default_get: true
-  print("default_get:", mlir.ir.BoolAttr.get(ctx, True))
+    # Test factory methods.
+    # CHECK: default_get: true
+    print("default_get:", BoolAttr.get(True))
 
 run(testBoolAttr)
 
 
 # CHECK-LABEL: TEST: testStringAttr
 def testStringAttr():
-  ctx = mlir.ir.Context()
-  sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"'))
-  # CHECK: sattr value: stringattr
-  print("sattr value:", sattr.value)
-
-  # Test factory methods.
-  # CHECK: default_get: "foobar"
-  print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar"))
-  # CHECK: typed_get: "12345" : i32
-  print("typed_get:", mlir.ir.StringAttr.get_typed(
-      mlir.ir.IntegerType.get_signless(ctx, 32), "12345"))
+  with Context() as ctx:
+    sattr = StringAttr(Attribute.parse('"stringattr"'))
+    # CHECK: sattr value: stringattr
+    print("sattr value:", sattr.value)
+
+    # Test factory methods.
+    # CHECK: default_get: "foobar"
+    print("default_get:", StringAttr.get("foobar"))
+    # CHECK: typed_get: "12345" : i32
+    print("typed_get:", StringAttr.get_typed(
+        IntegerType.get_signless(32), "12345"))
 
 run(testStringAttr)
 
 
 # CHECK-LABEL: TEST: testNamedAttr
 def testNamedAttr():
-  ctx = mlir.ir.Context()
-  a = ctx.parse_attr('"stringattr"')
-  named = a.get_named("foobar")  # Note: under the small object threshold
-  # CHECK: attr: "stringattr"
-  print("attr:", named.attr)
-  # CHECK: name: foobar
-  print("name:", named.name)
-  # CHECK: named: NamedAttribute(foobar="stringattr")
-  print("named:", named)
+  with Context():
+    a = Attribute.parse('"stringattr"')
+    named = a.get_named("foobar")  # Note: under the small object threshold
+    # CHECK: attr: "stringattr"
+    print("attr:", named.attr)
+    # CHECK: name: foobar
+    print("name:", named.name)
+    # CHECK: named: NamedAttribute(foobar="stringattr")
+    print("named:", named)
 
 run(testNamedAttr)

diff  --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py
index f7e99242b8ad..e0d1bf299f5b 100644
--- a/mlir/test/Bindings/Python/ir_location.py
+++ b/mlir/test/Bindings/Python/ir_location.py
@@ -1,19 +1,19 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import mlir
+from mlir.ir import *
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testUnknown
 def testUnknown():
-  ctx = mlir.ir.Context()
-  loc = ctx.get_unknown_location()
+  with Context() as ctx:
+    loc = Location.unknown()
   assert loc.context is ctx
   ctx = None
   gc.collect()
@@ -27,8 +27,8 @@ def testUnknown():
 
 # CHECK-LABEL: TEST: testFileLineCol
 def testFileLineCol():
-  ctx = mlir.ir.Context()
-  loc = ctx.get_file_location("foo.txt", 123, 56)
+  with Context() as ctx:
+    loc = Location.file("foo.txt", 123, 56)
   ctx = None
   gc.collect()
   # CHECK: file str: loc("foo.txt":123:56)

diff  --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py
index 5f3403809e83..7a270b8bc33c 100644
--- a/mlir/test/Bindings/Python/ir_module.py
+++ b/mlir/test/Bindings/Python/ir_module.py
@@ -1,21 +1,21 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import mlir
+from mlir.ir import *
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testParseSuccess
 # CHECK: module @successfulParse
 def testParseSuccess():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""module @successfulParse {}""")
+  ctx = Context()
+  module = Module.parse(r"""module @successfulParse {}""", ctx)
   assert module.context is ctx
   print("CLEAR CONTEXT")
   ctx = None  # Ensure that module captures the context.
@@ -30,9 +30,9 @@ def testParseSuccess():
 # CHECK-LABEL: TEST: testParseError
 # CHECK: testParseError: Unable to parse module assembly (see diagnostics)
 def testParseError():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   try:
-    module = ctx.parse_module(r"""}SYNTAX ERROR{""")
+    module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
   except ValueError as e:
     print("testParseError:", e)
   else:
@@ -45,9 +45,9 @@ def testParseError():
 # CHECK-LABEL: TEST: testCreateEmpty
 # CHECK: module {
 def testCreateEmpty():
-  ctx = mlir.ir.Context()
-  loc = ctx.get_unknown_location()
-  module = ctx.create_module(loc)
+  ctx = Context()
+  loc = Location.unknown(ctx)
+  module = Module.create(loc)
   print("CLEAR CONTEXT")
   ctx = None  # Ensure that module captures the context.
   gc.collect()
@@ -63,10 +63,10 @@ def testCreateEmpty():
 # CHECK: func @roundtripUnicode()
 # CHECK: foo = "\F0\9F\98\8A"
 def testRoundtripUnicode():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""
+  ctx = Context()
+  module = Module.parse(r"""
     func @roundtripUnicode() attributes { foo = "😊" }
-  """)
+  """, ctx)
   print(str(module))
 
 run(testRoundtripUnicode)
@@ -75,8 +75,8 @@ def testRoundtripUnicode():
 # Tests that module.operation works and correctly interns instances.
 # CHECK-LABEL: TEST: testModuleOperation
 def testModuleOperation():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""module @successfulParse {}""")
+  ctx = Context()
+  module = Module.parse(r"""module @successfulParse {}""", ctx)
   assert ctx._get_live_module_count() == 1
   op1 = module.operation
   assert ctx._get_live_operation_count() == 1
@@ -106,13 +106,13 @@ def testModuleOperation():
 
 # CHECK-LABEL: TEST: testModuleCapsule
 def testModuleCapsule():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""module @successfulParse {}""")
+  ctx = Context()
+  module = Module.parse(r"""module @successfulParse {}""", ctx)
   assert ctx._get_live_module_count() == 1
   # CHECK: "mlir.ir.Module._CAPIPtr"
   module_capsule = module._CAPIPtr
   print(module_capsule)
-  module_dup = mlir.ir.Module._CAPICreate(module_capsule)
+  module_dup = Module._CAPICreate(module_capsule)
   assert module is module_dup
   assert module_dup.context is ctx
   # Gc and verify destructed.

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 8bc2ced60dca..9e0ba3071073 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -3,26 +3,26 @@
 import gc
 import io
 import itertools
-import mlir
+from mlir.ir import *
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 
 # Verify iterator based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
 def testTraverseOpRegionBlockIterators():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   ctx.allow_unregistered_dialects = True
-  module = ctx.parse_module(r"""
+  module = Module.parse(r"""
     func @f1(%arg0: i32) -> i32 {
       %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
       return %1 : i32
     }
-  """)
+  """, ctx)
   op = module.operation
   assert op.context is ctx
   # Get the block using iterators off of the named collections.
@@ -69,14 +69,14 @@ def walk_operations(indent, op):
 # Verify index based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
 def testTraverseOpRegionBlockIndices():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   ctx.allow_unregistered_dialects = True
-  module = ctx.parse_module(r"""
+  module = Module.parse(r"""
     func @f1(%arg0: i32) -> i32 {
       %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
       return %1 : i32
     }
-  """)
+  """, ctx)
 
   def walk_operations(indent, op):
     for i in range(len(op.regions)):
@@ -105,28 +105,28 @@ def walk_operations(indent, op):
 
 # CHECK-LABEL: TEST: testBlockArgumentList
 def testBlockArgumentList():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""
-    func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
-      return
-    }
-  """)
-  func = module.body.operations[0]
-  entry_block = func.regions[0].blocks[0]
-  assert len(entry_block.arguments) == 3
-  # CHECK: Argument 0, type i32
-  # CHECK: Argument 1, type f64
-  # CHECK: Argument 2, type index
-  for arg in entry_block.arguments:
-    print(f"Argument {arg.arg_number}, type {arg.type}")
-    new_type = mlir.ir.IntegerType.get_signless(ctx, 8 * (arg.arg_number + 1))
-    arg.set_type(new_type)
-
-  # CHECK: Argument 0, type i8
-  # CHECK: Argument 1, type i16
-  # CHECK: Argument 2, type i24
-  for arg in entry_block.arguments:
-    print(f"Argument {arg.arg_number}, type {arg.type}")
+  with Context() as ctx:
+    module = Module.parse(r"""
+      func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
+        return
+      }
+    """, ctx)
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    assert len(entry_block.arguments) == 3
+    # CHECK: Argument 0, type i32
+    # CHECK: Argument 1, type f64
+    # CHECK: Argument 2, type index
+    for arg in entry_block.arguments:
+      print(f"Argument {arg.arg_number}, type {arg.type}")
+      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
+      arg.set_type(new_type)
+
+    # CHECK: Argument 0, type i8
+    # CHECK: Argument 1, type i16
+    # CHECK: Argument 2, type i24
+    for arg in entry_block.arguments:
+      print(f"Argument {arg.arg_number}, type {arg.type}")
 
 
 run(testBlockArgumentList)
@@ -134,18 +134,18 @@ def testBlockArgumentList():
 
 # CHECK-LABEL: TEST: testDetachedOperation
 def testDetachedOperation():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
-  op1 = ctx.create_operation(
-      "custom.op1", loc, results=[i32, i32], regions=1, attributes={
-          "foo": mlir.ir.StringAttr.get(ctx, "foo_value"),
-          "bar": mlir.ir.StringAttr.get(ctx, "bar_value"),
-      })
-  # CHECK: %0:2 = "custom.op1"() ( {
-  # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
-  print(op1)
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signed(32)
+    op1 = Operation.create(
+        "custom.op1", results=[i32, i32], regions=1, attributes={
+            "foo": StringAttr.get("foo_value"),
+            "bar": StringAttr.get("bar_value"),
+        })
+    # CHECK: %0:2 = "custom.op1"() ( {
+    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
+    print(op1)
 
   # TODO: Check successors once enough infra exists to do it properly.
 
@@ -154,30 +154,30 @@ def testDetachedOperation():
 
 # CHECK-LABEL: TEST: testOperationInsertionPoint
 def testOperationInsertionPoint():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   ctx.allow_unregistered_dialects = True
-  module = ctx.parse_module(r"""
+  module = Module.parse(r"""
     func @f1(%arg0: i32) -> i32 {
       %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
       return %1 : i32
     }
-  """)
+  """, ctx)
 
   # Create test op.
-  loc = ctx.get_unknown_location()
-  op1 = ctx.create_operation("custom.op1", loc)
-  op2 = ctx.create_operation("custom.op2", loc)
-
-  func = module.body.operations[0]
-  entry_block = func.regions[0].blocks[0]
-  ip = mlir.ir.InsertionPoint.at_block_begin(entry_block)
-  ip.insert(op1)
-  ip.insert(op2)
-  # CHECK: func @f1
-  # CHECK: "custom.op1"()
-  # CHECK: "custom.op2"()
-  # CHECK: %0 = "custom.addi"
-  print(module)
+  with Location.unknown(ctx):
+    op1 = Operation.create("custom.op1")
+    op2 = Operation.create("custom.op2")
+
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    ip = InsertionPoint.at_block_begin(entry_block)
+    ip.insert(op1)
+    ip.insert(op2)
+    # CHECK: func @f1
+    # CHECK: "custom.op1"()
+    # CHECK: "custom.op2"()
+    # CHECK: %0 = "custom.addi"
+    print(module)
 
   # Trying to add a previously added op should raise.
   try:
@@ -192,55 +192,55 @@ def testOperationInsertionPoint():
 
 # CHECK-LABEL: TEST: testOperationWithRegion
 def testOperationWithRegion():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   ctx.allow_unregistered_dialects = True
-  loc = ctx.get_unknown_location()
-  i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
-  op1 = ctx.create_operation("custom.op1", loc, regions=1)
-  block = op1.regions[0].blocks.append(i32, i32)
-  # CHECK: "custom.op1"() ( {
-  # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
-  # CHECK:   "custom.terminator"() : () -> ()
-  # CHECK: }) : () -> ()
-  terminator = ctx.create_operation("custom.terminator", loc)
-  ip = mlir.ir.InsertionPoint(block)
-  ip.insert(terminator)
-  print(op1)
-
-  # Now add the whole operation to another op.
-  # TODO: Verify lifetime hazard by nulling out the new owning module and
-  # accessing op1.
-  # TODO: Also verify accessing the terminator once both parents are nulled
-  # out.
-  module = ctx.parse_module(r"""
-    func @f1(%arg0: i32) -> i32 {
-      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
-      return %1 : i32
-    }
-  """)
-  func = module.body.operations[0]
-  entry_block = func.regions[0].blocks[0]
-  ip = mlir.ir.InsertionPoint.at_block_begin(entry_block)
-  ip.insert(op1)
-  # CHECK: func @f1
-  # CHECK: "custom.op1"()
-  # CHECK:   "custom.terminator"
-  # CHECK: %0 = "custom.addi"
-  print(module)
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signed(32)
+    op1 = Operation.create("custom.op1", regions=1)
+    block = op1.regions[0].blocks.append(i32, i32)
+    # CHECK: "custom.op1"() ( {
+    # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
+    # CHECK:   "custom.terminator"() : () -> ()
+    # CHECK: }) : () -> ()
+    terminator = Operation.create("custom.terminator")
+    ip = InsertionPoint(block)
+    ip.insert(terminator)
+    print(op1)
+
+    # Now add the whole operation to another op.
+    # TODO: Verify lifetime hazard by nulling out the new owning module and
+    # accessing op1.
+    # TODO: Also verify accessing the terminator once both parents are nulled
+    # out.
+    module = Module.parse(r"""
+      func @f1(%arg0: i32) -> i32 {
+        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+        return %1 : i32
+      }
+    """)
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    ip = InsertionPoint.at_block_begin(entry_block)
+    ip.insert(op1)
+    # CHECK: func @f1
+    # CHECK: "custom.op1"()
+    # CHECK:   "custom.terminator"
+    # CHECK: %0 = "custom.addi"
+    print(module)
 
 run(testOperationWithRegion)
 
 
 # CHECK-LABEL: TEST: testOperationResultList
 def testOperationResultList():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""
+  ctx = Context()
+  module = Module.parse(r"""
     func @f1() {
       %0:3 = call @f2() : () -> (i32, f64, index)
       return
     }
     func @f2() -> (i32, f64, index)
-  """)
+  """, ctx)
   caller = module.body.operations[0]
   call = caller.regions[0].blocks[0].operations[0]
   assert len(call.results) == 3
@@ -256,13 +256,13 @@ def testOperationResultList():
 
 # CHECK-LABEL: TEST: testOperationPrint
 def testOperationPrint():
-  ctx = mlir.ir.Context()
-  module = ctx.parse_module(r"""
+  ctx = Context()
+  module = Module.parse(r"""
     func @f1(%arg0: i32) -> i32 {
       %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
       return %arg0 : i32
     }
-  """)
+  """, ctx)
 
   # Test print to stdout.
   # CHECK: return %arg0 : i32

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 151a4679bd8c..ff058cb3bf93 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -1,19 +1,19 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import mlir
+from mlir.ir import *
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
-  assert mlir.ir.Context._get_live_count() == 0
+  assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testParsePrint
 def testParsePrint():
-  ctx = mlir.ir.Context()
-  t = ctx.parse_type("i32")
+  ctx = Context()
+  t = Type.parse("i32", ctx)
   assert t.context is ctx
   ctx = None
   gc.collect()
@@ -29,9 +29,9 @@ def testParsePrint():
 # TODO: Hook the diagnostic manager to capture a more meaningful error
 # message.
 def testParseError():
-  ctx = mlir.ir.Context()
+  ctx = Context()
   try:
-    t = ctx.parse_type("BAD_TYPE_DOES_NOT_EXIST")
+    t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
   except ValueError as e:
     # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
     print("testParseError:", e)
@@ -43,10 +43,10 @@ def testParseError():
 
 # CHECK-LABEL: TEST: testTypeEq
 def testTypeEq():
-  ctx = mlir.ir.Context()
-  t1 = ctx.parse_type("i32")
-  t2 = ctx.parse_type("f32")
-  t3 = ctx.parse_type("i32")
+  ctx = Context()
+  t1 = Type.parse("i32", ctx)
+  t2 = Type.parse("f32", ctx)
+  t3 = Type.parse("i32", ctx)
   # CHECK: t1 == t1: True
   print("t1 == t1:", t1 == t1)
   # CHECK: t1 == t2: False
@@ -61,8 +61,8 @@ def testTypeEq():
 
 # CHECK-LABEL: TEST: testTypeEqDoesNotRaise
 def testTypeEqDoesNotRaise():
-  ctx = mlir.ir.Context()
-  t1 = ctx.parse_type("i32")
+  ctx = Context()
+  t1 = Type.parse("i32", ctx)
   not_a_type = "foo"
   # CHECK: False
   print(t1 == not_a_type)
@@ -76,14 +76,14 @@ def testTypeEqDoesNotRaise():
 
 # CHECK-LABEL: TEST: testStandardTypeCasts
 def testStandardTypeCasts():
-  ctx = mlir.ir.Context()
-  t1 = ctx.parse_type("i32")
-  tint = mlir.ir.IntegerType(t1)
-  tself = mlir.ir.IntegerType(tint)
+  ctx = Context()
+  t1 = Type.parse("i32", ctx)
+  tint = IntegerType(t1)
+  tself = IntegerType(tint)
   # CHECK: Type(i32)
   print(repr(tint))
   try:
-    tillegal = mlir.ir.IntegerType(ctx.parse_type("f32"))
+    tillegal = IntegerType(Type.parse("f32", ctx))
   except ValueError as e:
     # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
     print("ValueError:", e)
@@ -95,91 +95,91 @@ def testStandardTypeCasts():
 
 # CHECK-LABEL: TEST: testIntegerType
 def testIntegerType():
-  ctx = mlir.ir.Context()
-  i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
-  # CHECK: i32 width: 32
-  print("i32 width:", i32.width)
-  # CHECK: i32 signless: True
-  print("i32 signless:", i32.is_signless)
-  # CHECK: i32 signed: False
-  print("i32 signed:", i32.is_signed)
-  # CHECK: i32 unsigned: False
-  print("i32 unsigned:", i32.is_unsigned)
-
-  s32 = mlir.ir.IntegerType(ctx.parse_type("si32"))
-  # CHECK: s32 signless: False
-  print("s32 signless:", s32.is_signless)
-  # CHECK: s32 signed: True
-  print("s32 signed:", s32.is_signed)
-  # CHECK: s32 unsigned: False
-  print("s32 unsigned:", s32.is_unsigned)
-
-  u32 = mlir.ir.IntegerType(ctx.parse_type("ui32"))
-  # CHECK: u32 signless: False
-  print("u32 signless:", u32.is_signless)
-  # CHECK: u32 signed: False
-  print("u32 signed:", u32.is_signed)
-  # CHECK: u32 unsigned: True
-  print("u32 unsigned:", u32.is_unsigned)
-
-  # CHECK: signless: i16
-  print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16))
-  # CHECK: signed: si8
-  print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8))
-  # CHECK: unsigned: ui64
-  print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
+  with Context() as ctx:
+    i32 = IntegerType(Type.parse("i32"))
+    # CHECK: i32 width: 32
+    print("i32 width:", i32.width)
+    # CHECK: i32 signless: True
+    print("i32 signless:", i32.is_signless)
+    # CHECK: i32 signed: False
+    print("i32 signed:", i32.is_signed)
+    # CHECK: i32 unsigned: False
+    print("i32 unsigned:", i32.is_unsigned)
+
+    s32 = IntegerType(Type.parse("si32"))
+    # CHECK: s32 signless: False
+    print("s32 signless:", s32.is_signless)
+    # CHECK: s32 signed: True
+    print("s32 signed:", s32.is_signed)
+    # CHECK: s32 unsigned: False
+    print("s32 unsigned:", s32.is_unsigned)
+
+    u32 = IntegerType(Type.parse("ui32"))
+    # CHECK: u32 signless: False
+    print("u32 signless:", u32.is_signless)
+    # CHECK: u32 signed: False
+    print("u32 signed:", u32.is_signed)
+    # CHECK: u32 unsigned: True
+    print("u32 unsigned:", u32.is_unsigned)
+
+    # CHECK: signless: i16
+    print("signless:", IntegerType.get_signless(16))
+    # CHECK: signed: si8
+    print("signed:", IntegerType.get_signed(8))
+    # CHECK: unsigned: ui64
+    print("unsigned:", IntegerType.get_unsigned(64))
 
 run(testIntegerType)
 
 # CHECK-LABEL: TEST: testIndexType
 def testIndexType():
-  ctx = mlir.ir.Context()
-  # CHECK: index type: index
-  print("index type:", mlir.ir.IndexType.get(ctx))
+  with Context() as ctx:
+    # CHECK: index type: index
+    print("index type:", IndexType.get())
 
 run(testIndexType)
 
 # CHECK-LABEL: TEST: testFloatType
 def testFloatType():
-  ctx = mlir.ir.Context()
-  # CHECK: float: bf16
-  print("float:", mlir.ir.BF16Type.get(ctx))
-  # CHECK: float: f16
-  print("float:", mlir.ir.F16Type.get(ctx))
-  # CHECK: float: f32
-  print("float:", mlir.ir.F32Type.get(ctx))
-  # CHECK: float: f64
-  print("float:", mlir.ir.F64Type.get(ctx))
+  with Context():
+    # CHECK: float: bf16
+    print("float:", BF16Type.get())
+    # CHECK: float: f16
+    print("float:", F16Type.get())
+    # CHECK: float: f32
+    print("float:", F32Type.get())
+    # CHECK: float: f64
+    print("float:", F64Type.get())
 
 run(testFloatType)
 
 # CHECK-LABEL: TEST: testNoneType
 def testNoneType():
-  ctx = mlir.ir.Context()
-  # CHECK: none type: none
-  print("none type:", mlir.ir.NoneType.get(ctx))
+  with Context():
+    # CHECK: none type: none
+    print("none type:", NoneType.get())
 
 run(testNoneType)
 
 # CHECK-LABEL: TEST: testComplexType
 def testComplexType():
-  ctx = mlir.ir.Context()
-  complex_i32 = mlir.ir.ComplexType(ctx.parse_type("complex<i32>"))
-  # CHECK: complex type element: i32
-  print("complex type element:", complex_i32.element_type)
-
-  f32 = mlir.ir.F32Type.get(ctx)
-  # CHECK: complex type: complex<f32>
-  print("complex type:", mlir.ir.ComplexType.get(f32))
-
-  index = mlir.ir.IndexType.get(ctx)
-  try:
-    complex_invalid = mlir.ir.ComplexType.get(index)
-  except ValueError as e:
-    # CHECK: invalid 'Type(index)' and expected floating point or integer type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context() as ctx:
+    complex_i32 = ComplexType(Type.parse("complex<i32>"))
+    # CHECK: complex type element: i32
+    print("complex type element:", complex_i32.element_type)
+
+    f32 = F32Type.get()
+    # CHECK: complex type: complex<f32>
+    print("complex type:", ComplexType.get(f32))
+
+    index = IndexType.get()
+    try:
+      complex_invalid = ComplexType.get(index)
+    except ValueError as e:
+      # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testComplexType)
 
@@ -188,26 +188,26 @@ def testComplexType():
 # vectors, memrefs and tensors, so this test case uses an instance of vector
 # to test the shaped type. The class hierarchy is preserved on the python side.
 def testConcreteShapedType():
-  ctx = mlir.ir.Context()
-  vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
-  # CHECK: element type: f32
-  print("element type:", vector.element_type)
-  # CHECK: whether the given shaped type is ranked: True
-  print("whether the given shaped type is ranked:", vector.has_rank)
-  # CHECK: rank: 2
-  print("rank:", vector.rank)
-  # CHECK: whether the shaped type has a static shape: True
-  print("whether the shaped type has a static shape:", vector.has_static_shape)
-  # CHECK: whether the dim-th dimension is dynamic: False
-  print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
-  # CHECK: dim size: 3
-  print("dim size:", vector.get_dim_size(1))
-  # CHECK: is_dynamic_size: False
-  print("is_dynamic_size:", vector.is_dynamic_size(3))
-  # CHECK: is_dynamic_stride_or_offset: False
-  print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
-  # CHECK: isinstance(ShapedType): True
-  print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType))
+  with Context() as ctx:
+    vector = VectorType(Type.parse("vector<2x3xf32>"))
+    # CHECK: element type: f32
+    print("element type:", vector.element_type)
+    # CHECK: whether the given shaped type is ranked: True
+    print("whether the given shaped type is ranked:", vector.has_rank)
+    # CHECK: rank: 2
+    print("rank:", vector.rank)
+    # CHECK: whether the shaped type has a static shape: True
+    print("whether the shaped type has a static shape:", vector.has_static_shape)
+    # CHECK: whether the dim-th dimension is dynamic: False
+    print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
+    # CHECK: dim size: 3
+    print("dim size:", vector.get_dim_size(1))
+    # CHECK: is_dynamic_size: False
+    print("is_dynamic_size:", vector.is_dynamic_size(3))
+    # CHECK: is_dynamic_stride_or_offset: False
+    print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
+    # CHECK: isinstance(ShapedType): True
+    print("isinstance(ShapedType):", isinstance(vector, ShapedType))
 
 run(testConcreteShapedType)
 
@@ -215,8 +215,8 @@ def testConcreteShapedType():
 # Tests that ShapedType operates as an abstract base class of a concrete
 # shaped type (using vector as an example).
 def testAbstractShapedType():
-  ctx = mlir.ir.Context()
-  vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>"))
+  ctx = Context()
+  vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
   # CHECK: element type: f32
   print("element type:", vector.element_type)
 
@@ -224,186 +224,184 @@ def testAbstractShapedType():
 
 # CHECK-LABEL: TEST: testVectorType
 def testVectorType():
-  ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type.get(ctx)
-  shape = [2, 3]
-  loc = ctx.get_unknown_location()
-  # CHECK: vector type: vector<2x3xf32>
-  print("vector type:", mlir.ir.VectorType.get(shape, f32, loc))
-
-  none = mlir.ir.NoneType.get(ctx)
-  try:
-    vector_invalid = mlir.ir.VectorType.get(shape, none, loc)
-  except ValueError as e:
-    # CHECK: invalid 'Type(none)' and expected floating point or integer type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context(), Location.unknown():
+    f32 = F32Type.get()
+    shape = [2, 3]
+    # CHECK: vector type: vector<2x3xf32>
+    print("vector type:", VectorType.get(shape, f32))
+
+    none = NoneType.get()
+    try:
+      vector_invalid = VectorType.get(shape, none)
+    except ValueError as e:
+      # CHECK: invalid 'Type(none)' and expected floating point or integer type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testVectorType)
 
 # CHECK-LABEL: TEST: testRankedTensorType
 def testRankedTensorType():
-  ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type.get(ctx)
-  shape = [2, 3]
-  loc = ctx.get_unknown_location()
-  # CHECK: ranked tensor type: tensor<2x3xf32>
-  print("ranked tensor type:",
-        mlir.ir.RankedTensorType.get(shape, f32, loc))
-
-  none = mlir.ir.NoneType.get(ctx)
-  try:
-    tensor_invalid = mlir.ir.RankedTensorType.get(shape, none, loc)
-  except ValueError as e:
-    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-    # CHECK: or complex type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context(), Location.unknown():
+    f32 = F32Type.get()
+    shape = [2, 3]
+    loc = Location.unknown()
+    # CHECK: ranked tensor type: tensor<2x3xf32>
+    print("ranked tensor type:",
+          RankedTensorType.get(shape, f32))
+
+    none = NoneType.get()
+    try:
+      tensor_invalid = RankedTensorType.get(shape, none)
+    except ValueError as e:
+      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+      # CHECK: or complex type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testRankedTensorType)
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
 def testUnrankedTensorType():
-  ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type.get(ctx)
-  loc = ctx.get_unknown_location()
-  unranked_tensor = mlir.ir.UnrankedTensorType.get(f32, loc)
-  # CHECK: unranked tensor type: tensor<*xf32>
-  print("unranked tensor type:", unranked_tensor)
-  try:
-    invalid_rank = unranked_tensor.rank
-  except ValueError as e:
-    # CHECK: calling this method requires that the type has a rank.
-    print(e)
-  else:
-    print("Exception not produced")
-  try:
-    invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
-  except ValueError as e:
-    # CHECK: calling this method requires that the type has a rank.
-    print(e)
-  else:
-    print("Exception not produced")
-  try:
-    invalid_get_dim_size = unranked_tensor.get_dim_size(1)
-  except ValueError as e:
-    # CHECK: calling this method requires that the type has a rank.
-    print(e)
-  else:
-    print("Exception not produced")
-
-  none = mlir.ir.NoneType.get(ctx)
-  try:
-    tensor_invalid = mlir.ir.UnrankedTensorType.get(none, loc)
-  except ValueError as e:
-    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-    # CHECK: or complex type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context(), Location.unknown():
+    f32 = F32Type.get()
+    loc = Location.unknown()
+    unranked_tensor = UnrankedTensorType.get(f32)
+    # CHECK: unranked tensor type: tensor<*xf32>
+    print("unranked tensor type:", unranked_tensor)
+    try:
+      invalid_rank = unranked_tensor.rank
+    except ValueError as e:
+      # CHECK: calling this method requires that the type has a rank.
+      print(e)
+    else:
+      print("Exception not produced")
+    try:
+      invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
+    except ValueError as e:
+      # CHECK: calling this method requires that the type has a rank.
+      print(e)
+    else:
+      print("Exception not produced")
+    try:
+      invalid_get_dim_size = unranked_tensor.get_dim_size(1)
+    except ValueError as e:
+      # CHECK: calling this method requires that the type has a rank.
+      print(e)
+    else:
+      print("Exception not produced")
+
+    none = NoneType.get()
+    try:
+      tensor_invalid = UnrankedTensorType.get(none)
+    except ValueError as e:
+      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+      # CHECK: or complex type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testUnrankedTensorType)
 
 # CHECK-LABEL: TEST: testMemRefType
 def testMemRefType():
-  ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type.get(ctx)
-  shape = [2, 3]
-  loc = ctx.get_unknown_location()
-  memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc)
-  # CHECK: memref type: memref<2x3xf32, 2>
-  print("memref type:", memref)
-  # CHECK: number of affine layout maps: 0
-  print("number of affine layout maps:", memref.num_affine_maps)
-  # CHECK: memory space: 2
-  print("memory space:", memref.memory_space)
-
-  none = mlir.ir.NoneType.get(ctx)
-  try:
-    memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2,
-                                                              loc)
-  except ValueError as e:
-    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-    # CHECK: or complex type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context(), Location.unknown():
+    f32 = F32Type.get()
+    shape = [2, 3]
+    loc = Location.unknown()
+    memref = MemRefType.get_contiguous_memref(f32, shape, 2)
+    # CHECK: memref type: memref<2x3xf32, 2>
+    print("memref type:", memref)
+    # CHECK: number of affine layout maps: 0
+    print("number of affine layout maps:", memref.num_affine_maps)
+    # CHECK: memory space: 2
+    print("memory space:", memref.memory_space)
+
+    none = NoneType.get()
+    try:
+      memref_invalid = MemRefType.get_contiguous_memref(none, shape, 2)
+    except ValueError as e:
+      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+      # CHECK: or complex type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testMemRefType)
 
 # CHECK-LABEL: TEST: testUnrankedMemRefType
 def testUnrankedMemRefType():
-  ctx = mlir.ir.Context()
-  f32 = mlir.ir.F32Type.get(ctx)
-  loc = ctx.get_unknown_location()
-  unranked_memref = mlir.ir.UnrankedMemRefType.get(f32, 2, loc)
-  # CHECK: unranked memref type: memref<*xf32, 2>
-  print("unranked memref type:", unranked_memref)
-  try:
-    invalid_rank = unranked_memref.rank
-  except ValueError as e:
-    # CHECK: calling this method requires that the type has a rank.
-    print(e)
-  else:
-    print("Exception not produced")
-  try:
-    invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
-  except ValueError as e:
-    # CHECK: calling this method requires that the type has a rank.
-    print(e)
-  else:
-    print("Exception not produced")
-  try:
-    invalid_get_dim_size = unranked_memref.get_dim_size(1)
-  except ValueError as e:
-    # CHECK: calling this method requires that the type has a rank.
-    print(e)
-  else:
-    print("Exception not produced")
-
-  none = mlir.ir.NoneType.get(ctx)
-  try:
-    memref_invalid = mlir.ir.UnrankedMemRefType.get(none, 2, loc)
-  except ValueError as e:
-    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-    # CHECK: or complex type.
-    print(e)
-  else:
-    print("Exception not produced")
+  with Context(), Location.unknown():
+    f32 = F32Type.get()
+    loc = Location.unknown()
+    unranked_memref = UnrankedMemRefType.get(f32, 2)
+    # CHECK: unranked memref type: memref<*xf32, 2>
+    print("unranked memref type:", unranked_memref)
+    try:
+      invalid_rank = unranked_memref.rank
+    except ValueError as e:
+      # CHECK: calling this method requires that the type has a rank.
+      print(e)
+    else:
+      print("Exception not produced")
+    try:
+      invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
+    except ValueError as e:
+      # CHECK: calling this method requires that the type has a rank.
+      print(e)
+    else:
+      print("Exception not produced")
+    try:
+      invalid_get_dim_size = unranked_memref.get_dim_size(1)
+    except ValueError as e:
+      # CHECK: calling this method requires that the type has a rank.
+      print(e)
+    else:
+      print("Exception not produced")
+
+    none = NoneType.get()
+    try:
+      memref_invalid = UnrankedMemRefType.get(none, 2)
+    except ValueError as e:
+      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+      # CHECK: or complex type.
+      print(e)
+    else:
+      print("Exception not produced")
 
 run(testUnrankedMemRefType)
 
 # CHECK-LABEL: TEST: testTupleType
 def testTupleType():
-  ctx = mlir.ir.Context()
-  i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
-  f32 = mlir.ir.F32Type.get(ctx)
-  vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
-  l = [i32, f32, vector]
-  tuple_type = mlir.ir.TupleType.get_tuple(ctx, l)
-  # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
-  print("tuple type:", tuple_type)
-  # CHECK: number of types: 3
-  print("number of types:", tuple_type.num_types)
-  # CHECK: pos-th type in the tuple type: f32
-  print("pos-th type in the tuple type:", tuple_type.get_type(1))
+  with Context() as ctx:
+    i32 = IntegerType(Type.parse("i32"))
+    f32 = F32Type.get()
+    vector = VectorType(Type.parse("vector<2x3xf32>"))
+    l = [i32, f32, vector]
+    tuple_type = TupleType.get_tuple(l)
+    # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
+    print("tuple type:", tuple_type)
+    # CHECK: number of types: 3
+    print("number of types:", tuple_type.num_types)
+    # CHECK: pos-th type in the tuple type: f32
+    print("pos-th type in the tuple type:", tuple_type.get_type(1))
 
 run(testTupleType)
 
 
 # CHECK-LABEL: TEST: testFunctionType
 def testFunctionType():
-  ctx = mlir.ir.Context()
-  input_types = [mlir.ir.IntegerType.get_signless(ctx, 32),
-                 mlir.ir.IntegerType.get_signless(ctx, 16)]
-  result_types = [mlir.ir.IndexType.get(ctx)]
-  func = mlir.ir.FunctionType.get(ctx, input_types, result_types)
-  # CHECK: INPUTS: [Type(i32), Type(i16)]
-  print("INPUTS:", func.inputs)
-  # CHECK: RESULTS: [Type(index)]
-  print("RESULTS:", func.results)
+  with Context() as ctx:
+    input_types = [IntegerType.get_signless(32),
+                  IntegerType.get_signless(16)]
+    result_types = [IndexType.get()]
+    func = FunctionType.get(input_types, result_types)
+    # CHECK: INPUTS: [Type(i32), Type(i16)]
+    print("INPUTS:", func.inputs)
+    # CHECK: RESULTS: [Type(index)]
+    print("RESULTS:", func.results)
 
 
 run(testFunctionType)


        


More information about the Mlir-commits mailing list