模板方法模式
High Contrast
Dark Mode
Light Mode
Sepia
Forest
2 min read496 words

模板方法模式

模板方法模式(Template Method Pattern)定义一个操作中的算法的骨架,而将一些步骤延迟到子类中。模板方法使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤。

问题定义

场景1:数据导出

# ❌ 问题:重复代码
class DataExporter:
"""数据导出器"""
def export_to_csv(self, data):
"""导出CSV"""
print("连接数据库")
print("查询数据")
print(f"转换为CSV格式: {data}")
print("保存文件")
print("关闭连接")
def export_to_excel(self, data):
"""导出Excel"""
print("连接数据库")
print("查询数据")
print(f"转换为Excel格式: {data}")
print("保存文件")
print("关闭连接")
def export_to_json(self, data):
"""导出JSON"""
print("连接数据库")
print("查询数据")
print(f"转换为JSON格式: {data}")
print("保存文件")
print("关闭连接")
# 问题:连接、查询、保存等步骤重复,只有格式转换不同

场景2:考试流程

# ❌ 问题:步骤重复
class Exam:
"""考试"""
def take_math_exam(self):
print("准备考试用品")
print("阅读数学题目")
print("回答数学问题")
print("检查答案")
print("提交试卷")
def take_english_exam(self):
print("准备考试用品")
print("阅读英语题目")
print("回答英语问题")
print("检查答案")
print("提交试卷")
def take_programming_exam(self):
print("准备考试用品")
print("阅读编程题目")
print("编写代码")
print("检查答案")
print("提交试卷")
# 问题:流程相同,只有部分步骤不同

解决方案

模板方法模式在父类中定义算法骨架,将可变步骤延迟到子类实现。

classDiagram class AbstractClass { +templateMethod() #primitiveOperation1() #primitiveOperation2() } class ConcreteClassA { +primitiveOperation1() +primitiveOperation2() } class ConcreteClassB { +primitiveOperation1() +primitiveOperation2() } AbstractClass <|-- ConcreteClassA AbstractClass <|-- ConcreteClassB

标准实现

抽象类

from abc import ABC, abstractmethod
class DataExporter(ABC):
"""数据导出器 - 抽象类"""
def export(self, data):
"""模板方法 - 定义算法骨架"""
print("=" * 50)
self.connect_database()
self.query_data()
self.transform_format(data)
self.save_file()
self.close_connection()
print("=" * 50)
def connect_database(self):
"""连接数据库 - 固定步骤"""
print("1. 连接数据库")
def query_data(self):
"""查询数据 - 固定步骤"""
print("2. 查询数据")
@abstractmethod
def transform_format(self, data):
"""转换格式 - 可变步骤 - 子类实现"""
pass
def save_file(self):
"""保存文件 - 固定步骤"""
print("4. 保存文件")
def close_connection(self):
"""关闭连接 - 固定步骤"""
print("5. 关闭连接")

具体类

class CSVExporter(DataExporter):
"""CSV导出器"""
def transform_format(self, data):
"""转换为CSV格式"""
print(f"3. 转换为CSV格式: {data}")
class ExcelExporter(DataExporter):
"""Excel导出器"""
def transform_format(self, data):
"""转换为Excel格式"""
print(f"3. 转换为Excel格式: {data}")
class JSONExporter(DataExporter):
"""JSON导出器"""
def transform_format(self, data):
"""转换为JSON格式"""
print(f"3. 转换为JSON格式: {data}")

客户端使用

# 创建导出器
csv_exporter = CSVExporter()
excel_exporter = ExcelExporter()
json_exporter = JSONExporter()
# 导出数据
data = {"name": "Alice", "age": 30}
print("导出CSV:")
csv_exporter.export(data)
print()
print("导出Excel:")
excel_exporter.export(data)
print()
print("导出JSON:")
json_exporter.export(data)

钩子方法

钩子方法是在模板方法中提供可选的步骤,子类可以选择性覆盖。

class DataExporterWithHook(ABC):
"""带钩子方法的数据导出器"""
def export(self, data):
"""模板方法"""
print("=" * 50)
self.before_export()  # 钩子方法
self.connect_database()
self.query_data()
self.transform_format(data)
self.save_file()
self.close_connection()
self.after_export()  # 钩子方法
print("=" * 50)
def before_export(self):
"""导出前的钩子方法 - 可选覆盖"""
pass
def after_export(self):
"""导出后的钩子方法 - 可选覆盖"""
pass
def connect_database(self):
print("1. 连接数据库")
def query_data(self):
print("2. 查询数据")
@abstractmethod
def transform_format(self, data):
pass
def save_file(self):
print("4. 保存文件")
def close_connection(self):
print("5. 关闭连接")
class LoggingCSVExporter(DataExporterWithHook):
"""带日志的CSV导出器"""
def transform_format(self, data):
print(f"3. 转换为CSV格式: {data}")
def before_export(self):
"""覆盖钩子方法"""
print("📝 开始导出...")
def after_export(self):
"""覆盖钩子方法"""
print("✅ 导出完成!")
# 使用
logging_exporter = LoggingCSVExporter()
logging_exporter.export({"name": "Alice"})

实战应用

应用1:考试流程

from abc import ABC, abstractmethod
class Exam(ABC):
"""考试 - 抽象类"""
def take_exam(self):
"""模板方法 - 考试流程"""
print("=" * 50)
self.prepare()
self.read_questions()
self.answer_questions()
self.check_answers()
self.submit()
print("=" * 50)
def prepare(self):
"""准备 - 固定步骤"""
print("🎒 准备考试用品(笔、橡皮、计算器)")
def read_questions(self):
"""阅读题目 - 固定步骤"""
print("📖 阅读考试题目")
@abstractmethod
def answer_questions(self):
"""答题 - 可变步骤"""
pass
def check_answers(self):
"""检查答案 - 固定步骤"""
print("🔍 检查答案")
def submit(self):
"""提交 - 固定步骤"""
print("✍️ 提交试卷")
class MathExam(Exam):
"""数学考试"""
def answer_questions(self):
"""回答数学问题"""
print("🧮 回答数学问题")
print("  - 计算函数导数")
print("  - 求解微分方程")
print("  - 证明定理")
class EnglishExam(Exam):
"""英语考试"""
def answer_questions(self):
"""回答英语问题"""
print("📝 回答英语问题")
print("  - 选择题")
print("  - 阅读理解")
print("  - 作文")
class ProgrammingExam(Exam):
"""编程考试"""
def answer_questions(self):
"""编写代码"""
print("💻 编写代码")
print("  - 算法题")
print("  - 数据结构题")
print("  - 系统设计题")
# 使用
print("数学考试:")
math_exam = MathExam()
math_exam.take_exam()
print("\n英语考试:")
english_exam = EnglishExam()
english_exam.take_exam()
print("\n编程考试:")
programming_exam = ProgrammingExam()
programming_exam.take_exam()

应用2:数据处理流程

from abc import ABC, abstractmethod
class DataPipeline(ABC):
"""数据处理管道"""
def process(self, raw_data):
"""模板方法 - 数据处理流程"""
print("=" * 50)
data = raw_data
data = self.validate(data)
if data is None:
print("❌ 数据验证失败,流程终止")
return
data = self.clean(data)
data = self.transform(data)
data = self.enrich(data)
self.save(data)
print("=" * 50)
return data
def validate(self, data):
"""验证数据 - 可选钩子"""
print("✅ 数据验证通过")
return data
def clean(self, data):
"""清洗数据 - 固定步骤"""
print("🧹 清洗数据(去重、去除空值)")
return data
@abstractmethod
def transform(self, data):
"""转换数据 - 可变步骤"""
pass
def enrich(self, data):
"""丰富数据 - 固定步骤"""
print("🔗 丰富数据(添加元数据)")
return data
def save(self, data):
"""保存数据 - 固定步骤"""
print("💾 保存数据")
class CustomerDataPipeline(DataPipeline):
"""客户数据处理"""
def transform(self, data):
"""转换客户数据"""
print("🔄 转换客户数据")
print("  - 格式化电话号码")
print("  - 标准化地址")
print("  - 计算客户评分")
class SalesDataPipeline(DataPipeline):
"""销售数据处理"""
def validate(self, data):
"""验证销售数据"""
print("✅ 验证销售数据(检查金额是否为正数)")
return data
def transform(self, data):
"""转换销售数据"""
print("🔄 转换销售数据")
print("  - 计算税费")
print("  - 转换货币")
print("  - 按地区分组")
# 使用
print("客户数据处理:")
customer_pipeline = CustomerDataPipeline()
customer_pipeline.process("客户原始数据")
print("\n销售数据处理:")
sales_pipeline = SalesDataPipeline()
sales_pipeline.process("销售原始数据")

应用3:网络请求

from abc import ABC, abstractmethod
class APIClient(ABC):
"""API客户端"""
def request(self, url, **kwargs):
"""模板方法 - 发送请求"""
print("=" * 50)
print(f"📡 发送请求: {url}")
# 准备请求
headers = self.prepare_headers(**kwargs)
params = self.prepare_params(**kwargs)
# 发送请求
response = self.send_request(url, headers, params)
# 处理响应
if response:
data = self.parse_response(response)
validated_data = self.validate_data(data)
return validated_data
print("=" * 50)
return None
def prepare_headers(self, **kwargs):
"""准备请求头"""
headers = {
"Content-Type": "application/json",
"User-Agent": "MyApp/1.0"
}
print(f"📋 请求头: {headers}")
return headers
def prepare_params(self, **kwargs):
"""准备请求参数"""
print(f"📋 请求参数: {kwargs}")
return kwargs
@abstractmethod
def send_request(self, url, headers, params):
"""发送请求 - 可变步骤"""
pass
def parse_response(self, response):
"""解析响应"""
print(f"📄 解析响应: {response}")
return response
def validate_data(self, data):
"""验证数据"""
print("✅ 验证数据通过")
return data
class RESTAPIClient(APIClient):
"""REST API客户端"""
def send_request(self, url, headers, params):
"""发送REST请求"""
print("📡 使用GET方法发送REST请求")
return {"status": "success", "data": "响应数据"}
class GraphQLAPIClient(APIClient):
"""GraphQL API客户端"""
def send_request(self, url, headers, params):
"""发送GraphQL请求"""
print("📡 使用POST方法发送GraphQL请求")
print(f"   Query: {params.get('query', '')}")
return {"status": "success", "data": "响应数据"}
# 使用
print("REST API请求:")
rest_client = RESTAPIClient()
rest_client.request("https://api.example.com/users", page=1)
print("\nGraphQL API请求:")
graphql_client = GraphQLAPIClient()
graphql_client.request(
"https://api.example.com/graphql",
query="{ users { name } }"
)

应用4:游戏角色创建

from abc import ABC, abstractmethod
class CharacterCreator(ABC):
"""角色创建器"""
def create_character(self, name):
"""模板方法 - 创建角色流程"""
print("=" * 50)
print(f"🎮 创建角色: {name}")
self.initialize_attributes()
self.choose_class()
self.select_appearance()
self.assign_equipment()
self.finalize()
print("=" * 50)
def initialize_attributes(self):
"""初始化属性 - 固定步骤"""
print("📊 初始化角色属性")
print("  - 生命值: 100")
print("  - 魔力值: 50")
print("  - 速度: 10")
@abstractmethod
def choose_class(self):
"""选择职业 - 可变步骤"""
pass
def select_appearance(self):
"""选择外观 - 固定步骤"""
print("🎨 选择角色外观")
print("  - 发型: 短发")
print("  - 脸型: 方脸")
print("  - 服装: 基础装备")
@abstractmethod
def assign_equipment(self):
"""分配装备 - 可变步骤"""
pass
def finalize(self):
"""完成创建 - 固定步骤"""
print("✅ 角色创建完成!")
class WarriorCreator(CharacterCreator):
"""战士创建器"""
def choose_class(self):
"""选择战士职业"""
print("⚔️ 选择职业: 战士")
print("  - 力量: +20")
print("  - 防御: +15")
print("  - 特技: 盾击")
def assign_equipment(self):
"""分配战士装备"""
print("🛡️ 分配装备")
print("  - 武器: 铁剑")
print("  - 防具: 铁甲")
print("  - 盾牌: 铁盾")
class MageCreator(CharacterCreator):
"""法师创建器"""
def choose_class(self):
"""选择法师职业"""
print("🔮 选择职业: 法师")
print("  - 智力: +25")
print("  - 魔力: +20")
print("  - 特技: 火球术")
def assign_equipment(self):
"""分配法师装备"""
print("🪄 分配装备")
print("  - 武器: 法杖")
print("  - 防具: 法袍")
print("  - 饰品: 魔法戒指")
class ArcherCreator(CharacterCreator):
"""弓箭手创建器"""
def choose_class(self):
"""选择弓箭手职业"""
print("🏹 选择职业: 弓箭手")
print("  - 敏捷: +25")
print("  - 命中: +15")
print("  - 特技: 多重射击")
def assign_equipment(self):
"""分配弓箭手装备"""
print("🎯 分配装备")
print("  - 武器: 长弓")
print("  - 防具: 皮甲")
print("  - 饰品: 箭袋")

优缺点

✅ 优点

优点 说明
代码复用 相同步骤在父类实现一次
扩展性好 子类可扩展可变步骤
控制反转 父类控制流程,子类实现细节
一致性 确保算法结构一致

❌ 缺点

缺点 说明
类数量 每个具体实现一个子类
钩子复杂 钩子方法可能难以理解
继承限制 无法多重继承

与其他模式的关系

模式 关系
策略模式 策略模式替换整个算法,模板方法替换部分步骤
工厂方法 模板方法模式可以使用工厂方法创建对象
组合模式 模板方法常与组合模式一起使用

本章要点


下一步命令模式 🚀