db.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from collections import defaultdict
  2. import hashlib
  3. import hmac
  4. import binascii
  5. import os
  6. import tornado.gen
  7. import psycopg2
  8. import momoko
  9. import config
  10. def hash_pw(password, salt=None):
  11. if salt is None:
  12. salt = os.urandom(16)
  13. h = hmac.new(salt, password.encode('utf-8'), hashlib.sha256)
  14. hashed = h.hexdigest()
  15. salt_hex = binascii.hexlify(salt).decode()
  16. return hashed, salt_hex
  17. class MomokoDB:
  18. db = momoko.Pool(dsn='dbname=%s user=%s' % (config.database, config.db_user), size=2)
  19. @tornado.gen.coroutine
  20. def execute(self, query, *args):
  21. result = yield momoko.Op(self.db.execute, query, args, cursor_factory=psycopg2.extras.DictCursor)
  22. return result
  23. @tornado.gen.coroutine
  24. def create_user(self, username, password):
  25. hashed_password, salt = hash_pw(password)
  26. query = 'INSERT INTO users (username, password, salt) VALUES (%s, %s, %s) RETURNING id;'
  27. cursor = yield self.execute(query, username, hashed_password, salt)
  28. return cursor.fetchone()['id']
  29. @tornado.gen.coroutine
  30. def check_user(self, username, password):
  31. query = 'SELECT id, username, password, salt FROM users WHERE username=%s;'
  32. cursor = yield self.execute(query, username)
  33. user = cursor.fetchone()
  34. if not user:
  35. return
  36. salt = binascii.unhexlify(bytes(user['salt'], 'ascii'))
  37. hashed, _ = hash_pw(password, salt)
  38. if hashed == user['password']:
  39. return user
  40. @tornado.gen.coroutine
  41. def create_group(self, user_id, group_name):
  42. cursor = yield self.execute('INSERT INTO groups (name) VALUES(%s) RETURNING id;', group_name)
  43. group_id = cursor.fetchone()['id']
  44. yield self.execute('INSERT INTO user_groups (user_id, group_id) VALUES(%s, %s);', user_id, group_id)
  45. return group_id
  46. @tornado.gen.coroutine
  47. def get_groups(self, user_id):
  48. cursor = yield self.execute('''
  49. SELECT groups.id, groups.name FROM user_groups
  50. JOIN groups ON user_groups.group_id = groups.id
  51. WHERE user_id = %s;
  52. ''', user_id)
  53. return cursor.fetchall()
  54. @tornado.gen.coroutine
  55. def get_servers(self, user_id):
  56. cursor = yield self.execute('''
  57. SELECT servers.id, servers.group_id, servers.hostname FROM user_groups
  58. JOIN servers ON user_groups.group_id = servers.group_id
  59. WHERE user_id = %s;
  60. ''', user_id)
  61. servers = defaultdict(list)
  62. for row in cursor.fetchall():
  63. servers[row['group_id']].append(row)
  64. return servers