forked from jmwample/mp-quic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
streams_map.go
333 lines (285 loc) · 8.32 KB
/
streams_map.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
package quic
import (
"errors"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamsMap struct {
mutex sync.RWMutex
perspective protocol.Perspective
connectionParameters handshake.ConnectionParametersManager
streams map[protocol.StreamID]*stream
// needed for round-robin scheduling
openStreams []protocol.StreamID
roundRobinIndex uint32
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
numOutgoingStreams uint32
numIncomingStreams uint32
}
type streamLambda func(*stream) (bool, error)
type newStreamLambda func(protocol.StreamID) *stream
var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map")
)
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap {
sm := streamsMap{
perspective: pers,
streams: map[protocol.StreamID]*stream{},
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
connectionParameters: connectionParameters,
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
if pers == protocol.PerspectiveClient {
sm.nextStream = 1
sm.nextStreamToAccept = 2
} else {
sm.nextStream = 2
sm.nextStreamToAccept = 1
}
return &sm
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
m.mutex.RLock()
s, ok := m.streams[id]
m.mutex.RUnlock()
if ok {
return s, nil // s may be nil
}
// ... we don't have an existing stream
m.mutex.Lock()
defer m.mutex.Unlock()
// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
s, ok = m.streams[id]
if ok {
return s, nil
}
if m.perspective == protocol.PerspectiveServer {
if id%2 == 0 {
if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
}
if m.perspective == protocol.PerspectiveClient {
if id%2 == 1 {
if id <= m.nextStream { // this is a client-side stream that we already opened.
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
}
// sid is the next stream that will be opened
sid := m.highestStreamOpenedByPeer + 2
// if there is no stream opened yet, and this is the server, stream 1 should be openend
if sid == 2 && m.perspective == protocol.PerspectiveServer {
sid = 1
}
for ; sid <= id; sid += 2 {
_, err := m.openRemoteStream(sid)
if err != nil {
return nil, err
}
}
m.nextStreamOrErrCond.Broadcast()
return m.streams[id], nil
}
func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() {
return nil, qerr.TooManyOpenStreams
}
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
}
if m.perspective == protocol.PerspectiveServer {
m.numIncomingStreams++
} else {
m.numOutgoingStreams++
}
if id > m.highestStreamOpenedByPeer {
m.highestStreamOpenedByPeer = id
}
s := m.newStream(id)
m.putStream(s)
return s, nil
}
func (m *streamsMap) openStreamImpl() (*stream, error) {
id := m.nextStream
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
return nil, qerr.TooManyOpenStreams
}
if m.perspective == protocol.PerspectiveServer {
m.numOutgoingStreams++
} else {
m.numIncomingStreams++
}
m.nextStream += 2
s := m.newStream(id)
m.putStream(s)
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMap) OpenStream() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closeErr != nil {
return nil, m.closeErr
}
return m.openStreamImpl()
}
func (m *streamsMap) OpenStreamSync() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for {
if m.closeErr != nil {
return nil, m.closeErr
}
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamOrErrCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
func (m *streamsMap) AcceptStream() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var str *stream
for {
var ok bool
if m.closeErr != nil {
return nil, m.closeErr
}
str, ok = m.streams[m.nextStreamToAccept]
if ok {
break
}
m.nextStreamOrErrCond.Wait()
}
m.nextStreamToAccept += 2
return str, nil
}
func (m *streamsMap) Iterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
openStreams := append([]protocol.StreamID{}, m.openStreams...)
for _, streamID := range openStreams {
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err
}
if !cont {
break
}
}
return nil
}
// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false
// It uses a round-robin-like scheduling to ensure that every stream is considered fairly
// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3)
func (m *streamsMap) RoundRobinIterate(fn streamLambda) error {
m.mutex.Lock()
defer m.mutex.Unlock()
numStreams := uint32(len(m.streams))
startIndex := m.roundRobinIndex
for _, i := range []protocol.StreamID{1, 3} {
cont, err := m.iterateFunc(i, fn)
if err != nil && err != errMapAccess {
return err
}
if !cont {
return nil
}
}
for i := uint32(0); i < numStreams; i++ {
streamID := m.openStreams[(i+startIndex)%numStreams]
if streamID == 1 || streamID == 3 {
continue
}
cont, err := m.iterateFunc(streamID, fn)
if err != nil {
return err
}
m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams
if !cont {
break
}
}
return nil
}
func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) {
str, ok := m.streams[streamID]
if !ok {
return true, errMapAccess
}
return fn(str)
}
func (m *streamsMap) putStream(s *stream) error {
id := s.StreamID()
if _, ok := m.streams[id]; ok {
return fmt.Errorf("a stream with ID %d already exists", id)
}
m.streams[id] = s
m.openStreams = append(m.openStreams, id)
return nil
}
// Attention: this function must only be called if a mutex has been acquired previously
func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
s, ok := m.streams[id]
if !ok || s == nil {
return fmt.Errorf("attempted to remove non-existing stream: %d", id)
}
if id%2 == 0 {
m.numOutgoingStreams--
} else {
m.numIncomingStreams--
}
for i, s := range m.openStreams {
if s == id {
// delete the streamID from the openStreams slice
m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])]
// adjust round-robin index, if necessary
if uint32(i) < m.roundRobinIndex {
m.roundRobinIndex--
}
break
}
}
delete(m.streams, id)
m.openStreamOrErrCond.Signal()
return nil
}
func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast()
for _, s := range m.openStreams {
m.streams[s].Cancel(err)
}
}