Skip to content

Commit 0af7f30

Browse files
committed
新增mongo工具类
1 parent 9dc1637 commit 0af7f30

1 file changed

Lines changed: 185 additions & 0 deletions

File tree

tools/mongo.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# encoding: utf-8
2+
__author__ = 'zhanghe'
3+
4+
5+
from pymongo import MongoClient
6+
from pymongo import errors
7+
import json
8+
from datetime import date, datetime
9+
10+
# 数据库日志专用配置
11+
from log import Logger
12+
my_logger = Logger('mongodb', 'mongodb.log', 'DEBUG')
13+
my_logger.set_file_level('DEBUG')
14+
my_logger.set_stream_level('WARNING') # WARNING DEBUG
15+
my_logger.set_stream_handler_fmt('%(message)s')
16+
my_logger.load()
17+
logger = my_logger.logger
18+
# my_logger.get_memory_usage()
19+
20+
21+
class Mongodb(object):
22+
"""
23+
自定义mongodb工具
24+
"""
25+
def __init__(self, db_config, db_name=None):
26+
self.db_config = db_config
27+
if db_name is not None:
28+
self.db_config['database'] = db_name
29+
try:
30+
# 实例化mongodb
31+
self.conn = MongoClient(self.db_config['host'], self.db_config['port'])
32+
# 获取数据库对象(选择/切换)
33+
self.db = self.conn.get_database(self.db_config['database'])
34+
except errors.ServerSelectionTimeoutError, e:
35+
logger.error('连接超时:%s' % e)
36+
except Exception, e:
37+
logger.error(e)
38+
39+
@staticmethod
40+
def __default(obj):
41+
"""
42+
支持datetime的json encode
43+
TypeError: datetime.datetime(2015, 10, 21, 8, 42, 54) is not JSON serializable
44+
:param obj:
45+
:return:
46+
"""
47+
if isinstance(obj, datetime):
48+
return obj.strftime('%Y-%m-%d %H:%M:%S')
49+
elif isinstance(obj, date):
50+
return obj.strftime('%Y-%m-%d')
51+
else:
52+
raise TypeError('%r is not JSON serializable' % obj)
53+
54+
def close_conn(self):
55+
"""
56+
关闭连接
57+
关闭所有套接字的连接池和停止监控线程。
58+
如果这个实例再次使用它将自动重启和重新启动线程
59+
"""
60+
self.conn.close()
61+
62+
def find_one(self, table_name, condition=None):
63+
"""
64+
查询单条记录
65+
:param table_name:
66+
:param condition:
67+
:return:
68+
"""
69+
return self.db.get_collection(table_name).find_one(condition)
70+
71+
def find_all(self, table_name, condition=None):
72+
"""
73+
查询多条记录
74+
:param table_name:
75+
:param condition:
76+
:return:
77+
"""
78+
return self.db.get_collection(table_name).find(condition)
79+
80+
def count(self, table_name, condition=None):
81+
"""
82+
查询记录总数
83+
:param table_name:
84+
:param condition:
85+
:return:
86+
"""
87+
return self.db.get_collection(table_name).count(condition)
88+
89+
def distinct(self, table_name, field_name):
90+
"""
91+
查询某字段去重后值的范围
92+
:param table_name:
93+
:param field_name:
94+
:return:
95+
"""
96+
return self.db.get_collection(table_name).distinct(field_name)
97+
98+
def insert(self, table_name, data):
99+
"""
100+
插入数据
101+
:param table_name:
102+
:param data:
103+
:return:
104+
"""
105+
try:
106+
ids = self.db.get_collection(table_name).insert(data)
107+
return ids
108+
except Exception, e:
109+
logger.error('插入错误:%s' % e)
110+
111+
def update(self, table_name, condition, update_data):
112+
"""
113+
批量更新数据
114+
upsert : 如果不存在update的记录,是否插入;true为插入,默认是false,不插入。
115+
:param table_name:
116+
:param condition:
117+
:param update_data:
118+
:return:
119+
"""
120+
return self.db.get_collection(table_name).update_many(condition, update_data)
121+
122+
def remove(self, table_name, condition=None):
123+
"""
124+
删除文档记录
125+
:param table_name:
126+
:param condition:
127+
:return:
128+
"""
129+
result = self.db.get_collection(table_name).remove(condition)
130+
if result.get('err') is None:
131+
logger.info('删除成功,删除行数%s' % result.get('n'))
132+
else:
133+
logger.error('删除失败:%s' % result.get('err'))
134+
return result
135+
136+
def output_row(self, table_name, condition=None, style=0):
137+
"""
138+
格式化输出单个记录
139+
style=0 键值对齐风格
140+
style=1 JSON缩进风格
141+
:param table_name:
142+
:param condition:
143+
:param style:
144+
:return:
145+
"""
146+
row = self.find_one(table_name, condition)
147+
if style == 0:
148+
# 获取KEY最大的长度作为缩进依据
149+
max_len_key = max([len(each_key) for each_key in row.keys()])
150+
str_format = '{0: >%s}' % max_len_key
151+
keys = [str_format.format(each_key) for each_key in row.keys()]
152+
result = dict(zip(keys, row.values()))
153+
print '********** 表名[%s] **********' % table_name
154+
for key, item in result.items():
155+
print key, ':', item
156+
else:
157+
print json.dumps(row, indent=4, ensure_ascii=False, default=self.__default)
158+
159+
def output_rows(self, table_name, condition=None, style=0):
160+
"""
161+
格式化输出批量记录
162+
style=0 键值对齐风格
163+
style=1 JSON缩进风格
164+
:param table_name:
165+
:param condition:
166+
:param style:
167+
:return:
168+
"""
169+
rows = self.find_all(table_name, condition)
170+
total = self.count(table_name, condition)
171+
if style == 0:
172+
count = 0
173+
for row in rows:
174+
# 获取KEY最大的长度作为缩进依据
175+
max_len_key = max([len(each_key) for each_key in row.keys()])
176+
str_format = '{0: >%s}' % max_len_key
177+
keys = [str_format.format(each_key) for each_key in row.keys()]
178+
result = dict(zip(keys, row.values()))
179+
count += 1
180+
print '********** 表名[%s] [%d/%d] **********' % (table_name, count, total)
181+
for key, item in result.items():
182+
print key, ':', item
183+
else:
184+
for row in rows:
185+
print json.dumps(row, indent=4, ensure_ascii=False, default=self.__default)

0 commit comments

Comments
 (0)