Source file
src/net/rpc/server.go
1
2
3
4
5
127 package rpc
128
129 import (
130 "bufio"
131 "encoding/gob"
132 "errors"
133 "go/token"
134 "io"
135 "log"
136 "net"
137 "net/http"
138 "reflect"
139 "strings"
140 "sync"
141 )
142
143 const (
144
145 DefaultRPCPath = "/_goRPC_"
146 DefaultDebugPath = "/debug/rpc"
147 )
148
149
150
151 var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
152
153 type methodType struct {
154 sync.Mutex
155 method reflect.Method
156 ArgType reflect.Type
157 ReplyType reflect.Type
158 numCalls uint
159 }
160
161 type service struct {
162 name string
163 rcvr reflect.Value
164 typ reflect.Type
165 method map[string]*methodType
166 }
167
168
169
170
171 type Request struct {
172 ServiceMethod string
173 Seq uint64
174 next *Request
175 }
176
177
178
179
180 type Response struct {
181 ServiceMethod string
182 Seq uint64
183 Error string
184 next *Response
185 }
186
187
188 type Server struct {
189 serviceMap sync.Map
190 reqLock sync.Mutex
191 freeReq *Request
192 respLock sync.Mutex
193 freeResp *Response
194 }
195
196
197 func NewServer() *Server {
198 return &Server{}
199 }
200
201
202 var DefaultServer = NewServer()
203
204
205 func isExportedOrBuiltinType(t reflect.Type) bool {
206 for t.Kind() == reflect.Pointer {
207 t = t.Elem()
208 }
209
210
211 return token.IsExported(t.Name()) || t.PkgPath() == ""
212 }
213
214
215
216
217
218
219
220
221
222
223
224 func (server *Server) Register(rcvr any) error {
225 return server.register(rcvr, "", false)
226 }
227
228
229
230 func (server *Server) RegisterName(name string, rcvr any) error {
231 return server.register(rcvr, name, true)
232 }
233
234
235
236 const logRegisterError = false
237
238 func (server *Server) register(rcvr any, name string, useName bool) error {
239 s := new(service)
240 s.typ = reflect.TypeOf(rcvr)
241 s.rcvr = reflect.ValueOf(rcvr)
242 sname := reflect.Indirect(s.rcvr).Type().Name()
243 if useName {
244 sname = name
245 }
246 if sname == "" {
247 s := "rpc.Register: no service name for type " + s.typ.String()
248 log.Print(s)
249 return errors.New(s)
250 }
251 if !token.IsExported(sname) && !useName {
252 s := "rpc.Register: type " + sname + " is not exported"
253 log.Print(s)
254 return errors.New(s)
255 }
256 s.name = sname
257
258
259 s.method = suitableMethods(s.typ, logRegisterError)
260
261 if len(s.method) == 0 {
262 str := ""
263
264
265 method := suitableMethods(reflect.PointerTo(s.typ), false)
266 if len(method) != 0 {
267 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
268 } else {
269 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
270 }
271 log.Print(str)
272 return errors.New(str)
273 }
274
275 if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
276 return errors.New("rpc: service already defined: " + sname)
277 }
278 return nil
279 }
280
281
282
283 func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
284 methods := make(map[string]*methodType)
285 for m := 0; m < typ.NumMethod(); m++ {
286 method := typ.Method(m)
287 mtype := method.Type
288 mname := method.Name
289
290 if !method.IsExported() {
291 continue
292 }
293
294 if mtype.NumIn() != 3 {
295 if logErr {
296 log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
297 }
298 continue
299 }
300
301 argType := mtype.In(1)
302 if !isExportedOrBuiltinType(argType) {
303 if logErr {
304 log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
305 }
306 continue
307 }
308
309 replyType := mtype.In(2)
310 if replyType.Kind() != reflect.Pointer {
311 if logErr {
312 log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
313 }
314 continue
315 }
316
317 if !isExportedOrBuiltinType(replyType) {
318 if logErr {
319 log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
320 }
321 continue
322 }
323
324 if mtype.NumOut() != 1 {
325 if logErr {
326 log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
327 }
328 continue
329 }
330
331 if returnType := mtype.Out(0); returnType != typeOfError {
332 if logErr {
333 log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
334 }
335 continue
336 }
337 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
338 }
339 return methods
340 }
341
342
343
344
345 var invalidRequest = struct{}{}
346
347 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
348 resp := server.getResponse()
349
350 resp.ServiceMethod = req.ServiceMethod
351 if errmsg != "" {
352 resp.Error = errmsg
353 reply = invalidRequest
354 }
355 resp.Seq = req.Seq
356 sending.Lock()
357 err := codec.WriteResponse(resp, reply)
358 if debugLog && err != nil {
359 log.Println("rpc: writing response:", err)
360 }
361 sending.Unlock()
362 server.freeResponse(resp)
363 }
364
365 func (m *methodType) NumCalls() (n uint) {
366 m.Lock()
367 n = m.numCalls
368 m.Unlock()
369 return n
370 }
371
372 func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
373 if wg != nil {
374 defer wg.Done()
375 }
376 mtype.Lock()
377 mtype.numCalls++
378 mtype.Unlock()
379 function := mtype.method.Func
380
381 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
382
383 errInter := returnValues[0].Interface()
384 errmsg := ""
385 if errInter != nil {
386 errmsg = errInter.(error).Error()
387 }
388 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
389 server.freeRequest(req)
390 }
391
392 type gobServerCodec struct {
393 rwc io.ReadWriteCloser
394 dec *gob.Decoder
395 enc *gob.Encoder
396 encBuf *bufio.Writer
397 closed bool
398 }
399
400 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
401 return c.dec.Decode(r)
402 }
403
404 func (c *gobServerCodec) ReadRequestBody(body any) error {
405 return c.dec.Decode(body)
406 }
407
408 func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
409 if err = c.enc.Encode(r); err != nil {
410 if c.encBuf.Flush() == nil {
411
412
413 log.Println("rpc: gob error encoding response:", err)
414 c.Close()
415 }
416 return
417 }
418 if err = c.enc.Encode(body); err != nil {
419 if c.encBuf.Flush() == nil {
420
421
422 log.Println("rpc: gob error encoding body:", err)
423 c.Close()
424 }
425 return
426 }
427 return c.encBuf.Flush()
428 }
429
430 func (c *gobServerCodec) Close() error {
431 if c.closed {
432
433 return nil
434 }
435 c.closed = true
436 return c.rwc.Close()
437 }
438
439
440
441
442
443
444
445 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
446 buf := bufio.NewWriter(conn)
447 srv := &gobServerCodec{
448 rwc: conn,
449 dec: gob.NewDecoder(conn),
450 enc: gob.NewEncoder(buf),
451 encBuf: buf,
452 }
453 server.ServeCodec(srv)
454 }
455
456
457
458 func (server *Server) ServeCodec(codec ServerCodec) {
459 sending := new(sync.Mutex)
460 wg := new(sync.WaitGroup)
461 for {
462 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
463 if err != nil {
464 if debugLog && err != io.EOF {
465 log.Println("rpc:", err)
466 }
467 if !keepReading {
468 break
469 }
470
471 if req != nil {
472 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
473 server.freeRequest(req)
474 }
475 continue
476 }
477 wg.Add(1)
478 go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
479 }
480
481
482 wg.Wait()
483 codec.Close()
484 }
485
486
487
488 func (server *Server) ServeRequest(codec ServerCodec) error {
489 sending := new(sync.Mutex)
490 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
491 if err != nil {
492 if !keepReading {
493 return err
494 }
495
496 if req != nil {
497 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
498 server.freeRequest(req)
499 }
500 return err
501 }
502 service.call(server, sending, nil, mtype, req, argv, replyv, codec)
503 return nil
504 }
505
506 func (server *Server) getRequest() *Request {
507 server.reqLock.Lock()
508 req := server.freeReq
509 if req == nil {
510 req = new(Request)
511 } else {
512 server.freeReq = req.next
513 *req = Request{}
514 }
515 server.reqLock.Unlock()
516 return req
517 }
518
519 func (server *Server) freeRequest(req *Request) {
520 server.reqLock.Lock()
521 req.next = server.freeReq
522 server.freeReq = req
523 server.reqLock.Unlock()
524 }
525
526 func (server *Server) getResponse() *Response {
527 server.respLock.Lock()
528 resp := server.freeResp
529 if resp == nil {
530 resp = new(Response)
531 } else {
532 server.freeResp = resp.next
533 *resp = Response{}
534 }
535 server.respLock.Unlock()
536 return resp
537 }
538
539 func (server *Server) freeResponse(resp *Response) {
540 server.respLock.Lock()
541 resp.next = server.freeResp
542 server.freeResp = resp
543 server.respLock.Unlock()
544 }
545
546 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
547 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
548 if err != nil {
549 if !keepReading {
550 return
551 }
552
553 codec.ReadRequestBody(nil)
554 return
555 }
556
557
558 argIsValue := false
559 if mtype.ArgType.Kind() == reflect.Pointer {
560 argv = reflect.New(mtype.ArgType.Elem())
561 } else {
562 argv = reflect.New(mtype.ArgType)
563 argIsValue = true
564 }
565
566 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
567 return
568 }
569 if argIsValue {
570 argv = argv.Elem()
571 }
572
573 replyv = reflect.New(mtype.ReplyType.Elem())
574
575 switch mtype.ReplyType.Elem().Kind() {
576 case reflect.Map:
577 replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
578 case reflect.Slice:
579 replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
580 }
581 return
582 }
583
584 func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
585
586 req = server.getRequest()
587 err = codec.ReadRequestHeader(req)
588 if err != nil {
589 req = nil
590 if err == io.EOF || err == io.ErrUnexpectedEOF {
591 return
592 }
593 err = errors.New("rpc: server cannot decode request: " + err.Error())
594 return
595 }
596
597
598
599 keepReading = true
600
601 dot := strings.LastIndex(req.ServiceMethod, ".")
602 if dot < 0 {
603 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
604 return
605 }
606 serviceName := req.ServiceMethod[:dot]
607 methodName := req.ServiceMethod[dot+1:]
608
609
610 svci, ok := server.serviceMap.Load(serviceName)
611 if !ok {
612 err = errors.New("rpc: can't find service " + req.ServiceMethod)
613 return
614 }
615 svc = svci.(*service)
616 mtype = svc.method[methodName]
617 if mtype == nil {
618 err = errors.New("rpc: can't find method " + req.ServiceMethod)
619 }
620 return
621 }
622
623
624
625
626
627 func (server *Server) Accept(lis net.Listener) {
628 for {
629 conn, err := lis.Accept()
630 if err != nil {
631 log.Print("rpc.Serve: accept:", err.Error())
632 return
633 }
634 go server.ServeConn(conn)
635 }
636 }
637
638
639 func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
640
641
642
643 func RegisterName(name string, rcvr any) error {
644 return DefaultServer.RegisterName(name, rcvr)
645 }
646
647
648
649
650
651
652
653
654
655 type ServerCodec interface {
656 ReadRequestHeader(*Request) error
657 ReadRequestBody(any) error
658 WriteResponse(*Response, any) error
659
660
661 Close() error
662 }
663
664
665
666
667
668
669
670 func ServeConn(conn io.ReadWriteCloser) {
671 DefaultServer.ServeConn(conn)
672 }
673
674
675
676 func ServeCodec(codec ServerCodec) {
677 DefaultServer.ServeCodec(codec)
678 }
679
680
681
682 func ServeRequest(codec ServerCodec) error {
683 return DefaultServer.ServeRequest(codec)
684 }
685
686
687
688
689 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
690
691
692 var connected = "200 Connected to Go RPC"
693
694
695 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
696 if req.Method != "CONNECT" {
697 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
698 w.WriteHeader(http.StatusMethodNotAllowed)
699 io.WriteString(w, "405 must CONNECT\n")
700 return
701 }
702 conn, _, err := w.(http.Hijacker).Hijack()
703 if err != nil {
704 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
705 return
706 }
707 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
708 server.ServeConn(conn)
709 }
710
711
712
713
714 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
715 http.Handle(rpcPath, server)
716 http.Handle(debugPath, debugHTTP{server})
717 }
718
719
720
721
722 func HandleHTTP() {
723 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
724 }
725
View as plain text