Files

196 lines
5.8 KiB
Go
Raw Permalink Normal View History

2023-04-21 02:49:06 +08:00
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package web
import (
2026-03-29 12:24:30 +02:00
"bufio"
2023-04-21 02:49:06 +08:00
"fmt"
2026-03-29 12:24:30 +02:00
"net"
2023-04-21 02:49:06 +08:00
"net/http"
"reflect"
"code.gitea.io/gitea/modules/log"
2023-04-21 02:49:06 +08:00
"code.gitea.io/gitea/modules/web/routing"
"code.gitea.io/gitea/modules/web/types"
2023-04-21 02:49:06 +08:00
)
var responseStatusProviders = map[reflect.Type]func(req *http.Request) types.ResponseStatusProvider{}
2023-04-21 02:49:06 +08:00
func RegisterResponseStatusProvider[T any](fn func(req *http.Request) types.ResponseStatusProvider) {
2025-08-28 11:13:31 +08:00
responseStatusProviders[reflect.TypeFor[T]()] = fn
}
2023-04-21 02:49:06 +08:00
// responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written
type responseWriter struct {
respWriter http.ResponseWriter
status int
}
var _ types.ResponseStatusProvider = (*responseWriter)(nil)
2023-04-21 02:49:06 +08:00
func (r *responseWriter) WrittenStatus() int {
return r.status
2023-04-21 02:49:06 +08:00
}
func (r *responseWriter) Header() http.Header {
return r.respWriter.Header()
}
func (r *responseWriter) Write(bytes []byte) (int, error) {
if r.status == 0 {
r.status = http.StatusOK
}
return r.respWriter.Write(bytes)
}
func (r *responseWriter) WriteHeader(statusCode int) {
r.status = statusCode
r.respWriter.WriteHeader(statusCode)
}
2026-03-29 12:24:30 +02:00
func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := r.respWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, http.ErrNotSupported
}
2023-04-21 02:49:06 +08:00
var (
2025-08-28 11:13:31 +08:00
httpReqType = reflect.TypeFor[*http.Request]()
respWriterType = reflect.TypeFor[http.ResponseWriter]()
2023-04-21 02:49:06 +08:00
)
// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup
func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
hasStatusProvider := false
for _, argIn := range argsIn {
if _, hasStatusProvider = argIn.Interface().(types.ResponseStatusProvider); hasStatusProvider {
2023-04-21 02:49:06 +08:00
break
}
}
if !hasStatusProvider {
panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type()))
}
2024-12-24 11:43:57 +08:00
if fn.Type().NumOut() != 0 {
panic(fmt.Sprintf("handler should have no return value other than registered ones, but got %s", fn.Type()))
2023-04-21 02:49:06 +08:00
}
}
func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
defer func() {
2026-03-08 17:59:46 +08:00
if recovered := recover(); recovered != nil {
err := fmt.Errorf("%v\n%s", recovered, log.Stack(2))
log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
panic(err)
}
}()
2023-04-21 02:49:06 +08:00
isPreCheck := req == nil
argsIn := make([]reflect.Value, fn.Type().NumIn())
for i := 0; i < fn.Type().NumIn(); i++ {
argTyp := fn.Type().In(i)
switch argTyp {
case respWriterType:
argsIn[i] = reflect.ValueOf(resp)
case httpReqType:
argsIn[i] = reflect.ValueOf(req)
default:
if argFn, ok := responseStatusProviders[argTyp]; ok {
2023-04-21 02:49:06 +08:00
if isPreCheck {
argsIn[i] = reflect.ValueOf(&responseWriter{})
} else {
argsIn[i] = reflect.ValueOf(argFn(req))
}
} else {
panic(fmt.Sprintf("unsupported argument type: %s", argTyp))
}
}
}
return argsIn
}
2024-12-24 11:43:57 +08:00
func handleResponse(fn reflect.Value, ret []reflect.Value) {
if len(ret) != 0 {
2023-04-21 02:49:06 +08:00
panic(fmt.Sprintf("unsupported return values: %s", fn.Type()))
}
}
func hasResponseBeenWritten(argsIn []reflect.Value) bool {
for _, argIn := range argsIn {
if statusProvider, ok := argIn.Interface().(types.ResponseStatusProvider); ok {
if statusProvider.WrittenStatus() != 0 {
2023-04-21 02:49:06 +08:00
return true
}
}
}
return false
}
2026-03-08 17:59:46 +08:00
type middlewareProvider = func(next http.Handler) http.Handler
func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
handler := endpoint
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler).ServeHTTP
}
handler(w, r)
}
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
2024-04-21 08:53:45 +08:00
return func(next http.Handler) http.Handler {
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
defer routing.RecordFuncInfo(req.Context(), funcInfo)()
2024-04-21 08:53:45 +08:00
h.ServeHTTP(resp, req)
})
}
}
2023-04-21 02:49:06 +08:00
// toHandlerProvider converts a handler to a handler provider
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
2026-03-08 17:59:46 +08:00
func toHandlerProvider(handler any) middlewareProvider {
2023-04-21 02:49:06 +08:00
funcInfo := routing.GetFuncInfo(handler)
fn := reflect.ValueOf(handler)
if fn.Type().Kind() != reflect.Func {
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
}
2026-03-08 17:59:46 +08:00
if hp, ok := handler.(middlewareProvider); ok {
2024-04-21 08:53:45 +08:00
return wrapHandlerProvider(hp, funcInfo)
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
return wrapHandlerProvider(hp, funcInfo)
}
2023-04-21 02:49:06 +08:00
provider := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) {
// wrap the response writer to check whether the response has been written
resp := respOrig
if _, ok := resp.(types.ResponseStatusProvider); !ok {
2023-04-21 02:49:06 +08:00
resp = &responseWriter{respWriter: resp}
}
// prepare the arguments for the handler and do pre-check
argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo)
2023-04-21 02:49:06 +08:00
if req == nil {
preCheckHandler(fn, argsIn)
return // it's doing pre-check, just return
}
defer routing.RecordFuncInfo(req.Context(), funcInfo)()
2023-04-21 02:49:06 +08:00
ret := fn.Call(argsIn)
2024-12-24 11:43:57 +08:00
// handle the return value (no-op at the moment)
handleResponse(fn, ret)
2023-04-21 02:49:06 +08:00
// if the response has not been written, call the next handler
if next != nil && !hasResponseBeenWritten(argsIn) {
next.ServeHTTP(resp, req)
}
})
}
provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported
return provider
}