victory的博客

长安一片月,万户捣衣声

0%

Python定时任务调度

定时任务调度(Scheduled Task Execution)是一种在预定时间或周期性地自动执行特定操作的机制。它广泛应用于自动化运维、数据处理、日志清理、监控告警、定时推送等场景。


🕰️ 定时任务调度的核心概念

概念 说明
任务(Task) 要执行的具体操作,通常是一个函数或命令。
调度器(Scheduler) 管理和触发任务执行的组件。
触发条件(Trigger) 决定任务何时执行的规则,如:固定间隔、延迟执行、cron 表达式等。
执行方式 可同步或异步执行任务。

📌 定时任务的常见类型

  1. 一次性任务(One-time Task)
    • 执行一次后自动结束。
    • 示例:5 秒后发送一条通知。
  2. 周期性任务(Recurring Task)
    • 按照固定间隔重复执行。
    • 示例:每 5 分钟检查服务器状态。
  3. 延迟首次执行任务(Delayed First Execution)
    • 首次执行有延迟,之后按周期执行。
    • 示例:启动后等待 10 秒再开始心跳检测。
  4. 基于时间表达式的任务(Cron-like Task)
    • 使用类似 cron 的语法定义执行时间点。
    • 示例:每天凌晨 2:00 执行数据库备份。

✅ 定时任务调度的应用场景

场景 描述
数据同步 每小时从远程服务器拉取最新数据
日志清理 每天凌晨删除过期日志文件
报表生成 每月第一天自动生成统计报表
健康检查 每隔一段时间检测服务可用性
自动提醒 每天上午9点推送待办事项

🧩 定时任务调度的实现方式(Python 中)

1. 简单轮询 + sleep

  • 如你提供的 SimpleCronJob 类所示。
  • 利用 while 循环 + time.sleep() 实现秒级轮询。
  • 适用于轻量级任务。

2. 使用第三方库

  • APScheduler:功能强大的任务调度器,支持 cron、日期、间隔三种触发器。
  • Celery Beat:分布式任务调度系统,适合大型项目。
  • schedule:简洁易用的库,适合脚本化任务。

3. 操作系统级定时任务

  • Linux 下使用 crontab
  • Windows 下使用“任务计划程序”。

🔁 示例对比:不同方式实现每 5 秒打印一次

方式一:原生 Python(当前类)

1
cron.add_task("periodic_task", lambda: print("🔁 周期性任务执行"), interval_seconds=5)

方式二:使用 schedule

1
2
3
4
5
6
7
8
9
10
11
import schedule
import time

def job():
print("Scheduled task executed")

schedule.every(5).seconds.do(job)

while True:
schedule.run_pending()
time.sleep(1)

方式三:使用 APScheduler

1
2
3
4
5
6
7
8
from apscheduler.schedulers.blocking import BlockingScheduler

def job():
print("APScheduled task executed")

sched = BlockingScheduler()
sched.add_job(job, 'interval', seconds=5)
sched.start()

🧠 小结

特性 当前类 SimpleCronJob APScheduler Celery Beat crontab
易用性 ✅ 简单直接 ✅ 支持多种触发器 ⚠️ 分布式复杂 ✅ 系统级配置
功能丰富度 ❌ 较基础 ✅ 强大 ✅ 极其强大 ⚠️ 仅基本调度
多线程支持 ✅ 已封装 ✅ 支持 ✅ 支持 ❌ 单进程
持久化 ❌ 不支持 ✅ 可持久化 ✅ 支持 ❌ 不支持
适用场景 本地测试、小型脚本 中型应用 大型分布式系统 系统维护任务

📎 总结一句话:

定时任务调度就是让程序在指定时间或条件下自动执行某些操作,从而实现自动化流程管理。

python定时任务调度实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import time
import datetime
import threading
from typing import Dict, Optional, Callable, Tuple

class SimpleCronJob:
def __init__(self):
self.tasks: Dict[str, Tuple[datetime.datetime, datetime.timedelta, bool, bool, Callable]] = {}
self.lock = threading.Lock()
self.stop_flag = False

def add_task(self, task_id: str, func: Callable, delay_seconds: int = 0, interval_seconds: int = 0, once: bool = False):
"""
添加一个定时任务。

:param task_id: 任务唯一标识符
:param func: 要执行的函数
:param delay_seconds: 首次执行延迟时间(秒)
:param interval_seconds: 周期执行间隔时间(秒)
:param once: 是否为一次性任务
"""
with self.lock:
self.tasks[task_id] = (
None, # last_time
datetime.timedelta(seconds=delay_seconds),
False, # first_executed
datetime.timedelta(seconds=interval_seconds),
once,
func
)
print(f"任务 {task_id} 已注册")

def start(self):
print("开始执行定时任务服务...")
while not self.stop_flag:
now = datetime.datetime.now()

with self.lock:
for task_id, (last_time, delay_time, first_executed, interval_time, once, func) in list(self.tasks.items()):
if delay_time.total_seconds() > 0 and not first_executed:
if last_time is None:
self.tasks[task_id] = (now, delay_time, first_executed, interval_time, once, func)
continue
elif (now - last_time) < delay_time:
continue
else:
self.tasks[task_id] = (now, delay_time, True, interval_time, once, func)

elif last_time is None:
self.tasks[task_id] = (now, delay_time, first_executed, interval_time, once, func)
continue
elif (now - last_time) < interval_time:
continue

# 执行任务
try:
print(f"执行任务 {task_id}")
func()
except Exception as e:
print(f"任务 {task_id} 执行出错: {e}")

# 更新最后执行时间
self.tasks[task_id] = (now, delay_time, True, interval_time, once, func)

# 如果是一次性任务,执行后移除
if once:
del self.tasks[task_id]
print(f"任务 {task_id} 已完成并移除")

time.sleep(1) # 模拟 CRON_TIME_INTERVAL

def stop(self):
self.stop_flag = True
print("定时任务服务已停止")


# 示例任务函数
def example_task():
print("✅ 示例任务正在执行...")

if __name__ == "__main__":
cron = SimpleCronJob()

# 1. 周期性任务:每5秒执行一次
cron.add_task("periodic_task", lambda: print("🔁 周期性任务执行"), interval_seconds=5)

# 2. 延迟性任务:首次延迟3秒后开始,之后每6秒执行一次
cron.add_task("delayed_periodic_task",
lambda: print("🕒 延迟+周期任务执行"),
delay_seconds=3,
interval_seconds=6)

# 3. 一次性任务:立即执行一次(无延迟)
cron.add_task("one_time_task",
lambda: print("📎 一次性任务执行"),
once=True)

# 4. 一次性延迟任务:延迟5秒后执行一次
cron.add_task("one_time_delayed_task",
lambda: print("⏳ 一次性延迟任务执行"),
delay_seconds=5,
once=True)

try:
cron.start()
except KeyboardInterrupt:
cron.stop()


python协程

  1. 协程是什么?

    协程是一种用户级别的线程,由程序(而不是操作系统)调度。在Python种,由async def定义,使用await暂停和恢复执行。适合处理大量IO阻塞任务(如Redis、HTTP、文件、数据库等)。

  2. 为什么要用协程?
    目前主流语言基本上都选择了多线程作为并发设施,与线程相关的概念就是抢占式多任务(Preemptive multitasking),而与协程相关的是协作式多任务。

    其实不管是进程还是线程,每次阻塞、切换都需要陷入系统调用(system call),先让CPU跑操作系统的调度程序,然后再由调度程序决定该跑哪一个进程(线程)。
    而且由于抢占式调度执行顺序无法确定的特点,使用线程时需要非常小心地处理同步问题,而协程完全不存在这个问题(事件驱动和异步程序也有同样的优点)。

    因为协程是用户自己来编写调度逻辑的,对于我们的CPU来说,协程其实是单线程,所以CPU不用去考虑怎么调度、切换上下文,这就省去了CPU的切换开销,所以协程在一定程度上又好于多线程。

  3. 简单协程示例

    1
    2
    3
    4
    5
    6
    7
    8
    import asyncio

    async def say_hello():
    print("Hello ...")
    await asyncio.sleep(1)
    print("... World!")

    asyncio.run(say_hello())
  4. 运行Redis订阅监听、HTTP并发请求、异步文件写入、异步数据库查询的协程驱动系统代码示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    import asyncio
    import aiohttp
    import aiofiles
    import aiomysql
    import redis.asyncio as aioredis

    # === Redis 订阅监听 ===
    async def redis_listener():
    redis = aioredis.Redis()
    pubsub = redis.pubsub()
    await pubsub.subscribe("demo")
    print("[Redis] Subscribed to 'demo'")

    async for msg in pubsub.listen():
    if msg["type"] == "message":
    print(f"[Redis] Received: {msg['data'].decode()}")

    # === 并发 HTTP 请求 ===
    async def fetch(session, url, i):
    async with session.get(url) as resp:
    text = await resp.text()
    print(f"[HTTP] #{i}: {len(text)} bytes from {url}")

    async def http_worker():
    urls = ["https://example.com"] * 3
    async with aiohttp.ClientSession() as session:
    tasks = [fetch(session, url, i) for i, url in enumerate(urls)]
    await asyncio.gather(*tasks)

    # === 异步写文件 ===
    async def file_writer():
    for i in range(3):
    async with aiofiles.open(f"output_{i}.txt", "w") as f:
    await f.write(f"[File] Written by task {i}\n")
    print(f"[File] Written file {i}")
    await asyncio.sleep(0.5)

    # === 异步 MySQL 查询 ===
    async def mysql_worker():
    try:
    conn = await aiomysql.connect(
    host='localhost', port=3306,
    user='your_user', password='your_pass',
    db='your_db', autocommit=True
    )
    async with conn.cursor() as cur:
    await cur.execute("SELECT SLEEP(1), 'Hello from DB'")
    result = await cur.fetchone()
    print(f"[MySQL] Result: {result}")
    conn.close()
    except Exception as e:
    print("[MySQL] Error:", e)

    # === 主函数:并发运行所有任务 ===
    async def main():
    await asyncio.gather(
    redis_listener(),
    http_worker(),
    file_writer(),
    mysql_worker(),
    )

    if __name__ == "__main__":
    asyncio.run(main())

Python多进程与多线程

  1. Python多进程是并行执行的吗?

    答:Python多线程不是并行执行的。由于CPython解释器中有一个全局解释器锁(GIL),它会导致:

    • 同一时间内只有一个县城能执行Python字节码
    • 对于CPU密集型任务,多线程不会并行执行,线程是轮流执行的

    结论:Python在CPU密集型任务(图像处理、大量计算等)中不能并行,但在I/O密集型任务(网络请求、文件读写、数据库访问等)中非常有效,可以“并发”运行多个任务。

  2. 那么在Python中如何实现真正的并行?

    可使用multiprocessing多进程实现多个任务的并行,每个进程有自己独立的Python解释器和内存空间,不会收到GIL的限制。

  3. 多线程、多进程Python示例

    3.1 多线程

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import threading
    import time

    def worker(n):
    print(f"[Thread] Start {n}")
    time.sleep(1) # 模拟IO任务
    print(f"[Thread] Done {n} -> {n * n}")

    threads = []

    start = time.time()
    for i in range(5):
    t = threading.Thread(target=worker, args=(i,))
    threads.append(t)
    t.start()

    for t in threads:
    t.join()
    end = time.time()

    print(f"Threading total time: {end - start:.2f} seconds")

    3.2 多进程

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    from multiprocessing import Process
    import time
    import os

    def worker(n):
    print(f"[Process {os.getpid()}] Start {n}")
    time.sleep(1) # 模拟工作
    print(f"[Process {os.getpid()}] Done {n} -> {n * n}")

    processes = []

    start = time.time()
    for i in range(5):
    p = Process(target=worker, args=(i,))
    processes.append(p)
    p.start()

    for p in processes:
    p.join()
    end = time.time()

    print(f"Multiprocessing total time: {end - start:.2f} seconds")

Python Redis 发布订阅

  1. Redis发布/订阅是一种广播式消息系统:
    • 发布者:向一个频道(channel)发送消息
    • 订阅者:订阅一个或多个频道,等待接受消息
    • 当有消息发布到一个频道,所有订阅该频道的客户端都会立即收到消息
  2. 应用场景示例
    • 实时聊天系统:用户订阅频道,别人发言就能立刻收到
    • 消息推送/通知中心:后台发布消息,前端订阅实时显示
    • 分布式服务通信:多个服务通过Redis通信协调
  3. 特点
    • 高性能:消息实时传递
    • 非持久化:消息不会存储
    • 多对多支持:一个频道支持多个订阅者,一个客户端可订阅多个频道
    • 无确认机制:不想Kafka/RabbitMQ,由消息丢失风险
  4. Redis 发布订阅Python类实现(redis_pubsub.py)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import redis
import threading
import time

class RedisPubSub:
def __init__(self, host='localhost', port=6379, db=0, channel='default'):
self.channel = channel # 频道
self.redis = redis.Redis(host=host, port=port, db=db, decode_responses=True) # redis客户端
self.pubsub = self.redis.pubsub() # redis发布订阅对象
self._running = False # 是否开启订阅

# 向self.channel频道发送消息message
def publish(self, message: str):
"""发布消息"""
self.redis.publish(self.channel, message)

# 订阅self.channel消息
def subscribe(self, callback):
"""
订阅并启动监听线程。
参数 callback: 接收一个函数,在收到消息时被调用。
"""
def listen():
# 订阅self.channel频道
self.pubsub.subscribe(self.channel)
print(f"Subscribed to {self.channel}")

# 监听
for message in self.pubsub.listen():
if not self._running:
break
if message['type'] == 'message':
callback(message['data'])

# 开启订阅
self._running = True

# 创建监听线程
self.listen_thread = threading.Thread(target=listen, daemon=True)
# 启动监听线程
self.listen_thread.start()

def stop(self):
"""停止订阅"""
self._running = False
self.pubsub.unsubscribe()
print(f"Unsubscribed from {self.channel}")
  1. 发布端(publisher_demo.py)
1
2
3
4
5
6
7
8
from redis_pubsub import RedisPubSub
import time

pub = RedisPubSub(channel='chat')
for i in range(5):
pub.publish(f"Message {i}")
print(f"Sent: Message {i}")
time.sleep(1)
  1. 订阅端(subscriber_demo.py)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from redis_pubsub import RedisPubSub
import time

def handle_message(msg):
print(f"Received: {msg}")

sub = RedisPubSub(channel='chat')
sub.subscribe(callback=handle_message)

try:
while True:
time.sleep(1)
except KeyboardInterrupt:
sub.stop()

python第三方库/python-daemon/将当前进程变为守护进行

  • 守护进程

    • 什么是守护进程?

      守护进程是一种在后台运行的进程,通常用于执行周期性或长时间运行的任务。

    • 特点:

      • 在后台运行的进程
      • 没有控制终端(不受键盘、tty影响)
  • 守护进程的原始实现

    • 守护进程实现核心步骤
      • fork() 一个子进程,父进程退出(脱离原始终端)
      • setsid() 创建新会话,脱离控制终端
      • 第二次 fork() 防止重新获得终端
      • chdir(‘/‘) 切换到根目录(避免锁定工作目录)
      • umask(0) 重设文件权限掩码
      • 重定向 stdin、stdout、stderr 到 /dev/null
  • 实现守护进程的方法

    • 方法1

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      import os
      import sys

      def daemonize():
      # 第一次 fork
      pid = os.fork()
      if pid > 0:
      sys.exit(0)

      # 脱离终端,创建新会话
      os.setsid()
      os.umask(0)

      # 第二次 fork,避免重新打开终端
      pid = os.fork()
      if pid > 0:
      sys.exit(0)

      # 关闭标准文件描述符并重定向到 /dev/null
      sys.stdout.flush()
      sys.stderr.flush()
      with open('/dev/null', 'rb', 0) as f:
      os.dup2(f.fileno(), sys.stdin.fileno())
      with open('/dev/null', 'ab', 0) as f:
      os.dup2(f.fileno(), sys.stdout.fileno())
      os.dup2(f.fileno(), sys.stderr.fileno())
    • 方法2(推荐使用)

      1
      2
      3
      4
      5
      6
      7
      8
      9
      import daemon

      def run():
      while True:
      # 你的后台任务代码
      ...

      with daemon.DaemonContext():
      run()
    • 方法3:使用multiprocessing模块

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      from multiprocessing import Process, current_process

      def task():
      process = current_process()
      print(f"Daemon process: {process.daemon}")

      if __name__ == "__main__":
      process = Process(target=task, daemon=True)
      process.start()
      process.join()
  • 示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import time
    import daemon
    from daemon import pidfile

    LOG_PATH = "/tmp/my_daemon.log"
    PID_PATH = "/tmp/my_daemon.pid"

    def run():
    while True:
    with open(LOG_PATH, "a") as f:
    f.write(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Daemon is running...\n")
    time.sleep(5)

    if __name__ == "__main__":
    with daemon.DaemonContext(
    working_directory=".",
    umask=0o002,
    pidfile=pidfile.TimeoutPIDLockFile(PID_PATH)
    ):
    run()

  • Transformer

    Transformer是一种基于自注意力机制的深度学习模型,由Google在2017年的论文“Attention is All You Need”提出。Transformer由编码器(Encoder)和解码器(Decoder)组成,结构如下图所示:

  • Transformer Pytorch代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)

def forward(self, x):
return x + self.pe[:, :x.size(1)].detach()


class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
assert d_model % self.num_heads == 0 # 确保 d_model 能被 num_heads 整除

self.depth = d_model // self.num_heads

self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)

self.dense = nn.Linear(d_model, d_model)

def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3)

def attention(self, query, key, value, mask=None, dropout=None):
matmul_qk = torch.matmul(query, key.transpose(-2, -1)) # QK^T
dk = query.size(-1)
scaled_attention_logits = matmul_qk / math.sqrt(dk)

if mask is not None:
scaled_attention_logits += (mask * -1e9) # 避免pad部分被注意到

attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
if dropout is not None:
attention_weights = dropout(attention_weights)

output = torch.matmul(attention_weights, value)
return output, attention_weights

def forward(self, query, key, value, mask=None, dropout=None):
batch_size = query.size(0)

query = self.split_heads(self.wq(query), batch_size)
key = self.split_heads(self.wk(key), batch_size)
value = self.split_heads(self.wv(value), batch_size)

output, attention_weights = self.attention(query, key, value, mask, dropout)
output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)

return self.dense(output)


class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super(FeedForwardNetwork, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
x = F.relu(self.linear1(x))
x = self.dropout(x)
return self.linear2(x)


class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super(EncoderLayer, self).__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
self.layernorm1 = nn.LayerNorm(d_model)
self.layernorm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
# 自注意力层
attn_output = self.attention(x, x, x, mask, self.dropout)
x = self.layernorm1(x + attn_output) # 残差连接 + LayerNorm

# 前馈网络层
ffn_output = self.ffn(x)
x = self.layernorm2(x + ffn_output) # 残差连接 + LayerNorm

return x


class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super(DecoderLayer, self).__init__()
self.attention1 = MultiHeadAttention(d_model, num_heads)
self.attention2 = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
self.layernorm1 = nn.LayerNorm(d_model)
self.layernorm2 = nn.LayerNorm(d_model)
self.layernorm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
# 解码器中的自注意力层
attn1_output = self.attention1(x, x, x, look_ahead_mask, self.dropout)
x = self.layernorm1(attn1_output + x)

# 编码器-解码器注意力层
attn2_output = self.attention2(x, enc_output, enc_output, padding_mask, self.dropout)
x = self.layernorm2(attn2_output + x)

# 前馈网络层
ffn_output = self.ffn(x)
x = self.layernorm3(ffn_output + x)

return x


class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff=2048, dropout=0.1):
super(TransformerEncoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
])
self.d_model = d_model

def forward(self, x, mask=None):
x = self.embedding(x) * math.sqrt(self.d_model) # 嵌入 + 缩放
x = self.positional_encoding(x)

for layer in self.layers:
x = layer(x, mask)

return x


class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff=2048, dropout=0.1):
super(TransformerDecoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
])
self.d_model = d_model

def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.positional_encoding(x)

for layer in self.layers:
x = layer(x, enc_output, look_ahead_mask, padding_mask)

return x


class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff=2048, dropout=0.1):
super(Transformer, self).__init__()

self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout)
self.decoder = TransformerDecoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout)
self.output_layer = nn.Linear(d_model, vocab_size)

def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# 编码器部分
enc_output = self.encoder(src, src_mask)

# 解码器部分
dec_output = self.decoder(tgt, enc_output, tgt_mask, src_mask)

# 输出层
return self.output_layer(dec_output)


vocab_size = 10000 # 词汇表大小
d_model = 512 # 特征维度
num_heads = 8 # 注意力头数
num_layers = 6 # 编码器和解码器层数
dropout = 0.1 # Dropout 比例

# 初始化 Transformer 模型
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, dropout=dropout)

# 输入张量(batch_size, sequence_length)
src = torch.randint(0, vocab_size, (32, 100)) # 假设 source 语言输入 batch_size=32,序列长度=100
tgt = torch.randint(0, vocab_size, (32, 100)) # 假设 target 语言输入 batch_size=32,序列长度=100

# 创建遮罩(假设没有 padding)
src_mask = None
tgt_mask = None

# 前向传播
output = transformer(src, tgt, src_mask, tgt_mask)

print("Output shape:", output.shape) # 输出的形状 (batch_size, tgt_sequence_length, vocab_size)

  • 自注意力机制(Self-Attention)

    自注意力机制是Transfromer中的重要组件,它通过计算Query (Q)、Key (K)、Value (V)获取token之间的相关性。Q、K、V矩阵是通过输入嵌入(或前一层的输出)与权重矩阵进行线性变换得到的。

    自注意力机制的输入格式为(batch_size, seq_len, d_model),batch_size是批次大小,seq_len是序列长度,d_model是嵌入维度。

    Q、K、V的计算:Q=X x W_Q, K=X x W_K, V=X x W_V。

    自注意力输出的计算:output = (QxK_T) x V。

  • 自注意力机制pytorch代码实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    import torch
    import torch.nn.functional as F
    from torch import nn


    class SingleHeadAttention(nn.Module):
    def __init__(self, embed_size):
    super(SingleHeadAttention, self).__init__()

    # 输入的embedding维度
    self.embed_size = embed_size

    # 定义查询、键和值的线性变换
    self.query_fc = nn.Linear(embed_size, embed_size)
    self.key_fc = nn.Linear(embed_size, embed_size)
    self.value_fc = nn.Linear(embed_size, embed_size)

    # 输出的线性变换
    self.out_fc = nn.Linear(embed_size, embed_size)

    def forward(self, X, mask=None):
    print("X.shape: ", X.shape)
    # Step1: 通过线性层生成查询、键和值的向量
    Q = self.query_fc(X) # (batch_size, seq_len, embed_size)
    print("Q.shape: ", Q.shape)
    K = self.key_fc(X) # (batch_size, seq_len, embed_size)
    print("K.shape: ", K.shape)
    V = self.value_fc(X) # (batch_size, seq_len, embed_size)
    print("V.shape: ", V.shape)

    # Step2: 计算注意力得分
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5)
    print("attention_scores.shape: ", attention_scores.shape)

    # 如果有mask,应用mask
    if mask is not None:
    attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

    # Step3: 计算注意力权重(softmax)
    attention_weights = F.softmax(attention_scores, dim=-1) # (batch_size, seq_len, seq_len)

    # Step4: 加权求和得到输出
    output = torch.matmul(attention_weights, V) # (batch_size, seq_len, embed_size)

    return output


    if __name__ == "__main__":
    batch_size = 2
    seq_len = 4
    embed_size = 8

    # 随机生成输入数据
    X = torch.randn(batch_size, seq_len, embed_size)

    # 创建自注意力模型
    attention_layer = SingleHeadAttention(embed_size)

    # 前向传播
    output = attention_layer(X)

    print(f"Output: {output.shape}") # (batch_size, seq_len, embed_size)

  • 多头注意力机制(Multi-Head Attention)

    多头注意力机制是Transformer模型中的一个核心组成部分,它通过并行计算多个注意力头来捕捉输入序列的不同信息,每个注意力头都有独立的Q、K、V,能够关注输入的不同子空间,从而增强模型对不同特征的表达能力。

    多头注意力计算过程:

    1. 线性变换:输入的向量首先会通过不同的线性变换(权重矩阵)生成多个查询(Q)、键(K)和值(V)向量。
    2. 计算注意力:每个注意力头根据查询、键和值计算注意力权重,并通过加权求和得到一个输出。
    3. 拼接:所有头的输出会被拼接在一起。
    4. 线性变换:拼接后的结果通过一个线性变换,最终输出。
  • 多头注意力机制pytorch代码实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    import torch
    import torch.nn as nn
    import torch.nn.functional as F


    class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_size = embed_size
    self.num_heads = num_heads
    self.head_dim = embed_size // num_heads

    assert self.head_dim * num_heads == embed_size, "Embedding size must be divisible by num_heads"

    # 定义查询、键、值的线性变换
    self.query_fc = nn.Linear(embed_size, embed_size)
    self.key_fc = nn.Linear(embed_size, embed_size)
    self.value_fc = nn.Linear(embed_size, embed_size)

    # 定义输出的线性变换
    self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, X):
    batch_size = X.shape[1]

    # 通过线性变换得到 Q, K, V
    Q = self.query_fc(X) # (seq_len, batch_size, embed_size)
    print("Q.shape: ", Q.shape)
    K = self.key_fc(X)
    print("K.shape: ", K.shape)
    V = self.value_fc(X)
    print("V.shape: ", V.shape)

    # 将Q, K, V 切分成多个头
    Q = Q.view(X.shape[0], batch_size, self.num_heads, self.head_dim).transpose(1,
    2) # (seq_len, batch_size, num_heads, head_dim)
    print("Q_multi_head.shape: ", Q.shape)
    K = K.view(X.shape[0], batch_size, self.num_heads, self.head_dim).transpose(1, 2)
    print("K_multi_head.shape: ", K.shape)
    V = V.view(X.shape[0], batch_size, self.num_heads, self.head_dim).transpose(1, 2)
    print("V_multi_head.shape: ", V.shape)

    # 计算注意力得分
    energy = torch.matmul(Q, K.transpose(-2, -1)) # (seq_len, batch_size, num_heads, seq_len)
    attention = torch.softmax(energy / (self.head_dim ** 0.5), dim=-1) # 注意力得分
    print("Q*K.shape", attention.shape)

    # 计算加权求和的输出
    out = torch.matmul(attention, V) # (seq_len, batch_size, num_heads, head_dim)

    # 将多个头合并
    out = out.transpose(1, 2).contiguous().view(X.shape[0], batch_size, self.num_heads * self.head_dim)

    # 通过输出的线性层
    out = self.fc_out(out)

    return out


    # 测试
    embed_size = 64
    num_heads = 8
    seq_len = 10
    batch_size = 32

    multihead_attention = MultiHeadAttention(embed_size, num_heads)

    # 输入张量,shape: (seq_len, batch_size, embed_size)
    X = torch.rand(seq_len, batch_size, embed_size)

    out = multihead_attention(X)
    print(out.shape) # (seq_len, batch_size, embed_size)

Netron是一款开源的深度学习模型可视化工具,支持多种深度学习框架生成的模型(例如,PyTorch、TensorFlow、ONNX等)的可视化。

网页版工具地址:Netron