import contextlib import gevent.queue from gevent.socket import wait_read, wait_write import psycopg2 import psycopg2.extensions import psycopg2.extras import config def gevent_wait_callback(conn, timeout=None): # https://github.com/zacharyvoase/gevent-psycopg2/blob/master/lib/gevent_psycopg2.py while True: state = conn.poll() if state == psycopg2.extensions.POLL_OK: break elif state == psycopg2.extensions.POLL_READ: wait_read(conn.fileno(), timeout=timeout) elif state == psycopg2.extensions.POLL_WRITE: wait_write(conn.fileno(), timeout=timeout) else: raise psycopg2.OperationalError('unhandled state: %r' % state) psycopg2.extensions.set_wait_callback(gevent_wait_callback) pool = gevent.queue.Queue(maxsize=4) for _ in xrange(4): pool.put(psycopg2.connect('dbname=%s user=%s' % (config.database, config.db_user))) @contextlib.contextmanager def cursor(): # https://code.google.com/p/gevent/source/browse/examples/psycopg2_pool.py?name=1.0b4#88 conn = pool.get(timeout=1) try: yield conn.cursor(cursor_factory=psycopg2.extras.DictCursor) except: if not conn.closed: try: conn.rollback() except: gevent.get_hub().handle_error(conn, *sys.exc_info()) raise else: conn.commit() finally: if conn.closed: raise Exception('cursor context manager got back closed connection') pool.put_nowait(conn) def query(cur, sql, *args): cur.execute(sql, args) return cur.fetchall() def query_iter(cur, sql, *args): cur.execute(sql, args) while True: r = cur.fetchone() if r is None: break yield r def query_one(cur, sql, *args): cur.execute(sql, args) rval = cur.fetchone() if cur.fetchone() is not None: raise Exception('got more than one value for query', sql, args) return rval def get_api_key(group_id): with cursor() as cur: api_key = query_one(cur, 'SELECT api_key FROM groups WHERE id = %s', group_id)[0] return api_key def create_server(group_id, hostname): with cursor() as cur: server_id = query_one(cur, 'INSERT INTO servers (group_id, hostname) VALUES(%s, %s) RETURNING id', group_id, hostname)[0] return server_id def get_servers(group_id): servers = [] with cursor() as cur: for server in query(cur, 'SELECT id, hostname FROM servers WHERE group_id = %s', group_id): servers.append(server.copy()) return servers