summaryrefslogtreecommitdiff
path: root/buildscripts
diff options
context:
space:
mode:
authorJasur Nurboyev <bluestacks6523@gmail.com>2021-11-24 13:08:31 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-11-24 13:31:54 +0000
commitd4847ee07709d13e45d219d3c7834e2e99ca3c44 (patch)
tree67f37f15185bbbc5fa7a5d8837551d7a85ae929f /buildscripts
parente66c438f1823e75a4be3eafa1730e76c171531cc (diff)
downloadmongo-d4847ee07709d13e45d219d3c7834e2e99ca3c44.tar.gz
SERVER-61710 Added client-side authentication for mongosymb.py
Diffstat (limited to 'buildscripts')
-rwxr-xr-xbuildscripts/mongosymb.py66
-rw-r--r--buildscripts/util/oauth.py300
2 files changed, 360 insertions, 6 deletions
diff --git a/buildscripts/mongosymb.py b/buildscripts/mongosymb.py
index 1ea6ef5f6af..1f17c7b7485 100755
--- a/buildscripts/mongosymb.py
+++ b/buildscripts/mongosymb.py
@@ -17,16 +17,22 @@ You can also pass --output-format=json, to get rich json output. It shows some e
but emits json instead of plain text.
"""
-import json
import argparse
+import json
import os
import subprocess
import sys
+import time
from collections import OrderedDict
+from pathlib import Path
from typing import Dict
import requests
+# pylint: disable=wrong-import-position
+sys.path.append(str(Path(os.getcwd(), __file__).parent.parent))
+from buildscripts.util.oauth import Configs, get_oauth_credentials
+
class PathDbgFileResolver(object):
"""PathDbgFileResolver class."""
@@ -150,10 +156,20 @@ class PathResolver(object):
Cache size differs according to the situation, system resources and overall decision of development team.
"""
- default_host = 'http://127.0.0.1:8000' # the main (API) sever that we'll be sending requests to
- default_cache_dir = os.path.join(os.getcwd(), 'dl_cache')
+ # pylint: disable=too-many-instance-attributes
+ # This amount of attributes are necessary.
+
+ # the main (API) sever that we'll be sending requests to
+ default_host = 'https://symbolizer-service.server-tig.staging.corp.mongodb.com'
+ default_cache_dir = os.path.join(os.getcwd(), 'build', 'symbolizer_downloads_cache')
+ default_creds_file_path = os.path.join(os.getcwd(), '.symbolizer_credentials.json')
+ default_client_credentials_scope = "servertig-symbolizer-fullaccess"
+ default_client_credentials_user_name = "client-user"
- def __init__(self, host: str = None, cache_size: int = 0, cache_dir: str = None):
+ def __init__(self, host: str = None, cache_size: int = 0, cache_dir: str = None,
+ client_credentials_scope: str = None, client_credentials_user_name: str = None,
+ client_id: str = None, redirect_port: int = None, scope: str = None,
+ auth_domain: str = None):
"""
Initialize instance.
@@ -165,10 +181,47 @@ class PathResolver(object):
self._cached_results = CachedResults(max_cache_size=cache_size)
self.cache_dir = cache_dir or self.default_cache_dir
self.mci_build_dir = None
+ self.client_credentials_scope = client_credentials_scope or self.default_client_credentials_scope
+ self.client_credentials_user_name = client_credentials_user_name or self.default_client_credentials_user_name
+ self.client_id = client_id
+ self.redirect_port = redirect_port
+ self.scope = scope
+ self.auth_domain = auth_domain
+ self.configs = Configs(client_credentials_scope=self.client_credentials_scope,
+ client_credentials_user_name=self.client_credentials_user_name,
+ client_id=self.client_id, auth_domain=self.auth_domain,
+ redirect_port=self.redirect_port, scope=self.scope)
+ self.http_client = requests.Session()
# create cache dir if it doesn't exist
if not os.path.exists(self.cache_dir):
- os.mkdir(self.cache_dir)
+ os.makedirs(self.cache_dir)
+
+ self.authenticate()
+
+ def authenticate(self):
+ """Login & get credentials for further requests to web service."""
+
+ # try to read from file
+ if os.path.exists(self.default_creds_file_path):
+ with open(self.default_creds_file_path) as cfile:
+ data = json.loads(cfile.read())
+ access_token, expire_time = data.get("access_token"), data.get("expire_time")
+ if time.time() < expire_time:
+ # credentials hasn't expired yet
+ self.http_client.headers.update({"Authorization": f"Bearer {access_token}"})
+ return
+
+ credentials = get_oauth_credentials(configs=self.configs, print_auth_url=True)
+ self.http_client.headers.update({"Authorization": f"Bearer {credentials.access_token}"})
+
+ # write credentials to local file for further useage
+ with open(self.default_creds_file_path, "w") as cfile:
+ cfile.write(
+ json.dumps({
+ "access_token": credentials.access_token,
+ "expire_time": time.time() + credentials.expires_in
+ }))
@staticmethod
def is_valid_path(path: str) -> bool:
@@ -259,7 +312,8 @@ class PathResolver(object):
if not path:
# path does not exist in cache, so we send request to server
try:
- response = requests.get(f'{self.host}/find_by_id', params={'build_id': build_id})
+ response = self.http_client.get(f'{self.host}/find_by_id',
+ params={'build_id': build_id})
if response.status_code != 200:
sys.stderr.write(
f"Server returned unsuccessful status: {response.status_code}, "
diff --git a/buildscripts/util/oauth.py b/buildscripts/util/oauth.py
new file mode 100644
index 00000000000..7a6fabe9fde
--- /dev/null
+++ b/buildscripts/util/oauth.py
@@ -0,0 +1,300 @@
+"""Helper tools to get OAuth credentials using the PKCE flow."""
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from random import choice
+from string import ascii_lowercase
+from typing import Any, Callable, Optional, Tuple
+from urllib.parse import parse_qs, urlsplit
+from webbrowser import open as web_open
+
+import requests
+from oauthlib.oauth2 import BackendApplicationClient
+from pkce import generate_pkce_pair
+from pydantic import ValidationError
+from pydantic.main import BaseModel
+from requests_oauthlib import OAuth2Session
+from buildscripts.util.fileops import read_yaml_file
+
+AUTH_HANDLER_RESPONSE = """\
+<html>
+ <head>
+ <title>Authentication Status</title>
+ <script>
+ window.onload = function() {
+ window.close();
+ }
+ </script>
+ </head>
+ <body>
+ <p>The authentication flow has completed.</p>
+ </body>
+</html>
+""".encode("utf-8")
+
+
+class Configs:
+ """Collect configurations necessary for authentication process."""
+
+ # pylint: disable=invalid-name
+
+ AUTH_DOMAIN = "corp.mongodb.com/oauth2/aus4k4jv00hWjNnps297"
+ CLIENT_ID = "0oa5zf9ps4N3JKWIJ297"
+ REDIRECT_PORT = 8989
+ SCOPE = "kanopy+openid+profile"
+
+ def __init__(self, client_credentials_scope: str = None,
+ client_credentials_user_name: str = None, auth_domain: str = None,
+ client_id: str = None, redirect_port: int = None, scope: str = None):
+ """Initialize configs instance."""
+
+ self.AUTH_DOMAIN = auth_domain or self.AUTH_DOMAIN
+ self.CLIENT_ID = client_id or self.CLIENT_ID
+ self.REDIRECT_PORT = redirect_port or self.REDIRECT_PORT
+ self.SCOPE = scope or self.SCOPE
+ self.CLIENT_CREDENTIALS_SCOPE = client_credentials_scope
+ self.CLIENT_CREDENTIALS_USER_NAME = client_credentials_user_name
+
+
+class OAuthCredentials(BaseModel):
+ """OAuth access token and its associated metadata."""
+
+ expires_in: int
+ access_token: str
+ created_time: datetime
+ user_name: str
+
+ def are_expired(self) -> bool:
+ """
+ Check whether the current OAuth credentials are expired or not.
+
+ :return: Whether the credentials are expired or not.
+ """
+ return self.created_time + timedelta(seconds=self.expires_in) < datetime.now()
+
+ @classmethod
+ def get_existing_credentials_from_file(cls, file_path: str) -> Optional[OAuthCredentials]:
+ """
+ Try to get OAuth credentials from a file location.
+
+ Will return None if credentials either don't exist or are expired.
+ :param file_path: Location to check for OAuth credentials.
+ :return: Valid OAuth credentials or None if valid credentials don't exist
+ """
+ try:
+ creds = OAuthCredentials(**read_yaml_file(file_path))
+ if (creds.access_token and creds.created_time and creds.expires_in and creds.user_name
+ and not creds.are_expired()):
+ return creds
+ else:
+ return None
+ except ValidationError:
+ return None
+ except OSError:
+ return None
+
+
+class _RedirectServer(HTTPServer):
+ """HTTP server to use when fetching OAuth credentials using the PKCE flow."""
+
+ pkce_credentials: Optional[OAuthCredentials] = None
+ auth_domain: str
+ client_id: str
+ redirect_uri: str
+ code_verifier: str
+
+ def __init__(
+ self,
+ server_address: Tuple[str, int],
+ handler: Callable[..., BaseHTTPRequestHandler],
+ redirect_uri: str,
+ auth_domain: str,
+ client_id: str,
+ code_verifier: str,
+ ):
+ self.redirect_uri = redirect_uri
+ self.auth_domain = auth_domain
+ self.client_id = client_id
+ self.code_verifier = code_verifier
+ super().__init__(server_address, handler)
+
+
+class _Handler(BaseHTTPRequestHandler):
+ """Request handler class to use when trying to get OAuth credentials."""
+
+ # pylint: disable=invalid-name
+
+ server: _RedirectServer
+
+ def _set_response(self) -> None:
+ """Set the response to the server making a request."""
+ self.send_response(200)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+
+ def log_message(self, log_format: Any, *args: Any) -> None: # pylint: disable=unused-argument,arguments-differ
+ """
+ Log HTTP Server internal messages.
+
+ :param log_format: The format to use when logging messages.
+ :param args: Key word args.
+ """
+ return None
+
+ def do_GET(self) -> None:
+ """Handle the callback response from the auth server."""
+ params = parse_qs(urlsplit(self.path).query)
+ code = params.get("code")
+
+ if not code:
+ raise ValueError("Could not get authorization code when signing in to Okta")
+
+ url = f"https://{self.server.auth_domain}/v1/token"
+ body = {
+ "grant_type": "authorization_code",
+ "client_id": self.server.client_id,
+ "code_verifier": self.server.code_verifier,
+ "code": code,
+ "redirect_uri": self.server.redirect_uri,
+ }
+
+ resp = requests.post(url, data=body).json()
+
+ access_token = resp.get("access_token")
+ expires_in = resp.get("expires_in")
+
+ if not access_token or not expires_in:
+ raise ValueError("Could not get access token or expires_in data about access token")
+
+ headers = {"Authorization": f"Bearer {access_token}"}
+ resp = requests.get(f"https://{self.server.auth_domain}/v1/userinfo",
+ headers=headers).json()
+
+ split_username = resp["preferred_username"].split("@")
+
+ if len(split_username) != 2:
+ raise ValueError("Could not get user_name of current user")
+
+ self.server.pkce_credentials = OAuthCredentials(
+ access_token=access_token,
+ expires_in=expires_in,
+ created_time=datetime.now(),
+ user_name=split_username[0],
+ )
+ self._set_response()
+ self.wfile.write(AUTH_HANDLER_RESPONSE)
+
+
+class PKCEOauthTools:
+ """Basic toolset to get OAuth credentials using the PKCE flow."""
+
+ auth_domain: str
+ client_id: str
+ redirect_port: int
+ redirect_uri: str
+ scope: str
+
+ def __init__(self, auth_domain: str, client_id: str, redirect_port: int, scope: str):
+ """
+ Create a new PKCEOauth tools instance.
+
+ :param auth_domain: The uri of the auth server to get the credentials from.
+ :param client_id: The id of the client that you are using to authenticate.
+ :param redirect_port: Port to use when setting up the local server for the auth redirect.
+ :param scope: The OAuth scopes to request access for.
+ """
+ self.auth_domain = auth_domain
+ self.client_id = client_id
+ self.redirect_port = redirect_port
+ self.redirect_uri = f"http://localhost:{redirect_port}/"
+ self.scope = scope
+
+ def get_pkce_credentials(self, print_auth_url: bool = False) -> OAuthCredentials:
+ """
+ Try to get an OAuth access token and its associated metadata.
+
+ :param print_auth_url: Whether to print the auth url to the console instead of opening it.
+ :return: OAuth credentials and some associated metadata to check if they have expired.
+ """
+ code_verifier, code_challenge = generate_pkce_pair()
+
+ state = "".join(choice(ascii_lowercase) for i in range(10))
+
+ authorization_url = (f"https://{self.auth_domain}/v1/authorize?"
+ f"scope={self.scope}&"
+ f"response_type=code&"
+ f"response_mode=query&"
+ f"client_id={self.client_id}&"
+ f"code_challenge={code_challenge}&"
+ f"state={state}&"
+ f"code_challenge_method=S256&"
+ f"redirect_uri={self.redirect_uri}")
+
+ httpd = _RedirectServer(
+ ("", self.redirect_port),
+ _Handler,
+ self.redirect_uri,
+ self.auth_domain,
+ self.client_id,
+ code_verifier,
+ )
+ if print_auth_url:
+ print("Please open the below url in a browser and sign in if necessary")
+ print(authorization_url)
+ else:
+ web_open(authorization_url)
+ httpd.handle_request()
+
+ if not httpd.pkce_credentials:
+ raise ValueError(
+ "Could not retrieve Okta credentials to talk to Kanopy with. "
+ "Please sign out of Okta in your browser and try runnning this script again")
+
+ return httpd.pkce_credentials
+
+
+def get_oauth_credentials(configs: Configs, print_auth_url: bool = False) -> OAuthCredentials:
+ """
+ Run the OAuth workflow to get credentials for a human user.
+
+ :param configs: Configs instance.
+ :param print_auth_url: Whether to print the auth url to the console instead of opening it.
+ :return: OAuth credentials for the given user.
+ """
+ oauth_tools = PKCEOauthTools(auth_domain=configs.AUTH_DOMAIN, client_id=configs.CLIENT_ID,
+ redirect_port=configs.REDIRECT_PORT, scope=configs.SCOPE)
+ credentials = oauth_tools.get_pkce_credentials(print_auth_url)
+ return credentials
+
+
+def get_client_cred_oauth_credentials(client_id: str, client_secret: str,
+ configs: Configs) -> OAuthCredentials:
+ """
+ Run the OAuth workflow to get credentials for a machine user.
+
+ :param client_id: The client_id of the machine user to authenticate as.
+ :param client_secret: The client_secret of the machine user to authenticate as.
+ :param configs: Configs instance.
+ :return: OAuth credentials for the given machine user.
+ """
+ client = BackendApplicationClient(client_id=client_id)
+ oauth = OAuth2Session(client=client)
+ token = oauth.fetch_token(
+ token_url=f"https://{configs.AUTH_DOMAIN}/v1/token",
+ client_id=client_id,
+ client_secret=client_secret,
+ scope=configs.CLIENT_CREDENTIALS_SCOPE,
+ )
+ access_token = token.get("access_token")
+ expires_in = token.get("expires_in")
+
+ if not access_token or not expires_in:
+ raise ValueError("Could not get access token or expires_in data about access token")
+
+ return OAuthCredentials(
+ access_token=access_token,
+ expires_in=expires_in,
+ created_time=datetime.now(),
+ user_name=configs.CLIENT_CREDENTIALS_USER_NAME,
+ )