raylu 14 жил өмнө
parent
commit
0164e2515a
3 өөрчлөгдсөн 52 нэмэгдсэн , 26 устгасан
  1. 16 11
      db.go
  2. 29 14
      main.go
  3. 7 1
      updates.go

+ 16 - 11
db.go

@@ -6,20 +6,25 @@ import (
 	"log"
 )
 
-var db *mysql.Client
+var dbPool chan *mysql.Client
 
 func initDb() {
 	log.SetFlags(log.Ltime | log.Lshortfile)
 
-	var err os.Error
-	db, err = mysql.DialTCP("173.228.31.111", "audio", "audio", "audio")
-	if err != nil {
-		log.Panicln(err)
+	const dbPoolSize = 4
+	dbPool = make(chan *mysql.Client, 4)
+
+	for i := 0; i < dbPoolSize; i++ {
+		db, err := mysql.DialTCP("173.228.31.111", "audio", "audio", "audio")
+		if err != nil {
+			log.Panicln(err)
+		}
+		db.Reconnect = true
+		dbPool <- db
 	}
-	db.Reconnect = true
 }
 
-func prepare(sql string, params ...interface{}) (*mysql.Statement, os.Error) {
+func prepare(db *mysql.Client, sql string, params ...interface{}) (*mysql.Statement, os.Error) {
 	query, err := db.Prepare(sql)
 	if err != nil {
 		log.Println(err)
@@ -38,8 +43,8 @@ func prepare(sql string, params ...interface{}) (*mysql.Statement, os.Error) {
 	return query, err
 }
 
-func queryInt(sql string, params ...interface{}) (int, os.Error) {
-	query, err := prepare(sql, params...)
+func queryInt(db *mysql.Client, sql string, params ...interface{}) (int, os.Error) {
+	query, err := prepare(db, sql, params...)
 	if err != nil {
 		return 0, err
 	}
@@ -61,8 +66,8 @@ func queryInt(sql string, params ...interface{}) (int, os.Error) {
 }
 
 // given an id ('abcd1234'), return the pid (1)
-func getpid(id string) int {
-	pid, err := queryInt("SELECT `pid` FROM `playlist` WHERE `id` = ?", id)
+func getpid(db *mysql.Client, id string) int {
+	pid, err := queryInt(db, "SELECT `pid` FROM `playlist` WHERE `id` = ?", id)
 	if err != nil {
 		return -1
 	}

+ 29 - 14
main.go

@@ -56,7 +56,10 @@ func playlist(w http.ResponseWriter, r *http.Request) {
 
 func add(w http.ResponseWriter, r *http.Request) {
 	q := r.URL.Query()
-	pid := getpid(q.Get("pid"))
+	db := <-dbPool
+	defer func () {dbPool <- db}()
+
+	pid := getpid(db, q.Get("pid"))
 	if pid == -1 {
 		http.Error(w, "invalid pid", http.StatusBadRequest)
 		return
@@ -67,14 +70,15 @@ func add(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, err.String(), http.StatusInternalServerError)
 		return
 	}
-	maxOrder, err := queryInt("SELECT MAX(`order`) FROM `song` WHERE pid = ?", pid)
+	maxOrder, err := queryInt(db, "SELECT MAX(`order`) FROM `song` WHERE pid = ?", pid)
 	if err != nil {
 		db.Rollback()
 		http.Error(w, err.String(), http.StatusInternalServerError)
 		return
 	}
-	_, err = prepare("INSERT INTO `song` (`pid`,`yid`,`title`,`user`,`order`) VALUES(?, ?, ?, ?, ?)",
-				pid, q.Get("yid"), q.Get("title"), q.Get("user"), maxOrder + 1)
+	_, err = prepare(db,
+			"INSERT INTO `song` (`pid`,`yid`,`title`,`user`,`order`) VALUES(?, ?, ?, ?, ?)",
+			pid, q.Get("yid"), q.Get("title"), q.Get("user"), maxOrder + 1)
 	if err != nil {
 		db.Rollback()
 		http.Error(w, err.String(), http.StatusInternalServerError)
@@ -93,7 +97,11 @@ func add(w http.ResponseWriter, r *http.Request) {
 
 func remove(w http.ResponseWriter, r *http.Request) {
 	q := r.URL.Query()
-	pid := getpid(q.Get("pid"))
+
+	db := <-dbPool
+	defer func () {dbPool <- db}()
+
+	pid := getpid(db, q.Get("pid"))
 	if pid == -1 {
 		http.Error(w, "invalid pid", http.StatusBadRequest)
 		return
@@ -105,7 +113,7 @@ func remove(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	order, err := queryInt("SELECT `order` FROM `song` WHERE `yid` = ? AND `pid` = ?",
+	order, err := queryInt(db, "SELECT `order` FROM `song` WHERE `yid` = ? AND `pid` = ?",
 			q.Get("yid"), pid)
 	if err != nil {
 		db.Rollback()
@@ -113,7 +121,7 @@ func remove(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	_, err = prepare("DELETE FROM `song` WHERE `pid` = ? AND yid = ?",
+	_, err = prepare(db, "DELETE FROM `song` WHERE `pid` = ? AND yid = ?",
 			pid, q.Get("yid"))
 	if err != nil {
 		db.Rollback()
@@ -121,7 +129,7 @@ func remove(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	_, err = prepare("UPDATE `song` SET `order` = `order`-1 WHERE `order` > ? AND `pid` = ?",
+	_, err = prepare(db, "UPDATE `song` SET `order` = `order`-1 WHERE `order` > ? AND `pid` = ?",
 			order, pid)
 	if err != nil {
 		db.Rollback()
@@ -141,7 +149,11 @@ func remove(w http.ResponseWriter, r *http.Request) {
 
 func move(w http.ResponseWriter, r *http.Request) {
 	q := r.URL.Query()
-	pid := getpid(q.Get("pid"))
+
+	db := <-dbPool
+	defer func () {dbPool <- db}()
+
+	pid := getpid(db, q.Get("pid"))
 	if pid == -1 {
 		http.Error(w, "invalid pid", http.StatusBadRequest)
 		return
@@ -159,7 +171,7 @@ func move(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	order, err := queryInt("SELECT `order` FROM `song` WHERE `yid` = ? AND `pid` = ?",
+	order, err := queryInt(db, "SELECT `order` FROM `song` WHERE `yid` = ? AND `pid` = ?",
 			q.Get("yid"), pid)
 	if err != nil {
 		db.Rollback()
@@ -177,7 +189,7 @@ func move(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	query, err := prepare("UPDATE `song` SET `order` = ? WHERE `order` = ? AND `pid` = ?",
+	query, err := prepare(db, "UPDATE `song` SET `order` = ? WHERE `order` = ? AND `pid` = ?",
 			order, newOrder, pid)
 	if err != nil {
 		db.Rollback()
@@ -189,7 +201,7 @@ func move(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	// there are now two songs with that order, so also check yid
-	_, err = prepare("UPDATE `song` SET `order` = ? WHERE `order` = ? AND `pid` = ? AND `yid` = ?",
+	_, err = prepare(db, "UPDATE `song` SET `order` = ? WHERE `order` = ? AND `pid` = ? AND `yid` = ?",
 			newOrder, order, pid, q.Get("yid"))
 	if err != nil {
 		db.Rollback()
@@ -211,7 +223,10 @@ func poll(w http.ResponseWriter, r *http.Request) {
 	q := r.URL.Query()
 	timestamp := q.Get("timestamp")
 	if timestamp == "0" {
-		query, err := prepare(
+		db := <-dbPool
+		defer func () {dbPool <- db}()
+
+		query, err := prepare(db,
 				"SELECT `yid`,`title`,`user` FROM `playlist` JOIN `song` WHERE `id` = ? ORDER BY `order` ASC",
 				q.Get("pid"))
 
@@ -249,7 +264,7 @@ func poll(w http.ResponseWriter, r *http.Request) {
 		}
 		var update *Update
 		for i := 0; i < 30; i++ {
-			update = getUpdates(getpid(q.Get("pid")), timestamp)
+			update = getUpdates(q.Get("pid"), timestamp)
 			if update != nil {
 				w.Write([]byte("["))
 				for update != nil {

+ 7 - 1
updates.go

@@ -38,7 +38,13 @@ func addUpdate(pid int, action uint, song *Song) {
 	tailUpdates[pid] = update
 }
 
-func getUpdates(pid int, timestamp int64) *Update {
+func getUpdates(id string, timestamp int64) *Update {
+	db := <-dbPool
+	pid := getpid(db, id)
+	dbPool <- db
+	if pid == -1 {
+		return nil
+	}
 	pup, ok := headUpdates[pid]
 	if !ok {
 		return nil