← Back to Python series
🚀
Advanced
functools.wraps · Stacking · Class decorators · __enter__ · __exit__

Week 2 — Decorators & Context Managers

Master decorators for cross-cutting concerns like logging, caching, and timing. Then implement reusable context managers with both the class-based and contextlib.contextmanager approaches.

decoratorcontext managerfunctoolscontextlibwraps
Duration
2.5 hours
Level
📊 Advanced
Prerequisite
🎯 Intermediate Week 1
OUTCOME
Write a @retry decorator and a managed_connection context manager

What you'll learn

  • 1Write a decorator factory that accepts arguments
  • 2Stack multiple decorators correctly
  • 3Use functools.wraps to preserve metadata
  • 4Implement __enter__ and __exit__ for a context manager class
  • 5Use @contextmanager for generator-based context managers

1. Basic Decorators

python
import functools, time

def timer(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"{func.__name__} took {elapsed:.4f}s")
        return result
    return wrapper

@timer
def slow_sum(n):
    return sum(range(n))

print(slow_sum(1_000_000))

2. Decorator Factories (with arguments)

python
import functools

def retry(max_attempts=3, exceptions=(Exception,)):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(1, max_attempts + 1):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    if attempt == max_attempts:
                        raise
                    print(f"  Attempt {attempt} failed: {e}. Retrying...")
        return wrapper
    return decorator

@retry(max_attempts=3, exceptions=(ValueError,))
def risky():
    import random
    if random.random() < 0.7:
        raise ValueError("unlucky")
    return "success"

3. Context Managers

python
from contextlib import contextmanager
import time

@contextmanager
def timer_ctx(label=""):
    start = time.perf_counter()
    try:
        yield
    finally:
        elapsed = time.perf_counter() - start
        print(f"{label} {elapsed:.4f}s")

with timer_ctx("matrix multiply:"):
    result = sum(i*j for i in range(1000) for j in range(100))

💻 Examples

Run these examples and check the output yourself.

01_decorators.pyComposable utility decorators
CODE
import functools, time, logging
logging.basicConfig(level=logging.INFO)

def log_call(func):
    @functools.wraps(func)
    def wrapper(*a, **kw):
        logging.info(f"Calling {func.__name__}({a}, {kw})")
        result = func(*a, **kw)
        logging.info(f"{func.__name__} returned {result!r}")
        return result
    return wrapper

def memoize(func):
    cache = {}
    @functools.wraps(func)
    def wrapper(*a):
        if a not in cache:
            cache[a] = func(*a)
        return cache[a]
    return wrapper

@log_call
@memoize
def fib(n):
    if n <= 1: return n
    return fib(n-1) + fib(n-2)

print(fib(10))

📝 Exercises

Try them yourself first, then open the solution to compare.

Exercise 1

@validate_types decorator

Goal: Write a decorator that checks argument types at runtime using function annotations.

Requirements
  • Reads __annotations__ from the wrapped function
  • Raises TypeError with helpful message if type doesn't match
  • Handles missing annotations gracefully (skip check)
Toggle solution
SOLUTION
import functools

def validate_types(func):
    hints = func.__annotations__
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        import inspect
        sig = inspect.signature(func)
        bound = sig.bind(*args, **kwargs)
        for param, val in bound.arguments.items():
            if param in hints and param != "return":
                expected = hints[param]
                if not isinstance(val, expected):
                    raise TypeError(f"{param} must be {expected.__name__}, got {type(val).__name__}")
        return func(*args, **kwargs)
    return wrapper

@validate_types
def add(a: int, b: int) -> int:
    return a + b

print(add(3, 4))       # 7
try:
    add(3, "4")
except TypeError as e:
    print(e)           # b must be int, got str
▶ Output
7
b must be int, got str
Example code / lecture materials

All lecture materials and example code are openly available on GitHub.

View on GitHub ↗