迭代器协议与工具链
High Contrast
Dark Mode
Light Mode
Sepia
Forest
1 min read221 words

迭代器协议与工具链

理解迭代器协议是写出高效 Python 的基础——itertools 和 functools 是标准库中最强大的函数式工具。

迭代器协议

graph TD ITER[迭代器协议] --> ITERABLE["Iterable: __iter__()"] ITER --> ITERATOR["Iterator: __next__()"] ITER --> GEN[Generator: yield] ITERABLE --> LIST[list / tuple / dict] ITERABLE --> FILE[文件对象] ITERABLE --> CUSTOM[自定义类] GEN --> LAZY[惰性求值] GEN --> MEM[省内存] GEN --> PIPE[管道组合] style ITER fill:#e3f2fd,stroke:#1565c0,stroke-width:2px style GEN fill:#c8e6c9,stroke:#388e3c,stroke-width:2px

生成器深入

"""
生成器:惰性求值的核心
"""
# 基本生成器
def fibonacci(limit: int):
"""无限斐波那契数列(到 limit 终止)"""
a, b = 0, 1
while a < limit:
yield a
a, b = b, a + b
for n in fibonacci(100):
print(n, end=" ")
# 0 1 1 2 3 5 8 13 21 34 55 89
# 生成器表达式
squares = (x ** 2 for x in range(1_000_000))  # 不占内存
first_10 = [next(squares) for _ in range(10)]
# yield from:委托生成器
def flatten(nested):
"""递归展平嵌套列表"""
for item in nested:
if isinstance(item, (list, tuple)):
yield from flatten(item)
else:
yield item
data = [1, [2, 3], [4, [5, 6]], 7]
print(list(flatten(data)))  # [1, 2, 3, 4, 5, 6, 7]
# 生成器管道
def read_lines(path: str):
"""逐行读取"""
with open(path) as f:
for line in f:
yield line.strip()
def filter_errors(lines):
"""过滤错误行"""
for line in lines:
if "ERROR" in line:
yield line
def extract_message(lines):
"""提取消息"""
for line in lines:
parts = line.split(" ", 3)
if len(parts) >= 4:
yield parts[3]
# 组合:惰性处理大文件
# messages = extract_message(filter_errors(read_lines("app.log")))

itertools 实战

"""
itertools:组合、过滤、分组一网打尽
"""
import itertools
# === 无限迭代器 ===
# count(10)     → 10, 11, 12, ...
# cycle([1,2])  → 1, 2, 1, 2, ...
# repeat('x')   → 'x', 'x', 'x', ...
# 带序号循环
colors = ["红", "绿", "蓝"]
for i, color in zip(itertools.count(1), colors):
print(f"{i}. {color}")
# === 组合 ===
print(list(itertools.product("AB", "12")))
# [('A','1'), ('A','2'), ('B','1'), ('B','2')]
print(list(itertools.combinations("ABCD", 2)))
# [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')]
print(list(itertools.permutations("ABC", 2)))
# [('A','B'), ('A','C'), ('B','A'), ('B','C'), ('C','A'), ('C','B')]
# === 过滤和切片 ===
data = range(20)
print(list(itertools.islice(data, 5, 10)))     # [5, 6, 7, 8, 9]
print(list(itertools.takewhile(lambda x: x < 5, data)))  # [0, 1, 2, 3, 4]
print(list(itertools.dropwhile(lambda x: x < 15, data))) # [15, 16, 17, 18, 19]
# === 分组 ===
records = [
{"dept": "eng", "name": "Alice"},
{"dept": "eng", "name": "Bob"},
{"dept": "sales", "name": "Carol"},
{"dept": "sales", "name": "Dave"},
]
# 注意:groupby 要求数据已按 key 排序
for dept, members in itertools.groupby(records, key=lambda r: r["dept"]):
names = [m["name"] for m in members]
print(f"{dept}: {names}")
# eng: ['Alice', 'Bob']
# sales: ['Carol', 'Dave']
# === chain:合并多个迭代器 ===
a = [1, 2, 3]
b = [4, 5, 6]
c = [7, 8, 9]
print(list(itertools.chain(a, b, c)))  # [1, 2, ..., 9]
# chain.from_iterable:展平一层
matrix = [[1, 2], [3, 4], [5, 6]]
flat = list(itertools.chain.from_iterable(matrix))
print(flat)  # [1, 2, 3, 4, 5, 6]

functools 实战

"""
functools:缓存、偏函数、归约
"""
from functools import lru_cache, partial, reduce
# === lru_cache 缓存 ===
@lru_cache(maxsize=256)
def expensive_query(user_id: int) -> dict:
"""模拟耗时数据库查询"""
import time
time.sleep(0.1)  # 模拟延迟
return {"id": user_id, "name": f"User_{user_id}"}
# 第一次慢,之后直接返回缓存
result = expensive_query(42)
result = expensive_query(42)  # 命中缓存,瞬间返回
print(expensive_query.cache_info())
# CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
# === partial 偏函数 ===
def power(base, exp):
return base ** exp
square = partial(power, exp=2)
cube = partial(power, exp=3)
print(square(5))  # 25
print(cube(3))    # 27
# 实战:预配置的日志函数
import logging
log = logging.getLogger("app")
info = partial(log.log, logging.INFO)
error = partial(log.log, logging.ERROR)
# === reduce 归约 ===
nums = [1, 2, 3, 4, 5]
total = reduce(lambda acc, x: acc + x, nums)
print(total)  # 15
# 实战:深度合并字典
def deep_merge(base: dict, override: dict) -> dict:
result = base.copy()
for k, v in override.items():
if k in result and isinstance(result[k], dict) and isinstance(v, dict):
result[k] = deep_merge(result[k], v)
else:
result[k] = v
return result
configs = [
{"db": {"host": "localhost", "port": 5432}},
{"db": {"port": 3306}, "debug": True},
{"cache": {"ttl": 300}},
]
final = reduce(deep_merge, configs)
print(final)
# {'db': {'host': 'localhost', 'port': 3306}, 'debug': True, 'cache': {'ttl': 300}}

工具速查

函数 用途 示例
chain 合并迭代器 chain([1,2], [3,4]) → 1,2,3,4
islice 迭代器切片 islice(range(100), 5, 10)
groupby 分组(需预排序) 按部门分组员工
product 笛卡尔积 参数组合测试
lru_cache 自动缓存 缓存数据库查询
partial 固定部分参数 预配置函数
reduce 累积归约 多字典合并

本章小结

知识点 要点
生成器 yield 惰性求值、yield from 委托
管道 生成器组合实现流式处理
itertools chain / groupby / product / islice
functools lru_cache / partial / reduce
性能 惰性 > 一次性加载,省内存

下一章:并发与异步编程——让程序同时干多件事。