aboutsummaryrefslogtreecommitdiff
path: root/tun2/connection.go
blob: ff03e9695c5764fa06b73850b89a30abb3a6f60f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package tun2

import (
	"bufio"
	"context"
	"expvar"
	"net"
	"net/http"
	"sync"
	"time"

	failure "github.com/dgryski/go-failure"
	"github.com/pkg/errors"
	"github.com/xtaci/smux"
	"within.website/ln"
	"within.website/ln/opname"
)

// Connection is a single active client -> server connection and session
// containing many streams over TCP+TLS or KCP+TLS. Every stream beyond the
// control stream is assumed to be passed to the underlying backend server.
//
// All Connection methods assume this is locked externally.
type Connection struct {
	id            string
	conn          net.Conn
	session       *smux.Session
	controlStream *smux.Stream
	user          string
	domain        string
	cf            context.CancelFunc
	detector      *failure.Detector
	Auth          *Auth
	usable        bool

	sync.Mutex
	counter *expvar.Int
}

func (c *Connection) cancel() {
	c.cf()
	c.usable = false
}

// F logs key->value pairs as an ln.Fer
func (c *Connection) F() ln.F {
	return map[string]interface{}{
		"id":     c.id,
		"remote": c.conn.RemoteAddr(),
		"local":  c.conn.LocalAddr(),
		"kind":   c.conn.LocalAddr().Network(),
		"user":   c.user,
		"domain": c.domain,
	}
}

// Ping ends a "ping" to the client. If the client doesn't respond or the connection
// dies, then the connection needs to be cleaned up.
func (c *Connection) Ping() error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	ctx = opname.With(ctx, "tun2.Connection.Ping")
	ctx = ln.WithF(ctx, ln.F{"timeout": time.Second})

	req, err := http.NewRequest("GET", "http://backend/health", nil)
	if err != nil {
		panic(err)
	}
	req = req.WithContext(ctx)

	_, err = c.RoundTrip(req)
	if err != nil {
		ln.Error(ctx, err, c, ln.Action("pinging the backend"))
		return err
	}

	c.detector.Ping(time.Now())

	return nil
}

// OpenStream creates a new stream (connection) to the backend server.
func (c *Connection) OpenStream(ctx context.Context) (net.Conn, error) {
	ctx = opname.With(ctx, "OpenStream")
	if !c.usable {
		return nil, ErrNoSuchBackend
	}
	ctx = ln.WithF(ctx, ln.F{"timeout": time.Second})

	err := c.conn.SetDeadline(time.Now().Add(time.Second))
	if err != nil {
		ln.Error(ctx, err, c)
		return nil, err
	}

	stream, err := c.session.OpenStream()
	if err != nil {
		ln.Error(ctx, err, c)
		return nil, err
	}

	return stream, c.conn.SetDeadline(time.Time{})
}

// Close destroys resouces specific to the connection.
func (c *Connection) Close() error {
	err := c.controlStream.Close()
	if err != nil {
		return err
	}

	err = c.session.Close()
	if err != nil {
		return err
	}

	err = c.conn.Close()
	if err != nil {
		return err
	}

	return nil
}

// Connection-specific errors
var (
	ErrCantOpenSessionStream = errors.New("tun2: connection can't open session stream")
	ErrCantWriteRequest      = errors.New("tun2: connection stream can't write request")
	ErrCantReadResponse      = errors.New("tun2: connection stream can't read response")
)

// RoundTrip forwards a HTTP request to the remote backend and then returns the
// response, if any.
func (c *Connection) RoundTrip(req *http.Request) (*http.Response, error) {
	ctx := req.Context()
	ctx = opname.With(ctx, "tun2.Connection.RoundTrip")
	stream, err := c.OpenStream(ctx)
	if err != nil {
		return nil, errors.Wrap(err, ErrCantOpenSessionStream.Error())
	}

	go func() {
		<-req.Context().Done()
		stream.Close()
	}()

	err = req.Write(stream)
	if err != nil {
		return nil, errors.Wrap(err, ErrCantWriteRequest.Error())
	}

	buf := bufio.NewReader(stream)

	resp, err := http.ReadResponse(buf, req)
	if err != nil {
		return nil, errors.Wrap(err, ErrCantReadResponse.Error())
	}

	c.counter.Add(1)

	return resp, nil
}