策略模式
策略模式(Strategy Pattern)定义一系列算法,把它们一个个封装起来,并且使它们可相互替换。本模式使得算法的变化可独立于使用它的客户。
问题定义
场景1:排序算法
# ❌ 问题:大量条件判断
class Sorter:
"""排序器"""
def sort(self, data, algorithm):
if algorithm == "bubble":
self._bubble_sort(data)
elif algorithm == "quick":
self._quick_sort(data)
elif algorithm == "merge":
self._merge_sort(data)
elif algorithm == "heap":
self._heap_sort(data)
# 添加新算法需要修改这里!
def _bubble_sort(self, data):
print("使用冒泡排序")
def _quick_sort(self, data):
print("使用快速排序")
def _merge_sort(self, data):
print("使用归并排序")
def _heap_sort(self, data):
print("使用堆排序")
# 问题:每次添加新算法都要修改Sorter类
场景2:支付方式
# ❌ 问题:扩展困难
class PaymentProcessor:
"""支付处理器"""
def process_payment(self, payment_type, amount):
if payment_type == "credit_card":
self._process_credit_card(amount)
elif payment_type == "paypal":
self._process_paypal(amount)
elif payment_type == "wechat":
self._process_wechat(amount)
# 添加新支付方式需要修改这里
def _process_credit_card(self, amount):
print(f"信用卡支付 ${amount}")
def _process_paypal(self, amount):
print(f"PayPal支付 ${amount}")
def _process_wechat(self, amount):
print(f"微信支付 ¥{amount}")
解决方案
策略模式将算法封装成独立的类,使用时可以灵活替换。
classDiagram
class Strategy {
<>
+algorithm()
}
class ConcreteStrategyA {
+algorithm()
}
class ConcreteStrategyB {
+algorithm()
}
class Context {
-strategy: Strategy
+setStrategy()
+executeAlgorithm()
}
Strategy <|-- ConcreteStrategyA
Strategy <|-- ConcreteStrategyB
Context o-- Strategy : uses
标准实现
策略接口
from abc import ABC, abstractmethod
class SortStrategy(ABC):
"""排序策略接口"""
@abstractmethod
def sort(self, data):
pass
具体策略
class BubbleSortStrategy(SortStrategy):
"""冒泡排序策略"""
def sort(self, data):
print("使用冒泡排序")
n = len(data)
for i in range(n):
for j in range(0, n - i - 1):
if data[j] > data[j + 1]:
data[j], data[j + 1] = data[j + 1], data[j]
return data
class QuickSortStrategy(SortStrategy):
"""快速排序策略"""
def sort(self, data):
print("使用快速排序")
if len(data) <= 1:
return data
pivot = data[len(data) // 2]
left = [x for x in data if x < pivot]
middle = [x for x in data if x == pivot]
right = [x for x in data if x > pivot]
return self.sort(left) + middle + self.sort(right)
class MergeSortStrategy(SortStrategy):
"""归并排序策略"""
def sort(self, data):
print("使用归并排序")
if len(data) <= 1:
return data
mid = len(data) // 2
left = self.sort(data[:mid])
right = self.sort(data[mid:])
return self._merge(left, right)
def _merge(self, left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return data
上下文类
class Sorter:
"""排序器 - 上下文"""
def __init__(self, strategy: SortStrategy):
self._strategy = strategy
def set_strategy(self, strategy: SortStrategy):
"""设置排序策略"""
self._strategy = strategy
def sort(self, data):
"""执行排序"""
return self._strategy.sort(data[:]) # 复制数据,不修改原数据
客户端使用
# 数据
data = [64, 34, 25, 12, 22, 11, 90]
# 创建排序器,使用冒泡排序
sorter = Sorter(BubbleSortStrategy())
result1 = sorter.sort(data)
print(f"结果: {result1}\n")
# 切换到快速排序
sorter.set_strategy(QuickSortStrategy())
result2 = sorter.sort(data)
print(f"结果: {result2}\n")
# 切换到归并排序
sorter.set_strategy(MergeSortStrategy())
result3 = sorter.sort(data)
print(f"结果: {result3}")
实战应用
应用1:支付系统
from abc import ABC, abstractmethod
class PaymentStrategy(ABC):
"""支付策略接口"""
@abstractmethod
def pay(self, amount):
pass
class CreditCardPayment(PaymentStrategy):
"""信用卡支付"""
def pay(self, amount):
print(f"💳 信用卡支付: ${amount}")
print(" 验证信用卡信息...")
print(" 扣款成功")
return f"CC-{amount}"
class PayPalPayment(PaymentStrategy):
"""PayPal支付"""
def pay(self, amount):
print(f"🅿️ PayPal支付: ${amount}")
print(" 跳转到PayPal...")
print(" 支付成功")
return f"PP-{amount}"
class WeChatPayment(PaymentStrategy):
"""微信支付"""
def pay(self, amount):
print(f"💬 微信支付: ¥{amount}")
print(" 打开微信...")
print(" 支付成功")
return f"WX-{amount}"
class AlipayPayment(PaymentStrategy):
"""支付宝支付"""
def pay(self, amount):
print(f"🔷 支付宝支付: ¥{amount}")
print(" 打开支付宝...")
print(" 支付成功")
return f"ALI-{amount}"
class PaymentContext:
"""支付上下文"""
def __init__(self, strategy: PaymentStrategy):
self._strategy = strategy
def set_strategy(self, strategy: PaymentStrategy):
self._strategy = strategy
def pay(self, amount):
return self._strategy.pay(amount)
# 使用
context = PaymentContext(CreditCardPayment())
context.pay(100)
print()
context.set_strategy(WeChatPayment())
context.pay(500)
print()
context.set_strategy(AlipayPayment())
context.pay(200)
应用2:出行路线规划
from abc import ABC, abstractmethod
class RouteStrategy(ABC):
"""路线策略接口"""
@abstractmethod
def build_route(self, origin, destination):
pass
class FastestRouteStrategy(RouteStrategy):
"""最快路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的最快路线")
print(" 使用高速公路")
print(" 预计时间: 45分钟")
print(" 距离: 80公里")
print(" 过路费: $15")
class ShortestRouteStrategy(RouteStrategy):
"""最短路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的最短路线")
print(" 使用省道")
print(" 预计时间: 65分钟")
print(" 距离: 55公里")
print(" 过路费: $5")
class CheapestRouteStrategy(RouteStrategy):
"""最经济路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的最经济路线")
print(" 使用普通道路")
print(" 预计时间: 90分钟")
print(" 距离: 60公里")
print(" 过路费: $0")
class ScenicRouteStrategy(RouteStrategy):
"""风景路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的风景路线")
print(" 沿海岸线行驶")
print(" 预计时间: 120分钟")
print(" 距离: 100公里")
print(" 过路费: $10")
print(" 经过景点: 3个")
class Navigation:
"""导航系统"""
def __init__(self, strategy: RouteStrategy):
self._strategy = strategy
def set_strategy(self, strategy: RouteStrategy):
self._strategy = strategy
def navigate(self, origin, destination):
return self._strategy.build_route(origin, destination)
# 使用
nav = Navigation(FastestRouteStrategy())
nav.navigate("北京", "天津")
print()
nav.set_strategy(ScenicRouteStrategy())
nav.navigate("北京", "天津")
应用3:数据压缩
from abc import ABC, abstractmethod
class CompressionStrategy(ABC):
"""压缩策略接口"""
@abstractmethod
def compress(self, data):
pass
@abstractmethod
def decompress(self, compressed_data):
pass
class ZipCompression(CompressionStrategy):
"""ZIP压缩"""
def compress(self, data):
print(f"📦 使用ZIP压缩: {len(data)} bytes")
print(" 压缩率: 70%")
return f"ZIP:{data}"
def decompress(self, compressed_data):
print(f"📦 解压ZIP: {compressed_data}")
return compressed_data.replace("ZIP:", "")
class GzipCompression(CompressionStrategy):
"""GZIP压缩"""
def compress(self, data):
print(f"📦 使用GZIP压缩: {len(data)} bytes")
print(" 压缩率: 80%")
return f"GZIP:{data}"
def decompress(self, compressed_data):
print(f"📦 解压GZIP: {compressed_data}")
return compressed_data.replace("GZIP:", "")
class RarCompression(CompressionStrategy):
"""RAR压缩"""
def compress(self, data):
print(f"📦 使用RAR压缩: {len(data)} bytes")
print(" 压缩率: 75%")
return f"RAR:{data}"
def decompress(self, compressed_data):
print(f"📦 解压RAR: {compressed_data}")
return compressed_data.replace("RAR:", "")
class CompressionContext:
"""压缩上下文"""
def __init__(self, strategy: CompressionStrategy):
self._strategy = strategy
def set_strategy(self, strategy: CompressionStrategy):
self._strategy = strategy
def compress_data(self, data):
return self._strategy.compress(data)
def decompress_data(self, compressed_data):
return self._strategy.decompress(compressed_data)
# 使用
context = CompressionContext(ZipCompression())
compressed = context.compress_data("Hello World! " * 1000)
print()
context.set_strategy(GzipCompression())
compressed = context.compress_data("Hello World! " * 1000)
应用4:AI模型选择
from abc import ABC, abstractmethod
class ModelStrategy(ABC):
"""AI模型策略接口"""
@abstractmethod
def train(self, data):
pass
@abstractmethod
def predict(self, input_data):
pass
class RandomForestModel(ModelStrategy):
"""随机森林模型"""
def train(self, data):
print("🌲 训练随机森林模型")
print(" 构建100棵决策树")
print(" 训练完成")
def predict(self, input_data):
print("🌲 随机森林预测")
return "随机森林预测结果"
class NeuralNetworkModel(ModelStrategy):
"""神经网络模型"""
def train(self, data):
print("🧠 训练神经网络模型")
print(" 构建3层全连接网络")
print(" 训练100个epoch")
print(" 训练完成")
def predict(self, input_data):
print("🧠 神经网络预测")
return "神经网络预测结果"
class SVMModel(ModelStrategy):
"""支持向量机模型"""
def train(self, data):
print("📊 训练SVM模型")
print(" 选择RBF核函数")
print(" 优化超参数")
print(" 训练完成")
def predict(self, input_data):
print("📊 SVM预测")
return "SVM预测结果"
class MLModel:
"""机器学习模型上下文"""
def __init__(self, strategy: ModelStrategy):
self._strategy = strategy
def set_strategy(self, strategy: ModelStrategy):
self._strategy = strategy
def train(self, data):
return self._strategy.train(data)
def predict(self, input_data):
return self._strategy.predict(input_data)
# 使用
model = MLModel(RandomForestModel())
model.train("训练数据")
result1 = model.predict("测试数据")
print()
model.set_strategy(NeuralNetworkModel())
model.train("训练数据")
result2 = model.predict("测试数据")
策略模式 vs 工厂方法模式
| 维度 | 策略模式 | 工厂方法模式 |
|---|---|---|
| 目的 | 封装算法 | 创建对象 |
| 关注点 | 行为 | 创建 |
| 使用场景 | 算法可互换 | 对象创建可变 |
| 客户端 | 选择策略 | 通过工厂获取对象 |
graph TB
A[模式对比] --> B[策略模式]
A --> C[工厂方法模式]
B --> B1[封装算法]
B --> B2[运行时切换]
B --> B3[行为变化]
C --> C1[创建对象]
C --> C2[延迟创建]
C --> C3[创建逻辑变化]
style A fill:#ede7f6,stroke:#5e35b1,stroke-width:3px
style B fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style C fill:#fff9c4,stroke:#f9a825,stroke-width:2px
优缺点
✅ 优点
| 优点 | 说明 |
|---|---|
| 开闭原则 | 新增策略无需修改代码 |
| 单一职责 | 每个策略只负责一个算法 |
| 可复用 | 策略可在不同上下文中复用 |
| 易于测试 | 每个策略可独立测试 |
❌ 缺点
| 缺点 | 说明 |
|---|---|
| 类数量多 | 每个策略一个类 |
| 客户端知道所有策略 | 需要了解不同策略 |
| 策略通信困难 | 策略之间难以通信 |
适用场景
| 场景 | 是否适合 |
|---|---|
| 多种算法可互换 | ✅ 适合 |
| 算法有复杂逻辑 | ✅ 适合 |
| 需要运行时切换 | ✅ 适合 |
| 客户端不需要知道实现 | ✅ 适合 |
本章要点
- ✅ 策略模式封装算法,使其可互换
- ✅ 符合开闭原则,易于扩展
- ✅ 运行时可切换算法
- ✅ 适用于排序、支付、路线规划等
- ✅ 每个策略独立,易于测试
- ✅ 策略模式关注行为,工厂方法关注创建
下一步:模板方法模式 🚀