jingyaogong/minimind

YaRN的疑惑

Open

#602 aberto em 25 de dez. de 2025

Ver no GitHub
 (2 comments) (3 reactions) (0 assignees)Python (6.338 forks)batch import
good first issuequestion

Métricas do repositório

Stars
 (49.852 stars)
Métricas de merge de PR
 (Nenhuma PRs mesclada em 30d)

Description

感谢大佬的开源,阅读源码有个疑惑:对比了两个版本的YaRN代码实现,为什么high最后的取值不一样,一个是min(, dim // 2 - 1),一个是min(high, dim-1)?我试过一些数据跑了下两份代码,最后high的取值确实会不一样,这两种情况都可以的吗?

本项目代码如下:

def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None): freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0 if rope_scaling is not None: orig_max, factor, beta_fast, beta_slow, attn_factor = ( rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16), rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0) ) if end / orig_max > 1.0: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base)) low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1) ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1) freqs = freqs * (1 - ramp + ramp / factor)

  t = torch.arange(end, device=freqs.device)
  freqs = torch.outer(t, freqs).float()
  freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
  freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
  return freqs_cos, freqs_sin

DeepSeek版本代码如下: def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: """ Precomputes frequency-based complex exponential values for rotary positional embeddings.

Args:
    args (ModelArgs): Model arguments containing positional embedding parameters.

Returns:
    torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor

def find_correction_dim(num_rotations, dim, base, max_seq_len):
    """
    Computes the correction dimension for a given number of rotations in the rotary positional embedding.

    Args:
        num_rotations (float): Number of rotations to compute the correction for.
        dim (int): Dimensionality of the embedding space.
        base (float): Base value for the exponential computation.
        max_seq_len (int): Maximum sequence length.

    Returns:
        float: The correction dimension based on the input parameters.
    """
    return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
    """
    Computes the range of correction dimensions for rotary positional embeddings.

    Args:
        low_rot (float): Lower bound for the number of rotations.
        high_rot (float): Upper bound for the number of rotations.
        dim (int): Dimensionality of the embedding space.
        base (float): Base value for the exponential computation.
        max_seq_len (int): Maximum sequence length.

    Returns:
        Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
    """
    low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
    high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
    return max(low, 0), min(high, dim-1)

def linear_ramp_factor(min, max, dim):
    """
    Computes a linear ramp function used to smooth values between a minimum and maximum range.

    Args:
        min (float): Minimum value for the ramp function.
        max (float): Maximum value for the ramp function.
        dim (int): Dimensionality of the ramp tensor.

    Returns:
        torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
            clamped to the range [0, 1].
    """
    if min == max:
        max += 0.001
    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
    low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
    smooth = 1 - linear_ramp_factor(low, high, dim // 2)
    freqs = freqs / factor * (1 - smooth) + freqs * smooth

t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

希望大佬有空回下,解答下我的困惑!

Guia do colaborador