博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
简单的贝叶斯分类器的python实现
阅读量:6244 次
发布时间:2019-06-22

本文共 3860 字,大约阅读时间需要 12 分钟。

 

 

1 # -*- coding: utf-8 -*-  2 '''  3 >>> c = Classy()  4 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture')  5 True  6 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices')  7 True  8 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture')  9 True 10 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair'] 11 >>> c.classify(my_office) 12 ('input_devices', -1.0986122886681098) 13 ... 14 >>> c = Classy() 15 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture') 16 True 17 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices') 18 True 19 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture') 20 True 21 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair'] 22 >>> c.classify(my_office) 23 ('input_devices', -1.0986122886681098) 24 ... 25 ''' 26  27 from collections import Counter 28 import math 29  30 class ClassifierNotTrainedException(Exception): 31      32     def __str__(self): 33         return "Classifier is not trained." 34  35 class Classy(object): 36      37     def __init__(self): 38         self.term_count_store = {} 39         self.data = { 40             'class_term_count': {}, 41             'beta_priors': {}, 42             'class_doc_count': {}, 43         } 44         self.total_term_count = 0 45         self.total_doc_count = 0 46          47     def train(self, document_source, class_id): 48      49         ''' 50         Trains the classifier. 51          52         ''' 53         count = Counter(document_source) 54         try: 55             self.term_count_store[class_id] 56         except KeyError: 57             self.term_count_store[class_id] = {} 58         for term in count: 59             try: 60                 self.term_count_store[class_id][term] += count[term] 61             except KeyError: 62                 self.term_count_store[class_id][term] = count[term] 63         try: 64             self.data['class_term_count'][class_id] += document_source.__len__() 65         except KeyError: 66             self.data['class_term_count'][class_id] = document_source.__len__() 67         try: 68             self.data['class_doc_count'][class_id] += 1 69         except KeyError: 70             self.data['class_doc_count'][class_id] = 1 71         self.total_term_count += document_source.__len__() 72         self.total_doc_count += 1 73         self.compute_beta_priors() 74         return True 75          76     def classify(self, document_input): 77         if not self.total_doc_count: raise ClassifierNotTrainedException() 78          79         term_freq_matrix = Counter(document_input) 80         arg_max_matrix = [] 81         for class_id in self.data['class_doc_count']: 82             summation = 0 83             for term in document_input: 84                 try: 85                     conditional_probability = (self.term_count_store[class_id][term] + 1) 86                     conditional_probability = conditional_probability / (self.data['class_term_count'][class_id] + self.total_doc_count) 87                     summation += term_freq_matrix[term] * math.log(conditional_probability) 88                 except KeyError: 89                     break 90             arg_max = summation + self.data['beta_priors'][class_id] 91             arg_max_matrix.insert(0, (class_id, arg_max)) 92         arg_max_matrix.sort(key=lambda x:x[1]) 93         return (arg_max_matrix[-1][0], arg_max_matrix[-1][1]) 94          95     def compute_beta_priors(self): 96         if not self.total_doc_count: raise ClassifierNotTrainedException() 97          98         for class_id in self.data['class_doc_count']: 99             tmp = self.data['class_doc_count'][class_id] / self.total_doc_count100             self.data['beta_priors'][class_id] = math.log(tmp)

 

转载地址:http://gtsia.baihongyu.com/

你可能感兴趣的文章
去掉JSON中值为null的
查看>>
我的友情链接
查看>>
职业考试的安排-2
查看>>
40个迹象表明你还是PHP菜鸟
查看>>
把程序员这条路走下去 .
查看>>
[Zephir官方文档翻译之四] 安装Zephir
查看>>
每天学一点Scala之内部类
查看>>
BWidget部件
查看>>
JavaScript强化教程 - 六步实现贪食蛇
查看>>
在oracle中恢复一个表的数据到某个时点
查看>>
我的友情链接
查看>>
maven环境快速搭建
查看>>
我的友情链接
查看>>
半导体产业的根基:晶圆是什么
查看>>
PHP页面刷新
查看>>
数据库之变迁
查看>>
DICOM协议中有关打印的内容
查看>>
lsmod
查看>>
server 2003 IIS无法访问asp页面,但是可以访问html静态页面
查看>>
totem成为万能播放器
查看>>