工厂方法模式
工厂方法模式(Factory Method Pattern)定义了一个创建对象的接口,但由子类决定实例化哪个类。它让类将实例化延迟到子类。
问题定义
场景1:多种日志系统
# ❌ 问题:客户端直接依赖具体类
class FileLogger:
def log(self, message):
print(f"文件日志: {message}")
class DatabaseLogger:
def log(self, message):
print(f"数据库日志: {message}")
class ConsoleLogger:
def log(self, message):
print(f"控制台日志: {message}")
# 客户端代码
if log_type == "file":
logger = FileLogger()
elif log_type == "database":
logger = DatabaseLogger()
elif log_type == "console":
logger = ConsoleLogger()
logger.log("系统启动")
问题: - 客户端代码需要知道所有具体类 - 添加新类型需要修改客户端代码(违反开闭原则) - 对象创建逻辑分散在客户端
场景2:多种支付方式
class CreditCardPayment:
def pay(self, amount):
print(f"信用卡支付: {amount}")
class AlipayPayment:
def pay(self, amount):
print(f"支付宝支付: {amount}")
class WeChatPayment:
def pay(self, amount):
print(f"微信支付: {amount}")
# 每次添加新支付方式都要修改这个函数
def process_payment(payment_type, amount):
if payment_type == "credit_card":
payment = CreditCardPayment()
elif payment_type == "alipay":
payment = AlipayPayment()
elif payment_type == "wechat":
payment = WeChatPayment()
payment.pay(amount)
解决方案
工厂方法模式将对象创建逻辑封装在工厂类中,客户端通过工厂获取对象,无需知道具体创建细节。
classDiagram
class Creator {
<>
+factoryMethod() Product
}
class ConcreteCreatorA {
+factoryMethod() Product
}
class ConcreteCreatorB {
+factoryMethod() Product
}
class Product {
<>
+operation()
}
class ConcreteProductA {
+operation()
}
class ConcreteProductB {
+operation()
}
Creator <|-- ConcreteCreatorA
Creator <|-- ConcreteCreatorB
Product <|-- ConcreteProductA
Product <|-- ConcreteProductB
Creator ..> Product : creates
标准实现
抽象产品和具体产品
from abc import ABC, abstractmethod
class Logger(ABC):
"""日志记录器抽象类"""
@abstractmethod
def log(self, message):
pass
class FileLogger(Logger):
"""文件日志记录器"""
def log(self, message):
print(f"[文件] {message}")
class DatabaseLogger(Logger):
"""数据库日志记录器"""
def log(self, message):
print(f"[数据库] {message}")
class ConsoleLogger(Logger):
"""控制台日志记录器"""
def log(self, message):
print(f"[控制台] {message}")
抽象工厂和具体工厂
class LoggerFactory(ABC):
"""日志工厂抽象类"""
@abstractmethod
def create_logger(self) -> Logger:
pass
class FileLoggerFactory(LoggerFactory):
"""文件日志工厂"""
def create_logger(self) -> Logger:
return FileLogger()
class DatabaseLoggerFactory(LoggerFactory):
"""数据库日志工厂"""
def create_logger(self) -> Logger:
return DatabaseLogger()
class ConsoleLoggerFactory(LoggerFactory):
"""控制台日志工厂"""
def create_logger(self) -> Logger:
return ConsoleLogger()
客户端使用
def use_logger(factory: LoggerFactory):
"""使用日志记录器 - 不依赖具体类"""
logger = factory.create_logger()
logger.log("系统启动")
logger.log("处理请求")
logger.log("系统关闭")
# 使用不同工厂
file_factory = FileLoggerFactory()
use_logger(file_factory)
# 输出:
# [文件] 系统启动
# [文件] 处理请求
# [文件] 系统关闭
db_factory = DatabaseLoggerFactory()
use_logger(db_factory)
# 输出:
# [数据库] 系统启动
# [数据库] 处理请求
# [数据库] 系统关闭
实战应用
应用1:支付系统
from abc import ABC, abstractmethod
# 产品
class PaymentMethod(ABC):
"""支付方式抽象类"""
@abstractmethod
def pay(self, amount):
pass
@abstractmethod
def refund(self, transaction_id):
pass
class CreditCardPayment(PaymentMethod):
"""信用卡支付"""
def pay(self, amount):
print(f"信用卡支付 ${amount}")
return "TXN_001"
def refund(self, transaction_id):
print(f"信用卡退款: {transaction_id}")
class AlipayPayment(PaymentMethod):
"""支付宝支付"""
def pay(self, amount):
print(f"支付宝支付 ¥{amount}")
return "ALI_001"
def refund(self, transaction_id):
print(f"支付宝退款: {transaction_id}")
class WeChatPayment(PaymentMethod):
"""微信支付"""
def pay(self, amount):
print(f"微信支付 ¥{amount}")
return "WX_001"
def refund(self, transaction_id):
print(f"微信退款: {transaction_id}")
# 工厂
class PaymentFactory(ABC):
"""支付工厂抽象类"""
@abstractmethod
def create_payment(self) -> PaymentMethod:
pass
class CreditCardFactory(PaymentFactory):
def create_payment(self) -> PaymentMethod:
return CreditCardPayment()
class AlipayFactory(PaymentFactory):
def create_payment(self) -> PaymentMethod:
return AlipayPayment()
class WeChatFactory(PaymentFactory):
def create_payment(self) -> PaymentMethod:
return WeChatPayment()
# 使用
def process_order(factory: PaymentFactory, amount):
"""处理订单"""
payment = factory.create_payment()
txn_id = payment.pay(amount)
print(f"交易完成,ID: {txn_id}")
# 用户选择不同支付方式
process_order(CreditCardFactory(), 100)
process_order(AlipayFactory(), 100)
process_order(WeChatFactory(), 100)
应用2:图形编辑器
from abc import ABC, abstractmethod
# 产品:形状
class Shape(ABC):
"""形状抽象类"""
@abstractmethod
def draw(self):
pass
class Circle(Shape):
"""圆形"""
def __init__(self, radius):
self.radius = radius
def draw(self):
print(f"绘制圆形,半径: {self.radius}")
class Rectangle(Shape):
"""矩形"""
def __init__(self, width, height):
self.width = width
self.height = height
def draw(self):
print(f"绘制矩形,宽: {self.width}, 高: {self.height}")
class Triangle(Shape):
"""三角形"""
def __init__(self, base, height):
self.base = base
self.height = height
def draw(self):
print(f"绘制三角形,底: {self.base}, 高: {self.height}")
# 工厂
class ShapeFactory(ABC):
"""形状工厂抽象类"""
@abstractmethod
def create_shape(self, *args) -> Shape:
pass
class CircleFactory(ShapeFactory):
def create_shape(self, radius) -> Shape:
return Circle(radius)
class RectangleFactory(ShapeFactory):
def create_shape(self, width, height) -> Shape:
return Rectangle(width, height)
class TriangleFactory(ShapeFactory):
def create_shape(self, base, height) -> Shape:
return Triangle(base, height)
# 使用
def draw_shape(factory: ShapeFactory, *args):
shape = factory.create_shape(*args)
shape.draw()
draw_shape(CircleFactory(), 5)
draw_shape(RectangleFactory(), 10, 20)
draw_shape(TriangleFactory(), 8, 6)
应用3:数据库连接
from abc import ABC, abstractmethod
# 产品:数据库连接
class DatabaseConnection(ABC):
"""数据库连接抽象类"""
@abstractmethod
def connect(self):
pass
@abstractmethod
def query(self, sql):
pass
class MySQLConnection(DatabaseConnection):
"""MySQL连接"""
def __init__(self, host, port, database):
self.host = host
self.port = port
self.database = database
def connect(self):
print(f"连接MySQL: {self.host}:{self.port}/{self.database}")
def query(self, sql):
print(f"MySQL查询: {sql}")
class PostgreSQLConnection(DatabaseConnection):
"""PostgreSQL连接"""
def __init__(self, host, port, database):
self.host = host
self.port = port
self.database = database
def connect(self):
print(f"连接PostgreSQL: {self.host}:{self.port}/{self.database}")
def query(self, sql):
print(f"PostgreSQL查询: {sql}")
class MongoDBConnection(DatabaseConnection):
"""MongoDB连接"""
def __init__(self, host, port, database):
self.host = host
self.port = port
self.database = database
def connect(self):
print(f"连接MongoDB: {self.host}:{self.port}/{self.database}")
def query(self, sql):
print(f"MongoDB查询: {sql}")
# 工厂
class DatabaseFactory(ABC):
"""数据库工厂抽象类"""
@abstractmethod
def create_connection(self, config) -> DatabaseConnection:
pass
class MySQLFactory(DatabaseFactory):
def create_connection(self, config) -> DatabaseConnection:
return MySQLConnection(
config['host'],
config['port'],
config['database']
)
class PostgreSQLFactory(DatabaseFactory):
def create_connection(self, config) -> DatabaseConnection:
return PostgreSQLConnection(
config['host'],
config['port'],
config['database']
)
class MongoDBFactory(DatabaseFactory):
def create_connection(self, config) -> DatabaseConnection:
return MongoDBConnection(
config['host'],
config['port'],
config['database']
)
# 使用
def connect_database(factory: DatabaseFactory, config):
db = factory.create_connection(config)
db.connect()
db.query("SELECT * FROM users")
# 配置
mysql_config = {"host": "localhost", "port": 3306, "database": "mydb"}
postgres_config = {"host": "localhost", "port": 5432, "database": "mydb"}
connect_database(MySQLFactory(), mysql_config)
connect_database(PostgreSQLFactory(), postgres_config)
简化版工厂
如果工厂很简单,可以不用抽象工厂:
class ShapeFactory:
"""简单形状工厂"""
@staticmethod
def create_shape(shape_type, *args):
if shape_type == "circle":
return Circle(*args)
elif shape_type == "rectangle":
return Rectangle(*args)
elif shape_type == "triangle":
return Triangle(*args)
else:
raise ValueError(f"未知的形状类型: {shape_type}")
# 使用
circle = ShapeFactory.create_shape("circle", 5)
circle.draw()
rectangle = ShapeFactory.create_shape("rectangle", 10, 20)
rectangle.draw()
参数化工厂
通过参数决定创建哪种产品:
class LoggerFactory:
"""参数化日志工厂"""
@staticmethod
def create_logger(logger_type):
if logger_type == "file":
return FileLogger()
elif logger_type == "database":
return DatabaseLogger()
elif logger_type == "console":
return ConsoleLogger()
else:
raise ValueError(f"未知的日志类型: {logger_type}")
# 使用
logger = LoggerFactory.create_logger("file")
logger.log("消息")
与简单工厂的区别
| 维度 | 简单工厂 | 工厂方法 |
|---|---|---|
| 工厂类数量 | 1个工厂类 | 多个工厂类 |
| 创建逻辑 | 集中在一个工厂 | 分散在子工厂 |
| 扩展性 | 需要修改工厂类 | 添加新工厂类 |
| 符合OCP | ❌ 不符合 | ✅ 符合 |
# 简单工厂(不符合OCP)
class SimpleFactory:
def create(self, type):
if type == "A":
return ProductA()
elif type == "B":
return ProductB()
# 添加新类型需要修改这里!
# 工厂方法(符合OCP)
class FactoryA:
def create(self):
return ProductA()
class FactoryB:
def create(self):
return ProductB()
class FactoryC: # 新增类型,无需修改原有代码
def create(self):
return ProductC()
优缺点
✅ 优点
| 优点 | 说明 |
|---|---|
| 解耦创建和使用 | 客户端不依赖具体类 |
| 符合开闭原则 | 新增类型无需修改现有代码 |
| 代码复用 | 创建逻辑在工厂中复用 |
| 易于测试 | 可以注入Mock工厂 |
❌ 缺点
| 缺点 | 说明 |
|---|---|
| 类数量增加 | 每个产品需要一个工厂类 |
| 复杂度增加 | 简单场景可能过度设计 |
| 引用复杂 | 客户端需要知道工厂类 |
与其他模式的关系
graph TB
A[工厂方法] --> B[抽象工厂]
A --> C[模板方法]
A --> D[原型]
B --> B1[创建产品族]
C --> C1[工厂方法可以是模板方法]
D --> D1[结合使用创建对象]
style A fill:#ede7f6,stroke:#5e35b1,stroke-width:3px
本章要点
- ✅ 工厂方法模式将对象创建延迟到子类
- ✅ 解耦对象创建和使用
- ✅ 符合开闭原则,易于扩展
- ✅ 简单场景可以用简单工厂
- ✅ 复杂场景推荐使用抽象工厂
- ✅ 可以参数化工厂,通过参数决定创建类型
下一步:建造者模式 🚀