summaryrefslogtreecommitdiff
path: root/daemon/networkdriver/portallocator/portallocator.go
blob: 251ab94473388b85547ffe28f7c85246432ee894 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package portallocator

import (
	"errors"
	"net"
	"sync"
)

type (
	portMap     map[int]bool
	protocolMap map[string]portMap
	ipMapping   map[string]protocolMap
)

const (
	BeginPortRange = 49153
	EndPortRange   = 65535
)

var (
	ErrAllPortsAllocated    = errors.New("all ports are allocated")
	ErrPortAlreadyAllocated = errors.New("port has already been allocated")
	ErrUnknownProtocol      = errors.New("unknown protocol")
)

var (
	mutex sync.Mutex

	defaultIP = net.ParseIP("0.0.0.0")
	globalMap = ipMapping{}
)

func RequestPort(ip net.IP, proto string, port int) (int, error) {
	mutex.Lock()
	defer mutex.Unlock()

	if err := validateProto(proto); err != nil {
		return 0, err
	}

	ip = getDefault(ip)

	mapping := getOrCreate(ip)

	if port > 0 {
		if !mapping[proto][port] {
			mapping[proto][port] = true
			return port, nil
		} else {
			return 0, ErrPortAlreadyAllocated
		}
	} else {
		port, err := findPort(ip, proto)

		if err != nil {
			return 0, err
		}

		return port, nil
	}
}

func ReleasePort(ip net.IP, proto string, port int) error {
	mutex.Lock()
	defer mutex.Unlock()

	ip = getDefault(ip)

	mapping := getOrCreate(ip)
	delete(mapping[proto], port)

	return nil
}

func ReleaseAll() error {
	mutex.Lock()
	defer mutex.Unlock()

	globalMap = ipMapping{}

	return nil
}

func getOrCreate(ip net.IP) protocolMap {
	ipstr := ip.String()

	if _, ok := globalMap[ipstr]; !ok {
		globalMap[ipstr] = protocolMap{
			"tcp": portMap{},
			"udp": portMap{},
		}
	}

	return globalMap[ipstr]
}

func findPort(ip net.IP, proto string) (int, error) {
	port := BeginPortRange

	mapping := getOrCreate(ip)

	for mapping[proto][port] {
		port++

		if port > EndPortRange {
			return 0, ErrAllPortsAllocated
		}
	}

	mapping[proto][port] = true

	return port, nil
}

func getDefault(ip net.IP) net.IP {
	if ip == nil {
		return defaultIP
	}

	return ip
}

func validateProto(proto string) error {
	if proto != "tcp" && proto != "udp" {
		return ErrUnknownProtocol
	}

	return nil
}