我正在研究算术编码和解码算法的自适应实现,我已经实现了python但是对于某些字符串,我得到了正确的答案但是对于其他人我得到了正确的答案。
当程序首次启动时,会提供一个参数来确定符号概率改变的频率。例如,如果参数是10,则在发送/接收10个符号之后,根据到目前为止发送/接收的所有符号改变概率表。因此,域名分配也会发生变化。最初,我有均匀分布[a-z],概率为1/26。
它不适用于" heloworldheloworld"还有很多其他案例。
另外,我已经了解了下溢问题但我该如何解决这个问题。
import sys
import random
import string
def encode(encode_str, N):
count = dict.fromkeys(string.ascii_lowercase, 1) # probability table
cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
pdf = dict.fromkeys(string.ascii_lowercase, 0)
low = 0
high = float(1)/float(26)
for key, value in sorted(cdf_range.iteritems()):
cdf_range[key] = [low, high]
low = high
high += float(1)/float(26)
for key, value in sorted(pdf.iteritems()):
pdf[key] = float(1)/float(26)
# for key, value in sorted(cdf_range.iteritems()):
# print key, value
# for key, value in sorted(pdf.iteritems()):
# print key, value
i = 26
lower_bound = 0 # upper bound
upper_bound = 1 # lower bound
u = 0
# go thru every symbol in the string
for sym in encode_str:
i += 1
u += 1
count[sym] += 1
curr_range = upper_bound - lower_bound # current range
upper_bound = lower_bound + (curr_range * cdf_range[sym][1]) # upper_bound
lower_bound = lower_bound + (curr_range * cdf_range[sym][0]) # lower bound
# update cdf_range after N symbols have been read
if (u == N):
u = 0
for key, value in sorted(pdf.iteritems()):
pdf[key] = float(count[key])/float(i)
low = 0
for key, value in sorted(cdf_range.iteritems()):
high = pdf[key] + low
cdf_range[key] = [low, high]
low = high
return lower_bound
def decode(encoded, strlen, every):
decoded_str = ""
count = dict.fromkeys(string.ascii_lowercase, 1) # probability table
cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
pdf = dict.fromkeys(string.ascii_lowercase, 0)
low = 0
high = float(1)/float(26)
for key, value in sorted(cdf_range.iteritems()):
cdf_range[key] = [low, high]
low = high
high += float(1)/float(26)
for key, value in sorted(pdf.iteritems()):
pdf[key] = float(1)/float(26)
lower_bound = 0 # upper bound
upper_bound = 1 # lower bound
k = 0
while (strlen != len(decoded_str)):
for key, value in sorted(pdf.iteritems()):
curr_range = upper_bound - lower_bound # current range
upper_cand = lower_bound + (curr_range * cdf_range[key][1]) # upper_bound
lower_cand = lower_bound + (curr_range * cdf_range[key][0]) # lower bound
if (lower_cand <= encoded < upper_cand):
k += 1
decoded_str += key
if (strlen == len(decoded_str)):
break
upper_bound = upper_cand
lower_bound = lower_cand
count[key] += 1
if (k == every):
k = 0
for key, value in sorted(pdf.iteritems()):
pdf[key] = float(count[key])/float(26+len(decoded_str))
low = 0
for key, value in sorted(cdf_range.iteritems()):
high = pdf[key] + low
cdf_range[key] = [low, high]
low = high
print decoded_str
def main():
count = 10
encode_str = "yyyyuuuuyyyy"
strlen = len(encode_str)
every = 3
encoded = encode(encode_str, every)
decoded = decode(encoded, strlen, every)
if __name__ == '__main__':
main()
答案 0 :(得分:1)
错误出现在大约12个字符长度的字符串中。这接近python使用的双精度,可能会导致你的问题。
我使用BigFloat库(具有任意精度)进行了快速测试,得到了正确答案:
import sys
import random
import string
from bigfloat import *
factor = BigFloat(1)/BigFloat(26)
def encode(encode_str, N):
count = dict.fromkeys(string.ascii_lowercase, 1) # probability table
cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
pdf = dict.fromkeys(string.ascii_lowercase, 0)
with precision(200) + RoundTowardZero:
low = 0
high = factor
for key, value in sorted(cdf_range.iteritems()):
cdf_range[key] = [low, high]
low = high
high += factor
for key, value in sorted(pdf.iteritems()):
pdf[key] = factor
# for key, value in sorted(cdf_range.iteritems()):
# print key, value
# for key, value in sorted(pdf.iteritems()):
# print key, value
i = 26
lower_bound = 0 # upper bound
upper_bound = 1 # lower bound
u = 0
# go thru every symbol in the string
for sym in encode_str:
i += 1
u += 1
count[sym] += 1
curr_range = upper_bound - lower_bound # current range
upper_bound = lower_bound + (curr_range * cdf_range[sym][1]) # upper_bound
lower_bound = lower_bound + (curr_range * cdf_range[sym][0]) # lower bound
# update cdf_range after N symbols have been read
if (u == N):
u = 0
for key, value in sorted(pdf.iteritems()):
pdf[key] = BigFloat(count[key])/BigFloat(i)
low = 0
for key, value in sorted(cdf_range.iteritems()):
high = pdf[key] + low
cdf_range[key] = [low, high]
low = high
return lower_bound
def decode(encoded, strlen, every):
decoded_str = ""
count = dict.fromkeys(string.ascii_lowercase, 1) # probability table
cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
pdf = dict.fromkeys(string.ascii_lowercase, 0)
with precision(200) + RoundTowardZero:
low = 0
high = factor
for key, value in sorted(cdf_range.iteritems()):
cdf_range[key] = [low, high]
low = high
high += factor
for key, value in sorted(pdf.iteritems()):
pdf[key] = factor
lower_bound = BigFloat(0) # upper bound
upper_bound = BigFloat(1) # lower bound
k = 0
while (strlen != len(decoded_str)):
for key, value in sorted(pdf.iteritems()):
curr_range = upper_bound - lower_bound # current range
upper_cand = lower_bound + (curr_range * cdf_range[key][1]) # upper_bound
lower_cand = lower_bound + (curr_range * cdf_range[key][0]) # lower bound
if (lower_cand <= encoded < upper_cand):
k += 1
decoded_str += key
if (strlen == len(decoded_str)):
break
upper_bound = upper_cand
lower_bound = lower_cand
count[key] += 1
if (k == every):
k = 0
for key, value in sorted(pdf.iteritems()):
pdf[key] = BigFloat(count[key])/BigFloat(26+len(decoded_str))
low = 0
for key, value in sorted(cdf_range.iteritems()):
high = pdf[key] + low
cdf_range[key] = [low, high]
low = high
print decoded_str
def main():
count = 10
encode_str = "heloworldheloworld"
strlen = len(encode_str)
every = 3
encoded = encode(encode_str, every)
decoded = decode(encoded, strlen, every)
if __name__ == '__main__':
main()
答案 1 :(得分:1)
这种情况发生了,因为Python float
具有53位精度。你不能编码很长的字符串。
您可能希望使用decimal
代替floats
来获得任意精度
import sys
import random
import string
import decimal
from decimal import Decimal
decimal.getcontext().prec=100
def encode(encode_str, N):
count = dict.fromkeys(string.ascii_lowercase, 1) # probability table
cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
pdf = dict.fromkeys(string.ascii_lowercase, 0)
low = 0
high = Decimal(1)/Decimal(26)
for key, value in sorted(cdf_range.iteritems()):
cdf_range[key] = [low, high]
low = high
high += Decimal(1)/Decimal(26)
for key, value in sorted(pdf.iteritems()):
pdf[key] = Decimal(1)/Decimal(26)
# for key, value in sorted(cdf_range.iteritems()):
# print key, value
# for key, value in sorted(pdf.iteritems()):
# print key, value
i = 26
lower_bound = 0 # upper bound
upper_bound = 1 # lower bound
u = 0
# go thru every symbol in the string
for sym in encode_str:
i += 1
u += 1
count[sym] += 1
curr_range = upper_bound - lower_bound # current range
upper_bound = lower_bound + (curr_range * cdf_range[sym][1]) # upper_bound
lower_bound = lower_bound + (curr_range * cdf_range[sym][0]) # lower bound
# update cdf_range after N symbols have been read
if (u == N):
u = 0
for key, value in sorted(pdf.iteritems()):
pdf[key] = Decimal(count[key])/Decimal(i)
low = 0
for key, value in sorted(cdf_range.iteritems()):
high = pdf[key] + low
cdf_range[key] = [low, high]
low = high
return lower_bound
def decode(encoded, strlen, every):
decoded_str = ""
count = dict.fromkeys(string.ascii_lowercase, 1) # probability table
cdf_range = dict.fromkeys(string.ascii_lowercase, 0)
pdf = dict.fromkeys(string.ascii_lowercase, 0)
low = 0
high = Decimal(1)/Decimal(26)
for key, value in sorted(cdf_range.iteritems()):
cdf_range[key] = [low, high]
low = high
high += Decimal(1)/Decimal(26)
for key, value in sorted(pdf.iteritems()):
pdf[key] = Decimal(1)/Decimal(26)
lower_bound = 0 # upper bound
upper_bound = 1 # lower bound
k = 0
while (strlen != len(decoded_str)):
for key, value in sorted(pdf.iteritems()):
curr_range = upper_bound - lower_bound # current range
upper_cand = lower_bound + (curr_range * cdf_range[key][1]) # upper_bound
lower_cand = lower_bound + (curr_range * cdf_range[key][0]) # lower bound
if (lower_cand <= encoded < upper_cand):
k += 1
decoded_str += key
if (strlen == len(decoded_str)):
break
upper_bound = upper_cand
lower_bound = lower_cand
count[key] += 1
if (k == every):
k = 0
for key, value in sorted(pdf.iteritems()):
pdf[key] = Decimal(count[key])/Decimal(26+len(decoded_str))
low = 0
for key, value in sorted(cdf_range.iteritems()):
high = pdf[key] + low
cdf_range[key] = [low, high]
low = high
print decoded_str
def main():
count = 10
encode_str = "heloworldheloworld"
strlen = len(encode_str)
every = 3
encoded = encode(encode_str, every)
decoded = decode(encoded, strlen, every)
if __name__ == '__main__':
main()