Unexpected behavior with sampling of repeated character sequence.
#904 创建于 2023年8月14日
描述
I added some sequences of repeated characters as user defined tokens to a Unigram model. Now when tokenizing with sampling, I get unexpected behavior as I increase the nbest size. I believe this is a bug. Can you please confirm and let me know a workaround?
Below, my string is '+' repeated 16 times. When calling for an nbest_size of 2, I get the two most plausible sequences: '+' 16 times (which is a token), and '+' 8 times with and without leading meta space, as below.
LONG_STR = '++++++++++++++++'
def print_nbest(mystr, n):
possible_tokenization = []
for i in range(100):
tokenization = our_tokenizer.encode_as_pieces(mystr, enable_sampling=True, nbest_size=n)
possible_tokenization.append(" ".join(tokenization))
pprint.pprint(Counter(possible_tokenization),)
print_nbest(LONG_STR, 2) # Result: Counter({'▁++++++++ ++++++++': 54, '▁++++++++++++++++': 46})
Now when calling the same function with a larger nbest size, these "top two" tokenizations are now much further down the list. This does not make sense to me, as if each user-defined-symbol is added with the same high probability, sequences with more tokens should be less likely (as probabilities multiply)
print_nbest(LONG_STR, 50)
# Result: Counter({'▁++++++++ ++++ + ++ +': 7,
# '▁++++++++ ++ ++++ ++': 5,
# '▁++++ ++ + + ++++++++': 5,
# '▁++++ ++++++++ + ++ +': 4,
# '▁++ + ++++++++ + ++++': 4,
# '▁++++++++ ++ + ++++ +': 4,
# '▁+ ++ ++++ + ++++++++': 4,
# '▁++++++++ ++ ++ ++ ++': 4,
# '▁++++++++ ++++++++': 4,
# ... }
Finally, if I request an nbest_size of 1000, sampling fails:
print_nbest(LONG_STR, 1000) # Result: Counter({'': 100})
Furthermore, playing with alpha between 0.01 and 0.99 does not lead to a predictable change in peakiness of the distribution. Observe the results of setting alpha=0.01 and alpha=0.99, for four runs of 100 samples. The only difference between the two function calls is alpha.
def print_nbest(mystr, n, alpha):
possible_tokenization = []
for i in range(100):
tokenization = our_tokenizer.encode_as_pieces(mystr, enable_sampling=True, alpha=alpha, nbest_size=n)
possible_tokenization.append(" ".join(tokenization))
pprint.pprint(Counter(possible_tokenization),)
print_nbest(LONG_STR, 2, 0.01)
print_nbest(LONG_STR, 2, 0.01)
print_nbest(LONG_STR, 2, 0.01)
print_nbest(LONG_STR, 2, 0.01)
# Result:
# Counter({'▁++++++++ ++++++++': 54, '▁++++++++++++++++': 46})
# Counter({'▁++++++++ ++++++++': 58, '▁++++++++++++++++': 42})
# Counter({'▁++++++++ ++++++++': 51, '▁++++++++++++++++': 49})
# Counter({'▁++++++++ ++++++++': 50, '▁++++++++++++++++': 50})
print_nbest(LONG_STR, 2, 0.99)
print_nbest(LONG_STR, 2, 0.99)
print_nbest(LONG_STR, 2, 0.99)
print_nbest(LONG_STR, 2, 0.99)
# Result
# Counter({'▁++++++++++++++++': 56, '▁++++++++ ++++++++': 44})
# Counter({'▁++++++++++++++++': 54, '▁++++++++ ++++++++': 46})
# Counter({'▁++++++++++++++++': 50, '▁++++++++ ++++++++': 50})
# Counter({'▁++++++++ ++++++++': 58, '▁++++++++++++++++': 42})