Contents

设计模式-以深度学习为例-创建型模式

Design Pattern - A Deep Learning Perspective

设计模式 - 以深度学习为例 - 创建型模式

常用设计模式 中,介绍了一系列创建型模式、结构型模式和行为模式的设计模式,对日常代码质量提升有很大帮助。但它们往往以 JAVA 语言的工程项目为例,我们在这里以深度学习的模型搭建、训练、推理、部署为例,解释面向对象编程中 22 中设计模式的基本原理。

设计模式分为

  • 创建型模式:工厂方法,抽象工厂,生成器,原型,单例
  • 结构型模式:适配器,桥接,组合,装饰,外观,享元,代理
  • 行为模式:责任链,命令,迭代器,中介者,备忘录,观察者,状态,策略,模板方法,访问者

本文介绍创建型模式在深度学习中的应用。创建型模式提供创建对象的机制,能够提升已有代码的灵活性和可复用性。


工厂方法 Factory Method

工厂方法在父类中提供一个创建对象的方法,允许子类决定实例化对象的类型。 https://lemonzzy.oss-cn-hangzhou.aliyuncs.com/typora/202301252046498.png

  • 产品 Product 对接口进行声明,对于创建者或者由其子类构建的对象,接口是通用的
  • 创建者类声明返回 Product 对象的工厂方法,返回类型和产品接口相匹配

使用场景

  1. 将产品的创建与实际的使用分离 → 只需要开发新的 ConcreteCreator,重写其 createProduct 方法。
  2. 利用继承扩展软件库和框架默认行为

效果

  1. 避免创建者和具体产品之间的耦合
  2. 单一职责原则,将产品创建代码放在程序的单一位置,使代码更易维护
  3. 开闭原则,无需修改现有客户端代码,就可以引入新的产品类型
  4. 但是工厂方法会引入很多新的子类,代码变得复杂

深度学习例子

例如我们在整合模型到同一个框架中,其一是自回归模型 Transformer,其二是对抗网络 GAN,二者的结构不同,所以训练过程差异很大。Transformer 只有一个模型,而 GAN 包含了 Generator 和 Discriminator。但同时,两个任务有相同的超参初始化,checkpoint 保存,数据集读取等任务。

在模型构建中,我们有一个公共 Product nn.Module,基于此构建了 TransformerGAN 两个具体产品。

1
2
3
4
5
6
7
class Transformer(nn.Module):
    def forward(x):
        return "Transformer"

class GAN(nn.Module):
    def forward(x):
        return "GAN"

在模型训练中,我们使用一个公共方法 Trainer,包含了 build_model 方法来创建相应的模型。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Trainer(object):
    def build_model():
        raise NotImplementedError
    def train_model(data):
        # train the model
        pass

class GANTrainer(Trainer):
    def build_model():
        return GAN()

class TransformerTrainer(Trainer):
    def build_model():
        return Transformer()

在实际使用中,通过调用 trainer 的模型训练接口,就可以实现训练。

1
2
def train_my_model(trainer:Trainer, data):
    trainer.train_model(data)

抽象工厂 Abstract Factory

https://lemonzzy.oss-cn-hangzhou.aliyuncs.com/typora/202301252108049.png

  • 抽象产品 Abstract Product 是构成系列产品的不同但相关的产品声明接口
  • 抽象工厂 Abstract Factory 接口声明一组创建抽象产品的方法
  • 具体工厂会对具体产品初始化,但构造方法的签名必须返回相应的抽象产品。客户端只需要调用抽象接口就能返回相应的抽象产品,客户端代码就不会和工厂创建的特定产品变体耦合。

使用场景

  • 代码需要和多个不同序列的相关产品交互,不希望代码基于产品的具体类进行构建
  • 有一个基于 一组抽象方法 的类抽象工厂提供一个接口来创建每个系列产品的对象
  • 以不同的产品类型变体维度绘制矩阵 → 为所有产品声明抽象产品接口,让具体产品实现具体接口 → 声明抽象工厂接口,每个具体工厂实现接口 → 将代码中产品构造函数替换成工厂方法的构造函数

效果

  • 单一职责原则 & 开闭原则
  • 引入多种接口和类,代码变复杂

深度学习例子

抽象工厂可以理解为多维的工厂方法,一系列相互依赖的对象。例如我们实现了 GAN 和 Transformer 两种模型,为了测试我们需要在两种数据集 translation 和 paraphrase 上测试,两个 trainer 就变成了 translation+transformer,paraphrase+GAN 这样两组相互依赖的对象。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class Dataset():
    def build_dataset():
        raise NotImplementedError

class Translation():
    def build_dataset():
        return translation_pairs

class Paraphrase():
    def build_dataset():
        return paraphrase_lists

对于网络,如上定义两个深度学习模型 Transformer 和 GAN

1
2
3
4
5
6
7
class Transformer(nn.Module):
    def forward(x):
        return "Transformer"

class GAN(nn.Module):
    def forward(x):
        return "GAN"

之后,对于训练模型的 Trainer,我们需要同时实现数据集和网络两个不同维度模型的组合

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Trainer():
    def build_network():
        raise NotImplementedError
    def build_dataset():
        raise NotImplementedError

class TransformerTranslationTrainer():
    def build_network():
        return Transformer()
    def build_dataset():
        return Translation()

class GANParaphraseTrainer():
    def build_network():
        return GAN()
    def build_dataset():
        return Paraphrase()

在实际的使用中,我们只需要创建一个相应的 Trainer,调用对应的训练 api 即可。

例如在 [[Fairseq]] 框架中,对应的 Task 就是一个很好的抽象工厂模式,它调度了模型和数据集的搭建。

生成器 Builder

https://lemonzzy.oss-cn-hangzhou.aliyuncs.com/typora/202301272107100.png

  • 生成器 Builder 接口声明所有类型生成器中通用的产品构造步骤
  • 产品是最终生成的对象
  • 主管 Director 定义调用构造步骤的顺序
  • 客户端 Client 将某个生成器对象与 Director 关联

使用场景

  1. 重叠的多种构造函数出现
  2. 使用代码创建不同形式的产品,过程相似细节不同
  3. 使用生成器构造组合树或其他复杂对象

效果

  1. 分步创建对象,暂缓创建步骤
  2. 生成不同形式的产品,复用相同制造代码
  3. 单一直则原则,将复杂代码从业务逻辑分离
  4. 会增加多个类,代码整体复杂度增加

深度学习例子

例如普通分类模型,包括多种模型结构,如 Linear,MLP,CNN 等,都由 Linear 层和 Conv 层这两个基本单元构成。可以使用生成器模式来实现模型层级的搭建。

首先定义一个 Network 类和 Builder 类,来处理向基类中添加层

1
2
3
4
5
6
7
8
9
class Network():
    def __init__():
        this.network = nn.Sequential()

class Builder():
    def addLinearLayer():
        ...
    def addConvLayer():
        ...

原型 Prototype

使用场景

效果

深度学习例子

单例 Singleton

使用场景

效果

深度学习例子