from collections import defaultdict import hashlib import hmac import binascii import os import tornado.gen import psycopg2 import momoko import config def hash_pw(password, salt=None): if salt is None: salt = os.urandom(16) h = hmac.new(salt, password.encode('utf-8'), hashlib.sha256) hashed = h.hexdigest() salt_hex = binascii.hexlify(salt).decode() return hashed, salt_hex class MomokoDB: db = momoko.Pool(dsn='dbname=%s user=%s' % (config.database, config.db_user), size=2) @tornado.gen.coroutine def execute(self, query, *args): result = yield momoko.Op(self.db.execute, query, args, cursor_factory=psycopg2.extras.DictCursor) return result @tornado.gen.coroutine def create_user(self, username, password): hashed_password, salt = hash_pw(password) query = 'INSERT INTO users (username, password, salt) VALUES (%s, %s, %s) RETURNING id;' cursor = yield self.execute(query, username, hashed_password, salt) return cursor.fetchone()['id'] @tornado.gen.coroutine def check_user(self, username, password): query = 'SELECT id, username, password, salt FROM users WHERE username=%s;' cursor = yield self.execute(query, username) user = cursor.fetchone() if not user: return salt = binascii.unhexlify(bytes(user['salt'], 'ascii')) hashed, _ = hash_pw(password, salt) if hashed == user['password']: return user @tornado.gen.coroutine def create_group(self, user_id, group_name): cursor = yield self.execute('INSERT INTO groups (name) VALUES(%s) RETURNING id;', group_name) group_id = cursor.fetchone()['id'] yield self.execute('INSERT INTO user_groups (user_id, group_id) VALUES(%s, %s);', user_id, group_id) return group_id @tornado.gen.coroutine def get_groups(self, user_id): cursor = yield self.execute(''' SELECT groups.id, groups.name FROM user_groups JOIN groups ON user_groups.group_id = groups.id WHERE user_id = %s; ''', user_id) return cursor.fetchall() @tornado.gen.coroutine def get_servers(self, user_id): cursor = yield self.execute(''' SELECT servers.id, servers.group_id, servers.hostname FROM user_groups JOIN servers ON user_groups.group_id = servers.group_id WHERE user_id = %s; ''', user_id) servers = defaultdict(list) for row in cursor.fetchall(): servers[row['group_id']].append(row) return servers