这是我的代码:
class ReviewCategoryClassifier(object):
@classmethod
def load_data(cls, input_file):
job = category_predictor.CategoryPredictor()
category_counts = None
word_counts = {}
with open(input_file) as src:
for line in src:
category, counts = job.parse_output_line(line)
def __init__(self, input_file):
"""input_file: the output of the CategoryPredictor job."""
category_counts, word_counts = self.load_data(input_file)
self.word_given_cat_prob = {}
for cat, counts in word_counts.iteritems():
self.word_given_cat_prob[cat] = self.normalize_counts(counts)
# filter out categories which have no words
seen_categories = set(word_counts)
seen_category_counts = dict((cat, count) for cat, count in
category_counts.iteritems() \
if cat in seen_categories)
self.category_prob= self.normalize_counts(
seen_category_counts)
if __name__ == "__main__":
input_file = sys.argv[1]
text = sys.argv[2]
guesses = ReviewCategoryClassifier(input_file).classify(text)
btw CategoryPredictor()是一个mrjob项目。
每当我输入
python predict.py yelp_academic_dataset_review.json'我喜欢甜甜圈'
在命令行中,它出现错误:
TypeError:无法转换'字节'隐含地反对str
但是line是一个字符串而不是一个bytes对象。我做错了什么?
这里是完整的追溯
Traceback (most recent call last):
File "predict.py", line 116, in <module>
guesses = ReviewCategoryClassifier(input_file).classify(text)
File "predict.py", line 65, in __init__
category_counts, word_counts = self.load_data(input_file)
File "predict.py", line 44, in load_data
category, counts = job.parse_output_line(line)
File "//anaconda/lib/python3.5/site-packages/mrjob/job.py", line 961, in
parse_output_line
return self.output_protocol().read(line)
File "//anaconda/lib/python3.5/site-packages/mrjob/protocol.py", line 84, in
read
raw_key, raw_value = line.split(b'\t', 1)
TypeError: Can't convert 'bytes' object to str implicitly
答案 0 :(得分:1)
您需要将字节传递给MRJob.parse_output_line
;用二进制模式打开input_file
with open(input_file, 'rb') as src:
for line in src:
category, counts = job.parse_output_line(line)
或在传递给方法之前对该行进行编码:
with open(input_file) as src:
for line in src:
category, counts = job.parse_output_line(line.encode())