组合模式
组合模式(Composite Pattern)将对象组合成树形结构以表示"部分-整体"的层次结构,使得用户对单个对象和组合对象的使用具有一致性。
问题定义
场景1:文件系统
# ❌ 问题:没有统一接口
class File:
def __init__(self, name, size):
self.name = name
self.size = size
def get_size(self):
return self.size
class Directory:
def __init__(self, name):
self.name = name
self.children = []
def add(self, item):
self.children.append(item)
def get_size(self):
# 问题:需要处理不同类型
total = 0
for child in self.children:
if isinstance(child, File):
total += child.get_size()
elif isinstance(child, Directory):
total += child.get_size()
return total
# 问题:客户端需要区分File和Directory
file1 = File("file1.txt", 100)
file2 = File("file2.txt", 200)
dir1 = Directory("dir1")
dir1.add(file1)
dir1.add(file2)
print(dir1.get_size()) # 300
# 如果要添加Directory到Directory呢?
# 需要修改Directory的get_size方法
场景2:组织架构
# ❌ 问题:层级管理复杂
class Employee:
def __init__(self, name, salary):
self.name = name
self.salary = salary
class Manager:
def __init__(self, name, salary):
self.name = name
self.salary = salary
self.subordinates = []
def add_subordinate(self, emp):
self.subordinates.append(emp)
def calculate_total_salary(self):
# 问题:需要处理Employee和Manager
total = self.salary
for sub in self.subordinates:
if isinstance(sub, Manager):
total += sub.calculate_total_salary()
else:
total += sub.salary
return total
# 问题:Manager不能是Manager的下属?
# 需要修改代码
解决方案
组合模式让单个对象和组合对象具有统一接口,客户端可以一致地使用。
classDiagram
class Component {
<>
+operation()
+add()
+remove()
+getChild()
}
class Leaf {
+operation()
}
class Composite {
-children: Component[]
+operation()
+add()
+remove()
+getChild()
}
Component <|-- Leaf
Component <|-- Composite
Composite o-- Component : contains
标准实现
抽象组件
from abc import ABC, abstractmethod
class FileSystemNode(ABC):
"""文件系统节点抽象类"""
@abstractmethod
def get_size(self):
"""获取大小"""
pass
@abstractmethod
def print_structure(self, indent=0):
"""打印结构"""
pass
def add(self, node):
"""添加子节点 - 默认实现(叶子节点不支持)"""
raise NotImplementedError("叶子节点不支持添加子节点")
def remove(self, node):
"""移除子节点 - 默认实现"""
raise NotImplementedError("叶子节点不支持移除子节点")
叶子节点(文件)
class File(FileSystemNode):
"""文件 - 叶子节点"""
def __init__(self, name, size):
self.name = name
self.size = size
def get_size(self):
return self.size
def print_structure(self, indent=0):
print(" " * indent + f"📄 {self.name} ({self.size} bytes)")
组合节点(目录)
class Directory(FileSystemNode):
"""目录 - 组合节点"""
def __init__(self, name):
self.name = name
self.children = []
def add(self, node):
"""添加子节点"""
self.children.append(node)
def remove(self, node):
"""移除子节点"""
self.children.remove(node)
def get_size(self):
"""计算总大小"""
total = 0
for child in self.children:
total += child.get_size()
return total
def print_structure(self, indent=0):
print(" " * indent + f"📁 {self.name}/")
for child in self.children:
child.print_structure(indent + 1)
客户端使用
# 创建文件结构
file1 = File("readme.txt", 1024)
file2 = File("license.txt", 2048)
file3 = File("image.png", 51200)
docs_dir = Directory("docs")
docs_dir.add(file1)
docs_dir.add(file2)
images_dir = Directory("images")
images_dir.add(file3)
root = Directory("root")
root.add(docs_dir)
root.add(images_dir)
root.add(File("main.py", 8192))
# 客户端可以一致地处理文件和目录
print(f"总大小: {root.get_size()} bytes")
# 总大小: 64264 bytes
print("\n文件结构:")
root.print_structure()
# 📁 root/
# 📁 docs/
# 📄 readme.txt (1024 bytes)
# 📄 license.txt (2048 bytes)
# 📁 images/
# 📄 image.png (51200 bytes)
# 📄 main.py (8192 bytes)
实战应用
应用1:组织架构管理
from abc import ABC, abstractmethod
class Employee(ABC):
"""员工抽象类"""
def __init__(self, name, salary):
self.name = name
self.salary = salary
@abstractmethod
def get_total_salary(self):
"""获取总薪资"""
pass
@abstractmethod
def print_info(self, indent=0):
"""打印信息"""
pass
class Worker(Employee):
"""普通员工 - 叶子节点"""
def get_total_salary(self):
return self.salary
def print_info(self, indent=0):
print(" " * indent + f"👤 {self.name} (${self.salary})")
class Manager(Employee):
"""经理 - 组合节点"""
def __init__(self, name, salary):
super().__init__(name, salary)
self.subordinates = []
def add_subordinate(self, emp):
"""添加下属"""
self.subordinates.append(emp)
def remove_subordinate(self, emp):
"""移除下属"""
self.subordinates.remove(emp)
def get_total_salary(self):
"""计算团队总薪资"""
total = self.salary
for emp in self.subordinates:
total += emp.get_total_salary()
return total
def print_info(self, indent=0):
print(" " * indent + f"👔 {self.name} (${self.salary})")
for emp in self.subordinates:
emp.print_info(indent + 1)
# 创建组织架构
alice = Worker("Alice", 5000)
bob = Worker("Bob", 6000)
charlie = Worker("Charlie", 5500)
dev_team_lead = Manager("David", 8000)
dev_team_lead.add_subordinate(alice)
dev_team_lead.add_subordinate(bob)
dev_team_lead.add_subordinate(charlie)
diana = Worker("Diana", 7000)
eric = Worker("Eric", 7500)
qa_team_lead = Manager("Frank", 9000)
qa_team_lead.add_subordinate(diana)
qa_team_lead.add_subordinate(eric)
cto = Manager("George", 15000)
cto.add_subordinate(dev_team_lead)
cto.add_subordinate(qa_team_lead)
# 使用
print(f"公司总薪资: ${cto.get_total_salary()}")
print("\n组织架构:")
cto.print_info()
# 👔 George ($15000)
# 👔 David ($8000)
# 👤 Alice ($5000)
# 👤 Bob ($6000)
# 👤 Charlie ($5500)
# 👔 Frank ($9000)
# 👤 Diana ($7000)
# 👤 Eric ($7500)
应用2:图形编辑器
from abc import ABC, abstractmethod
class Graphic(ABC):
"""图形抽象类"""
@abstractmethod
def draw(self):
"""绘制"""
pass
@abstractmethod
def move(self, x, y):
"""移动"""
pass
class Circle(Graphic):
"""圆形 - 叶子节点"""
def __init__(self, x, y, radius):
self.x = x
self.y = y
self.radius = radius
def draw(self):
print(f"绘制圆形: 中心({self.x}, {self.y}), 半径{self.radius}")
def move(self, x, y):
self.x = x
self.y = y
class Rectangle(Graphic):
"""矩形 - 叶子节点"""
def __init__(self, x, y, width, height):
self.x = x
self.y = y
self.width = width
self.height = height
def draw(self):
print(f"绘制矩形: 左上角({self.x}, {self.y}), {self.width}x{self.height}")
def move(self, x, y):
self.x = x
self.y = y
class GraphicGroup(Graphic):
"""图形组 - 组合节点"""
def __init__(self):
self.graphics = []
def add(self, graphic):
"""添加图形"""
self.graphics.append(graphic)
def remove(self, graphic):
"""移除图形"""
self.graphics.remove(graphic)
def draw(self):
"""绘制所有图形"""
for graphic in self.graphics:
graphic.draw()
def move(self, x, y):
"""移动所有图形"""
for graphic in self.graphics:
graphic.move(x, y)
# 创建图形
circle1 = Circle(10, 10, 5)
circle2 = Circle(20, 20, 3)
rect1 = Rectangle(0, 0, 10, 10)
rect2 = Rectangle(15, 15, 20, 20)
# 创建组合
group1 = GraphicGroup()
group1.add(circle1)
group1.add(rect1)
group2 = GraphicGroup()
group2.add(circle2)
group2.add(rect2)
# 创建更大的组合
main_group = GraphicGroup()
main_group.add(group1)
main_group.add(group2)
# 统一操作
print("绘制所有图形:")
main_group.draw()
# 绘制圆形: 中心(10, 10), 半径5
# 绘制矩形: 左上角(0, 0), 10x10
# 绘制圆形: 中心(20, 20), 半径3
# 绘制矩形: 左上角(15, 15), 20x20
print("\n移动所有图形:")
main_group.move(5, 5)
main_group.draw()
# 绘制圆形: 中心(5, 5), 半径5
# 绘制矩形: 左上角(5, 5), 10x10
# 绘制圆形: 中心(5, 5), 半径3
# 绘制矩形: 左上角(5, 5), 20x20
应用3:菜单系统
from abc import ABC, abstractmethod
class MenuItem(ABC):
"""菜单项抽象类"""
@abstractmethod
def print_menu(self, indent=0):
"""打印菜单"""
pass
@abstractmethod
def is_vegetarian(self):
"""是否素食"""
pass
class Dish(MenuItem):
"""菜品 - 叶子节点"""
def __init__(self, name, price, vegetarian=False):
self.name = name
self.price = price
self.vegetarian = vegetarian
def print_menu(self, indent=0):
v_mark = "(V)" if self.vegetarian else ""
print(" " * indent + f"🍽️ {self.name} ${self.price} {v_mark}")
def is_vegetarian(self):
return self.vegetarian
class MenuCategory(MenuItem):
"""菜单分类 - 组合节点"""
def __init__(self, name):
self.name = name
self.items = []
def add(self, item):
"""添加菜品或分类"""
self.items.append(item)
def remove(self, item):
"""移除菜品或分类"""
self.items.remove(item)
def print_menu(self, indent=0):
print(" " * indent + f"📂 {self.name}/")
for item in self.items:
item.print_menu(indent + 1)
def is_vegetarian(self):
"""至少包含一个素食菜品"""
for item in self.items:
if item.is_vegetarian():
return True
return False
# 创建菜单
menu = MenuCategory("主菜单")
# 添加分类
appetizers = MenuCategory("开胃菜")
main_courses = MenuCategory("主菜")
desserts = MenuCategory("甜点")
menu.add(appetizers)
menu.add(main_courses)
menu.add(desserts)
# 添加菜品
appetizers.add(Dish("沙拉", 8.99, True))
appetizers.add(Dish("汤", 6.99, True))
main_courses.add(Dish("牛排", 29.99))
main_courses.add(Dish("意大利面", 18.99, True))
main_courses.add(Dish("鸡肉", 22.99))
desserts.add(Dish("蛋糕", 9.99, True))
desserts.add(Dish("冰淇淋", 6.99, True))
# 打印菜单
print("=== 菜单 ===")
menu.print_menu()
# 📂 主菜单/
# 📂 开胃菜/
# 🍽️ 沙拉 $8.99 (V)
# 🍽️ 汤 $6.99 (V)
# 📂 主菜/
# 🍽️ 牛排 $29.99
# 🍽️ 意大利面 $18.99 (V)
# 🍽️ 鸡肉 $22.99
# 📂 甜点/
# 🍽️ 蛋糕 $9.99 (V)
# 🍽️ 冰淇淋 $6.99 (V)
# 查询素食选项
print(f"\n主菜有素食选项: {main_courses.is_vegetarian()}") # True
安全性 vs 透明性
透明性(Transparency)
# 透明性:所有组件都有相同接口
class FileSystemNode(ABC):
@abstractmethod
def add(self, node):
pass
@abstractmethod
def remove(self, node):
pass
# 叶子节点必须实现,但抛出异常
class File(FileSystemNode):
def add(self, node):
raise NotImplementedError
def remove(self, node):
raise NotImplementedError
安全性(Security)
# 安全性:只在组合类中有add/remove方法
class FileSystemNode(ABC):
@abstractmethod
def get_size(self):
pass
# 叶子节点不需要实现add/remove
class File(FileSystemNode):
def get_size(self):
return self.size
class Directory(FileSystemNode):
def add(self, node):
pass
def remove(self, node):
pass
| 特性 | 透明性 | 安全性 |
|---|---|---|
| 接口统一 | ✅ 全部相同 | ❌ 组合类有额外方法 |
| 类型安全 | ❌ 运行时错误 | ✅ 编译时检查 |
| 客户端代码 | 简单 | 需要类型检查 |
优缺点
✅ 优点
| 优点 | 说明 |
|---|---|
| 统一接口 | 一致地处理单个和组合对象 |
| 层次清晰 | 树形结构表达部分-整体 |
| 易于扩展 | 新增组件类型不影响客户端 |
| 灵活 | 可以自由组合对象 |
❌ 缺点
| 缺点 | 说明 |
|---|---|
| 复杂度 | 需要维护层级关系 |
| 限制性 | 叶子节点的限制 |
| 类型安全 | 透明性模式下可能类型不安全 |
与其他模式的关系
| 模式 | 关系 |
|---|---|
| 装饰器 | 装饰器模式经常与组合模式一起使用 |
| 迭代器 | 可以使用迭代器遍历组合结构 |
| 访问者 | 访问者模式可以操作组合结构 |
本章要点
- ✅ 组合模式让单个对象和组合对象统一
- ✅ 构建树形结构表达部分-整体关系
- ✅ 客户端无需区分处理单个和组合对象
- ✅ 适用于文件系统、组织架构、图形编辑等场景
- ✅ 透明性 vs 安全性两种设计选择
- ✅ 与装饰器、迭代器模式结合使用
恭喜完成结构型模式部分!
下一步:观察者模式 开始学习行为型模式 🚀