diff --git a/cmd/gateway-http/gateway-http.go b/cmd/gateway-http/gateway-http.go index 6abde2c..b19601a 100644 --- a/cmd/gateway-http/gateway-http.go +++ b/cmd/gateway-http/gateway-http.go @@ -25,6 +25,7 @@ import ( "github.com/tapglue/snaas/platform/redis" "github.com/tapglue/snaas/service/app" "github.com/tapglue/snaas/service/connection" + "github.com/tapglue/snaas/service/counter" "github.com/tapglue/snaas/service/device" "github.com/tapglue/snaas/service/event" "github.com/tapglue/snaas/service/invite" @@ -342,6 +343,9 @@ func main() { // Combine connection service and source. connections = connection.SourcingServiceMiddleware(conSource)(connections) + var counters counter.Service + counters = counter.PostgresService(pgClient) + var devices device.Service devices = device.PostgresService(pgClient) devices = device.InstrumentServiceMiddleware( @@ -505,6 +509,20 @@ func main() { ), ) + current.Methods("PUT").Path(`/me/counters/{counterName:[a-z_\-]+}`).Name("counterGetMe").HandlerFunc( + handler.Wrap( + withUser, + handler.CounterSet(core.CounterSet(counters)), + ), + ) + + current.Methods("GET").Path(`/counters/{counterName:[a-z_\-]+}`).Name("counterGetAll").HandlerFunc( + handler.Wrap( + withUser, + handler.CounterGetAll(core.CounterGetAll(counters)), + ), + ) + current.Methods("GET").Path(`/me/followers`).Name("connectionFollowersMe").HandlerFunc( handler.Wrap( withUser, diff --git a/core/counter.go b/core/counter.go new file mode 100644 index 0000000..fa02100 --- /dev/null +++ b/core/counter.go @@ -0,0 +1,38 @@ +package core + +import ( + "github.com/tapglue/snaas/service/app" + "github.com/tapglue/snaas/service/counter" +) + +// CounterGetAllFunc returns the sum of all counter for a coutner name. +type CounterGetAllFunc func(currentApp *app.App, name string) (uint64, error) + +// CounterGetAll returns the sum of all counter for a coutner name. +func CounterGetAll(counters counter.Service) CounterGetAllFunc { + return func(currentApp *app.App, name string) (uint64, error) { + return counters.CountAll(currentApp.Namespace(), name) + } +} + +// CounterSetFunc sets the counter for the current user and the given counter +// name to the new value. +type CounterSetFunc func( + currentApp *app.App, + origin uint64, + name string, + value uint64, +) error + +// CounterSet sets the counter for the current user and the given counter name +// to the new value. +func CounterSet(counters counter.Service) CounterSetFunc { + return func( + currentApp *app.App, + origin uint64, + name string, + value uint64, + ) error { + return counters.Set(currentApp.Namespace(), name, origin, value) + } +} diff --git a/handler/http/counter.go b/handler/http/counter.go new file mode 100644 index 0000000..c2ad3c9 --- /dev/null +++ b/handler/http/counter.go @@ -0,0 +1,69 @@ +package http + +import ( + "encoding/json" + "net/http" + + "golang.org/x/net/context" + + "github.com/tapglue/snaas/core" +) + +// CounterGetAll returns the sum of all counter for a coutner name. +func CounterGetAll(fn core.CounterGetAllFunc) Handler { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + var ( + currentApp = appFromContext(ctx) + ) + + name, err := extractCounterName(r) + if err != nil { + respondError(w, 0, wrapError(ErrBadRequest, err.Error())) + return + } + + v, err := fn(currentApp, name) + if err != nil { + respondError(w, 0, err) + return + } + + respondJSON(w, http.StatusOK, &payloadCounter{Value: v}) + } +} + +// CounterSet sets the counter for the current user and the given counter name +// to the new value. +func CounterSet(fn core.CounterSetFunc) Handler { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + var ( + currentApp = appFromContext(ctx) + currentUser = userFromContext(ctx) + p = payloadCounter{} + ) + + name, err := extractCounterName(r) + if err != nil { + respondError(w, 0, wrapError(ErrBadRequest, err.Error())) + return + } + + err = json.NewDecoder(r.Body).Decode(&p) + if err != nil { + respondError(w, 0, wrapError(ErrBadRequest, err.Error())) + return + } + + err = fn(currentApp, currentUser.ID, name, p.Value) + if err != nil { + respondError(w, 0, err) + return + } + + respondJSON(w, http.StatusNoContent, nil) + } +} + +type payloadCounter struct { + Value uint64 `json:"value"` +} diff --git a/handler/http/query.go b/handler/http/query.go index 591e7fe..5c72b47 100644 --- a/handler/http/query.go +++ b/handler/http/query.go @@ -26,6 +26,7 @@ const ( keyAppID = "appID" keyCommentID = "commentID" + keyCounterName = "counterName" keyCursorAfter = "after" keyCursorBefore = "before" keyInviteConnections = "invite-connections" @@ -151,6 +152,10 @@ func extractConnectionOpts(r *http.Request) (connection.QueryOptions, error) { return connection.QueryOptions{}, nil } +func extractCounterName(r *http.Request) (string, error) { + return mux.Vars(r)[keyCounterName], nil +} + type condition struct { EQ string `json:"eq"` IN []string `json:"in"` diff --git a/service/counter/counter.go b/service/counter/counter.go new file mode 100644 index 0000000..b950fd3 --- /dev/null +++ b/service/counter/counter.go @@ -0,0 +1,17 @@ +package counter + +import ( + "github.com/tapglue/snaas/platform/service" +) + +// Service for counter interactions. +type Service interface { + service.Lifecycle + + Count(namespace, name string, userID uint64) (uint64, error) + CountAll(namespace, name string) (uint64, error) + Set(namespace, name string, userID, value uint64) error +} + +// ServiceMiddleware is a chainable behaviour modifier for Service. +type ServiceMiddleware func(Service) Service diff --git a/service/counter/postgres.go b/service/counter/postgres.go new file mode 100644 index 0000000..873b08c --- /dev/null +++ b/service/counter/postgres.go @@ -0,0 +1,195 @@ +package counter + +import ( + "fmt" + + "github.com/jmoiron/sqlx" + + "github.com/tapglue/snaas/platform/pg" +) + +const ( + pgGetCounter = ` + SELECT + value + FROM + %s.counters + WHERE + deleted = false + AND name = $1 + ANd user_id = $2 + LIMIT + 1` + pgGetCounterAll = ` + SELECT + sum(value) + FROM + %s.counters + WHERE + deleted = false + AND name = $1` + pgSetCounter = ` + INSERT INTO %s.counters(name, user_id, value) + VALUES($1, $2, $3) + ON CONFLICT (name, user_id) DO + UPDATE SET + value = $3` + + pgCreateSchema = `CREATE SCHEMA IF NOT EXISTS %s` + pgCreateTable = ` + CREATE TABLE IF NOT EXISTS %s.counters( + name TEXT NOT NULL, + user_id BIGINT NOT NULL, + value BIGINT NOT NULL, + deleted BOOL DEFAULT false, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT (now() AT TIME ZONE 'utc'), + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT (now() AT TIME ZONE 'utc'), + + CONSTRAINT counter_id UNIQUE (name, user_id), + PRIMARY KEY (name, user_id) + )` + pgDropTable = `DROP TABLE IF EXISTS %s.counters CASCADE` + + pgIndexCounterID = ` + CREATE UNIQUE INDEX + %s + ON + %s.counters + USING + btree(name, user_id)` + pgIndexCounterName = ` + CREATE INDEX + %s + ON + %s.counters + USING + btree(name)` + + // Extensions. + pgCreateExtensionModdatetime = `CREATE EXTENSION IF NOT EXISTS moddatetime` + + // Trigger to autoamtically set the latest time on updated_at, depends on + // the moddatetime extension: + // * https://www.postgresql.org/docs/current/static/contrib-spi.html + // * https://github.com/postgres/postgres/blob/master/contrib/spi/moddatetime.example + // * https://dba.stackexchange.com/a/158750 + pgAlterTriggerUpdatedAt = ` + ALTER TRIGGER %s ON %s.counters DEPENDS ON EXTENSION moddatetime` + pgCreateTriggerUpdatedAt = ` + CREATE TRIGGER + %s + BEFORE UPDATE ON + %s.counters + FOR EACH ROW EXECUTE PROCEDURE + moddatetime(updated_at)` + pgDropTriggerUpdatedAt = ` + DROP TRIGGER IF EXISTS %s ON %s.counters` +) + +type pgService struct { + db *sqlx.DB +} + +func PostgresService(db *sqlx.DB) Service { + return &pgService{db: db} +} + +func (s *pgService) Count(ns, name string, userID uint64) (uint64, error) { + var ( + args = []interface{}{name, userID} + query = fmt.Sprintf(pgGetCounter, ns) + + value uint64 + ) + + err := s.db.Get(&value, query, args...) + if err != nil && pg.IsRelationNotFound(pg.WrapError(err)) { + if err := s.Setup(ns); err != nil { + return 0, err + } + + err = s.db.Get(&value, query, args...) + } + + return value, err +} + +func (s *pgService) CountAll(ns, name string) (uint64, error) { + var ( + args = []interface{}{name} + query = fmt.Sprintf(pgGetCounterAll, ns) + + value uint64 + ) + + err := s.db.Get(&value, query, args...) + if err != nil && pg.IsRelationNotFound(pg.WrapError(err)) { + if err := s.Setup(ns); err != nil { + return 0, err + } + + err = s.db.Get(&value, query, args...) + } + + return value, err +} + +func (s *pgService) Set(ns, name string, userID, value uint64) error { + var ( + args = []interface{}{ + name, + userID, + value, + } + query = fmt.Sprintf(pgSetCounter, ns) + ) + + _, err := s.db.Exec(query, args...) + if err != nil && pg.IsRelationNotFound(pg.WrapError(err)) { + if err := s.Setup(ns); err != nil { + return err + } + + _, err = s.db.Exec(query, args...) + } + + return err +} + +func (s *pgService) Setup(ns string) error { + for _, q := range []string{ + fmt.Sprintf(pgCreateSchema, ns), + fmt.Sprintf(pgCreateTable, ns), + + // Indexes. + pg.GuardIndex(ns, "counter_counter_id", pgIndexCounterID), + pg.GuardIndex(ns, "counter_counter_name", pgIndexCounterName), + + // FIXME: Re-enable when migrated to Postgres 9.6 + // Setup idempotent updated_at trigger. + // pgCreateExtensionModdatetime, + // fmt.Sprintf(pgDropTriggerUpdatedAt, "counter_updated_at", ns), + // fmt.Sprintf(pgCreateTriggerUpdatedAt, "counter_updated_at", ns), + // fmt.Sprintf(pgAlterTriggerUpdatedAt, "counter_updated_at", ns), + } { + _, err := s.db.Exec(q) + if err != nil { + return fmt.Errorf("setup '%s': %s", q, err) + } + } + + return nil +} + +func (s *pgService) Teardown(ns string) error { + for _, q := range []string{ + fmt.Sprintf(pgDropTable, ns), + } { + _, err := s.db.Exec(q) + if err != nil { + return fmt.Errorf("teardown '%s': %s", q, err) + } + } + + return nil +} diff --git a/service/counter/postgres_test.go b/service/counter/postgres_test.go new file mode 100644 index 0000000..4232296 --- /dev/null +++ b/service/counter/postgres_test.go @@ -0,0 +1,103 @@ +// +build integration + +package counter + +import ( + "flag" + "fmt" + "math/rand" + "os/user" + "testing" + + "github.com/jmoiron/sqlx" + + "github.com/tapglue/snaas/platform/pg" +) + +var pgTestURL string + +func TestPostgresCountAll(t *testing.T) { + var ( + namespace = "service_count_all" + service = preparePostgres(t, namespace) + name = "beers" + + want uint64 + ) + + for i := 0; i < rand.Intn(64); i++ { + userID, value := testUser() + + err := service.Set(namespace, name, userID, value) + if err != nil { + t.Fatal(err) + } + + want += value + } + + have, err := service.CountAll(namespace, name) + if err != nil { + t.Fatal(err) + } + + if have != want { + t.Errorf("have %v, want %v", have, want) + } +} + +func TestPostgresSet(t *testing.T) { + var ( + namespace = "service_set" + service = preparePostgres(t, namespace) + name = "setter" + userID, value = testUser() + ) + + err := service.Set(namespace, name, userID, value) + if err != nil { + t.Fatal(err) + } + + c, err := service.Count(namespace, name, userID) + if err != nil { + t.Fatal(err) + } + + if have, want := c, value; err != nil { + t.Errorf("have %v, want %v", have, want) + } +} + +func preparePostgres(t *testing.T, namespace string) Service { + db, err := sqlx.Connect("postgres", pgTestURL) + if err != nil { + t.Fatal(err) + } + + s := PostgresService(db) + + if err := s.Teardown(namespace); err != nil { + t.Fatal(err) + } + + return s +} + +func testUser() (uint64, uint64) { + return uint64(rand.Int63()), uint64(rand.Int31()) +} + +func init() { + u, err := user.Current() + if err != nil { + panic(err) + } + + d := fmt.Sprintf(pg.URLTest, u.Username) + + url := flag.String("postgres.url", d, "Postgres test connection URL") + flag.Parse() + + pgTestURL = *url +}