Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass request to Render backport (stable-5.21) #14409

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lxd-agent/devlxd.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ var devLxdEventsGet = devLxdHandler{
}

func devlxdEventsGetHandler(d *Daemon, w http.ResponseWriter, r *http.Request) *devLxdResponse {
err := eventsGet(d, r).Render(w)
err := eventsGet(d, r).Render(w, r)
if err != nil {
return smartResponse(err)
}
Expand Down
4 changes: 2 additions & 2 deletions lxd-agent/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ type eventsServe struct {
}

// Render starts event socket.
func (r *eventsServe) Render(w http.ResponseWriter) error {
return eventsSocket(r.d, r.req, w)
func (r *eventsServe) Render(w http.ResponseWriter, request *http.Request) error {
return eventsSocket(r.d, request, w)
}

func (r *eventsServe) String() string {
Expand Down
10 changes: 5 additions & 5 deletions lxd-agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func restServer(tlsConfig *tls.Config, cert *x509.Certificate, d *Daemon) *http.

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = response.SyncResponse(true, []string{"/1.0"}).Render(w)
_ = response.SyncResponse(true, []string{"/1.0"}).Render(w, r)
})

for _, c := range api10 {
Expand All @@ -46,7 +46,7 @@ func createCmd(restAPI *mux.Router, version string, c APIEndpoint, cert *x509.Ce

if !authenticate(r, cert) {
logger.Error("Not authorized")
_ = response.InternalError(fmt.Errorf("Not authorized")).Render(w)
_ = response.InternalError(fmt.Errorf("Not authorized")).Render(w, r)
return
}

Expand All @@ -57,7 +57,7 @@ func createCmd(restAPI *mux.Router, version string, c APIEndpoint, cert *x509.Ce
multiW := io.MultiWriter(newBody, captured)
_, err := io.Copy(multiW, r.Body)
if err != nil {
_ = response.InternalError(err).Render(w)
_ = response.InternalError(err).Render(w, r)
return
}

Expand Down Expand Up @@ -92,9 +92,9 @@ func createCmd(restAPI *mux.Router, version string, c APIEndpoint, cert *x509.Ce
}

// Handle errors
err := resp.Render(w)
err := resp.Render(w, r)
if err != nil {
writeErr := response.InternalError(err).Render(w)
writeErr := response.InternalError(err).Render(w, r)
if writeErr != nil {
logger.Error("Failed writing error for HTTP response", logger.Ctx{"url": uri, "err": err, "writeErr": writeErr})
}
Expand Down
5 changes: 3 additions & 2 deletions lxd-agent/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ func (r *sftpServe) String() string {
return "sftp handler"
}

func (r *sftpServe) Render(w http.ResponseWriter) error {
// Render hijacks the connection and starts a sftp server.
func (r *sftpServe) Render(w http.ResponseWriter, request *http.Request) error {
// Upgrade to sftp.
if r.r.Header.Get("Upgrade") != "sftp" {
if request.Header.Get("Upgrade") != "sftp" {
http.Error(w, "Missing or invalid upgrade header", http.StatusBadRequest)
return nil
}
Expand Down
10 changes: 5 additions & 5 deletions lxd/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func restServer(d *Daemon) *http.Server {
}

// Normal client handling.
_ = response.SyncResponse(true, []string{"/1.0"}).Render(w)
_ = response.SyncResponse(true, []string{"/1.0"}).Render(w, r)
})

for endpoint, f := range d.gateway.HandlerFuncs(d.heartbeatHandler, d.identityCache) {
Expand Down Expand Up @@ -206,7 +206,7 @@ func restServer(d *Daemon) *http.Server {
mux.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("Sending top level 404", logger.Ctx{"url": r.URL, "method": r.Method, "remote": r.RemoteAddr})
w.Header().Set("Content-Type", "application/json")
_ = response.NotFound(nil).Render(w)
_ = response.NotFound(nil).Render(w, r)
})

return &http.Server{
Expand All @@ -232,7 +232,7 @@ func hoistReqVM(f func(*Daemon, instance.Instance, http.ResponseWriter, *http.Re
}

resp := f(d, inst, w, r)
_ = resp.Render(w)
_ = resp.Render(w, r)
}
}

Expand All @@ -248,7 +248,7 @@ func metricsServer(d *Daemon) *http.Server {

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = response.SyncResponse(true, []string{"/1.0"}).Render(w)
_ = response.SyncResponse(true, []string{"/1.0"}).Render(w, r)
})

for endpoint, f := range d.gateway.HandlerFuncs(d.heartbeatHandler, d.identityCache) {
Expand All @@ -261,7 +261,7 @@ func metricsServer(d *Daemon) *http.Server {
mux.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("Sending top level 404", logger.Ctx{"url": r.URL, "method": r.Method, "remote": r.RemoteAddr})
w.Header().Set("Content-Type", "application/json")
_ = response.NotFound(nil).Render(w)
_ = response.NotFound(nil).Render(w, r)
})

return &http.Server{Handler: &lxdHTTPServer{r: mux, d: d}}
Expand Down
4 changes: 2 additions & 2 deletions lxd/api_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ func clusterPutDisable(d *Daemon, r *http.Request, req api.ClusterPut) response.
}()

return response.ManualResponse(func(w http.ResponseWriter) error {
err := response.EmptySyncResponse.Render(w)
err := response.EmptySyncResponse.Render(w, r)
if err != nil {
return err
}
Expand Down Expand Up @@ -2039,7 +2039,7 @@ func clusterNodeDelete(d *Daemon, r *http.Request) response.Response {
}

return response.ManualResponse(func(w http.ResponseWriter) error {
err := response.EmptySyncResponse.Render(w)
err := response.EmptySyncResponse.Render(w, r)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion lxd/api_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func internalShutdown(d *Daemon, r *http.Request) response.Response {

// Run shutdown sequence synchronously.
stopErr := d.Stop(forceCtx, unix.SIGPWR)
err := response.SmartError(stopErr).Render(w)
err := response.SmartError(stopErr).Render(w, r)
if err != nil {
return err
}
Expand Down
10 changes: 5 additions & 5 deletions lxd/auth/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (o *Verifier) getGroupsFromClaims(customClaims map[string]any) []string {
func (o *Verifier) Login(w http.ResponseWriter, r *http.Request) {
err := o.ensureConfig(r.Context(), r.Host)
if err != nil {
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Login failed: %w", err).Error()).Render(w)
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Login failed: %w", err).Error()).Render(w, r)
return
}

Expand All @@ -287,7 +287,7 @@ func (o *Verifier) Login(w http.ResponseWriter, r *http.Request) {
func (o *Verifier) Logout(w http.ResponseWriter, r *http.Request) {
err := o.setCookies(w, nil, uuid.UUID{}, "", "", true)
if err != nil {
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Failed to delete login information: %w", err).Error()).Render(w)
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Failed to delete login information: %w", err).Error()).Render(w, r)
return
}

Expand All @@ -298,21 +298,21 @@ func (o *Verifier) Logout(w http.ResponseWriter, r *http.Request) {
func (o *Verifier) Callback(w http.ResponseWriter, r *http.Request) {
err := o.ensureConfig(r.Context(), r.Host)
if err != nil {
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("OIDC callback failed: %w", err).Error()).Render(w)
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("OIDC callback failed: %w", err).Error()).Render(w, r)
return
}

handler := rp.CodeExchangeHandler(func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) {
sessionID := uuid.New()
secureCookie, err := o.secureCookieFromSession(sessionID)
if err != nil {
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Failed to start a new session: %w", err).Error()).Render(w)
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Failed to start a new session: %w", err).Error()).Render(w, r)
return
}

err = o.setCookies(w, secureCookie, sessionID, tokens.IDToken, tokens.RefreshToken, false)
if err != nil {
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Failed to set login information: %w", err).Error()).Render(w)
_ = response.ErrorResponse(http.StatusInternalServerError, fmt.Errorf("Failed to set login information: %w", err).Error()).Render(w, r)
return
}

Expand Down
2 changes: 1 addition & 1 deletion lxd/cluster/notify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (h *notifyFixtures) Unavailable(i int, err error) {
mux.HandleFunc("/1.0/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err := response.Unavailable(err)
_ = err.Render(w)
_ = err.Render(w, r)
})

h.servers[i].Config.Handler = mux
Expand Down
18 changes: 9 additions & 9 deletions lxd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
case <-d.setupChan:
default:
response := response.Unavailable(fmt.Errorf("LXD daemon setup in progress"))
_ = response.Render(w)
_ = response.Render(w, r)
return
}
}
Expand All @@ -654,11 +654,11 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {

// Return 401 Unauthorized error. This indicates to the client that it needs to use the
// headers we've set above to get an access token and try again.
_ = response.Unauthorized(err).Render(w)
_ = response.Unauthorized(err).Render(w, r)
return
}

_ = response.Forbidden(err).Render(w)
_ = response.Forbidden(err).Render(w, r)
return
}

Expand All @@ -670,7 +670,7 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
// Except for the initial cluster accept request (done over trusted TLS)
if !trusted || c.Path != "cluster/accept" || protocol != api.AuthenticationMethodTLS {
logger.Warn("Rejecting remote internal API request", logger.Ctx{"ip": r.RemoteAddr})
_ = response.Forbidden(nil).Render(w)
_ = response.Forbidden(nil).Render(w, r)
return
}
}
Expand Down Expand Up @@ -720,7 +720,7 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
}

logger.Warn("Rejecting request from untrusted client", logger.Ctx{"ip": r.RemoteAddr})
_ = response.Forbidden(nil).Render(w)
_ = response.Forbidden(nil).Render(w, r)
return
}

Expand All @@ -731,7 +731,7 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
multiW := io.MultiWriter(newBody, captured)
_, err := io.Copy(multiW, r.Body)
if err != nil {
_ = response.InternalError(err).Render(w)
_ = response.InternalError(err).Render(w, r)
return
}

Expand Down Expand Up @@ -766,7 +766,7 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
}

if d.shutdownCtx.Err() == context.Canceled && !allowedDuringShutdown() {
_ = response.Unavailable(fmt.Errorf("LXD is shutting down")).Render(w)
_ = response.Unavailable(fmt.Errorf("LXD is shutting down")).Render(w, r)
return
}

Expand Down Expand Up @@ -814,9 +814,9 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
}

// Handle errors
err = resp.Render(w)
err = resp.Render(w, r)
if err != nil {
writeErr := response.SmartError(err).Render(w)
writeErr := response.SmartError(err).Render(w, r)
if writeErr != nil {
logger.Warn("Failed writing error for HTTP response", logger.Ctx{"url": uri, "err": err, "writeErr": writeErr})
}
Expand Down
4 changes: 2 additions & 2 deletions lxd/devlxd.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func devlxdImageExportHandler(d *Daemon, c instance.Instance, w http.ResponseWri

resp := imageExport(d, r)

err := resp.Render(w)
err := resp.Render(w, r)
if err != nil {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusInternalServerError, "internal server error"), c.Type() == instancetype.VM)
}
Expand Down Expand Up @@ -390,7 +390,7 @@ func hoistReq(f func(*Daemon, instance.Instance, http.ResponseWriter, *http.Requ
}

resp := f(d, c, w, r)
_ = resp.Render(w)
_ = resp.Render(w, r)
}
}

Expand Down
2 changes: 1 addition & 1 deletion lxd/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type eventsServe struct {
}

// Render starts event socket.
func (r *eventsServe) Render(w http.ResponseWriter) error {
func (r *eventsServe) Render(w http.ResponseWriter, req *http.Request) error {
return eventsSocket(r.s, r.req, w)
}

Expand Down
2 changes: 1 addition & 1 deletion lxd/instance_sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (r *sftpServeResponse) String() string {
}

// Render renders the server response.
func (r *sftpServeResponse) Render(w http.ResponseWriter) error {
func (r *sftpServeResponse) Render(w http.ResponseWriter, req *http.Request) error {
defer func() { _ = r.instConn.Close() }()

hijacker, ok := w.(http.Hijacker)
Expand Down
34 changes: 4 additions & 30 deletions lxd/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,17 +983,17 @@ func operationWaitGet(d *Daemon, r *http.Request) response.Response {
// Wait for the operation.
err = op.Wait(ctx)
if err != nil {
_ = response.SmartError(err).Render(w)
_ = response.SmartError(err).Render(w, r)
return nil
}

_, body, err := op.Render()
if err != nil {
_ = response.SmartError(err).Render(w)
_ = response.SmartError(err).Render(w, r)
return nil
}

_ = response.SyncResponse(true, body).Render(w)
_ = response.SyncResponse(true, body).Render(w, r)
return nil
}

Expand Down Expand Up @@ -1034,32 +1034,6 @@ func operationWaitGet(d *Daemon, r *http.Request) response.Response {
return response.ForwardedResponse(client, r)
}

type operationWebSocket struct {
req *http.Request
op *operations.Operation
}

// Render implements response.Response for operationWebSocket.
func (r *operationWebSocket) Render(w http.ResponseWriter) error {
chanErr, err := r.op.Connect(r.req, w)
if err != nil {
return err
}

err = <-chanErr
return err
}

// String implements fmt.Stringer for operationWebSocket.
func (r *operationWebSocket) String() string {
_, md, err := r.op.Render()
if err != nil {
return fmt.Sprintf("error: %s", err)
}

return md.ID
}

// swagger:operation GET /1.0/operations/{id}/websocket?public operations operation_websocket_get_untrusted
//
// Get the websocket stream
Expand Down Expand Up @@ -1125,7 +1099,7 @@ func operationWebsocketGet(d *Daemon, r *http.Request) response.Response {
// First check if the query is for a local operation from this node
op, err := operations.OperationGetInternal(id)
if err == nil {
return &operationWebSocket{r, op}
return operations.OperationWebSocket(r, op)
}

// Then check if the query is from an operation on another node, and, if so, forward it
Expand Down
6 changes: 4 additions & 2 deletions lxd/operations/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func OperationResponse(op *Operation) response.Response {
return &operationResponse{op}
}

func (r *operationResponse) Render(w http.ResponseWriter) error {
// Render builds operationResponse and writes it to http.ResponseWriter.
func (r *operationResponse) Render(w http.ResponseWriter, req *http.Request) error {
err := r.op.Start()
if err != nil {
return err
Expand Down Expand Up @@ -79,7 +80,8 @@ func ForwardedOperationResponse(project string, op *api.Operation) response.Resp
}
}

func (r *forwardedOperationResponse) Render(w http.ResponseWriter) error {
// Render builds forwardedOperationResponse and writes it to http.ResponseWriter.
func (r *forwardedOperationResponse) Render(w http.ResponseWriter, req *http.Request) error {
url := fmt.Sprintf("/%s/operations/%s", version.APIVersion, r.op.ID)
if r.project != "" {
url += fmt.Sprintf("?project=%s", r.project)
Expand Down
Loading
Loading