jingyaogong/minimind

YaRN的疑惑

Open

#602 opened on Dec 25, 2025

View on GitHub
 (2 comments) (3 reactions) (0 assignees)Python (49,852 stars) (6,338 forks)batch import
good first issuequestion

Description

感谢大佬的开源,阅读源码有个疑惑:对比了两个版本的YaRN代码实现,为什么high最后的取值不一样,一个是min(, dim // 2 - 1),一个是min(high, dim-1)?我试过一些数据跑了下两份代码,最后high的取值确实会不一样,这两种情况都可以的吗?

本项目代码如下:

def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None): freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0 if rope_scaling is not None: orig_max, factor, beta_fast, beta_slow, attn_factor = ( rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16), rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0) ) if end / orig_max > 1.0: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base)) low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1) ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1) freqs = freqs * (1 - ramp + ramp / factor)

  t = torch.arange(end, device=freqs.device)
  freqs = torch.outer(t, freqs).float()
  freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
  freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
  return freqs_cos, freqs_sin

DeepSeek版本代码如下: def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: """ Precomputes frequency-based complex exponential values for rotary positional embeddings.

Args:
    args (ModelArgs): Model arguments containing positional embedding parameters.

Returns:
    torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor

def find_correction_dim(num_rotations, dim, base, max_seq_len):
    """
    Computes the correction dimension for a given number of rotations in the rotary positional embedding.

    Args:
        num_rotations (float): Number of rotations to compute the correction for.
        dim (int): Dimensionality of the embedding space.
        base (float): Base value for the exponential computation.
        max_seq_len (int): Maximum sequence length.

    Returns:
        float: The correction dimension based on the input parameters.
    """
    return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
    """
    Computes the range of correction dimensions for rotary positional embeddings.

    Args:
        low_rot (float): Lower bound for the number of rotations.
        high_rot (float): Upper bound for the number of rotations.
        dim (int): Dimensionality of the embedding space.
        base (float): Base value for the exponential computation.
        max_seq_len (int): Maximum sequence length.

    Returns:
        Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
    """
    low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
    high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
    return max(low, 0), min(high, dim-1)

def linear_ramp_factor(min, max, dim):
    """
    Computes a linear ramp function used to smooth values between a minimum and maximum range.

    Args:
        min (float): Minimum value for the ramp function.
        max (float): Maximum value for the ramp function.
        dim (int): Dimensionality of the ramp tensor.

    Returns:
        torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
            clamped to the range [0, 1].
    """
    if min == max:
        max += 0.001
    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
    low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
    smooth = 1 - linear_ramp_factor(low, high, dim // 2)
    freqs = freqs / factor * (1 - smooth) + freqs * smooth

t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

希望大佬有空回下,解答下我的困惑!

Contributor guide