summaryrefslogtreecommitdiff
path: root/buildscripts/util/oauth.py
blob: 7a6fabe9fdec3d625a398022ca9447d3c1894c0e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""Helper tools to get OAuth credentials using the PKCE flow."""
from __future__ import annotations

from datetime import datetime, timedelta
from http.server import BaseHTTPRequestHandler, HTTPServer
from random import choice
from string import ascii_lowercase
from typing import Any, Callable, Optional, Tuple
from urllib.parse import parse_qs, urlsplit
from webbrowser import open as web_open

import requests
from oauthlib.oauth2 import BackendApplicationClient
from pkce import generate_pkce_pair
from pydantic import ValidationError
from pydantic.main import BaseModel
from requests_oauthlib import OAuth2Session
from buildscripts.util.fileops import read_yaml_file

AUTH_HANDLER_RESPONSE = """\
<html>
  <head>
    <title>Authentication Status</title>
    <script>
    window.onload = function() {
      window.close();
    }
    </script>
  </head>
  <body>
    <p>The authentication flow has completed.</p>
  </body>
</html>
""".encode("utf-8")


class Configs:
    """Collect configurations necessary for authentication process."""

    # pylint: disable=invalid-name

    AUTH_DOMAIN = "corp.mongodb.com/oauth2/aus4k4jv00hWjNnps297"
    CLIENT_ID = "0oa5zf9ps4N3JKWIJ297"
    REDIRECT_PORT = 8989
    SCOPE = "kanopy+openid+profile"

    def __init__(self, client_credentials_scope: str = None,
                 client_credentials_user_name: str = None, auth_domain: str = None,
                 client_id: str = None, redirect_port: int = None, scope: str = None):
        """Initialize configs instance."""

        self.AUTH_DOMAIN = auth_domain or self.AUTH_DOMAIN
        self.CLIENT_ID = client_id or self.CLIENT_ID
        self.REDIRECT_PORT = redirect_port or self.REDIRECT_PORT
        self.SCOPE = scope or self.SCOPE
        self.CLIENT_CREDENTIALS_SCOPE = client_credentials_scope
        self.CLIENT_CREDENTIALS_USER_NAME = client_credentials_user_name


class OAuthCredentials(BaseModel):
    """OAuth access token and its associated metadata."""

    expires_in: int
    access_token: str
    created_time: datetime
    user_name: str

    def are_expired(self) -> bool:
        """
        Check whether the current OAuth credentials are expired or not.

        :return: Whether the credentials are expired or not.
        """
        return self.created_time + timedelta(seconds=self.expires_in) < datetime.now()

    @classmethod
    def get_existing_credentials_from_file(cls, file_path: str) -> Optional[OAuthCredentials]:
        """
        Try to get OAuth credentials from a file location.

        Will return None if credentials either don't exist or are expired.
        :param file_path: Location to check for OAuth credentials.
        :return: Valid OAuth credentials or None if valid credentials don't exist
        """
        try:
            creds = OAuthCredentials(**read_yaml_file(file_path))
            if (creds.access_token and creds.created_time and creds.expires_in and creds.user_name
                    and not creds.are_expired()):
                return creds
            else:
                return None
        except ValidationError:
            return None
        except OSError:
            return None


class _RedirectServer(HTTPServer):
    """HTTP server to use when fetching OAuth credentials using the PKCE flow."""

    pkce_credentials: Optional[OAuthCredentials] = None
    auth_domain: str
    client_id: str
    redirect_uri: str
    code_verifier: str

    def __init__(
            self,
            server_address: Tuple[str, int],
            handler: Callable[..., BaseHTTPRequestHandler],
            redirect_uri: str,
            auth_domain: str,
            client_id: str,
            code_verifier: str,
    ):
        self.redirect_uri = redirect_uri
        self.auth_domain = auth_domain
        self.client_id = client_id
        self.code_verifier = code_verifier
        super().__init__(server_address, handler)


class _Handler(BaseHTTPRequestHandler):
    """Request handler class to use when trying to get OAuth credentials."""

    # pylint: disable=invalid-name

    server: _RedirectServer

    def _set_response(self) -> None:
        """Set the response to the server making a request."""
        self.send_response(200)
        self.send_header("Content-type", "text/html")
        self.end_headers()

    def log_message(self, log_format: Any, *args: Any) -> None:  # pylint: disable=unused-argument,arguments-differ
        """
        Log HTTP Server internal messages.

        :param log_format: The format to use when logging messages.
        :param args: Key word args.
        """
        return None

    def do_GET(self) -> None:
        """Handle the callback response from the auth server."""
        params = parse_qs(urlsplit(self.path).query)
        code = params.get("code")

        if not code:
            raise ValueError("Could not get authorization code when signing in to Okta")

        url = f"https://{self.server.auth_domain}/v1/token"
        body = {
            "grant_type": "authorization_code",
            "client_id": self.server.client_id,
            "code_verifier": self.server.code_verifier,
            "code": code,
            "redirect_uri": self.server.redirect_uri,
        }

        resp = requests.post(url, data=body).json()

        access_token = resp.get("access_token")
        expires_in = resp.get("expires_in")

        if not access_token or not expires_in:
            raise ValueError("Could not get access token or expires_in data about access token")

        headers = {"Authorization": f"Bearer {access_token}"}
        resp = requests.get(f"https://{self.server.auth_domain}/v1/userinfo",
                            headers=headers).json()

        split_username = resp["preferred_username"].split("@")

        if len(split_username) != 2:
            raise ValueError("Could not get user_name of current user")

        self.server.pkce_credentials = OAuthCredentials(
            access_token=access_token,
            expires_in=expires_in,
            created_time=datetime.now(),
            user_name=split_username[0],
        )
        self._set_response()
        self.wfile.write(AUTH_HANDLER_RESPONSE)


class PKCEOauthTools:
    """Basic toolset to get OAuth credentials using the PKCE flow."""

    auth_domain: str
    client_id: str
    redirect_port: int
    redirect_uri: str
    scope: str

    def __init__(self, auth_domain: str, client_id: str, redirect_port: int, scope: str):
        """
        Create a new PKCEOauth tools instance.

        :param auth_domain: The uri of the auth server to get the credentials from.
        :param client_id: The id of the client that you are using to authenticate.
        :param redirect_port: Port to use when setting up the local server for the auth redirect.
        :param scope: The OAuth scopes to request access for.
        """
        self.auth_domain = auth_domain
        self.client_id = client_id
        self.redirect_port = redirect_port
        self.redirect_uri = f"http://localhost:{redirect_port}/"
        self.scope = scope

    def get_pkce_credentials(self, print_auth_url: bool = False) -> OAuthCredentials:
        """
        Try to get an OAuth access token and its associated metadata.

        :param print_auth_url: Whether to print the auth url to the console instead of opening it.
        :return: OAuth credentials and some associated metadata to check if they have expired.
        """
        code_verifier, code_challenge = generate_pkce_pair()

        state = "".join(choice(ascii_lowercase) for i in range(10))

        authorization_url = (f"https://{self.auth_domain}/v1/authorize?"
                             f"scope={self.scope}&"
                             f"response_type=code&"
                             f"response_mode=query&"
                             f"client_id={self.client_id}&"
                             f"code_challenge={code_challenge}&"
                             f"state={state}&"
                             f"code_challenge_method=S256&"
                             f"redirect_uri={self.redirect_uri}")

        httpd = _RedirectServer(
            ("", self.redirect_port),
            _Handler,
            self.redirect_uri,
            self.auth_domain,
            self.client_id,
            code_verifier,
        )
        if print_auth_url:
            print("Please open the below url in a browser and sign in if necessary")
            print(authorization_url)
        else:
            web_open(authorization_url)
        httpd.handle_request()

        if not httpd.pkce_credentials:
            raise ValueError(
                "Could not retrieve Okta credentials to talk to Kanopy with. "
                "Please sign out of Okta in your browser and try runnning this script again")

        return httpd.pkce_credentials


def get_oauth_credentials(configs: Configs, print_auth_url: bool = False) -> OAuthCredentials:
    """
    Run the OAuth workflow to get credentials for a human user.

    :param configs: Configs instance.
    :param print_auth_url: Whether to print the auth url to the console instead of opening it.
    :return: OAuth credentials for the given user.
    """
    oauth_tools = PKCEOauthTools(auth_domain=configs.AUTH_DOMAIN, client_id=configs.CLIENT_ID,
                                 redirect_port=configs.REDIRECT_PORT, scope=configs.SCOPE)
    credentials = oauth_tools.get_pkce_credentials(print_auth_url)
    return credentials


def get_client_cred_oauth_credentials(client_id: str, client_secret: str,
                                      configs: Configs) -> OAuthCredentials:
    """
    Run the OAuth workflow to get credentials for a machine user.

    :param client_id: The client_id of the machine user to authenticate as.
    :param client_secret: The client_secret of the machine user to authenticate as.
    :param configs: Configs instance.
    :return: OAuth credentials for the given machine user.
    """
    client = BackendApplicationClient(client_id=client_id)
    oauth = OAuth2Session(client=client)
    token = oauth.fetch_token(
        token_url=f"https://{configs.AUTH_DOMAIN}/v1/token",
        client_id=client_id,
        client_secret=client_secret,
        scope=configs.CLIENT_CREDENTIALS_SCOPE,
    )
    access_token = token.get("access_token")
    expires_in = token.get("expires_in")

    if not access_token or not expires_in:
        raise ValueError("Could not get access token or expires_in data about access token")

    return OAuthCredentials(
        access_token=access_token,
        expires_in=expires_in,
        created_time=datetime.now(),
        user_name=configs.CLIENT_CREDENTIALS_USER_NAME,
    )