diff options
author | Jasur Nurboyev <bluestacks6523@gmail.com> | 2021-11-24 13:08:31 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-11-24 13:31:54 +0000 |
commit | d4847ee07709d13e45d219d3c7834e2e99ca3c44 (patch) | |
tree | 67f37f15185bbbc5fa7a5d8837551d7a85ae929f /buildscripts | |
parent | e66c438f1823e75a4be3eafa1730e76c171531cc (diff) | |
download | mongo-d4847ee07709d13e45d219d3c7834e2e99ca3c44.tar.gz |
SERVER-61710 Added client-side authentication for mongosymb.py
Diffstat (limited to 'buildscripts')
-rwxr-xr-x | buildscripts/mongosymb.py | 66 | ||||
-rw-r--r-- | buildscripts/util/oauth.py | 300 |
2 files changed, 360 insertions, 6 deletions
diff --git a/buildscripts/mongosymb.py b/buildscripts/mongosymb.py index 1ea6ef5f6af..1f17c7b7485 100755 --- a/buildscripts/mongosymb.py +++ b/buildscripts/mongosymb.py @@ -17,16 +17,22 @@ You can also pass --output-format=json, to get rich json output. It shows some e but emits json instead of plain text. """ -import json import argparse +import json import os import subprocess import sys +import time from collections import OrderedDict +from pathlib import Path from typing import Dict import requests +# pylint: disable=wrong-import-position +sys.path.append(str(Path(os.getcwd(), __file__).parent.parent)) +from buildscripts.util.oauth import Configs, get_oauth_credentials + class PathDbgFileResolver(object): """PathDbgFileResolver class.""" @@ -150,10 +156,20 @@ class PathResolver(object): Cache size differs according to the situation, system resources and overall decision of development team. """ - default_host = 'http://127.0.0.1:8000' # the main (API) sever that we'll be sending requests to - default_cache_dir = os.path.join(os.getcwd(), 'dl_cache') + # pylint: disable=too-many-instance-attributes + # This amount of attributes are necessary. + + # the main (API) sever that we'll be sending requests to + default_host = 'https://symbolizer-service.server-tig.staging.corp.mongodb.com' + default_cache_dir = os.path.join(os.getcwd(), 'build', 'symbolizer_downloads_cache') + default_creds_file_path = os.path.join(os.getcwd(), '.symbolizer_credentials.json') + default_client_credentials_scope = "servertig-symbolizer-fullaccess" + default_client_credentials_user_name = "client-user" - def __init__(self, host: str = None, cache_size: int = 0, cache_dir: str = None): + def __init__(self, host: str = None, cache_size: int = 0, cache_dir: str = None, + client_credentials_scope: str = None, client_credentials_user_name: str = None, + client_id: str = None, redirect_port: int = None, scope: str = None, + auth_domain: str = None): """ Initialize instance. @@ -165,10 +181,47 @@ class PathResolver(object): self._cached_results = CachedResults(max_cache_size=cache_size) self.cache_dir = cache_dir or self.default_cache_dir self.mci_build_dir = None + self.client_credentials_scope = client_credentials_scope or self.default_client_credentials_scope + self.client_credentials_user_name = client_credentials_user_name or self.default_client_credentials_user_name + self.client_id = client_id + self.redirect_port = redirect_port + self.scope = scope + self.auth_domain = auth_domain + self.configs = Configs(client_credentials_scope=self.client_credentials_scope, + client_credentials_user_name=self.client_credentials_user_name, + client_id=self.client_id, auth_domain=self.auth_domain, + redirect_port=self.redirect_port, scope=self.scope) + self.http_client = requests.Session() # create cache dir if it doesn't exist if not os.path.exists(self.cache_dir): - os.mkdir(self.cache_dir) + os.makedirs(self.cache_dir) + + self.authenticate() + + def authenticate(self): + """Login & get credentials for further requests to web service.""" + + # try to read from file + if os.path.exists(self.default_creds_file_path): + with open(self.default_creds_file_path) as cfile: + data = json.loads(cfile.read()) + access_token, expire_time = data.get("access_token"), data.get("expire_time") + if time.time() < expire_time: + # credentials hasn't expired yet + self.http_client.headers.update({"Authorization": f"Bearer {access_token}"}) + return + + credentials = get_oauth_credentials(configs=self.configs, print_auth_url=True) + self.http_client.headers.update({"Authorization": f"Bearer {credentials.access_token}"}) + + # write credentials to local file for further useage + with open(self.default_creds_file_path, "w") as cfile: + cfile.write( + json.dumps({ + "access_token": credentials.access_token, + "expire_time": time.time() + credentials.expires_in + })) @staticmethod def is_valid_path(path: str) -> bool: @@ -259,7 +312,8 @@ class PathResolver(object): if not path: # path does not exist in cache, so we send request to server try: - response = requests.get(f'{self.host}/find_by_id', params={'build_id': build_id}) + response = self.http_client.get(f'{self.host}/find_by_id', + params={'build_id': build_id}) if response.status_code != 200: sys.stderr.write( f"Server returned unsuccessful status: {response.status_code}, " diff --git a/buildscripts/util/oauth.py b/buildscripts/util/oauth.py new file mode 100644 index 00000000000..7a6fabe9fde --- /dev/null +++ b/buildscripts/util/oauth.py @@ -0,0 +1,300 @@ +"""Helper tools to get OAuth credentials using the PKCE flow.""" +from __future__ import annotations + +from datetime import datetime, timedelta +from http.server import BaseHTTPRequestHandler, HTTPServer +from random import choice +from string import ascii_lowercase +from typing import Any, Callable, Optional, Tuple +from urllib.parse import parse_qs, urlsplit +from webbrowser import open as web_open + +import requests +from oauthlib.oauth2 import BackendApplicationClient +from pkce import generate_pkce_pair +from pydantic import ValidationError +from pydantic.main import BaseModel +from requests_oauthlib import OAuth2Session +from buildscripts.util.fileops import read_yaml_file + +AUTH_HANDLER_RESPONSE = """\ +<html> + <head> + <title>Authentication Status</title> + <script> + window.onload = function() { + window.close(); + } + </script> + </head> + <body> + <p>The authentication flow has completed.</p> + </body> +</html> +""".encode("utf-8") + + +class Configs: + """Collect configurations necessary for authentication process.""" + + # pylint: disable=invalid-name + + AUTH_DOMAIN = "corp.mongodb.com/oauth2/aus4k4jv00hWjNnps297" + CLIENT_ID = "0oa5zf9ps4N3JKWIJ297" + REDIRECT_PORT = 8989 + SCOPE = "kanopy+openid+profile" + + def __init__(self, client_credentials_scope: str = None, + client_credentials_user_name: str = None, auth_domain: str = None, + client_id: str = None, redirect_port: int = None, scope: str = None): + """Initialize configs instance.""" + + self.AUTH_DOMAIN = auth_domain or self.AUTH_DOMAIN + self.CLIENT_ID = client_id or self.CLIENT_ID + self.REDIRECT_PORT = redirect_port or self.REDIRECT_PORT + self.SCOPE = scope or self.SCOPE + self.CLIENT_CREDENTIALS_SCOPE = client_credentials_scope + self.CLIENT_CREDENTIALS_USER_NAME = client_credentials_user_name + + +class OAuthCredentials(BaseModel): + """OAuth access token and its associated metadata.""" + + expires_in: int + access_token: str + created_time: datetime + user_name: str + + def are_expired(self) -> bool: + """ + Check whether the current OAuth credentials are expired or not. + + :return: Whether the credentials are expired or not. + """ + return self.created_time + timedelta(seconds=self.expires_in) < datetime.now() + + @classmethod + def get_existing_credentials_from_file(cls, file_path: str) -> Optional[OAuthCredentials]: + """ + Try to get OAuth credentials from a file location. + + Will return None if credentials either don't exist or are expired. + :param file_path: Location to check for OAuth credentials. + :return: Valid OAuth credentials or None if valid credentials don't exist + """ + try: + creds = OAuthCredentials(**read_yaml_file(file_path)) + if (creds.access_token and creds.created_time and creds.expires_in and creds.user_name + and not creds.are_expired()): + return creds + else: + return None + except ValidationError: + return None + except OSError: + return None + + +class _RedirectServer(HTTPServer): + """HTTP server to use when fetching OAuth credentials using the PKCE flow.""" + + pkce_credentials: Optional[OAuthCredentials] = None + auth_domain: str + client_id: str + redirect_uri: str + code_verifier: str + + def __init__( + self, + server_address: Tuple[str, int], + handler: Callable[..., BaseHTTPRequestHandler], + redirect_uri: str, + auth_domain: str, + client_id: str, + code_verifier: str, + ): + self.redirect_uri = redirect_uri + self.auth_domain = auth_domain + self.client_id = client_id + self.code_verifier = code_verifier + super().__init__(server_address, handler) + + +class _Handler(BaseHTTPRequestHandler): + """Request handler class to use when trying to get OAuth credentials.""" + + # pylint: disable=invalid-name + + server: _RedirectServer + + def _set_response(self) -> None: + """Set the response to the server making a request.""" + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + def log_message(self, log_format: Any, *args: Any) -> None: # pylint: disable=unused-argument,arguments-differ + """ + Log HTTP Server internal messages. + + :param log_format: The format to use when logging messages. + :param args: Key word args. + """ + return None + + def do_GET(self) -> None: + """Handle the callback response from the auth server.""" + params = parse_qs(urlsplit(self.path).query) + code = params.get("code") + + if not code: + raise ValueError("Could not get authorization code when signing in to Okta") + + url = f"https://{self.server.auth_domain}/v1/token" + body = { + "grant_type": "authorization_code", + "client_id": self.server.client_id, + "code_verifier": self.server.code_verifier, + "code": code, + "redirect_uri": self.server.redirect_uri, + } + + resp = requests.post(url, data=body).json() + + access_token = resp.get("access_token") + expires_in = resp.get("expires_in") + + if not access_token or not expires_in: + raise ValueError("Could not get access token or expires_in data about access token") + + headers = {"Authorization": f"Bearer {access_token}"} + resp = requests.get(f"https://{self.server.auth_domain}/v1/userinfo", + headers=headers).json() + + split_username = resp["preferred_username"].split("@") + + if len(split_username) != 2: + raise ValueError("Could not get user_name of current user") + + self.server.pkce_credentials = OAuthCredentials( + access_token=access_token, + expires_in=expires_in, + created_time=datetime.now(), + user_name=split_username[0], + ) + self._set_response() + self.wfile.write(AUTH_HANDLER_RESPONSE) + + +class PKCEOauthTools: + """Basic toolset to get OAuth credentials using the PKCE flow.""" + + auth_domain: str + client_id: str + redirect_port: int + redirect_uri: str + scope: str + + def __init__(self, auth_domain: str, client_id: str, redirect_port: int, scope: str): + """ + Create a new PKCEOauth tools instance. + + :param auth_domain: The uri of the auth server to get the credentials from. + :param client_id: The id of the client that you are using to authenticate. + :param redirect_port: Port to use when setting up the local server for the auth redirect. + :param scope: The OAuth scopes to request access for. + """ + self.auth_domain = auth_domain + self.client_id = client_id + self.redirect_port = redirect_port + self.redirect_uri = f"http://localhost:{redirect_port}/" + self.scope = scope + + def get_pkce_credentials(self, print_auth_url: bool = False) -> OAuthCredentials: + """ + Try to get an OAuth access token and its associated metadata. + + :param print_auth_url: Whether to print the auth url to the console instead of opening it. + :return: OAuth credentials and some associated metadata to check if they have expired. + """ + code_verifier, code_challenge = generate_pkce_pair() + + state = "".join(choice(ascii_lowercase) for i in range(10)) + + authorization_url = (f"https://{self.auth_domain}/v1/authorize?" + f"scope={self.scope}&" + f"response_type=code&" + f"response_mode=query&" + f"client_id={self.client_id}&" + f"code_challenge={code_challenge}&" + f"state={state}&" + f"code_challenge_method=S256&" + f"redirect_uri={self.redirect_uri}") + + httpd = _RedirectServer( + ("", self.redirect_port), + _Handler, + self.redirect_uri, + self.auth_domain, + self.client_id, + code_verifier, + ) + if print_auth_url: + print("Please open the below url in a browser and sign in if necessary") + print(authorization_url) + else: + web_open(authorization_url) + httpd.handle_request() + + if not httpd.pkce_credentials: + raise ValueError( + "Could not retrieve Okta credentials to talk to Kanopy with. " + "Please sign out of Okta in your browser and try runnning this script again") + + return httpd.pkce_credentials + + +def get_oauth_credentials(configs: Configs, print_auth_url: bool = False) -> OAuthCredentials: + """ + Run the OAuth workflow to get credentials for a human user. + + :param configs: Configs instance. + :param print_auth_url: Whether to print the auth url to the console instead of opening it. + :return: OAuth credentials for the given user. + """ + oauth_tools = PKCEOauthTools(auth_domain=configs.AUTH_DOMAIN, client_id=configs.CLIENT_ID, + redirect_port=configs.REDIRECT_PORT, scope=configs.SCOPE) + credentials = oauth_tools.get_pkce_credentials(print_auth_url) + return credentials + + +def get_client_cred_oauth_credentials(client_id: str, client_secret: str, + configs: Configs) -> OAuthCredentials: + """ + Run the OAuth workflow to get credentials for a machine user. + + :param client_id: The client_id of the machine user to authenticate as. + :param client_secret: The client_secret of the machine user to authenticate as. + :param configs: Configs instance. + :return: OAuth credentials for the given machine user. + """ + client = BackendApplicationClient(client_id=client_id) + oauth = OAuth2Session(client=client) + token = oauth.fetch_token( + token_url=f"https://{configs.AUTH_DOMAIN}/v1/token", + client_id=client_id, + client_secret=client_secret, + scope=configs.CLIENT_CREDENTIALS_SCOPE, + ) + access_token = token.get("access_token") + expires_in = token.get("expires_in") + + if not access_token or not expires_in: + raise ValueError("Could not get access token or expires_in data about access token") + + return OAuthCredentials( + access_token=access_token, + expires_in=expires_in, + created_time=datetime.now(), + user_name=configs.CLIENT_CREDENTIALS_USER_NAME, + ) |