How do I detect the usage of a ContextManager in Python?

47 Views Asked by At

I am trying to refactor the following:

with MyContext() as ctx:
    ctx.some_function()

...into something more like this:

with MyContext():
    some_function()

How can I detect in the body of some_function that I have called it within the context of MyContext()? I would prefer that this is done in a thread-safe way.

This appears to be possible because it is done in the builtin decimal module:

from decimal import localcontext

with localcontext() as ctx:
    ctx.prec = 42   # Perform a high precision calculation
    s = calculate_something()
2

There are 2 best solutions below

0
J_H On

It sounds like you want to crawl up the call stack, looking for evidence of a context manager.

#! /usr/bin/env python

from io import StringIO
import dis
import inspect


class MyManager:
    def __enter__(self) -> None:
        print("enter")

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        print("exit")


def app() -> None:
    with MyManager() as m:
        report_on_ctx_mgr()


def report_on_ctx_mgr() -> None:
    stack = inspect.stack()
    assert "app" == stack[1].function
    fn = globals()[stack[1].function]
    src = inspect.getsource(fn)
    print(list(filter(_contains_with, src.splitlines())))

    out = StringIO()
    dis.dis(fn, file=out)
    disasm = out.getvalue()
    if "MyManager" in disasm:
        print(disasm)


def _contains_with(s: str) -> bool:
    return "with " in s


if __name__ == "__main__":
    app()

output:

['    with MyManager() as m:']
 16           0 RESUME                   0

 17           2 LOAD_GLOBAL              1 (NULL + MyManager)
...
0
Andrej Kesely On

Another solution, using threading.local for thread-local data to store contexts:

import threading
from contextlib import contextmanager

thread_local_data = threading.local()


@contextmanager
def MyContext():
    try:
        # every thread has different `ctx` value
        thread_local_data.ctx = "[some context]"
        yield
    finally:
        pass


def some_function():
    print(f"My context is {thread_local_data.ctx}")


with MyContext():
    some_function()

Prints:

My context is [some context]