|
| 1 | +# encoding: utf-8 |
| 2 | +__author__ = 'zhanghe' |
| 3 | + |
| 4 | + |
| 5 | +from pymongo import MongoClient |
| 6 | +from pymongo import errors |
| 7 | +import json |
| 8 | +from datetime import date, datetime |
| 9 | + |
| 10 | +# 数据库日志专用配置 |
| 11 | +from log import Logger |
| 12 | +my_logger = Logger('mongodb', 'mongodb.log', 'DEBUG') |
| 13 | +my_logger.set_file_level('DEBUG') |
| 14 | +my_logger.set_stream_level('WARNING') # WARNING DEBUG |
| 15 | +my_logger.set_stream_handler_fmt('%(message)s') |
| 16 | +my_logger.load() |
| 17 | +logger = my_logger.logger |
| 18 | +# my_logger.get_memory_usage() |
| 19 | + |
| 20 | + |
| 21 | +class Mongodb(object): |
| 22 | + """ |
| 23 | + 自定义mongodb工具 |
| 24 | + """ |
| 25 | + def __init__(self, db_config, db_name=None): |
| 26 | + self.db_config = db_config |
| 27 | + if db_name is not None: |
| 28 | + self.db_config['database'] = db_name |
| 29 | + try: |
| 30 | + # 实例化mongodb |
| 31 | + self.conn = MongoClient(self.db_config['host'], self.db_config['port']) |
| 32 | + # 获取数据库对象(选择/切换) |
| 33 | + self.db = self.conn.get_database(self.db_config['database']) |
| 34 | + except errors.ServerSelectionTimeoutError, e: |
| 35 | + logger.error('连接超时:%s' % e) |
| 36 | + except Exception, e: |
| 37 | + logger.error(e) |
| 38 | + |
| 39 | + @staticmethod |
| 40 | + def __default(obj): |
| 41 | + """ |
| 42 | + 支持datetime的json encode |
| 43 | + TypeError: datetime.datetime(2015, 10, 21, 8, 42, 54) is not JSON serializable |
| 44 | + :param obj: |
| 45 | + :return: |
| 46 | + """ |
| 47 | + if isinstance(obj, datetime): |
| 48 | + return obj.strftime('%Y-%m-%d %H:%M:%S') |
| 49 | + elif isinstance(obj, date): |
| 50 | + return obj.strftime('%Y-%m-%d') |
| 51 | + else: |
| 52 | + raise TypeError('%r is not JSON serializable' % obj) |
| 53 | + |
| 54 | + def close_conn(self): |
| 55 | + """ |
| 56 | + 关闭连接 |
| 57 | + 关闭所有套接字的连接池和停止监控线程。 |
| 58 | + 如果这个实例再次使用它将自动重启和重新启动线程 |
| 59 | + """ |
| 60 | + self.conn.close() |
| 61 | + |
| 62 | + def find_one(self, table_name, condition=None): |
| 63 | + """ |
| 64 | + 查询单条记录 |
| 65 | + :param table_name: |
| 66 | + :param condition: |
| 67 | + :return: |
| 68 | + """ |
| 69 | + return self.db.get_collection(table_name).find_one(condition) |
| 70 | + |
| 71 | + def find_all(self, table_name, condition=None): |
| 72 | + """ |
| 73 | + 查询多条记录 |
| 74 | + :param table_name: |
| 75 | + :param condition: |
| 76 | + :return: |
| 77 | + """ |
| 78 | + return self.db.get_collection(table_name).find(condition) |
| 79 | + |
| 80 | + def count(self, table_name, condition=None): |
| 81 | + """ |
| 82 | + 查询记录总数 |
| 83 | + :param table_name: |
| 84 | + :param condition: |
| 85 | + :return: |
| 86 | + """ |
| 87 | + return self.db.get_collection(table_name).count(condition) |
| 88 | + |
| 89 | + def distinct(self, table_name, field_name): |
| 90 | + """ |
| 91 | + 查询某字段去重后值的范围 |
| 92 | + :param table_name: |
| 93 | + :param field_name: |
| 94 | + :return: |
| 95 | + """ |
| 96 | + return self.db.get_collection(table_name).distinct(field_name) |
| 97 | + |
| 98 | + def insert(self, table_name, data): |
| 99 | + """ |
| 100 | + 插入数据 |
| 101 | + :param table_name: |
| 102 | + :param data: |
| 103 | + :return: |
| 104 | + """ |
| 105 | + try: |
| 106 | + ids = self.db.get_collection(table_name).insert(data) |
| 107 | + return ids |
| 108 | + except Exception, e: |
| 109 | + logger.error('插入错误:%s' % e) |
| 110 | + |
| 111 | + def update(self, table_name, condition, update_data): |
| 112 | + """ |
| 113 | + 批量更新数据 |
| 114 | + upsert : 如果不存在update的记录,是否插入;true为插入,默认是false,不插入。 |
| 115 | + :param table_name: |
| 116 | + :param condition: |
| 117 | + :param update_data: |
| 118 | + :return: |
| 119 | + """ |
| 120 | + return self.db.get_collection(table_name).update_many(condition, update_data) |
| 121 | + |
| 122 | + def remove(self, table_name, condition=None): |
| 123 | + """ |
| 124 | + 删除文档记录 |
| 125 | + :param table_name: |
| 126 | + :param condition: |
| 127 | + :return: |
| 128 | + """ |
| 129 | + result = self.db.get_collection(table_name).remove(condition) |
| 130 | + if result.get('err') is None: |
| 131 | + logger.info('删除成功,删除行数%s' % result.get('n')) |
| 132 | + else: |
| 133 | + logger.error('删除失败:%s' % result.get('err')) |
| 134 | + return result |
| 135 | + |
| 136 | + def output_row(self, table_name, condition=None, style=0): |
| 137 | + """ |
| 138 | + 格式化输出单个记录 |
| 139 | + style=0 键值对齐风格 |
| 140 | + style=1 JSON缩进风格 |
| 141 | + :param table_name: |
| 142 | + :param condition: |
| 143 | + :param style: |
| 144 | + :return: |
| 145 | + """ |
| 146 | + row = self.find_one(table_name, condition) |
| 147 | + if style == 0: |
| 148 | + # 获取KEY最大的长度作为缩进依据 |
| 149 | + max_len_key = max([len(each_key) for each_key in row.keys()]) |
| 150 | + str_format = '{0: >%s}' % max_len_key |
| 151 | + keys = [str_format.format(each_key) for each_key in row.keys()] |
| 152 | + result = dict(zip(keys, row.values())) |
| 153 | + print '********** 表名[%s] **********' % table_name |
| 154 | + for key, item in result.items(): |
| 155 | + print key, ':', item |
| 156 | + else: |
| 157 | + print json.dumps(row, indent=4, ensure_ascii=False, default=self.__default) |
| 158 | + |
| 159 | + def output_rows(self, table_name, condition=None, style=0): |
| 160 | + """ |
| 161 | + 格式化输出批量记录 |
| 162 | + style=0 键值对齐风格 |
| 163 | + style=1 JSON缩进风格 |
| 164 | + :param table_name: |
| 165 | + :param condition: |
| 166 | + :param style: |
| 167 | + :return: |
| 168 | + """ |
| 169 | + rows = self.find_all(table_name, condition) |
| 170 | + total = self.count(table_name, condition) |
| 171 | + if style == 0: |
| 172 | + count = 0 |
| 173 | + for row in rows: |
| 174 | + # 获取KEY最大的长度作为缩进依据 |
| 175 | + max_len_key = max([len(each_key) for each_key in row.keys()]) |
| 176 | + str_format = '{0: >%s}' % max_len_key |
| 177 | + keys = [str_format.format(each_key) for each_key in row.keys()] |
| 178 | + result = dict(zip(keys, row.values())) |
| 179 | + count += 1 |
| 180 | + print '********** 表名[%s] [%d/%d] **********' % (table_name, count, total) |
| 181 | + for key, item in result.items(): |
| 182 | + print key, ':', item |
| 183 | + else: |
| 184 | + for row in rows: |
| 185 | + print json.dumps(row, indent=4, ensure_ascii=False, default=self.__default) |
0 commit comments