Step-by-Step 实现一个能编程的大模型

从零开始训练一个专注于 Python 代码生成的小型 LLM

目录:

你是否好奇过 GitHub Copilot、CodeLlama 这些代码生成模型是如何工作的?本文将带你从零开始,一步步实现一个专注于 Python 代码生成的小型语言模型。通过这个项目,你将深入理解 Transformer 架构、代码 tokenization、以及如何让模型学会"写代码"。

为什么要自己实现一个代码模型?

市面上已经有很多优秀的代码生成模型,但自己动手实现一个有几个独特的价值:

  1. 深入理解原理:纸上得来终觉浅,只有亲手实现才能真正理解每个组件的作用
  2. 定制化需求:你可以针对特定的代码风格或领域进行优化
  3. 资源可控:小模型可以在消费级 GPU 上训练和运行
  4. 学习路径:这是进入 AI 领域的绝佳实践项目

我们的目标是训练一个约 50M 参数的模型,能够:

  • 根据函数签名和注释生成 Python 函数体
  • 补全未完成的代码片段
  • 理解基本的 Python 语法和常用库

整体架构概览

┌─────────────────────────────────────────────────────────────┐
│                    Code Generation LLM                       │
├─────────────────────────────────────────────────────────────┤
│  1. 数据收集与预处理                                          │
│     └── Python 代码语料库 → 清洗 → Tokenization              │
├─────────────────────────────────────────────────────────────┤
│  2. 模型架构                                                  │
│     └── Decoder-only Transformer (GPT-style)                 │
├─────────────────────────────────────────────────────────────┤
│  3. 训练流程                                                  │
│     └── Next Token Prediction + Causal Language Modeling     │
├─────────────────────────────────────────────────────────────┤
│  4. 推理与代码生成                                            │
│     └── Temperature Sampling + Top-k/Top-p                   │
└─────────────────────────────────────────────────────────────┘

Step 1: 数据收集与预处理

1.1 收集 Python 代码数据

高质量的训练数据是模型成功的基础。我们可以从以下来源获取 Python 代码:

import os
import ast
from pathlib import Path
from typing import List, Optional
from dataclasses import dataclass

@dataclass
class CodeSample:
    """表示一个代码样本"""
    source: str           # 原始代码
    file_path: str        # 文件路径
    is_valid: bool        # 是否语法正确
    functions: List[str]  # 提取的函数列表

def collect_python_files(root_dir: str) -> List[Path]:
    """
    递归收集目录下所有 Python 文件

    Args:
        root_dir: 根目录路径

    Returns:
        Python 文件路径列表
    """
    python_files = []
    for path in Path(root_dir).rglob("*.py"):
        # 跳过测试文件和虚拟环境
        if "test" not in str(path).lower() and "venv" not in str(path):
            python_files.append(path)
    return python_files

def validate_python_syntax(code: str) -> bool:
    """
    检查代码是否是有效的 Python 语法

    Args:
        code: Python 代码字符串

    Returns:
        语法是否有效
    """
    try:
        ast.parse(code)
        return True
    except SyntaxError:
        return False

def extract_functions(code: str) -> List[str]:
    """
    从代码中提取所有函数定义

    Args:
        code: Python 代码字符串

    Returns:
        函数代码列表
    """
    functions = []
    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                # 获取函数的源代码
                func_source = ast.get_source_segment(code, node)
                if func_source:
                    functions.append(func_source)
    except SyntaxError:
        pass
    return functions

def process_code_file(file_path: Path) -> Optional[CodeSample]:
    """
    处理单个代码文件

    Args:
        file_path: 文件路径

    Returns:
        CodeSample 对象或 None
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            code = f.read()

        # 跳过过短或过长的文件
        if len(code) < 100 or len(code) > 100000:
            return None

        is_valid = validate_python_syntax(code)
        functions = extract_functions(code) if is_valid else []

        return CodeSample(
            source=code,
            file_path=str(file_path),
            is_valid=is_valid,
            functions=functions
        )
    except Exception:
        return None

# generated by hugo's coding agent

1.2 代码 Tokenizer

代码的 tokenization 与自然语言有所不同。我们需要保留缩进、特殊符号等对代码语义至关重要的信息。

import re
from collections import Counter
from typing import Dict, List, Tuple

class CodeTokenizer:
    """
    专为 Python 代码设计的 Tokenizer

    使用 Byte-Pair Encoding (BPE) 算法,但对代码特殊处理:
    - 保留缩进信息
    - 识别关键字和运算符
    - 处理字符串和注释
    """

    # Python 关键字作为特殊 token
    PYTHON_KEYWORDS = [
        'def', 'class', 'if', 'else', 'elif', 'for', 'while', 'try',
        'except', 'finally', 'with', 'as', 'import', 'from', 'return',
        'yield', 'raise', 'pass', 'break', 'continue', 'lambda', 'and',
        'or', 'not', 'in', 'is', 'None', 'True', 'False', 'async', 'await'
    ]

    # 特殊 token
    SPECIAL_TOKENS = {
        '<PAD>': 0,
        '<UNK>': 1,
        '<BOS>': 2,  # Beginning of sequence
        '<EOS>': 3,  # End of sequence
        '<INDENT>': 4,  # 缩进增加
        '<DEDENT>': 5,  # 缩进减少
        '<NEWLINE>': 6,
    }

    def __init__(self, vocab_size: int = 8000):
        self.vocab_size = vocab_size
        self.token_to_id: Dict[str, int] = {}
        self.id_to_token: Dict[int, str] = {}
        self.merges: Dict[Tuple[str, str], str] = {}

    def _pre_tokenize(self, code: str) -> List[str]:
        """
        预分词:将代码分割成基本单元

        处理策略:
        1. 保留完整的字符串
        2. 分离运算符和标点
        3. 保留空白用于缩进处理
        """
        tokens = []

        # 正则表达式匹配不同类型的 token
        pattern = r'''
            ("[^"]*"|'[^']*')  |  # 字符串
            (\#[^\n]*)         |  # 注释
            (\d+\.?\d*)        |  # 数字
            ([a-zA-Z_]\w*)     |  # 标识符
            ([ ]{4}|\t)        |  # 缩进单位
            (\n)               |  # 换行
            ([^\s\w])             # 其他符号
        '''

        for match in re.finditer(pattern, code, re.VERBOSE):
            token = match.group()
            if token.strip() or token in ('\n', '    ', '\t'):
                tokens.append(token)

        return tokens

    def _process_indentation(self, tokens: List[str]) -> List[str]:
        """
        处理缩进,转换为 INDENT/DEDENT token
        """
        processed = []
        indent_stack = [0]
        current_indent = 0
        at_line_start = True

        for token in tokens:
            if token == '\n':
                processed.append('<NEWLINE>')
                at_line_start = True
                current_indent = 0
            elif at_line_start and token in ('    ', '\t'):
                current_indent += 1
            elif at_line_start:
                # 处理缩进变化
                while current_indent < indent_stack[-1]:
                    processed.append('<DEDENT>')
                    indent_stack.pop()
                if current_indent > indent_stack[-1]:
                    processed.append('<INDENT>')
                    indent_stack.append(current_indent)
                at_line_start = False
                processed.append(token)
            else:
                processed.append(token)

        return processed

    def train(self, code_samples: List[str], num_merges: int = 5000):
        """
        训练 BPE tokenizer

        Args:
            code_samples: 代码样本列表
            num_merges: BPE 合并次数
        """
        # 初始化词表
        self.token_to_id = dict(self.SPECIAL_TOKENS)
        next_id = len(self.SPECIAL_TOKENS)

        # 添加 Python 关键字
        for keyword in self.PYTHON_KEYWORDS:
            self.token_to_id[keyword] = next_id
            next_id += 1

        # 统计所有字符
        all_tokens = []
        for code in code_samples:
            tokens = self._pre_tokenize(code)
            tokens = self._process_indentation(tokens)
            all_tokens.extend(tokens)

        # 将每个 token 拆分成字符
        words = [list(t) + ['</w>'] for t in all_tokens if t not in self.SPECIAL_TOKENS]

        # BPE 训练
        for _ in range(num_merges):
            pairs = Counter()
            for word in words:
                for i in range(len(word) - 1):
                    pairs[(word[i], word[i + 1])] += 1

            if not pairs:
                break

            best_pair = max(pairs, key=pairs.get)
            new_token = ''.join(best_pair)

            if new_token not in self.token_to_id:
                self.token_to_id[new_token] = next_id
                next_id += 1
                self.merges[best_pair] = new_token

            # 应用合并
            new_words = []
            for word in words:
                new_word = []
                i = 0
                while i < len(word):
                    if i < len(word) - 1 and (word[i], word[i + 1]) == best_pair:
                        new_word.append(new_token)
                        i += 2
                    else:
                        new_word.append(word[i])
                        i += 1
                new_words.append(new_word)
            words = new_words

        # 构建反向映射
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}

    def encode(self, code: str) -> List[int]:
        """将代码编码为 token ID 序列"""
        tokens = self._pre_tokenize(code)
        tokens = self._process_indentation(tokens)

        ids = [self.SPECIAL_TOKENS['<BOS>']]

        for token in tokens:
            if token in self.token_to_id:
                ids.append(self.token_to_id[token])
            elif token in self.SPECIAL_TOKENS:
                ids.append(self.SPECIAL_TOKENS[token])
            else:
                # 应用 BPE
                word = list(token) + ['</w>']
                while len(word) > 1:
                    pairs = [(word[i], word[i + 1]) for i in range(len(word) - 1)]
                    mergeable = [p for p in pairs if p in self.merges]
                    if not mergeable:
                        break
                    pair = min(mergeable, key=lambda p: list(self.merges.keys()).index(p))
                    new_word = []
                    i = 0
                    while i < len(word):
                        if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
                            new_word.append(self.merges[pair])
                            i += 2
                        else:
                            new_word.append(word[i])
                            i += 1
                    word = new_word

                for subtoken in word:
                    if subtoken in self.token_to_id:
                        ids.append(self.token_to_id[subtoken])
                    else:
                        ids.append(self.SPECIAL_TOKENS['<UNK>'])

        ids.append(self.SPECIAL_TOKENS['<EOS>'])
        return ids

    def decode(self, ids: List[int]) -> str:
        """将 token ID 序列解码为代码"""
        tokens = []
        for id in ids:
            if id in self.id_to_token:
                token = self.id_to_token[id]
                if token not in ('<PAD>', '<BOS>', '<EOS>'):
                    tokens.append(token)

        # 重建代码
        code = ''
        indent_level = 0

        for token in tokens:
            if token == '<NEWLINE>':
                code += '\n' + '    ' * indent_level
            elif token == '<INDENT>':
                indent_level += 1
                code += '    '
            elif token == '<DEDENT>':
                indent_level = max(0, indent_level - 1)
                code = code.rstrip('    ')
            else:
                token = token.replace('</w>', '')
                code += token

        return code

# generated by hugo's coding agent

Step 2: 构建 Transformer 模型

2.1 模型配置

from dataclasses import dataclass

@dataclass
class CodeLLMConfig:
    """模型配置"""
    vocab_size: int = 8000       # 词表大小
    max_seq_len: int = 1024      # 最大序列长度
    d_model: int = 512           # 模型维度
    n_heads: int = 8             # 注意力头数
    n_layers: int = 6            # Transformer 层数
    d_ff: int = 2048             # 前馈网络维度
    dropout: float = 0.1         # Dropout 概率

    @property
    def n_params(self) -> int:
        """估算参数量"""
        # Embedding
        embed_params = self.vocab_size * self.d_model
        # Attention (Q, K, V, O projections per layer)
        attn_params = 4 * self.d_model * self.d_model * self.n_layers
        # FFN (2 linear layers per layer)
        ffn_params = 2 * self.d_model * self.d_ff * self.n_layers
        # Layer norms
        ln_params = 4 * self.d_model * self.n_layers
        # Output projection
        out_params = self.d_model * self.vocab_size

        return embed_params + attn_params + ffn_params + ln_params + out_params

# 50M 参数的配置
config = CodeLLMConfig()
print(f"Estimated parameters: {config.n_params / 1e6:.1f}M")
# Output: Estimated parameters: 51.2M

# generated by hugo's coding agent

2.2 核心组件实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

class RotaryPositionalEmbedding(nn.Module):
    """
    旋转位置编码 (RoPE)

    相比传统的正弦位置编码,RoPE 有更好的长度外推能力,
    且能更好地编码相对位置信息。
    """

    def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = 10000.0):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # 计算频率
        inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算 cos 和 sin
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.shape[1]
        if seq_len > self.max_seq_len:
            self._build_cache(seq_len)

        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """将张量的后半部分旋转到前面并取负"""
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat([-x2, x1], dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
                         cos: torch.Tensor, sin: torch.Tensor) -> tuple:
    """应用旋转位置编码到 Q 和 K"""
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制

    使用 RoPE 位置编码和 KV Cache 优化推理速度
    """

    def __init__(self, config: CodeLLMConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.head_dim = config.d_model // config.n_heads

        assert self.head_dim * self.n_heads == config.d_model

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.rope = RotaryPositionalEmbedding(self.head_dim, config.max_seq_len)

    def forward(self, x: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[tuple] = None) -> tuple:
        batch_size, seq_len, _ = x.shape

        # 线性投影
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 重塑为多头
        q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        # 应用 RoPE
        cos, sin = self.rope(q)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # KV Cache 处理
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=2)
            v = torch.cat([v_cache, v], dim=2)

        new_kv_cache = (k, v)

        # 计算注意力分数
        scale = math.sqrt(self.head_dim)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale

        # 应用因果掩码
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        # 加权求和
        attn_output = torch.matmul(attn_probs, v)

        # 重塑并投影
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.o_proj(attn_output)

        return output, new_kv_cache


class FeedForward(nn.Module):
    """
    前馈神经网络

    使用 SwiGLU 激活函数,相比 ReLU 在语言模型中表现更好
    """

    def __init__(self, config: CodeLLMConfig):
        super().__init__()
        self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: gate * up
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        x = gate * up
        x = self.dropout(x)
        x = self.down_proj(x)
        return x


class TransformerBlock(nn.Module):
    """
    单个 Transformer 块

    使用 Pre-LayerNorm 结构,训练更稳定
    """

    def __init__(self, config: CodeLLMConfig):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)

    def forward(self, x: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[tuple] = None) -> tuple:
        # Self-attention with residual
        normed = self.ln1(x)
        attn_out, new_kv_cache = self.attention(normed, mask, kv_cache)
        x = x + attn_out

        # FFN with residual
        normed = self.ln2(x)
        ff_out = self.feed_forward(normed)
        x = x + ff_out

        return x, new_kv_cache

# generated by hugo's coding agent

2.3 完整模型

class CodeLLM(nn.Module):
    """
    代码生成语言模型

    Decoder-only Transformer 架构,专为 Python 代码生成优化
    """

    def __init__(self, config: CodeLLMConfig):
        super().__init__()
        self.config = config

        # Token embedding
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)

        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        # 最终 LayerNorm
        self.ln_f = nn.LayerNorm(config.d_model)

        # 输出投影(与 embedding 共享权重)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.embed_tokens.weight  # 权重共享

        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Xavier/Glorot 初始化"""
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)

    def _create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """创建因果注意力掩码"""
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def forward(self,
                input_ids: torch.Tensor,
                labels: Optional[torch.Tensor] = None,
                kv_cache: Optional[list] = None) -> dict:
        """
        前向传播

        Args:
            input_ids: 输入 token IDs [batch_size, seq_len]
            labels: 标签 [batch_size, seq_len],用于计算损失
            kv_cache: KV 缓存,用于高效推理

        Returns:
            包含 logits 和可选 loss 的字典
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        # Token embedding
        x = self.embed_tokens(input_ids)

        # 创建因果掩码
        if kv_cache is None:
            mask = self._create_causal_mask(seq_len, device)
        else:
            mask = None  # 使用 cache 时不需要完整掩码

        # 通过 Transformer 层
        new_kv_cache = []
        for i, layer in enumerate(self.layers):
            layer_cache = kv_cache[i] if kv_cache else None
            x, layer_kv = layer(x, mask, layer_cache)
            new_kv_cache.append(layer_kv)

        # 最终 LayerNorm
        x = self.ln_f(x)

        # 计算 logits
        logits = self.lm_head(x)

        # 计算损失
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
                ignore_index=-100  # 忽略 padding
            )

        return {
            'logits': logits,
            'loss': loss,
            'kv_cache': new_kv_cache
        }

    @torch.no_grad()
    def generate(self,
                 input_ids: torch.Tensor,
                 max_new_tokens: int = 256,
                 temperature: float = 0.8,
                 top_k: int = 50,
                 top_p: float = 0.95,
                 stop_tokens: Optional[list] = None) -> torch.Tensor:
        """
        自回归生成代码

        Args:
            input_ids: 输入 prompt 的 token IDs
            max_new_tokens: 最大生成 token 数
            temperature: 采样温度,越高越随机
            top_k: Top-k 采样
            top_p: Nucleus 采样
            stop_tokens: 停止 token 列表

        Returns:
            生成的完整 token 序列
        """
        self.eval()
        device = input_ids.device
        generated = input_ids.clone()
        kv_cache = None

        for _ in range(max_new_tokens):
            # 只使用新 token(如果有 cache)
            if kv_cache is not None:
                curr_input = generated[:, -1:]
            else:
                curr_input = generated

            # 前向传播
            outputs = self.forward(curr_input, kv_cache=kv_cache)
            logits = outputs['logits'][:, -1, :]  # 最后一个位置
            kv_cache = outputs['kv_cache']

            # 应用温度
            logits = logits / temperature

            # Top-k 过滤
            if top_k > 0:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')

            # Top-p (nucleus) 过滤
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove
                )
                logits[indices_to_remove] = float('-inf')

            # 采样
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # 添加到生成序列
            generated = torch.cat([generated, next_token], dim=1)

            # 检查停止条件
            if stop_tokens and next_token.item() in stop_tokens:
                break

        return generated

# generated by hugo's coding agent

Step 3: 训练流程

3.1 数据集

from torch.utils.data import Dataset, DataLoader

class CodeDataset(Dataset):
    """代码训练数据集"""

    def __init__(self,
                 code_samples: List[str],
                 tokenizer: CodeTokenizer,
                 max_length: int = 1024):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # 预处理:tokenize 所有代码
        self.examples = []
        for code in code_samples:
            ids = tokenizer.encode(code)
            if len(ids) <= max_length:
                self.examples.append(ids)
            else:
                # 切分长序列
                for i in range(0, len(ids) - max_length, max_length // 2):
                    self.examples.append(ids[i:i + max_length])

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ids = self.examples[idx]

        # Padding
        padding_length = self.max_length - len(ids)
        input_ids = ids + [0] * padding_length  # 0 是 PAD token

        # Labels: 与 input 相同,但 padding 位置设为 -100
        labels = ids + [-100] * padding_length

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }


def collate_fn(batch):
    """批次整理函数"""
    input_ids = torch.stack([x['input_ids'] for x in batch])
    labels = torch.stack([x['labels'] for x in batch])
    return {'input_ids': input_ids, 'labels': labels}

# generated by hugo's coding agent

3.2 训练器

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import wandb

class CodeLLMTrainer:
    """训练器"""

    def __init__(self,
                 model: CodeLLM,
                 train_dataset: CodeDataset,
                 val_dataset: CodeDataset,
                 config: dict):
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset

        # 训练配置
        self.batch_size = config.get('batch_size', 32)
        self.learning_rate = config.get('learning_rate', 3e-4)
        self.epochs = config.get('epochs', 10)
        self.warmup_steps = config.get('warmup_steps', 1000)
        self.gradient_clip = config.get('gradient_clip', 1.0)
        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')

        # 移动模型到设备
        self.model = self.model.to(self.device)

        # 优化器
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=self.learning_rate,
            betas=(0.9, 0.95),
            weight_decay=0.1
        )

        # 数据加载器
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=4
        )
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=4
        )

        # 学习率调度器
        total_steps = len(self.train_loader) * self.epochs
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=total_steps,
            eta_min=self.learning_rate * 0.1
        )

    def train_epoch(self, epoch: int) -> float:
        """训练一个 epoch"""
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            # 前向传播
            outputs = self.model(input_ids, labels=labels)
            loss = outputs['loss']

            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()

            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.gradient_clip
            )

            self.optimizer.step()
            self.scheduler.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        return total_loss / len(self.train_loader)

    @torch.no_grad()
    def evaluate(self) -> float:
        """评估模型"""
        self.model.eval()
        total_loss = 0

        for batch in self.val_loader:
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            outputs = self.model(input_ids, labels=labels)
            total_loss += outputs['loss'].item()

        return total_loss / len(self.val_loader)

    def train(self):
        """完整训练流程"""
        best_val_loss = float('inf')

        for epoch in range(1, self.epochs + 1):
            train_loss = self.train_epoch(epoch)
            val_loss = self.evaluate()

            print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}')

            # 保存最佳模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_loss': val_loss
                }, 'best_model.pt')
                print(f'  -> Saved best model (val_loss: {val_loss:.4f})')

# generated by hugo's coding agent

3.3 开始训练

# 完整的训练脚本
def main():
    # 1. 收集数据
    print("Collecting Python code samples...")
    python_files = collect_python_files("/path/to/python/repos")
    code_samples = []
    for f in tqdm(python_files):
        sample = process_code_file(f)
        if sample and sample.is_valid:
            code_samples.append(sample.source)

    print(f"Collected {len(code_samples)} valid Python files")

    # 2. 训练 tokenizer
    print("Training tokenizer...")
    tokenizer = CodeTokenizer(vocab_size=8000)
    tokenizer.train(code_samples[:10000], num_merges=5000)

    # 3. 创建数据集
    print("Creating datasets...")
    train_samples = code_samples[:int(len(code_samples) * 0.9)]
    val_samples = code_samples[int(len(code_samples) * 0.9):]

    train_dataset = CodeDataset(train_samples, tokenizer)
    val_dataset = CodeDataset(val_samples, tokenizer)

    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

    # 4. 创建模型
    print("Initializing model...")
    config = CodeLLMConfig()
    model = CodeLLM(config)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

    # 5. 训练
    print("Starting training...")
    trainer = CodeLLMTrainer(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        config={
            'batch_size': 32,
            'learning_rate': 3e-4,
            'epochs': 10,
            'device': 'cuda'
        }
    )
    trainer.train()

if __name__ == '__main__':
    main()

# generated by hugo's coding agent

Step 4: 推理与代码生成

训练完成后,让我们看看如何使用模型生成代码。

def generate_code(model: CodeLLM,
                  tokenizer: CodeTokenizer,
                  prompt: str,
                  max_tokens: int = 256) -> str:
    """
    根据 prompt 生成代码

    Args:
        model: 训练好的模型
        tokenizer: tokenizer
        prompt: 代码提示,如函数签名
        max_tokens: 最大生成 token 数

    Returns:
        生成的代码
    """
    model.eval()
    device = next(model.parameters()).device

    # Encode prompt
    input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)

    # 生成
    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_tokens,
        temperature=0.7,
        top_k=40,
        top_p=0.92,
        stop_tokens=[tokenizer.token_to_id.get('<EOS>', 3)]
    )

    # Decode
    generated_code = tokenizer.decode(output_ids[0].tolist())
    return generated_code


# 使用示例
prompt = '''def fibonacci(n: int) -> int:
    """
    Calculate the nth Fibonacci number.

    Args:
        n: The position in Fibonacci sequence

    Returns:
        The nth Fibonacci number
    """
'''

generated = generate_code(model, tokenizer, prompt)
print(generated)

# 可能的输出:
# def fibonacci(n: int) -> int:
#     """
#     Calculate the nth Fibonacci number.
#     ...
#     """
#     if n <= 1:
#         return n
#     return fibonacci(n - 1) + fibonacci(n - 2)

# generated by hugo's coding agent

评估与改进方向

评估指标

指标说明目标值
Perplexity困惑度,越低越好< 10
Pass@k生成代码通过测试的比例> 30%
BLEU与参考代码的相似度> 0.3
Syntax Valid语法正确率> 95%

改进方向

  1. 扩大数据集:使用更多高质量的 Python 代码,如 GitHub 上星标高的项目
  2. 增加模型参数:从 50M 扩展到 100M-300M
  3. 引入代码结构信息:利用 AST 信息增强模型理解
  4. 指令微调:使用 instruction-following 数据进行 SFT
  5. RLHF:通过人类反馈进一步优化生成质量

总结

本文从零开始实现了一个能够生成 Python 代码的小型语言模型。虽然与 CodeLlama、StarCoder 等大模型相比,我们的模型规模较小,但通过这个项目,你应该对以下内容有了深入理解:

  1. 代码 Tokenization:如何处理代码的特殊结构(缩进、关键字等)
  2. Transformer 架构:包括 RoPE、KV Cache、SwiGLU 等现代技术
  3. 训练流程:数据准备、损失计算、优化器配置
  4. 推理优化:Temperature、Top-k、Top-p 采样策略

下一步,你可以尝试:

  • 在更大的数据集上训练
  • 添加更多编程语言支持
  • 实现 Fill-in-the-Middle (FIM) 能力
  • 部署为 API 服务

代码能力是 LLM 最重要的能力之一,理解其实现原理将帮助你更好地使用和优化这些模型。Happy coding!


See also