6c7e8590273ae9cb3b432ca60c8929d4c390b0cc
[wrw.git] / wrw / env.py
1 import threading, weakref
2
3 __all__ = ["environment", "root", "get", "binding", "var"]
4
5 class stack(object):
6     __slots__ = ["env", "prev"]
7     def __init__(self, env, prev):
8         self.env = env
9         self.prev = prev
10
11 class environment(object):
12     __slots__ = ["parent", "map"]
13     def __init__(self, parent = None):
14         self.parent = None
15         self.map = weakref.WeakKeyDictionary()
16
17     def get(self, var):
18         if var in self.map:
19             return self.map[var]
20         if self.parent is None:
21             return None
22         return self.parent.get(var)
23
24     def set(self, var, val):
25         self.map[var] = val
26
27     def __enter__(self):
28         cur = context.env
29         context.prev = stack(cur, context.prev)
30         context.env = self
31         return None
32
33     def __exit__(self, *excinfo):
34         prev = context.prev
35         if prev is None:
36             raise Exception("Unbalanced __enter__/__exit__")
37         context.env = prev.env
38         context.prev = prev.prev
39         return False
40
41 root = environment()
42
43 class context(threading.local):
44     env = root
45     prev = None
46 context = context()
47
48 def get():
49     return context.env
50
51 class binding(object):
52     __slots__ = ["bindings"]
53     def __init__(self, bindings):
54         if isinstance(bindings, dict):
55             bindings = bindings.items()
56         self.bindings = bindings
57
58     def __enter__(self):
59         cur = context.env
60         new = environment(cur)
61         for var, val in self.bindings:
62             new.map[var] = val
63         context.prev = stack(cur, context.prev)
64         context.env = new
65         return None
66
67     def __exit__(self, *excinfo):
68         prev = context.prev
69         if prev is None:
70             raise Exception("Unbalanced __enter__/__exit__")
71         context.env = prev.env
72         context.prev = prev.prev
73         return False
74
75 class var(object):
76     __slots__ = ["__weakref__"]
77     def __init__(self, default = None):
78         if default is not None:
79             root.map[self] = default
80
81     @property
82     def val(self):
83         return context.env.get(self)
84
85     def binding(self, val):
86         return binding([(self, val)])