diff --git a/cmd/replicatrd/replicatr/alias.go b/cmd/replicatrd/replicatr/alias.go index 8783ca81..502f6a74 100644 --- a/cmd/replicatrd/replicatr/alias.go +++ b/cmd/replicatrd/replicatr/alias.go @@ -5,6 +5,7 @@ import ( "net/http" "sync" + "github.com/Hubmakerlabs/replicatr/pkg/nostr/nip1" "github.com/fasthttp/websocket" "github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr/nip11" @@ -21,6 +22,7 @@ type ( TagMap = nostr.TagMap EventEnvelope = nostr.EventEnvelope OKEnvelope = nostr.OKEnvelope + EventID = nip1.EventID CountEnvelope = nostr.CountEnvelope ClosedEnvelope = nostr.ClosedEnvelope ReqEnvelope = nostr.ReqEnvelope diff --git a/cmd/replicatrd/replicatr/handlecount.go b/cmd/replicatrd/replicatr/handlecount.go new file mode 100644 index 00000000..7e243a82 --- /dev/null +++ b/cmd/replicatrd/replicatr/handlecount.go @@ -0,0 +1,28 @@ +package replicatr + +func (rl *Relay) handleCountRequest(ctx Ctx, ws *WebSocket, + filter *Filter) (subtotal int64) { + + // overwrite the filter (for example, to eliminate some kinds or tags that + // we know we don't support) + for _, ovw := range rl.OverwriteCountFilter { + ovw(ctx, filter) + } + // then check if we'll reject this filter + for _, reject := range rl.RejectCountFilter { + if rej, msg := reject(ctx, filter); rej { + rl.E.Chk(ws.WriteJSON(NoticeEnvelope(msg))) + return 0 + } + } + // run the functions to count (generally it will be just one) + var e error + var res int64 + for _, count := range rl.CountEvents { + if res, e = count(ctx, filter); rl.E.Chk(e) { + rl.E.Chk(ws.WriteJSON(NoticeEnvelope(e.Error()))) + } + subtotal += res + } + return +} diff --git a/cmd/replicatrd/replicatr/deleting.go b/cmd/replicatrd/replicatr/handledelete.go similarity index 100% rename from cmd/replicatrd/replicatr/deleting.go rename to cmd/replicatrd/replicatr/handledelete.go diff --git a/cmd/replicatrd/replicatr/responding.go b/cmd/replicatrd/replicatr/handlefilter.go similarity index 62% rename from cmd/replicatrd/replicatr/responding.go rename to cmd/replicatrd/replicatr/handlefilter.go index a9a70f15..c48ef778 100644 --- a/cmd/replicatrd/replicatr/responding.go +++ b/cmd/replicatrd/replicatr/handlefilter.go @@ -4,10 +4,9 @@ import ( err "errors" "github.com/Hubmakerlabs/replicatr/pkg/nostr/normalize" - "github.com/pkg/errors" ) -func (rl *Relay) handleRequest(ctx Ctx, id string, +func (rl *Relay) handleFilter(ctx Ctx, id string, eose *WaitGroup, ws *WebSocket, f *Filter) (e error) { defer eose.Done() @@ -17,7 +16,7 @@ func (rl *Relay) handleRequest(ctx Ctx, id string, ovw(ctx, f) } if f.Limit < 0 { - e = errors.New("blocked: filter invalidated") + e = err.New("blocked: filter invalidated") rl.E.Chk(e) return } @@ -57,29 +56,3 @@ func (rl *Relay) handleRequest(ctx Ctx, id string, return nil } -func (rl *Relay) handleCountRequest(ctx Ctx, ws *WebSocket, - filter *Filter) (subtotal int64) { - - // overwrite the filter (for example, to eliminate some kinds or tags that - // we know we don't support) - for _, ovw := range rl.OverwriteCountFilter { - ovw(ctx, filter) - } - // then check if we'll reject this filter - for _, reject := range rl.RejectCountFilter { - if rej, msg := reject(ctx, filter); rej { - rl.E.Chk(ws.WriteJSON(NoticeEnvelope(msg))) - return 0 - } - } - // run the functions to count (generally it will be just one) - var e error - var res int64 - for _, count := range rl.CountEvents { - if res, e = count(ctx, filter); rl.E.Chk(e) { - rl.E.Chk(ws.WriteJSON(NoticeEnvelope(e.Error()))) - } - subtotal += res - } - return -} diff --git a/cmd/replicatrd/replicatr/handlenip11.go b/cmd/replicatrd/replicatr/handlenip11.go new file mode 100644 index 00000000..e1aa85b0 --- /dev/null +++ b/cmd/replicatrd/replicatr/handlenip11.go @@ -0,0 +1,15 @@ +package replicatr + +import ( + "encoding/json" + "net/http" +) + +func (rl *Relay) HandleNIP11(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/nostr+json") + info := rl.Info + for _, ovw := range rl.OverwriteRelayInfo { + info = ovw(r.Context(), r, info) + } + rl.E.Chk(json.NewEncoder(w).Encode(info)) +} diff --git a/cmd/replicatrd/replicatr/helpers.go b/cmd/replicatrd/replicatr/helpers.go index cf176012..5b0f5538 100644 --- a/cmd/replicatrd/replicatr/helpers.go +++ b/cmd/replicatrd/replicatr/helpers.go @@ -7,15 +7,52 @@ import ( "strings" "unsafe" + "github.com/nbd-wtf/go-nostr" + "github.com/sebest/xff" log2 "mleku.online/git/log" ) +const ( + wsKey = iota + subscriptionIdKey +) + var ( log = log2.GetLogger() fails = log.D.Chk hexDecode, encodeToHex = hex.DecodeString, hex.EncodeToString ) +func RequestAuth(ctx Ctx) { + ws := GetConnection(ctx) + ws.authLock.Lock() + if ws.Authed == nil { + ws.Authed = make(chan struct{}) + } + ws.authLock.Unlock() + log.E.Chk(ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})) +} + +func GetConnection(ctx Ctx) *WebSocket { return ctx.Value(wsKey).(*WebSocket) } + +func GetAuthed(ctx Ctx) string { return GetConnection(ctx).AuthedPublicKey } + +func GetIP(ctx Ctx) string { return xff.GetRemoteAddr(GetConnection(ctx).Request) } + +func GetSubscriptionID(ctx Ctx) string { return ctx.Value(subscriptionIdKey).(string) } + +func GetOpenSubscriptions(ctx Ctx) Filters { + if subs, ok := listeners.Load(GetConnection(ctx)); ok { + res := make(Filters, 0, listeners.Size()*2) + subs.Range(func(_ string, sub *Listener) bool { + res = append(res, sub.filters...) + return true + }) + return res + } + return nil +} + func pointerHasher[V any](_ maphash.Seed, k *V) uint64 { return uint64(uintptr(unsafe.Pointer(k))) } diff --git a/cmd/replicatrd/replicatr/http.go b/cmd/replicatrd/replicatr/http.go new file mode 100644 index 00000000..8af6fba5 --- /dev/null +++ b/cmd/replicatrd/replicatr/http.go @@ -0,0 +1,22 @@ +package replicatr + +import ( + "net/http" + + "github.com/rs/cors" +) + +// ServeHTTP implements http.Handler interface. +func (rl *Relay) ServeHTTP(w ResponseWriter, r *Request) { + if rl.ServiceURL == "" { + rl.ServiceURL = getServiceBaseURL(r) + } + if r.Header.Get("Upgrade") == "websocket" { + rl.HandleWebsocket(w, r) + } else if r.Header.Get("Accept") == "application/nostr+json" { + cors.AllowAll().Handler(http.HandlerFunc(rl.HandleNIP11)).ServeHTTP(w, r) + } else { + rl.serveMux.ServeHTTP(w, r) + } +} + diff --git a/cmd/replicatrd/replicatr/nip04.go b/cmd/replicatrd/replicatr/nip4.go similarity index 88% rename from cmd/replicatrd/replicatr/nip04.go rename to cmd/replicatrd/replicatr/nip4.go index 0c501f16..6c5d5891 100644 --- a/cmd/replicatrd/replicatr/nip04.go +++ b/cmd/replicatrd/replicatr/nip4.go @@ -4,9 +4,9 @@ import ( "golang.org/x/exp/slices" ) -// RejectKind04Snoopers prevents reading NIP-04 messages from people not +// RejectKind4Snoopers prevents reading NIP-04 messages from people not // involved in the conversation. -func RejectKind04Snoopers(ctx Ctx, filter *Filter) (bool, string) { +func RejectKind4Snoopers(ctx Ctx, filter *Filter) (bool, string) { // prevent kind-4 events from being returned to unauthed users, only when // authentication is a thing if !slices.Contains(filter.Kinds, 4) { diff --git a/cmd/replicatrd/replicatr/events.go b/cmd/replicatrd/replicatr/policiesevents.go similarity index 99% rename from cmd/replicatrd/replicatr/events.go rename to cmd/replicatrd/replicatr/policiesevents.go index eec7cd67..a3e3bf6f 100644 --- a/cmd/replicatrd/replicatr/events.go +++ b/cmd/replicatrd/replicatr/policiesevents.go @@ -28,12 +28,10 @@ func PreventTooManyIndexableTags(max int, ignoreKinds []int, return !isApplicable } } - return func(ctx Ctx, event *Event) (reject bool, msg string) { if ignore(event.Kind) { return false, "" } - ntags := 0 for _, tag := range event.Tags { if len(tag) > 0 && len(tag[0]) == 1 { diff --git a/cmd/replicatrd/replicatr/filters.go b/cmd/replicatrd/replicatr/policiesfilters.go similarity index 95% rename from cmd/replicatrd/replicatr/filters.go rename to cmd/replicatrd/replicatr/policiesfilters.go index 58c23274..a632d811 100644 --- a/cmd/replicatrd/replicatr/filters.go +++ b/cmd/replicatrd/replicatr/policiesfilters.go @@ -3,7 +3,6 @@ package replicatr import ( "context" - "github.com/nbd-wtf/go-nostr" "golang.org/x/exp/slices" ) @@ -47,7 +46,7 @@ func RemoveSearchQueries(ctx Ctx, filter *Filter) { filter.Search = "" } -func RemoveAllButKinds(kinds ...uint16) func(Ctx, *nostr.Filter) { +func RemoveAllButKinds(kinds ...uint16) func(Ctx, *Filter) { return func(ctx Ctx, filter *Filter) { if n := len(filter.Kinds); n > 0 { newKinds := make([]int, 0, n) diff --git a/cmd/replicatrd/replicatr/relay.go b/cmd/replicatrd/replicatr/relay.go index 817725d0..4e149b8c 100644 --- a/cmd/replicatrd/replicatr/relay.go +++ b/cmd/replicatrd/replicatr/relay.go @@ -35,29 +35,6 @@ type ( OnEventSaved func(ctx Ctx, event *Event) ) -func NewRelay(appName string) (r *Relay) { - r = &Relay{ - Log: log2.New(os.Stderr, appName, 0), - Info: &Info{ - Software: "https://github.com/Hubmakerlabs/replicatr/cmd/khatru", - Version: "n/a", - SupportedNIPs: make([]int, 0), - }, - upgrader: websocket.Upgrader{ - ReadBufferSize: ReadBufferSize, - WriteBufferSize: WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - }, - clients: xsync.NewTypedMapOf[*websocket.Conn, struct{}](pointerHasher[websocket.Conn]), - serveMux: &http.ServeMux{}, - WriteWait: WriteWait, - PongWait: PongWait, - PingPeriod: PingPeriod, - MaxMessageSize: MaxMessageSize, - } - return -} - type Relay struct { ServiceURL string RejectEvent []RejectEvent @@ -92,3 +69,26 @@ type Relay struct { PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait. MaxMessageSize int64 // Maximum message size allowed from peer. } + +func NewRelay(appName string) (r *Relay) { + r = &Relay{ + Log: log2.New(os.Stderr, appName, 0), + Info: &Info{ + Software: "https://github.com/Hubmakerlabs/replicatr/cmd/khatru", + Version: "n/a", + SupportedNIPs: make([]int, 0), + }, + upgrader: websocket.Upgrader{ + ReadBufferSize: ReadBufferSize, + WriteBufferSize: WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + }, + clients: xsync.NewTypedMapOf[*websocket.Conn, struct{}](pointerHasher[websocket.Conn]), + serveMux: &http.ServeMux{}, + WriteWait: WriteWait, + PongWait: PongWait, + PingPeriod: PingPeriod, + MaxMessageSize: MaxMessageSize, + } + return +} diff --git a/cmd/replicatrd/replicatr/utils.go b/cmd/replicatrd/replicatr/utils.go deleted file mode 100644 index 88b14f54..00000000 --- a/cmd/replicatrd/replicatr/utils.go +++ /dev/null @@ -1,41 +0,0 @@ -package replicatr - -import ( - "github.com/nbd-wtf/go-nostr" - "github.com/sebest/xff" -) - -const ( - wsKey = iota - subscriptionIdKey -) - -func RequestAuth(ctx Ctx) { - ws := GetConnection(ctx) - ws.authLock.Lock() - if ws.Authed == nil { - ws.Authed = make(chan struct{}) - } - ws.authLock.Unlock() - log.E.Chk(ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})) -} - -func GetConnection(ctx Ctx) *WebSocket { return ctx.Value(wsKey).(*WebSocket) } - -func GetAuthed(ctx Ctx) string { return GetConnection(ctx).AuthedPublicKey } - -func GetIP(ctx Ctx) string { return xff.GetRemoteAddr(GetConnection(ctx).Request) } - -func GetSubscriptionID(ctx Ctx) string { return ctx.Value(subscriptionIdKey).(string) } - -func GetOpenSubscriptions(ctx Ctx) Filters { - if subs, ok := listeners.Load(GetConnection(ctx)); ok { - res := make(Filters, 0, listeners.Size()*2) - subs.Range(func(_ string, sub *Listener) bool { - res = append(res, sub.filters...) - return true - }) - return res - } - return nil -} diff --git a/cmd/replicatrd/replicatr/handlers.go b/cmd/replicatrd/replicatr/websockethandler.go similarity index 81% rename from cmd/replicatrd/replicatr/handlers.go rename to cmd/replicatrd/replicatr/websockethandler.go index 5aa53b8e..547f37ff 100644 --- a/cmd/replicatrd/replicatr/handlers.go +++ b/cmd/replicatrd/replicatr/websockethandler.go @@ -5,33 +5,15 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" - "encoding/json" - err "errors" - "net/http" + "errors" "strings" - "sync" "time" "github.com/fasthttp/websocket" "github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr/nip42" - "github.com/rs/cors" ) -// ServeHTTP implements http.Handler interface. -func (rl *Relay) ServeHTTP(w ResponseWriter, r *Request) { - if rl.ServiceURL == "" { - rl.ServiceURL = getServiceBaseURL(r) - } - if r.Header.Get("Upgrade") == "websocket" { - rl.HandleWebsocket(w, r) - } else if r.Header.Get("Accept") == "application/nostr+json" { - cors.AllowAll().Handler(http.HandlerFunc(rl.HandleNIP11)).ServeHTTP(w, r) - } else { - rl.serveMux.ServeHTTP(w, r) - } -} - func (rl *Relay) HandleWebsocket(w ResponseWriter, r *Request) { var e error var conn *Conn @@ -69,11 +51,11 @@ func (rl *Relay) HandleWebsocket(w ResponseWriter, r *Request) { removeListener(ws) } } - go rl.readMessages(ctx, kill, ws, conn, r) - go rl.watcher(ctx, kill, ticker, ws) + go rl.websocketReadMessages(ctx, kill, ws, conn, r) + go rl.websocketWatcher(ctx, kill, ticker, ws) } -func (rl *Relay) processMessages(message []byte, ctx Ctx, ws *WebSocket) { +func (rl *Relay) websocketProcessMessages(message []byte, ctx Ctx, ws *WebSocket) { var e error envelope := nostr.ParseMessage(message) if envelope == nil { @@ -136,8 +118,8 @@ func (rl *Relay) processMessages(message []byte, ctx Ctx, ws *WebSocket) { if rl.CountEvents == nil { rl.E.Chk(ws.WriteJSON(ClosedEnvelope{ SubscriptionID: env.SubscriptionID, - Reason: "unsupported: this relay does not support NIP-45"}, - )) + Reason: "unsupported: this relay does not support NIP-45", + })) return } var total int64 @@ -149,7 +131,7 @@ func (rl *Relay) processMessages(message []byte, ctx Ctx, ws *WebSocket) { Count: &total, })) case *ReqEnvelope: - eose := sync.WaitGroup{} + eose := WaitGroup{} eose.Add(len(env.Filters)) // a context just for the "stored events" request handler reqCtx, cancelReqCtx := context.WithCancelCause(ctx) @@ -157,7 +139,7 @@ func (rl *Relay) processMessages(message []byte, ctx Ctx, ws *WebSocket) { reqCtx = context.WithValue(reqCtx, subscriptionIdKey, env.SubscriptionID) // handle each filter separately -- dispatching events as they're loaded from databases for _, filter := range env.Filters { - e = rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, &filter) + e = rl.handleFilter(reqCtx, env.SubscriptionID, &eose, ws, &filter) if rl.E.Chk(e) { // fail everything if any filter is rejected reason := e.Error() @@ -168,7 +150,7 @@ func (rl *Relay) processMessages(message []byte, ctx Ctx, ws *WebSocket) { SubscriptionID: env.SubscriptionID, Reason: reason}, )) - cancelReqCtx(err.New("filter rejected")) + cancelReqCtx(errors.New("filter rejected")) return } } @@ -205,7 +187,10 @@ func (rl *Relay) processMessages(message []byte, ctx Ctx, ws *WebSocket) { } } } -func (rl *Relay) readMessages(ctx Ctx, kill func(), ws *WebSocket, conn *Conn, r *Request) { + +func (rl *Relay) websocketReadMessages(ctx Ctx, kill func(), + ws *WebSocket, conn *Conn, r *Request) { + defer kill() conn.SetReadLimit(rl.MaxMessageSize) rl.E.Chk(conn.SetReadDeadline(time.Now().Add(rl.PongWait))) @@ -239,11 +224,11 @@ func (rl *Relay) readMessages(ctx Ctx, kill func(), ws *WebSocket, conn *Conn, r rl.E.Chk(ws.WriteMessage(websocket.PongMessage, nil)) continue } - go rl.processMessages(message, ctx, ws) + go rl.websocketProcessMessages(message, ctx, ws) } } -func (rl *Relay) watcher(ctx Ctx, kill func(), ticker *time.Ticker, ws *WebSocket) { +func (rl *Relay) websocketWatcher(ctx Ctx, kill func(), ticker *time.Ticker, ws *WebSocket) { var e error defer kill() for { @@ -260,12 +245,3 @@ func (rl *Relay) watcher(ctx Ctx, kill func(), ticker *time.Ticker, ws *WebSocke } } } - -func (rl *Relay) HandleNIP11(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/nostr+json") - info := rl.Info - for _, ovw := range rl.OverwriteRelayInfo { - info = ovw(r.Context(), r, info) - } - rl.E.Chk(json.NewEncoder(w).Encode(info)) -}