Skip to content
65 changes: 54 additions & 11 deletions api/queries_pr.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package api

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -629,17 +631,58 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter
return pr, nil
}

func UpdatePullRequestReviews(client *Client, repo ghrepo.Interface, params githubv4.RequestReviewsInput) error {
var mutation struct {
RequestReviews struct {
PullRequest struct {
ID string
}
} `graphql:"requestReviews(input: $input)"`
}
variables := map[string]interface{}{"input": params}
err := client.Mutate(repo.RepoHost(), "PullRequestUpdateRequestReviews", &mutation, variables)
return err
// AddPullRequestReviews adds the given user and team reviewers to a pull request using the REST API.
func AddPullRequestReviews(client *Client, repo ghrepo.Interface, prNumber int, users, teams []string) error {
if len(users) == 0 && len(teams) == 0 {
return nil
}

path := fmt.Sprintf(
"repos/%s/%s/pulls/%d/requested_reviewers",
url.PathEscape(repo.RepoOwner()),
url.PathEscape(repo.RepoName()),
prNumber,
)
body := struct {
Reviewers []string `json:"reviewers,omitempty"`
TeamReviewers []string `json:"team_reviewers,omitempty"`
}{
Reviewers: users,
TeamReviewers: teams,
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(body); err != nil {
return err
}
// The endpoint responds with the updated pull request object; we don't need it here.
return client.REST(repo.RepoHost(), "POST", path, buf, nil)
Comment thread
BagToad marked this conversation as resolved.
}

// RemovePullRequestReviews removes requested reviewers from a pull request using the REST API.
func RemovePullRequestReviews(client *Client, repo ghrepo.Interface, prNumber int, users, teams []string) error {
if len(users) == 0 && len(teams) == 0 {
return nil
}

path := fmt.Sprintf(
"repos/%s/%s/pulls/%d/requested_reviewers",
url.PathEscape(repo.RepoOwner()),
url.PathEscape(repo.RepoName()),
prNumber,
)
body := struct {
Reviewers []string `json:"reviewers,omitempty"`
TeamReviewers []string `json:"team_reviewers,omitempty"`
}{
Reviewers: users,
TeamReviewers: teams,
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(body); err != nil {
return err
}
// The endpoint responds with the updated pull request object; we don't need it here.
return client.REST(repo.RepoHost(), "DELETE", path, buf, nil)
}

func UpdatePullRequestBranch(client *Client, repo ghrepo.Interface, params githubv4.UpdatePullRequestBranchInput) error {
Expand Down
87 changes: 60 additions & 27 deletions pkg/cmd/pr/edit/edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package edit
import (
"fmt"
"net/http"
"slices"
"strings"
"time"

"github.com/MakeNowJust/heredoc"
Expand All @@ -13,7 +15,7 @@ import (
shared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/shurcooL/githubv4"
"github.com/cli/cli/v2/pkg/set"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -170,7 +172,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman
}

if opts.Interactive && !opts.IO.CanPrompt() {
return cmdutil.FlagErrorf("--tile, --body, --reviewer, --assignee, --label, --project, or --milestone required when not running interactively")
return cmdutil.FlagErrorf("--title, --body, --reviewer, --assignee, --label, --project, or --milestone required when not running interactively")
Comment thread
BagToad marked this conversation as resolved.
}

if runF != nil {
Expand Down Expand Up @@ -237,7 +239,7 @@ func editRun(opts *EditOptions) error {

findOptions := shared.FindOptions{
Selector: opts.SelectorArg,
Comment thread
BagToad marked this conversation as resolved.
Fields: []string{"id", "url", "title", "body", "baseRefName", "reviewRequests", "labels", "projectCards", "projectItems", "milestone"},
Fields: []string{"id", "author", "url", "title", "body", "baseRefName", "reviewRequests", "labels", "projectCards", "projectItems", "milestone"},
Detector: opts.Detector,
}

Expand Down Expand Up @@ -298,6 +300,15 @@ func editRun(opts *EditOptions) error {
}

if opts.Interactive {
// Remove PR author from reviewer options;
// REST API errors if author is included (GraphQL silently ignores).
if editable.Reviewers.Edited {
s := set.NewStringSet()
s.AddValues(editable.Reviewers.Options)
s.Remove(pr.Author.Login)
editable.Reviewers.Options = s.ToSlice()
}

editorCommand, err := opts.EditorRetriever.Retrieve()
if err != nil {
return err
Expand All @@ -309,7 +320,7 @@ func editRun(opts *EditOptions) error {
}

opts.IO.StartProgressIndicator()
err = updatePullRequest(httpClient, repo, pr.ID, editable)
err = updatePullRequest(httpClient, repo, pr.ID, pr.Number, editable)
opts.IO.StopProgressIndicator()
if err != nil {
return err
Expand All @@ -320,36 +331,53 @@ func editRun(opts *EditOptions) error {
return nil
}

func updatePullRequest(httpClient *http.Client, repo ghrepo.Interface, id string, editable shared.Editable) error {
func updatePullRequest(httpClient *http.Client, repo ghrepo.Interface, id string, number int, editable shared.Editable) error {
var wg errgroup.Group
wg.Go(func() error {
return shared.UpdateIssue(httpClient, repo, id, true, editable)
})
if editable.Reviewers.Edited {
wg.Go(func() error {
return updatePullRequestReviews(httpClient, repo, id, editable)
return updatePullRequestReviews(httpClient, repo, number, editable)
})
}
return wg.Wait()
}

func updatePullRequestReviews(httpClient *http.Client, repo ghrepo.Interface, id string, editable shared.Editable) error {
userIds, teamIds, err := editable.ReviewerIds()
if err != nil {
return err
}
if userIds == nil && teamIds == nil {
func updatePullRequestReviews(httpClient *http.Client, repo ghrepo.Interface, number int, editable shared.Editable) error {
if !editable.Reviewers.Edited {
return nil
}
union := githubv4.Boolean(false)
reviewsRequestParams := githubv4.RequestReviewsInput{
PullRequestID: id,
Union: &union,
UserIDs: ghIds(userIds),
TeamIDs: ghIds(teamIds),

// Rebuild the Value slice from non-interactive flag input.
if len(editable.Reviewers.Add) != 0 || len(editable.Reviewers.Remove) != 0 {
s := set.NewStringSet()
Comment thread
babakks marked this conversation as resolved.
s.AddValues(editable.Reviewers.Add)
s.AddValues(editable.Reviewers.Default)
s.RemoveValues(editable.Reviewers.Remove)
editable.Reviewers.Value = s.ToSlice()
}

addUsers, addTeams := partitionUsersAndTeams(editable.Reviewers.Value)

// Reviewers in Default but not in the Value have been removed interactively.
var toRemove []string
for _, r := range editable.Reviewers.Default {
if !slices.Contains(editable.Reviewers.Value, r) {
toRemove = append(toRemove, r)
}
}
removeUsers, removeTeams := partitionUsersAndTeams(toRemove)

client := api.NewClientFromHTTP(httpClient)
return api.UpdatePullRequestReviews(client, repo, reviewsRequestParams)
wg := errgroup.Group{}
wg.Go(func() error {
return api.AddPullRequestReviews(client, repo, number, addUsers, addTeams)
})
wg.Go(func() error {
return api.RemovePullRequestReviews(client, repo, number, removeUsers, removeTeams)
})
return wg.Wait()
}

type Surveyor interface {
Expand Down Expand Up @@ -391,13 +419,18 @@ func (e editorRetriever) Retrieve() (string, error) {
return cmdutil.DetermineEditor(e.config)
}

func ghIds(s *[]string) *[]githubv4.ID {
if s == nil {
return nil
}
ids := make([]githubv4.ID, len(*s))
for i, v := range *s {
ids[i] = v
// partitionUsersAndTeams splits reviewer identifiers into user logins and team slugs.
// Team identifiers are in the form "org/slug"; only the slug portion is returned for teams.
func partitionUsersAndTeams(values []string) (users []string, teams []string) {
for _, v := range values {
if strings.ContainsRune(v, '/') {
parts := strings.SplitN(v, "/", 2)
if len(parts) == 2 && parts[1] != "" {
teams = append(teams, parts[1])
}
} else if v != "" {
users = append(users, v)
}
}
return &ids
return
}
Loading