Skip to content

Commit

Permalink
fix(spx-backend): correct user_relationship table joins in list use…
Browse files Browse the repository at this point in the history
…rs API (#1006)

Signed-off-by: Aofei Sheng <[email protected]>
  • Loading branch information
aofei authored Oct 22, 2024
1 parent 3ab8066 commit a334e66
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
22 changes: 12 additions & 10 deletions spx-backend/internal/controller/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,28 +148,30 @@ func (p *ListUsersParams) Validate() (ok bool, msg string) {
// ListUsers retrieves a paginated list of users with optional filtering and ordering.
func (ctrl *Controller) ListUsers(ctx context.Context, params *ListUsersParams) (*ByPage[UserDTO], error) {
query := ctrl.db.WithContext(ctx).Model(&model.User{})
joinedTables := map[string]struct{}{}
if params.Follower != nil {
query = query.Joins("JOIN user_relationship ON user_relationship.target_user_id = user.id").
Where("user_relationship.user_id = ?", *params.Follower)
joinedTables["user_relationship"] = struct{}{}
query = query.Joins("JOIN user AS follower ON follower.username = ?", *params.Follower).
Joins("JOIN user_relationship AS follower_relationship ON follower_relationship.user_id = follower.id AND follower_relationship.target_user_id = user.id").
Where("follower_relationship.followed_at IS NOT NULL")
}
if params.Followee != nil {
query = query.Joins("JOIN user_relationship ON user_relationship.user_id = user.id").
Where("user_relationship.target_user_id = ?", *params.Followee)
joinedTables["user_relationship"] = struct{}{}
query = query.Joins("JOIN user AS followee ON followee.username = ?", *params.Followee).
Joins("JOIN user_relationship AS followee_relationship ON followee_relationship.target_user_id = followee.id AND followee_relationship.user_id = user.id").
Where("followee_relationship.followed_at IS NOT NULL")
}
switch params.OrderBy {
case ListUsersOrderByCreatedAt:
query = query.Order(fmt.Sprintf("user.created_at %s", params.SortOrder))
case ListUsersOrderByUpdatedAt:
query = query.Order(fmt.Sprintf("user.updated_at %s", params.SortOrder))
case ListUsersOrderByFollowedAt:
if _, ok := joinedTables["user_relationship"]; !ok {
switch {
case params.Follower != nil:
query = query.Order(fmt.Sprintf("follower_relationship.followed_at %s", params.SortOrder))
case params.Followee != nil:
query = query.Order(fmt.Sprintf("followee_relationship.followed_at %s", params.SortOrder))
default:
query = query.Order(fmt.Sprintf("user.created_at %s", params.SortOrder))
break
}
query = query.Order(fmt.Sprintf("user_relationship.followed_at %s", params.SortOrder))
}

var total int64
Expand Down
32 changes: 19 additions & 13 deletions spx-backend/internal/controller/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ func TestControllerListUsers(t *testing.T) {

dbMockStmt := ctrl.db.Session(&gorm.Session{DryRun: true}).
Model(&model.User{}).
Joins("JOIN user_relationship ON user_relationship.target_user_id = user.id").
Where("user_relationship.user_id = ?", *params.Follower).
Joins("JOIN user AS follower ON follower.username = ?", *params.Follower).
Joins("JOIN user_relationship AS follower_relationship ON follower_relationship.user_id = follower.id AND follower_relationship.target_user_id = user.id").
Where("follower_relationship.followed_at IS NOT NULL").
Count(new(int64)).
Statement
dbMockArgs := modeltest.ToDriverValueSlice(dbMockStmt.Vars...)
Expand All @@ -270,8 +271,9 @@ func TestControllerListUsers(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))

dbMockStmt = ctrl.db.Session(&gorm.Session{DryRun: true}).
Joins("JOIN user_relationship ON user_relationship.target_user_id = user.id").
Where("user_relationship.user_id = ?", *params.Follower).
Joins("JOIN user AS follower ON follower.username = ?", *params.Follower).
Joins("JOIN user_relationship AS follower_relationship ON follower_relationship.user_id = follower.id AND follower_relationship.target_user_id = user.id").
Where("follower_relationship.followed_at IS NOT NULL").
Order("user.created_at desc").
Limit(params.Pagination.Size).
Find(&[]model.User{}).
Expand Down Expand Up @@ -308,8 +310,9 @@ func TestControllerListUsers(t *testing.T) {

dbMockStmt := ctrl.db.Session(&gorm.Session{DryRun: true}).
Model(&model.User{}).
Joins("JOIN user_relationship ON user_relationship.user_id = user.id").
Where("user_relationship.target_user_id = ?", *params.Followee).
Joins("JOIN user AS followee ON followee.username = ?", *params.Followee).
Joins("JOIN user_relationship AS followee_relationship ON followee_relationship.target_user_id = followee.id AND followee_relationship.user_id = user.id").
Where("followee_relationship.followed_at IS NOT NULL").
Count(new(int64)).
Statement
dbMockArgs := modeltest.ToDriverValueSlice(dbMockStmt.Vars...)
Expand All @@ -318,8 +321,9 @@ func TestControllerListUsers(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))

dbMockStmt = ctrl.db.Session(&gorm.Session{DryRun: true}).
Joins("JOIN user_relationship ON user_relationship.user_id = user.id").
Where("user_relationship.target_user_id = ?", *params.Followee).
Joins("JOIN user AS followee ON followee.username = ?", *params.Followee).
Joins("JOIN user_relationship AS followee_relationship ON followee_relationship.target_user_id = followee.id AND followee_relationship.user_id = user.id").
Where("followee_relationship.followed_at IS NOT NULL").
Order("user.created_at desc").
Limit(params.Pagination.Size).
Find(&[]model.User{}).
Expand Down Expand Up @@ -361,8 +365,9 @@ func TestControllerListUsers(t *testing.T) {

dbMockStmt := ctrl.db.Session(&gorm.Session{DryRun: true}).
Model(&model.User{}).
Joins("JOIN user_relationship ON user_relationship.target_user_id = user.id").
Where("user_relationship.user_id = ?", *params.Follower).
Joins("JOIN user AS follower ON follower.username = ?", *params.Follower).
Joins("JOIN user_relationship AS follower_relationship ON follower_relationship.user_id = follower.id AND follower_relationship.target_user_id = user.id").
Where("follower_relationship.followed_at IS NOT NULL").
Count(new(int64)).
Statement
dbMockArgs := modeltest.ToDriverValueSlice(dbMockStmt.Vars...)
Expand All @@ -371,9 +376,10 @@ func TestControllerListUsers(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(2))

dbMockStmt = ctrl.db.Session(&gorm.Session{DryRun: true}).
Joins("JOIN user_relationship ON user_relationship.target_user_id = user.id").
Where("user_relationship.user_id = ?", *params.Follower).
Order("user_relationship.followed_at desc").
Joins("JOIN user AS follower ON follower.username = ?", *params.Follower).
Joins("JOIN user_relationship AS follower_relationship ON follower_relationship.user_id = follower.id AND follower_relationship.target_user_id = user.id").
Where("follower_relationship.followed_at IS NOT NULL").
Order("follower_relationship.followed_at desc").
Limit(params.Pagination.Size).
Find(&[]model.User{}).
Statement
Expand Down

0 comments on commit a334e66

Please sign in to comment.