db.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import contextlib
  2. import gevent.queue
  3. from gevent.socket import wait_read, wait_write
  4. import psycopg2
  5. import psycopg2.extensions
  6. import psycopg2.extras
  7. def gevent_wait_callback(conn, timeout=None):
  8. # https://github.com/zacharyvoase/gevent-psycopg2/blob/master/lib/gevent_psycopg2.py
  9. while True:
  10. state = conn.poll()
  11. if state == psycopg2.extensions.POLL_OK:
  12. break
  13. elif state == psycopg2.extensions.POLL_READ:
  14. wait_read(conn.fileno(), timeout=timeout)
  15. elif state == psycopg2.extensions.POLL_WRITE:
  16. wait_write(conn.fileno(), timeout=timeout)
  17. else:
  18. raise psycopg2.OperationalError('unhandled state: %r' % state)
  19. psycopg2.extensions.set_wait_callback(gevent_wait_callback)
  20. pool = gevent.queue.Queue(maxsize=4)
  21. for _ in xrange(4):
  22. pool.put(psycopg2.connect('dbname=%s user=%s' % ('sysvitals', 'sysvitals')))
  23. @contextlib.contextmanager
  24. def cursor(): # https://code.google.com/p/gevent/source/browse/examples/psycopg2_pool.py?name=1.0b4#88
  25. conn = pool.get(timeout=1)
  26. try:
  27. yield conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
  28. except:
  29. if not conn.closed:
  30. try:
  31. conn.rollback()
  32. except:
  33. gevent.get_hub().handle_error(conn, *sys.exc_info())
  34. raise
  35. else:
  36. conn.commit()
  37. finally:
  38. if conn.closed:
  39. raise Exception('cursor context manager got back closed connection')
  40. pool.put_nowait(conn)
  41. def query_one(cur, sql, *args):
  42. cur.execute(sql, args)
  43. rval = cur.fetchone()
  44. if cur.fetchone() is not None:
  45. raise Exception('got more than one value for query', sql, args)
  46. return rval
  47. def create_server(group_id, hostname):
  48. with cursor() as cur:
  49. server_id = query_one(cur, 'INSERT INTO servers (group_id, hostname) VALUES(%s, %s) RETURNING id',
  50. group_id, hostname)[0]
  51. return server_id