Skip to content

Commit 5bf6b32

Browse files
committed
Add atomic_var, string_replace, string_template.
1 parent 38ba9ad commit 5bf6b32

File tree

6 files changed

+363
-0
lines changed

6 files changed

+363
-0
lines changed

‎src/strif/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@
2929
)
3030

3131
from .strif import * # noqa: F403
32+
from .string_template import * # noqa: F403

‎src/strif/atomic_var.py‎

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import copy as cpy
2+
import threading
3+
from collections.abc import Callable
4+
from contextlib import contextmanager
5+
from typing import Any, Generic, TypeVar
6+
7+
T = TypeVar("T")
8+
9+
10+
def value_is_immutable(obj: Any) -> bool:
11+
"""
12+
Check if a value is of a known immutable type. Just a heuristic for common
13+
cases and not perfect.
14+
"""
15+
immutable_types = (int, float, bool, str, tuple, frozenset, type(None), bytes, complex)
16+
if isinstance(obj, immutable_types):
17+
return True
18+
if hasattr(obj, "__dataclass_params__") and getattr(obj.__dataclass_params__, "frozen", False):
19+
return True
20+
return False
21+
22+
23+
class AtomicVar(Generic[T]):
24+
"""
25+
`AtomicVar` is a simple zero-dependency thread-safe variable that works
26+
for any type.
27+
28+
Often the standard "Pythonic" approach is to use locks directly, but for
29+
some common use cases, `AtomicVar` may be simpler and more readable.
30+
Works on any type, including lists and dicts.
31+
32+
Other options include `threading.Event` (for shared booleans),
33+
`threading.Queue` (for producer-consumer queues), and `multiprocessing.Value`
34+
(for process-safe primitives).
35+
36+
Examples:
37+
38+
```python
39+
# Immutable types are always safe:
40+
count = AtomicVar(0)
41+
count.update(lambda x: x + 5) # In any thread.
42+
count.set(0) # In any thread.
43+
current_count = count.value # In any thread.
44+
45+
# Useful for flags:
46+
global_flag = AtomicVar(False)
47+
global_flag.set(True) # In any thread.
48+
if global_flag: # In any thread.
49+
print("Flag is set")
50+
51+
52+
# For mutable types,consider using `copy` or `deepcopy` to access the value:
53+
my_list = AtomicVar([1, 2, 3])
54+
my_list_copy = my_list.copy() # In any thread.
55+
my_list_deepcopy = my_list.deepcopy() # In any thread.
56+
57+
# For mutable types, the `updates()` context manager gives a simple way to
58+
# lock on updates:
59+
with my_list.updates() as value:
60+
value.append(5)
61+
62+
# Or if you prefer, via a function:
63+
my_list.update(lambda x: x.append(4)) # In any thread.
64+
65+
# You can also use the var's lock directly. In particular, this encapsulates
66+
# locked one-time initialization:
67+
initialized = AtomicVar(False)
68+
with initialized.lock:
69+
if not initialized: # checks truthiness of underlying value
70+
expensive_setup()
71+
initialized.set(True)
72+
73+
# Or:
74+
lazy_var: AtomicVar[list[str] | None] = AtomicVar(None)
75+
with lazy_var.lock:
76+
if not lazy_var:
77+
lazy_var.set(expensive_calculation())
78+
```
79+
"""
80+
81+
def __init__(self, initial_value: T, is_immutable: bool | None = None):
82+
self._value: T = initial_value
83+
# Use an RLock just in case we read from the var while in an update().
84+
self.lock: threading.RLock = threading.RLock()
85+
self.is_immutable: bool
86+
if is_immutable is None:
87+
self.is_immutable = value_is_immutable(initial_value)
88+
else:
89+
self.is_immutable = is_immutable
90+
91+
@property
92+
def value(self) -> T:
93+
"""
94+
Current value. For immutable types, this is thread safe. For mutable types,
95+
this gives direct access to the value, so you should consider using `copy` or
96+
`deepcopy` instead.
97+
"""
98+
with self.lock:
99+
return self._value
100+
101+
def copy(self) -> T:
102+
"""
103+
Shallow copy of the current value.
104+
"""
105+
with self.lock:
106+
return cpy.copy(self._value)
107+
108+
def deepcopy(self) -> T:
109+
"""
110+
Deep copy of the current value.
111+
"""
112+
with self.lock:
113+
return cpy.deepcopy(self._value)
114+
115+
def set(self, new_value: T) -> None:
116+
with self.lock:
117+
self._value = new_value
118+
119+
def swap(self, new_value: T) -> T:
120+
"""
121+
Set to new value and return the old value.
122+
"""
123+
with self.lock:
124+
old_value = self._value
125+
self._value = new_value
126+
return old_value
127+
128+
def update(self, update_func: Callable[[T], T | None]) -> T:
129+
"""
130+
Update value with a function and return the new value.
131+
132+
The `update_func` can either return a new value or update a mutable type in place,
133+
in which case it should return None. Always returns the final value of the
134+
variable after the update.
135+
"""
136+
with self.lock:
137+
result = update_func(self._value)
138+
if result is not None:
139+
self._value = result
140+
# Always return the potentially updated self._value
141+
return self._value
142+
143+
@contextmanager
144+
def updates(self):
145+
"""
146+
Context manager for convenient thread-safe updates. Only applicable to
147+
mutable types.
148+
149+
Example usage:
150+
```
151+
my_list = AtomicVar([1, 2, 3])
152+
with my_list.updates() as value:
153+
value.append(4)
154+
```
155+
"""
156+
# Sanity check to avoid accidental use with atomic/immutable types.
157+
if self.is_immutable:
158+
raise ValueError("Cannot use AtomicVar.updates() context manager on an immutable value")
159+
with self.lock:
160+
yield self._value
161+
162+
def __bool__(self) -> bool:
163+
"""
164+
Truthiness matches that of the underlying value.
165+
"""
166+
return bool(self.value)
167+
168+
def __repr__(self) -> str:
169+
return f"{self.__class__.__name__}({self.value!r})"
170+
171+
def __str__(self) -> str:
172+
return str(self.value)

‎src/strif/string_replace.py‎

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import TypeAlias
2+
3+
Insertion = tuple[int, str]
4+
5+
6+
def insert_multiple(text: str, insertions: list[Insertion]) -> str:
7+
"""
8+
Insert multiple strings into `text` at the given offsets, at once.
9+
"""
10+
chunks: list[str] = []
11+
last_end = 0
12+
for offset, insertion in sorted(insertions, key=lambda x: x[0]):
13+
chunks.append(text[last_end:offset])
14+
chunks.append(insertion)
15+
last_end = offset
16+
chunks.append(text[last_end:])
17+
return "".join(chunks)
18+
19+
20+
Replacement: TypeAlias = tuple[int, int, str]
21+
22+
23+
def replace_multiple(text: str, replacements: list[Replacement]) -> str:
24+
"""
25+
Replace multiple substrings in `text` with new strings, simultaneously.
26+
The replacements are a list of tuples (start_offset, end_offset, new_string).
27+
"""
28+
replacements = sorted(replacements, key=lambda x: x[0])
29+
chunks: list[str] = []
30+
last_end = 0
31+
for start, end, new_text in replacements:
32+
if start < last_end:
33+
raise ValueError("Overlapping replacements are not allowed.")
34+
chunks.append(text[last_end:start])
35+
chunks.append(new_text)
36+
last_end = end
37+
chunks.append(text[last_end:])
38+
return "".join(chunks)

‎src/strif/string_template.py‎

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from collections.abc import Sequence
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
6+
@dataclass(frozen=True)
7+
class StringTemplate:
8+
"""
9+
A validated template string that supports only specified fields.
10+
Can subclass to have a type with a given set of `allowed_fields`.
11+
Provide a type with a field name to allow validation of int/float format strings.
12+
13+
Examples:
14+
>>> t = StringTemplate("{name} is {age} years old", ["name", "age"])
15+
>>> t.format(name="Alice", age=30)
16+
'Alice is 30 years old'
17+
18+
>>> t = StringTemplate("{count:3d}@{price:.2f}", [("count", int), ("price", float)])
19+
>>> t.format(count=10, price=19.99)
20+
' 10@19.99'
21+
"""
22+
23+
template: str
24+
25+
allowed_fields: Sequence[str | tuple[str, type | None]]
26+
"""List of allowed field names. If `d` or `f` formats are used, give tuple with the type."""
27+
# Sequence is covariant so compatible with List[str]
28+
29+
strict: bool = False
30+
"""If True, raise a ValueError if the template is missing an allowed field."""
31+
32+
def __post_init__(self):
33+
if not isinstance(self.template, str): # pyright: ignore[reportUnnecessaryIsInstance]
34+
raise ValueError("Template must be a string")
35+
36+
# Confirm only the allowed fields are in the template.
37+
field_types = self._field_types()
38+
try:
39+
placeholder_values = {field: (type or str)(123) for field, type in field_types.items()}
40+
self.template.format(**placeholder_values)
41+
except KeyError as e:
42+
raise ValueError(f"Template contains unsupported variable: {e}") from None
43+
except ValueError as e:
44+
raise ValueError(
45+
f"Invalid template (forgot to provide a type when using non-str format strings?): {e}"
46+
) from None
47+
48+
def _field_types(self) -> dict[str, type | None]:
49+
return {
50+
field[0] if isinstance(field, tuple) else field: (
51+
field[1] if isinstance(field, tuple) else None
52+
)
53+
for field in self.allowed_fields
54+
}
55+
56+
def format(self, **kwargs: Any) -> str:
57+
field_types = self._field_types()
58+
allowed_keys = field_types.keys()
59+
unexpected_keys = set(kwargs.keys()) - allowed_keys
60+
if self.strict and unexpected_keys:
61+
raise ValueError(f"Unexpected keyword arguments: {', '.join(unexpected_keys)}")
62+
63+
# Type check the values, if types were provided.
64+
for f, expected_type in field_types.items():
65+
if f in kwargs and expected_type:
66+
if not isinstance(kwargs[f], expected_type):
67+
raise ValueError(
68+
f"Invalid type for '{f}': expected {expected_type.__name__} but got {repr(kwargs[f])} ({type(kwargs[f]).__name__})"
69+
)
70+
71+
return self.template.format(**kwargs)
72+
73+
def __bool__(self) -> bool:
74+
return bool(self.template)

‎tests/test_string_replace.py‎

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from strif.string_replace import Insertion, Replacement, insert_multiple, replace_multiple
2+
3+
4+
def test_insert_multiple():
5+
text = "hello world"
6+
insertions: list[Insertion] = [(5, ",")]
7+
expected = "hello, world"
8+
assert insert_multiple(text, insertions) == expected, "Single insertion failed"
9+
10+
text = "hello world"
11+
insertions = [(0, "Start "), (11, " End")]
12+
expected = "Start hello world End"
13+
assert insert_multiple(text, insertions) == expected, "Multiple insertions failed"
14+
15+
text = "short"
16+
insertions = [(10, " end")]
17+
expected = "short end"
18+
assert insert_multiple(text, insertions) == expected, "Out of bounds insertion failed"
19+
20+
text = "negative test"
21+
insertions = [(-1, "ss")]
22+
expected = "negative tessst"
23+
assert insert_multiple(text, insertions) == expected, "Negative offset insertion failed"
24+
25+
text = "no change"
26+
insertions = []
27+
expected = "no change"
28+
assert insert_multiple(text, insertions) == expected, "Empty insertions failed"
29+
30+
31+
def test_replace_multiple():
32+
text = "The quick brown fox"
33+
replacements: list[Replacement] = [(4, 9, "slow"), (16, 19, "dog")]
34+
expected = "The slow brown dog"
35+
assert replace_multiple(text, replacements) == expected, "Multiple replacements failed"
36+
37+
text = "overlap test"
38+
replacements = [(0, 6, "start"), (5, 10, "end")]
39+
try:
40+
replace_multiple(text, replacements)
41+
raise AssertionError("Overlapping replacements did not raise ValueError")
42+
except ValueError:
43+
pass # Expected exception
44+
45+
text = "short text"
46+
replacements = [(5, 10, " longer text")]
47+
expected = "short longer text"
48+
assert replace_multiple(text, replacements) == expected, "Out of bounds replacement failed"
49+
50+
text = "no change"
51+
replacements = []
52+
expected = "no change"
53+
assert replace_multiple(text, replacements) == expected, "Empty replacements failed"

‎tests/test_string_template.py‎

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from strif.string_template import StringTemplate
2+
3+
4+
def test_string_template():
5+
t = StringTemplate("{name} is {age} years old", ["name", "age"])
6+
assert t.format(name="Alice", age=30) == "Alice is 30 years old"
7+
8+
t = StringTemplate("{count:3d}@{price:.2f}", [("count", int), ("price", float)])
9+
assert t.format(count=10, price=19.99) == " 10@19.99"
10+
11+
try:
12+
StringTemplate("{name} {age}", ["name"])
13+
raise AssertionError("Should have raised ValueError")
14+
except ValueError as e:
15+
assert "Template contains unsupported variable: 'age'" in str(e)
16+
17+
t = StringTemplate("{count:d}", [("count", int)])
18+
try:
19+
t.format(count="not an int")
20+
raise AssertionError("Should have raised ValueError")
21+
except ValueError as e:
22+
assert "Invalid type for 'count': expected int but got 'not an int' (str)" in str(e)
23+
24+
assert not StringTemplate("", [])
25+
assert StringTemplate("not empty", [])

0 commit comments

Comments
 (0)