Description
感谢大佬的开源,阅读源码有个疑惑:对比了两个版本的YaRN代码实现,为什么high最后的取值不一样,一个是min(, dim // 2 - 1),一个是min(high, dim-1)?我试过一些数据跑了下两份代码,最后high的取值确实会不一样,这两种情况都可以的吗?
本项目代码如下:
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None): freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0 if rope_scaling is not None: orig_max, factor, beta_fast, beta_slow, attn_factor = ( rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16), rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0) ) if end / orig_max > 1.0: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base)) low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1) ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1) freqs = freqs * (1 - ramp + ramp / factor)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
return freqs_cos, freqs_sin
DeepSeek版本代码如下: def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: """ Precomputes frequency-based complex exponential values for rotary positional embeddings.
Args:
args (ModelArgs): Model arguments containing positional embedding parameters.
Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor
def find_correction_dim(num_rotations, dim, base, max_seq_len):
"""
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
Args:
num_rotations (float): Number of rotations to compute the correction for.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
float: The correction dimension based on the input parameters.
"""
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
"""
Computes the range of correction dimensions for rotary positional embeddings.
Args:
low_rot (float): Lower bound for the number of rotations.
high_rot (float): Upper bound for the number of rotations.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1)
def linear_ramp_factor(min, max, dim):
"""
Computes a linear ramp function used to smooth values between a minimum and maximum range.
Args:
min (float): Minimum value for the ramp function.
max (float): Maximum value for the ramp function.
dim (int): Dimensionality of the ramp tensor.
Returns:
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
clamped to the range [0, 1].
"""
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
希望大佬有空回下,解答下我的困惑!