策略模式
High Contrast
Dark Mode
Light Mode
Sepia
Forest
3 min read506 words

策略模式

策略模式(Strategy Pattern)定义一系列算法,把它们一个个封装起来,并且使它们可相互替换。本模式使得算法的变化可独立于使用它的客户。

问题定义

场景1:排序算法

# ❌ 问题:大量条件判断
class Sorter:
"""排序器"""
def sort(self, data, algorithm):
if algorithm == "bubble":
self._bubble_sort(data)
elif algorithm == "quick":
self._quick_sort(data)
elif algorithm == "merge":
self._merge_sort(data)
elif algorithm == "heap":
self._heap_sort(data)
# 添加新算法需要修改这里!
def _bubble_sort(self, data):
print("使用冒泡排序")
def _quick_sort(self, data):
print("使用快速排序")
def _merge_sort(self, data):
print("使用归并排序")
def _heap_sort(self, data):
print("使用堆排序")
# 问题:每次添加新算法都要修改Sorter类

场景2:支付方式

# ❌ 问题:扩展困难
class PaymentProcessor:
"""支付处理器"""
def process_payment(self, payment_type, amount):
if payment_type == "credit_card":
self._process_credit_card(amount)
elif payment_type == "paypal":
self._process_paypal(amount)
elif payment_type == "wechat":
self._process_wechat(amount)
# 添加新支付方式需要修改这里
def _process_credit_card(self, amount):
print(f"信用卡支付 ${amount}")
def _process_paypal(self, amount):
print(f"PayPal支付 ${amount}")
def _process_wechat(self, amount):
print(f"微信支付 ¥{amount}")

解决方案

策略模式将算法封装成独立的类,使用时可以灵活替换。

classDiagram class Strategy { <> +algorithm() } class ConcreteStrategyA { +algorithm() } class ConcreteStrategyB { +algorithm() } class Context { -strategy: Strategy +setStrategy() +executeAlgorithm() } Strategy <|-- ConcreteStrategyA Strategy <|-- ConcreteStrategyB Context o-- Strategy : uses

标准实现

策略接口

from abc import ABC, abstractmethod
class SortStrategy(ABC):
"""排序策略接口"""
@abstractmethod
def sort(self, data):
pass

具体策略

class BubbleSortStrategy(SortStrategy):
"""冒泡排序策略"""
def sort(self, data):
print("使用冒泡排序")
n = len(data)
for i in range(n):
for j in range(0, n - i - 1):
if data[j] > data[j + 1]:
data[j], data[j + 1] = data[j + 1], data[j]
return data
class QuickSortStrategy(SortStrategy):
"""快速排序策略"""
def sort(self, data):
print("使用快速排序")
if len(data) <= 1:
return data
pivot = data[len(data) // 2]
left = [x for x in data if x < pivot]
middle = [x for x in data if x == pivot]
right = [x for x in data if x > pivot]
return self.sort(left) + middle + self.sort(right)
class MergeSortStrategy(SortStrategy):
"""归并排序策略"""
def sort(self, data):
print("使用归并排序")
if len(data) <= 1:
return data
mid = len(data) // 2
left = self.sort(data[:mid])
right = self.sort(data[mid:])
return self._merge(left, right)
def _merge(self, left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return data

上下文类

class Sorter:
"""排序器 - 上下文"""
def __init__(self, strategy: SortStrategy):
self._strategy = strategy
def set_strategy(self, strategy: SortStrategy):
"""设置排序策略"""
self._strategy = strategy
def sort(self, data):
"""执行排序"""
return self._strategy.sort(data[:])  # 复制数据,不修改原数据

客户端使用

# 数据
data = [64, 34, 25, 12, 22, 11, 90]
# 创建排序器,使用冒泡排序
sorter = Sorter(BubbleSortStrategy())
result1 = sorter.sort(data)
print(f"结果: {result1}\n")
# 切换到快速排序
sorter.set_strategy(QuickSortStrategy())
result2 = sorter.sort(data)
print(f"结果: {result2}\n")
# 切换到归并排序
sorter.set_strategy(MergeSortStrategy())
result3 = sorter.sort(data)
print(f"结果: {result3}")

实战应用

应用1:支付系统

from abc import ABC, abstractmethod
class PaymentStrategy(ABC):
"""支付策略接口"""
@abstractmethod
def pay(self, amount):
pass
class CreditCardPayment(PaymentStrategy):
"""信用卡支付"""
def pay(self, amount):
print(f"💳 信用卡支付: ${amount}")
print("  验证信用卡信息...")
print("  扣款成功")
return f"CC-{amount}"
class PayPalPayment(PaymentStrategy):
"""PayPal支付"""
def pay(self, amount):
print(f"🅿️ PayPal支付: ${amount}")
print("  跳转到PayPal...")
print("  支付成功")
return f"PP-{amount}"
class WeChatPayment(PaymentStrategy):
"""微信支付"""
def pay(self, amount):
print(f"💬 微信支付: ¥{amount}")
print("  打开微信...")
print("  支付成功")
return f"WX-{amount}"
class AlipayPayment(PaymentStrategy):
"""支付宝支付"""
def pay(self, amount):
print(f"🔷 支付宝支付: ¥{amount}")
print("  打开支付宝...")
print("  支付成功")
return f"ALI-{amount}"
class PaymentContext:
"""支付上下文"""
def __init__(self, strategy: PaymentStrategy):
self._strategy = strategy
def set_strategy(self, strategy: PaymentStrategy):
self._strategy = strategy
def pay(self, amount):
return self._strategy.pay(amount)
# 使用
context = PaymentContext(CreditCardPayment())
context.pay(100)
print()
context.set_strategy(WeChatPayment())
context.pay(500)
print()
context.set_strategy(AlipayPayment())
context.pay(200)

应用2:出行路线规划

from abc import ABC, abstractmethod
class RouteStrategy(ABC):
"""路线策略接口"""
@abstractmethod
def build_route(self, origin, destination):
pass
class FastestRouteStrategy(RouteStrategy):
"""最快路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的最快路线")
print("  使用高速公路")
print("  预计时间: 45分钟")
print("  距离: 80公里")
print("  过路费: $15")
class ShortestRouteStrategy(RouteStrategy):
"""最短路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的最短路线")
print("  使用省道")
print("  预计时间: 65分钟")
print("  距离: 55公里")
print("  过路费: $5")
class CheapestRouteStrategy(RouteStrategy):
"""最经济路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的最经济路线")
print("  使用普通道路")
print("  预计时间: 90分钟")
print("  距离: 60公里")
print("  过路费: $0")
class ScenicRouteStrategy(RouteStrategy):
"""风景路线策略"""
def build_route(self, origin, destination):
print(f"🚗 规划从 {origin} 到 {destination} 的风景路线")
print("  沿海岸线行驶")
print("  预计时间: 120分钟")
print("  距离: 100公里")
print("  过路费: $10")
print("  经过景点: 3个")
class Navigation:
"""导航系统"""
def __init__(self, strategy: RouteStrategy):
self._strategy = strategy
def set_strategy(self, strategy: RouteStrategy):
self._strategy = strategy
def navigate(self, origin, destination):
return self._strategy.build_route(origin, destination)
# 使用
nav = Navigation(FastestRouteStrategy())
nav.navigate("北京", "天津")
print()
nav.set_strategy(ScenicRouteStrategy())
nav.navigate("北京", "天津")

应用3:数据压缩

from abc import ABC, abstractmethod
class CompressionStrategy(ABC):
"""压缩策略接口"""
@abstractmethod
def compress(self, data):
pass
@abstractmethod
def decompress(self, compressed_data):
pass
class ZipCompression(CompressionStrategy):
"""ZIP压缩"""
def compress(self, data):
print(f"📦 使用ZIP压缩: {len(data)} bytes")
print("  压缩率: 70%")
return f"ZIP:{data}"
def decompress(self, compressed_data):
print(f"📦 解压ZIP: {compressed_data}")
return compressed_data.replace("ZIP:", "")
class GzipCompression(CompressionStrategy):
"""GZIP压缩"""
def compress(self, data):
print(f"📦 使用GZIP压缩: {len(data)} bytes")
print("  压缩率: 80%")
return f"GZIP:{data}"
def decompress(self, compressed_data):
print(f"📦 解压GZIP: {compressed_data}")
return compressed_data.replace("GZIP:", "")
class RarCompression(CompressionStrategy):
"""RAR压缩"""
def compress(self, data):
print(f"📦 使用RAR压缩: {len(data)} bytes")
print("  压缩率: 75%")
return f"RAR:{data}"
def decompress(self, compressed_data):
print(f"📦 解压RAR: {compressed_data}")
return compressed_data.replace("RAR:", "")
class CompressionContext:
"""压缩上下文"""
def __init__(self, strategy: CompressionStrategy):
self._strategy = strategy
def set_strategy(self, strategy: CompressionStrategy):
self._strategy = strategy
def compress_data(self, data):
return self._strategy.compress(data)
def decompress_data(self, compressed_data):
return self._strategy.decompress(compressed_data)
# 使用
context = CompressionContext(ZipCompression())
compressed = context.compress_data("Hello World! " * 1000)
print()
context.set_strategy(GzipCompression())
compressed = context.compress_data("Hello World! " * 1000)

应用4:AI模型选择

from abc import ABC, abstractmethod
class ModelStrategy(ABC):
"""AI模型策略接口"""
@abstractmethod
def train(self, data):
pass
@abstractmethod
def predict(self, input_data):
pass
class RandomForestModel(ModelStrategy):
"""随机森林模型"""
def train(self, data):
print("🌲 训练随机森林模型")
print("  构建100棵决策树")
print("  训练完成")
def predict(self, input_data):
print("🌲 随机森林预测")
return "随机森林预测结果"
class NeuralNetworkModel(ModelStrategy):
"""神经网络模型"""
def train(self, data):
print("🧠 训练神经网络模型")
print("  构建3层全连接网络")
print("  训练100个epoch")
print("  训练完成")
def predict(self, input_data):
print("🧠 神经网络预测")
return "神经网络预测结果"
class SVMModel(ModelStrategy):
"""支持向量机模型"""
def train(self, data):
print("📊 训练SVM模型")
print("  选择RBF核函数")
print("  优化超参数")
print("  训练完成")
def predict(self, input_data):
print("📊 SVM预测")
return "SVM预测结果"
class MLModel:
"""机器学习模型上下文"""
def __init__(self, strategy: ModelStrategy):
self._strategy = strategy
def set_strategy(self, strategy: ModelStrategy):
self._strategy = strategy
def train(self, data):
return self._strategy.train(data)
def predict(self, input_data):
return self._strategy.predict(input_data)
# 使用
model = MLModel(RandomForestModel())
model.train("训练数据")
result1 = model.predict("测试数据")
print()
model.set_strategy(NeuralNetworkModel())
model.train("训练数据")
result2 = model.predict("测试数据")

策略模式 vs 工厂方法模式

维度 策略模式 工厂方法模式
目的 封装算法 创建对象
关注点 行为 创建
使用场景 算法可互换 对象创建可变
客户端 选择策略 通过工厂获取对象
graph TB A[模式对比] --> B[策略模式] A --> C[工厂方法模式] B --> B1[封装算法] B --> B2[运行时切换] B --> B3[行为变化] C --> C1[创建对象] C --> C2[延迟创建] C --> C3[创建逻辑变化] style A fill:#ede7f6,stroke:#5e35b1,stroke-width:3px style B fill:#e3f2fd,stroke:#1976d2,stroke-width:2px style C fill:#fff9c4,stroke:#f9a825,stroke-width:2px

优缺点

✅ 优点

优点 说明
开闭原则 新增策略无需修改代码
单一职责 每个策略只负责一个算法
可复用 策略可在不同上下文中复用
易于测试 每个策略可独立测试

❌ 缺点

缺点 说明
类数量多 每个策略一个类
客户端知道所有策略 需要了解不同策略
策略通信困难 策略之间难以通信

适用场景

场景 是否适合
多种算法可互换 ✅ 适合
算法有复杂逻辑 ✅ 适合
需要运行时切换 ✅ 适合
客户端不需要知道实现 ✅ 适合

本章要点


下一步模板方法模式 🚀