diff options
Diffstat (limited to 'extra')
55 files changed, 8768 insertions, 7419 deletions
diff --git a/extra/cr50_rma_open/cr50_rma_open.py b/extra/cr50_rma_open/cr50_rma_open.py index 42ddbbac2d..dc9c144158 100755 --- a/extra/cr50_rma_open/cr50_rma_open.py +++ b/extra/cr50_rma_open/cr50_rma_open.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2018 The Chromium OS Authors. All rights reserved. +# Copyright 2018 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. @@ -50,23 +50,22 @@ import subprocess import sys import time -import serial +import serial # pylint:disable=import-error SCRIPT_VERSION = 5 CCD_IS_UNRESTRICTED = 1 << 0 WP_IS_DISABLED = 1 << 1 TESTLAB_IS_ENABLED = 1 << 2 RMA_OPENED = CCD_IS_UNRESTRICTED | WP_IS_DISABLED -URL = ('https://www.google.com/chromeos/partner/console/cr50reset?' - 'challenge=%s&hwid=%s') -RMA_SUPPORT_PROD = '0.3.3' -RMA_SUPPORT_PREPVT = '0.4.5' -DEV_MODE_OPEN_PROD = '0.3.9' -DEV_MODE_OPEN_PREPVT = '0.4.7' -TESTLAB_PROD = '0.3.10' -CR50_USB = '18d1:5014' -CR50_LSUSB_CMD = ['lsusb', '-vd', CR50_USB] -ERASED_BID = 'ffffffff' +URL = "https://www.google.com/chromeos/partner/console/cr50reset?challenge=%s&hwid=%s" +RMA_SUPPORT_PROD = "0.3.3" +RMA_SUPPORT_PREPVT = "0.4.5" +DEV_MODE_OPEN_PROD = "0.3.9" +DEV_MODE_OPEN_PREPVT = "0.4.7" +TESTLAB_PROD = "0.3.10" +CR50_USB = "18d1:5014" +CR50_LSUSB_CMD = ["lsusb", "-vd", CR50_USB] +ERASED_BID = "ffffffff" DEBUG_MISSING_USB = """ Unable to find Cr50 Device 18d1:5014 @@ -128,13 +127,14 @@ DEBUG_DUT_CONTROL_OSERROR = """ Run from chroot if you are trying to use a /dev/pts ccd servo console """ + class RMAOpen(object): """Used to find the cr50 console and run RMA open""" - ENABLE_TESTLAB_CMD = 'ccd testlab enabled\n' + ENABLE_TESTLAB_CMD = "ccd testlab enabled\n" def __init__(self, device=None, usb_serial=None, servo_port=None, ip=None): - self.servo_port = servo_port if servo_port else '9999' + self.servo_port = servo_port if servo_port else "9999" self.ip = ip if device: self.set_cr50_device(device) @@ -142,18 +142,18 @@ class RMAOpen(object): self.find_cr50_servo_uart() else: self.find_cr50_device(usb_serial) - logging.info('DEVICE: %s', self.device) + logging.info("DEVICE: %s", self.device) self.check_version() self.print_platform_info() - logging.info('Cr50 setup ok') + logging.info("Cr50 setup ok") self.update_ccd_state() self.using_ccd = self.device_is_running_with_servo_ccd() def _dut_control(self, control): """Run dut-control and return the response""" try: - cmd = ['dut-control', '-p', self.servo_port, control] - return subprocess.check_output(cmd, encoding='utf-8').strip() + cmd = ["dut-control", "-p", self.servo_port, control] + return subprocess.check_output(cmd, encoding="utf-8").strip() except OSError: logging.warning(DEBUG_DUT_CONTROL_OSERROR) raise @@ -163,8 +163,8 @@ class RMAOpen(object): Find the console and configure it, so it can be used with this script. """ - self._dut_control('cr50_uart_timestamp:off') - self.device = self._dut_control('cr50_uart_pty').split(':')[-1] + self._dut_control("cr50_uart_timestamp:off") + self.device = self._dut_control("cr50_uart_pty").split(":")[-1] def set_cr50_device(self, device): """Save the device used for the console""" @@ -183,38 +183,38 @@ class RMAOpen(object): try: ser = serial.Serial(self.device, timeout=1) except OSError: - logging.warning('Permission denied %s', self.device) - logging.warning('Try running cr50_rma_open with sudo') + logging.warning("Permission denied %s", self.device) + logging.warning("Try running cr50_rma_open with sudo") raise - write_cmd = cmd + '\n\n' - ser.write(write_cmd.encode('utf-8')) + write_cmd = cmd + "\n\n" + ser.write(write_cmd.encode("utf-8")) if nbytes: output = ser.read(nbytes) else: output = ser.readall() ser.close() - output = output.decode('utf-8').strip() if output else '' + output = output.decode("utf-8").strip() if output else "" # Return only the command output - split_cmd = cmd + '\r' + split_cmd = cmd + "\r" if cmd and split_cmd in output: - return ''.join(output.rpartition(split_cmd)[1::]).split('>')[0] + return "".join(output.rpartition(split_cmd)[1::]).split(">")[0] return output def device_is_running_with_servo_ccd(self): """Return True if the device is a servod ccd console""" # servod uses /dev/pts consoles. Non-servod uses /dev/ttyUSBX - if '/dev/pts' not in self.device: + if "/dev/pts" not in self.device: return False # If cr50 doesn't show rdd is connected, cr50 the device must not be # a ccd device - if 'Rdd: connected' not in self.send_cmd_get_output('ccdstate'): + if "Rdd: connected" not in self.send_cmd_get_output("ccdstate"): return False # Check if the servod is running with ccd. This requires the script # is run in the chroot, so run it last. - if 'ccd_cr50' not in self._dut_control('servo_type'): + if "ccd_cr50" not in self._dut_control("servo_type"): return False - logging.info('running through servod ccd') + logging.info("running through servod ccd") return True def get_rma_challenge(self): @@ -239,14 +239,14 @@ class RMAOpen(object): Returns: The RMA challenge with all whitespace removed. """ - output = self.send_cmd_get_output('rma_auth').strip() - logging.info('rma_auth output:\n%s', output) + output = self.send_cmd_get_output("rma_auth").strip() + logging.info("rma_auth output:\n%s", output) # Extract the challenge from the console output - if 'generated challenge:' in output: - return output.split('generated challenge:')[-1].strip() - challenge = ''.join(re.findall(r' \S{5}' * 4, output)) + if "generated challenge:" in output: + return output.split("generated challenge:")[-1].strip() + challenge = "".join(re.findall(r" \S{5}" * 4, output)) # Remove all whitespace - return re.sub(r'\s', '', challenge) + return re.sub(r"\s", "", challenge) def generate_challenge_url(self, hwid): """Get the rma_auth challenge @@ -257,12 +257,14 @@ class RMAOpen(object): challenge = self.get_rma_challenge() self.print_platform_info() - logging.info('CHALLENGE: %s', challenge) - logging.info('HWID: %s', hwid) + logging.info("CHALLENGE: %s", challenge) + logging.info("HWID: %s", hwid) url = URL % (challenge, hwid) - logging.info('GOTO:\n %s', url) - logging.info('If the server fails to debug the challenge make sure the ' - 'RLZ is allowlisted') + logging.info("GOTO:\n %s", url) + logging.info( + "If the server fails to debug the challenge make sure the " + "RLZ is allowlisted" + ) def try_authcode(self, authcode): """Try opening cr50 with the authcode @@ -272,48 +274,48 @@ class RMAOpen(object): """ # rma_auth may cause the system to reboot. Don't wait to read all that # output. Read the first 300 bytes and call it a day. - output = self.send_cmd_get_output('rma_auth ' + authcode, nbytes=300) - logging.info('CR50 RESPONSE: %s', output) - logging.info('waiting for cr50 reboot') + output = self.send_cmd_get_output("rma_auth " + authcode, nbytes=300) + logging.info("CR50 RESPONSE: %s", output) + logging.info("waiting for cr50 reboot") # Cr50 may be rebooting. Wait a bit time.sleep(5) if self.using_ccd: # After reboot, reset the ccd endpoints - self._dut_control('power_state:ccd_reset') + self._dut_control("power_state:ccd_reset") # Update the ccd state after the authcode attempt self.update_ccd_state() - authcode_match = 'process_response: success!' in output + authcode_match = "process_response: success!" in output if not self.check(CCD_IS_UNRESTRICTED): if not authcode_match: logging.warning(DEBUG_AUTHCODE_MISMATCH) - message = 'Authcode mismatch. Check args and url' + message = "Authcode mismatch. Check args and url" else: - message = 'Could not set all capability privileges to Always' + message = "Could not set all capability privileges to Always" raise ValueError(message) def wp_is_force_disabled(self): """Returns True if write protect is forced disabled""" - output = self.send_cmd_get_output('wp') - wp_state = output.split('Flash WP:', 1)[-1].split('\n', 1)[0].strip() - logging.info('wp: %s', wp_state) - return wp_state == 'forced disabled' + output = self.send_cmd_get_output("wp") + wp_state = output.split("Flash WP:", 1)[-1].split("\n", 1)[0].strip() + logging.info("wp: %s", wp_state) + return wp_state == "forced disabled" def testlab_is_enabled(self): """Returns True if testlab mode is enabled""" - output = self.send_cmd_get_output('ccd testlab') - testlab_state = output.split('mode')[-1].strip().lower() - logging.info('testlab: %s', testlab_state) - return testlab_state == 'enabled' + output = self.send_cmd_get_output("ccd testlab") + testlab_state = output.split("mode")[-1].strip().lower() + logging.info("testlab: %s", testlab_state) + return testlab_state == "enabled" def ccd_is_restricted(self): """Returns True if any of the capabilities are still restricted""" - output = self.send_cmd_get_output('ccd') - if 'Capabilities' not in output: - raise ValueError('Could not get ccd output') - logging.debug('CURRENT CCD SETTINGS:\n%s', output) - restricted = 'IfOpened' in output or 'IfUnlocked' in output - logging.info('ccd: %srestricted', '' if restricted else 'Un') + output = self.send_cmd_get_output("ccd") + if "Capabilities" not in output: + raise ValueError("Could not get ccd output") + logging.debug("CURRENT CCD SETTINGS:\n%s", output) + restricted = "IfOpened" in output or "IfUnlocked" in output + logging.info("ccd: %srestricted", "" if restricted else "Un") return restricted def update_ccd_state(self): @@ -339,9 +341,10 @@ class RMAOpen(object): def _capabilities_allow_open_from_console(self): """Return True if ccd open is Always allowed from usb""" - output = self.send_cmd_get_output('ccd') - return (re.search('OpenNoDevMode.*Always', output) and - re.search('OpenFromUSB.*Always', output)) + output = self.send_cmd_get_output("ccd") + return re.search("OpenNoDevMode.*Always", output) and re.search( + "OpenFromUSB.*Always", output + ) def _requires_dev_mode_open(self): """Return True if the image requires dev mode to open""" @@ -354,78 +357,83 @@ class RMAOpen(object): def _run_on_dut(self, command): """Run the command on the DUT.""" - return subprocess.check_output(['ssh', self.ip, command], - encoding='utf-8') + return subprocess.check_output( + ["ssh", self.ip, command], encoding="utf-8" + ) def _open_in_dev_mode(self): """Open Cr50 when it's in dev mode""" - output = self.send_cmd_get_output('ccd') + output = self.send_cmd_get_output("ccd") # If the device is already open, nothing needs to be done. - if 'State: Open' not in output: + if "State: Open" not in output: # Verify the device is in devmode before trying to run open. - if 'dev_mode' not in output: - logging.warning('Enter dev mode to open ccd or update to %s', - TESTLAB_PROD) - raise ValueError('DUT not in dev mode') + if "dev_mode" not in output: + logging.warning( + "Enter dev mode to open ccd or update to %s", TESTLAB_PROD + ) + raise ValueError("DUT not in dev mode") if not self.ip: - logging.warning("If your DUT doesn't have ssh support, run " - "'gsctool -a -o' from the AP") - raise ValueError('Cannot run ccd open without dut ip') - self._run_on_dut('gsctool -a -o') + logging.warning( + "If your DUT doesn't have ssh support, run " + "'gsctool -a -o' from the AP" + ) + raise ValueError("Cannot run ccd open without dut ip") + self._run_on_dut("gsctool -a -o") # Wait >1 second for cr50 to update ccd state time.sleep(3) - output = self.send_cmd_get_output('ccd') - if 'State: Open' not in output: - raise ValueError('Could not open cr50') - logging.info('ccd is open') + output = self.send_cmd_get_output("ccd") + if "State: Open" not in output: + raise ValueError("Could not open cr50") + logging.info("ccd is open") def enable_testlab(self): """Disable write protect""" if not self._has_testlab_support(): - logging.warning('Testlab mode is not supported in prod iamges') + logging.warning("Testlab mode is not supported in prod iamges") return # Some cr50 images need to be in dev mode before they can be opened. if self._requires_dev_mode_open(): self._open_in_dev_mode() else: - self.send_cmd_get_output('ccd open') - logging.info('Enabling testlab mode reqires pressing the power button.') - logging.info('Once the process starts keep tapping the power button ' - 'for 10 seconds.') + self.send_cmd_get_output("ccd open") + logging.info("Enabling testlab mode reqires pressing the power button.") + logging.info( + "Once the process starts keep tapping the power button for 10 seconds." + ) input("Press Enter when you're ready to start...") end_time = time.time() + 15 ser = serial.Serial(self.device, timeout=1) - printed_lines = '' - output = '' + printed_lines = "" + output = "" # start ccd testlab enable - ser.write(self.ENABLE_TESTLAB_CMD.encode('utf-8')) - logging.info('start pressing the power button\n\n') + ser.write(self.ENABLE_TESTLAB_CMD.encode("utf-8")) + logging.info("start pressing the power button\n\n") # Print all of the cr50 output as we get it, so the user will have more # information about pressing the power button. Tapping the power button # a couple of times should do it, but this will give us more confidence # the process is still running/worked. try: while time.time() < end_time: - output += ser.read(100).decode('utf-8') - full_lines = output.rsplit('\n', 1)[0] + output += ser.read(100).decode("utf-8") + full_lines = output.rsplit("\n", 1)[0] new_lines = full_lines if printed_lines: new_lines = full_lines.split(printed_lines, 1)[-1].strip() - logging.info('\n%s', new_lines) + logging.info("\n%s", new_lines) printed_lines = full_lines # Make sure the process hasn't ended. If it has, print the last # of the output and exit. new_lines = output.split(printed_lines, 1)[-1] - if 'CCD test lab mode enabled' in output: + if "CCD test lab mode enabled" in output: # print the last of the ou logging.info(new_lines) break - elif 'Physical presence check timeout' in output: + elif "Physical presence check timeout" in output: logging.info(new_lines) - logging.warning('Did not detect power button press in time') - raise ValueError('Could not enable testlab mode try again') + logging.warning("Did not detect power button press in time") + raise ValueError("Could not enable testlab mode try again") finally: ser.close() # Wait for the ccd hook to update things @@ -433,44 +441,50 @@ class RMAOpen(object): # Update the state after attempting to disable write protect self.update_ccd_state() if not self.check(TESTLAB_IS_ENABLED): - raise ValueError('Could not enable testlab mode try again') + raise ValueError("Could not enable testlab mode try again") def wp_disable(self): """Disable write protect""" - logging.info('Disabling write protect') - self.send_cmd_get_output('wp disable') + logging.info("Disabling write protect") + self.send_cmd_get_output("wp disable") # Update the state after attempting to disable write protect self.update_ccd_state() if not self.check(WP_IS_DISABLED): - raise ValueError('Could not disable write protect') + raise ValueError("Could not disable write protect") def check_version(self): """Make sure cr50 is running a version that supports RMA Open""" - output = self.send_cmd_get_output('version') + output = self.send_cmd_get_output("version") if not output.strip(): logging.warning(DEBUG_DEVICE, self.device) - raise ValueError('Could not communicate with %s' % self.device) + raise ValueError("Could not communicate with %s" % self.device) - version = re.search(r'RW.*\* ([\d\.]+)/', output).group(1) - logging.info('Running Cr50 Version: %s', version) - self.running_ver_fields = [int(field) for field in version.split('.')] + version = re.search(r"RW.*\* ([\d\.]+)/", output).group(1) + logging.info("Running Cr50 Version: %s", version) + self.running_ver_fields = [int(field) for field in version.split(".")] # prePVT images have even major versions. Prod have odd self.is_prepvt = self.running_ver_fields[1] % 2 == 0 rma_support = RMA_SUPPORT_PREPVT if self.is_prepvt else RMA_SUPPORT_PROD - logging.info('%s RMA support added in: %s', - 'prePVT' if self.is_prepvt else 'prod', rma_support) + logging.info( + "%s RMA support added in: %s", + "prePVT" if self.is_prepvt else "prod", + rma_support, + ) if not self.is_prepvt and self._running_version_is_older(TESTLAB_PROD): - raise ValueError('Update cr50. No testlab support in old prod ' - 'images.') + raise ValueError( + "Update cr50. No testlab support in old prod images." + ) if self._running_version_is_older(rma_support): - raise ValueError('%s does not have RMA support. Update to at ' - 'least %s' % (version, rma_support)) + raise ValueError( + "%s does not have RMA support. Update to at least %s" + % (version, rma_support) + ) def _running_version_is_older(self, target_ver): """Returns True if running version is older than target_ver.""" - target_ver_fields = [int(field) for field in target_ver.split('.')] + target_ver_fields = [int(field) for field in target_ver.split(".")] for i, field in enumerate(self.running_ver_fields): if field > int(target_ver_fields[i]): return False @@ -486,11 +500,11 @@ class RMAOpen(object): is no output or sysinfo doesn't contain the devid. """ self.set_cr50_device(device) - sysinfo = self.send_cmd_get_output('sysinfo') + sysinfo = self.send_cmd_get_output("sysinfo") # Make sure there is some output, and it shows it's from Cr50 - if not sysinfo or 'cr50' not in sysinfo: + if not sysinfo or "cr50" not in sysinfo: return False - logging.debug('Sysinfo output: %s', sysinfo) + logging.debug("Sysinfo output: %s", sysinfo) # The cr50 device id should be in the sysinfo output, if we found # the right console. Make sure it is return devid in sysinfo @@ -508,104 +522,150 @@ class RMAOpen(object): ValueError if the console can't be found with the given serialname """ usb_serial = self.find_cr50_usb(usb_serial) - logging.info('SERIALNAME: %s', usb_serial) - devid = '0x' + ' 0x'.join(usb_serial.lower().split('-')) - logging.info('DEVID: %s', devid) + logging.info("SERIALNAME: %s", usb_serial) + devid = "0x" + " 0x".join(usb_serial.lower().split("-")) + logging.info("DEVID: %s", devid) # Get all the usb devices - devices = glob.glob('/dev/ttyUSB*') + devices = glob.glob("/dev/ttyUSB*") # Typically Cr50 has the lowest number. Sort the devices, so we're more # likely to try the cr50 console first. devices.sort() # Find the one that is the cr50 console for device in devices: - logging.info('testing %s', device) + logging.info("testing %s", device) if self.device_matches_devid(devid, device): - logging.info('found device: %s', device) + logging.info("found device: %s", device) return logging.warning(DEBUG_CONNECTION) - raise ValueError('Found USB device, but could not communicate with ' - 'cr50 console') + raise ValueError( + "Found USB device, but could not communicate with cr50 console" + ) def print_platform_info(self): """Print the cr50 BID RLZ code""" - bid_output = self.send_cmd_get_output('bid') - bid = re.search(r'Board ID: (\S+?)[:,]', bid_output).group(1) + bid_output = self.send_cmd_get_output("bid") + bid = re.search(r"Board ID: (\S+?)[:,]", bid_output).group(1) if bid == ERASED_BID: logging.warning(DEBUG_ERASED_BOARD_ID) - raise ValueError('Cannot run RMA Open when board id is erased') + raise ValueError("Cannot run RMA Open when board id is erased") bid = int(bid, 16) - chrs = [chr((bid >> (8 * i)) & 0xff) for i in range(4)] - logging.info('RLZ: %s', ''.join(chrs[::-1])) + chrs = [chr((bid >> (8 * i)) & 0xFF) for i in range(4)] + logging.info("RLZ: %s", "".join(chrs[::-1])) @staticmethod def find_cr50_usb(usb_serial): """Make sure the Cr50 USB device exists""" try: - output = subprocess.check_output(CR50_LSUSB_CMD, encoding='utf-8') + output = subprocess.check_output(CR50_LSUSB_CMD, encoding="utf-8") except: logging.warning(DEBUG_MISSING_USB) - raise ValueError('Could not find Cr50 USB device') - serialnames = re.findall(r'iSerial +\d+ (\S+)\s', output) + raise ValueError("Could not find Cr50 USB device") + serialnames = re.findall(r"iSerial +\d+ (\S+)\s", output) if usb_serial: if usb_serial not in serialnames: logging.warning(DEBUG_SERIALNAME) raise ValueError('Could not find usb device "%s"' % usb_serial) return usb_serial if len(serialnames) > 1: - logging.info('Found Cr50 device serialnames %s', - ', '.join(serialnames)) + logging.info( + "Found Cr50 device serialnames %s", ", ".join(serialnames) + ) logging.warning(DEBUG_TOO_MANY_USB_DEVICES) - raise ValueError('Too many cr50 usb devices') + raise ValueError("Too many cr50 usb devices") return serialnames[0] def print_dut_state(self): """Print CCD RMA and testlab mode state.""" if not self.check(CCD_IS_UNRESTRICTED): - logging.info('CCD is still restricted.') - logging.info('Run cr50_rma_open.py -g -i $HWID to generate a url') - logging.info('Run cr50_rma_open.py -a $AUTHCODE to open cr50 with ' - 'an authcode') + logging.info("CCD is still restricted.") + logging.info("Run cr50_rma_open.py -g -i $HWID to generate a url") + logging.info( + "Run cr50_rma_open.py -a $AUTHCODE to open cr50 with an authcode" + ) elif not self.check(WP_IS_DISABLED): - logging.info('WP is still enabled.') - logging.info('Run cr50_rma_open.py -w to disable write protect') + logging.info("WP is still enabled.") + logging.info("Run cr50_rma_open.py -w to disable write protect") if self.check(RMA_OPENED): - logging.info('RMA Open complete') + logging.info("RMA Open complete") if not self.check(TESTLAB_IS_ENABLED) and self.is_prepvt: - logging.info('testlab mode is disabled.') - logging.info('If you are prepping a device for the testlab, you ' - 'should enable testlab mode.') - logging.info('Run cr50_rma_open.py -t to enable testlab mode') + logging.info("testlab mode is disabled.") + logging.info( + "If you are prepping a device for the testlab, you " + "should enable testlab mode." + ) + logging.info("Run cr50_rma_open.py -t to enable testlab mode") def parse_args(argv): """Get cr50_rma_open args.""" parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('-g', '--generate_challenge', action='store_true', - help='Generate Cr50 challenge. Must be used with -i') - parser.add_argument('-t', '--enable_testlab', action='store_true', - help='enable testlab mode') - parser.add_argument('-w', '--wp_disable', action='store_true', - help='Disable write protect') - parser.add_argument('-c', '--check_connection', action='store_true', - help='Check cr50 console connection works') - parser.add_argument('-s', '--serialname', type=str, default='', - help='The cr50 usb serialname') - parser.add_argument('-D', '--debug', action='store_true', - help='print debug messages') - parser.add_argument('-d', '--device', type=str, default='', - help='cr50 console device ex /dev/ttyUSB0') - parser.add_argument('-i', '--hwid', type=str, default='', - help='The board hwid. Needed to generate a challenge') - parser.add_argument('-a', '--authcode', type=str, default='', - help='The authcode string from the challenge url') - parser.add_argument('-P', '--servo_port', type=str, default='', - help='the servo port') - parser.add_argument('-I', '--ip', type=str, default='', - help='The DUT IP. Necessary to do ccd open') + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "-g", + "--generate_challenge", + action="store_true", + help="Generate Cr50 challenge. Must be used with -i", + ) + parser.add_argument( + "-t", + "--enable_testlab", + action="store_true", + help="enable testlab mode", + ) + parser.add_argument( + "-w", "--wp_disable", action="store_true", help="Disable write protect" + ) + parser.add_argument( + "-c", + "--check_connection", + action="store_true", + help="Check cr50 console connection works", + ) + parser.add_argument( + "-s", + "--serialname", + type=str, + default="", + help="The cr50 usb serialname", + ) + parser.add_argument( + "-D", "--debug", action="store_true", help="print debug messages" + ) + parser.add_argument( + "-d", + "--device", + type=str, + default="", + help="cr50 console device ex /dev/ttyUSB0", + ) + parser.add_argument( + "-i", + "--hwid", + type=str, + default="", + help="The board hwid. Needed to generate a challenge", + ) + parser.add_argument( + "-a", + "--authcode", + type=str, + default="", + help="The authcode string from the challenge url", + ) + parser.add_argument( + "-P", "--servo_port", type=str, default="", help="the servo port" + ) + parser.add_argument( + "-I", + "--ip", + type=str, + default="", + help="The DUT IP. Necessary to do ccd open", + ) return parser.parse_args(argv) @@ -614,52 +674,59 @@ def main(argv): opts = parse_args(argv) loglevel = logging.INFO - log_format = '%(levelname)7s' + log_format = "%(levelname)7s" if opts.debug: loglevel = logging.DEBUG - log_format += ' - %(lineno)3d:%(funcName)-15s' - log_format += ' - %(message)s' + log_format += " - %(lineno)3d:%(funcName)-15s" + log_format += " - %(message)s" logging.basicConfig(level=loglevel, format=log_format) tried_authcode = False - logging.info('Running cr50_rma_open version %s', SCRIPT_VERSION) + logging.info("Running cr50_rma_open version %s", SCRIPT_VERSION) - cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, - opts.ip) + cr50_rma_open = RMAOpen( + opts.device, opts.serialname, opts.servo_port, opts.ip + ) if opts.check_connection: sys.exit(0) if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): if opts.generate_challenge: if not opts.hwid: - logging.warning('--hwid necessary to generate challenge url') + logging.warning("--hwid necessary to generate challenge url") sys.exit(0) cr50_rma_open.generate_challenge_url(opts.hwid) sys.exit(0) elif opts.authcode: - logging.info('Using authcode: %s', opts.authcode) + logging.info("Using authcode: %s", opts.authcode) cr50_rma_open.try_authcode(opts.authcode) tried_authcode = True - if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or - opts.wp_disable): + if not cr50_rma_open.check(WP_IS_DISABLED) and ( + tried_authcode or opts.wp_disable + ): if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): - raise ValueError("Can't disable write protect unless ccd is " - "open. Run through the rma open process first") + raise ValueError( + "Can't disable write protect unless ccd is " + "open. Run through the rma open process first" + ) if tried_authcode: - logging.warning('RMA Open did not disable write protect. File a ' - 'bug') - logging.warning('Trying to disable it manually') + logging.warning( + "RMA Open did not disable write protect. File a bug" + ) + logging.warning("Trying to disable it manually") cr50_rma_open.wp_disable() if not cr50_rma_open.check(TESTLAB_IS_ENABLED) and opts.enable_testlab: if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): - raise ValueError("Can't enable testlab mode unless ccd is open." - "Run through the rma open process first") + raise ValueError( + "Can't enable testlab mode unless ccd is open." + "Run through the rma open process first" + ) cr50_rma_open.enable_testlab() cr50_rma_open.print_dut_state() -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/extra/ftdi_hostcmd/Makefile b/extra/ftdi_hostcmd/Makefile index d46b4b1c72..10f0e2d390 100644 --- a/extra/ftdi_hostcmd/Makefile +++ b/extra/ftdi_hostcmd/Makefile @@ -1,4 +1,4 @@ -# Copyright 2015 The Chromium OS Authors. All rights reserved. +# Copyright 2015 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/ftdi_hostcmd/test_cmds.c b/extra/ftdi_hostcmd/test_cmds.c index 4552476d0f..7bd3413032 100644 --- a/extra/ftdi_hostcmd/test_cmds.c +++ b/extra/ftdi_hostcmd/test_cmds.c @@ -1,4 +1,4 @@ -/* Copyright 2015 The Chromium OS Authors. All rights reserved. +/* Copyright 2015 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -22,7 +22,7 @@ static struct mpsse_context *mpsse; /* enum ec_status meaning */ static const char *ec_strerr(enum ec_status r) { - static const char * const strs[] = { + static const char *const strs[] = { "SUCCESS", "INVALID_COMMAND", "ERROR", @@ -47,10 +47,9 @@ static const char *ec_strerr(enum ec_status r) return "<undefined result>"; }; - -/**************************************************************************** - * Debugging output - */ + /**************************************************************************** + * Debugging output + */ #define LINELEN 16 @@ -64,8 +63,7 @@ static void showline(uint8_t *buf, int len) printf(" "); printf(" "); for (i = 0; i < len; i++) - printf("%c", - (buf[i] >= ' ' && buf[i] <= '~') ? buf[i] : '.'); + printf("%c", (buf[i] >= ' ' && buf[i] <= '~') ? buf[i] : '.'); printf("\n"); } @@ -104,8 +102,8 @@ static uint8_t txbuf[128]; * Load the output buffer with a proto v3 request (header, then data, with * checksum correct in header). */ -static size_t prepare_request(int cmd, int version, - const uint8_t *data, size_t data_len) +static size_t prepare_request(int cmd, int version, const uint8_t *data, + size_t data_len) { struct ec_host_request *request; size_t i, total_len; @@ -113,8 +111,8 @@ static size_t prepare_request(int cmd, int version, total_len = sizeof(*request) + data_len; if (total_len > sizeof(txbuf)) { - printf("Request too large (%zd > %zd)\n", - total_len, sizeof(txbuf)); + printf("Request too large (%zd > %zd)\n", total_len, + sizeof(txbuf)); return -1; } @@ -153,8 +151,7 @@ static int send_request(uint8_t *txbuf, size_t len) tptr = Transfer(mpsse, txbuf, len); if (!tptr) { - fprintf(stderr, "Transfer failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Transfer failed: %s\n", ErrorString(mpsse)); return -1; } @@ -178,7 +175,6 @@ static int send_request(uint8_t *txbuf, size_t len) return ret; } - /* Timeout flag, so we don't wait forever */ static int timedout; static void alarm_handler(int sig) @@ -195,8 +191,8 @@ static void alarm_handler(int sig) * 0 = response received (check hdr for EC result and body size) * -1 = problems */ -static int get_response(struct ec_host_response *hdr, - uint8_t *bodydest, size_t bodylen) +static int get_response(struct ec_host_response *hdr, uint8_t *bodydest, + size_t bodylen) { uint8_t *hptr = 0, *bptr = 0; uint8_t sum = 0; @@ -237,8 +233,7 @@ static int get_response(struct ec_host_response *hdr, /* Now read the response header */ hptr = Read(mpsse, sizeof(*hdr)); if (!hptr) { - fprintf(stderr, "Read failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Read failed: %s\n", ErrorString(mpsse)); goto out; } show("Header(%d):\n", hptr, sizeof(*hdr)); @@ -247,14 +242,12 @@ static int get_response(struct ec_host_response *hdr, /* Check the header */ if (hdr->struct_version != EC_HOST_RESPONSE_VERSION) { printf("response version %d (should be %d)\n", - hdr->struct_version, - EC_HOST_RESPONSE_VERSION); + hdr->struct_version, EC_HOST_RESPONSE_VERSION); goto out; } if (hdr->data_len > bodylen) { - printf("response data_len %d is > %zd\n", - hdr->data_len, + printf("response data_len %d is > %zd\n", hdr->data_len, bodylen); goto out; } @@ -290,19 +283,13 @@ out: return ret; } - /* * Send command, wait for result. Return zero if communication succeeded; check * response to see if the EC liked the command. */ -static int send_cmd(int cmd, int version, - void *outbuf, - size_t outsize, - struct ec_host_response *resp, - void *inbuf, - size_t insize) +static int send_cmd(int cmd, int version, void *outbuf, size_t outsize, + struct ec_host_response *resp, void *inbuf, size_t insize) { - size_t len; int ret = -1; @@ -312,8 +299,7 @@ static int send_cmd(int cmd, int version, return -1; if (MPSSE_OK != Start(mpsse)) { - fprintf(stderr, "Start failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Start failed: %s\n", ErrorString(mpsse)); return -1; } @@ -322,15 +308,13 @@ static int send_cmd(int cmd, int version, ret = 0; if (MPSSE_OK != Stop(mpsse)) { - fprintf(stderr, "Stop failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Stop failed: %s\n", ErrorString(mpsse)); return -1; } return ret; } - /**************************************************************************** * Probe for basic protocol info */ @@ -352,10 +336,8 @@ static int probe_v3(void) if (opt_verbose) printf("Trying EC_CMD_GET_PROTOCOL_INFO...\n"); - ret = send_cmd(EC_CMD_GET_PROTOCOL_INFO, 0, - 0, 0, - &resp, - &info, sizeof(info)); + ret = send_cmd(EC_CMD_GET_PROTOCOL_INFO, 0, 0, 0, &resp, &info, + sizeof(info)); if (ret) { printf("EC_CMD_GET_PROTOCOL_INFO failed\n"); @@ -363,8 +345,8 @@ static int probe_v3(void) } if (EC_RES_SUCCESS != resp.result) { - printf("EC result is %d: %s\n", - resp.result, ec_strerr(resp.result)); + printf("EC result is %d: %s\n", resp.result, + ec_strerr(resp.result)); return -1; } @@ -378,8 +360,7 @@ static int probe_v3(void) info.max_request_packet_size); printf(" max_response_packet_size: %d\n", info.max_response_packet_size); - printf(" flags: 0x%x\n", - info.flags); + printf(" flags: 0x%x\n", info.flags); return 0; } @@ -390,119 +371,118 @@ static int probe_v3(void) struct lookup { uint16_t cmd; - const char * const desc; + const char *const desc; }; static struct lookup cmd_table[] = { - {0x00, "EC_CMD_PROTO_VERSION"}, - {0x01, "EC_CMD_HELLO"}, - {0x02, "EC_CMD_GET_VERSION"}, - {0x03, "EC_CMD_READ_TEST"}, - {0x04, "EC_CMD_GET_BUILD_INFO"}, - {0x05, "EC_CMD_GET_CHIP_INFO"}, - {0x06, "EC_CMD_GET_BOARD_VERSION"}, - {0x07, "EC_CMD_READ_MEMMAP"}, - {0x08, "EC_CMD_GET_CMD_VERSIONS"}, - {0x09, "EC_CMD_GET_COMMS_STATUS"}, - {0x0a, "EC_CMD_TEST_PROTOCOL"}, - {0x0b, "EC_CMD_GET_PROTOCOL_INFO"}, - {0x0c, "EC_CMD_GSV_PAUSE_IN_S5"}, - {0x0d, "EC_CMD_GET_FEATURES"}, - {0x10, "EC_CMD_FLASH_INFO"}, - {0x11, "EC_CMD_FLASH_READ"}, - {0x12, "EC_CMD_FLASH_WRITE"}, - {0x13, "EC_CMD_FLASH_ERASE"}, - {0x15, "EC_CMD_FLASH_PROTECT"}, - {0x16, "EC_CMD_FLASH_REGION_INFO"}, - {0x17, "EC_CMD_VBNV_CONTEXT"}, - {0x20, "EC_CMD_PWM_GET_FAN_TARGET_RPM"}, - {0x21, "EC_CMD_PWM_SET_FAN_TARGET_RPM"}, - {0x22, "EC_CMD_PWM_GET_KEYBOARD_BACKLIGHT"}, - {0x23, "EC_CMD_PWM_SET_KEYBOARD_BACKLIGHT"}, - {0x24, "EC_CMD_PWM_SET_FAN_DUTY"}, - {0x28, "EC_CMD_LIGHTBAR_CMD"}, - {0x29, "EC_CMD_LED_CONTROL"}, - {0x2a, "EC_CMD_VBOOT_HASH"}, - {0x2b, "EC_CMD_MOTION_SENSE_CMD"}, - {0x2c, "EC_CMD_FORCE_LID_OPEN"}, - {0x30, "EC_CMD_USB_CHARGE_SET_MODE"}, - {0x40, "EC_CMD_PSTORE_INFO"}, - {0x41, "EC_CMD_PSTORE_READ"}, - {0x42, "EC_CMD_PSTORE_WRITE"}, - {0x44, "EC_CMD_RTC_GET_VALUE"}, - {0x45, "EC_CMD_RTC_GET_ALARM"}, - {0x46, "EC_CMD_RTC_SET_VALUE"}, - {0x47, "EC_CMD_RTC_SET_ALARM"}, - {0x48, "EC_CMD_PORT80_LAST_BOOT"}, - {0x48, "EC_CMD_PORT80_READ"}, - {0x50, "EC_CMD_THERMAL_SET_THRESHOLD"}, - {0x51, "EC_CMD_THERMAL_GET_THRESHOLD"}, - {0x52, "EC_CMD_THERMAL_AUTO_FAN_CTRL"}, - {0x53, "EC_CMD_TMP006_GET_CALIBRATION"}, - {0x54, "EC_CMD_TMP006_SET_CALIBRATION"}, - {0x55, "EC_CMD_TMP006_GET_RAW"}, - {0x60, "EC_CMD_MKBP_STATE"}, - {0x61, "EC_CMD_MKBP_INFO"}, - {0x62, "EC_CMD_MKBP_SIMULATE_KEY"}, - {0x64, "EC_CMD_MKBP_SET_CONFIG"}, - {0x65, "EC_CMD_MKBP_GET_CONFIG"}, - {0x66, "EC_CMD_KEYSCAN_SEQ_CTRL"}, - {0x67, "EC_CMD_GET_NEXT_EVENT"}, - {0x70, "EC_CMD_TEMP_SENSOR_GET_INFO"}, - {0x87, "EC_CMD_HOST_EVENT_GET_B"}, - {0x88, "EC_CMD_HOST_EVENT_GET_SMI_MASK"}, - {0x89, "EC_CMD_HOST_EVENT_GET_SCI_MASK"}, - {0x8d, "EC_CMD_HOST_EVENT_GET_WAKE_MASK"}, - {0x8a, "EC_CMD_HOST_EVENT_SET_SMI_MASK"}, - {0x8b, "EC_CMD_HOST_EVENT_SET_SCI_MASK"}, - {0x8c, "EC_CMD_HOST_EVENT_CLEAR"}, - {0x8e, "EC_CMD_HOST_EVENT_SET_WAKE_MASK"}, - {0x8f, "EC_CMD_HOST_EVENT_CLEAR_B"}, - {0x90, "EC_CMD_SWITCH_ENABLE_BKLIGHT"}, - {0x91, "EC_CMD_SWITCH_ENABLE_WIRELESS"}, - {0x92, "EC_CMD_GPIO_SET"}, - {0x93, "EC_CMD_GPIO_GET"}, - {0x94, "EC_CMD_I2C_READ"}, - {0x95, "EC_CMD_I2C_WRITE"}, - {0x96, "EC_CMD_CHARGE_CONTROL"}, - {0x97, "EC_CMD_CONSOLE_SNAPSHOT"}, - {0x98, "EC_CMD_CONSOLE_READ"}, - {0x99, "EC_CMD_BATTERY_CUT_OFF"}, - {0x9a, "EC_CMD_USB_MUX"}, - {0x9b, "EC_CMD_LDO_SET"}, - {0x9c, "EC_CMD_LDO_GET"}, - {0x9d, "EC_CMD_POWER_INFO"}, - {0x9e, "EC_CMD_I2C_PASSTHRU"}, - {0x9f, "EC_CMD_HANG_DETECT"}, - {0xa0, "EC_CMD_CHARGE_STATE"}, - {0xa1, "EC_CMD_CHARGE_CURRENT_LIMIT"}, - {0xa2, "EC_CMD_EXT_POWER_CURRENT_LIMIT"}, - {0xb0, "EC_CMD_SB_READ_WORD"}, - {0xb1, "EC_CMD_SB_WRITE_WORD"}, - {0xb2, "EC_CMD_SB_READ_BLOCK"}, - {0xb3, "EC_CMD_SB_WRITE_BLOCK"}, - {0xb4, "EC_CMD_BATTERY_VENDOR_PARAM"}, - {0xb5, "EC_CMD_SB_FW_UPDATE"}, - {0xd2, "EC_CMD_REBOOT_EC"}, - {0xd3, "EC_CMD_GET_PANIC_INFO"}, - {0xd1, "EC_CMD_REBOOT"}, - {0xdb, "EC_CMD_RESEND_RESPONSE"}, - {0xdc, "EC_CMD_VERSION0"}, - {0x100, "EC_CMD_PD_EXCHANGE_STATUS"}, - {0x104, "EC_CMD_PD_HOST_EVENT_STATUS"}, - {0x101, "EC_CMD_USB_PD_CONTROL"}, - {0x102, "EC_CMD_USB_PD_PORTS"}, - {0x103, "EC_CMD_USB_PD_POWER_INFO"}, - {0x110, "EC_CMD_USB_PD_FW_UPDATE"}, - {0x111, "EC_CMD_USB_PD_RW_HASH_ENTRY"}, - {0x112, "EC_CMD_USB_PD_DEV_INFO"}, - {0x113, "EC_CMD_USB_PD_DISCOVERY"}, - {0x114, "EC_CMD_PD_CHARGE_PORT_OVERRIDE"}, - {0x115, "EC_CMD_PD_GET_LOG_ENTRY"}, - {0x116, "EC_CMD_USB_PD_GET_AMODE"}, - {0x117, "EC_CMD_USB_PD_SET_AMODE"}, - {0x118, "EC_CMD_PD_WRITE_LOG_ENTRY"}, - {0x200, "EC_CMD_BLOB"}, + { 0x00, "EC_CMD_PROTO_VERSION" }, + { 0x01, "EC_CMD_HELLO" }, + { 0x02, "EC_CMD_GET_VERSION" }, + { 0x03, "EC_CMD_READ_TEST" }, + { 0x04, "EC_CMD_GET_BUILD_INFO" }, + { 0x05, "EC_CMD_GET_CHIP_INFO" }, + { 0x06, "EC_CMD_GET_BOARD_VERSION" }, + { 0x07, "EC_CMD_READ_MEMMAP" }, + { 0x08, "EC_CMD_GET_CMD_VERSIONS" }, + { 0x09, "EC_CMD_GET_COMMS_STATUS" }, + { 0x0a, "EC_CMD_TEST_PROTOCOL" }, + { 0x0b, "EC_CMD_GET_PROTOCOL_INFO" }, + { 0x0c, "EC_CMD_GSV_PAUSE_IN_S5" }, + { 0x0d, "EC_CMD_GET_FEATURES" }, + { 0x10, "EC_CMD_FLASH_INFO" }, + { 0x11, "EC_CMD_FLASH_READ" }, + { 0x12, "EC_CMD_FLASH_WRITE" }, + { 0x13, "EC_CMD_FLASH_ERASE" }, + { 0x15, "EC_CMD_FLASH_PROTECT" }, + { 0x16, "EC_CMD_FLASH_REGION_INFO" }, + { 0x20, "EC_CMD_PWM_GET_FAN_TARGET_RPM" }, + { 0x21, "EC_CMD_PWM_SET_FAN_TARGET_RPM" }, + { 0x22, "EC_CMD_PWM_GET_KEYBOARD_BACKLIGHT" }, + { 0x23, "EC_CMD_PWM_SET_KEYBOARD_BACKLIGHT" }, + { 0x24, "EC_CMD_PWM_SET_FAN_DUTY" }, + { 0x28, "EC_CMD_LIGHTBAR_CMD" }, + { 0x29, "EC_CMD_LED_CONTROL" }, + { 0x2a, "EC_CMD_VBOOT_HASH" }, + { 0x2b, "EC_CMD_MOTION_SENSE_CMD" }, + { 0x2c, "EC_CMD_FORCE_LID_OPEN" }, + { 0x30, "EC_CMD_USB_CHARGE_SET_MODE" }, + { 0x40, "EC_CMD_PSTORE_INFO" }, + { 0x41, "EC_CMD_PSTORE_READ" }, + { 0x42, "EC_CMD_PSTORE_WRITE" }, + { 0x44, "EC_CMD_RTC_GET_VALUE" }, + { 0x45, "EC_CMD_RTC_GET_ALARM" }, + { 0x46, "EC_CMD_RTC_SET_VALUE" }, + { 0x47, "EC_CMD_RTC_SET_ALARM" }, + { 0x48, "EC_CMD_PORT80_LAST_BOOT" }, + { 0x48, "EC_CMD_PORT80_READ" }, + { 0x50, "EC_CMD_THERMAL_SET_THRESHOLD" }, + { 0x51, "EC_CMD_THERMAL_GET_THRESHOLD" }, + { 0x52, "EC_CMD_THERMAL_AUTO_FAN_CTRL" }, + { 0x53, "EC_CMD_TMP006_GET_CALIBRATION" }, + { 0x54, "EC_CMD_TMP006_SET_CALIBRATION" }, + { 0x55, "EC_CMD_TMP006_GET_RAW" }, + { 0x60, "EC_CMD_MKBP_STATE" }, + { 0x61, "EC_CMD_MKBP_INFO" }, + { 0x62, "EC_CMD_MKBP_SIMULATE_KEY" }, + { 0x64, "EC_CMD_MKBP_SET_CONFIG" }, + { 0x65, "EC_CMD_MKBP_GET_CONFIG" }, + { 0x66, "EC_CMD_KEYSCAN_SEQ_CTRL" }, + { 0x67, "EC_CMD_GET_NEXT_EVENT" }, + { 0x70, "EC_CMD_TEMP_SENSOR_GET_INFO" }, + { 0x87, "EC_CMD_HOST_EVENT_GET_B" }, + { 0x88, "EC_CMD_HOST_EVENT_GET_SMI_MASK" }, + { 0x89, "EC_CMD_HOST_EVENT_GET_SCI_MASK" }, + { 0x8d, "EC_CMD_HOST_EVENT_GET_WAKE_MASK" }, + { 0x8a, "EC_CMD_HOST_EVENT_SET_SMI_MASK" }, + { 0x8b, "EC_CMD_HOST_EVENT_SET_SCI_MASK" }, + { 0x8c, "EC_CMD_HOST_EVENT_CLEAR" }, + { 0x8e, "EC_CMD_HOST_EVENT_SET_WAKE_MASK" }, + { 0x8f, "EC_CMD_HOST_EVENT_CLEAR_B" }, + { 0x90, "EC_CMD_SWITCH_ENABLE_BKLIGHT" }, + { 0x91, "EC_CMD_SWITCH_ENABLE_WIRELESS" }, + { 0x92, "EC_CMD_GPIO_SET" }, + { 0x93, "EC_CMD_GPIO_GET" }, + { 0x94, "EC_CMD_I2C_READ" }, + { 0x95, "EC_CMD_I2C_WRITE" }, + { 0x96, "EC_CMD_CHARGE_CONTROL" }, + { 0x97, "EC_CMD_CONSOLE_SNAPSHOT" }, + { 0x98, "EC_CMD_CONSOLE_READ" }, + { 0x99, "EC_CMD_BATTERY_CUT_OFF" }, + { 0x9a, "EC_CMD_USB_MUX" }, + { 0x9b, "EC_CMD_LDO_SET" }, + { 0x9c, "EC_CMD_LDO_GET" }, + { 0x9d, "EC_CMD_POWER_INFO" }, + { 0x9e, "EC_CMD_I2C_PASSTHRU" }, + { 0x9f, "EC_CMD_HANG_DETECT" }, + { 0xa0, "EC_CMD_CHARGE_STATE" }, + { 0xa1, "EC_CMD_CHARGE_CURRENT_LIMIT" }, + { 0xa2, "EC_CMD_EXT_POWER_CURRENT_LIMIT" }, + { 0xb0, "EC_CMD_SB_READ_WORD" }, + { 0xb1, "EC_CMD_SB_WRITE_WORD" }, + { 0xb2, "EC_CMD_SB_READ_BLOCK" }, + { 0xb3, "EC_CMD_SB_WRITE_BLOCK" }, + { 0xb4, "EC_CMD_BATTERY_VENDOR_PARAM" }, + { 0xb5, "EC_CMD_SB_FW_UPDATE" }, + { 0xd2, "EC_CMD_REBOOT_EC" }, + { 0xd3, "EC_CMD_GET_PANIC_INFO" }, + { 0xd1, "EC_CMD_REBOOT" }, + { 0xdb, "EC_CMD_RESEND_RESPONSE" }, + { 0xdc, "EC_CMD_VERSION0" }, + { 0x100, "EC_CMD_PD_EXCHANGE_STATUS" }, + { 0x104, "EC_CMD_PD_HOST_EVENT_STATUS" }, + { 0x101, "EC_CMD_USB_PD_CONTROL" }, + { 0x102, "EC_CMD_USB_PD_PORTS" }, + { 0x103, "EC_CMD_USB_PD_POWER_INFO" }, + { 0x110, "EC_CMD_USB_PD_FW_UPDATE" }, + { 0x111, "EC_CMD_USB_PD_RW_HASH_ENTRY" }, + { 0x112, "EC_CMD_USB_PD_DEV_INFO" }, + { 0x113, "EC_CMD_USB_PD_DISCOVERY" }, + { 0x114, "EC_CMD_PD_CHARGE_PORT_OVERRIDE" }, + { 0x115, "EC_CMD_PD_GET_LOG_ENTRY" }, + { 0x116, "EC_CMD_USB_PD_GET_AMODE" }, + { 0x117, "EC_CMD_USB_PD_SET_AMODE" }, + { 0x118, "EC_CMD_PD_WRITE_LOG_ENTRY" }, + { 0x200, "EC_CMD_BLOB" }, }; #define ARRAY_SIZE(A) (sizeof(A) / sizeof(A[0])) @@ -532,15 +512,13 @@ static void scan_commands(uint16_t start, uint16_t stop) printf("Supported host commands:\n"); for (i = start; i <= stop; i++) { - if (opt_verbose) printf("Querying CMD %02x\n", i); q_vers.cmd = i; - if (0 != send_cmd(EC_CMD_GET_CMD_VERSIONS, 1, - &q_vers, sizeof(q_vers), - &ec_resp, - &r_vers, sizeof(r_vers))) { + if (0 != send_cmd(EC_CMD_GET_CMD_VERSIONS, 1, &q_vers, + sizeof(q_vers), &ec_resp, &r_vers, + sizeof(r_vers))) { printf("query failed on cmd %02x - aborting\n", i); return; } @@ -557,8 +535,7 @@ static void scan_commands(uint16_t start, uint16_t stop) break; default: printf("lookup of cmd %02x returned %d %s\n", i, - ec_resp.result, - ec_strerr(ec_resp.result)); + ec_resp.result, ec_strerr(ec_resp.result)); } } } diff --git a/extra/i2c_pseudo/Documentation.md b/extra/i2c_pseudo/Documentation.md new file mode 100644 index 0000000000..ebcef6a01e --- /dev/null +++ b/extra/i2c_pseudo/Documentation.md @@ -0,0 +1,279 @@ +# i2c-pseudo driver + +Usually I2C adapters are implemented in a kernel driver. It is also possible to +implement an adapter in userspace, through the /dev/i2c-pseudo-controller +interface. Load module i2c-pseudo for this. + +Use cases for this module include: + +* Using local I2C device drivers, particularly i2c-dev, with I2C busses on + remote systems. For example, interacting with a Device Under Test (DUT) + connected to a Linux host through a debug interface, or interacting with a + remote host over a network. + +* Implementing I2C device driver tests that are impractical with the i2c-stub + module. For example, when simulating an I2C device where its driver might + issue a sequence of reads and writes without interruption, and the value at a + certain address must change during the sequence. + +This is not intended to replace kernel drivers for actual I2C busses on the +local host machine. + +## Details + +Each time /dev/i2c-pseudo-controller is opened, and the correct initialization +command is written to it (ADAPTER_START), a new I2C adapter is created. The +adapter will live until its file descriptor is closed. Multiple pseudo adapters +can co-exist simultaneously, controlled by the same or different userspace +processes. When an I2C device driver sends an I2C message to a pseudo adapter, +the message becomes readable from its file descriptor. If a reply is written +before the adapter timeout expires, that reply will be sent back to the I2C +device driver. + +Reads and writes are buffered inside i2c-pseudo such that userspace controllers +may split them up into arbitrarily small chunks. Multiple commands, or portions +of multiple commands, may be read or written together. + +Blocking I/O is the default. Non-blocking I/O is supported as well, enabled by +O_NONBLOCK. Polling is supported, with or without non-blocking I/O. A special +command (ADAPTER_SHUTDOWN) is available to unblock any pollers or blocked +reads or writes, as a convenience for a multi-threaded or multi-process program +that wants to exit. + +It is safe to access a single controller fd from multiple threads or processes +concurrently, though it is up to the controller to ensure proper ordering, and +to ensure that writes for different commands do not get interleaved. However, +it is recommended (not required) that controller implementations have only one +reader thread and one writer thread, which may or may not be the same thread. +Avoiding multiple readers and multiple writers greatly simplifies controller +implementation, and there is likely no performance benefit to be gained from +concurrent reads or concurrent writes due to how i2c-pseudo serializes them +internally. After all, on a real I2C bus only one I2C message can be active at +a time. + +Commands are newline-terminated, both those read from the controller device, and +those written to it. + +## Read Commands + +The commands that may be read from a pseudo controller device are: + + +--- + +Read Command + +: `I2C_ADAPTER_NUM <num>` + +Example + +: `"I2C_ADAPTER_NUM 5\\n"` + +Details + + +--- + +Read Command + +: `I2C_PSEUDO_ID <num>` + +Example + +: `"I2C_PSEUDO_ID 98\\n"` + +Details + + +--- + +Read Command + +: `I2C_BEGIN_XFER` + +Example + +: `"I2C_BEGIN_XFER\\n"` + +Details + + +--- + +Read Command + +: `I2C_XFER_REQ <xfer_id> <msg_id> <addr> <flags> <data_len> [<write_byte>[:...]]` + +Example + +: `"I2C_XFER_REQ 3 0 0x0070 0x0000 2 AB:9F\\n"` + +Example + +: `"I2C_XFER_REQ 3 1 0x0070 0x0001 4\\n"` + +Details + + +--- + +Read Command + +: `I2C_COMMIT_XFER` + +Example + +: `"I2C_COMMIT_XFER\\n"` + +Details + +## Write Commands + +The commands that may be written to a pseudo controller device are: + +Write Command + +: `SET_ADAPTER_NAME_SUFFIX <suffix>` + +Example + +: `"SET_ADAPTER_NAME_SUFFIX My Adapter\\n"` + +Details + + +--- + +Write Command + +: `SET_ADAPTER_TIMEOUT_MS <ms>` + +Example + +: `"SET_ADAPTER_TIMEOUT_MS 2000\\n"` + +Details + + +--- + +Write Command + +: `ADAPTER_START` + +Example + +: `"ADAPTER_START\\n"` + +Details + + +--- + +Write Command + +: `GET_ADAPTER_NUM` + +Example + +: `"GET_ADAPTER_NUM\\n"` + +Details + + +--- + +Write Command + +: `GET_PSEUDO_ID` + +Example + +: `"GET_PSEUDO_ID\\n"` + +Details + + +--- + +Write Command + +: `I2C_XFER_REPLY <xfer_id> <msg_id> <addr> <flags> <errno> [<read_byte>[:...]]` + +Example + +: `"I2C_XFER_REPLY 3 0 0x0070 0x0000 0\\n"` + +Example + +: `"I2C_XFER_REPLY 3 1 0x0070 0x0001 0 0B:29:02:D9\\n"` + +Details + + +--- + +Write Command + +: `ADAPTER_SHUTDOWN` + +Example + +: `"ADAPTER_SHUTDOWN\\n"` + +Details + +## Example userspace controller code + +In C, a simple exchange between i2c-pseudo and userspace might look like the +example below. Note that for brevity this lacks any error checking and +handling, which a real pseudo controller implementation should have. + +``` +int fd; +char buf[1<<12]; + +fd = open("/dev/i2c-pseudo-controller", O_RDWR); +/* Create the I2C adapter. */ +dprintf(fd, "ADAPTER_START\n"); + +/* + * Pretend this I2C adapter number is 5, and the first I2C xfer sent to it was + * from this command (using its i2c-dev interface): + * $ i2cset -y 5 0x70 0xC2 + * + * Then this read would place the following into *buf: + * "I2C_BEGIN_XFER\n" + * "I2C_XFER_REQ 0 0 0x0070 0x0000 1 C2\n" + * "I2C_COMMIT_XFER\n" + */ +read(fd, buf, sizeof(buf)); + +/* This reply would allow the i2cset command above to exit successfully. */ +dprintf(fd, "I2C_XFER_REPLY 0 0 0x0070 0x0000 0\n"); + +/* + * Now pretend the next I2C xfer sent to this adapter was from: + * $ i2cget -y 5 0x70 0xAB + * + * Then this read would place the following into *buf: + * "I2C_BEGIN_XFER\n" + * "I2C_XFER_REQ 1 0 0x0070 0x0000 1 AB\n" + * "I2C_XFER_REQ 1 1 0x0070 0x0001 1\n'" + * "I2C_COMMIT_XFER\n" + */ +read(fd, buf, sizeof(buf)); + +/* + * These replies would allow the i2cget command above to print the following to + * stdout and exit successfully: + * 0x0b + * + * Note that it is also valid to write these together in one write(). + */ +dprintf(fd, "I2C_XFER_REPLY 1 0 0x0070 0x0000 0\n"); +dprintf(fd, "I2C_XFER_REPLY 1 1 0x0070 0x0001 0 0B\n"); + +/* Destroy the I2C adapter. */ +close(fd); +``` diff --git a/extra/i2c_pseudo/Documentation.rst b/extra/i2c_pseudo/Documentation.rst new file mode 100644 index 0000000000..2527eb5337 --- /dev/null +++ b/extra/i2c_pseudo/Documentation.rst @@ -0,0 +1,306 @@ +================= +i2c-pseudo driver +================= + +Usually I2C adapters are implemented in a kernel driver. It is also possible to +implement an adapter in userspace, through the /dev/i2c-pseudo-controller +interface. Load module i2c-pseudo for this. + +Use cases for this module include: + +- Using local I2C device drivers, particularly i2c-dev, with I2C busses on + remote systems. For example, interacting with a Device Under Test (DUT) + connected to a Linux host through a debug interface, or interacting with a + remote host over a network. + +- Implementing I2C device driver tests that are impractical with the i2c-stub + module. For example, when simulating an I2C device where its driver might + issue a sequence of reads and writes without interruption, and the value at a + certain address must change during the sequence. + +This is not intended to replace kernel drivers for actual I2C busses on the +local host machine. + + +Details +======= + +Each time /dev/i2c-pseudo-controller is opened, and the correct initialization +command is written to it (ADAPTER_START), a new I2C adapter is created. The +adapter will live until its file descriptor is closed. Multiple pseudo adapters +can co-exist simultaneously, controlled by the same or different userspace +processes. When an I2C device driver sends an I2C message to a pseudo adapter, +the message becomes readable from its file descriptor. If a reply is written +before the adapter timeout expires, that reply will be sent back to the I2C +device driver. + +Reads and writes are buffered inside i2c-pseudo such that userspace controllers +may split them up into arbitrarily small chunks. Multiple commands, or portions +of multiple commands, may be read or written together. + +Blocking I/O is the default. Non-blocking I/O is supported as well, enabled by +O_NONBLOCK. Polling is supported, with or without non-blocking I/O. A special +command (ADAPTER_SHUTDOWN) is available to unblock any pollers or blocked +reads or writes, as a convenience for a multi-threaded or multi-process program +that wants to exit. + +It is safe to access a single controller fd from multiple threads or processes +concurrently, though it is up to the controller to ensure proper ordering, and +to ensure that writes for different commands do not get interleaved. However, +it is recommended (not required) that controller implementations have only one +reader thread and one writer thread, which may or may not be the same thread. +Avoiding multiple readers and multiple writers greatly simplifies controller +implementation, and there is likely no performance benefit to be gained from +concurrent reads or concurrent writes due to how i2c-pseudo serializes them +internally. After all, on a real I2C bus only one I2C message can be active at +a time. + +Commands are newline-terminated, both those read from the controller device, and +those written to it. + + +Read Commands +============= + +The commands that may be read from a pseudo controller device are: + +---- + +:Read Command: ``I2C_ADAPTER_NUM <num>`` +:Example: ``"I2C_ADAPTER_NUM 5\n"`` +:Details: + | This is read in response to the GET_ADAPTER_NUM command being written. + The number is the I2C adapter number in decimal. This can only occur after + ADAPTER_START, because before that the number is not known and cannot be + predicted reliably. + +---- + +:Read Command: ``I2C_PSEUDO_ID <num>`` +:Example: ``"I2C_PSEUDO_ID 98\n"`` +:Details: + | This is read in response to the GET_PSEUDO_ID command being written. + The number is the pseudo ID in decimal. + +---- + +:Read Command: ``I2C_BEGIN_XFER`` +:Example: ``"I2C_BEGIN_XFER\n"`` +:Details: + | This indicates the start of an I2C transaction request, in other words + the start of the I2C messages from a single invocation of the I2C adapter's + master_xfer() callback. This can only occur after ADAPTER_START. + +---- + +:Read Command: ``I2C_XFER_REQ <xfer_id> <msg_id> <addr> <flags> <data_len> [<write_byte>[:...]]`` +:Example: ``"I2C_XFER_REQ 3 0 0x0070 0x0000 2 AB:9F\n"`` +:Example: ``"I2C_XFER_REQ 3 1 0x0070 0x0001 4\n"`` +:Details: + | This is a single I2C message that a device driver requested be sent on + the bus, in other words a single struct i2c_msg from master_xfer() msgs arg. + | + | The xfer_id is a number representing the whole I2C transaction, thus all + I2C_XFER_REQ between a I2C_BEGIN_XFER + I2C_COMMIT_XFER pair share an + xfer_id. The purpose is to ensure replies from the userspace controller are + always properly matched to the intended master_xfer() request. The first + transaction has xfer_id 0, and it increases by 1 with each transaction, + however it will eventually wrap back to 0 if enough transactions happen + during the lifetime of a pseudo adapter. It is guaranteed to have a large + enough maximum value such that there can never be multiple outstanding + transactions with the same ID, due to an internal limit in i2c-pseudo that + will block master_xfer() calls when the controller is falling behind in its + replies. + | + | The msg_id is a decimal number representing the index of the I2C message + within its transaction, in other words the index in master_xfer() \*msgs + array arg. This starts at 0 after each I2C_BEGIN_XFER. This is guaranteed + to not wrap. + | + | The addr is the hexadecimal I2C address for this I2C message. The address + is right-aligned without any read/write bit. + | + | The flags are the same bitmask flags used in struct i2c_msg, in hexadecimal + form. Of particular importance to any pseudo controller is the read bit, + which is guaranteed to be 0x1 per Linux I2C documentation. + | + | The data_len is the decimal number of either how many bytes to write that + will follow, or how many bytes to read and reply with if this is a read + request. + | + | If this is a read, data_len will be the final field in this command. If + this is a write, data_len will be followed by the given number of + colon-separated hexadecimal byte values, in the format shown in the example + above. + +---- + +:Read Command: ``I2C_COMMIT_XFER`` +:Example: ``"I2C_COMMIT_XFER\n"`` +:Details: + | This indicates the end of an I2C transaction request, in other words the + end of the I2C messages from a single invocation of the I2C adapter's + master_xfer() callback. This should be read exactly once after each + I2C_BEGIN_XFER, with a varying number of I2C_XFER_REQ between them. + + +Write Commands +============== + +The commands that may be written to a pseudo controller device are: + + +:Write Command: ``SET_ADAPTER_NAME_SUFFIX <suffix>`` +:Example: ``"SET_ADAPTER_NAME_SUFFIX My Adapter\n"`` +:Details: + | Sets a suffix to append to the auto-generated I2C adapter name. Only + valid before ADAPTER_START. A space or other separator character will be + placed between the auto-generated name and the suffix, so there is no need + to include a leading separator in the suffix. If the resulting name is too + long for the I2C adapter name field, it will be quietly truncated. + +---- + +:Write Command: ``SET_ADAPTER_TIMEOUT_MS <ms>`` +:Example: ``"SET_ADAPTER_TIMEOUT_MS 2000\n"`` +:Details: + | Sets the timeout in milliseconds for each I2C transaction, in other words + for each master_xfer() reply. Only valid before ADAPTER_START. The I2C + subsystem will automatically time out transactions based on this setting. + Set to 0 to use the I2C subsystem default timeout. The default timeout for + new pseudo adapters where this command has not been used is configurable at + i2c-pseudo module load time, and itself has a default independent from the + I2C subsystem default. (If the i2c-pseudo module level default is set to 0, + that has the same meaning as here.) + +---- + +:Write Command: ``ADAPTER_START`` +:Example: ``"ADAPTER_START\n"`` +:Details: + | Tells i2c-pseudo to actually create the I2C adapter. Only valid once per + open controller fd. + +---- + +:Write Command: ``GET_ADAPTER_NUM`` +:Example: ``"GET_ADAPTER_NUM\n"`` +:Details: + | Asks i2c-pseudo for the number assigned to this I2C adapter by the I2C + subsystem. Only valid after ADAPTER_START, because before that the number + is not known and cannot be predicted reliably. + +---- + +:Write Command: ``GET_PSEUDO_ID`` +:Example: ``"GET_PSEUDO_ID\n"`` +:Details: + | Asks i2c-pseudo for the pseudo ID of this I2C adapter. The pseudo ID will + not be reused for the lifetime of the i2c-pseudo module, unless an internal + counter wraps. I2C clients can use this to track specific instances of + pseudo adapters, even when adapter numbers have been reused. + +---- + +:Write Command: ``I2C_XFER_REPLY <xfer_id> <msg_id> <addr> <flags> <errno> [<read_byte>[:...]]`` +:Example: ``"I2C_XFER_REPLY 3 0 0x0070 0x0000 0\n"`` +:Example: ``"I2C_XFER_REPLY 3 1 0x0070 0x0001 0 0B:29:02:D9\n"`` +:Details: + | This is how a pseudo controller can reply to I2C_XFER_REQ. Only valid + after I2C_XFER_REQ. A pseudo controller should write one of these for each + I2C_XFER_REQ it reads, including for failures, so that I2C device drivers + need not wait for the adapter timeout upon failure (if failure is known + sooner). + | + | The fields in common with I2C_XFER_REQ have their same meanings, and their + values are expected to exactly match what was read in the I2C_XFER_REQ + command that this is in reply to. + | + | The errno field is how the pseudo controller indicates success or failure + for this I2C message. A 0 value indicates success. A non-zero value + indicates a failure. Pseudo controllers are encouraged to use errno values + to encode some meaning in a failure response, but that is not a requirement, + and the I2C adapter interface does not provide a way to pass per-message + errno values to a device driver anyways. + | + | Pseudo controllers are encouraged to reply in the same order as messages + were received, however i2c-pseudo will properly match up out-of-order + replies with their original requests. + +---- + +:Write Command: ``ADAPTER_SHUTDOWN`` +:Example: ``"ADAPTER_SHUTDOWN\n"`` +:Details: + | This tells i2c-pseudo that the pseudo controller wants to shutdown and + intends to close the controller device fd soon. Use of this is OPTIONAL, it + is perfectly valid to close the controller device fd without ever using this + command. + | + | This commands unblocks any blocked controller I/O (reads, writes, or polls), + and that is its main purpose. + | + | Any I2C transactions attempted by a device driver after this command will + fail, and will not be passed on to the userspace controller. + | + | This DOES NOT delete the I2C adapter. Only closing the fd will do that. + That MAY CHANGE in the future, such that this does delete the I2C adapter. + (However this will never be required, it will always be okay to simply close + the fd.) + + +Example userspace controller code +================================= + +In C, a simple exchange between i2c-pseudo and userspace might look like the +example below. Note that for brevity this lacks any error checking and +handling, which a real pseudo controller implementation should have. + +:: + + int fd; + char buf[1<<12]; + + fd = open("/dev/i2c-pseudo-controller", O_RDWR); + /* Create the I2C adapter. */ + dprintf(fd, "ADAPTER_START\n"); + + /* + * Pretend this I2C adapter number is 5, and the first I2C xfer sent to it was + * from this command (using its i2c-dev interface): + * $ i2cset -y 5 0x70 0xC2 + * + * Then this read would place the following into *buf: + * "I2C_BEGIN_XFER\n" + * "I2C_XFER_REQ 0 0 0x0070 0x0000 1 C2\n" + * "I2C_COMMIT_XFER\n" + */ + read(fd, buf, sizeof(buf)); + + /* This reply would allow the i2cset command above to exit successfully. */ + dprintf(fd, "I2C_XFER_REPLY 0 0 0x0070 0x0000 0\n"); + + /* + * Now pretend the next I2C xfer sent to this adapter was from: + * $ i2cget -y 5 0x70 0xAB + * + * Then this read would place the following into *buf: + * "I2C_BEGIN_XFER\n" + * "I2C_XFER_REQ 1 0 0x0070 0x0000 1 AB\n" + * "I2C_XFER_REQ 1 1 0x0070 0x0001 1\n'" + * "I2C_COMMIT_XFER\n" + */ + read(fd, buf, sizeof(buf)); + + /* + * These replies would allow the i2cget command above to print the following to + * stdout and exit successfully: + * 0x0b + * + * Note that it is also valid to write these together in one write(). + */ + dprintf(fd, "I2C_XFER_REPLY 1 0 0x0070 0x0000 0\n"); + dprintf(fd, "I2C_XFER_REPLY 1 1 0x0070 0x0001 0 0B\n"); + + /* Destroy the I2C adapter. */ + close(fd); diff --git a/extra/i2c_pseudo/Makefile b/extra/i2c_pseudo/Makefile index f7fda6e2de..b53085a970 100644 --- a/extra/i2c_pseudo/Makefile +++ b/extra/i2c_pseudo/Makefile @@ -1,4 +1,4 @@ -# Copyright 2020 The Chromium OS Authors. All rights reserved. +# Copyright 2020 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # diff --git a/extra/i2c_pseudo/README b/extra/i2c_pseudo/README index 96efa062b1..1d1ef75641 100644 --- a/extra/i2c_pseudo/README +++ b/extra/i2c_pseudo/README @@ -2,12 +2,16 @@ This directory contains the i2c-pseudo Linux kernel module. The i2c-pseudo module was written with the intention of being submitted upstream in the Linux kernel. This copy exists because of as 2019-03 this module is not -yet in the upstream kernel, and even if/when this is included, it may take years -before making its way to the prepackaged Linux distribution kernels typically -used by CrOS developers. +yet in the upstream kernel, and even if/when this is included, it may take a +long time to get included in prepackaged Linux distribution kernels, especially +those based on Linux LTS branches. -See Documentation.txt for more information about the module itself. That file -is Documentation/i2c/pseudo-controller-interface in the upstream patch. +See Documentation.rst or Documentation.md for more information about the module +itself. The reStructuredText (.rst) file is +Documentation/i2c/pseudo-controller-interface.rst in the upstream patch. The +Markdown file (.md) is generated using rst2md from +nb2plots (https://github.com/matthew-brett/nb2plots) which uses +Sphinx (https://www.sphinx-doc.org/). When servod starts, if the i2c-pseudo module is loaded servod will automatically create an I2C pseudo adapter for the Servo I2C bus. That I2C adapter may then diff --git a/extra/i2c_pseudo/check_stream_open.sh b/extra/i2c_pseudo/check_stream_open.sh index da802cb282..70cffd7c73 100755 --- a/extra/i2c_pseudo/check_stream_open.sh +++ b/extra/i2c_pseudo/check_stream_open.sh @@ -1,6 +1,6 @@ #!/bin/sh # -# Copyright 2020 The Chromium OS Authors. All rights reserved. +# Copyright 2020 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # diff --git a/extra/i2c_pseudo/i2c-pseudo.c b/extra/i2c_pseudo/i2c-pseudo.c index 325d140663..7cb2904322 100644 --- a/extra/i2c_pseudo/i2c-pseudo.c +++ b/extra/i2c_pseudo/i2c-pseudo.c @@ -30,47 +30,47 @@ #include <linux/wait.h> /* Minimum i2cp_limit module parameter value. */ -#define I2CP_ADAPTERS_MIN 0 +#define I2CP_ADAPTERS_MIN 0 /* Maximum i2cp_limit module parameter value. */ -#define I2CP_ADAPTERS_MAX 256 +#define I2CP_ADAPTERS_MAX 256 /* Default i2cp_limit module parameter value. */ -#define I2CP_DEFAULT_LIMIT 8 +#define I2CP_DEFAULT_LIMIT 8 /* Value for alloc_chrdev_region() baseminor arg. */ -#define I2CP_CDEV_BASEMINOR 0 -#define I2CP_TIMEOUT_MS_MIN 0 -#define I2CP_TIMEOUT_MS_MAX (60 * MSEC_PER_SEC) -#define I2CP_DEFAULT_TIMEOUT_MS (3 * MSEC_PER_SEC) +#define I2CP_CDEV_BASEMINOR 0 +#define I2CP_TIMEOUT_MS_MIN 0 +#define I2CP_TIMEOUT_MS_MAX (60 * MSEC_PER_SEC) +#define I2CP_DEFAULT_TIMEOUT_MS (3 * MSEC_PER_SEC) /* Used in struct device.kobj.name field. */ -#define I2CP_DEVICE_NAME "i2c-pseudo-controller" +#define I2CP_DEVICE_NAME "i2c-pseudo-controller" /* Value for alloc_chrdev_region() name arg. */ -#define I2CP_CHRDEV_NAME "i2c_pseudo" +#define I2CP_CHRDEV_NAME "i2c_pseudo" /* Value for class_create() name arg. */ -#define I2CP_CLASS_NAME "i2c-pseudo" +#define I2CP_CLASS_NAME "i2c-pseudo" /* Value for alloc_chrdev_region() count arg. Should always be 1. */ -#define I2CP_CDEV_COUNT 1 - -#define I2CP_ADAP_START_CMD "ADAPTER_START" -#define I2CP_ADAP_SHUTDOWN_CMD "ADAPTER_SHUTDOWN" -#define I2CP_GET_NUMBER_CMD "GET_ADAPTER_NUM" -#define I2CP_NUMBER_REPLY_CMD "I2C_ADAPTER_NUM" -#define I2CP_GET_PSEUDO_ID_CMD "GET_PSEUDO_ID" -#define I2CP_PSEUDO_ID_REPLY_CMD "I2C_PSEUDO_ID" -#define I2CP_SET_NAME_SUFFIX_CMD "SET_ADAPTER_NAME_SUFFIX" -#define I2CP_SET_TIMEOUT_CMD "SET_ADAPTER_TIMEOUT_MS" -#define I2CP_BEGIN_MXFER_REQ_CMD "I2C_BEGIN_XFER" -#define I2CP_COMMIT_MXFER_REQ_CMD "I2C_COMMIT_XFER" -#define I2CP_MXFER_REQ_CMD "I2C_XFER_REQ" -#define I2CP_MXFER_REPLY_CMD "I2C_XFER_REPLY" +#define I2CP_CDEV_COUNT 1 + +#define I2CP_ADAP_START_CMD "ADAPTER_START" +#define I2CP_ADAP_SHUTDOWN_CMD "ADAPTER_SHUTDOWN" +#define I2CP_GET_NUMBER_CMD "GET_ADAPTER_NUM" +#define I2CP_NUMBER_REPLY_CMD "I2C_ADAPTER_NUM" +#define I2CP_GET_PSEUDO_ID_CMD "GET_PSEUDO_ID" +#define I2CP_PSEUDO_ID_REPLY_CMD "I2C_PSEUDO_ID" +#define I2CP_SET_NAME_SUFFIX_CMD "SET_ADAPTER_NAME_SUFFIX" +#define I2CP_SET_TIMEOUT_CMD "SET_ADAPTER_TIMEOUT_MS" +#define I2CP_BEGIN_MXFER_REQ_CMD "I2C_BEGIN_XFER" +#define I2CP_COMMIT_MXFER_REQ_CMD "I2C_COMMIT_XFER" +#define I2CP_MXFER_REQ_CMD "I2C_XFER_REQ" +#define I2CP_MXFER_REPLY_CMD "I2C_XFER_REPLY" /* Maximum size of a controller command. */ -#define I2CP_CTRLR_CMD_LIMIT 255 +#define I2CP_CTRLR_CMD_LIMIT 255 /* Maximum number of controller read responses to allow enqueued at once. */ -#define I2CP_CTRLR_RSP_QUEUE_LIMIT 256 +#define I2CP_CTRLR_RSP_QUEUE_LIMIT 256 /* The maximum size of a single controller read response. */ -#define I2CP_MAX_MSG_BUF_SIZE 16384 +#define I2CP_MAX_MSG_BUF_SIZE 16384 /* Maximum size of a controller read or write. */ -#define I2CP_RW_SIZE_LIMIT 1048576 +#define I2CP_RW_SIZE_LIMIT 1048576 /* * Marks the end of a controller command or read response. @@ -85,11 +85,11 @@ * because of an assertion that the copy size (1) must match the size of the * string literal (2 with its trailing null). */ -static const char i2cp_ctrlr_end_char = '\n'; +static const char i2cp_ctrlr_end_char = '\n'; /* Separator between I2C message header fields in the controller bytestream. */ -static const char i2cp_ctrlr_header_sep_char = ' '; +static const char i2cp_ctrlr_header_sep_char = ' '; /* Separator between I2C message data bytes in the controller bytestream. */ -static const char i2cp_ctrlr_data_sep_char = ':'; +static const char i2cp_ctrlr_data_sep_char = ':'; /* * This used instead of strcmp(in_str, other_str) because in_str may have null @@ -99,10 +99,10 @@ static const char i2cp_ctrlr_data_sep_char = ':'; #define STRING_NEQ(in_str, in_size, other_str) \ (in_size != strlen(other_str) || memcmp(other_str, in_str, in_size)) -#define STR_HELPER(num) #num -#define STR(num) STR_HELPER(num) +#define STR_HELPER(num) #num +#define STR(num) STR_HELPER(num) -#define CONST_STRLEN(str) (sizeof(str) - 1) +#define CONST_STRLEN(str) (sizeof(str) - 1) /* * The number of pseudo I2C adapters permitted. This default value can be @@ -207,8 +207,8 @@ struct i2cp_cmd { * behavior with duplicate command names is undefined, subject to * change, and subject to become either a build-time or runtime error. */ - char *cmd_string; /* Must be non-NULL. */ - size_t cmd_size; /* Must be non-zero. */ + char *cmd_string; /* Must be non-NULL. */ + size_t cmd_size; /* Must be non-zero. */ /* * This is called once for each I2C pseudo controller to initialize @@ -308,7 +308,7 @@ struct i2cp_cmd { * This callback MUST NOT be NULL. */ int (*header_receiver)(void *data, char *in, size_t in_size, - bool non_blocking); + bool non_blocking); /* * This is called to process write command data, when requested by the * header_receiver() return value. @@ -347,7 +347,7 @@ struct i2cp_cmd { * should be NULL. Otherwise, this callback MUST NOT be NULL. */ int (*data_receiver)(void *data, char *in, size_t in_size, - bool non_blocking); + bool non_blocking); /* * This is called to complete processing of a command, after it has been * received in its entirety. @@ -394,7 +394,7 @@ struct i2cp_cmd { * This callback may be NULL. */ int (*cmd_completer)(void *data, struct i2cp_controller *pdata, - int receive_status, bool non_blocking); + int receive_status, bool non_blocking); }; /* @@ -749,13 +749,13 @@ struct i2cp_rsp_master_xfer { * Always initialize fields below here to zero. They are for internal * use by i2cp_rsp_master_xfer_formatter(). */ - int num_msgs_done; /* type of @num field */ + int num_msgs_done; /* type of @num field */ size_t buf_start_plus_one; }; /* vanprintf - See anprintf() documentation. */ static ssize_t vanprintf(char **out, ssize_t max_size, gfp_t gfp, - const char *fmt, va_list ap) + const char *fmt, va_list ap) { int ret; ssize_t buf_size; @@ -790,9 +790,9 @@ static ssize_t vanprintf(char **out, ssize_t max_size, gfp_t gfp, *out = buf; return ret; - fail_before_args1: +fail_before_args1: va_end(args1); - fail_after_args1: +fail_after_args1: kfree(buf); if (ret >= 0) ret = -ENOTRECOVERABLE; @@ -833,7 +833,7 @@ static ssize_t vanprintf(char **out, ssize_t max_size, gfp_t gfp, * a bug. */ static ssize_t anprintf(char **out, ssize_t max_size, gfp_t gfp, - const char *fmt, ...) + const char *fmt, ...) { ssize_t ret; va_list args; @@ -905,24 +905,26 @@ static ssize_t i2cp_rsp_master_xfer_formatter(void *data, char **out) * that no bytes were lost in kernel->userspace transmission. */ ret = anprintf(&buf_start, I2CP_MAX_MSG_BUF_SIZE, GFP_KERNEL, - "%*s%c%u%c%d%c0x%04X%c0x%04X%c%u", - (int)strlen(I2CP_MXFER_REQ_CMD), I2CP_MXFER_REQ_CMD, - i2cp_ctrlr_header_sep_char, mxfer_rsp->id, - i2cp_ctrlr_header_sep_char, mxfer_rsp->num_msgs_done, - i2cp_ctrlr_header_sep_char, i2c_msg->addr, - i2cp_ctrlr_header_sep_char, i2c_msg->flags, - i2cp_ctrlr_header_sep_char, i2c_msg->len); + "%*s%c%u%c%d%c0x%04X%c0x%04X%c%u", + (int)strlen(I2CP_MXFER_REQ_CMD), + I2CP_MXFER_REQ_CMD, i2cp_ctrlr_header_sep_char, + mxfer_rsp->id, i2cp_ctrlr_header_sep_char, + mxfer_rsp->num_msgs_done, + i2cp_ctrlr_header_sep_char, i2c_msg->addr, + i2cp_ctrlr_header_sep_char, i2c_msg->flags, + i2cp_ctrlr_header_sep_char, i2c_msg->len); if (ret > 0) { *out = buf_start; mxfer_rsp->buf_start_plus_one = 1; - /* - * If we have a zero return value, it means the output buffer - * was allocated as size one, containing only a terminating null - * character. This would be a bug given the requested format - * string above. Also, formatter functions must not mutate *out - * when returning zero. So if this matches, free the useless - * buffer and return an error. - */ + /* + * If we have a zero return value, it means the output + * buffer was allocated as size one, containing only a + * terminating null character. This would be a bug + * given the requested format string above. Also, + * formatter functions must not mutate *out when + * returning zero. So if this matches, free the useless + * buffer and return an error. + */ } else if (ret == 0) { ret = -EINVAL; kfree(buf_start); @@ -932,7 +934,7 @@ static ssize_t i2cp_rsp_master_xfer_formatter(void *data, char **out) byte_start = mxfer_rsp->buf_start_plus_one - 1; byte_limit = min_t(size_t, i2c_msg->len - byte_start, - I2CP_MAX_MSG_BUF_SIZE / 3); + I2CP_MAX_MSG_BUF_SIZE / 3); /* 3 chars per byte == 2 chars for hex + 1 char for separator */ buf_size = byte_limit * 3; @@ -943,34 +945,34 @@ static ssize_t i2cp_rsp_master_xfer_formatter(void *data, char **out) } for (buf_pos = buf_start, i = 0; i < byte_limit; ++i) { - *buf_pos++ = (i || byte_start) ? - i2cp_ctrlr_data_sep_char : i2cp_ctrlr_header_sep_char; - buf_pos = hex_byte_pack_upper( - buf_pos, i2c_msg->buf[byte_start + i]); + *buf_pos++ = (i || byte_start) ? i2cp_ctrlr_data_sep_char : + i2cp_ctrlr_header_sep_char; + buf_pos = hex_byte_pack_upper(buf_pos, + i2c_msg->buf[byte_start + i]); } *out = buf_start; ret = buf_size; mxfer_rsp->buf_start_plus_one += i; - maybe_free: +maybe_free: if (ret <= 0) { if (mxfer_rsp->num_msgs_done >= mxfer_rsp->num) { kfree(mxfer_rsp->msgs); kfree(mxfer_rsp); - /* - * If we are returning an error but have not consumed all of - * mxfer_rsp yet, we must not attempt to output any more I2C - * messages from the same mxfer_rsp. Setting mxfer_rsp->msgs to - * NULL tells the remaining invocations with this mxfer_rsp to - * output nothing. - * - * There can be more invocations with the same mxfer_rsp even - * after returning an error here because - * i2cp_adapter_master_xfer() reuses a single - * struct i2cp_rsp_master_xfer (mxfer_rsp) across multiple - * struct i2cp_rsp (rsp_wrappers), one for each struct i2c_msg - * within the mxfer_rsp. - */ + /* + * If we are returning an error but have not consumed + * all of mxfer_rsp yet, we must not attempt to output + * any more I2C messages from the same mxfer_rsp. + * Setting mxfer_rsp->msgs to NULL tells the remaining + * invocations with this mxfer_rsp to output nothing. + * + * There can be more invocations with the same mxfer_rsp + * even after returning an error here because + * i2cp_adapter_master_xfer() reuses a single + * struct i2cp_rsp_master_xfer (mxfer_rsp) across + * multiple struct i2cp_rsp (rsp_wrappers), one for each + * struct i2c_msg within the mxfer_rsp. + */ } else if (ret < 0) { kfree(mxfer_rsp->msgs); mxfer_rsp->msgs = NULL; @@ -980,7 +982,7 @@ static ssize_t i2cp_rsp_master_xfer_formatter(void *data, char **out) } static ssize_t i2cp_id_show(struct device *dev, struct device_attribute *attr, - char *buf) + char *buf) { int ret; struct i2c_adapter *adap; @@ -1039,9 +1041,10 @@ static void i2cp_cmd_mxfer_reply_data_shutdown(void *data) cmd_data = data; mutex_lock(&cmd_data->reply_queue_lock); - list_for_each(list_ptr, &cmd_data->reply_queue_head) { + list_for_each(list_ptr, &cmd_data->reply_queue_head) + { mxfer_reply = list_entry(list_ptr, struct i2cp_cmd_mxfer_reply, - reply_queue_item); + reply_queue_item); mutex_lock(&mxfer_reply->lock); complete_all(&mxfer_reply->data_filled); mutex_unlock(&mxfer_reply->lock); @@ -1059,29 +1062,30 @@ static void i2cp_cmd_mxfer_reply_data_destroyer(void *data) kfree(data); } -static inline bool i2cp_mxfer_reply_is_current( - struct i2cp_cmd_mxfer_reply_data *cmd_data, - struct i2cp_cmd_mxfer_reply *mxfer_reply) +static inline bool +i2cp_mxfer_reply_is_current(struct i2cp_cmd_mxfer_reply_data *cmd_data, + struct i2cp_cmd_mxfer_reply *mxfer_reply) { int i; i = cmd_data->current_msg_idx; - return cmd_data->current_id == mxfer_reply->id && - i >= 0 && i < mxfer_reply->num_msgs && - cmd_data->current_addr == mxfer_reply->msgs[i].addr && - cmd_data->current_flags == mxfer_reply->msgs[i].flags; + return cmd_data->current_id == mxfer_reply->id && i >= 0 && + i < mxfer_reply->num_msgs && + cmd_data->current_addr == mxfer_reply->msgs[i].addr && + cmd_data->current_flags == mxfer_reply->msgs[i].flags; } /* cmd_data->reply_queue_lock must be held. */ -static inline struct i2cp_cmd_mxfer_reply *i2cp_mxfer_reply_find_current( - struct i2cp_cmd_mxfer_reply_data *cmd_data) +static inline struct i2cp_cmd_mxfer_reply * +i2cp_mxfer_reply_find_current(struct i2cp_cmd_mxfer_reply_data *cmd_data) { struct list_head *list_ptr; struct i2cp_cmd_mxfer_reply *mxfer_reply; - list_for_each(list_ptr, &cmd_data->reply_queue_head) { + list_for_each(list_ptr, &cmd_data->reply_queue_head) + { mxfer_reply = list_entry(list_ptr, struct i2cp_cmd_mxfer_reply, - reply_queue_item); + reply_queue_item); if (i2cp_mxfer_reply_is_current(cmd_data, mxfer_reply)) return mxfer_reply; } @@ -1089,17 +1093,18 @@ static inline struct i2cp_cmd_mxfer_reply *i2cp_mxfer_reply_find_current( } /* cmd_data->reply_queue_lock must NOT already be held. */ -static inline void i2cp_mxfer_reply_update_current( - struct i2cp_cmd_mxfer_reply_data *cmd_data) +static inline void +i2cp_mxfer_reply_update_current(struct i2cp_cmd_mxfer_reply_data *cmd_data) { mutex_lock(&cmd_data->reply_queue_lock); - cmd_data->reply_queue_current_item = i2cp_mxfer_reply_find_current( - cmd_data); + cmd_data->reply_queue_current_item = + i2cp_mxfer_reply_find_current(cmd_data); mutex_unlock(&cmd_data->reply_queue_lock); } static int i2cp_cmd_mxfer_reply_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { int ret, reply_errno = 0; struct i2cp_cmd_mxfer_reply_data *cmd_data; @@ -1218,10 +1223,10 @@ static int i2cp_cmd_mxfer_reply_header_receiver(void *data, char *in, } static int i2cp_cmd_mxfer_reply_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, bool non_blocking) { int ret; - char u8_hex[3] = {0}; + char u8_hex[3] = { 0 }; struct i2cp_cmd_mxfer_reply_data *cmd_data; struct i2cp_cmd_mxfer_reply *mxfer_reply; struct i2c_msg *i2c_msg; @@ -1333,7 +1338,7 @@ static int i2cp_cmd_mxfer_reply_data_receiver(void *data, char *in, * I2C_M_DMA_SAFE bit? Do we ever need to use copy_to_user()? */ ret = kstrtou8(u8_hex, 16, - &i2c_msg->buf[cmd_data->current_buf_idx]); + &i2c_msg->buf[cmd_data->current_buf_idx]); if (ret < 0) goto unlock; if (i2c_msg->flags & I2C_M_RECV_LEN) @@ -1346,13 +1351,15 @@ static int i2cp_cmd_mxfer_reply_data_receiver(void *data, char *in, /* Quietly ignore any bytes beyond the buffer size. */ ret = 0; - unlock: +unlock: mutex_unlock(&mxfer_reply->lock); return ret; } static int i2cp_cmd_mxfer_reply_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { int ret; struct i2cp_cmd_mxfer_reply_data *cmd_data; @@ -1399,7 +1406,7 @@ static int i2cp_cmd_mxfer_reply_cmd_completer(void *data, mutex_unlock(&mxfer_reply->lock); ret = 0; - reset_cmd_data: +reset_cmd_data: cmd_data->state = I2CP_CMD_MXFER_REPLY_STATE_CMD_NEXT; cmd_data->current_id = 0; cmd_data->current_addr = 0; @@ -1410,7 +1417,8 @@ static int i2cp_cmd_mxfer_reply_cmd_completer(void *data, } static int i2cp_cmd_adap_start_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { /* * No more header fields or data are expected. This directs any further @@ -1421,7 +1429,7 @@ static int i2cp_cmd_adap_start_header_receiver(void *data, char *in, } static int i2cp_cmd_adap_start_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, bool non_blocking) { /* * Reaching here means the controller wrote extra data in the command @@ -1432,7 +1440,9 @@ static int i2cp_cmd_adap_start_data_receiver(void *data, char *in, } static int i2cp_cmd_adap_start_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { int ret; @@ -1466,13 +1476,14 @@ static int i2cp_cmd_adap_start_cmd_completer(void *data, ret = 0; - unlock: +unlock: mutex_unlock(&pdata->startstop_lock); return ret; } static int i2cp_cmd_adap_shutdown_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { /* * No more header fields or data are expected. This directs any further @@ -1483,7 +1494,8 @@ static int i2cp_cmd_adap_shutdown_header_receiver(void *data, char *in, } static int i2cp_cmd_adap_shutdown_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { /* * Reaching here means the controller wrote extra data in the command @@ -1494,7 +1506,9 @@ static int i2cp_cmd_adap_shutdown_data_receiver(void *data, char *in, } static int i2cp_cmd_adap_shutdown_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { /* Refuse to shutdown if there were errors processing this command. */ if (receive_status) @@ -1512,7 +1526,8 @@ static int i2cp_cmd_adap_shutdown_cmd_completer(void *data, } static int i2cp_cmd_get_number_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { /* * No more header fields or data are expected. This directs any further @@ -1523,7 +1538,7 @@ static int i2cp_cmd_get_number_header_receiver(void *data, char *in, } static int i2cp_cmd_get_number_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, bool non_blocking) { /* * Reaching here means the controller wrote extra data in the command @@ -1534,7 +1549,9 @@ static int i2cp_cmd_get_number_data_receiver(void *data, char *in, } static int i2cp_cmd_get_number_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { ssize_t ret; int i2c_adap_nr; @@ -1572,9 +1589,9 @@ static int i2cp_cmd_get_number_cmd_completer(void *data, } ret = anprintf(&rsp_buf->buf, I2CP_MAX_MSG_BUF_SIZE, GFP_KERNEL, - "%*s%c%d", - (int)strlen(I2CP_NUMBER_REPLY_CMD), I2CP_NUMBER_REPLY_CMD, - i2cp_ctrlr_header_sep_char, i2c_adap_nr); + "%*s%c%d", (int)strlen(I2CP_NUMBER_REPLY_CMD), + I2CP_NUMBER_REPLY_CMD, i2cp_ctrlr_header_sep_char, + i2c_adap_nr); if (ret < 0) { goto fail_after_rsp_buf_alloc; } else if (ret == 0) { @@ -1600,17 +1617,18 @@ static int i2cp_cmd_get_number_cmd_completer(void *data, mutex_unlock(&pdata->read_rsp_queue_lock); return 0; - fail_after_buf_alloc: +fail_after_buf_alloc: kfree(rsp_buf->buf); - fail_after_rsp_buf_alloc: +fail_after_rsp_buf_alloc: kfree(rsp_buf); - fail_after_rsp_wrapper_alloc: +fail_after_rsp_wrapper_alloc: kfree(rsp_wrapper); return ret; } static int i2cp_cmd_get_pseudo_id_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { /* * No more header fields or data are expected. This directs any further @@ -1621,7 +1639,8 @@ static int i2cp_cmd_get_pseudo_id_header_receiver(void *data, char *in, } static int i2cp_cmd_get_pseudo_id_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { /* * Reaching here means the controller wrote extra data in the command @@ -1632,7 +1651,9 @@ static int i2cp_cmd_get_pseudo_id_data_receiver(void *data, char *in, } static int i2cp_cmd_get_pseudo_id_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { ssize_t ret; struct i2cp_rsp_buffer *rsp_buf; @@ -1653,9 +1674,9 @@ static int i2cp_cmd_get_pseudo_id_cmd_completer(void *data, } ret = anprintf(&rsp_buf->buf, I2CP_MAX_MSG_BUF_SIZE, GFP_KERNEL, - "%*s%c%u", - (int)strlen(I2CP_PSEUDO_ID_REPLY_CMD), I2CP_PSEUDO_ID_REPLY_CMD, - i2cp_ctrlr_header_sep_char, pdata->id); + "%*s%c%u", (int)strlen(I2CP_PSEUDO_ID_REPLY_CMD), + I2CP_PSEUDO_ID_REPLY_CMD, i2cp_ctrlr_header_sep_char, + pdata->id); if (ret < 0) { goto fail_after_rsp_buf_alloc; } else if (ret == 0) { @@ -1681,11 +1702,11 @@ static int i2cp_cmd_get_pseudo_id_cmd_completer(void *data, mutex_unlock(&pdata->read_rsp_queue_lock); return 0; - fail_after_buf_alloc: +fail_after_buf_alloc: kfree(rsp_buf->buf); - fail_after_rsp_buf_alloc: +fail_after_rsp_buf_alloc: kfree(rsp_buf); - fail_after_rsp_wrapper_alloc: +fail_after_rsp_wrapper_alloc: kfree(rsp_wrapper); return ret; } @@ -1707,13 +1728,15 @@ static void i2cp_cmd_set_name_suffix_data_destroyer(void *data) } static int i2cp_cmd_set_name_suffix_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { return 1; } static int i2cp_cmd_set_name_suffix_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { size_t remaining; struct i2cp_cmd_set_name_suffix_data *cmd_data; @@ -1730,7 +1753,9 @@ static int i2cp_cmd_set_name_suffix_data_receiver(void *data, char *in, } static int i2cp_cmd_set_name_suffix_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { int ret; struct i2cp_cmd_set_name_suffix_data *cmd_data; @@ -1753,14 +1778,14 @@ static int i2cp_cmd_set_name_suffix_cmd_completer(void *data, cmd_data = data; ret = snprintf(pdata->i2c_adapter.name, sizeof(pdata->i2c_adapter.name), - "I2C pseudo ID %u %*s", pdata->id, - (int)cmd_data->name_suffix_len, cmd_data->name_suffix); + "I2C pseudo ID %u %*s", pdata->id, + (int)cmd_data->name_suffix_len, cmd_data->name_suffix); if (ret < 0) goto unlock; ret = 0; - unlock: +unlock: mutex_unlock(&pdata->startstop_lock); return ret; } @@ -1782,7 +1807,8 @@ static void i2cp_cmd_set_timeout_data_destroyer(void *data) } static int i2cp_cmd_set_timeout_header_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, + bool non_blocking) { int ret; struct i2cp_cmd_set_timeout_data *cmd_data; @@ -1802,7 +1828,7 @@ static int i2cp_cmd_set_timeout_header_receiver(void *data, char *in, } static int i2cp_cmd_set_timeout_data_receiver(void *data, char *in, - size_t in_size, bool non_blocking) + size_t in_size, bool non_blocking) { /* * Reaching here means the controller wrote extra data in the command @@ -1812,7 +1838,9 @@ static int i2cp_cmd_set_timeout_data_receiver(void *data, char *in, } static int i2cp_cmd_set_timeout_cmd_completer(void *data, - struct i2cp_controller *pdata, int receive_status, bool non_blocking) + struct i2cp_controller *pdata, + int receive_status, + bool non_blocking) { int ret; struct i2cp_cmd_set_timeout_data *cmd_data; @@ -1835,7 +1863,7 @@ static int i2cp_cmd_set_timeout_cmd_completer(void *data, cmd_data = data; if (cmd_data->timeout_ms < I2CP_TIMEOUT_MS_MIN || - cmd_data->timeout_ms > I2CP_TIMEOUT_MS_MAX) { + cmd_data->timeout_ms > I2CP_TIMEOUT_MS_MAX) { ret = -ERANGE; goto unlock; } @@ -1843,7 +1871,7 @@ static int i2cp_cmd_set_timeout_cmd_completer(void *data, pdata->i2c_adapter.timeout = msecs_to_jiffies(cmd_data->timeout_ms); ret = 0; - unlock: +unlock: mutex_unlock(&pdata->startstop_lock); return ret; } @@ -1914,11 +1942,12 @@ static const struct i2cp_cmd i2cp_cmds[] = { static inline bool i2cp_poll_in(struct i2cp_controller *pdata) { return pdata->rsp_invalidated || pdata->rsp_buf_remaining != 0 || - !list_empty(&pdata->read_rsp_queue_head); + !list_empty(&pdata->read_rsp_queue_head); } static inline int i2cp_fill_rsp_buf(struct i2cp_rsp *rsp_wrapper, - struct i2cp_rsp_buffer *rsp_buf, char *contents, size_t size) + struct i2cp_rsp_buffer *rsp_buf, + char *contents, size_t size) { rsp_buf->buf = kmemdup(contents, size, GFP_KERNEL); if (!rsp_buf->buf) @@ -1929,19 +1958,19 @@ static inline int i2cp_fill_rsp_buf(struct i2cp_rsp *rsp_wrapper, return 0; } -#define I2CP_FILL_RSP_BUF_WITH_LITERAL(rsp_wrapper, rsp_buf, str_literal)\ - i2cp_fill_rsp_buf(\ - rsp_wrapper, rsp_buf, str_literal, strlen(str_literal)) +#define I2CP_FILL_RSP_BUF_WITH_LITERAL(rsp_wrapper, rsp_buf, str_literal) \ + i2cp_fill_rsp_buf(rsp_wrapper, rsp_buf, str_literal, \ + strlen(str_literal)) static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, - struct i2c_msg *msgs, int num) + struct i2c_msg *msgs, int num) { int i, ret = 0; long wait_ret; size_t wrappers_length, wrapper_idx = 0, rsp_bufs_idx = 0; struct i2cp_controller *pdata; struct i2cp_rsp **rsp_wrappers; - struct i2cp_rsp_buffer *rsp_bufs[2] = {0}; + struct i2cp_rsp_buffer *rsp_bufs[2] = { 0 }; struct i2cp_rsp_master_xfer *mxfer_rsp; struct i2cp_cmd_mxfer_reply_data *cmd_data; struct i2cp_cmd_mxfer_reply *mxfer_reply; @@ -1966,8 +1995,8 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, } wrappers_length = (size_t)num + ARRAY_SIZE(rsp_bufs); - rsp_wrappers = kcalloc(wrappers_length, sizeof(*rsp_wrappers), - GFP_KERNEL); + rsp_wrappers = + kcalloc(wrappers_length, sizeof(*rsp_wrappers), GFP_KERNEL); if (!rsp_wrappers) return -ENOMEM; @@ -1981,15 +2010,15 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, init_completion(&mxfer_reply->data_filled); mutex_init(&mxfer_reply->lock); - mxfer_reply->msgs = kcalloc(num, sizeof(*mxfer_reply->msgs), - GFP_KERNEL); + mxfer_reply->msgs = + kcalloc(num, sizeof(*mxfer_reply->msgs), GFP_KERNEL); if (!mxfer_reply->msgs) { ret = -ENOMEM; goto return_after_mxfer_reply_alloc; } - mxfer_reply->completed = kcalloc(num, sizeof(*mxfer_reply->completed), - GFP_KERNEL); + mxfer_reply->completed = + kcalloc(num, sizeof(*mxfer_reply->completed), GFP_KERNEL); if (!mxfer_reply->completed) { ret = -ENOMEM; goto return_after_reply_msgs_alloc; @@ -2034,8 +2063,8 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, if (msgs[i].flags & I2C_M_RD) continue; /* Copy the data, not the address. */ - mxfer_rsp->msgs[i].buf = kmemdup(msgs[i].buf, msgs[i].len, - GFP_KERNEL); + mxfer_rsp->msgs[i].buf = + kmemdup(msgs[i].buf, msgs[i].len, GFP_KERNEL); if (!mxfer_rsp->msgs[i].buf) { ret = -ENOMEM; goto fail_after_rsp_msgs_alloc; @@ -2051,7 +2080,8 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, } ret = I2CP_FILL_RSP_BUF_WITH_LITERAL(rsp_wrappers[wrapper_idx++], - rsp_bufs[rsp_bufs_idx++], I2CP_BEGIN_MXFER_REQ_CMD); + rsp_bufs[rsp_bufs_idx++], + I2CP_BEGIN_MXFER_REQ_CMD); if (ret < 0) goto fail_after_individual_rsp_wrappers_alloc; @@ -2062,7 +2092,8 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, } ret = I2CP_FILL_RSP_BUF_WITH_LITERAL(rsp_wrappers[wrapper_idx++], - rsp_bufs[rsp_bufs_idx++], I2CP_COMMIT_MXFER_REQ_CMD); + rsp_bufs[rsp_bufs_idx++], + I2CP_COMMIT_MXFER_REQ_CMD); if (ret < 0) goto fail_after_individual_rsp_wrappers_alloc; @@ -2082,12 +2113,12 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, mxfer_reply->id = mxfer_rsp->id; list_add_tail(&mxfer_reply->reply_queue_item, - &cmd_data->reply_queue_head); + &cmd_data->reply_queue_head); ++cmd_data->reply_queue_length; for (i = 0; i < wrappers_length; ++i) { list_add_tail(&rsp_wrappers[i]->queue, - &pdata->read_rsp_queue_head); + &pdata->read_rsp_queue_head); complete(&pdata->read_rsp_queued); } pdata->read_rsp_queue_length += wrappers_length; @@ -2132,31 +2163,31 @@ static int i2cp_adapter_master_xfer(struct i2c_adapter *adap, mutex_unlock(&cmd_data->reply_queue_lock); goto return_after_reply_msgs_alloc; - fail_with_reply_queue_lock: +fail_with_reply_queue_lock: mutex_unlock(&cmd_data->reply_queue_lock); - fail_with_read_rsp_queue_lock: +fail_with_read_rsp_queue_lock: mutex_unlock(&pdata->read_rsp_queue_lock); - fail_after_individual_rsp_wrappers_alloc: +fail_after_individual_rsp_wrappers_alloc: for (i = 0; i < wrappers_length; ++i) kfree(rsp_wrappers[i]); - fail_after_rsp_msgs_alloc: +fail_after_rsp_msgs_alloc: for (i = 0; i < num; ++i) kfree(mxfer_rsp->msgs[i].buf); kfree(mxfer_rsp->msgs); - fail_after_mxfer_rsp_alloc: +fail_after_mxfer_rsp_alloc: kfree(mxfer_rsp); - fail_after_individual_rsp_bufs_alloc: +fail_after_individual_rsp_bufs_alloc: for (i = 0; i < ARRAY_SIZE(rsp_bufs); ++i) { kfree(rsp_bufs[i]->buf); kfree(rsp_bufs[i]); } - return_after_reply_completed_alloc: +return_after_reply_completed_alloc: kfree(mxfer_reply->completed); - return_after_reply_msgs_alloc: +return_after_reply_msgs_alloc: kfree(mxfer_reply->msgs); - return_after_mxfer_reply_alloc: +return_after_mxfer_reply_alloc: kfree(mxfer_reply); - return_after_rsp_wrappers_ptrs_alloc: +return_after_rsp_wrappers_ptrs_alloc: kfree(rsp_wrappers); return ret; } @@ -2183,9 +2214,8 @@ static const struct i2c_algorithm i2cp_algorithm = { /* this_pseudo->counters.lock must _not_ be held when calling this. */ static void i2cp_remove_from_counters(struct i2cp_controller *pdata, - struct i2cp_device *this_pseudo) + struct i2cp_device *this_pseudo) { - mutex_lock(&this_pseudo->counters.lock); this_pseudo->counters.all_controllers[pdata->index] = NULL; --this_pseudo->counters.count; @@ -2290,7 +2320,7 @@ static int i2cp_cdev_open(struct inode *inodep, struct file *filep) pdata->i2c_adapter.timeout = msecs_to_jiffies(i2cp_default_timeout_ms); pdata->i2c_adapter.dev.parent = &this_pseudo->device; ret = snprintf(pdata->i2c_adapter.name, sizeof(pdata->i2c_adapter.name), - "I2C pseudo ID %u", pdata->id); + "I2C pseudo ID %u", pdata->id); if (ret < 0) goto fail_after_counters_update; @@ -2298,9 +2328,9 @@ static int i2cp_cdev_open(struct inode *inodep, struct file *filep) filep->private_data = pdata; return 0; - fail_after_counters_update: +fail_after_counters_update: i2cp_remove_from_counters(pdata, this_pseudo); - fail_after_cmd_data_created: +fail_after_cmd_data_created: for (i = 0; i < num_cmd_data_created; ++i) if (i2cp_cmds[i].data_destroyer) i2cp_cmds[i].data_destroyer(pdata->cmd_data[i]); @@ -2317,7 +2347,7 @@ static int i2cp_cdev_release(struct inode *inodep, struct file *filep) pdata = filep->private_data; this_pseudo = container_of(pdata->i2c_adapter.dev.parent, - struct i2cp_device, device); + struct i2cp_device, device); /* * The select(2) man page makes it clear that the behavior of pending @@ -2378,7 +2408,8 @@ static int i2cp_cdev_release(struct inode *inodep, struct file *filep) /* The caller must hold pdata->rsp_lock. */ /* Return value is whether or not to continue in calling loop. */ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, - ssize_t *ret, bool non_blocking, struct i2cp_controller *pdata) + ssize_t *ret, bool non_blocking, + struct i2cp_controller *pdata) { long wait_ret; ssize_t copy_size; @@ -2450,9 +2481,9 @@ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, mutex_lock(&pdata->read_rsp_queue_lock); if (!list_empty(&pdata->read_rsp_queue_head)) - rsp_wrapper = list_first_entry( - &pdata->read_rsp_queue_head, - struct i2cp_rsp, queue); + rsp_wrapper = + list_first_entry(&pdata->read_rsp_queue_head, + struct i2cp_rsp, queue); /* * Avoid holding pdata->read_rsp_queue_lock while * executing a formatter, allocating memory, or doing @@ -2543,7 +2574,7 @@ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, return false; } - write_end_char: + write_end_char: copy_size = sizeof(i2cp_ctrlr_end_char); /* * This assertion is just in case someone changes @@ -2554,8 +2585,7 @@ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, * block, we already know it's greater than zero. */ BUILD_BUG_ON(copy_size != 1); - copy_ret = copy_to_user(*buf, &i2cp_ctrlr_end_char, - copy_size); + copy_ret = copy_to_user(*buf, &i2cp_ctrlr_end_char, copy_size); copy_size -= copy_ret; /* * After writing to the userspace buffer, we need to @@ -2571,7 +2601,7 @@ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, } copy_size = max_t(ssize_t, 0, - min_t(ssize_t, *count, pdata->rsp_buf_remaining)); + min_t(ssize_t, *count, pdata->rsp_buf_remaining)); copy_ret = copy_to_user(*buf, pdata->rsp_buf_pos, copy_size); copy_size -= copy_ret; pdata->rsp_buf_remaining -= copy_size; @@ -2584,14 +2614,14 @@ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, pdata->rsp_buf_pos = NULL; } - /* - * When jumping here, the following variables should be set: - * copy_ret: Return value from copy_to_user() (bytes not copied). - * copy_size: The number of bytes successfully copied by copy_to_user(). In - * other words, this should be the size arg to copy_to_user() minus its - * return value (bytes not copied). - */ - after_copy_to_user: +/* + * When jumping here, the following variables should be set: + * copy_ret: Return value from copy_to_user() (bytes not copied). + * copy_size: The number of bytes successfully copied by copy_to_user(). In + * other words, this should be the size arg to copy_to_user() minus its + * return value (bytes not copied). + */ +after_copy_to_user: *ret += copy_size; *count -= copy_size; *buf += copy_size; @@ -2600,7 +2630,7 @@ static bool i2cp_cdev_read_iteration(char __user **buf, size_t *count, } static ssize_t i2cp_cdev_read(struct file *filep, char __user *buf, - size_t count, loff_t *f_ps) + size_t count, loff_t *f_ps) { ssize_t ret = 0; bool non_blocking; @@ -2638,20 +2668,20 @@ static ssize_t i2cp_cdev_read(struct file *filep, char __user *buf, goto unlock; } - while (count > 0 && i2cp_cdev_read_iteration( - &buf, &count, &ret, non_blocking, pdata)) + while (count > 0 && i2cp_cdev_read_iteration(&buf, &count, &ret, + non_blocking, pdata)) ; - unlock: +unlock: mutex_unlock(&pdata->rsp_lock); return ret; } /* Must be called with pdata->cmd_lock held. */ /* Must never consume past first i2cp_ctrlr_end_char in @start. */ -static ssize_t i2cp_receive_ctrlr_cmd_header( - struct i2cp_controller *pdata, char *start, size_t remaining, - bool non_blocking) +static ssize_t i2cp_receive_ctrlr_cmd_header(struct i2cp_controller *pdata, + char *start, size_t remaining, + bool non_blocking) { int found_deliminator_char = 0; int i, cmd_idx; @@ -2665,7 +2695,7 @@ static ssize_t i2cp_receive_ctrlr_cmd_header( start[i] == i2cp_ctrlr_header_sep_char) { found_deliminator_char = 1; break; - } + } if (i <= buf_remaining) { copy_size = i; @@ -2695,7 +2725,7 @@ static ssize_t i2cp_receive_ctrlr_cmd_header( for (i = 0; i < ARRAY_SIZE(i2cp_cmds); ++i) if (i2cp_cmds[i].cmd_size == pdata->cmd_size && !memcmp(i2cp_cmds[i].cmd_string, pdata->cmd_buf, - pdata->cmd_size)) + pdata->cmd_size)) break; if (i >= ARRAY_SIZE(i2cp_cmds)) { /* unrecognized command */ @@ -2725,7 +2755,7 @@ static ssize_t i2cp_receive_ctrlr_cmd_header( } } - clear_buffer: +clear_buffer: pdata->cmd_size = 0; /* * Ensure a trailing null character for the next header_receiver() or @@ -2745,7 +2775,8 @@ static ssize_t i2cp_receive_ctrlr_cmd_header( /* Must be called with pdata->cmd_lock held. */ /* Must never consume past first i2cp_ctrlr_end_char in @start. */ static ssize_t i2cp_receive_ctrlr_cmd_data(struct i2cp_controller *pdata, - char *start, size_t remaining, bool non_blocking) + char *start, size_t remaining, + bool non_blocking) { ssize_t i, ret, size_holder; int cmd_idx; @@ -2755,13 +2786,14 @@ static ssize_t i2cp_receive_ctrlr_cmd_data(struct i2cp_controller *pdata, if (cmd_idx < 0) return -EINVAL; - size_holder = min_t(size_t, + size_holder = min_t( + size_t, (I2CP_CTRLR_CMD_LIMIT - (I2CP_CTRLR_CMD_LIMIT % pdata->cmd_data_increment)) - - pdata->cmd_size, - (((pdata->cmd_size + remaining) / - pdata->cmd_data_increment) * - pdata->cmd_data_increment) - pdata->cmd_size); + pdata->cmd_size, + (((pdata->cmd_size + remaining) / pdata->cmd_data_increment) * + pdata->cmd_data_increment) - + pdata->cmd_size); /* Size of current buffer plus all remaining write bytes. */ size_holder = pdata->cmd_size + remaining; @@ -2791,8 +2823,10 @@ static ssize_t i2cp_receive_ctrlr_cmd_data(struct i2cp_controller *pdata, * buffer to end up with if there were unlimited write bytes * remaining (computed in-line below). */ - size_holder = min_t(ssize_t, size_holder, (I2CP_CTRLR_CMD_LIMIT - ( - I2CP_CTRLR_CMD_LIMIT % pdata->cmd_data_increment))); + size_holder = + min_t(ssize_t, size_holder, + (I2CP_CTRLR_CMD_LIMIT - + (I2CP_CTRLR_CMD_LIMIT % pdata->cmd_data_increment))); /* * Subtract the existing buffer size to get the number of bytes we * actually want to copy from the remaining write bytes in this loop @@ -2843,7 +2877,7 @@ static ssize_t i2cp_receive_ctrlr_cmd_data(struct i2cp_controller *pdata, /* Must be called with pdata->cmd_lock held. */ static int i2cp_receive_ctrlr_cmd_complete(struct i2cp_controller *pdata, - bool non_blocking) + bool non_blocking) { int ret = 0, cmd_idx; @@ -2851,8 +2885,9 @@ static int i2cp_receive_ctrlr_cmd_complete(struct i2cp_controller *pdata, cmd_idx = pdata->cmd_idx_plus_one - 1; if (cmd_idx >= 0 && i2cp_cmds[cmd_idx].cmd_completer) { - ret = i2cp_cmds[cmd_idx].cmd_completer(pdata->cmd_data[cmd_idx], - pdata, pdata->cmd_receive_status, non_blocking); + ret = i2cp_cmds[cmd_idx].cmd_completer( + pdata->cmd_data[cmd_idx], pdata, + pdata->cmd_receive_status, non_blocking); if (ret > 0) ret = 0; } @@ -2872,7 +2907,7 @@ static int i2cp_receive_ctrlr_cmd_complete(struct i2cp_controller *pdata, } static ssize_t i2cp_cdev_write(struct file *filep, const char __user *buf, - size_t count, loff_t *f_ps) + size_t count, loff_t *f_ps) { ssize_t ret = 0; bool non_blocking; @@ -2949,8 +2984,8 @@ static ssize_t i2cp_cdev_write(struct file *filep, const char __user *buf, start += ret; if (ret > 0 && start[-1] == i2cp_ctrlr_end_char) { - ret = i2cp_receive_ctrlr_cmd_complete( - pdata, non_blocking); + ret = i2cp_receive_ctrlr_cmd_complete(pdata, + non_blocking); if (ret < 0) break; } @@ -2963,7 +2998,7 @@ static ssize_t i2cp_cdev_write(struct file *filep, const char __user *buf, /* If successful the whole write is always consumed. */ ret = count; - free_kbuf: +free_kbuf: kfree(kbuf); return ret; } @@ -3056,7 +3091,7 @@ static const struct file_operations i2cp_fileops = { }; static ssize_t i2cp_limit_show(struct device *dev, - struct device_attribute *attr, char *buf) + struct device_attribute *attr, char *buf) { int ret; @@ -3075,7 +3110,7 @@ static struct device_attribute i2cp_limit_dev_attr = { }; static ssize_t i2cp_count_show(struct device *dev, - struct device_attribute *attr, char *buf) + struct device_attribute *attr, char *buf) { int count, ret; struct i2cp_device *this_pseudo; @@ -3138,9 +3173,9 @@ static int __init i2cp_init(void) int ret = -1; if (i2cp_limit < I2CP_ADAPTERS_MIN || i2cp_limit > I2CP_ADAPTERS_MAX) { - pr_err("%s: i2cp_limit=%u, must be in range [" - STR(I2CP_ADAPTERS_MIN) ", " STR(I2CP_ADAPTERS_MAX) - "]\n", __func__, i2cp_limit); + pr_err("%s: i2cp_limit=%u, must be in range [" STR( + I2CP_ADAPTERS_MIN) ", " STR(I2CP_ADAPTERS_MAX) "]\n", + __func__, i2cp_limit); return -EINVAL; } @@ -3151,7 +3186,7 @@ static int __init i2cp_init(void) i2cp_class->dev_groups = i2cp_device_sysfs_groups; ret = alloc_chrdev_region(&i2cp_dev_num, I2CP_CDEV_BASEMINOR, - I2CP_CDEV_COUNT, I2CP_CHRDEV_NAME); + I2CP_CDEV_COUNT, I2CP_CHRDEV_NAME); if (ret < 0) goto fail_after_class_create; @@ -3171,8 +3206,9 @@ static int __init i2cp_init(void) goto fail_after_device_init; mutex_init(&i2cp_device->counters.lock); - i2cp_device->counters.all_controllers = kcalloc(i2cp_limit, - sizeof(*i2cp_device->counters.all_controllers), GFP_KERNEL); + i2cp_device->counters.all_controllers = kcalloc( + i2cp_limit, sizeof(*i2cp_device->counters.all_controllers), + GFP_KERNEL); if (!i2cp_device->counters.all_controllers) { ret = -ENOMEM; goto fail_after_device_init; @@ -3187,11 +3223,11 @@ static int __init i2cp_init(void) return 0; - fail_after_device_init: +fail_after_device_init: put_device(&i2cp_device->device); - fail_after_chrdev_register: +fail_after_chrdev_register: unregister_chrdev_region(i2cp_dev_num, I2CP_CDEV_COUNT); - fail_after_class_create: +fail_after_class_create: i2c_p_class_destroy(); return ret; } diff --git a/extra/lightbar/Makefile b/extra/lightbar/Makefile index ce84428869..628f19ab81 100644 --- a/extra/lightbar/Makefile +++ b/extra/lightbar/Makefile @@ -1,4 +1,4 @@ -# Copyright 2014 The Chromium OS Authors. All rights reserved. +# Copyright 2014 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/lightbar/input.c b/extra/lightbar/input.c index e6c5485e39..5b605600ea 100644 --- a/extra/lightbar/input.c +++ b/extra/lightbar/input.c @@ -1,5 +1,5 @@ /* - * Copyright 2014 The Chromium OS Authors. All rights reserved. + * Copyright 2014 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -32,7 +32,7 @@ char *get_input(const char *prompt) return line; } -#else /* no readline */ +#else /* no readline */ char *get_input(const char *prompt) { diff --git a/extra/lightbar/main.c b/extra/lightbar/main.c index ef011d35f1..321c0c73d2 100644 --- a/extra/lightbar/main.c +++ b/extra/lightbar/main.c @@ -1,5 +1,5 @@ /* - * Copyright 2014 The Chromium OS Authors. All rights reserved. + * Copyright 2014 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -55,7 +55,7 @@ void *entry_lightbar(void *ptr) /* timespec uses nanoseconds */ #define TS_USEC 1000L #define TS_MSEC 1000000L -#define TS_SEC 1000000000L +#define TS_SEC 1000000000L static void timespec_incr(struct timespec *v, time_t secs, long nsecs) { @@ -66,7 +66,6 @@ static void timespec_incr(struct timespec *v, time_t secs, long nsecs) v->tv_nsec %= TS_SEC; } - static pthread_mutex_t task_mutex = PTHREAD_MUTEX_INITIALIZER; static pthread_cond_t task_cond = PTHREAD_COND_INITIALIZER; static uint32_t task_event; @@ -82,8 +81,8 @@ uint32_t task_wait_event(int timeout_us) clock_gettime(CLOCK_REALTIME, &t); timespec_incr(&t, timeout_us / SECOND, timeout_us * TS_USEC); - if (ETIMEDOUT == pthread_cond_timedwait(&task_cond, - &task_mutex, &t)) + if (ETIMEDOUT == + pthread_cond_timedwait(&task_cond, &task_mutex, &t)) task_event |= TASK_EVENT_TIMER; } else { pthread_cond_wait(&task_cond, &task_mutex); @@ -96,7 +95,7 @@ uint32_t task_wait_event(int timeout_us) } void task_set_event(task_id_t tskid, /* always LIGHTBAR */ - uint32_t event) + uint32_t event) { pthread_mutex_lock(&task_mutex); task_event = event; @@ -104,8 +103,6 @@ void task_set_event(task_id_t tskid, /* always LIGHTBAR */ pthread_mutex_unlock(&task_mutex); } - - /* Stubbed functions */ void cprintf(int zero, const char *fmt, ...) @@ -146,7 +143,7 @@ timestamp_t get_time(void) clock_gettime(CLOCK_REALTIME, &t_start); clock_gettime(CLOCK_REALTIME, &t); ret.val = (t.tv_sec - t_start.tv_sec) * SECOND + - (t.tv_nsec - t_start.tv_nsec) / TS_USEC; + (t.tv_nsec - t_start.tv_nsec) / TS_USEC; return ret; } @@ -162,8 +159,7 @@ uint8_t *system_get_jump_tag(uint16_t tag, int *version, int *size) } /* Copied from util/ectool.c */ -int lb_read_params_from_file(const char *filename, - struct lightbar_params_v1 *p) +int lb_read_params_from_file(const char *filename, struct lightbar_params_v1 *p) { FILE *fp; char buf[80]; @@ -175,46 +171,65 @@ int lb_read_params_from_file(const char *filename, fp = fopen(filename, "rb"); if (!fp) { - fprintf(stderr, "Can't open %s: %s\n", - filename, strerror(errno)); + fprintf(stderr, "Can't open %s: %s\n", filename, + strerror(errno)); return 1; } /* We must read the correct number of params from each line */ -#define READ(N) do { \ - line++; \ - want = (N); \ - got = -1; \ - if (!fgets(buf, sizeof(buf), fp)) \ - goto done; \ - got = sscanf(buf, "%i %i %i %i", \ - &val[0], &val[1], &val[2], &val[3]); \ - if (want != got) \ - goto done; \ +#define READ(N) \ + do { \ + line++; \ + want = (N); \ + got = -1; \ + if (!fgets(buf, sizeof(buf), fp)) \ + goto done; \ + got = sscanf(buf, "%i %i %i %i", &val[0], &val[1], &val[2], \ + &val[3]); \ + if (want != got) \ + goto done; \ } while (0) - /* Do it */ - READ(1); p->google_ramp_up = val[0]; - READ(1); p->google_ramp_down = val[0]; - READ(1); p->s3s0_ramp_up = val[0]; - READ(1); p->s0_tick_delay[0] = val[0]; - READ(1); p->s0_tick_delay[1] = val[0]; - READ(1); p->s0a_tick_delay[0] = val[0]; - READ(1); p->s0a_tick_delay[1] = val[0]; - READ(1); p->s0s3_ramp_down = val[0]; - READ(1); p->s3_sleep_for = val[0]; - READ(1); p->s3_ramp_up = val[0]; - READ(1); p->s3_ramp_down = val[0]; - READ(1); p->tap_tick_delay = val[0]; - READ(1); p->tap_gate_delay = val[0]; - READ(1); p->tap_display_time = val[0]; - - READ(1); p->tap_pct_red = val[0]; - READ(1); p->tap_pct_green = val[0]; - READ(1); p->tap_seg_min_on = val[0]; - READ(1); p->tap_seg_max_on = val[0]; - READ(1); p->tap_seg_osc = val[0]; + READ(1); + p->google_ramp_up = val[0]; + READ(1); + p->google_ramp_down = val[0]; + READ(1); + p->s3s0_ramp_up = val[0]; + READ(1); + p->s0_tick_delay[0] = val[0]; + READ(1); + p->s0_tick_delay[1] = val[0]; + READ(1); + p->s0a_tick_delay[0] = val[0]; + READ(1); + p->s0a_tick_delay[1] = val[0]; + READ(1); + p->s0s3_ramp_down = val[0]; + READ(1); + p->s3_sleep_for = val[0]; + READ(1); + p->s3_ramp_up = val[0]; + READ(1); + p->s3_ramp_down = val[0]; + READ(1); + p->tap_tick_delay = val[0]; + READ(1); + p->tap_gate_delay = val[0]; + READ(1); + p->tap_display_time = val[0]; + + READ(1); + p->tap_pct_red = val[0]; + READ(1); + p->tap_pct_green = val[0]; + READ(1); + p->tap_seg_min_on = val[0]; + READ(1); + p->tap_seg_max_on = val[0]; + READ(1); + p->tap_seg_osc = val[0]; READ(3); p->tap_idx[0] = val[0]; p->tap_idx[1] = val[1]; @@ -298,19 +313,18 @@ int lb_load_program(const char *filename, struct lightbar_program *prog) fp = fopen(filename, "rb"); if (!fp) { - fprintf(stderr, "Can't open %s: %s\n", - filename, strerror(errno)); + fprintf(stderr, "Can't open %s: %s\n", filename, + strerror(errno)); return 1; } rc = fseek(fp, 0, SEEK_END); if (rc) { - fprintf(stderr, "Couldn't find end of file %s", - filename); + fprintf(stderr, "Couldn't find end of file %s", filename); fclose(fp); return 1; } - rc = (int) ftell(fp); + rc = (int)ftell(fp); if (rc > EC_LB_PROG_LEN) { fprintf(stderr, "File %s is too long, aborting\n", filename); fclose(fp); diff --git a/extra/lightbar/simulation.h b/extra/lightbar/simulation.h index edbe5f340e..c77583e6c9 100644 --- a/extra/lightbar/simulation.h +++ b/extra/lightbar/simulation.h @@ -1,5 +1,5 @@ /* - * Copyright 2014 The Chromium OS Authors. All rights reserved. + * Copyright 2014 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -38,13 +38,12 @@ int fake_consolecmd_lightbar(int argc, char *argv[]); #define CONFIG_LIGHTBAR_POWER_RAILS #endif - /* Stuff that's too interleaved with the rest of the EC to just include */ /* Test an important condition at compile time, not run time */ -#define _BA1_(cond, line) \ - extern int __build_assertion_ ## line[1 - 2*!(cond)] \ - __attribute__ ((unused)) +#define _BA1_(cond, line) \ + extern int __build_assertion_##line[1 - 2 * !(cond)] \ + __attribute__((unused)) #define _BA0_(c, x) _BA1_(c, x) #define BUILD_ASSERT(cond) _BA0_(cond, __LINE__) @@ -61,14 +60,14 @@ void cprints(int zero, const char *fmt, ...); /* Task events */ #define TASK_EVENT_CUSTOM_BIT(x) BUILD_CHECK_INLINE(BIT(x), BIT(x) & 0x0fffffff) -#define TASK_EVENT_I2C_IDLE 0x10000000 -#define TASK_EVENT_WAKE 0x20000000 -#define TASK_EVENT_MUTEX 0x40000000 -#define TASK_EVENT_TIMER 0x80000000 +#define TASK_EVENT_I2C_IDLE 0x10000000 +#define TASK_EVENT_WAKE 0x20000000 +#define TASK_EVENT_MUTEX 0x40000000 +#define TASK_EVENT_TIMER 0x80000000 /* Time units in usecs */ -#define MSEC 1000 -#define SECOND 1000000 +#define MSEC 1000 +#define SECOND 1000000 #define TASK_ID_LIGHTBAR 0 #define CC_LIGHTBAR 0 @@ -103,15 +102,22 @@ int system_add_jump_tag(uint16_t tag, int version, int size, const void *data); uint8_t *system_get_jump_tag(uint16_t tag, int *version, int *size); /* Export unused static functions to avoid compiler warnings. */ -#define DECLARE_HOOK(X, fn, Y) \ - void fake_hook_##fn(void) { fn(); } +#define DECLARE_HOOK(X, fn, Y) \ + void fake_hook_##fn(void) \ + { \ + fn(); \ + } -#define DECLARE_HOST_COMMAND(X, fn, Y) \ +#define DECLARE_HOST_COMMAND(X, fn, Y) \ enum ec_status fake_hostcmd_##fn(struct host_cmd_handler_args *args) \ - { return fn(args); } + { \ + return fn(args); \ + } -#define DECLARE_CONSOLE_COMMAND(X, fn, Y...) \ +#define DECLARE_CONSOLE_COMMAND(X, fn, Y...) \ int fake_consolecmd_##X(int argc, char *argv[]) \ - { return fn(argc, argv); } + { \ + return fn(argc, argv); \ + } -#endif /* __EXTRA_SIMULATION_H */ +#endif /* __EXTRA_SIMULATION_H */ diff --git a/extra/lightbar/windows.c b/extra/lightbar/windows.c index 115074363c..e0b14fae42 100644 --- a/extra/lightbar/windows.c +++ b/extra/lightbar/windows.c @@ -1,5 +1,5 @@ /* - * Copyright 2014 The Chromium OS Authors. All rights reserved. + * Copyright 2014 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -42,8 +42,8 @@ void init_windows(void) /* Get a colormap */ colormap_id = xcb_generate_id(c); - xcb_create_colormap(c, XCB_COLORMAP_ALLOC_NONE, - colormap_id, screen->root, screen->root_visual); + xcb_create_colormap(c, XCB_COLORMAP_ALLOC_NONE, colormap_id, + screen->root, screen->root_visual); /* Create foreground GC */ foreground = xcb_generate_id(c); @@ -57,16 +57,16 @@ void init_windows(void) mask = XCB_CW_BACK_PIXEL | XCB_CW_EVENT_MASK; values[0] = screen->black_pixel; values[1] = XCB_EVENT_MASK_EXPOSURE | XCB_EVENT_MASK_BUTTON_PRESS; - xcb_create_window(c, /* Connection */ - XCB_COPY_FROM_PARENT, /* depth */ - win, /* window Id */ - screen->root, /* parent window */ - 0, 0, /* x, y */ - win_w, win_h, /* width, height */ - 10, /* border_width */ + xcb_create_window(c, /* Connection */ + XCB_COPY_FROM_PARENT, /* depth */ + win, /* window Id */ + screen->root, /* parent window */ + 0, 0, /* x, y */ + win_w, win_h, /* width, height */ + 10, /* border_width */ XCB_WINDOW_CLASS_INPUT_OUTPUT, /* class */ - screen->root_visual, /* visual */ - mask, values); /* masks */ + screen->root_visual, /* visual */ + mask, values); /* masks */ /* Map the window on the screen */ xcb_map_window(c, win); @@ -88,10 +88,10 @@ void cleanup(void) /* xcb likes 16-bit colors */ uint16_t leds[NUM_LEDS][3] = { - {0xffff, 0x0000, 0x0000}, - {0x0000, 0xffff, 0x0000}, - {0x0000, 0x0000, 0xffff}, - {0xffff, 0xffff, 0x0000}, + { 0xffff, 0x0000, 0x0000 }, + { 0x0000, 0xffff, 0x0000 }, + { 0x0000, 0x0000, 0xffff }, + { 0xffff, 0xffff, 0x0000 }, }; pthread_mutex_t leds_mutex = PTHREAD_MUTEX_INITIALIZER; @@ -101,10 +101,8 @@ void change_gc_color(uint16_t red, uint16_t green, uint16_t blue) uint32_t values[2]; xcb_alloc_color_reply_t *reply; - reply = xcb_alloc_color_reply(c, - xcb_alloc_color(c, colormap_id, - red, green, blue), - NULL); + reply = xcb_alloc_color_reply( + c, xcb_alloc_color(c, colormap_id, red, green, blue), NULL); assert(reply); mask = XCB_GC_FOREGROUND; @@ -116,8 +114,8 @@ void change_gc_color(uint16_t red, uint16_t green, uint16_t blue) void update_window(void) { xcb_segment_t segments[] = { - {0, 0, win_w, win_h}, - {0, win_h, win_w, 0}, + { 0, 0, win_w, win_h }, + { 0, win_h, win_w, 0 }, }; xcb_rectangle_t rect; int w = win_w / NUM_LEDS; @@ -135,8 +133,7 @@ void update_window(void) rect.width = w; rect.height = win_h; - change_gc_color(copyleds[i][0], - copyleds[i][1], + change_gc_color(copyleds[i][0], copyleds[i][1], copyleds[i][2]); xcb_poly_fill_rectangle(c, win, foreground, 1, &rect); @@ -184,8 +181,6 @@ void setrgb(int led, int red, int green, int blue) /*****************************************************************************/ /* lb_common stubs */ - - /* Brightness serves no purpose here. It's automatic on the Chromebook. */ static int brightness = 0xc0; void lb_set_brightness(unsigned int newval) @@ -238,14 +233,13 @@ void lb_hc_cmd_dump(struct ec_response_lightbar *out) printf("lightbar is %s\n", fake_power ? "on" : "off"); memset(out, fake_power, sizeof(*out)); }; -void lb_hc_cmd_reg(const struct ec_params_lightbar *in) { }; +void lb_hc_cmd_reg(const struct ec_params_lightbar *in){}; int lb_power(int enabled) { return fake_power; } - /*****************************************************************************/ /* Event handling stuff */ @@ -257,7 +251,6 @@ void *entry_windows(void *ptr) int chg = 1; while ((e = xcb_wait_for_event(c))) { - switch (e->response_type & ~0x80) { case XCB_EXPOSE: ev = (xcb_expose_event_t *)e; diff --git a/extra/rma_reset/Makefile b/extra/rma_reset/Makefile index 4a640c5b4c..d4644e91c8 100644 --- a/extra/rma_reset/Makefile +++ b/extra/rma_reset/Makefile @@ -1,4 +1,4 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. @@ -19,7 +19,7 @@ CFLAGS := -std=gnu99 \ -Wredundant-decls \ -Wmissing-declarations -ifeq ($(DEBUG),1) +ifneq ($(DEBUG),) CFLAGS += -g -O0 else CFLAGS += -O3 diff --git a/extra/rma_reset/board.h b/extra/rma_reset/board.h index f969ad0c56..38e3e7b382 100644 --- a/extra/rma_reset/board.h +++ b/extra/rma_reset/board.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The Chromium OS Authors. All rights reserved. +/* Copyright 2017 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ diff --git a/extra/rma_reset/rma_reset.c b/extra/rma_reset/rma_reset.c index fe1eb5e909..d437b63f1a 100644 --- a/extra/rma_reset/rma_reset.c +++ b/extra/rma_reset/rma_reset.c @@ -1,4 +1,4 @@ -/* Copyright 2017 The Chromium OS Authors. All rights reserved. +/* Copyright 2017 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -27,24 +27,22 @@ #define EC_COORDINATE_SZ 32 #define EC_PRIV_KEY_SZ 32 #define EC_P256_UNCOMPRESSED_PUB_KEY_SZ (EC_COORDINATE_SZ * 2 + 1) -#define EC_P256_COMPRESSED_PUB_KEY_SZ (EC_COORDINATE_SZ + 1) +#define EC_P256_COMPRESSED_PUB_KEY_SZ (EC_COORDINATE_SZ + 1) #define SERVER_ADDRESS \ "https://www.google.com/chromeos/partner/console/cr50reset/request" /* Test server keys for x25519 and p256 curves. */ static const uint8_t rma_test_server_x25519_public_key[] = { - 0x03, 0xae, 0x2d, 0x2c, 0x06, 0x23, 0xe0, 0x73, - 0x0d, 0xd3, 0xb7, 0x92, 0xac, 0x54, 0xc5, 0xfd, - 0x7e, 0x9c, 0xf0, 0xa8, 0xeb, 0x7e, 0x2a, 0xb5, - 0xdb, 0xf4, 0x79, 0x5f, 0x8a, 0x0f, 0x28, 0x3f + 0x03, 0xae, 0x2d, 0x2c, 0x06, 0x23, 0xe0, 0x73, 0x0d, 0xd3, 0xb7, + 0x92, 0xac, 0x54, 0xc5, 0xfd, 0x7e, 0x9c, 0xf0, 0xa8, 0xeb, 0x7e, + 0x2a, 0xb5, 0xdb, 0xf4, 0x79, 0x5f, 0x8a, 0x0f, 0x28, 0x3f }; static const uint8_t rma_test_server_x25519_private_key[] = { - 0x47, 0x3b, 0xa5, 0xdb, 0xc4, 0xbb, 0xd6, 0x77, - 0x20, 0xbd, 0xd8, 0xbd, 0xc8, 0x7a, 0xbb, 0x07, - 0x03, 0x79, 0xba, 0x7b, 0x52, 0x8c, 0xec, 0xb3, - 0x4d, 0xaa, 0x69, 0xf5, 0x65, 0xb4, 0x31, 0xad + 0x47, 0x3b, 0xa5, 0xdb, 0xc4, 0xbb, 0xd6, 0x77, 0x20, 0xbd, 0xd8, + 0xbd, 0xc8, 0x7a, 0xbb, 0x07, 0x03, 0x79, 0xba, 0x7b, 0x52, 0x8c, + 0xec, 0xb3, 0x4d, 0xaa, 0x69, 0xf5, 0x65, 0xb4, 0x31, 0xad }; #define RMA_TEST_SERVER_X25519_KEY_ID 0x10 @@ -57,10 +55,9 @@ static const uint8_t rma_test_server_x25519_private_key[] = { * openssl ec -in key.pem -text -noout */ static const uint8_t rma_test_server_p256_private_key[] = { - 0x54, 0xb0, 0x82, 0x92, 0x54, 0x92, 0xfc, 0x4a, - 0xa7, 0x6b, 0xea, 0x8f, 0x30, 0xcc, 0xf7, 0x3d, - 0xa2, 0xf6, 0xa7, 0xad, 0xf0, 0xec, 0x7d, 0xe9, - 0x26, 0x75, 0xd1, 0xec, 0xde, 0x20, 0x8f, 0x81 + 0x54, 0xb0, 0x82, 0x92, 0x54, 0x92, 0xfc, 0x4a, 0xa7, 0x6b, 0xea, + 0x8f, 0x30, 0xcc, 0xf7, 0x3d, 0xa2, 0xf6, 0xa7, 0xad, 0xf0, 0xec, + 0x7d, 0xe9, 0x26, 0x75, 0xd1, 0xec, 0xde, 0x20, 0x8f, 0x81 }; /* @@ -68,15 +65,12 @@ static const uint8_t rma_test_server_p256_private_key[] = { * prefix, 65 bytes total. */ static const uint8_t rma_test_server_p256_public_key[] = { - 0x04, 0xe7, 0xbe, 0x37, 0xaa, 0x68, 0xca, 0xcc, - 0x68, 0xf4, 0x8c, 0x56, 0x65, 0x5a, 0xcb, 0xf8, - 0xf4, 0x65, 0x3c, 0xd3, 0xc6, 0x1b, 0xae, 0xd6, - 0x51, 0x7a, 0xcc, 0x00, 0x8d, 0x59, 0x6d, 0x1b, - 0x0a, 0x66, 0xe8, 0x68, 0x5e, 0x6a, 0x82, 0x19, - 0x81, 0x76, 0x84, 0x92, 0x7f, 0x8d, 0xb2, 0xbe, - 0xf5, 0x39, 0x50, 0xd5, 0xfe, 0xee, 0x00, 0x67, - 0xcf, 0x40, 0x5f, 0x68, 0x12, 0x83, 0x4f, 0xa4, - 0x35 + 0x04, 0xe7, 0xbe, 0x37, 0xaa, 0x68, 0xca, 0xcc, 0x68, 0xf4, 0x8c, + 0x56, 0x65, 0x5a, 0xcb, 0xf8, 0xf4, 0x65, 0x3c, 0xd3, 0xc6, 0x1b, + 0xae, 0xd6, 0x51, 0x7a, 0xcc, 0x00, 0x8d, 0x59, 0x6d, 0x1b, 0x0a, + 0x66, 0xe8, 0x68, 0x5e, 0x6a, 0x82, 0x19, 0x81, 0x76, 0x84, 0x92, + 0x7f, 0x8d, 0xb2, 0xbe, 0xf5, 0x39, 0x50, 0xd5, 0xfe, 0xee, 0x00, + 0x67, 0xcf, 0x40, 0x5f, 0x68, 0x12, 0x83, 0x4f, 0xa4, 0x35 }; #define RMA_TEST_SERVER_P256_KEY_ID 0x20 @@ -84,8 +78,8 @@ static const uint8_t rma_test_server_p256_public_key[] = { /* Default values which can change based on command line arguments. */ static uint8_t server_key_id = RMA_TEST_SERVER_X25519_KEY_ID; -static uint8_t board_id[4] = {'Z', 'Z', 'C', 'R'}; -static uint8_t device_id[8] = {'T', 'H', 'X', 1, 1, 3, 8, 0xfe}; +static uint8_t board_id[4] = { 'Z', 'Z', 'C', 'R' }; +static uint8_t device_id[8] = { 'T', 'H', 'X', 1, 1, 3, 8, 0xfe }; static uint8_t hw_id[20] = "TESTSAMUS1234"; static char challenge[RMA_CHALLENGE_BUF_SIZE]; @@ -95,20 +89,15 @@ static char *progname; static char *short_opts = "a:b:c:d:hpk:tw:"; static const struct option long_opts[] = { /* name hasarg *flag val */ - {"auth_code", 1, NULL, 'a'}, - {"board_id", 1, NULL, 'b'}, - {"challenge", 1, NULL, 'c'}, - {"device_id", 1, NULL, 'd'}, - {"help", 0, NULL, 'h'}, - {"hw_id", 1, NULL, 'w'}, - {"key_id", 1, NULL, 'k'}, - {"p256", 0, NULL, 'p'}, - {"test", 0, NULL, 't'}, - {}, + { "auth_code", 1, NULL, 'a' }, { "board_id", 1, NULL, 'b' }, + { "challenge", 1, NULL, 'c' }, { "device_id", 1, NULL, 'd' }, + { "help", 0, NULL, 'h' }, { "hw_id", 1, NULL, 'w' }, + { "key_id", 1, NULL, 'k' }, { "p256", 0, NULL, 'p' }, + { "test", 0, NULL, 't' }, {}, }; void panic_assert_fail(const char *fname, int linenum); -void rand_bytes(void *buffer, size_t len); +void trng_rand_bytes(void *buffer, size_t len); int safe_memcmp(const void *s1, const void *s2, size_t size); void panic_assert_fail(const char *fname, int linenum) @@ -131,7 +120,7 @@ int safe_memcmp(const void *s1, const void *s2, size_t size) return result != 0; } -void rand_bytes(void *buffer, size_t len) +void trng_rand_bytes(void *buffer, size_t len) { RAND_bytes(buffer, len); } @@ -173,8 +162,8 @@ static void p256_key_and_secret_seed(uint8_t pub_key[32], /* Extract public key into an octal array. */ EC_POINT_point2oct(group, EC_KEY_get0_public_key(key), - POINT_CONVERSION_UNCOMPRESSED, - buf, sizeof(buf), NULL); + POINT_CONVERSION_UNCOMPRESSED, buf, + sizeof(buf), NULL); /* If Y coordinate is an odd value, we are done. */ } while (!(buf[sizeof(buf) - 1] & 1)); @@ -195,8 +184,8 @@ static void p256_key_and_secret_seed(uint8_t pub_key[32], secret_point = EC_POINT_new(group); /* Multiply server public key by our private key. */ - EC_POINT_mul(group, secret_point, 0, pub, - EC_KEY_get0_private_key(key), 0); + EC_POINT_mul(group, secret_point, 0, pub, EC_KEY_get0_private_key(key), + 0); /* Pull the result back into the octal buffer. */ EC_POINT_point2oct(group, secret_point, POINT_CONVERSION_UNCOMPRESSED, @@ -252,9 +241,8 @@ static void p256_calculate_secret(uint8_t secret[32], secret_point = EC_POINT_new(group); /* Multiply client's point by our private key. */ - EC_POINT_mul(group, secret_point, 0, - EC_KEY_get0_public_key(key), - priv, 0); + EC_POINT_mul(group, secret_point, 0, EC_KEY_get0_public_key(key), priv, + 0); /* Pull the result back into the octal buffer. */ EC_POINT_point2oct(group, secret_point, POINT_CONVERSION_UNCOMPRESSED, @@ -274,7 +262,7 @@ static int rma_server_side(const char *generated_challenge) /* Convert the challenge back into binary */ if (base32_decode(cptr, 8 * sizeof(c), generated_challenge, 9) != - 8 * sizeof(c)) { + 8 * sizeof(c)) { printf("Error decoding challenge\n"); return -1; } @@ -311,8 +299,8 @@ static int rma_server_side(const char *generated_challenge) * and DeviceID. */ hmac_SHA256(hmac, secret, sizeof(secret), cptr + 1, sizeof(c) - 1); - if (base32_encode(authcode, RMA_AUTHCODE_BUF_SIZE, - hmac, RMA_AUTHCODE_CHARS * 5, 0)) { + if (base32_encode(authcode, RMA_AUTHCODE_BUF_SIZE, hmac, + RMA_AUTHCODE_CHARS * 5, 0)) { printf("Error encoding auth code\n"); return -1; } @@ -323,7 +311,7 @@ static int rma_server_side(const char *generated_challenge) static int rma_create_test_challenge(int p256_mode) { - uint8_t temp[32]; /* Private key or HMAC */ + uint8_t temp[32]; /* Private key or HMAC */ uint8_t secret_seed[32]; struct rma_challenge c; uint8_t *cptr = (uint8_t *)&c; @@ -334,8 +322,8 @@ static int rma_create_test_challenge(int p256_mode) memset(authcode, 0, sizeof(authcode)); memset(&c, 0, sizeof(c)); - c.version_key_id = RMA_CHALLENGE_VKID_BYTE( - RMA_CHALLENGE_VERSION, server_key_id); + c.version_key_id = + RMA_CHALLENGE_VKID_BYTE(RMA_CHALLENGE_VERSION, server_key_id); memcpy(&bid, board_id, sizeof(bid)); bid = be32toh(bid); @@ -361,8 +349,8 @@ static int rma_create_test_challenge(int p256_mode) * and DeviceID. Those are all in the right order in the challenge * struct, after the version/key id byte. */ - hmac_SHA256(temp, secret_seed, sizeof(secret_seed), - cptr + 1, sizeof(c) - 1); + hmac_SHA256(temp, secret_seed, sizeof(secret_seed), cptr + 1, + sizeof(c) - 1); if (base32_encode(authcode, sizeof(authcode), temp, RMA_AUTHCODE_CHARS * 5, 0)) return 1; @@ -382,7 +370,8 @@ static void dump_key(const char *title, const uint8_t *key, size_t key_size) printf("\n\n\%s\n", title); for (i = 0; i < key_size; i++) - printf("%02x%c", key[i], ((i + 1) % bytes_per_line) ? ' ':'\n'); + printf("%02x%c", key[i], + ((i + 1) % bytes_per_line) ? ' ' : '\n'); if (i % bytes_per_line) printf("\n"); @@ -453,25 +442,26 @@ static void usage(void) "--device_id <arg> --hw_id <arg> |\n" " --auth_code <arg> |\n" " --challenge <arg>\n" - "\n" - "This is used to generate the cr50 or server responses for rma " - "open.\n" - "The cr50 side can be used to generate a challenge response " - "and sends authoriztion code to reset device.\n" - "The server side can generate an authcode from cr50's " - "rma challenge.\n" - "\n" - " -c,--challenge The challenge generated by cr50\n" - " -k,--key_id Index of the server private key\n" - " -b,--board_id BoardID type field\n" - " -d,--device_id Device-unique identifier\n" - " -a,--auth_code Reset authorization code\n" - " -w,--hw_id Hardware id\n" - " -h,--help Show this message\n" - " -p,--p256 Use prime256v1 curve instead of x25519\n" - " -t,--test " - "Generate challenge using default test inputs\n" - "\n", progname); + "\n" + "This is used to generate the cr50 or server responses for rma " + "open.\n" + "The cr50 side can be used to generate a challenge response " + "and sends authoriztion code to reset device.\n" + "The server side can generate an authcode from cr50's " + "rma challenge.\n" + "\n" + " -c,--challenge The challenge generated by cr50\n" + " -k,--key_id Index of the server private key\n" + " -b,--board_id BoardID type field\n" + " -d,--device_id Device-unique identifier\n" + " -a,--auth_code Reset authorization code\n" + " -w,--hw_id Hardware id\n" + " -h,--help Show this message\n" + " -p,--p256 Use prime256v1 curve instead of x25519\n" + " -t,--test " + "Generate challenge using default test inputs\n" + "\n", + progname); } static int atoh(char *v) @@ -498,7 +488,7 @@ static int set_server_key_id(char *id) return 1; /* verify digits */ - if (!isxdigit(*id) || !isxdigit(*(id+1))) + if (!isxdigit(*id) || !isxdigit(*(id + 1))) return 1; server_key_id = atoh(id); @@ -520,7 +510,7 @@ static int set_board_id(char *id) return 1; for (i = 0; i < 4; i++) - board_id[i] = atoh((id + (i*2))); + board_id[i] = atoh((id + (i * 2))); return 0; } @@ -538,7 +528,7 @@ static int set_device_id(char *id) return 1; for (i = 0; i < 8; i++) - device_id[i] = atoh((id + (i*2))); + device_id[i] = atoh((id + (i * 2))); return 0; } @@ -635,14 +625,14 @@ int main(int argc, char **argv) case 'h': usage(); return 0; - case 0: /* auto-handled option */ + case 0: /* auto-handled option */ break; case '?': if (optopt) printf("Unrecognized option: -%c\n", optopt); else printf("Unrecognized option: %s\n", - argv[optind - 1]); + argv[optind - 1]); break; case ':': printf("Missing argument to %s\n", argv[optind - 1]); @@ -683,7 +673,7 @@ int main(int argc, char **argv) if (!k_flag || !b_flag || !d_flag || !w_flag) { printf("server-side: Flag -c is mandatory\n"); printf("cr50-side: Flags -k, -b, -d, and -w " - "are mandatory\n"); + "are mandatory\n"); return 1; } } diff --git a/extra/sps_errs/Makefile b/extra/sps_errs/Makefile index 12224ad803..b25eecbdab 100644 --- a/extra/sps_errs/Makefile +++ b/extra/sps_errs/Makefile @@ -1,4 +1,4 @@ -# Copyright 2015 The Chromium OS Authors. All rights reserved. +# Copyright 2015 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/sps_errs/prog.c b/extra/sps_errs/prog.c index b649199068..bf44dd182c 100644 --- a/extra/sps_errs/prog.c +++ b/extra/sps_errs/prog.c @@ -1,4 +1,4 @@ -/* Copyright 2015 The Chromium OS Authors. All rights reserved. +/* Copyright 2015 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -23,7 +23,7 @@ static struct mpsse_context *mpsse; /* enum ec_status meaning */ static const char *ec_strerr(enum ec_status r) { - static const char * const strs[] = { + static const char *const strs[] = { "SUCCESS", "INVALID_COMMAND", "ERROR", @@ -48,10 +48,9 @@ static const char *ec_strerr(enum ec_status r) return "<undefined result>"; }; - -/**************************************************************************** - * Debugging output - */ + /**************************************************************************** + * Debugging output + */ #define LINELEN 16 @@ -65,8 +64,7 @@ static void showline(uint8_t *buf, int len) printf(" "); printf(" "); for (i = 0; i < len; i++) - printf("%c", - (buf[i] >= ' ' && buf[i] <= '~') ? buf[i] : '.'); + printf("%c", (buf[i] >= ' ' && buf[i] <= '~') ? buf[i] : '.'); printf("\n"); } @@ -105,8 +103,8 @@ static uint8_t txbuf[128]; * Load the output buffer with a proto v3 request (header, then data, with * checksum correct in header). */ -static size_t prepare_request(int cmd, int version, - const uint8_t *data, size_t data_len) +static size_t prepare_request(int cmd, int version, const uint8_t *data, + size_t data_len) { struct ec_host_request *request; size_t i, total_len; @@ -114,8 +112,8 @@ static size_t prepare_request(int cmd, int version, total_len = sizeof(*request) + data_len; if (total_len > sizeof(txbuf)) { - printf("Request too large (%zd > %zd)\n", - total_len, sizeof(txbuf)); + printf("Request too large (%zd > %zd)\n", total_len, + sizeof(txbuf)); return -1; } @@ -139,7 +137,6 @@ static size_t prepare_request(int cmd, int version, return total_len; } - /* Timeout flag, so we don't wait forever */ static int timedout; static void alarm_handler(int sig) @@ -151,11 +148,8 @@ static void alarm_handler(int sig) * Send command, wait for result. Return zero if communication succeeded; check * response to see if the EC liked the command. */ -static int send_cmd(int cmd, int version, - void *outbuf, - size_t outsize, - struct ec_host_response *hdr, - void *bodydest, +static int send_cmd(int cmd, int version, void *outbuf, size_t outsize, + struct ec_host_response *hdr, void *bodydest, size_t bodylen) { uint8_t *tptr, *hptr = 0, *bptr = 0; @@ -166,15 +160,13 @@ static int send_cmd(int cmd, int version, size_t bytes_left = stop_after; size_t bytes_sent = 0; - /* Load up the txbuf with the stuff to send */ len = prepare_request(cmd, version, outbuf, outsize); if (len < 0) return -1; if (MPSSE_OK != Start(mpsse)) { - fprintf(stderr, "Start failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Start failed: %s\n", ErrorString(mpsse)); return -1; } @@ -189,8 +181,7 @@ static int send_cmd(int cmd, int version, bytes_left -= len; bytes_sent += len; if (!tptr) { - fprintf(stderr, "Transfer failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Transfer failed: %s\n", ErrorString(mpsse)); goto out; } @@ -278,8 +269,7 @@ static int send_cmd(int cmd, int version, bytes_left -= len; bytes_sent += len; if (!hptr) { - fprintf(stderr, "Read failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Read failed: %s\n", ErrorString(mpsse)); goto out; } show("Header(%d):\n", hptr, sizeof(*hdr)); @@ -288,14 +278,12 @@ static int send_cmd(int cmd, int version, /* Check the header */ if (hdr->struct_version != EC_HOST_RESPONSE_VERSION) { printf("HEY: response version %d (should be %d)\n", - hdr->struct_version, - EC_HOST_RESPONSE_VERSION); + hdr->struct_version, EC_HOST_RESPONSE_VERSION); goto out; } if (hdr->data_len > bodylen) { - printf("HEY: response data_len %d is > %zd\n", - hdr->data_len, + printf("HEY: response data_len %d is > %zd\n", hdr->data_len, bodylen); goto out; } @@ -341,15 +329,13 @@ out: free(bptr); if (MPSSE_OK != Stop(mpsse)) { - fprintf(stderr, "Stop failed: %s\n", - ErrorString(mpsse)); + fprintf(stderr, "Stop failed: %s\n", ErrorString(mpsse)); return -1; } return 0; } - /****************************************************************************/ /** @@ -372,10 +358,7 @@ static int hello(void) p.in_data = 0xa5a5a5a5; expected = p.in_data + 0x01020304; - retval = send_cmd(EC_CMD_HELLO, 0, - &p, sizeof(p), - &resp, - &r, sizeof(r)); + retval = send_cmd(EC_CMD_HELLO, 0, &p, sizeof(p), &resp, &r, sizeof(r)); if (retval) { printf("Transmission error\n"); @@ -383,14 +366,13 @@ static int hello(void) } if (EC_RES_SUCCESS != resp.result) { - printf("EC result is %d: %s\n", - resp.result, ec_strerr(resp.result)); + printf("EC result is %d: %s\n", resp.result, + ec_strerr(resp.result)); return -1; } - printf("sent %08x, expected %08x, got %08x => %s\n", - p.in_data, expected, r.out_data, - expected == r.out_data ? "yay" : "boo"); + printf("sent %08x, expected %08x, got %08x => %s\n", p.in_data, + expected, r.out_data, expected == r.out_data ? "yay" : "boo"); return !(expected == r.out_data); } diff --git a/extra/stack_analyzer/run_tests.sh b/extra/stack_analyzer/run_tests.sh index 5662f60b8b..d5e65045c3 100755 --- a/extra/stack_analyzer/run_tests.sh +++ b/extra/stack_analyzer/run_tests.sh @@ -1,6 +1,6 @@ #!/bin/bash # -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/stack_analyzer/stack_analyzer.py b/extra/stack_analyzer/stack_analyzer.py index 77d16d5450..2431545c6a 100755 --- a/extra/stack_analyzer/stack_analyzer.py +++ b/extra/stack_analyzer/stack_analyzer.py @@ -1,11 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Statically analyze stack usage of EC firmware. @@ -25,1848 +21,2090 @@ import ctypes import os import re import subprocess -import yaml +import yaml # pylint:disable=import-error -SECTION_RO = 'RO' -SECTION_RW = 'RW' +SECTION_RO = "RO" +SECTION_RW = "RW" # Default size of extra stack frame needed by exception context switch. # This value is for cortex-m with FPU enabled. DEFAULT_EXCEPTION_FRAME_SIZE = 224 class StackAnalyzerError(Exception): - """Exception class for stack analyzer utility.""" + """Exception class for stack analyzer utility.""" class TaskInfo(ctypes.Structure): - """Taskinfo ctypes structure. - - The structure definition is corresponding to the "struct taskinfo" - in "util/export_taskinfo.so.c". - """ - _fields_ = [('name', ctypes.c_char_p), - ('routine', ctypes.c_char_p), - ('stack_size', ctypes.c_uint32)] + """Taskinfo ctypes structure. + The structure definition is corresponding to the "struct taskinfo" + in "util/export_taskinfo.so.c". + """ -class Task(object): - """Task information. + _fields_ = [ + ("name", ctypes.c_char_p), + ("routine", ctypes.c_char_p), + ("stack_size", ctypes.c_uint32), + ] - Attributes: - name: Task name. - routine_name: Routine function name. - stack_max_size: Max stack size. - routine_address: Resolved routine address. None if it hasn't been resolved. - """ - def __init__(self, name, routine_name, stack_max_size, routine_address=None): - """Constructor. +class Task(object): + """Task information. - Args: + Attributes: name: Task name. routine_name: Routine function name. stack_max_size: Max stack size. - routine_address: Resolved routine address. - """ - self.name = name - self.routine_name = routine_name - self.stack_max_size = stack_max_size - self.routine_address = routine_address - - def __eq__(self, other): - """Task equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. + routine_address: Resolved routine address. None if it hasn't been resolved. """ - if not isinstance(other, Task): - return False - return (self.name == other.name and - self.routine_name == other.routine_name and - self.stack_max_size == other.stack_max_size and - self.routine_address == other.routine_address) + def __init__( + self, name, routine_name, stack_max_size, routine_address=None + ): + """Constructor. + + Args: + name: Task name. + routine_name: Routine function name. + stack_max_size: Max stack size. + routine_address: Resolved routine address. + """ + self.name = name + self.routine_name = routine_name + self.stack_max_size = stack_max_size + self.routine_address = routine_address + + def __eq__(self, other): + """Task equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Task): + return False + + return ( + self.name == other.name + and self.routine_name == other.routine_name + and self.stack_max_size == other.stack_max_size + and self.routine_address == other.routine_address + ) class Symbol(object): - """Symbol information. + """Symbol information. - Attributes: - address: Symbol address. - symtype: Symbol type, 'O' (data, object) or 'F' (function). - size: Symbol size. - name: Symbol name. - """ - - def __init__(self, address, symtype, size, name): - """Constructor. - - Args: + Attributes: address: Symbol address. - symtype: Symbol type. + symtype: Symbol type, 'O' (data, object) or 'F' (function). size: Symbol size. name: Symbol name. """ - assert symtype in ['O', 'F'] - self.address = address - self.symtype = symtype - self.size = size - self.name = name - def __eq__(self, other): - """Symbol equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Symbol): - return False - - return (self.address == other.address and - self.symtype == other.symtype and - self.size == other.size and - self.name == other.name) + def __init__(self, address, symtype, size, name): + """Constructor. + + Args: + address: Symbol address. + symtype: Symbol type. + size: Symbol size. + name: Symbol name. + """ + assert symtype in ["O", "F"] + self.address = address + self.symtype = symtype + self.size = size + self.name = name + + def __eq__(self, other): + """Symbol equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Symbol): + return False + + return ( + self.address == other.address + and self.symtype == other.symtype + and self.size == other.size + and self.name == other.name + ) class Callsite(object): - """Function callsite. - - Attributes: - address: Address of callsite location. None if it is unknown. - target: Callee address. None if it is unknown. - is_tail: A bool indicates that it is a tailing call. - callee: Resolved callee function. None if it hasn't been resolved. - """ + """Function callsite. - def __init__(self, address, target, is_tail, callee=None): - """Constructor. - - Args: + Attributes: address: Address of callsite location. None if it is unknown. target: Callee address. None if it is unknown. - is_tail: A bool indicates that it is a tailing call. (function jump to - another function without restoring the stack frame) - callee: Resolved callee function. + is_tail: A bool indicates that it is a tailing call. + callee: Resolved callee function. None if it hasn't been resolved. """ - # It makes no sense that both address and target are unknown. - assert not (address is None and target is None) - self.address = address - self.target = target - self.is_tail = is_tail - self.callee = callee - - def __eq__(self, other): - """Callsite equality. - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Callsite): - return False - - if not (self.address == other.address and - self.target == other.target and - self.is_tail == other.is_tail): - return False - - if self.callee is None: - return other.callee is None - elif other.callee is None: - return False - - # Assume the addresses of functions are unique. - return self.callee.address == other.callee.address + def __init__(self, address, target, is_tail, callee=None): + """Constructor. + + Args: + address: Address of callsite location. None if it is unknown. + target: Callee address. None if it is unknown. + is_tail: A bool indicates that it is a tailing call. (function jump to + another function without restoring the stack frame) + callee: Resolved callee function. + """ + # It makes no sense that both address and target are unknown. + assert not (address is None and target is None) + self.address = address + self.target = target + self.is_tail = is_tail + self.callee = callee + + def __eq__(self, other): + """Callsite equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Callsite): + return False + + if not ( + self.address == other.address + and self.target == other.target + and self.is_tail == other.is_tail + ): + return False + + if self.callee is None: + return other.callee is None + elif other.callee is None: + return False + + # Assume the addresses of functions are unique. + return self.callee.address == other.callee.address class Function(object): - """Function. + """Function. - Attributes: - address: Address of function. - name: Name of function from its symbol. - stack_frame: Size of stack frame. - callsites: Callsite list. - stack_max_usage: Max stack usage. None if it hasn't been analyzed. - stack_max_path: Max stack usage path. None if it hasn't been analyzed. - """ - - def __init__(self, address, name, stack_frame, callsites): - """Constructor. - - Args: + Attributes: address: Address of function. name: Name of function from its symbol. stack_frame: Size of stack frame. callsites: Callsite list. + stack_max_usage: Max stack usage. None if it hasn't been analyzed. + stack_max_path: Max stack usage path. None if it hasn't been analyzed. """ - self.address = address - self.name = name - self.stack_frame = stack_frame - self.callsites = callsites - self.stack_max_usage = None - self.stack_max_path = None - - def __eq__(self, other): - """Function equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Function): - return False - if not (self.address == other.address and - self.name == other.name and - self.stack_frame == other.stack_frame and - self.callsites == other.callsites and - self.stack_max_usage == other.stack_max_usage): - return False + def __init__(self, address, name, stack_frame, callsites): + """Constructor. + + Args: + address: Address of function. + name: Name of function from its symbol. + stack_frame: Size of stack frame. + callsites: Callsite list. + """ + self.address = address + self.name = name + self.stack_frame = stack_frame + self.callsites = callsites + self.stack_max_usage = None + self.stack_max_path = None + + def __eq__(self, other): + """Function equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Function): + return False + + if not ( + self.address == other.address + and self.name == other.name + and self.stack_frame == other.stack_frame + and self.callsites == other.callsites + and self.stack_max_usage == other.stack_max_usage + ): + return False + + if self.stack_max_path is None: + return other.stack_max_path is None + elif other.stack_max_path is None: + return False + + if len(self.stack_max_path) != len(other.stack_max_path): + return False + + for self_func, other_func in zip( + self.stack_max_path, other.stack_max_path + ): + # Assume the addresses of functions are unique. + if self_func.address != other_func.address: + return False + + return True + + def __hash__(self): + return id(self) - if self.stack_max_path is None: - return other.stack_max_path is None - elif other.stack_max_path is None: - return False - - if len(self.stack_max_path) != len(other.stack_max_path): - return False - - for self_func, other_func in zip(self.stack_max_path, other.stack_max_path): - # Assume the addresses of functions are unique. - if self_func.address != other_func.address: - return False - - return True - - def __hash__(self): - return id(self) class AndesAnalyzer(object): - """Disassembly analyzer for Andes architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - GENERAL_PURPOSE_REGISTER_SIZE = 4 - - # Possible condition code suffixes. - CONDITION_CODES = [ 'eq', 'eqz', 'gez', 'gtz', 'lez', 'ltz', 'ne', 'nez', - 'eqc', 'nec', 'nezs', 'nes', 'eqs'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - - IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>' - # Branch instructions. - JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$' \ - .format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile \ - (r'^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$') - CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE)) - # Ignore lp register because it's for return. - INDIRECT_CALL_OPERAND_RE = re.compile \ - (r'^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$') - # TODO: Handle other kinds of store instructions. - PUSH_OPCODE_RE = re.compile(r'^push(\d{1,})$') - PUSH_OPERAND_RE = re.compile(r'^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}') - SMW_OPCODE_RE = re.compile(r'^smw(\.\w\w|\.\w\w\w)$') - SMW_OPERAND_RE = re.compile(r'^(\$r\d{1,}|\$\wp), \[\$\wp\], ' - r'(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}') - OPERANDGROUP_RE = re.compile(r'^\$r\d{1,}\~\$r\d{1,}') - - LWI_OPCODE_RE = re.compile(r'^lwi(\.\w\w)$') - LWI_PC_OPERAND_RE = re.compile(r'^\$pc, \[([^\]]+)\]') - # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec" - # Assume there is always a "\t" after the hex data. - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+' - r'(?P<words>[0-9A-Fa-f ]+)' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + """Disassembly analyzer for Andes architecture. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - - Returns: - (address, words, opcode, operand_text): The instruction address, words, - opcode, and the text of operands. - None if it isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - words = result.group('words') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, words, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - - stack_frame = 0 - callsites = [] - for address, words, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - - if is_jump_opcode or is_call_opcode: - is_tail = is_jump_opcode - - result = self.CALL_OPERAND_RE.match(operand_text) + GENERAL_PURPOSE_REGISTER_SIZE = 4 + + # Possible condition code suffixes. + CONDITION_CODES = [ + "eq", + "eqz", + "gez", + "gtz", + "lez", + "ltz", + "ne", + "nez", + "eqc", + "nec", + "nezs", + "nes", + "eqs", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + + IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>" + # Branch instructions. + JUMP_OPCODE_RE = re.compile( + r"^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$".format(CONDITION_CODES_RE) + ) + # Call instructions. + CALL_OPCODE_RE = re.compile( + r"^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$" + ) + CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE)) + # Ignore lp register because it's for return. + INDIRECT_CALL_OPERAND_RE = re.compile( + r"^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$" + ) + # TODO: Handle other kinds of store instructions. + PUSH_OPCODE_RE = re.compile(r"^push(\d{1,})$") + PUSH_OPERAND_RE = re.compile(r"^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}") + SMW_OPCODE_RE = re.compile(r"^smw(\.\w\w|\.\w\w\w)$") + SMW_OPERAND_RE = re.compile( + r"^(\$r\d{1,}|\$\wp), \[\$\wp\], " + r"(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}" + ) + OPERANDGROUP_RE = re.compile(r"^\$r\d{1,}\~\$r\d{1,}") + + LWI_OPCODE_RE = re.compile(r"^lwi(\.\w\w)$") + LWI_PC_OPERAND_RE = re.compile(r"^\$pc, \[([^\]]+)\]") + # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec" + # Assume there is always a "\t" after the hex data. + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+" + r"(?P<words>[0-9A-Fa-f ]+)" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, words, opcode, operand_text): The instruction address, words, + opcode, and the text of operands. + None if it isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) - + return None + + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None + + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + words = result.group("words") + if operand_text is None: + operand_text = "" else: - target_address = int(result.group(1), 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.LWI_OPCODE_RE.match(opcode) is not None: - result = self.LWI_PC_OPERAND_RE.match(operand_text) - if result is not None: - # Ignore "lwi $pc, [$sp], xx" because it's usually a return. - if result.group(1) != '$sp': - # Found an indirect call. - callsites.append(Callsite(address, None, True)) - - elif self.PUSH_OPCODE_RE.match(opcode) is not None: - # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp} - if self.PUSH_OPERAND_RE.match(operand_text) is not None: - # capture fc 20 - imm5u = int(words.split(' ')[1], 16) - # sp = sp - (imm5u << 3) - imm8u = (imm5u<<3) & 0xff - stack_frame += imm8u - - result = self.PUSH_OPERAND_RE.match(operand_text) - operandgroup_text = result.group(1) - # capture $rx~$ry - if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: - # capture number & transfer string to integer - oprandgrouphead = operandgroup_text.split(',')[0] - rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0]))) - ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1]))) - - stack_frame += ((len(operandgroup_text.split(','))+ry-rx) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - else: - stack_frame += (len(operandgroup_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - - elif self.SMW_OPCODE_RE.match(opcode) is not None: - # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp} - if self.SMW_OPERAND_RE.match(operand_text) is not None: - result = self.SMW_OPERAND_RE.match(operand_text) - operandgroup_text = result.group(3) - # capture $rx~$ry - if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: - # capture number & transfer string to integer - oprandgrouphead = operandgroup_text.split(',')[0] - rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0]))) - ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1]))) - - stack_frame += ((len(operandgroup_text.split(','))+ry-rx) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - else: - stack_frame += (len(operandgroup_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - - return (stack_frame, callsites) + operand_text = operand_text.strip() + + return (address, words, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + + stack_frame = 0 + callsites = [] + for address, words, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + + if is_jump_opcode or is_call_opcode: + is_tail = is_jump_opcode + + result = self.CALL_OPERAND_RE.match(operand_text) + + if result is None: + if ( + self.INDIRECT_CALL_OPERAND_RE.match(operand_text) + is not None + ): + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + target_address = int(result.group(1), 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append( + Callsite(address, target_address, is_tail) + ) + + elif self.LWI_OPCODE_RE.match(opcode) is not None: + result = self.LWI_PC_OPERAND_RE.match(operand_text) + if result is not None: + # Ignore "lwi $pc, [$sp], xx" because it's usually a return. + if result.group(1) != "$sp": + # Found an indirect call. + callsites.append(Callsite(address, None, True)) + + elif self.PUSH_OPCODE_RE.match(opcode) is not None: + # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp} + if self.PUSH_OPERAND_RE.match(operand_text) is not None: + # capture fc 20 + imm5u = int(words.split(" ")[1], 16) + # sp = sp - (imm5u << 3) + imm8u = (imm5u << 3) & 0xFF + stack_frame += imm8u + + result = self.PUSH_OPERAND_RE.match(operand_text) + operandgroup_text = result.group(1) + # capture $rx~$ry + if ( + self.OPERANDGROUP_RE.match(operandgroup_text) + is not None + ): + # capture number & transfer string to integer + oprandgrouphead = operandgroup_text.split(",")[0] + rx = int( + "".join( + filter( + str.isdigit, oprandgrouphead.split("~")[0] + ) + ) + ) + ry = int( + "".join( + filter( + str.isdigit, oprandgrouphead.split("~")[1] + ) + ) + ) + + stack_frame += ( + len(operandgroup_text.split(",")) + ry - rx + ) * self.GENERAL_PURPOSE_REGISTER_SIZE + else: + stack_frame += ( + len(operandgroup_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + elif self.SMW_OPCODE_RE.match(opcode) is not None: + # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp} + if self.SMW_OPERAND_RE.match(operand_text) is not None: + result = self.SMW_OPERAND_RE.match(operand_text) + operandgroup_text = result.group(3) + # capture $rx~$ry + if ( + self.OPERANDGROUP_RE.match(operandgroup_text) + is not None + ): + # capture number & transfer string to integer + oprandgrouphead = operandgroup_text.split(",")[0] + rx = int( + "".join( + filter( + str.isdigit, oprandgrouphead.split("~")[0] + ) + ) + ) + ry = int( + "".join( + filter( + str.isdigit, oprandgrouphead.split("~")[1] + ) + ) + ) + + stack_frame += ( + len(operandgroup_text.split(",")) + ry - rx + ) * self.GENERAL_PURPOSE_REGISTER_SIZE + else: + stack_frame += ( + len(operandgroup_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + return (stack_frame, callsites) -class ArmAnalyzer(object): - """Disassembly analyzer for ARM architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - GENERAL_PURPOSE_REGISTER_SIZE = 4 - - # Possible condition code suffixes. - CONDITION_CODES = ['', 'eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl', 'vs', - 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - # Assume there is no function name containing ">". - IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>' - - # Fuzzy regular expressions for instruction and operand parsing. - # Branch instructions. - JUMP_OPCODE_RE = re.compile( - r'^(b{0}|bx{0})(\.\w)?$'.format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile( - r'^(bl{0}|blx{0})(\.\w)?$'.format(CONDITION_CODES_RE)) - CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE)) - CBZ_CBNZ_OPCODE_RE = re.compile(r'^(cbz|cbnz)(\.\w)?$') - # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>" - CBZ_CBNZ_OPERAND_RE = re.compile(r'^[^,]+,\s+{}$'.format(IMM_ADDRESS_RE)) - # Ignore lr register because it's for return. - INDIRECT_CALL_OPERAND_RE = re.compile(r'^r\d+|sb|sl|fp|ip|sp|pc$') - # TODO(cheyuw): Handle conditional versions of following - # instructions. - # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc). - LDR_OPCODE_RE = re.compile(r'^ldr(\.\w)?$') - # Example: "pc, [sp], #4" - LDR_PC_OPERAND_RE = re.compile(r'^pc, \[([^\]]+)\]') - # TODO(cheyuw): Handle other kinds of stm instructions. - PUSH_OPCODE_RE = re.compile(r'^push$') - STM_OPCODE_RE = re.compile(r'^stmdb$') - # Stack subtraction instructions. - SUB_OPCODE_RE = re.compile(r'^sub(s|w)?(\.\w)?$') - SUB_OPERAND_RE = re.compile(r'^sp[^#]+#(\d+)') - # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68" - # Assume there is always a "\t" after the hex data. - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. +class ArmAnalyzer(object): + """Disassembly analyzer for ARM architecture. - Returns: - (address, opcode, operand_text): The instruction address, opcode, - and the text of operands. None if it - isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - """Analyze function, resolve the size of stack frame and callsites. - - Args: - function_symbol: Function symbol. - instructions: Instruction list. - - Returns: - (stack_frame, callsites): Size of stack frame, callsite list. - """ - stack_frame = 0 - callsites = [] - for address, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None - if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode: - is_tail = is_jump_opcode or is_cbz_cbnz_opcode - - if is_cbz_cbnz_opcode: - result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text) - else: - result = self.CALL_OPERAND_RE.match(operand_text) + GENERAL_PURPOSE_REGISTER_SIZE = 4 + + # Possible condition code suffixes. + CONDITION_CODES = [ + "", + "eq", + "ne", + "cs", + "hs", + "cc", + "lo", + "mi", + "pl", + "vs", + "vc", + "hi", + "ls", + "ge", + "lt", + "gt", + "le", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + # Assume there is no function name containing ">". + IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>" + + # Fuzzy regular expressions for instruction and operand parsing. + # Branch instructions. + JUMP_OPCODE_RE = re.compile( + r"^(b{0}|bx{0})(\.\w)?$".format(CONDITION_CODES_RE) + ) + # Call instructions. + CALL_OPCODE_RE = re.compile( + r"^(bl{0}|blx{0})(\.\w)?$".format(CONDITION_CODES_RE) + ) + CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE)) + CBZ_CBNZ_OPCODE_RE = re.compile(r"^(cbz|cbnz)(\.\w)?$") + # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>" + CBZ_CBNZ_OPERAND_RE = re.compile(r"^[^,]+,\s+{}$".format(IMM_ADDRESS_RE)) + # Ignore lr register because it's for return. + INDIRECT_CALL_OPERAND_RE = re.compile(r"^r\d+|sb|sl|fp|ip|sp|pc$") + # TODO(cheyuw): Handle conditional versions of following + # instructions. + # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc). + LDR_OPCODE_RE = re.compile(r"^ldr(\.\w)?$") + # Example: "pc, [sp], #4" + LDR_PC_OPERAND_RE = re.compile(r"^pc, \[([^\]]+)\]") + # TODO(cheyuw): Handle other kinds of stm instructions. + PUSH_OPCODE_RE = re.compile(r"^push$") + STM_OPCODE_RE = re.compile(r"^stmdb$") + # Stack subtraction instructions. + SUB_OPCODE_RE = re.compile(r"^sub(s|w)?(\.\w)?$") + SUB_OPERAND_RE = re.compile(r"^sp[^#]+#(\d+)") + # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68" + # Assume there is always a "\t" after the hex data. + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, opcode, operand_text): The instruction address, opcode, + and the text of operands. None if it + isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - # Failed to match immediate address, maybe it is an indirect call. - # CBZ and CBNZ can't be indirect calls. - if (not is_cbz_cbnz_opcode and - self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) + return None - else: - target_address = int(result.group(1), 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.LDR_OPCODE_RE.match(opcode) is not None: - result = self.LDR_PC_OPERAND_RE.match(operand_text) - if result is not None: - # Ignore "ldr pc, [sp], xx" because it's usually a return. - if result.group(1) != 'sp': - # Found an indirect call. - callsites.append(Callsite(address, None, True)) - - elif self.PUSH_OPCODE_RE.match(opcode) is not None: - # Example: "{r4, r5, r6, r7, lr}" - stack_frame += (len(operand_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - elif self.SUB_OPCODE_RE.match(opcode) is not None: - result = self.SUB_OPERAND_RE.match(operand_text) - if result is not None: - stack_frame += int(result.group(1)) - else: - # Unhandled stack register subtraction. - assert not operand_text.startswith('sp') + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None - elif self.STM_OPCODE_RE.match(opcode) is not None: - if operand_text.startswith('sp!'): - # Subtract and writeback to stack register. - # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}" - # Get the text of pushed register list. - unused_sp, unused_sep, parameter_text = operand_text.partition(',') - stack_frame += (len(parameter_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + if operand_text is None: + operand_text = "" + else: + operand_text = operand_text.strip() + + return (address, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + """Analyze function, resolve the size of stack frame and callsites. + + Args: + function_symbol: Function symbol. + instructions: Instruction list. + + Returns: + (stack_frame, callsites): Size of stack frame, callsite list. + """ + stack_frame = 0 + callsites = [] + for address, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + is_cbz_cbnz_opcode = ( + self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None + ) + if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode: + is_tail = is_jump_opcode or is_cbz_cbnz_opcode + + if is_cbz_cbnz_opcode: + result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text) + else: + result = self.CALL_OPERAND_RE.match(operand_text) + + if result is None: + # Failed to match immediate address, maybe it is an indirect call. + # CBZ and CBNZ can't be indirect calls. + if ( + not is_cbz_cbnz_opcode + and self.INDIRECT_CALL_OPERAND_RE.match(operand_text) + is not None + ): + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + target_address = int(result.group(1), 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append( + Callsite(address, target_address, is_tail) + ) + + elif self.LDR_OPCODE_RE.match(opcode) is not None: + result = self.LDR_PC_OPERAND_RE.match(operand_text) + if result is not None: + # Ignore "ldr pc, [sp], xx" because it's usually a return. + if result.group(1) != "sp": + # Found an indirect call. + callsites.append(Callsite(address, None, True)) + + elif self.PUSH_OPCODE_RE.match(opcode) is not None: + # Example: "{r4, r5, r6, r7, lr}" + stack_frame += ( + len(operand_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + elif self.SUB_OPCODE_RE.match(opcode) is not None: + result = self.SUB_OPERAND_RE.match(operand_text) + if result is not None: + stack_frame += int(result.group(1)) + else: + # Unhandled stack register subtraction. + assert not operand_text.startswith("sp") + + elif self.STM_OPCODE_RE.match(opcode) is not None: + if operand_text.startswith("sp!"): + # Subtract and writeback to stack register. + # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}" + # Get the text of pushed register list. + ( + unused_sp, + unused_sep, + parameter_text, + ) = operand_text.partition(",") + stack_frame += ( + len(parameter_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + return (stack_frame, callsites) - return (stack_frame, callsites) class RiscvAnalyzer(object): - """Disassembly analyzer for RISC-V architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - # Possible condition code suffixes. - CONDITION_CODES = [ 'eqz', 'nez', 'lez', 'gez', 'ltz', 'gtz', 'gt', 'le', - 'gtu', 'leu', 'eq', 'ne', 'ge', 'lt', 'ltu', 'geu'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - # Branch instructions. - JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr)$'.format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile(r'^(jal|jalr)$') - # Example: "j 8009b318 <set_state_prl_hr>" or - # "jal ra,800a4394 <power_get_signals>" or - # "bltu t0,t1,80080300 <data_loop>" - JUMP_ADDRESS_RE = r'((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>' - CALL_OPERAND_RE = re.compile(r'^{}$'.format(JUMP_ADDRESS_RE)) - # Capture address, Example: 800a4394 - CAPTURE_ADDRESS = re.compile(r'[0-9A-Fa-f]{8}') - # Indirect jump, Example: jalr a5 - INDIRECT_CALL_OPERAND_RE = re.compile(r'^t\d+|s\d+|a\d+$') - # Example: addi - ADDI_OPCODE_RE = re.compile(r'^addi$') - # Allocate stack instructions. - ADDI_OPERAND_RE = re.compile(r'^(sp,sp,-\d+)$') - # Example: "800804b6: 1101 addi sp,sp,-32" - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + """Disassembly analyzer for RISC-V architecture. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - - Returns: - (address, opcode, operand_text): The instruction address, opcode, - and the text of operands. None if it - isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - - stack_frame = 0 - callsites = [] - for address, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - if is_jump_opcode or is_call_opcode: - is_tail = is_jump_opcode - - result = self.CALL_OPERAND_RE.match(operand_text) + # Possible condition code suffixes. + CONDITION_CODES = [ + "eqz", + "nez", + "lez", + "gez", + "ltz", + "gtz", + "gt", + "le", + "gtu", + "leu", + "eq", + "ne", + "ge", + "lt", + "ltu", + "geu", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + # Branch instructions. + JUMP_OPCODE_RE = re.compile(r"^(b{0}|j|jr)$".format(CONDITION_CODES_RE)) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(jal|jalr)$") + # Example: "j 8009b318 <set_state_prl_hr>" or + # "jal ra,800a4394 <power_get_signals>" or + # "bltu t0,t1,80080300 <data_loop>" + JUMP_ADDRESS_RE = r"((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>" + CALL_OPERAND_RE = re.compile(r"^{}$".format(JUMP_ADDRESS_RE)) + # Capture address, Example: 800a4394 + CAPTURE_ADDRESS = re.compile(r"[0-9A-Fa-f]{8}") + # Indirect jump, Example: jalr a5 + INDIRECT_CALL_OPERAND_RE = re.compile(r"^t\d+|s\d+|a\d+$") + # Example: addi + ADDI_OPCODE_RE = re.compile(r"^addi$") + # Allocate stack instructions. + ADDI_OPERAND_RE = re.compile(r"^(sp,sp,-\d+)$") + # Example: "800804b6: 1101 addi sp,sp,-32" + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, opcode, operand_text): The instruction address, opcode, + and the text of operands. None if it + isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) - - else: - # Capture address form operand_text and then convert to string - address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text)) - # String to integer - target_address = int(address_str, 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.ADDI_OPCODE_RE.match(opcode) is not None: - # Example: sp,sp,-32 - if self.ADDI_OPERAND_RE.match(operand_text) is not None: - stack_frame += abs(int(operand_text.split(",")[2])) - - return (stack_frame, callsites) - -class StackAnalyzer(object): - """Class to analyze stack usage. - - Public Methods: - Analyze: Run the stack analysis. - """ - - C_FUNCTION_NAME = r'_A-Za-z0-9' - - # Assume there is no ":" in the path. - # Example: "driver/accel_kionix.c:321 (discriminator 3)" - ADDRTOLINE_RE = re.compile( - r'^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$') - # To eliminate the suffix appended by compilers, try to extract the - # C function name from the prefix of symbol name. - # Example: "SHA256_transform.constprop.28" - FUNCTION_PREFIX_NAME_RE = re.compile( - r'^(?P<name>[{0}]+)([^{0}].*)?$'.format(C_FUNCTION_NAME)) - - # Errors of annotation resolving. - ANNOTATION_ERROR_INVALID = 'invalid signature' - ANNOTATION_ERROR_NOTFOUND = 'function is not found' - ANNOTATION_ERROR_AMBIGUOUS = 'signature is ambiguous' - - def __init__(self, options, symbols, rodata, tasklist, annotation): - """Constructor. - - Args: - options: Namespace from argparse.parse_args(). - symbols: Symbol list. - rodata: Content of .rodata section (offset, data) - tasklist: Task list. - annotation: Annotation config. - """ - self.options = options - self.symbols = symbols - self.rodata_offset = rodata[0] - self.rodata = rodata[1] - self.tasklist = tasklist - self.annotation = annotation - self.address_to_line_cache = {} - - def AddressToLine(self, address, resolve_inline=False): - """Convert address to line. + return None - Args: - address: Target address. - resolve_inline: Output the stack of inlining. - - Returns: - lines: List of the corresponding lines. - - Raises: - StackAnalyzerError: If addr2line is failed. - """ - cache_key = (address, resolve_inline) - if cache_key in self.address_to_line_cache: - return self.address_to_line_cache[cache_key] - - try: - args = [self.options.addr2line, - '-f', - '-e', - self.options.elf_path, - '{:x}'.format(address)] - if resolve_inline: - args.append('-i') - - line_text = subprocess.check_output(args, encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('addr2line failed to resolve lines.') - except OSError: - raise StackAnalyzerError('Failed to run addr2line.') - - lines = [line.strip() for line in line_text.splitlines()] - # Assume the output has at least one pair like "function\nlocation\n", and - # they always show up in pairs. - # Example: "handle_request\n - # common/usb_pd_protocol.c:1191\n" - assert len(lines) >= 2 and len(lines) % 2 == 0 - - line_infos = [] - for index in range(0, len(lines), 2): - (function_name, line_text) = lines[index:index + 2] - if line_text in ['??:0', ':?']: - line_infos.append(None) - else: - result = self.ADDRTOLINE_RE.match(line_text) - # Assume the output is always well-formed. - assert result is not None - line_infos.append((function_name.strip(), - os.path.realpath(result.group('path').strip()), - int(result.group('linenum')))) - - self.address_to_line_cache[cache_key] = line_infos - return line_infos - - def AnalyzeDisassembly(self, disasm_text): - """Parse the disassembly text, analyze, and build a map of all functions. - - Args: - disasm_text: Disassembly text. + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None - Returns: - function_map: Dict of functions. - """ - disasm_lines = [line.strip() for line in disasm_text.splitlines()] - - if 'nds' in disasm_lines[1]: - analyzer = AndesAnalyzer() - elif 'arm' in disasm_lines[1]: - analyzer = ArmAnalyzer() - elif 'riscv' in disasm_lines[1]: - analyzer = RiscvAnalyzer() - else: - raise StackAnalyzerError('Unsupported architecture.') - - # Example: "08028c8c <motion_lid_calc>:" - function_signature_regex = re.compile( - r'^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$') - - def DetectFunctionHead(line): - """Check if the line is a function head. - - Args: - line: Text of disassembly. - - Returns: - symbol: Function symbol. None if it isn't a function head. - """ - result = function_signature_regex.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - symbol = symbol_map.get(address) - - # Check if the function exists and matches. - if symbol is None or symbol.symtype != 'F': - return None - - return symbol - - # Build symbol map, indexed by symbol address. - symbol_map = {} - for symbol in self.symbols: - # If there are multiple symbols with same address, keeping any of them is - # good enough. - symbol_map[symbol.address] = symbol - - # Parse the disassembly text. We update the variable "line" to next line - # when needed. There are two steps of parser: - # - # Step 1: Searching for the function head. Once reach the function head, - # move to the next line, which is the first line of function body. - # - # Step 2: Parsing each instruction line of function body. Once reach a - # non-instruction line, stop parsing and analyze the parsed instructions. - # - # Finally turn back to the step 1 without updating the line, because the - # current non-instruction line can be another function head. - function_map = {} - # The following three variables are the states of the parsing processing. - # They will be initialized properly during the state changes. - function_symbol = None - function_end = None - instructions = [] - - # Remove heading and tailing spaces for each line. - line_index = 0 - while line_index < len(disasm_lines): - # Get the current line. - line = disasm_lines[line_index] - - if function_symbol is None: - # Step 1: Search for the function head. - - function_symbol = DetectFunctionHead(line) - if function_symbol is not None: - # Assume there is no empty function. If the function head is followed - # by EOF, it is an empty function. - assert line_index + 1 < len(disasm_lines) - - # Found the function head, initialize and turn to the step 2. - instructions = [] - # If symbol size exists, use it as a hint of function size. - if function_symbol.size > 0: - function_end = function_symbol.address + function_symbol.size - else: - function_end = None - - else: - # Step 2: Parse the function body. - - instruction = analyzer.ParseInstruction(line, function_end) - if instruction is not None: - instructions.append(instruction) - - if instruction is None or line_index + 1 == len(disasm_lines): - # Either the invalid instruction or EOF indicates the end of the - # function, finalize the function analysis. - - # Assume there is no empty function. - assert len(instructions) > 0 - - (stack_frame, callsites) = analyzer.AnalyzeFunction(function_symbol, - instructions) - # Assume the function addresses are unique in the disassembly. - assert function_symbol.address not in function_map - function_map[function_symbol.address] = Function( - function_symbol.address, - function_symbol.name, - stack_frame, - callsites) - - # Initialize and turn back to the step 1. - function_symbol = None - - # If the current line isn't an instruction, it can be another function - # head, skip moving to the next line. - if instruction is None: - continue - - # Move to the next line. - line_index += 1 - - # Resolve callees of functions. - for function in function_map.values(): - for callsite in function.callsites: - if callsite.target is not None: - # Remain the callee as None if we can't resolve it. - callsite.callee = function_map.get(callsite.target) - - return function_map + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + if operand_text is None: + operand_text = "" + else: + operand_text = operand_text.strip() + + return (address, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + + stack_frame = 0 + callsites = [] + for address, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + + if is_jump_opcode or is_call_opcode: + is_tail = is_jump_opcode + + result = self.CALL_OPERAND_RE.match(operand_text) + if result is None: + if ( + self.INDIRECT_CALL_OPERAND_RE.match(operand_text) + is not None + ): + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + # Capture address form operand_text and then convert to string + address_str = "".join( + self.CAPTURE_ADDRESS.findall(operand_text) + ) + # String to integer + target_address = int(address_str, 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append( + Callsite(address, target_address, is_tail) + ) + + elif self.ADDI_OPCODE_RE.match(opcode) is not None: + # Example: sp,sp,-32 + if self.ADDI_OPERAND_RE.match(operand_text) is not None: + stack_frame += abs(int(operand_text.split(",")[2])) + + return (stack_frame, callsites) - def MapAnnotation(self, function_map, signature_set): - """Map annotation signatures to functions. - Args: - function_map: Function map. - signature_set: Set of annotation signatures. +class StackAnalyzer(object): + """Class to analyze stack usage. - Returns: - Map of signatures to functions, map of signatures which can't be resolved. + Public Methods: + Analyze: Run the stack analysis. """ - # Build the symbol map indexed by symbol name. If there are multiple symbols - # with the same name, add them into a set. (e.g. symbols of static function - # with the same name) - symbol_map = collections.defaultdict(set) - for symbol in self.symbols: - if symbol.symtype == 'F': - # Function symbol. - result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) - if result is not None: - function = function_map.get(symbol.address) - # Ignore the symbol not in disassembly. - if function is not None: - # If there are multiple symbol with the same name and point to the - # same function, the set will deduplicate them. - symbol_map[result.group('name').strip()].add(function) - - # Build the signature map indexed by annotation signature. - signature_map = {} - sig_error_map = {} - symbol_path_map = {} - for sig in signature_set: - (name, path, _) = sig - - functions = symbol_map.get(name) - if functions is None: - sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND - continue - - if name not in symbol_path_map: - # Lazy symbol path resolving. Since the addr2line isn't fast, only - # resolve needed symbol paths. - group_map = collections.defaultdict(list) - for function in functions: - line_info = self.AddressToLine(function.address)[0] - if line_info is None: - continue - (_, symbol_path, _) = line_info + C_FUNCTION_NAME = r"_A-Za-z0-9" - # Group the functions with the same symbol signature (symbol name + - # symbol path). Assume they are the same copies and do the same - # annotation operations of them because we don't know which copy is - # indicated by the users. - group_map[symbol_path].append(function) + # Assume there is no ":" in the path. + # Example: "driver/accel_kionix.c:321 (discriminator 3)" + ADDRTOLINE_RE = re.compile( + r"^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$" + ) + # To eliminate the suffix appended by compilers, try to extract the + # C function name from the prefix of symbol name. + # Example: "SHA256_transform.constprop.28" + FUNCTION_PREFIX_NAME_RE = re.compile( + r"^(?P<name>[{0}]+)([^{0}].*)?$".format(C_FUNCTION_NAME) + ) + + # Errors of annotation resolving. + ANNOTATION_ERROR_INVALID = "invalid signature" + ANNOTATION_ERROR_NOTFOUND = "function is not found" + ANNOTATION_ERROR_AMBIGUOUS = "signature is ambiguous" + + def __init__(self, options, symbols, rodata, tasklist, annotation): + """Constructor. + + Args: + options: Namespace from argparse.parse_args(). + symbols: Symbol list. + rodata: Content of .rodata section (offset, data) + tasklist: Task list. + annotation: Annotation config. + """ + self.options = options + self.symbols = symbols + self.rodata_offset = rodata[0] + self.rodata = rodata[1] + self.tasklist = tasklist + self.annotation = annotation + self.address_to_line_cache = {} + + def AddressToLine(self, address, resolve_inline=False): + """Convert address to line. + + Args: + address: Target address. + resolve_inline: Output the stack of inlining. + + Returns: + lines: List of the corresponding lines. + + Raises: + StackAnalyzerError: If addr2line is failed. + """ + cache_key = (address, resolve_inline) + if cache_key in self.address_to_line_cache: + return self.address_to_line_cache[cache_key] + + try: + args = [ + self.options.addr2line, + "-f", + "-e", + self.options.elf_path, + "{:x}".format(address), + ] + if resolve_inline: + args.append("-i") + + line_text = subprocess.check_output(args, encoding="utf-8") + except subprocess.CalledProcessError: + raise StackAnalyzerError("addr2line failed to resolve lines.") + except OSError: + raise StackAnalyzerError("Failed to run addr2line.") + + lines = [line.strip() for line in line_text.splitlines()] + # Assume the output has at least one pair like "function\nlocation\n", and + # they always show up in pairs. + # Example: "handle_request\n + # common/usb_pd_protocol.c:1191\n" + assert len(lines) >= 2 and len(lines) % 2 == 0 + + line_infos = [] + for index in range(0, len(lines), 2): + (function_name, line_text) = lines[index : index + 2] + if line_text in ["??:0", ":?"]: + line_infos.append(None) + else: + result = self.ADDRTOLINE_RE.match(line_text) + # Assume the output is always well-formed. + assert result is not None + line_infos.append( + ( + function_name.strip(), + os.path.realpath(result.group("path").strip()), + int(result.group("linenum")), + ) + ) + + self.address_to_line_cache[cache_key] = line_infos + return line_infos + + def AnalyzeDisassembly(self, disasm_text): + """Parse the disassembly text, analyze, and build a map of all functions. + + Args: + disasm_text: Disassembly text. + + Returns: + function_map: Dict of functions. + """ + disasm_lines = [line.strip() for line in disasm_text.splitlines()] + + if "nds" in disasm_lines[1]: + analyzer = AndesAnalyzer() + elif "arm" in disasm_lines[1]: + analyzer = ArmAnalyzer() + elif "riscv" in disasm_lines[1]: + analyzer = RiscvAnalyzer() + else: + raise StackAnalyzerError("Unsupported architecture.") - symbol_path_map[name] = group_map + # Example: "08028c8c <motion_lid_calc>:" + function_signature_regex = re.compile( + r"^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$" + ) - # Symbol matching. - function_group = None - group_map = symbol_path_map[name] - if len(group_map) > 0: - if path is None: - if len(group_map) > 1: - # There is ambiguity but the path isn't specified. - sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS - continue + def DetectFunctionHead(line): + """Check if the line is a function head. - # No path signature but all symbol signatures of functions are same. - # Assume they are the same functions, so there is no ambiguity. - (function_group,) = group_map.values() - else: - function_group = group_map.get(path) + Args: + line: Text of disassembly. - if function_group is None: - sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND - continue + Returns: + symbol: Function symbol. None if it isn't a function head. + """ + result = function_signature_regex.match(line) + if result is None: + return None - # The function_group is a list of all the same functions (according to - # our assumption) which should be annotated together. - signature_map[sig] = function_group + address = int(result.group("address"), 16) + symbol = symbol_map.get(address) - return (signature_map, sig_error_map) + # Check if the function exists and matches. + if symbol is None or symbol.symtype != "F": + return None - def LoadAnnotation(self): - """Load annotation rules. + return symbol - Returns: - Map of add rules, set of remove rules, set of text signatures which can't - be parsed. - """ - # Assume there is no ":" in the path. - # Example: "get_range.lto.2501[driver/accel_kionix.c:327]" - annotation_signature_regex = re.compile( - r'^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$') - - def NormalizeSignature(signature_text): - """Parse and normalize the annotation signature. - - Args: - signature_text: Text of the annotation signature. - - Returns: - (function name, path, line number) of the signature. The path and line - number can be None if not exist. None if failed to parse. - """ - result = annotation_signature_regex.match(signature_text.strip()) - if result is None: - return None - - name_result = self.FUNCTION_PREFIX_NAME_RE.match( - result.group('name').strip()) - if name_result is None: - return None - - path = result.group('path') - if path is not None: - path = os.path.realpath(path.strip()) - - linenum = result.group('linenum') - if linenum is not None: - linenum = int(linenum.strip()) - - return (name_result.group('name').strip(), path, linenum) - - def ExpandArray(dic): - """Parse and expand a symbol array - - Args: - dic: Dictionary for the array annotation - - Returns: - array of (symbol name, None, None). - """ - # TODO(drinkcat): This function is quite inefficient, as it goes through - # the symbol table multiple times. - - begin_name = dic['name'] - end_name = dic['name'] + "_end" - offset = dic['offset'] if 'offset' in dic else 0 - stride = dic['stride'] - - begin_address = None - end_address = None - - for symbol in self.symbols: - if (symbol.name == begin_name): - begin_address = symbol.address - if (symbol.name == end_name): - end_address = symbol.address - - if (not begin_address or not end_address): - return None - - output = [] - # TODO(drinkcat): This is inefficient as we go from address to symbol - # object then to symbol name, and later on we'll go back from symbol name - # to symbol object. - for addr in range(begin_address+offset, end_address, stride): - # TODO(drinkcat): Not all architectures need to drop the first bit. - val = self.rodata[(addr-self.rodata_offset) // 4] & 0xfffffffe - name = None + # Build symbol map, indexed by symbol address. + symbol_map = {} for symbol in self.symbols: - if (symbol.address == val): - result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) - name = result.group('name') - break - - if not name: - raise StackAnalyzerError('Cannot find function for address %s.', - hex(val)) - - output.append((name, None, None)) - - return output - - add_rules = collections.defaultdict(set) - remove_rules = list() - invalid_sigtxts = set() - - if 'add' in self.annotation and self.annotation['add'] is not None: - for src_sigtxt, dst_sigtxts in self.annotation['add'].items(): - src_sig = NormalizeSignature(src_sigtxt) - if src_sig is None: - invalid_sigtxts.add(src_sigtxt) - continue - - for dst_sigtxt in dst_sigtxts: - if isinstance(dst_sigtxt, dict): - dst_sig = ExpandArray(dst_sigtxt) - if dst_sig is None: - invalid_sigtxts.add(str(dst_sigtxt)) + # If there are multiple symbols with same address, keeping any of them is + # good enough. + symbol_map[symbol.address] = symbol + + # Parse the disassembly text. We update the variable "line" to next line + # when needed. There are two steps of parser: + # + # Step 1: Searching for the function head. Once reach the function head, + # move to the next line, which is the first line of function body. + # + # Step 2: Parsing each instruction line of function body. Once reach a + # non-instruction line, stop parsing and analyze the parsed instructions. + # + # Finally turn back to the step 1 without updating the line, because the + # current non-instruction line can be another function head. + function_map = {} + # The following three variables are the states of the parsing processing. + # They will be initialized properly during the state changes. + function_symbol = None + function_end = None + instructions = [] + + # Remove heading and tailing spaces for each line. + line_index = 0 + while line_index < len(disasm_lines): + # Get the current line. + line = disasm_lines[line_index] + + if function_symbol is None: + # Step 1: Search for the function head. + + function_symbol = DetectFunctionHead(line) + if function_symbol is not None: + # Assume there is no empty function. If the function head is followed + # by EOF, it is an empty function. + assert line_index + 1 < len(disasm_lines) + + # Found the function head, initialize and turn to the step 2. + instructions = [] + # If symbol size exists, use it as a hint of function size. + if function_symbol.size > 0: + function_end = ( + function_symbol.address + function_symbol.size + ) + else: + function_end = None + else: - add_rules[src_sig].update(dst_sig) - else: - dst_sig = NormalizeSignature(dst_sigtxt) - if dst_sig is None: - invalid_sigtxts.add(dst_sigtxt) + # Step 2: Parse the function body. + + instruction = analyzer.ParseInstruction(line, function_end) + if instruction is not None: + instructions.append(instruction) + + if instruction is None or line_index + 1 == len(disasm_lines): + # Either the invalid instruction or EOF indicates the end of the + # function, finalize the function analysis. + + # Assume there is no empty function. + assert len(instructions) > 0 + + (stack_frame, callsites) = analyzer.AnalyzeFunction( + function_symbol, instructions + ) + # Assume the function addresses are unique in the disassembly. + assert function_symbol.address not in function_map + function_map[function_symbol.address] = Function( + function_symbol.address, + function_symbol.name, + stack_frame, + callsites, + ) + + # Initialize and turn back to the step 1. + function_symbol = None + + # If the current line isn't an instruction, it can be another function + # head, skip moving to the next line. + if instruction is None: + continue + + # Move to the next line. + line_index += 1 + + # Resolve callees of functions. + for function in function_map.values(): + for callsite in function.callsites: + if callsite.target is not None: + # Remain the callee as None if we can't resolve it. + callsite.callee = function_map.get(callsite.target) + + return function_map + + def MapAnnotation(self, function_map, signature_set): + """Map annotation signatures to functions. + + Args: + function_map: Function map. + signature_set: Set of annotation signatures. + + Returns: + Map of signatures to functions, map of signatures which can't be resolved. + """ + # Build the symbol map indexed by symbol name. If there are multiple symbols + # with the same name, add them into a set. (e.g. symbols of static function + # with the same name) + symbol_map = collections.defaultdict(set) + for symbol in self.symbols: + if symbol.symtype == "F": + # Function symbol. + result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) + if result is not None: + function = function_map.get(symbol.address) + # Ignore the symbol not in disassembly. + if function is not None: + # If there are multiple symbol with the same name and point to the + # same function, the set will deduplicate them. + symbol_map[result.group("name").strip()].add(function) + + # Build the signature map indexed by annotation signature. + signature_map = {} + sig_error_map = {} + symbol_path_map = {} + for sig in signature_set: + (name, path, _) = sig + + functions = symbol_map.get(name) + if functions is None: + sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND + continue + + if name not in symbol_path_map: + # Lazy symbol path resolving. Since the addr2line isn't fast, only + # resolve needed symbol paths. + group_map = collections.defaultdict(list) + for function in functions: + line_info = self.AddressToLine(function.address)[0] + if line_info is None: + continue + + (_, symbol_path, _) = line_info + + # Group the functions with the same symbol signature (symbol name + + # symbol path). Assume they are the same copies and do the same + # annotation operations of them because we don't know which copy is + # indicated by the users. + group_map[symbol_path].append(function) + + symbol_path_map[name] = group_map + + # Symbol matching. + function_group = None + group_map = symbol_path_map[name] + if len(group_map) > 0: + if path is None: + if len(group_map) > 1: + # There is ambiguity but the path isn't specified. + sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS + continue + + # No path signature but all symbol signatures of functions are same. + # Assume they are the same functions, so there is no ambiguity. + (function_group,) = group_map.values() + else: + function_group = group_map.get(path) + + if function_group is None: + sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND + continue + + # The function_group is a list of all the same functions (according to + # our assumption) which should be annotated together. + signature_map[sig] = function_group + + return (signature_map, sig_error_map) + + def LoadAnnotation(self): + """Load annotation rules. + + Returns: + Map of add rules, set of remove rules, set of text signatures which can't + be parsed. + """ + # Assume there is no ":" in the path. + # Example: "get_range.lto.2501[driver/accel_kionix.c:327]" + annotation_signature_regex = re.compile( + r"^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$" + ) + + def NormalizeSignature(signature_text): + """Parse and normalize the annotation signature. + + Args: + signature_text: Text of the annotation signature. + + Returns: + (function name, path, line number) of the signature. The path and line + number can be None if not exist. None if failed to parse. + """ + result = annotation_signature_regex.match(signature_text.strip()) + if result is None: + return None + + name_result = self.FUNCTION_PREFIX_NAME_RE.match( + result.group("name").strip() + ) + if name_result is None: + return None + + path = result.group("path") + if path is not None: + path = os.path.realpath(path.strip()) + + linenum = result.group("linenum") + if linenum is not None: + linenum = int(linenum.strip()) + + return (name_result.group("name").strip(), path, linenum) + + def ExpandArray(dic): + """Parse and expand a symbol array + + Args: + dic: Dictionary for the array annotation + + Returns: + array of (symbol name, None, None). + """ + # TODO(drinkcat): This function is quite inefficient, as it goes through + # the symbol table multiple times. + + begin_name = dic["name"] + end_name = dic["name"] + "_end" + offset = dic["offset"] if "offset" in dic else 0 + stride = dic["stride"] + + begin_address = None + end_address = None + + for symbol in self.symbols: + if symbol.name == begin_name: + begin_address = symbol.address + if symbol.name == end_name: + end_address = symbol.address + + if not begin_address or not end_address: + return None + + output = [] + # TODO(drinkcat): This is inefficient as we go from address to symbol + # object then to symbol name, and later on we'll go back from symbol name + # to symbol object. + for addr in range(begin_address + offset, end_address, stride): + # TODO(drinkcat): Not all architectures need to drop the first bit. + val = self.rodata[(addr - self.rodata_offset) // 4] & 0xFFFFFFFE + name = None + for symbol in self.symbols: + if symbol.address == val: + result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) + name = result.group("name") + break + + if not name: + raise StackAnalyzerError( + "Cannot find function for address %s." % hex(val) + ) + + output.append((name, None, None)) + + return output + + add_rules = collections.defaultdict(set) + remove_rules = list() + invalid_sigtxts = set() + + if "add" in self.annotation and self.annotation["add"] is not None: + for src_sigtxt, dst_sigtxts in self.annotation["add"].items(): + src_sig = NormalizeSignature(src_sigtxt) + if src_sig is None: + invalid_sigtxts.add(src_sigtxt) + continue + + for dst_sigtxt in dst_sigtxts: + if isinstance(dst_sigtxt, dict): + dst_sig = ExpandArray(dst_sigtxt) + if dst_sig is None: + invalid_sigtxts.add(str(dst_sigtxt)) + else: + add_rules[src_sig].update(dst_sig) + else: + dst_sig = NormalizeSignature(dst_sigtxt) + if dst_sig is None: + invalid_sigtxts.add(dst_sigtxt) + else: + add_rules[src_sig].add(dst_sig) + + if ( + "remove" in self.annotation + and self.annotation["remove"] is not None + ): + for sigtxt_path in self.annotation["remove"]: + if isinstance(sigtxt_path, str): + # The path has only one vertex. + sigtxt_path = [sigtxt_path] + + if len(sigtxt_path) == 0: + continue + + # Generate multiple remove paths from all the combinations of the + # signatures of each vertex. + sig_paths = [[]] + broken_flag = False + for sigtxt_node in sigtxt_path: + if isinstance(sigtxt_node, str): + # The vertex has only one signature. + sigtxt_set = {sigtxt_node} + elif isinstance(sigtxt_node, list): + # The vertex has multiple signatures. + sigtxt_set = set(sigtxt_node) + else: + # Assume the format of annotation is verified. There should be no + # invalid case. + assert False + + sig_set = set() + for sigtxt in sigtxt_set: + sig = NormalizeSignature(sigtxt) + if sig is None: + invalid_sigtxts.add(sigtxt) + broken_flag = True + elif not broken_flag: + sig_set.add(sig) + + if broken_flag: + continue + + # Append each signature of the current node to the all previous + # remove paths. + sig_paths = [ + path + [sig] for path in sig_paths for sig in sig_set + ] + + if not broken_flag: + # All signatures are normalized. The remove path has no error. + remove_rules.extend(sig_paths) + + return (add_rules, remove_rules, invalid_sigtxts) + + def ResolveAnnotation(self, function_map): + """Resolve annotation. + + Args: + function_map: Function map. + + Returns: + Set of added call edges, list of remove paths, set of eliminated + callsite addresses, set of annotation signatures which can't be resolved. + """ + + def StringifySignature(signature): + """Stringify the tupled signature. + + Args: + signature: Tupled signature. + + Returns: + Signature string. + """ + (name, path, linenum) = signature + bracket_text = "" + if path is not None: + path = os.path.relpath(path) + if linenum is None: + bracket_text = "[{}]".format(path) + else: + bracket_text = "[{}:{}]".format(path, linenum) + + return name + bracket_text + + (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation() + + signature_set = set() + for src_sig, dst_sigs in add_rules.items(): + signature_set.add(src_sig) + signature_set.update(dst_sigs) + + for remove_sigs in remove_rules: + signature_set.update(remove_sigs) + + # Map signatures to functions. + (signature_map, sig_error_map) = self.MapAnnotation( + function_map, signature_set + ) + + # Build the indirect callsite map indexed by callsite signature. + indirect_map = collections.defaultdict(set) + for function in function_map.values(): + for callsite in function.callsites: + if callsite.target is not None: + continue + + # Found an indirect callsite. + line_info = self.AddressToLine(callsite.address)[0] + if line_info is None: + continue + + (name, path, linenum) = line_info + result = self.FUNCTION_PREFIX_NAME_RE.match(name) + if result is None: + continue + + indirect_map[(result.group("name").strip(), path, linenum)].add( + (function, callsite.address) + ) + + # Generate the annotation sets. + add_set = set() + remove_list = list() + eliminated_addrs = set() + + for src_sig, dst_sigs in add_rules.items(): + src_funcs = set(signature_map.get(src_sig, [])) + # Try to match the source signature to the indirect callsites. Even if it + # can't be found in disassembly. + indirect_calls = indirect_map.get(src_sig) + if indirect_calls is not None: + for function, callsite_address in indirect_calls: + # Add the caller of the indirect callsite to the source functions. + src_funcs.add(function) + # Assume each callsite can be represented by a unique address. + eliminated_addrs.add(callsite_address) + + if src_sig in sig_error_map: + # Assume the error is always the not found error. Since the signature + # found in indirect callsite map must be a full signature, it can't + # happen the ambiguous error. + assert ( + sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND + ) + # Found in inline stack, remove the not found error. + del sig_error_map[src_sig] + + for dst_sig in dst_sigs: + dst_funcs = signature_map.get(dst_sig) + if dst_funcs is None: + continue + + # Duplicate the call edge for all the same source and destination + # functions. + for src_func in src_funcs: + for dst_func in dst_funcs: + add_set.add((src_func, dst_func)) + + for remove_sigs in remove_rules: + # Since each signature can be mapped to multiple functions, generate + # multiple remove paths from all the combinations of these functions. + remove_paths = [[]] + skip_flag = False + for remove_sig in remove_sigs: + # Transform each signature to the corresponding functions. + remove_funcs = signature_map.get(remove_sig) + if remove_funcs is None: + # There is an unresolved signature in the remove path. Ignore the + # whole broken remove path. + skip_flag = True + break + else: + # Append each function of the current signature to the all previous + # remove paths. + remove_paths = [ + p + [f] for p in remove_paths for f in remove_funcs + ] + + if skip_flag: + # Ignore the broken remove path. + continue + + for remove_path in remove_paths: + # Deduplicate the remove paths. + if remove_path not in remove_list: + remove_list.append(remove_path) + + # Format the error messages. + failed_sigtxts = set() + for sigtxt in invalid_sigtxts: + failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID)) + + for sig, error in sig_error_map.items(): + failed_sigtxts.add((StringifySignature(sig), error)) + + return (add_set, remove_list, eliminated_addrs, failed_sigtxts) + + def PreprocessAnnotation( + self, function_map, add_set, remove_list, eliminated_addrs + ): + """Preprocess the annotation and callgraph. + + Add the missing call edges, and delete simple remove paths (the paths have + one or two vertices) from the function_map. + + Eliminate the annotated indirect callsites. + + Return the remaining remove list. + + Args: + function_map: Function map. + add_set: Set of missing call edges. + remove_list: List of remove paths. + eliminated_addrs: Set of eliminated callsite addresses. + + Returns: + List of remaining remove paths. + """ + + def CheckEdge(path): + """Check if all edges of the path are on the callgraph. + + Args: + path: Path. + + Returns: + True or False. + """ + for index in range(len(path) - 1): + if (path[index], path[index + 1]) not in edge_set: + return False + + return True + + for src_func, dst_func in add_set: + # TODO(cheyuw): Support tailing call annotation. + src_func.callsites.append( + Callsite(None, dst_func.address, False, dst_func) + ) + + # Delete simple remove paths. + remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2) + edge_set = set() + for function in function_map.values(): + cleaned_callsites = [] + for callsite in function.callsites: + if (callsite.callee,) in remove_simple or ( + function, + callsite.callee, + ) in remove_simple: + continue + + if ( + callsite.target is None + and callsite.address in eliminated_addrs + ): + continue + + cleaned_callsites.append(callsite) + if callsite.callee is not None: + edge_set.add((function, callsite.callee)) + + function.callsites = cleaned_callsites + + return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)] + + def AnalyzeCallGraph(self, function_map, remove_list): + """Analyze callgraph. + + It will update the max stack size and path for each function. + + Args: + function_map: Function map. + remove_list: List of remove paths. + + Returns: + List of function cycles. + """ + + def Traverse(curr_state): + """Traverse the callgraph and calculate the max stack usages of functions. + + Args: + curr_state: Current state. + + Returns: + SCC lowest link. + """ + scc_index = scc_index_counter[0] + scc_index_counter[0] += 1 + scc_index_map[curr_state] = scc_index + scc_lowlink = scc_index + scc_stack.append(curr_state) + # Push the current state in the stack. We can use a set to maintain this + # because the stacked states are unique; otherwise we will find a cycle + # first. + stacked_states.add(curr_state) + + (curr_address, curr_positions) = curr_state + curr_func = function_map[curr_address] + + invalid_flag = False + new_positions = list(curr_positions) + for index, position in enumerate(curr_positions): + remove_path = remove_list[index] + + # The position of each remove path in the state is the length of the + # longest matching path between the prefix of the remove path and the + # suffix of the current traversing path. We maintain this length when + # appending the next callee to the traversing path. And it can be used + # to check if the remove path appears in the traversing path. + + # TODO(cheyuw): Implement KMP algorithm to match remove paths + # efficiently. + if remove_path[position] is curr_func: + # Matches the current function, extend the length. + new_positions[index] = position + 1 + if new_positions[index] == len(remove_path): + # The length of the longest matching path is equal to the length of + # the remove path, which means the suffix of the current traversing + # path matches the remove path. + invalid_flag = True + break + + else: + # We can't get the new longest matching path by extending the previous + # one directly. Fallback to search the new longest matching path. + + # If we can't find any matching path in the following search, reset + # the matching length to 0. + new_positions[index] = 0 + + # We want to find the new longest matching prefix of remove path with + # the suffix of the current traversing path. Because the new longest + # matching path won't be longer than the prevous one now, and part of + # the suffix matches the prefix of remove path, we can get the needed + # suffix from the previous matching prefix of the invalid path. + suffix = remove_path[:position] + [curr_func] + for offset in range(1, len(suffix)): + length = position - offset + if remove_path[:length] == suffix[offset:]: + new_positions[index] = length + break + + new_positions = tuple(new_positions) + + # If the current suffix is invalid, set the max stack usage to 0. + max_stack_usage = 0 + max_callee_state = None + self_loop = False + + if not invalid_flag: + # Max stack usage is at least equal to the stack frame. + max_stack_usage = curr_func.stack_frame + for callsite in curr_func.callsites: + callee = callsite.callee + if callee is None: + continue + + callee_state = (callee.address, new_positions) + if callee_state not in scc_index_map: + # Unvisited state. + scc_lowlink = min(scc_lowlink, Traverse(callee_state)) + elif callee_state in stacked_states: + # The state is shown in the stack. There is a cycle. + sub_stack_usage = 0 + scc_lowlink = min( + scc_lowlink, scc_index_map[callee_state] + ) + if callee_state == curr_state: + self_loop = True + + done_result = done_states.get(callee_state) + if done_result is not None: + # Already done this state and use its result. If the state reaches a + # cycle, reusing the result will cause inaccuracy (the stack usage + # of cycle depends on where the entrance is). But it's fine since we + # can't get accurate stack usage under this situation, and we rely + # on user-provided annotations to break the cycle, after which the + # result will be accurate again. + (sub_stack_usage, _) = done_result + + if callsite.is_tail: + # For tailing call, since the callee reuses the stack frame of the + # caller, choose the larger one directly. + stack_usage = max( + curr_func.stack_frame, sub_stack_usage + ) + else: + stack_usage = ( + curr_func.stack_frame + sub_stack_usage + ) + + if stack_usage > max_stack_usage: + max_stack_usage = stack_usage + max_callee_state = callee_state + + if scc_lowlink == scc_index: + group = [] + while scc_stack[-1] != curr_state: + scc_state = scc_stack.pop() + stacked_states.remove(scc_state) + group.append(scc_state) + + scc_stack.pop() + stacked_states.remove(curr_state) + + # If the cycle is not empty, record it. + if len(group) > 0 or self_loop: + group.append(curr_state) + cycle_groups.append(group) + + # Store the done result. + done_states[curr_state] = (max_stack_usage, max_callee_state) + + if curr_positions == initial_positions: + # If the current state is initial state, we traversed the callgraph by + # using the current function as start point. Update the stack usage of + # the function. + # If the function matches a single vertex remove path, this will set its + # max stack usage to 0, which is not expected (we still calculate its + # max stack usage, but prevent any function from calling it). However, + # all the single vertex remove paths have been preprocessed and removed. + curr_func.stack_max_usage = max_stack_usage + + # Reconstruct the max stack path by traversing the state transitions. + max_stack_path = [curr_func] + callee_state = max_callee_state + while callee_state is not None: + # The first element of state tuple is function address. + max_stack_path.append(function_map[callee_state[0]]) + done_result = done_states.get(callee_state) + # All of the descendants should be done. + assert done_result is not None + (_, callee_state) = done_result + + curr_func.stack_max_path = max_stack_path + + return scc_lowlink + + # The state is the concatenation of the current function address and the + # state of matching position. + initial_positions = (0,) * len(remove_list) + done_states = {} + stacked_states = set() + scc_index_counter = [0] + scc_index_map = {} + scc_stack = [] + cycle_groups = [] + for function in function_map.values(): + if function.stack_max_usage is None: + Traverse((function.address, initial_positions)) + + cycle_functions = [] + for group in cycle_groups: + cycle = set(function_map[state[0]] for state in group) + if cycle not in cycle_functions: + cycle_functions.append(cycle) + + return cycle_functions + + def Analyze(self): + """Run the stack analysis. + + Raises: + StackAnalyzerError: If disassembly fails. + """ + + def OutputInlineStack(address, prefix=""): + """Output beautiful inline stack. + + Args: + address: Address. + prefix: Prefix of each line. + + Returns: + Key for sorting, output text + """ + line_infos = self.AddressToLine(address, True) + + if line_infos[0] is None: + order_key = (None, None) else: - add_rules[src_sig].add(dst_sig) - - if 'remove' in self.annotation and self.annotation['remove'] is not None: - for sigtxt_path in self.annotation['remove']: - if isinstance(sigtxt_path, str): - # The path has only one vertex. - sigtxt_path = [sigtxt_path] - - if len(sigtxt_path) == 0: - continue - - # Generate multiple remove paths from all the combinations of the - # signatures of each vertex. - sig_paths = [[]] - broken_flag = False - for sigtxt_node in sigtxt_path: - if isinstance(sigtxt_node, str): - # The vertex has only one signature. - sigtxt_set = {sigtxt_node} - elif isinstance(sigtxt_node, list): - # The vertex has multiple signatures. - sigtxt_set = set(sigtxt_node) - else: - # Assume the format of annotation is verified. There should be no - # invalid case. - assert False - - sig_set = set() - for sigtxt in sigtxt_set: - sig = NormalizeSignature(sigtxt) - if sig is None: - invalid_sigtxts.add(sigtxt) - broken_flag = True - elif not broken_flag: - sig_set.add(sig) - - if broken_flag: - continue - - # Append each signature of the current node to the all previous - # remove paths. - sig_paths = [path + [sig] for path in sig_paths for sig in sig_set] - - if not broken_flag: - # All signatures are normalized. The remove path has no error. - remove_rules.extend(sig_paths) - - return (add_rules, remove_rules, invalid_sigtxts) + (_, path, linenum) = line_infos[0] + order_key = (linenum, path) + + line_texts = [] + for line_info in reversed(line_infos): + if line_info is None: + (function_name, path, linenum) = ("??", "??", 0) + else: + (function_name, path, linenum) = line_info + + line_texts.append( + "{}[{}:{}]".format( + function_name, os.path.relpath(path), linenum + ) + ) + + output = "{}-> {} {:x}\n".format(prefix, line_texts[0], address) + for depth, line_text in enumerate(line_texts[1:]): + output += "{} {}- {}\n".format( + prefix, " " * depth, line_text + ) + + # Remove the last newline character. + return (order_key, output.rstrip("\n")) + + # Analyze disassembly. + try: + disasm_text = subprocess.check_output( + [self.options.objdump, "-d", self.options.elf_path], + encoding="utf-8", + ) + except subprocess.CalledProcessError: + raise StackAnalyzerError("objdump failed to disassemble.") + except OSError: + raise StackAnalyzerError("Failed to run objdump.") + + function_map = self.AnalyzeDisassembly(disasm_text) + result = self.ResolveAnnotation(function_map) + (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result + remove_list = self.PreprocessAnnotation( + function_map, add_set, remove_list, eliminated_addrs + ) + cycle_functions = self.AnalyzeCallGraph(function_map, remove_list) + + # Print the results of task-aware stack analysis. + extra_stack_frame = self.annotation.get( + "exception_frame_size", DEFAULT_EXCEPTION_FRAME_SIZE + ) + for task in self.tasklist: + routine_func = function_map[task.routine_address] + print( + "Task: {}, Max size: {} ({} + {}), Allocated size: {}".format( + task.name, + routine_func.stack_max_usage + extra_stack_frame, + routine_func.stack_max_usage, + extra_stack_frame, + task.stack_max_size, + ) + ) + + print("Call Trace:") + max_stack_path = routine_func.stack_max_path + # Assume the routine function is resolved. + assert max_stack_path is not None + for depth, curr_func in enumerate(max_stack_path): + line_info = self.AddressToLine(curr_func.address)[0] + if line_info is None: + (path, linenum) = ("??", 0) + else: + (_, path, linenum) = line_info + + print( + " {} ({}) [{}:{}] {:x}".format( + curr_func.name, + curr_func.stack_frame, + os.path.relpath(path), + linenum, + curr_func.address, + ) + ) + + if depth + 1 < len(max_stack_path): + succ_func = max_stack_path[depth + 1] + text_list = [] + for callsite in curr_func.callsites: + if callsite.callee is succ_func: + indent_prefix = " " + if callsite.address is None: + order_text = ( + None, + "{}-> [annotation]".format(indent_prefix), + ) + else: + order_text = OutputInlineStack( + callsite.address, indent_prefix + ) + + text_list.append(order_text) + + for _, text in sorted(text_list, key=lambda item: item[0]): + print(text) + + print("Unresolved indirect callsites:") + for function in function_map.values(): + indirect_callsites = [] + for callsite in function.callsites: + if callsite.target is None: + indirect_callsites.append(callsite.address) + + if len(indirect_callsites) > 0: + print(" In function {}:".format(function.name)) + text_list = [] + for address in indirect_callsites: + text_list.append(OutputInlineStack(address, " ")) + + for _, text in sorted(text_list, key=lambda item: item[0]): + print(text) + + print("Unresolved annotation signatures:") + for sigtxt, error in failed_sigtxts: + print(" {}: {}".format(sigtxt, error)) + + if len(cycle_functions) > 0: + print("There are cycles in the following function sets:") + for functions in cycle_functions: + print( + "[{}]".format( + ", ".join(function.name for function in functions) + ) + ) - def ResolveAnnotation(self, function_map): - """Resolve annotation. - Args: - function_map: Function map. +def ParseArgs(): + """Parse commandline arguments. Returns: - Set of added call edges, list of remove paths, set of eliminated - callsite addresses, set of annotation signatures which can't be resolved. + options: Namespace from argparse.parse_args(). """ - def StringifySignature(signature): - """Stringify the tupled signature. - - Args: - signature: Tupled signature. - - Returns: - Signature string. - """ - (name, path, linenum) = signature - bracket_text = '' - if path is not None: - path = os.path.relpath(path) - if linenum is None: - bracket_text = '[{}]'.format(path) - else: - bracket_text = '[{}:{}]'.format(path, linenum) - - return name + bracket_text - - (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation() - - signature_set = set() - for src_sig, dst_sigs in add_rules.items(): - signature_set.add(src_sig) - signature_set.update(dst_sigs) - - for remove_sigs in remove_rules: - signature_set.update(remove_sigs) - - # Map signatures to functions. - (signature_map, sig_error_map) = self.MapAnnotation(function_map, - signature_set) - - # Build the indirect callsite map indexed by callsite signature. - indirect_map = collections.defaultdict(set) - for function in function_map.values(): - for callsite in function.callsites: - if callsite.target is not None: - continue - - # Found an indirect callsite. - line_info = self.AddressToLine(callsite.address)[0] - if line_info is None: - continue - - (name, path, linenum) = line_info - result = self.FUNCTION_PREFIX_NAME_RE.match(name) - if result is None: - continue - - indirect_map[(result.group('name').strip(), path, linenum)].add( - (function, callsite.address)) - - # Generate the annotation sets. - add_set = set() - remove_list = list() - eliminated_addrs = set() - - for src_sig, dst_sigs in add_rules.items(): - src_funcs = set(signature_map.get(src_sig, [])) - # Try to match the source signature to the indirect callsites. Even if it - # can't be found in disassembly. - indirect_calls = indirect_map.get(src_sig) - if indirect_calls is not None: - for function, callsite_address in indirect_calls: - # Add the caller of the indirect callsite to the source functions. - src_funcs.add(function) - # Assume each callsite can be represented by a unique address. - eliminated_addrs.add(callsite_address) - - if src_sig in sig_error_map: - # Assume the error is always the not found error. Since the signature - # found in indirect callsite map must be a full signature, it can't - # happen the ambiguous error. - assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND - # Found in inline stack, remove the not found error. - del sig_error_map[src_sig] - - for dst_sig in dst_sigs: - dst_funcs = signature_map.get(dst_sig) - if dst_funcs is None: - continue - - # Duplicate the call edge for all the same source and destination - # functions. - for src_func in src_funcs: - for dst_func in dst_funcs: - add_set.add((src_func, dst_func)) - - for remove_sigs in remove_rules: - # Since each signature can be mapped to multiple functions, generate - # multiple remove paths from all the combinations of these functions. - remove_paths = [[]] - skip_flag = False - for remove_sig in remove_sigs: - # Transform each signature to the corresponding functions. - remove_funcs = signature_map.get(remove_sig) - if remove_funcs is None: - # There is an unresolved signature in the remove path. Ignore the - # whole broken remove path. - skip_flag = True - break - else: - # Append each function of the current signature to the all previous - # remove paths. - remove_paths = [p + [f] for p in remove_paths for f in remove_funcs] - - if skip_flag: - # Ignore the broken remove path. - continue - - for remove_path in remove_paths: - # Deduplicate the remove paths. - if remove_path not in remove_list: - remove_list.append(remove_path) - - # Format the error messages. - failed_sigtxts = set() - for sigtxt in invalid_sigtxts: - failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID)) - - for sig, error in sig_error_map.items(): - failed_sigtxts.add((StringifySignature(sig), error)) - - return (add_set, remove_list, eliminated_addrs, failed_sigtxts) - - def PreprocessAnnotation(self, function_map, add_set, remove_list, - eliminated_addrs): - """Preprocess the annotation and callgraph. + parser = argparse.ArgumentParser(description="EC firmware stack analyzer.") + parser.add_argument("elf_path", help="the path of EC firmware ELF") + parser.add_argument( + "--export_taskinfo", + required=True, + help="the path of export_taskinfo.so utility", + ) + parser.add_argument( + "--section", + required=True, + help="the section.", + choices=[SECTION_RO, SECTION_RW], + ) + parser.add_argument( + "--objdump", default="objdump", help="the path of objdump" + ) + parser.add_argument( + "--addr2line", default="addr2line", help="the path of addr2line" + ) + parser.add_argument( + "--annotation", default=None, help="the path of annotation file" + ) + + # TODO(cheyuw): Add an option for dumping stack usage of all functions. + + return parser.parse_args() - Add the missing call edges, and delete simple remove paths (the paths have - one or two vertices) from the function_map. - Eliminate the annotated indirect callsites. - - Return the remaining remove list. +def ParseSymbolText(symbol_text): + """Parse the content of the symbol text. Args: - function_map: Function map. - add_set: Set of missing call edges. - remove_list: List of remove paths. - eliminated_addrs: Set of eliminated callsite addresses. + symbol_text: Text of the symbols. Returns: - List of remaining remove paths. + symbols: Symbol list. """ - def CheckEdge(path): - """Check if all edges of the path are on the callgraph. - - Args: - path: Path. - - Returns: - True or False. - """ - for index in range(len(path) - 1): - if (path[index], path[index + 1]) not in edge_set: - return False - - return True - - for src_func, dst_func in add_set: - # TODO(cheyuw): Support tailing call annotation. - src_func.callsites.append( - Callsite(None, dst_func.address, False, dst_func)) - - # Delete simple remove paths. - remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2) - edge_set = set() - for function in function_map.values(): - cleaned_callsites = [] - for callsite in function.callsites: - if ((callsite.callee,) in remove_simple or - (function, callsite.callee) in remove_simple): - continue - - if callsite.target is None and callsite.address in eliminated_addrs: - continue - - cleaned_callsites.append(callsite) - if callsite.callee is not None: - edge_set.add((function, callsite.callee)) + # Example: "10093064 g F .text 0000015c .hidden hook_task" + symbol_regex = re.compile( + r"^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+" + r"((?P<type>[OF])\s+)?\S+\s+" + r"(?P<size>[0-9A-Fa-f]+)\s+" + r"(\S+\s+)?(?P<name>\S+)$" + ) + + symbols = [] + for line in symbol_text.splitlines(): + line = line.strip() + result = symbol_regex.match(line) + if result is not None: + address = int(result.group("address"), 16) + symtype = result.group("type") + if symtype is None: + symtype = "O" - function.callsites = cleaned_callsites + size = int(result.group("size"), 16) + name = result.group("name") + symbols.append(Symbol(address, symtype, size, name)) - return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)] + return symbols - def AnalyzeCallGraph(self, function_map, remove_list): - """Analyze callgraph. - It will update the max stack size and path for each function. +def ParseRoDataText(rodata_text): + """Parse the content of rodata Args: - function_map: Function map. - remove_list: List of remove paths. + symbol_text: Text of the rodata dump. Returns: - List of function cycles. + symbols: Symbol list. """ - def Traverse(curr_state): - """Traverse the callgraph and calculate the max stack usages of functions. - - Args: - curr_state: Current state. - - Returns: - SCC lowest link. - """ - scc_index = scc_index_counter[0] - scc_index_counter[0] += 1 - scc_index_map[curr_state] = scc_index - scc_lowlink = scc_index - scc_stack.append(curr_state) - # Push the current state in the stack. We can use a set to maintain this - # because the stacked states are unique; otherwise we will find a cycle - # first. - stacked_states.add(curr_state) - - (curr_address, curr_positions) = curr_state - curr_func = function_map[curr_address] - - invalid_flag = False - new_positions = list(curr_positions) - for index, position in enumerate(curr_positions): - remove_path = remove_list[index] - - # The position of each remove path in the state is the length of the - # longest matching path between the prefix of the remove path and the - # suffix of the current traversing path. We maintain this length when - # appending the next callee to the traversing path. And it can be used - # to check if the remove path appears in the traversing path. - - # TODO(cheyuw): Implement KMP algorithm to match remove paths - # efficiently. - if remove_path[position] is curr_func: - # Matches the current function, extend the length. - new_positions[index] = position + 1 - if new_positions[index] == len(remove_path): - # The length of the longest matching path is equal to the length of - # the remove path, which means the suffix of the current traversing - # path matches the remove path. - invalid_flag = True - break - - else: - # We can't get the new longest matching path by extending the previous - # one directly. Fallback to search the new longest matching path. - - # If we can't find any matching path in the following search, reset - # the matching length to 0. - new_positions[index] = 0 - - # We want to find the new longest matching prefix of remove path with - # the suffix of the current traversing path. Because the new longest - # matching path won't be longer than the prevous one now, and part of - # the suffix matches the prefix of remove path, we can get the needed - # suffix from the previous matching prefix of the invalid path. - suffix = remove_path[:position] + [curr_func] - for offset in range(1, len(suffix)): - length = position - offset - if remove_path[:length] == suffix[offset:]: - new_positions[index] = length - break - - new_positions = tuple(new_positions) - - # If the current suffix is invalid, set the max stack usage to 0. - max_stack_usage = 0 - max_callee_state = None - self_loop = False - - if not invalid_flag: - # Max stack usage is at least equal to the stack frame. - max_stack_usage = curr_func.stack_frame - for callsite in curr_func.callsites: - callee = callsite.callee - if callee is None: + # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K... + # 100a7294 00000000 00000000 01000000 ............ + + base_offset = None + offset = None + rodata = [] + for line in rodata_text.splitlines(): + line = line.strip() + space = line.find(" ") + if space < 0: + continue + try: + address = int(line[0:space], 16) + except ValueError: continue - callee_state = (callee.address, new_positions) - if callee_state not in scc_index_map: - # Unvisited state. - scc_lowlink = min(scc_lowlink, Traverse(callee_state)) - elif callee_state in stacked_states: - # The state is shown in the stack. There is a cycle. - sub_stack_usage = 0 - scc_lowlink = min(scc_lowlink, scc_index_map[callee_state]) - if callee_state == curr_state: - self_loop = True - - done_result = done_states.get(callee_state) - if done_result is not None: - # Already done this state and use its result. If the state reaches a - # cycle, reusing the result will cause inaccuracy (the stack usage - # of cycle depends on where the entrance is). But it's fine since we - # can't get accurate stack usage under this situation, and we rely - # on user-provided annotations to break the cycle, after which the - # result will be accurate again. - (sub_stack_usage, _) = done_result - - if callsite.is_tail: - # For tailing call, since the callee reuses the stack frame of the - # caller, choose the larger one directly. - stack_usage = max(curr_func.stack_frame, sub_stack_usage) - else: - stack_usage = curr_func.stack_frame + sub_stack_usage - - if stack_usage > max_stack_usage: - max_stack_usage = stack_usage - max_callee_state = callee_state - - if scc_lowlink == scc_index: - group = [] - while scc_stack[-1] != curr_state: - scc_state = scc_stack.pop() - stacked_states.remove(scc_state) - group.append(scc_state) - - scc_stack.pop() - stacked_states.remove(curr_state) - - # If the cycle is not empty, record it. - if len(group) > 0 or self_loop: - group.append(curr_state) - cycle_groups.append(group) - - # Store the done result. - done_states[curr_state] = (max_stack_usage, max_callee_state) - - if curr_positions == initial_positions: - # If the current state is initial state, we traversed the callgraph by - # using the current function as start point. Update the stack usage of - # the function. - # If the function matches a single vertex remove path, this will set its - # max stack usage to 0, which is not expected (we still calculate its - # max stack usage, but prevent any function from calling it). However, - # all the single vertex remove paths have been preprocessed and removed. - curr_func.stack_max_usage = max_stack_usage - - # Reconstruct the max stack path by traversing the state transitions. - max_stack_path = [curr_func] - callee_state = max_callee_state - while callee_state is not None: - # The first element of state tuple is function address. - max_stack_path.append(function_map[callee_state[0]]) - done_result = done_states.get(callee_state) - # All of the descendants should be done. - assert done_result is not None - (_, callee_state) = done_result - - curr_func.stack_max_path = max_stack_path - - return scc_lowlink - - # The state is the concatenation of the current function address and the - # state of matching position. - initial_positions = (0,) * len(remove_list) - done_states = {} - stacked_states = set() - scc_index_counter = [0] - scc_index_map = {} - scc_stack = [] - cycle_groups = [] - for function in function_map.values(): - if function.stack_max_usage is None: - Traverse((function.address, initial_positions)) - - cycle_functions = [] - for group in cycle_groups: - cycle = set(function_map[state[0]] for state in group) - if cycle not in cycle_functions: - cycle_functions.append(cycle) - - return cycle_functions - - def Analyze(self): - """Run the stack analysis. - - Raises: - StackAnalyzerError: If disassembly fails. - """ - def OutputInlineStack(address, prefix=''): - """Output beautiful inline stack. - - Args: - address: Address. - prefix: Prefix of each line. - - Returns: - Key for sorting, output text - """ - line_infos = self.AddressToLine(address, True) - - if line_infos[0] is None: - order_key = (None, None) - else: - (_, path, linenum) = line_infos[0] - order_key = (linenum, path) - - line_texts = [] - for line_info in reversed(line_infos): - if line_info is None: - (function_name, path, linenum) = ('??', '??', 0) - else: - (function_name, path, linenum) = line_info - - line_texts.append('{}[{}:{}]'.format(function_name, - os.path.relpath(path), - linenum)) - - output = '{}-> {} {:x}\n'.format(prefix, line_texts[0], address) - for depth, line_text in enumerate(line_texts[1:]): - output += '{} {}- {}\n'.format(prefix, ' ' * depth, line_text) - - # Remove the last newline character. - return (order_key, output.rstrip('\n')) - - # Analyze disassembly. - try: - disasm_text = subprocess.check_output([self.options.objdump, - '-d', - self.options.elf_path], - encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('objdump failed to disassemble.') - except OSError: - raise StackAnalyzerError('Failed to run objdump.') - - function_map = self.AnalyzeDisassembly(disasm_text) - result = self.ResolveAnnotation(function_map) - (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result - remove_list = self.PreprocessAnnotation(function_map, - add_set, - remove_list, - eliminated_addrs) - cycle_functions = self.AnalyzeCallGraph(function_map, remove_list) - - # Print the results of task-aware stack analysis. - extra_stack_frame = self.annotation.get('exception_frame_size', - DEFAULT_EXCEPTION_FRAME_SIZE) - for task in self.tasklist: - routine_func = function_map[task.routine_address] - print('Task: {}, Max size: {} ({} + {}), Allocated size: {}'.format( - task.name, - routine_func.stack_max_usage + extra_stack_frame, - routine_func.stack_max_usage, - extra_stack_frame, - task.stack_max_size)) - - print('Call Trace:') - max_stack_path = routine_func.stack_max_path - # Assume the routine function is resolved. - assert max_stack_path is not None - for depth, curr_func in enumerate(max_stack_path): - line_info = self.AddressToLine(curr_func.address)[0] - if line_info is None: - (path, linenum) = ('??', 0) - else: - (_, path, linenum) = line_info - - print(' {} ({}) [{}:{}] {:x}'.format(curr_func.name, - curr_func.stack_frame, - os.path.relpath(path), - linenum, - curr_func.address)) - - if depth + 1 < len(max_stack_path): - succ_func = max_stack_path[depth + 1] - text_list = [] - for callsite in curr_func.callsites: - if callsite.callee is succ_func: - indent_prefix = ' ' - if callsite.address is None: - order_text = (None, '{}-> [annotation]'.format(indent_prefix)) - else: - order_text = OutputInlineStack(callsite.address, indent_prefix) - - text_list.append(order_text) - - for _, text in sorted(text_list, key=lambda item: item[0]): - print(text) - - print('Unresolved indirect callsites:') - for function in function_map.values(): - indirect_callsites = [] - for callsite in function.callsites: - if callsite.target is None: - indirect_callsites.append(callsite.address) - - if len(indirect_callsites) > 0: - print(' In function {}:'.format(function.name)) - text_list = [] - for address in indirect_callsites: - text_list.append(OutputInlineStack(address, ' ')) - - for _, text in sorted(text_list, key=lambda item: item[0]): - print(text) - - print('Unresolved annotation signatures:') - for sigtxt, error in failed_sigtxts: - print(' {}: {}'.format(sigtxt, error)) - - if len(cycle_functions) > 0: - print('There are cycles in the following function sets:') - for functions in cycle_functions: - print('[{}]'.format(', '.join(function.name for function in functions))) - - -def ParseArgs(): - """Parse commandline arguments. - - Returns: - options: Namespace from argparse.parse_args(). - """ - parser = argparse.ArgumentParser(description="EC firmware stack analyzer.") - parser.add_argument('elf_path', help="the path of EC firmware ELF") - parser.add_argument('--export_taskinfo', required=True, - help="the path of export_taskinfo.so utility") - parser.add_argument('--section', required=True, help='the section.', - choices=[SECTION_RO, SECTION_RW]) - parser.add_argument('--objdump', default='objdump', - help='the path of objdump') - parser.add_argument('--addr2line', default='addr2line', - help='the path of addr2line') - parser.add_argument('--annotation', default=None, - help='the path of annotation file') - - # TODO(cheyuw): Add an option for dumping stack usage of all functions. - - return parser.parse_args() - - -def ParseSymbolText(symbol_text): - """Parse the content of the symbol text. - - Args: - symbol_text: Text of the symbols. - - Returns: - symbols: Symbol list. - """ - # Example: "10093064 g F .text 0000015c .hidden hook_task" - symbol_regex = re.compile(r'^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+' - r'((?P<type>[OF])\s+)?\S+\s+' - r'(?P<size>[0-9A-Fa-f]+)\s+' - r'(\S+\s+)?(?P<name>\S+)$') - - symbols = [] - for line in symbol_text.splitlines(): - line = line.strip() - result = symbol_regex.match(line) - if result is not None: - address = int(result.group('address'), 16) - symtype = result.group('type') - if symtype is None: - symtype = 'O' - - size = int(result.group('size'), 16) - name = result.group('name') - symbols.append(Symbol(address, symtype, size, name)) - - return symbols - - -def ParseRoDataText(rodata_text): - """Parse the content of rodata - - Args: - symbol_text: Text of the rodata dump. - - Returns: - symbols: Symbol list. - """ - # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K... - # 100a7294 00000000 00000000 01000000 ............ - - base_offset = None - offset = None - rodata = [] - for line in rodata_text.splitlines(): - line = line.strip() - space = line.find(' ') - if space < 0: - continue - try: - address = int(line[0:space], 16) - except ValueError: - continue - - if not base_offset: - base_offset = address - offset = address - elif address != offset: - raise StackAnalyzerError('objdump of rodata not contiguous.') + if not base_offset: + base_offset = address + offset = address + elif address != offset: + raise StackAnalyzerError("objdump of rodata not contiguous.") - for i in range(0, 4): - num = line[(space + 1 + i*9):(space + 9 + i*9)] - if len(num.strip()) > 0: - val = int(num, 16) - else: - val = 0 - # TODO(drinkcat): Not all platforms are necessarily big-endian - rodata.append((val & 0x000000ff) << 24 | - (val & 0x0000ff00) << 8 | - (val & 0x00ff0000) >> 8 | - (val & 0xff000000) >> 24) + for i in range(0, 4): + num = line[(space + 1 + i * 9) : (space + 9 + i * 9)] + if len(num.strip()) > 0: + val = int(num, 16) + else: + val = 0 + # TODO(drinkcat): Not all platforms are necessarily big-endian + rodata.append( + (val & 0x000000FF) << 24 + | (val & 0x0000FF00) << 8 + | (val & 0x00FF0000) >> 8 + | (val & 0xFF000000) >> 24 + ) - offset = offset + 4*4 + offset = offset + 4 * 4 - return (base_offset, rodata) + return (base_offset, rodata) def LoadTasklist(section, export_taskinfo, symbols): - """Load the task information. + """Load the task information. - Args: - section: Section (RO | RW). - export_taskinfo: Handle of export_taskinfo.so. - symbols: Symbol list. + Args: + section: Section (RO | RW). + export_taskinfo: Handle of export_taskinfo.so. + symbols: Symbol list. - Returns: - tasklist: Task list. - """ + Returns: + tasklist: Task list. + """ - TaskInfoPointer = ctypes.POINTER(TaskInfo) - taskinfos = TaskInfoPointer() - if section == SECTION_RO: - get_taskinfos_func = export_taskinfo.get_ro_taskinfos - else: - get_taskinfos_func = export_taskinfo.get_rw_taskinfos + TaskInfoPointer = ctypes.POINTER(TaskInfo) + taskinfos = TaskInfoPointer() + if section == SECTION_RO: + get_taskinfos_func = export_taskinfo.get_ro_taskinfos + else: + get_taskinfos_func = export_taskinfo.get_rw_taskinfos - taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos)) + taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos)) - tasklist = [] - for index in range(taskinfo_num): - taskinfo = taskinfos[index] - tasklist.append(Task(taskinfo.name.decode('utf-8'), - taskinfo.routine.decode('utf-8'), - taskinfo.stack_size)) + tasklist = [] + for index in range(taskinfo_num): + taskinfo = taskinfos[index] + tasklist.append( + Task( + taskinfo.name.decode("utf-8"), + taskinfo.routine.decode("utf-8"), + taskinfo.stack_size, + ) + ) - # Resolve routine address for each task. It's more efficient to resolve all - # routine addresses of tasks together. - routine_map = dict((task.routine_name, None) for task in tasklist) + # Resolve routine address for each task. It's more efficient to resolve all + # routine addresses of tasks together. + routine_map = dict((task.routine_name, None) for task in tasklist) - for symbol in symbols: - # Resolve task routine address. - if symbol.name in routine_map: - # Assume the symbol of routine is unique. - assert routine_map[symbol.name] is None - routine_map[symbol.name] = symbol.address + for symbol in symbols: + # Resolve task routine address. + if symbol.name in routine_map: + # Assume the symbol of routine is unique. + assert routine_map[symbol.name] is None + routine_map[symbol.name] = symbol.address - for task in tasklist: - address = routine_map[task.routine_name] - # Assume we have resolved all routine addresses. - assert address is not None - task.routine_address = address + for task in tasklist: + address = routine_map[task.routine_name] + # Assume we have resolved all routine addresses. + assert address is not None + task.routine_address = address - return tasklist + return tasklist def main(): - """Main function.""" - try: - options = ParseArgs() - - # Load annotation config. - if options.annotation is None: - annotation = {} - elif not os.path.exists(options.annotation): - print('Warning: Annotation file {} does not exist.' - .format(options.annotation)) - annotation = {} - else: - try: - with open(options.annotation, 'r') as annotation_file: - annotation = yaml.safe_load(annotation_file) - - except yaml.YAMLError: - raise StackAnalyzerError('Failed to parse annotation file {}.' - .format(options.annotation)) - except IOError: - raise StackAnalyzerError('Failed to open annotation file {}.' - .format(options.annotation)) - - # TODO(cheyuw): Do complete annotation format verification. - if not isinstance(annotation, dict): - raise StackAnalyzerError('Invalid annotation file {}.' - .format(options.annotation)) - - # Generate and parse the symbols. + """Main function.""" try: - symbol_text = subprocess.check_output([options.objdump, - '-t', - options.elf_path], - encoding='utf-8') - rodata_text = subprocess.check_output([options.objdump, - '-s', - '-j', '.rodata', - options.elf_path], - encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('objdump failed to dump symbol table or rodata.') - except OSError: - raise StackAnalyzerError('Failed to run objdump.') - - symbols = ParseSymbolText(symbol_text) - rodata = ParseRoDataText(rodata_text) - - # Load the tasklist. - try: - export_taskinfo = ctypes.CDLL(options.export_taskinfo) - except OSError: - raise StackAnalyzerError('Failed to load export_taskinfo.') - - tasklist = LoadTasklist(options.section, export_taskinfo, symbols) - - analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation) - analyzer.Analyze() - except StackAnalyzerError as e: - print('Error: {}'.format(e)) - - -if __name__ == '__main__': - main() + options = ParseArgs() + + # Load annotation config. + if options.annotation is None: + annotation = {} + elif not os.path.exists(options.annotation): + print( + "Warning: Annotation file {} does not exist.".format( + options.annotation + ) + ) + annotation = {} + else: + try: + with open(options.annotation, "r") as annotation_file: + annotation = yaml.safe_load(annotation_file) + + except yaml.YAMLError: + raise StackAnalyzerError( + "Failed to parse annotation file {}.".format( + options.annotation + ) + ) + except IOError: + raise StackAnalyzerError( + "Failed to open annotation file {}.".format( + options.annotation + ) + ) + + # TODO(cheyuw): Do complete annotation format verification. + if not isinstance(annotation, dict): + raise StackAnalyzerError( + "Invalid annotation file {}.".format(options.annotation) + ) + + # Generate and parse the symbols. + try: + symbol_text = subprocess.check_output( + [options.objdump, "-t", options.elf_path], encoding="utf-8" + ) + rodata_text = subprocess.check_output( + [options.objdump, "-s", "-j", ".rodata", options.elf_path], + encoding="utf-8", + ) + except subprocess.CalledProcessError: + raise StackAnalyzerError( + "objdump failed to dump symbol table or rodata." + ) + except OSError: + raise StackAnalyzerError("Failed to run objdump.") + + symbols = ParseSymbolText(symbol_text) + rodata = ParseRoDataText(rodata_text) + + # Load the tasklist. + try: + export_taskinfo = ctypes.CDLL(options.export_taskinfo) + except OSError: + raise StackAnalyzerError("Failed to load export_taskinfo.") + + tasklist = LoadTasklist(options.section, export_taskinfo, symbols) + + analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation) + analyzer.Analyze() + except StackAnalyzerError as e: + print("Error: {}".format(e)) + + +if __name__ == "__main__": + main() diff --git a/extra/stack_analyzer/stack_analyzer_unittest.py b/extra/stack_analyzer/stack_analyzer_unittest.py index c36fa9da45..23a8fb93ea 100755 --- a/extra/stack_analyzer/stack_analyzer_unittest.py +++ b/extra/stack_analyzer/stack_analyzer_unittest.py @@ -1,830 +1,993 @@ #!/usr/bin/env python3 -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Tests for Stack Analyzer classes and functions.""" from __future__ import print_function -import mock import os import subprocess import unittest +import mock # pylint:disable=import-error import stack_analyzer as sa class ObjectTest(unittest.TestCase): - """Tests for classes of basic objects.""" - - def testTask(self): - task_a = sa.Task('a', 'a_task', 1234) - task_b = sa.Task('b', 'b_task', 5678, 0x1000) - self.assertEqual(task_a, task_a) - self.assertNotEqual(task_a, task_b) - self.assertNotEqual(task_a, None) - - def testSymbol(self): - symbol_a = sa.Symbol(0x1234, 'F', 32, 'a') - symbol_b = sa.Symbol(0x234, 'O', 42, 'b') - self.assertEqual(symbol_a, symbol_a) - self.assertNotEqual(symbol_a, symbol_b) - self.assertNotEqual(symbol_a, None) - - def testCallsite(self): - callsite_a = sa.Callsite(0x1002, 0x3000, False) - callsite_b = sa.Callsite(0x1002, 0x3000, True) - self.assertEqual(callsite_a, callsite_a) - self.assertNotEqual(callsite_a, callsite_b) - self.assertNotEqual(callsite_a, None) - - def testFunction(self): - func_a = sa.Function(0x100, 'a', 0, []) - func_b = sa.Function(0x200, 'b', 0, []) - self.assertEqual(func_a, func_a) - self.assertNotEqual(func_a, func_b) - self.assertNotEqual(func_a, None) + """Tests for classes of basic objects.""" + + def testTask(self): + task_a = sa.Task("a", "a_task", 1234) + task_b = sa.Task("b", "b_task", 5678, 0x1000) + self.assertEqual(task_a, task_a) + self.assertNotEqual(task_a, task_b) + self.assertNotEqual(task_a, None) + + def testSymbol(self): + symbol_a = sa.Symbol(0x1234, "F", 32, "a") + symbol_b = sa.Symbol(0x234, "O", 42, "b") + self.assertEqual(symbol_a, symbol_a) + self.assertNotEqual(symbol_a, symbol_b) + self.assertNotEqual(symbol_a, None) + + def testCallsite(self): + callsite_a = sa.Callsite(0x1002, 0x3000, False) + callsite_b = sa.Callsite(0x1002, 0x3000, True) + self.assertEqual(callsite_a, callsite_a) + self.assertNotEqual(callsite_a, callsite_b) + self.assertNotEqual(callsite_a, None) + + def testFunction(self): + func_a = sa.Function(0x100, "a", 0, []) + func_b = sa.Function(0x200, "b", 0, []) + self.assertEqual(func_a, func_a) + self.assertNotEqual(func_a, func_b) + self.assertNotEqual(func_a, None) class ArmAnalyzerTest(unittest.TestCase): - """Tests for class ArmAnalyzer.""" - - def AppendConditionCode(self, opcodes): - rets = [] - for opcode in opcodes: - rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES) - - return rets - - def testInstructionMatching(self): - jump_list = self.AppendConditionCode(['b', 'bx']) - jump_list += (list(opcode + '.n' for opcode in jump_list) + - list(opcode + '.w' for opcode in jump_list)) - for opcode in jump_list: - self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('bl')) - self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('blx')) - - cbz_list = ['cbz', 'cbnz', 'cbz.n', 'cbnz.n', 'cbz.w', 'cbnz.w'] - for opcode in cbz_list: - self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match('cbn')) - - call_list = self.AppendConditionCode(['bl', 'blx']) - call_list += list(opcode + '.n' for opcode in call_list) - for opcode in call_list: - self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match('ble')) - - result = sa.ArmAnalyzer.CALL_OPERAND_RE.match('53f90 <get_time+0x18>') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '53f90') - self.assertEqual(result.group(2), 'get_time+0x18') - - result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match('r6, 53f90 <get+0x0>') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '53f90') - self.assertEqual(result.group(2), 'get+0x0') - - self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('push')) - self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('pushal')) - self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('stmdb')) - self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('lstm')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subw')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub.w')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs.w')) - - result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, sp, #1668 ; 0x684') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '1668') - result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, #1668') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '1668') - self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match('sl, #1668')) - - def testAnalyzeFunction(self): - analyzer = sa.ArmAnalyzer() - symbol = sa.Symbol(0x10, 'F', 0x100, 'foo') - instructions = [ - (0x10, 'push', '{r4, r5, r6, r7, lr}'), - (0x12, 'subw', 'sp, sp, #16 ; 0x10'), - (0x16, 'movs', 'lr, r1'), - (0x18, 'beq.n', '26 <foo+0x26>'), - (0x1a, 'bl', '30 <foo+0x30>'), - (0x1e, 'bl', 'deadbeef <bar>'), - (0x22, 'blx', '0 <woo>'), - (0x26, 'push', '{r1}'), - (0x28, 'stmdb', 'sp!, {r4, r5, r6, r7, r8, r9, lr}'), - (0x2c, 'stmdb', 'sp!, {r4}'), - (0x30, 'stmdb', 'sp, {r4}'), - (0x34, 'bx.n', '10 <foo>'), - (0x36, 'bx.n', 'r3'), - (0x38, 'ldr', 'pc, [r10]'), - ] - (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions) - self.assertEqual(size, 72) - expect_callsites = [sa.Callsite(0x1e, 0xdeadbeef, False), - sa.Callsite(0x22, 0x0, False), - sa.Callsite(0x34, 0x10, True), - sa.Callsite(0x36, None, True), - sa.Callsite(0x38, None, True)] - self.assertEqual(callsites, expect_callsites) + """Tests for class ArmAnalyzer.""" + + def AppendConditionCode(self, opcodes): + rets = [] + for opcode in opcodes: + rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES) + + return rets + + def testInstructionMatching(self): + jump_list = self.AppendConditionCode(["b", "bx"]) + jump_list += list(opcode + ".n" for opcode in jump_list) + list( + opcode + ".w" for opcode in jump_list + ) + for opcode in jump_list: + self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("bl")) + self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("blx")) + + cbz_list = ["cbz", "cbnz", "cbz.n", "cbnz.n", "cbz.w", "cbnz.w"] + for opcode in cbz_list: + self.assertIsNotNone( + sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode) + ) + + self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match("cbn")) + + call_list = self.AppendConditionCode(["bl", "blx"]) + call_list += list(opcode + ".n" for opcode in call_list) + for opcode in call_list: + self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match("ble")) + + result = sa.ArmAnalyzer.CALL_OPERAND_RE.match("53f90 <get_time+0x18>") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "53f90") + self.assertEqual(result.group(2), "get_time+0x18") + + result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match("r6, 53f90 <get+0x0>") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "53f90") + self.assertEqual(result.group(2), "get+0x0") + + self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("push")) + self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("pushal")) + self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("stmdb")) + self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("lstm")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subw")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub.w")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs.w")) + + result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, sp, #1668 ; 0x684") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "1668") + result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, #1668") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "1668") + self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match("sl, #1668")) + + def testAnalyzeFunction(self): + analyzer = sa.ArmAnalyzer() + symbol = sa.Symbol(0x10, "F", 0x100, "foo") + instructions = [ + (0x10, "push", "{r4, r5, r6, r7, lr}"), + (0x12, "subw", "sp, sp, #16 ; 0x10"), + (0x16, "movs", "lr, r1"), + (0x18, "beq.n", "26 <foo+0x26>"), + (0x1A, "bl", "30 <foo+0x30>"), + (0x1E, "bl", "deadbeef <bar>"), + (0x22, "blx", "0 <woo>"), + (0x26, "push", "{r1}"), + (0x28, "stmdb", "sp!, {r4, r5, r6, r7, r8, r9, lr}"), + (0x2C, "stmdb", "sp!, {r4}"), + (0x30, "stmdb", "sp, {r4}"), + (0x34, "bx.n", "10 <foo>"), + (0x36, "bx.n", "r3"), + (0x38, "ldr", "pc, [r10]"), + ] + (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions) + self.assertEqual(size, 72) + expect_callsites = [ + sa.Callsite(0x1E, 0xDEADBEEF, False), + sa.Callsite(0x22, 0x0, False), + sa.Callsite(0x34, 0x10, True), + sa.Callsite(0x36, None, True), + sa.Callsite(0x38, None, True), + ] + self.assertEqual(callsites, expect_callsites) class StackAnalyzerTest(unittest.TestCase): - """Tests for class StackAnalyzer.""" - - def setUp(self): - symbols = [sa.Symbol(0x1000, 'F', 0x15C, 'hook_task'), - sa.Symbol(0x2000, 'F', 0x51C, 'console_task'), - sa.Symbol(0x3200, 'O', 0x124, '__just_data'), - sa.Symbol(0x4000, 'F', 0x11C, 'touchpad_calc'), - sa.Symbol(0x5000, 'F', 0x12C, 'touchpad_calc.constprop.42'), - sa.Symbol(0x12000, 'F', 0x13C, 'trackpad_range'), - sa.Symbol(0x13000, 'F', 0x200, 'inlined_mul'), - sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul'), - sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul_alias'), - sa.Symbol(0x20000, 'O', 0x0, '__array'), - sa.Symbol(0x20010, 'O', 0x0, '__array_end'), - ] - tasklist = [sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - sa.Task('CONSOLE', 'console_task', 460, 0x2000)] - # Array at 0x20000 that contains pointers to hook_task and console_task, - # with stride=8, offset=4 - rodata = (0x20000, [ 0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000 ]) - options = mock.MagicMock(elf_path='./ec.RW.elf', - export_taskinfo='fake', - section='RW', - objdump='objdump', - addr2line='addr2line', - annotation=None) - self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {}) - - def testParseSymbolText(self): - symbol_text = ( - '0 g F .text e8 Foo\n' - '0000dead w F .text 000000e8 .hidden Bar\n' - 'deadbeef l O .bss 00000004 .hidden Woooo\n' - 'deadbee g O .rodata 00000008 __Hooo_ooo\n' - 'deadbee g .rodata 00000000 __foo_doo_coo_end\n' - ) - symbols = sa.ParseSymbolText(symbol_text) - expect_symbols = [sa.Symbol(0x0, 'F', 0xe8, 'Foo'), - sa.Symbol(0xdead, 'F', 0xe8, 'Bar'), - sa.Symbol(0xdeadbeef, 'O', 0x4, 'Woooo'), - sa.Symbol(0xdeadbee, 'O', 0x8, '__Hooo_ooo'), - sa.Symbol(0xdeadbee, 'O', 0x0, '__foo_doo_coo_end')] - self.assertEqual(symbols, expect_symbols) - - def testParseRoData(self): - rodata_text = ( - '\n' - 'Contents of section .rodata:\n' - ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n' - ) - rodata = sa.ParseRoDataText(rodata_text) - expect_rodata = (0x20000, - [ 0x0010adde, 0x00001000, 0x0020adde, 0x00002000 ]) - self.assertEqual(rodata, expect_rodata) - - def testLoadTasklist(self): - def tasklist_to_taskinfos(pointer, tasklist): - taskinfos = [] - for task in tasklist: - taskinfos.append(sa.TaskInfo(name=task.name.encode('utf-8'), - routine=task.routine_name.encode('utf-8'), - stack_size=task.stack_max_size)) - - TaskInfoArray = sa.TaskInfo * len(taskinfos) - pointer.contents.contents = TaskInfoArray(*taskinfos) - return len(taskinfos) - - def ro_taskinfos(pointer): - return tasklist_to_taskinfos(pointer, expect_ro_tasklist) - - def rw_taskinfos(pointer): - return tasklist_to_taskinfos(pointer, expect_rw_tasklist) - - expect_ro_tasklist = [ - sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - ] - - expect_rw_tasklist = [ - sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - sa.Task('WOOKS', 'hook_task', 4096, 0x1000), - sa.Task('CONSOLE', 'console_task', 460, 0x2000), - ] - - export_taskinfo = mock.MagicMock( - get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos), - get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos)) - - tasklist = sa.LoadTasklist('RO', export_taskinfo, self.analyzer.symbols) - self.assertEqual(tasklist, expect_ro_tasklist) - tasklist = sa.LoadTasklist('RW', export_taskinfo, self.analyzer.symbols) - self.assertEqual(tasklist, expect_rw_tasklist) - - def testResolveAnnotation(self): - self.analyzer.annotation = {} - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(remove_rules, []) - self.assertEqual(invalid_sigtxts, set()) - - self.analyzer.annotation = {'add': None, 'remove': None} - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(remove_rules, []) - self.assertEqual(invalid_sigtxts, set()) - - self.analyzer.annotation = { - 'add': None, - 'remove': [ - [['a', 'b'], ['0', '[', '2'], 'x'], - [['a', 'b[x:3]'], ['0', '1', '2'], 'x'], - ], - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(list.sort(remove_rules), list.sort([ - [('a', None, None), ('1', None, None), ('x', None, None)], - [('a', None, None), ('0', None, None), ('x', None, None)], - [('a', None, None), ('2', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('1', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('0', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('2', None, None), ('x', None, None)], - ])) - self.assertEqual(invalid_sigtxts, {'['}) - - self.analyzer.annotation = { - 'add': { - 'touchpad_calc': [ dict(name='__array', stride=8, offset=4) ], + """Tests for class StackAnalyzer.""" + + def setUp(self): + symbols = [ + sa.Symbol(0x1000, "F", 0x15C, "hook_task"), + sa.Symbol(0x2000, "F", 0x51C, "console_task"), + sa.Symbol(0x3200, "O", 0x124, "__just_data"), + sa.Symbol(0x4000, "F", 0x11C, "touchpad_calc"), + sa.Symbol(0x5000, "F", 0x12C, "touchpad_calc.constprop.42"), + sa.Symbol(0x12000, "F", 0x13C, "trackpad_range"), + sa.Symbol(0x13000, "F", 0x200, "inlined_mul"), + sa.Symbol(0x13100, "F", 0x200, "inlined_mul"), + sa.Symbol(0x13100, "F", 0x200, "inlined_mul_alias"), + sa.Symbol(0x20000, "O", 0x0, "__array"), + sa.Symbol(0x20010, "O", 0x0, "__array_end"), + ] + tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + sa.Task("CONSOLE", "console_task", 460, 0x2000), + ] + # Array at 0x20000 that contains pointers to hook_task and console_task, + # with stride=8, offset=4 + rodata = (0x20000, [0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000]) + options = mock.MagicMock( + elf_path="./ec.RW.elf", + export_taskinfo="fake", + section="RW", + objdump="objdump", + addr2line="addr2line", + annotation=None, + ) + self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {}) + + def testParseSymbolText(self): + symbol_text = ( + "0 g F .text e8 Foo\n" + "0000dead w F .text 000000e8 .hidden Bar\n" + "deadbeef l O .bss 00000004 .hidden Woooo\n" + "deadbee g O .rodata 00000008 __Hooo_ooo\n" + "deadbee g .rodata 00000000 __foo_doo_coo_end\n" + ) + symbols = sa.ParseSymbolText(symbol_text) + expect_symbols = [ + sa.Symbol(0x0, "F", 0xE8, "Foo"), + sa.Symbol(0xDEAD, "F", 0xE8, "Bar"), + sa.Symbol(0xDEADBEEF, "O", 0x4, "Woooo"), + sa.Symbol(0xDEADBEE, "O", 0x8, "__Hooo_ooo"), + sa.Symbol(0xDEADBEE, "O", 0x0, "__foo_doo_coo_end"), + ] + self.assertEqual(symbols, expect_symbols) + + def testParseRoData(self): + rodata_text = ( + "\n" + "Contents of section .rodata:\n" + " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n" + ) + rodata = sa.ParseRoDataText(rodata_text) + expect_rodata = ( + 0x20000, + [0x0010ADDE, 0x00001000, 0x0020ADDE, 0x00002000], + ) + self.assertEqual(rodata, expect_rodata) + + def testLoadTasklist(self): + def tasklist_to_taskinfos(pointer, tasklist): + taskinfos = [] + for task in tasklist: + taskinfos.append( + sa.TaskInfo( + name=task.name.encode("utf-8"), + routine=task.routine_name.encode("utf-8"), + stack_size=task.stack_max_size, + ) + ) + + TaskInfoArray = sa.TaskInfo * len(taskinfos) + pointer.contents.contents = TaskInfoArray(*taskinfos) + return len(taskinfos) + + def ro_taskinfos(pointer): + return tasklist_to_taskinfos(pointer, expect_ro_tasklist) + + def rw_taskinfos(pointer): + return tasklist_to_taskinfos(pointer, expect_rw_tasklist) + + expect_ro_tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + ] + + expect_rw_tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + sa.Task("WOOKS", "hook_task", 4096, 0x1000), + sa.Task("CONSOLE", "console_task", 460, 0x2000), + ] + + export_taskinfo = mock.MagicMock( + get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos), + get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos), + ) + + tasklist = sa.LoadTasklist("RO", export_taskinfo, self.analyzer.symbols) + self.assertEqual(tasklist, expect_ro_tasklist) + tasklist = sa.LoadTasklist("RW", export_taskinfo, self.analyzer.symbols) + self.assertEqual(tasklist, expect_rw_tasklist) + + def testResolveAnnotation(self): + self.analyzer.annotation = {} + ( + add_rules, + remove_rules, + invalid_sigtxts, + ) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual(remove_rules, []) + self.assertEqual(invalid_sigtxts, set()) + + self.analyzer.annotation = {"add": None, "remove": None} + ( + add_rules, + remove_rules, + invalid_sigtxts, + ) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual(remove_rules, []) + self.assertEqual(invalid_sigtxts, set()) + + self.analyzer.annotation = { + "add": None, + "remove": [ + [["a", "b"], ["0", "[", "2"], "x"], + [["a", "b[x:3]"], ["0", "1", "2"], "x"], + ], } - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, { - ('touchpad_calc', None, None): - set([('console_task', None, None), ('hook_task', None, None)])}) - - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - 0x5000: sa.Function(0x5000, 'touchpad_calc.constprop.42', 0, []), - 0x13000: sa.Function(0x13000, 'inlined_mul', 0, []), - 0x13100: sa.Function(0x13100, 'inlined_mul', 0, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, None, False, None)] - # Set address_to_line_cache to fake the results of addr2line. - self.analyzer.address_to_line_cache = { - (0x1000, False): [('hook_task', os.path.abspath('a.c'), 10)], - (0x1002, False): [('toot_calc', os.path.abspath('t.c'), 1234)], - (0x2000, False): [('console_task', os.path.abspath('b.c'), 20)], - (0x4000, False): [('toudhpad_calc', os.path.abspath('a.c'), 20)], - (0x5000, False): [ - ('touchpad_calc.constprop.42', os.path.abspath('b.c'), 40)], - (0x12000, False): [('trackpad_range', os.path.abspath('t.c'), 10)], - (0x13000, False): [('inlined_mul', os.path.abspath('x.c'), 12)], - (0x13100, False): [('inlined_mul', os.path.abspath('x.c'), 12)], - } - self.analyzer.annotation = { - 'add': { - 'hook_task.lto.573': ['touchpad_calc.lto.2501[a.c]'], - 'console_task': ['touchpad_calc[b.c]', 'inlined_mul_alias'], - 'hook_task[q.c]': ['hook_task'], - 'inlined_mul[x.c]': ['inlined_mul'], - 'toot_calc[t.c:1234]': ['hook_task'], - }, - 'remove': [ - ['touchpad?calc['], - 'touchpad_calc', - ['touchpad_calc[a.c]'], - ['task_unk[a.c]'], - ['touchpad_calc[x/a.c]'], - ['trackpad_range'], - ['inlined_mul'], - ['inlined_mul', 'console_task', 'touchpad_calc[a.c]'], - ['inlined_mul', 'inlined_mul_alias', 'console_task'], - ['inlined_mul', 'inlined_mul_alias', 'console_task'], - ], - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(invalid_sigtxts, {'touchpad?calc['}) - - signature_set = set() - for src_sig, dst_sigs in add_rules.items(): - signature_set.add(src_sig) - signature_set.update(dst_sigs) - - for remove_sigs in remove_rules: - signature_set.update(remove_sigs) - - (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, - signature_set) - result = self.analyzer.ResolveAnnotation(funcs) - (add_set, remove_list, eliminated_addrs, failed_sigs) = result - - expect_signature_map = { - ('hook_task', None, None): {funcs[0x1000]}, - ('touchpad_calc', os.path.abspath('a.c'), None): {funcs[0x4000]}, - ('touchpad_calc', os.path.abspath('b.c'), None): {funcs[0x5000]}, - ('console_task', None, None): {funcs[0x2000]}, - ('inlined_mul_alias', None, None): {funcs[0x13100]}, - ('inlined_mul', os.path.abspath('x.c'), None): {funcs[0x13000], - funcs[0x13100]}, - ('inlined_mul', None, None): {funcs[0x13000], funcs[0x13100]}, - } - self.assertEqual(len(signature_map), len(expect_signature_map)) - for sig, funclist in signature_map.items(): - self.assertEqual(set(funclist), expect_signature_map[sig]) - - self.assertEqual(add_set, { - (funcs[0x1000], funcs[0x4000]), - (funcs[0x1000], funcs[0x1000]), - (funcs[0x2000], funcs[0x5000]), - (funcs[0x2000], funcs[0x13100]), - (funcs[0x13000], funcs[0x13000]), - (funcs[0x13000], funcs[0x13100]), - (funcs[0x13100], funcs[0x13000]), - (funcs[0x13100], funcs[0x13100]), - }) - expect_remove_list = [ - [funcs[0x4000]], - [funcs[0x13000]], - [funcs[0x13100]], - [funcs[0x13000], funcs[0x2000], funcs[0x4000]], - [funcs[0x13100], funcs[0x2000], funcs[0x4000]], - [funcs[0x13000], funcs[0x13100], funcs[0x2000]], - [funcs[0x13100], funcs[0x13100], funcs[0x2000]], - ] - self.assertEqual(len(remove_list), len(expect_remove_list)) - for remove_path in remove_list: - self.assertTrue(remove_path in expect_remove_list) - - self.assertEqual(eliminated_addrs, {0x1002}) - self.assertEqual(failed_sigs, { - ('touchpad?calc[', sa.StackAnalyzer.ANNOTATION_ERROR_INVALID), - ('touchpad_calc', sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS), - ('hook_task[q.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('task_unk[a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('touchpad_calc[x/a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('trackpad_range', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - }) - - def testPreprocessAnnotation(self): - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])] - funcs[0x2000].callsites = [ - sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]), - sa.Callsite(0x2006, None, True, None), - ] - add_set = { - (funcs[0x2000], funcs[0x2000]), - (funcs[0x2000], funcs[0x4000]), - (funcs[0x4000], funcs[0x1000]), - (funcs[0x4000], funcs[0x2000]), - } - remove_list = [ - [funcs[0x1000]], - [funcs[0x2000], funcs[0x2000]], - [funcs[0x4000], funcs[0x1000]], - [funcs[0x2000], funcs[0x4000], funcs[0x2000]], - [funcs[0x4000], funcs[0x1000], funcs[0x4000]], - ] - eliminated_addrs = {0x2006} - - remaining_remove_list = self.analyzer.PreprocessAnnotation(funcs, - add_set, - remove_list, - eliminated_addrs) - - expect_funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - expect_funcs[0x2000].callsites = [ - sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])] - expect_funcs[0x4000].callsites = [ - sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])] - self.assertEqual(funcs, expect_funcs) - self.assertEqual(remaining_remove_list, [ - [funcs[0x2000], funcs[0x4000], funcs[0x2000]], - ]) - - def testAndesAnalyzeDisassembly(self): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n' - ' 1004: 47 70\t\tmovi55 $r0, #1\n' - ' 1006: b1 13\tbnezs8 100929de <flash_command_write>\n' - ' 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n' - '00002000 <console_task>:\n' - ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n' - ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n' - ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n' - ' 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n' - '00004000 <touchpad_calc>:\n' - ' 4000: 47 70\t\tmovi55 $r0, #1\n' - '00010000 <look_task>:' - ) - function_map = self.analyzer.AnalyzeDisassembly(disasm_text) - func_hook_task = sa.Function(0x1000, 'hook_task', 48, [ - sa.Callsite(0x1006, 0x100929de, True, None)]) - expect_funcmap = { - 0x1000: func_hook_task, - 0x2000: sa.Function(0x2000, 'console_task', 16, - [sa.Callsite(0x2002, 0x1000, False, func_hook_task), - sa.Callsite(0x2006, 0x53968, True, None)]), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - self.assertEqual(function_map, expect_funcmap) - - def testArmAnalyzeDisassembly(self): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: dead beef\tfake\n' - ' 1004: 4770\t\tbx lr\n' - ' 1006: b113\tcbz r3, 100929de <flash_command_write>\n' - ' 1008: 00015cfc\t.word 0x00015cfc\n' - '00002000 <console_task>:\n' - ' 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n' - ' 2002: f00e fcc5\tbl 1000 <hook_task>\n' - ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n' - ' 200a: dead beef\tfake\n' - '00004000 <touchpad_calc>:\n' - ' 4000: 4770\t\tbx lr\n' - '00010000 <look_task>:' - ) - function_map = self.analyzer.AnalyzeDisassembly(disasm_text) - func_hook_task = sa.Function(0x1000, 'hook_task', 0, [ - sa.Callsite(0x1006, 0x100929de, True, None)]) - expect_funcmap = { - 0x1000: func_hook_task, - 0x2000: sa.Function(0x2000, 'console_task', 8, - [sa.Callsite(0x2002, 0x1000, False, func_hook_task), - sa.Callsite(0x2006, 0x53968, True, None)]), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - self.assertEqual(function_map, expect_funcmap) - - def testAnalyzeCallGraph(self): - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 8, []), - 0x3000: sa.Function(0x3000, 'task_a', 12, []), - 0x4000: sa.Function(0x4000, 'task_b', 96, []), - 0x5000: sa.Function(0x5000, 'task_c', 32, []), - 0x6000: sa.Function(0x6000, 'task_d', 100, []), - 0x7000: sa.Function(0x7000, 'task_e', 24, []), - 0x8000: sa.Function(0x8000, 'task_f', 20, []), - 0x9000: sa.Function(0x9000, 'task_g', 20, []), - 0x10000: sa.Function(0x10000, 'task_x', 16, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]), - sa.Callsite(0x1006, 0x4000, False, funcs[0x4000])] - funcs[0x2000].callsites = [ - sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]), - sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]), - sa.Callsite(0x200a, 0x10000, False, funcs[0x10000])] - funcs[0x3000].callsites = [ - sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]), - sa.Callsite(0x3006, 0x1000, False, funcs[0x1000])] - funcs[0x4000].callsites = [ - sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]), - sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]), - sa.Callsite(0x400a, 0x8000, False, funcs[0x8000])] - funcs[0x5000].callsites = [ - sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])] - funcs[0x7000].callsites = [ - sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])] - funcs[0x8000].callsites = [ - sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])] - funcs[0x9000].callsites = [ - sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])] - funcs[0x10000].callsites = [ - sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])] - - cycles = self.analyzer.AnalyzeCallGraph(funcs, [ - [funcs[0x2000]] * 2, - [funcs[0x10000], funcs[0x2000]] * 3, - [funcs[0x1000], funcs[0x3000], funcs[0x1000]] - ]) - - expect_func_stack = { - 0x1000: (268, [funcs[0x1000], - funcs[0x3000], - funcs[0x4000], - funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x2000: (208, [funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x5000], - funcs[0x4000], - funcs[0x7000]]), - 0x3000: (280, [funcs[0x3000], - funcs[0x1000], - funcs[0x3000], - funcs[0x4000], - funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x4000: (120, [funcs[0x4000], funcs[0x7000]]), - 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]), - 0x6000: (100, [funcs[0x6000]]), - 0x7000: (24, [funcs[0x7000]]), - 0x8000: (160, [funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]), - 0x10000: (200, [funcs[0x10000], - funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x5000], - funcs[0x4000], - funcs[0x7000]]), - } - expect_cycles = [ - {funcs[0x4000], funcs[0x8000], funcs[0x9000]}, - {funcs[0x7000]}, - ] - for func in funcs.values(): - (stack_max_usage, stack_max_path) = expect_func_stack[func.address] - self.assertEqual(func.stack_max_usage, stack_max_usage) - self.assertEqual(func.stack_max_path, stack_max_path) - - self.assertEqual(len(cycles), len(expect_cycles)) - for cycle in cycles: - self.assertTrue(cycle in expect_cycles) - - @mock.patch('subprocess.check_output') - def testAddressToLine(self, checkoutput_mock): - checkoutput_mock.return_value = 'fake_func\n/test.c:1' - self.assertEqual(self.analyzer.AddressToLine(0x1234), - [('fake_func', '/test.c', 1)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '1234'], encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = 'fake_func\n/a.c:1\nbake_func\n/b.c:2\n' - self.assertEqual(self.analyzer.AddressToLine(0x1234, True), - [('fake_func', '/a.c', 1), ('bake_func', '/b.c', 2)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '1234', '-i'], - encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = 'fake_func\n/test.c:1 (discriminator 128)' - self.assertEqual(self.analyzer.AddressToLine(0x12345), - [('fake_func', '/test.c', 1)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '12345'], encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = '??\n:?\nbake_func\n/b.c:2\n' - self.assertEqual(self.analyzer.AddressToLine(0x123456), - [None, ('bake_func', '/b.c', 2)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '123456'], encoding='utf-8') - checkoutput_mock.reset_mock() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'addr2line failed to resolve lines.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.AddressToLine(0x5678) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run addr2line.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.AddressToLine(0x9012) - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine') - def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n' - ' 1002: 47 70\t\tmovi55 $r0, #1\n' - ' 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n' - '00002000 <console_task>:\n' - ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n' - ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n' - ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n' - ' 200a: 12 34 56 78\tjral5 $r0\n' - ) - - addrtoline_mock.return_value = [('??', '??', 0)] - self.analyzer.annotation = { - 'exception_frame_size': 64, - 'remove': [['fake_func']], - } - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.return_value = disasm_text - self.analyzer.Analyze() - print_mock.assert_has_calls([ - mock.call( - 'Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048'), - mock.call('Call Trace:'), - mock.call(' hook_task (32) [??:0] 1000'), - mock.call( - 'Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460'), - mock.call('Call Trace:'), - mock.call(' console_task (16) [??:0] 2000'), - mock.call(' -> ??[??:0] 2002'), - mock.call(' hook_task (32) [??:0] 1000'), - mock.call('Unresolved indirect callsites:'), - mock.call(' In function console_task:'), - mock.call(' -> ??[??:0] 200a'), - mock.call('Unresolved annotation signatures:'), - mock.call(' fake_func: function is not found'), - ]) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run objdump.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.Analyze() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'objdump failed to disassemble.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.Analyze() - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine') - def testArmAnalyze(self, addrtoline_mock, checkoutput_mock): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: b508\t\tpush {r3, lr}\n' - ' 1002: 4770\t\tbx lr\n' - ' 1006: 00015cfc\t.word 0x00015cfc\n' - '00002000 <console_task>:\n' - ' 2000: b508\t\tpush {r3, lr}\n' - ' 2002: f00e fcc5\tbl 1000 <hook_task>\n' - ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n' - ' 200a: 1234 5678\tb.w sl\n' - ) - - addrtoline_mock.return_value = [('??', '??', 0)] - self.analyzer.annotation = { - 'exception_frame_size': 64, - 'remove': [['fake_func']], - } - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.return_value = disasm_text - self.analyzer.Analyze() - print_mock.assert_has_calls([ - mock.call( - 'Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048'), - mock.call('Call Trace:'), - mock.call(' hook_task (8) [??:0] 1000'), - mock.call( - 'Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460'), - mock.call('Call Trace:'), - mock.call(' console_task (8) [??:0] 2000'), - mock.call(' -> ??[??:0] 2002'), - mock.call(' hook_task (8) [??:0] 1000'), - mock.call('Unresolved indirect callsites:'), - mock.call(' In function console_task:'), - mock.call(' -> ??[??:0] 200a'), - mock.call('Unresolved annotation signatures:'), - mock.call(' fake_func: function is not found'), - ]) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run objdump.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.Analyze() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'objdump failed to disassemble.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.Analyze() - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.ParseArgs') - def testMain(self, parseargs_mock, checkoutput_mock): - symbol_text = ('1000 g F .text 0000015c .hidden hook_task\n' - '2000 g F .text 0000051c .hidden console_task\n') - rodata_text = ( - '\n' - 'Contents of section .rodata:\n' - ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n' - ) - - args = mock.MagicMock(elf_path='./ec.RW.elf', - export_taskinfo='fake', - section='RW', - objdump='objdump', - addr2line='addr2line', - annotation='fake') - parseargs_mock.return_value = args - - with mock.patch('os.path.exists') as path_mock: - path_mock.return_value = False - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - sa.main() - print_mock.assert_any_call( - 'Warning: Annotation file fake does not exist.') - - with mock.patch('os.path.exists') as path_mock: - path_mock.return_value = True - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - open_mock.side_effect = IOError() - sa.main() - print_mock.assert_called_once_with( - 'Error: Failed to open annotation file fake.') - - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - open_mock.return_value.read.side_effect = ['{', ''] - sa.main() - open_mock.assert_called_once_with('fake', 'r') - print_mock.assert_called_once_with( - 'Error: Failed to parse annotation file fake.') - - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', - mock.mock_open(read_data='')) as open_mock: - sa.main() - print_mock.assert_called_once_with( - 'Error: Invalid annotation file fake.') - - args.annotation = None - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = [symbol_text, rodata_text] - sa.main() - print_mock.assert_called_once_with( - 'Error: Failed to load export_taskinfo.') - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - sa.main() - print_mock.assert_called_once_with( - 'Error: objdump failed to dump symbol table or rodata.') - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = OSError() - sa.main() - print_mock.assert_called_once_with('Error: Failed to run objdump.') - - -if __name__ == '__main__': - unittest.main() + ( + add_rules, + remove_rules, + invalid_sigtxts, + ) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual( + list.sort(remove_rules), + list.sort( + [ + [("a", None, None), ("1", None, None), ("x", None, None)], + [("a", None, None), ("0", None, None), ("x", None, None)], + [("a", None, None), ("2", None, None), ("x", None, None)], + [ + ("b", os.path.abspath("x"), 3), + ("1", None, None), + ("x", None, None), + ], + [ + ("b", os.path.abspath("x"), 3), + ("0", None, None), + ("x", None, None), + ], + [ + ("b", os.path.abspath("x"), 3), + ("2", None, None), + ("x", None, None), + ], + ] + ), + ) + self.assertEqual(invalid_sigtxts, {"["}) + + self.analyzer.annotation = { + "add": { + "touchpad_calc": [dict(name="__array", stride=8, offset=4)], + } + } + ( + add_rules, + remove_rules, + invalid_sigtxts, + ) = self.analyzer.LoadAnnotation() + self.assertEqual( + add_rules, + { + ("touchpad_calc", None, None): set( + [("console_task", None, None), ("hook_task", None, None)] + ) + }, + ) + + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + 0x5000: sa.Function(0x5000, "touchpad_calc.constprop.42", 0, []), + 0x13000: sa.Function(0x13000, "inlined_mul", 0, []), + 0x13100: sa.Function(0x13100, "inlined_mul", 0, []), + } + funcs[0x1000].callsites = [sa.Callsite(0x1002, None, False, None)] + # Set address_to_line_cache to fake the results of addr2line. + self.analyzer.address_to_line_cache = { + (0x1000, False): [("hook_task", os.path.abspath("a.c"), 10)], + (0x1002, False): [("toot_calc", os.path.abspath("t.c"), 1234)], + (0x2000, False): [("console_task", os.path.abspath("b.c"), 20)], + (0x4000, False): [("toudhpad_calc", os.path.abspath("a.c"), 20)], + (0x5000, False): [ + ("touchpad_calc.constprop.42", os.path.abspath("b.c"), 40) + ], + (0x12000, False): [("trackpad_range", os.path.abspath("t.c"), 10)], + (0x13000, False): [("inlined_mul", os.path.abspath("x.c"), 12)], + (0x13100, False): [("inlined_mul", os.path.abspath("x.c"), 12)], + } + self.analyzer.annotation = { + "add": { + "hook_task.lto.573": ["touchpad_calc.lto.2501[a.c]"], + "console_task": ["touchpad_calc[b.c]", "inlined_mul_alias"], + "hook_task[q.c]": ["hook_task"], + "inlined_mul[x.c]": ["inlined_mul"], + "toot_calc[t.c:1234]": ["hook_task"], + }, + "remove": [ + ["touchpad?calc["], + "touchpad_calc", + ["touchpad_calc[a.c]"], + ["task_unk[a.c]"], + ["touchpad_calc[x/a.c]"], + ["trackpad_range"], + ["inlined_mul"], + ["inlined_mul", "console_task", "touchpad_calc[a.c]"], + ["inlined_mul", "inlined_mul_alias", "console_task"], + ["inlined_mul", "inlined_mul_alias", "console_task"], + ], + } + ( + add_rules, + remove_rules, + invalid_sigtxts, + ) = self.analyzer.LoadAnnotation() + self.assertEqual(invalid_sigtxts, {"touchpad?calc["}) + + signature_set = set() + for src_sig, dst_sigs in add_rules.items(): + signature_set.add(src_sig) + signature_set.update(dst_sigs) + + for remove_sigs in remove_rules: + signature_set.update(remove_sigs) + + (signature_map, failed_sigs) = self.analyzer.MapAnnotation( + funcs, signature_set + ) + result = self.analyzer.ResolveAnnotation(funcs) + (add_set, remove_list, eliminated_addrs, failed_sigs) = result + + expect_signature_map = { + ("hook_task", None, None): {funcs[0x1000]}, + ("touchpad_calc", os.path.abspath("a.c"), None): {funcs[0x4000]}, + ("touchpad_calc", os.path.abspath("b.c"), None): {funcs[0x5000]}, + ("console_task", None, None): {funcs[0x2000]}, + ("inlined_mul_alias", None, None): {funcs[0x13100]}, + ("inlined_mul", os.path.abspath("x.c"), None): { + funcs[0x13000], + funcs[0x13100], + }, + ("inlined_mul", None, None): {funcs[0x13000], funcs[0x13100]}, + } + self.assertEqual(len(signature_map), len(expect_signature_map)) + for sig, funclist in signature_map.items(): + self.assertEqual(set(funclist), expect_signature_map[sig]) + + self.assertEqual( + add_set, + { + (funcs[0x1000], funcs[0x4000]), + (funcs[0x1000], funcs[0x1000]), + (funcs[0x2000], funcs[0x5000]), + (funcs[0x2000], funcs[0x13100]), + (funcs[0x13000], funcs[0x13000]), + (funcs[0x13000], funcs[0x13100]), + (funcs[0x13100], funcs[0x13000]), + (funcs[0x13100], funcs[0x13100]), + }, + ) + expect_remove_list = [ + [funcs[0x4000]], + [funcs[0x13000]], + [funcs[0x13100]], + [funcs[0x13000], funcs[0x2000], funcs[0x4000]], + [funcs[0x13100], funcs[0x2000], funcs[0x4000]], + [funcs[0x13000], funcs[0x13100], funcs[0x2000]], + [funcs[0x13100], funcs[0x13100], funcs[0x2000]], + ] + self.assertEqual(len(remove_list), len(expect_remove_list)) + for remove_path in remove_list: + self.assertTrue(remove_path in expect_remove_list) + + self.assertEqual(eliminated_addrs, {0x1002}) + self.assertEqual( + failed_sigs, + { + ("touchpad?calc[", sa.StackAnalyzer.ANNOTATION_ERROR_INVALID), + ("touchpad_calc", sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS), + ("hook_task[q.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("task_unk[a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ( + "touchpad_calc[x/a.c]", + sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND, + ), + ("trackpad_range", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + }, + ) + + def testPreprocessAnnotation(self): + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + funcs[0x1000].callsites = [ + sa.Callsite(0x1002, 0x1000, False, funcs[0x1000]) + ] + funcs[0x2000].callsites = [ + sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]), + sa.Callsite(0x2006, None, True, None), + ] + add_set = { + (funcs[0x2000], funcs[0x2000]), + (funcs[0x2000], funcs[0x4000]), + (funcs[0x4000], funcs[0x1000]), + (funcs[0x4000], funcs[0x2000]), + } + remove_list = [ + [funcs[0x1000]], + [funcs[0x2000], funcs[0x2000]], + [funcs[0x4000], funcs[0x1000]], + [funcs[0x2000], funcs[0x4000], funcs[0x2000]], + [funcs[0x4000], funcs[0x1000], funcs[0x4000]], + ] + eliminated_addrs = {0x2006} + + remaining_remove_list = self.analyzer.PreprocessAnnotation( + funcs, add_set, remove_list, eliminated_addrs + ) + + expect_funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + expect_funcs[0x2000].callsites = [ + sa.Callsite(None, 0x4000, False, expect_funcs[0x4000]) + ] + expect_funcs[0x4000].callsites = [ + sa.Callsite(None, 0x2000, False, expect_funcs[0x2000]) + ] + self.assertEqual(funcs, expect_funcs) + self.assertEqual( + remaining_remove_list, + [ + [funcs[0x2000], funcs[0x4000], funcs[0x2000]], + ], + ) + + def testAndesAnalyzeDisassembly(self): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n" + " 1004: 47 70\t\tmovi55 $r0, #1\n" + " 1006: b1 13\tbnezs8 100929de <flash_command_write>\n" + " 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n" + "00002000 <console_task>:\n" + " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n" + " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n" + " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n" + " 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n" + "00004000 <touchpad_calc>:\n" + " 4000: 47 70\t\tmovi55 $r0, #1\n" + "00010000 <look_task>:" + ) + function_map = self.analyzer.AnalyzeDisassembly(disasm_text) + func_hook_task = sa.Function( + 0x1000, + "hook_task", + 48, + [sa.Callsite(0x1006, 0x100929DE, True, None)], + ) + expect_funcmap = { + 0x1000: func_hook_task, + 0x2000: sa.Function( + 0x2000, + "console_task", + 16, + [ + sa.Callsite(0x2002, 0x1000, False, func_hook_task), + sa.Callsite(0x2006, 0x53968, True, None), + ], + ), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + self.assertEqual(function_map, expect_funcmap) + + def testArmAnalyzeDisassembly(self): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: dead beef\tfake\n" + " 1004: 4770\t\tbx lr\n" + " 1006: b113\tcbz r3, 100929de <flash_command_write>\n" + " 1008: 00015cfc\t.word 0x00015cfc\n" + "00002000 <console_task>:\n" + " 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n" + " 2002: f00e fcc5\tbl 1000 <hook_task>\n" + " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n" + " 200a: dead beef\tfake\n" + "00004000 <touchpad_calc>:\n" + " 4000: 4770\t\tbx lr\n" + "00010000 <look_task>:" + ) + function_map = self.analyzer.AnalyzeDisassembly(disasm_text) + func_hook_task = sa.Function( + 0x1000, + "hook_task", + 0, + [sa.Callsite(0x1006, 0x100929DE, True, None)], + ) + expect_funcmap = { + 0x1000: func_hook_task, + 0x2000: sa.Function( + 0x2000, + "console_task", + 8, + [ + sa.Callsite(0x2002, 0x1000, False, func_hook_task), + sa.Callsite(0x2006, 0x53968, True, None), + ], + ), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + self.assertEqual(function_map, expect_funcmap) + + def testAnalyzeCallGraph(self): + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 8, []), + 0x3000: sa.Function(0x3000, "task_a", 12, []), + 0x4000: sa.Function(0x4000, "task_b", 96, []), + 0x5000: sa.Function(0x5000, "task_c", 32, []), + 0x6000: sa.Function(0x6000, "task_d", 100, []), + 0x7000: sa.Function(0x7000, "task_e", 24, []), + 0x8000: sa.Function(0x8000, "task_f", 20, []), + 0x9000: sa.Function(0x9000, "task_g", 20, []), + 0x10000: sa.Function(0x10000, "task_x", 16, []), + } + funcs[0x1000].callsites = [ + sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]), + sa.Callsite(0x1006, 0x4000, False, funcs[0x4000]), + ] + funcs[0x2000].callsites = [ + sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]), + sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]), + sa.Callsite(0x200A, 0x10000, False, funcs[0x10000]), + ] + funcs[0x3000].callsites = [ + sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]), + sa.Callsite(0x3006, 0x1000, False, funcs[0x1000]), + ] + funcs[0x4000].callsites = [ + sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]), + sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]), + sa.Callsite(0x400A, 0x8000, False, funcs[0x8000]), + ] + funcs[0x5000].callsites = [ + sa.Callsite(0x5002, 0x4000, False, funcs[0x4000]) + ] + funcs[0x7000].callsites = [ + sa.Callsite(0x7002, 0x7000, False, funcs[0x7000]) + ] + funcs[0x8000].callsites = [ + sa.Callsite(0x8002, 0x9000, False, funcs[0x9000]) + ] + funcs[0x9000].callsites = [ + sa.Callsite(0x9002, 0x4000, False, funcs[0x4000]) + ] + funcs[0x10000].callsites = [ + sa.Callsite(0x10002, 0x2000, False, funcs[0x2000]) + ] + + cycles = self.analyzer.AnalyzeCallGraph( + funcs, + [ + [funcs[0x2000]] * 2, + [funcs[0x10000], funcs[0x2000]] * 3, + [funcs[0x1000], funcs[0x3000], funcs[0x1000]], + ], + ) + + expect_func_stack = { + 0x1000: ( + 268, + [ + funcs[0x1000], + funcs[0x3000], + funcs[0x4000], + funcs[0x8000], + funcs[0x9000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x2000: ( + 208, + [ + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x5000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x3000: ( + 280, + [ + funcs[0x3000], + funcs[0x1000], + funcs[0x3000], + funcs[0x4000], + funcs[0x8000], + funcs[0x9000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x4000: (120, [funcs[0x4000], funcs[0x7000]]), + 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]), + 0x6000: (100, [funcs[0x6000]]), + 0x7000: (24, [funcs[0x7000]]), + 0x8000: ( + 160, + [funcs[0x8000], funcs[0x9000], funcs[0x4000], funcs[0x7000]], + ), + 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]), + 0x10000: ( + 200, + [ + funcs[0x10000], + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x5000], + funcs[0x4000], + funcs[0x7000], + ], + ), + } + expect_cycles = [ + {funcs[0x4000], funcs[0x8000], funcs[0x9000]}, + {funcs[0x7000]}, + ] + for func in funcs.values(): + (stack_max_usage, stack_max_path) = expect_func_stack[func.address] + self.assertEqual(func.stack_max_usage, stack_max_usage) + self.assertEqual(func.stack_max_path, stack_max_path) + + self.assertEqual(len(cycles), len(expect_cycles)) + for cycle in cycles: + self.assertTrue(cycle in expect_cycles) + + @mock.patch("subprocess.check_output") + def testAddressToLine(self, checkoutput_mock): + checkoutput_mock.return_value = "fake_func\n/test.c:1" + self.assertEqual( + self.analyzer.AddressToLine(0x1234), [("fake_func", "/test.c", 1)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "1234"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "fake_func\n/a.c:1\nbake_func\n/b.c:2\n" + self.assertEqual( + self.analyzer.AddressToLine(0x1234, True), + [("fake_func", "/a.c", 1), ("bake_func", "/b.c", 2)], + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "1234", "-i"], + encoding="utf-8", + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = ( + "fake_func\n/test.c:1 (discriminator 128)" + ) + self.assertEqual( + self.analyzer.AddressToLine(0x12345), [("fake_func", "/test.c", 1)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "12345"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "??\n:?\nbake_func\n/b.c:2\n" + self.assertEqual( + self.analyzer.AddressToLine(0x123456), + [None, ("bake_func", "/b.c", 2)], + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "123456"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "addr2line failed to resolve lines." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.AddressToLine(0x5678) + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "Failed to run addr2line." + ): + checkoutput_mock.side_effect = OSError() + self.analyzer.AddressToLine(0x9012) + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine") + def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n" + " 1002: 47 70\t\tmovi55 $r0, #1\n" + " 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n" + "00002000 <console_task>:\n" + " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n" + " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n" + " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n" + " 200a: 12 34 56 78\tjral5 $r0\n" + ) + + addrtoline_mock.return_value = [("??", "??", 0)] + self.analyzer.annotation = { + "exception_frame_size": 64, + "remove": [["fake_func"]], + } + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.return_value = disasm_text + self.analyzer.Analyze() + print_mock.assert_has_calls( + [ + mock.call( + "Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048" + ), + mock.call("Call Trace:"), + mock.call(" hook_task (32) [??:0] 1000"), + mock.call( + "Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460" + ), + mock.call("Call Trace:"), + mock.call(" console_task (16) [??:0] 2000"), + mock.call(" -> ??[??:0] 2002"), + mock.call(" hook_task (32) [??:0] 1000"), + mock.call("Unresolved indirect callsites:"), + mock.call(" In function console_task:"), + mock.call(" -> ??[??:0] 200a"), + mock.call("Unresolved annotation signatures:"), + mock.call(" fake_func: function is not found"), + ] + ) + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "Failed to run objdump." + ): + checkoutput_mock.side_effect = OSError() + self.analyzer.Analyze() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "objdump failed to disassemble." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.Analyze() + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine") + def testArmAnalyze(self, addrtoline_mock, checkoutput_mock): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: b508\t\tpush {r3, lr}\n" + " 1002: 4770\t\tbx lr\n" + " 1006: 00015cfc\t.word 0x00015cfc\n" + "00002000 <console_task>:\n" + " 2000: b508\t\tpush {r3, lr}\n" + " 2002: f00e fcc5\tbl 1000 <hook_task>\n" + " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n" + " 200a: 1234 5678\tb.w sl\n" + ) + + addrtoline_mock.return_value = [("??", "??", 0)] + self.analyzer.annotation = { + "exception_frame_size": 64, + "remove": [["fake_func"]], + } + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.return_value = disasm_text + self.analyzer.Analyze() + print_mock.assert_has_calls( + [ + mock.call( + "Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048" + ), + mock.call("Call Trace:"), + mock.call(" hook_task (8) [??:0] 1000"), + mock.call( + "Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460" + ), + mock.call("Call Trace:"), + mock.call(" console_task (8) [??:0] 2000"), + mock.call(" -> ??[??:0] 2002"), + mock.call(" hook_task (8) [??:0] 1000"), + mock.call("Unresolved indirect callsites:"), + mock.call(" In function console_task:"), + mock.call(" -> ??[??:0] 200a"), + mock.call("Unresolved annotation signatures:"), + mock.call(" fake_func: function is not found"), + ] + ) + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "Failed to run objdump." + ): + checkoutput_mock.side_effect = OSError() + self.analyzer.Analyze() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "objdump failed to disassemble." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.Analyze() + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.ParseArgs") + def testMain(self, parseargs_mock, checkoutput_mock): + symbol_text = ( + "1000 g F .text 0000015c .hidden hook_task\n" + "2000 g F .text 0000051c .hidden console_task\n" + ) + rodata_text = ( + "\n" + "Contents of section .rodata:\n" + " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n" + ) + + args = mock.MagicMock( + elf_path="./ec.RW.elf", + export_taskinfo="fake", + section="RW", + objdump="objdump", + addr2line="addr2line", + annotation="fake", + ) + parseargs_mock.return_value = args + + with mock.patch("os.path.exists") as path_mock: + path_mock.return_value = False + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + sa.main() + print_mock.assert_any_call( + "Warning: Annotation file fake does not exist." + ) + + with mock.patch("os.path.exists") as path_mock: + path_mock.return_value = True + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + open_mock.side_effect = IOError() + sa.main() + print_mock.assert_called_once_with( + "Error: Failed to open annotation file fake." + ) + + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + open_mock.return_value.read.side_effect = ["{", ""] + sa.main() + open_mock.assert_called_once_with("fake", "r") + print_mock.assert_called_once_with( + "Error: Failed to parse annotation file fake." + ) + + with mock.patch("builtins.print") as print_mock: + with mock.patch( + "builtins.open", mock.mock_open(read_data="") + ) as open_mock: + sa.main() + print_mock.assert_called_once_with( + "Error: Invalid annotation file fake." + ) + + args.annotation = None + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = [symbol_text, rodata_text] + sa.main() + print_mock.assert_called_once_with( + "Error: Failed to load export_taskinfo." + ) + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + sa.main() + print_mock.assert_called_once_with( + "Error: objdump failed to dump symbol table or rodata." + ) + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = OSError() + sa.main() + print_mock.assert_called_once_with("Error: Failed to run objdump.") + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/tigertool/ecusb/__init__.py b/extra/tigertool/ecusb/__init__.py index fe4dbc6749..9451551f37 100644 --- a/extra/tigertool/ecusb/__init__.py +++ b/extra/tigertool/ecusb/__init__.py @@ -1,9 +1,5 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes -__all__ = ['tiny_servo_common', 'stm32usb', 'stm32uart', 'pty_driver'] +__all__ = ["tiny_servo_common", "stm32usb", "stm32uart", "pty_driver"] diff --git a/extra/tigertool/ecusb/pty_driver.py b/extra/tigertool/ecusb/pty_driver.py index 09ef8c42e4..723bf41b57 100644 --- a/extra/tigertool/ecusb/pty_driver.py +++ b/extra/tigertool/ecusb/pty_driver.py @@ -1,10 +1,6 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """ptyDriver class @@ -17,9 +13,10 @@ import ast import errno import fcntl import os -import pexpect import time -from pexpect import fdpexpect + +import pexpect # pylint:disable=import-error +from pexpect import fdpexpect # pylint:disable=import-error # Expecting a result in 3 seconds is plenty even for slow platforms. DEFAULT_UART_TIMEOUT = 3 @@ -27,281 +24,291 @@ FLUSH_UART_TIMEOUT = 1 class ptyError(Exception): - """Exception class for pty errors.""" + """Exception class for pty errors.""" UART_PARAMS = { - 'uart_cmd': None, - 'uart_multicmd': None, - 'uart_regexp': None, - 'uart_timeout': DEFAULT_UART_TIMEOUT, + "uart_cmd": None, + "uart_multicmd": None, + "uart_regexp": None, + "uart_timeout": DEFAULT_UART_TIMEOUT, } class ptyDriver(object): - """Automate interactive commands on a pty interface.""" - def __init__(self, interface, params, fast=False): - """Init class variables.""" - self._child = None - self._fd = None - self._interface = interface - self._pty_path = self._interface.get_pty() - self._dict = UART_PARAMS.copy() - self._fast = fast - - def __del__(self): - self.close() - - def close(self): - """Close any open files and interfaces.""" - if self._fd: - self._close() - self._interface.close() - - def _open(self): - """Connect to serial device and create pexpect interface.""" - assert self._fd is None - self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK) - # Don't allow forked processes to access. - fcntl.fcntl(self._fd, fcntl.F_SETFD, - fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC) - self._child = fdpexpect.fdspawn(self._fd) - # pexpect defaults to a 100ms delay before sending characters, to - # work around race conditions in ssh. We don't need this feature - # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up. - if self._fast: - self._child.delaybeforesend = 0.001 - - def _close(self): - """Close serial device connection.""" - os.close(self._fd) - self._fd = None - self._child = None - - def _flush(self): - """Flush device output to prevent previous messages interfering.""" - if self._child.sendline('') != 1: - raise ptyError('Failed to send newline.') - # Have a maximum timeout for the flush operation. We should have cleared - # all data from the buffer, but if data is regularly being generated, we - # can't guarantee it will ever stop. - flush_end_time = time.time() + FLUSH_UART_TIMEOUT - while time.time() <= flush_end_time: - try: - self._child.expect('.', timeout=0.01) - except (pexpect.TIMEOUT, pexpect.EOF): - break - except OSError as e: - # EAGAIN indicates no data available, maybe we didn't wait long enough. - if e.errno != errno.EAGAIN: - raise - break - - def _send(self, cmds): - """Send command to EC. - - This function always flushes serial device before sending, and is used as - a wrapper function to make sure the channel is always flushed before - sending commands. - - Args: - cmds: The commands to send to the device, either a list or a string. - - Raises: - ptyError: Raised when writing to the device fails. - """ - self._flush() - if not isinstance(cmds, list): - cmds = [cmds] - for cmd in cmds: - if self._child.sendline(cmd) != len(cmd) + 1: - raise ptyError('Failed to send command.') - - def _issue_cmd(self, cmds): - """Send command to the device and do not wait for response. - - Args: - cmds: The commands to send to the device, either a list or a string. - """ - self._issue_cmd_get_results(cmds, []) - - def _issue_cmd_get_results(self, cmds, - regex_list, timeout=DEFAULT_UART_TIMEOUT): - """Send command to the device and wait for response. - - This function waits for response message matching a regular - expressions. - - Args: - cmds: The commands issued, either a list or a string. - regex_list: List of Regular expressions used to match response message. - Note1, list must be ordered. - Note2, empty list sends and returns. - timeout: time to wait for matching results before failing. - - Returns: - List of tuples, each of which contains the entire matched string and - all the subgroups of the match. None if not matched. - For example: - response of the given command: - High temp: 37.2 - Low temp: 36.4 - regex_list: - ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)'] - returns: - [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')] - - Raises: - ptyError: If timed out waiting for a response - """ - result_list = [] - self._open() - try: - self._send(cmds) - for regex in regex_list: - self._child.expect(regex, timeout) - match = self._child.match - lastindex = match.lastindex if match and match.lastindex else 0 - # Create a tuple which contains the entire matched string and all - # the subgroups of the match. - result = match.group(*range(lastindex + 1)) if match else None - if result: - result = tuple(res.decode('utf-8') for res in result) - result_list.append(result) - except pexpect.TIMEOUT: - raise ptyError('Timeout waiting for response.') - finally: - if not regex_list: - # Must be longer than delaybeforesend - time.sleep(0.1) - self._close() - return result_list - - def _issue_cmd_get_multi_results(self, cmd, regex): - """Send command to the device and wait for multiple response. - - This function waits for arbitrary number of response message - matching a regular expression. - - Args: - cmd: The command issued. - regex: Regular expression used to match response message. - - Returns: - List of tuples, each of which contains the entire matched string and - all the subgroups of the match. None if not matched. - """ - result_list = [] - self._open() - try: - self._send(cmd) - while True: + """Automate interactive commands on a pty interface.""" + + def __init__(self, interface, params, fast=False): + """Init class variables.""" + self._child = None + self._fd = None + self._interface = interface + self._pty_path = self._interface.get_pty() + self._dict = UART_PARAMS.copy() + self._fast = fast + + def __del__(self): + self.close() + + def close(self): + """Close any open files and interfaces.""" + if self._fd: + self._close() + self._interface.close() + + def _open(self): + """Connect to serial device and create pexpect interface.""" + assert self._fd is None + self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK) + # Don't allow forked processes to access. + fcntl.fcntl( + self._fd, + fcntl.F_SETFD, + fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC, + ) + self._child = fdpexpect.fdspawn(self._fd) + # pexpect defaults to a 100ms delay before sending characters, to + # work around race conditions in ssh. We don't need this feature + # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up. + if self._fast: + self._child.delaybeforesend = 0.001 + + def _close(self): + """Close serial device connection.""" + os.close(self._fd) + self._fd = None + self._child = None + + def _flush(self): + """Flush device output to prevent previous messages interfering.""" + if self._child.sendline("") != 1: + raise ptyError("Failed to send newline.") + # Have a maximum timeout for the flush operation. We should have cleared + # all data from the buffer, but if data is regularly being generated, we + # can't guarantee it will ever stop. + flush_end_time = time.time() + FLUSH_UART_TIMEOUT + while time.time() <= flush_end_time: + try: + self._child.expect(".", timeout=0.01) + except (pexpect.TIMEOUT, pexpect.EOF): + break + except OSError as e: + # EAGAIN indicates no data available, maybe we didn't wait long enough. + if e.errno != errno.EAGAIN: + raise + break + + def _send(self, cmds): + """Send command to EC. + + This function always flushes serial device before sending, and is used as + a wrapper function to make sure the channel is always flushed before + sending commands. + + Args: + cmds: The commands to send to the device, either a list or a string. + + Raises: + ptyError: Raised when writing to the device fails. + """ + self._flush() + if not isinstance(cmds, list): + cmds = [cmds] + for cmd in cmds: + if self._child.sendline(cmd) != len(cmd) + 1: + raise ptyError("Failed to send command.") + + def _issue_cmd(self, cmds): + """Send command to the device and do not wait for response. + + Args: + cmds: The commands to send to the device, either a list or a string. + """ + self._issue_cmd_get_results(cmds, []) + + def _issue_cmd_get_results( + self, cmds, regex_list, timeout=DEFAULT_UART_TIMEOUT + ): + """Send command to the device and wait for response. + + This function waits for response message matching a regular + expressions. + + Args: + cmds: The commands issued, either a list or a string. + regex_list: List of Regular expressions used to match response message. + Note1, list must be ordered. + Note2, empty list sends and returns. + timeout: time to wait for matching results before failing. + + Returns: + List of tuples, each of which contains the entire matched string and + all the subgroups of the match. None if not matched. + For example: + response of the given command: + High temp: 37.2 + Low temp: 36.4 + regex_list: + ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)'] + returns: + [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')] + + Raises: + ptyError: If timed out waiting for a response + """ + result_list = [] + self._open() try: - self._child.expect(regex, timeout=0.1) - match = self._child.match - lastindex = match.lastindex if match and match.lastindex else 0 - # Create a tuple which contains the entire matched string and all - # the subgroups of the match. - result = match.group(*range(lastindex + 1)) if match else None - if result: - result = tuple(res.decode('utf-8') for res in result) - result_list.append(result) + self._send(cmds) + for regex in regex_list: + self._child.expect(regex, timeout) + match = self._child.match + lastindex = match.lastindex if match and match.lastindex else 0 + # Create a tuple which contains the entire matched string and all + # the subgroups of the match. + result = match.group(*range(lastindex + 1)) if match else None + if result: + result = tuple(res.decode("utf-8") for res in result) + result_list.append(result) except pexpect.TIMEOUT: - break - finally: - self._close() - return result_list - - def _Set_uart_timeout(self, timeout): - """Set timeout value for waiting for the device response. - - Args: - timeout: Timeout value in second. - """ - self._dict['uart_timeout'] = timeout - - def _Get_uart_timeout(self): - """Get timeout value for waiting for the device response. - - Returns: - Timeout value in second. - """ - return self._dict['uart_timeout'] - - def _Set_uart_regexp(self, regexp): - """Set the list of regular expressions which matches the command response. - - Args: - regexp: A string which contains a list of regular expressions. - """ - if not isinstance(regexp, str): - raise ptyError('The argument regexp should be a string.') - self._dict['uart_regexp'] = ast.literal_eval(regexp) - - def _Get_uart_regexp(self): - """Get the list of regular expressions which matches the command response. - - Returns: - A string which contains a list of regular expressions. - """ - return str(self._dict['uart_regexp']) - - def _Set_uart_cmd(self, cmd): - """Set the UART command and send it to the device. - - If ec_uart_regexp is 'None', the command is just sent and it doesn't care - about its response. - - If ec_uart_regexp is not 'None', the command is send and its response, - which matches the regular expression of ec_uart_regexp, will be kept. - Use its getter to obtain this result. If no match after ec_uart_timeout - seconds, a timeout error will be raised. - - Args: - cmd: A string of UART command. - """ - if self._dict['uart_regexp']: - self._dict['uart_cmd'] = self._issue_cmd_get_results( - cmd, self._dict['uart_regexp'], self._dict['uart_timeout']) - else: - self._dict['uart_cmd'] = None - self._issue_cmd(cmd) - - def _Set_uart_multicmd(self, cmds): - """Set multiple UART commands and send them to the device. - - Note that ec_uart_regexp is not supported to match the results. - - Args: - cmds: A semicolon-separated string of UART commands. - """ - self._issue_cmd(cmds.split(';')) - - def _Get_uart_cmd(self): - """Get the result of the latest UART command. - - Returns: - A string which contains a list of tuples, each of which contains the - entire matched string and all the subgroups of the match. 'None' if - the ec_uart_regexp is 'None'. - """ - return str(self._dict['uart_cmd']) - - def _Set_uart_capture(self, cmd): - """Set UART capture mode (on or off). - - Once capture is enabled, UART output could be collected periodically by - invoking _Get_uart_stream() below. - - Args: - cmd: True for on, False for off - """ - self._interface.set_capture_active(cmd) - - def _Get_uart_capture(self): - """Get the UART capture mode (on or off).""" - return self._interface.get_capture_active() - - def _Get_uart_stream(self): - """Get uart stream generated since last time.""" - return self._interface.get_stream() + raise ptyError("Timeout waiting for response.") + finally: + if not regex_list: + # Must be longer than delaybeforesend + time.sleep(0.1) + self._close() + return result_list + + def _issue_cmd_get_multi_results(self, cmd, regex): + """Send command to the device and wait for multiple response. + + This function waits for arbitrary number of response message + matching a regular expression. + + Args: + cmd: The command issued. + regex: Regular expression used to match response message. + + Returns: + List of tuples, each of which contains the entire matched string and + all the subgroups of the match. None if not matched. + """ + result_list = [] + self._open() + try: + self._send(cmd) + while True: + try: + self._child.expect(regex, timeout=0.1) + match = self._child.match + lastindex = ( + match.lastindex if match and match.lastindex else 0 + ) + # Create a tuple which contains the entire matched string and all + # the subgroups of the match. + result = ( + match.group(*range(lastindex + 1)) if match else None + ) + if result: + result = tuple(res.decode("utf-8") for res in result) + result_list.append(result) + except pexpect.TIMEOUT: + break + finally: + self._close() + return result_list + + def _Set_uart_timeout(self, timeout): + """Set timeout value for waiting for the device response. + + Args: + timeout: Timeout value in second. + """ + self._dict["uart_timeout"] = timeout + + def _Get_uart_timeout(self): + """Get timeout value for waiting for the device response. + + Returns: + Timeout value in second. + """ + return self._dict["uart_timeout"] + + def _Set_uart_regexp(self, regexp): + """Set the list of regular expressions which matches the command response. + + Args: + regexp: A string which contains a list of regular expressions. + """ + if not isinstance(regexp, str): + raise ptyError("The argument regexp should be a string.") + self._dict["uart_regexp"] = ast.literal_eval(regexp) + + def _Get_uart_regexp(self): + """Get the list of regular expressions which matches the command response. + + Returns: + A string which contains a list of regular expressions. + """ + return str(self._dict["uart_regexp"]) + + def _Set_uart_cmd(self, cmd): + """Set the UART command and send it to the device. + + If ec_uart_regexp is 'None', the command is just sent and it doesn't care + about its response. + + If ec_uart_regexp is not 'None', the command is send and its response, + which matches the regular expression of ec_uart_regexp, will be kept. + Use its getter to obtain this result. If no match after ec_uart_timeout + seconds, a timeout error will be raised. + + Args: + cmd: A string of UART command. + """ + if self._dict["uart_regexp"]: + self._dict["uart_cmd"] = self._issue_cmd_get_results( + cmd, self._dict["uart_regexp"], self._dict["uart_timeout"] + ) + else: + self._dict["uart_cmd"] = None + self._issue_cmd(cmd) + + def _Set_uart_multicmd(self, cmds): + """Set multiple UART commands and send them to the device. + + Note that ec_uart_regexp is not supported to match the results. + + Args: + cmds: A semicolon-separated string of UART commands. + """ + self._issue_cmd(cmds.split(";")) + + def _Get_uart_cmd(self): + """Get the result of the latest UART command. + + Returns: + A string which contains a list of tuples, each of which contains the + entire matched string and all the subgroups of the match. 'None' if + the ec_uart_regexp is 'None'. + """ + return str(self._dict["uart_cmd"]) + + def _Set_uart_capture(self, cmd): + """Set UART capture mode (on or off). + + Once capture is enabled, UART output could be collected periodically by + invoking _Get_uart_stream() below. + + Args: + cmd: True for on, False for off + """ + self._interface.set_capture_active(cmd) + + def _Get_uart_capture(self): + """Get the UART capture mode (on or off).""" + return self._interface.get_capture_active() + + def _Get_uart_stream(self): + """Get uart stream generated since last time.""" + return self._interface.get_stream() diff --git a/extra/tigertool/ecusb/stm32uart.py b/extra/tigertool/ecusb/stm32uart.py index 95219455a9..64d0234f06 100644 --- a/extra/tigertool/ecusb/stm32uart.py +++ b/extra/tigertool/ecusb/stm32uart.py @@ -1,10 +1,6 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Allow creation of uart/console interface via stm32 usb endpoint.""" @@ -17,232 +13,247 @@ import termios import threading import time import tty -import usb + +import usb # pylint:disable=import-error from . import stm32usb class SuartError(Exception): - """Class for exceptions of Suart.""" - def __init__(self, msg, value=0): - """SuartError constructor. + """Class for exceptions of Suart.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SuartError, self).__init__(msg, value) - self.msg = msg - self.value = value + def __init__(self, msg, value=0): + """SuartError constructor. + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SuartError, self).__init__(msg, value) + self.msg = msg + self.value = value -class Suart(object): - """Provide interface to stm32 serial usb endpoint.""" - def __init__(self, vendor=0x18d1, product=0x501a, interface=0, - serialname=None, debuglog=False): - """Suart contstructor. - - Initializes stm32 USB stream interface. - - Args: - vendor: usb vendor id of stm32 device - product: usb product id of stm32 device - interface: interface number of stm32 device to use - serialname: serial name to target. Defaults to None. - debuglog: chatty output. Defaults to False. - - Raises: - SuartError: If init fails - """ - self._ptym = None - self._ptys = None - self._ptyname = None - self._rx_thread = None - self._tx_thread = None - self._debuglog = debuglog - self._susb = stm32usb.Susb(vendor=vendor, product=product, - interface=interface, serialname=serialname) - self._running = False - - def __del__(self): - """Suart destructor.""" - self.close() - - def close(self): - """Stop all running threads.""" - self._running = False - if self._rx_thread: - self._rx_thread.join(2) - self._rx_thread = None - if self._tx_thread: - self._tx_thread.join(2) - self._tx_thread = None - self._susb.close() - - def run_rx_thread(self): - """Background loop to pass data from USB to pty.""" - ep = select.epoll() - ep.register(self._ptym, select.EPOLLHUP) - try: - while self._running: - events = ep.poll(0) - # Check if the pty is connected to anything, or hungup. - if not events: - try: - r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) - if r: - if self._debuglog: - print(''.join([chr(x) for x in r]), end='') - os.write(self._ptym, r) - - # If we miss some characters on pty disconnect, that's fine. - # ep.read() also throws USBError on timeout, which we discard. - except OSError: - pass - except usb.core.USBError: - pass - else: - time.sleep(.1) - except Exception as e: - raise e - - def run_tx_thread(self): - """Background loop to pass data from pty to USB.""" - ep = select.epoll() - ep.register(self._ptym, select.EPOLLHUP) - try: - while self._running: - events = ep.poll(0) - # Check if the pty is connected to anything, or hungup. - if not events: - try: - r = os.read(self._ptym, 64) - # TODO(crosbug.com/936182): Remove when the servo v4/micro console - # issues are fixed. - time.sleep(0.001) - if r: - self._susb._write_ep.write(r, self._susb.TIMEOUT_MS) - - except OSError: - pass - except usb.core.USBError: - pass - else: - time.sleep(.1) - except Exception as e: - raise e - - def run(self): - """Creates pthreads to poll stm32 & PTY for data.""" - m, s = os.openpty() - self._ptyname = os.ttyname(s) - - self._ptym = m - self._ptys = s - - os.fchmod(s, 0o660) - - # Change the owner and group of the PTY to the user who started servod. - try: - uid = int(os.environ.get('SUDO_UID', -1)) - except TypeError: - uid = -1 - try: - gid = int(os.environ.get('SUDO_GID', -1)) - except TypeError: - gid = -1 - os.fchown(s, uid, gid) - - tty.setraw(self._ptym, termios.TCSADRAIN) - - # Generate a HUP flag on pty slave fd. - os.fdopen(s).close() - - self._running = True - - self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[]) - self._rx_thread.daemon = True - self._rx_thread.start() - - self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[]) - self._tx_thread.daemon = True - self._tx_thread.start() - - def get_uart_props(self): - """Get the uart's properties. - - Returns: - dict where: - baudrate: integer of uarts baudrate - bits: integer, number of bits of data Can be 5|6|7|8 inclusive - parity: integer, parity of 0-2 inclusive where: - 0: no parity - 1: odd parity - 2: even parity - sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: - 0: 1 stop bit - 1: 1.5 stop bits - 2: 2 stop bits - """ - return { - 'baudrate': 115200, - 'bits': 8, - 'parity': 0, - 'sbits': 1, - } - - def set_uart_props(self, line_props): - """Set the uart's properties. - - Note that Suart cannot set properties - and will fail if the properties are not the default 115200,8n1. - - Args: - line_props: dict where: - baudrate: integer of uarts baudrate - bits: integer, number of bits of data ( prior to stop bit) - parity: integer, parity of 0-2 inclusive where - 0: no parity - 1: odd parity - 2: even parity - sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: - 0: 1 stop bit - 1: 1.5 stop bits - 2: 2 stop bits - - Raises: - SuartError: If requested line properties are not the default. - """ - curr_props = self.get_uart_props() - for prop in line_props: - if line_props[prop] != curr_props[prop]: - raise SuartError('Line property %s cannot be set from %s to %s' % ( - prop, curr_props[prop], line_props[prop])) - return True - - def get_pty(self): - """Gets path to pty for communication to/from uart. - - Returns: - String path to the pty connected to the uart - """ - return self._ptyname +class Suart(object): + """Provide interface to stm32 serial usb endpoint.""" + + def __init__( + self, + vendor=0x18D1, + product=0x501A, + interface=0, + serialname=None, + debuglog=False, + ): + """Suart contstructor. + + Initializes stm32 USB stream interface. + + Args: + vendor: usb vendor id of stm32 device + product: usb product id of stm32 device + interface: interface number of stm32 device to use + serialname: serial name to target. Defaults to None. + debuglog: chatty output. Defaults to False. + + Raises: + SuartError: If init fails + """ + self._ptym = None + self._ptys = None + self._ptyname = None + self._rx_thread = None + self._tx_thread = None + self._debuglog = debuglog + self._susb = stm32usb.Susb( + vendor=vendor, + product=product, + interface=interface, + serialname=serialname, + ) + self._running = False + + def __del__(self): + """Suart destructor.""" + self.close() + + def close(self): + """Stop all running threads.""" + self._running = False + if self._rx_thread: + self._rx_thread.join(2) + self._rx_thread = None + if self._tx_thread: + self._tx_thread.join(2) + self._tx_thread = None + self._susb.close() + + def run_rx_thread(self): + """Background loop to pass data from USB to pty.""" + ep = select.epoll() + ep.register(self._ptym, select.EPOLLHUP) + try: + while self._running: + events = ep.poll(0) + # Check if the pty is connected to anything, or hungup. + if not events: + try: + r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) + if r: + if self._debuglog: + print("".join([chr(x) for x in r]), end="") + os.write(self._ptym, r) + + # If we miss some characters on pty disconnect, that's fine. + # ep.read() also throws USBError on timeout, which we discard. + except OSError: + pass + except usb.core.USBError: + pass + else: + time.sleep(0.1) + except Exception as e: + raise e + + def run_tx_thread(self): + """Background loop to pass data from pty to USB.""" + ep = select.epoll() + ep.register(self._ptym, select.EPOLLHUP) + try: + while self._running: + events = ep.poll(0) + # Check if the pty is connected to anything, or hungup. + if not events: + try: + r = os.read(self._ptym, 64) + # TODO(crosbug.com/936182): Remove when the servo v4/micro console + # issues are fixed. + time.sleep(0.001) + if r: + self._susb._write_ep.write(r, self._susb.TIMEOUT_MS) + + except OSError: + pass + except usb.core.USBError: + pass + else: + time.sleep(0.1) + except Exception as e: + raise e + + def run(self): + """Creates pthreads to poll stm32 & PTY for data.""" + m, s = os.openpty() + self._ptyname = os.ttyname(s) + + self._ptym = m + self._ptys = s + + os.fchmod(s, 0o660) + + # Change the owner and group of the PTY to the user who started servod. + try: + uid = int(os.environ.get("SUDO_UID", -1)) + except TypeError: + uid = -1 + + try: + gid = int(os.environ.get("SUDO_GID", -1)) + except TypeError: + gid = -1 + os.fchown(s, uid, gid) + + tty.setraw(self._ptym, termios.TCSADRAIN) + + # Generate a HUP flag on pty slave fd. + os.fdopen(s).close() + + self._running = True + + self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[]) + self._rx_thread.daemon = True + self._rx_thread.start() + + self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[]) + self._tx_thread.daemon = True + self._tx_thread.start() + + def get_uart_props(self): + """Get the uart's properties. + + Returns: + dict where: + baudrate: integer of uarts baudrate + bits: integer, number of bits of data Can be 5|6|7|8 inclusive + parity: integer, parity of 0-2 inclusive where: + 0: no parity + 1: odd parity + 2: even parity + sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: + 0: 1 stop bit + 1: 1.5 stop bits + 2: 2 stop bits + """ + return { + "baudrate": 115200, + "bits": 8, + "parity": 0, + "sbits": 1, + } + + def set_uart_props(self, line_props): + """Set the uart's properties. + + Note that Suart cannot set properties + and will fail if the properties are not the default 115200,8n1. + + Args: + line_props: dict where: + baudrate: integer of uarts baudrate + bits: integer, number of bits of data ( prior to stop bit) + parity: integer, parity of 0-2 inclusive where + 0: no parity + 1: odd parity + 2: even parity + sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: + 0: 1 stop bit + 1: 1.5 stop bits + 2: 2 stop bits + + Raises: + SuartError: If requested line properties are not the default. + """ + curr_props = self.get_uart_props() + for prop in line_props: + if line_props[prop] != curr_props[prop]: + raise SuartError( + "Line property %s cannot be set from %s to %s" + % (prop, curr_props[prop], line_props[prop]) + ) + return True + + def get_pty(self): + """Gets path to pty for communication to/from uart. + + Returns: + String path to the pty connected to the uart + """ + return self._ptyname def main(): - """Run a suart test with the default parameters.""" - try: - sobj = Suart() - sobj.run() + """Run a suart test with the default parameters.""" + try: + sobj = Suart() + sobj.run() - # run() is a thread so just busy wait to mimic server. - while True: - # Ours sleeps to eleven! - time.sleep(11) - except KeyboardInterrupt: - sys.exit(0) + # run() is a thread so just busy wait to mimic server. + while True: + # Ours sleeps to eleven! + time.sleep(11) + except KeyboardInterrupt: + sys.exit(0) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/tigertool/ecusb/stm32usb.py b/extra/tigertool/ecusb/stm32usb.py index bfd5fbb1fb..f9c700466a 100644 --- a/extra/tigertool/ecusb/stm32usb.py +++ b/extra/tigertool/ecusb/stm32usb.py @@ -1,119 +1,132 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Allows creation of an interface via stm32 usb.""" -import usb +import usb # pylint:disable=import-error class SusbError(Exception): - """Class for exceptions of Susb.""" - def __init__(self, msg, value=0): - """SusbError constructor. + """Class for exceptions of Susb.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SusbError, self).__init__(msg, value) - self.msg = msg - self.value = value + def __init__(self, msg, value=0): + """SusbError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SusbError, self).__init__(msg, value) + self.msg = msg + self.value = value class Susb(object): - """Provide stm32 USB functionality. - - Instance Variables: - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - READ_ENDPOINT = 0x81 - WRITE_ENDPOINT = 0x1 - TIMEOUT_MS = 100 - - def __init__(self, vendor=0x18d1, - product=0x5027, interface=1, serialname=None, logger=None): - """Susb constructor. - - Discovers and connects to stm32 USB endpoints. - - Args: - vendor: usb vendor id of stm32 device. - product: usb product id of stm32 device. - interface: interface number ( 1 - 4 ) of stm32 device to use. - serialname: string of device serialname. - logger: none - - Raises: - SusbError: An error accessing Susb object + """Provide stm32 USB functionality. + + Instance Variables: + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - self._vendor = vendor - self._product = product - self._interface = interface - self._serialname = serialname - self._find_device() - - def _find_device(self): - """Set up the usb endpoint""" - # Find the stm32. - dev_g = usb.core.find(idVendor=self._vendor, idProduct=self._product, - find_all=True) - dev_list = list(dev_g) - - if not dev_list: - raise SusbError('USB device not found') - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if self._serialname: - for d in dev_list: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == self._serialname: - dev = d - break - if dev is None: - raise SusbError('USB device(%s) not found' % self._serialname) - else: - try: - dev = dev_list[0] - except StopIteration: - raise SusbError('USB device %04x:%04x not found' % ( - self._vendor, self._product)) - - # If we can't set configuration, it's already been set. - try: - dev.set_configuration() - except usb.core.USBError: - pass - - self._dev = dev - - # Get an endpoint instance. - cfg = dev.get_active_configuration() - intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface) - self._intf = intf - if not intf: - raise SusbError('Interface %04x:%04x - 0x%x not found' % ( - self._vendor, self._product, self._interface)) - - # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel - # module driver that produces a ttyUSB, or direct endpoint access, but - # can't do both at the same time. - if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: - dev.detach_kernel_driver(intf.bInterfaceNumber) - - read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT - read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) - self._read_ep = read_ep - - write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT - write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) - self._write_ep = write_ep - - def close(self): - usb.util.dispose_resources(self._dev) + + READ_ENDPOINT = 0x81 + WRITE_ENDPOINT = 0x1 + TIMEOUT_MS = 100 + + def __init__( + self, + vendor=0x18D1, + product=0x5027, + interface=1, + serialname=None, + logger=None, + ): + """Susb constructor. + + Discovers and connects to stm32 USB endpoints. + + Args: + vendor: usb vendor id of stm32 device. + product: usb product id of stm32 device. + interface: interface number ( 1 - 4 ) of stm32 device to use. + serialname: string of device serialname. + logger: none + + Raises: + SusbError: An error accessing Susb object + """ + self._vendor = vendor + self._product = product + self._interface = interface + self._serialname = serialname + self._find_device() + + def _find_device(self): + """Set up the usb endpoint""" + # Find the stm32. + dev_g = usb.core.find( + idVendor=self._vendor, idProduct=self._product, find_all=True + ) + dev_list = list(dev_g) + + if not dev_list: + raise SusbError("USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if self._serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == self._serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % self._serialname) + else: + try: + dev = dev_list[0] + except StopIteration: + raise SusbError( + "USB device %04x:%04x not found" + % (self._vendor, self._product) + ) + + # If we can't set configuration, it's already been set. + try: + dev.set_configuration() + except usb.core.USBError: + pass + + self._dev = dev + + # Get an endpoint instance. + cfg = dev.get_active_configuration() + intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface) + self._intf = intf + if not intf: + raise SusbError( + "Interface %04x:%04x - 0x%x not found" + % (self._vendor, self._product, self._interface) + ) + + # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel + # module driver that produces a ttyUSB, or direct endpoint access, but + # can't do both at the same time. + if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: + dev.detach_kernel_driver(intf.bInterfaceNumber) + + read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT + read_ep = usb.util.find_descriptor( + intf, bEndpointAddress=read_ep_number + ) + self._read_ep = read_ep + + write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT + write_ep = usb.util.find_descriptor( + intf, bEndpointAddress=write_ep_number + ) + self._write_ep = write_ep + + def close(self): + usb.util.dispose_resources(self._dev) diff --git a/extra/tigertool/ecusb/tiny_servo_common.py b/extra/tigertool/ecusb/tiny_servo_common.py index e27736a9dc..fc028104ed 100644 --- a/extra/tigertool/ecusb/tiny_servo_common.py +++ b/extra/tigertool/ecusb/tiny_servo_common.py @@ -1,238 +1,241 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Utilities for using lightweight console functions.""" # Note: This is a py2/3 compatible file. import datetime -import errno -import os -import re -import subprocess import sys import time -import usb import six +import usb # pylint:disable=import-error -from . import pty_driver -from . import stm32uart +from . import pty_driver, stm32uart def get_subprocess_args(): - if six.PY3: - return {'encoding': 'utf-8'} - return {} + if six.PY3: + return {"encoding": "utf-8"} + return {} class TinyServoError(Exception): - """Exceptions.""" + """Exceptions.""" def log(output): - """Print output to console, logfiles can be added here. + """Print output to console, logfiles can be added here. + + Args: + output: string to output. + """ + sys.stdout.write(output) + sys.stdout.write("\n") + sys.stdout.flush() - Args: - output: string to output. - """ - sys.stdout.write(output) - sys.stdout.write('\n') - sys.stdout.flush() def check_usb(vidpid, serialname=None): - """Check if |vidpid| is present on the system's USB. + """Check if |vidpid| is present on the system's USB. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. - Returns: True if found, False, otherwise. - """ - if serialname: - output = subprocess.check_output(['lsusb', '-v', '-d', vidpid], - **get_subprocess_args()) - m = re.search(r'^\s*iSerial\s+\d+\s+%s$' % serialname, output, flags=re.M) - if m: - return True + Returns: + True if found, False, otherwise. + """ + if get_usb_dev(vidpid, serialname): + return True return False - else: - if subprocess.call(['lsusb', '-d', vidpid], stdout=open('/dev/null', 'w')): - return False - return True - -def check_usb_sn(vidpid): - """Return the serial number - - Return the serial number of the first USB device with VID:PID vidpid, - or None if no device is found. This will not work well with two of - the same device attached. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - Returns: string serial number if found, None otherwise. - """ - output = subprocess.check_output(['lsusb', '-v', '-d', vidpid], - **get_subprocess_args()) - m = re.search(r'^\s*iSerial\s+(.*)$', output, flags=re.M) - if m: - return m.group(1) +def check_usb_sn(vidpid): + """Return the serial number - return None + Return the serial number of the first USB device with VID:PID vidpid, + or None if no device is found. This will not work well with two of + the same device attached. -def get_usb_dev(vidpid, serialname=None): - """Return the USB pyusb devie struct + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - Return the dev struct of the first USB device with VID:PID vidpid, - or None if no device is found. If more than one device check serial - if supplied. + Returns: + string serial number if found, None otherwise. + """ + dev = get_usb_dev(vidpid) - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + if dev: + dev_serial = usb.util.get_string(dev, dev.iSerialNumber) - Returns: pyusb device if found, None otherwise. - """ - vidpidst = vidpid.split(':') - vid = int(vidpidst[0], 16) - pid = int(vidpidst[1], 16) + return dev_serial + return None - dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True) - dev_list = list(dev_g) - if not dev_list: - return None +def get_usb_dev(vidpid, serialname=None): + """Return the USB pyusb devie struct + + Return the dev struct of the first USB device with VID:PID vidpid, + or None if no device is found. If more than one device check serial + if supplied. + + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. + + Returns: + pyusb device if found, None otherwise. + """ + vidpidst = vidpid.split(":") + vid = int(vidpidst[0], 16) + pid = int(vidpidst[1], 16) + + dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True) + dev_list = list(dev_g) + + if not dev_list: + return None + + # Check if we have multiple devices and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + return None + else: + try: + dev = dev_list[0] + except StopIteration: + return None + + return dev - # Check if we have multiple devices and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - return None - else: - try: - dev = dev_list[0] - except StopIteration: - return None - - return dev def check_usb_dev(vidpid, serialname=None): - """Return the USB dev number + """Return the USB dev number - Return the dev number of the first USB device with VID:PID vidpid, - or None if no device is found. If more than one device check serial - if supplied. + Return the dev number of the first USB device with VID:PID vidpid, + or None if no device is found. If more than one device check serial + if supplied. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. - Returns: usb device number if found, None otherwise. - """ - dev = get_usb_dev(vidpid, serialname=serialname) + Returns: + usb device number if found, None otherwise. + """ + dev = get_usb_dev(vidpid, serialname=serialname) - if dev: - return dev.address + if dev: + return dev.address - return None + return None def wait_for_usb_remove(vidpid, serialname=None, timeout=None): - """Wait for USB device with vidpid to be removed. + """Wait for USB device with vidpid to be removed. + + Wrapper for wait_for_usb below + """ + wait_for_usb( + vidpid, serialname=serialname, timeout=timeout, desiredpresence=False + ) - Wrapper for wait_for_usb below - """ - wait_for_usb(vidpid, serialname=serialname, - timeout=timeout, desiredpresence=False) def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True): - """Wait for usb device with vidpid to be present/absent. - - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specificed. - timeout: timeout in seconds, None for no timeout. - desiredpresence: True for present, False for not present. - - Raises: - TinyServoError: on timeout. - """ - if timeout: - finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout) - while check_usb(vidpid, serialname) != desiredpresence: - time.sleep(.01) + """Wait for usb device with vidpid to be present/absent. + + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specificed. + timeout: timeout in seconds, None for no timeout. + desiredpresence: True for present, False for not present. + + Raises: + TinyServoError: on timeout. + """ if timeout: - if datetime.datetime.now() > finish: - raise TinyServoError('Timeout', 'Timeout waiting for USB %s' % vidpid) + finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout) + while check_usb(vidpid, serialname) != desiredpresence: + time.sleep(0.1) + if timeout: + if datetime.datetime.now() > finish: + raise TinyServoError( + "Timeout", "Timeout waiting for USB %s" % vidpid + ) + def do_serialno(serialno, pty): - """Set serialnumber 'serialno' via ec console 'pty'. - - Commands are: - # > serialno set 1234 - # Saving serial number - # Serial number: 1234 - - Args: - serialno: string serial number to set. - pty: tinyservo console to send commands. - - Raises: - TinyServoError: on failure to set. - ptyError: on command interface error. - """ - cmd = 'serialno set %s' % serialno - regex = 'Serial number:\s+(\S+)' - - results = pty._issue_cmd_get_results(cmd, [regex])[0] - sn = results[1].strip().strip('\n\r') - - if sn == serialno: - log('Success !') - log('Serial set to %s' % sn) - else: - log('Serial number set to %s but saved as %s.' % (serialno, sn)) - raise TinyServoError( - 'Serial Number', - 'Serial number set to %s but saved as %s.' % (serialno, sn)) + """Set serialnumber 'serialno' via ec console 'pty'. + + Commands are: + # > serialno set 1234 + # Saving serial number + # Serial number: 1234 + + Args: + serialno: string serial number to set. + pty: tinyservo console to send commands. + + Raises: + TinyServoError: on failure to set. + ptyError: on command interface error. + """ + cmd = r"serialno set %s" % serialno + regex = r"Serial number:\s+(\S+)" + + results = pty._issue_cmd_get_results(cmd, [regex])[0] + sn = results[1].strip().strip("\n\r") + + if sn == serialno: + log("Success !") + log("Serial set to %s" % sn) + else: + log("Serial number set to %s but saved as %s." % (serialno, sn)) + raise TinyServoError( + "Serial Number", + "Serial number set to %s but saved as %s." % (serialno, sn), + ) + def setup_tinyservod(vidpid, interface, serialname=None, debuglog=False): - """Set up a pty - - Set up a pty to the ec console in order - to send commands. Returns a pty_driver object. - - Args: - vidpid: string vidpid of device to access. - interface: not used. - serialname: string serial name of device requested, optional. - debuglog: chatty printout (boolean) - - Returns: pty object - - Raises: - UsbError, SusbError: on device not found - """ - vidstr, pidstr = vidpid.split(':') - vid = int(vidstr, 16) - pid = int(pidstr, 16) - suart = stm32uart.Suart(vendor=vid, product=pid, - interface=interface, serialname=serialname, - debuglog=debuglog) - suart.run() - pty = pty_driver.ptyDriver(suart, []) - - return pty + """Set up a pty + + Set up a pty to the ec console in order + to send commands. Returns a pty_driver object. + + Args: + vidpid: string vidpid of device to access. + interface: not used. + serialname: string serial name of device requested, optional. + debuglog: chatty printout (boolean) + + Returns: + pty object + + Raises: + UsbError, SusbError: on device not found + """ + vidstr, pidstr = vidpid.split(":") + vid = int(vidstr, 16) + pid = int(pidstr, 16) + suart = stm32uart.Suart( + vendor=vid, + product=pid, + interface=interface, + serialname=serialname, + debuglog=debuglog, + ) + suart.run() + pty = pty_driver.ptyDriver(suart, []) + + return pty diff --git a/extra/tigertool/ecusb/tiny_servod.py b/extra/tigertool/ecusb/tiny_servod.py index 632d9c3a20..f8d61b5305 100644 --- a/extra/tigertool/ecusb/tiny_servod.py +++ b/extra/tigertool/ecusb/tiny_servod.py @@ -1,54 +1,51 @@ -# Copyright 2020 The Chromium OS Authors. All rights reserved. +# Copyright 2020 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Helper class to facilitate communication to servo ec console.""" -from ecusb import pty_driver -from ecusb import stm32uart +from ecusb import pty_driver, stm32uart class TinyServod(object): - """Helper class to wrap a pty_driver with interface.""" - - def __init__(self, vid, pid, interface, serialname=None, debug=False): - """Build the driver and interface. - - Args: - vid: servo device vid - pid: servo device pid - interface: which usb interface the servo console is on - serialname: the servo device serial (if available) - """ - self._vid = vid - self._pid = pid - self._interface = interface - self._serial = serialname - self._debug = debug - self._init() - - def _init(self): - self.suart = stm32uart.Suart(vendor=self._vid, - product=self._pid, - interface=self._interface, - serialname=self._serial, - debuglog=self._debug) - self.suart.run() - self.pty = pty_driver.ptyDriver(self.suart, []) - - def reinitialize(self): - """Reinitialize the connect after a reset/disconnect/etc.""" - self.close() - self._init() - - def close(self): - """Close out the connection and release resources. - - Note: if another TinyServod process or servod itself needs the same device - it's necessary to call this to ensure the usb device is available. - """ - self.suart.close() + """Helper class to wrap a pty_driver with interface.""" + + def __init__(self, vid, pid, interface, serialname=None, debug=False): + """Build the driver and interface. + + Args: + vid: servo device vid + pid: servo device pid + interface: which usb interface the servo console is on + serialname: the servo device serial (if available) + """ + self._vid = vid + self._pid = pid + self._interface = interface + self._serial = serialname + self._debug = debug + self._init() + + def _init(self): + self.suart = stm32uart.Suart( + vendor=self._vid, + product=self._pid, + interface=self._interface, + serialname=self._serial, + debuglog=self._debug, + ) + self.suart.run() + self.pty = pty_driver.ptyDriver(self.suart, []) + + def reinitialize(self): + """Reinitialize the connect after a reset/disconnect/etc.""" + self.close() + self._init() + + def close(self): + """Close out the connection and release resources. + + Note: if another TinyServod process or servod itself needs the same device + it's necessary to call this to ensure the usb device is available. + """ + self.suart.close() diff --git a/extra/tigertool/flash_dfu.sh b/extra/tigertool/flash_dfu.sh index 7aa6c24f09..9578ef626e 100755 --- a/extra/tigertool/flash_dfu.sh +++ b/extra/tigertool/flash_dfu.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/tigertool/make_pkg.sh b/extra/tigertool/make_pkg.sh index 5a63862242..ae0ae95cfe 100755 --- a/extra/tigertool/make_pkg.sh +++ b/extra/tigertool/make_pkg.sh @@ -1,8 +1,10 @@ #!/bin/bash -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +set -e + # Make sure we are in the correct dir. cd "$( dirname "${BASH_SOURCE[0]}" )" || exit @@ -21,10 +23,11 @@ cp tigertest.py "${DEST}" cp README.md "${DEST}" cp -r ecusb "${DEST}" -cp -r ../../../../../chroot/usr/lib64/python2.7/site-packages/usb "${DEST}" +# Not compatible with glinux as of 4/28/2022. +# cp -r ../../../../../chroot/usr/lib64/python3.6/site-packages/usb "${DEST}" find "${DEST}" -name "*.py[co]" -delete cp -r ../usb_serial "${DEST}" -(cd build; tar -czf tigertool_${DATE}.tgz tigertool) +(cd build && tar -czf tigertool_"${DATE}".tgz tigertool) echo "Done packaging tigertool_${DATE}.tgz" diff --git a/extra/tigertool/tigertest.py b/extra/tigertool/tigertest.py index 0cd31c8cce..b1186cca77 100755 --- a/extra/tigertool/tigertest.py +++ b/extra/tigertool/tigertest.py @@ -1,11 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2022 The Chromium OS Authors. All rights reserved. +# Copyright 2022 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Smoke test of tigertool binary.""" @@ -13,7 +9,6 @@ import argparse import subprocess import sys - # Script to control tigertail USB-C Mux board. # # optional arguments: @@ -35,58 +30,62 @@ import sys def testCmd(cmd, expected_results): - """Run command on console, check for success. - - Args: - cmd: shell command to run. - expected_results: a list object of strings expected in the result. - - Raises: - Exception on fail. - """ - print('run: ' + cmd) - try: - p = subprocess.run(cmd, shell=True, check=False, capture_output=True) - output = p.stdout.decode('utf-8') - error = p.stderr.decode('utf-8') - assert p.returncode == 0 - for result in expected_results: - output.index(result) - except Exception as e: - print('FAIL') - print('cmd: ' + cmd) - print('error: ' + str(e)) - print('stdout:\n' + output) - print('stderr:\n' + error) - print('expected: ' + str(expected_results)) - print('RC: ' + str(p.returncode)) - raise e + """Run command on console, check for success. + + Args: + cmd: shell command to run. + expected_results: a list object of strings expected in the result. + + Raises: + Exception on fail. + """ + print("run: " + cmd) + try: + p = subprocess.run(cmd, shell=True, check=False, capture_output=True) + output = p.stdout.decode("utf-8") + error = p.stderr.decode("utf-8") + assert p.returncode == 0 + for result in expected_results: + output.index(result) + except Exception as e: + print("FAIL") + print("cmd: " + cmd) + print("error: " + str(e)) + print("stdout:\n" + output) + print("stderr:\n" + error) + print("expected: " + str(expected_results)) + print("RC: " + str(p.returncode)) + raise e + def test_sequence(): - testCmd('./tigertool.py --reboot', ['PASS']) - testCmd('./tigertool.py --setserialno test', ['PASS']) - testCmd('./tigertool.py --check_serial', ['test', 'PASS']) - testCmd('./tigertool.py -s test --check_serial', ['test', 'PASS']) - testCmd('./tigertool.py -m A', ['Mux set to A', 'PASS']) - testCmd('./tigertool.py -m B', ['Mux set to B', 'PASS']) - testCmd('./tigertool.py -m off', ['Mux set to off', 'PASS']) - testCmd('./tigertool.py -p', ['PASS']) - testCmd('./tigertool.py -r rw', ['PASS']) - testCmd('./tigertool.py -r ro', ['PASS']) - testCmd('./tigertool.py --check_version', ['RW', 'RO', 'PASS']) - - print('PASS') + testCmd("./tigertool.py --reboot", ["PASS"]) + testCmd("./tigertool.py --setserialno test", ["PASS"]) + testCmd("./tigertool.py --check_serial", ["test", "PASS"]) + testCmd("./tigertool.py -s test --check_serial", ["test", "PASS"]) + testCmd("./tigertool.py -m A", ["Mux set to A", "PASS"]) + testCmd("./tigertool.py -m B", ["Mux set to B", "PASS"]) + testCmd("./tigertool.py -m off", ["Mux set to off", "PASS"]) + testCmd("./tigertool.py -p", ["PASS"]) + testCmd("./tigertool.py -r rw", ["PASS"]) + testCmd("./tigertool.py -r ro", ["PASS"]) + testCmd("./tigertool.py --check_version", ["RW", "RO", "PASS"]) + + print("PASS") + def main(argv): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('-c', '--count', type=int, default=1, - help='loops to run') + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-c", "--count", type=int, default=1, help="loops to run" + ) + + opts = parser.parse_args(argv) - opts = parser.parse_args(argv) + for i in range(1, opts.count + 1): + print("Iteration: %d" % i) + test_sequence() - for i in range(1, opts.count + 1): - print('Iteration: %d' % i) - test_sequence() -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/extra/tigertool/tigertool.py b/extra/tigertool/tigertool.py index 6baae8abdf..69303aa02a 100755 --- a/extra/tigertool/tigertool.py +++ b/extra/tigertool/tigertool.py @@ -1,11 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Script to control tigertail USB-C Mux board.""" @@ -17,287 +13,318 @@ import time import ecusb.tiny_servo_common as c -STM_VIDPID = '18d1:5027' -serialno = 'Uninitialized' +STM_VIDPID = "18d1:5027" +serialno = "Uninitialized" + def do_mux(mux, pty): - """Set mux via ec console 'pty'. + """Set mux via ec console 'pty'. + + Args: + mux: mux to connect to DUT, 'A', 'B', or 'off' + pty: a pty object connected to tigertail - Args: - mux: mux to connect to DUT, 'A', 'B', or 'off' - pty: a pty object connected to tigertail + Commands are: + # > mux A + # TYPE-C mux is A + """ + validmux = ["A", "B", "off"] + if mux not in validmux: + c.log("Mux setting %s invalid, try one of %s" % (mux, validmux)) + return False - Commands are: - # > mux A - # TYPE-C mux is A - """ - validmux = ['A', 'B', 'off'] - if mux not in validmux: - c.log('Mux setting %s invalid, try one of %s' % (mux, validmux)) - return False + cmd = "mux %s" % mux + regex = "TYPE\-C mux is ([^\s\r\n]*)\r" - cmd = 'mux %s' % mux - regex = 'TYPE\-C mux is ([^\s\r\n]*)\r' + results = pty._issue_cmd_get_results(cmd, [regex])[0] + result = results[1].strip().strip("\n\r") - results = pty._issue_cmd_get_results(cmd, [regex])[0] - result = results[1].strip().strip('\n\r') + if result != mux: + c.log("Mux set to %s but saved as %s." % (mux, result)) + return False + c.log("Mux set to %s" % result) + return True - if result != mux: - c.log('Mux set to %s but saved as %s.' % (mux, result)) - return False - c.log('Mux set to %s' % result) - return True def do_version(pty): - """Check version via ec console 'pty'. - - Args: - pty: a pty object connected to tigertail - - Commands are: - # > version - # Chip: stm stm32f07x - # Board: 0 - # RO: tigertail_v1.1.6749-74d1a312e - # RW: tigertail_v1.1.6749-74d1a312e - # Build: tigertail_v1.1.6749-74d1a312e - # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com - """ - cmd = 'version' - regex = r'RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+' \ - r'(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)' - - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('Version is %s' % results[3]) - c.log('RO: %s' % results[1]) - c.log('RW: %s' % results[2]) - c.log('Date: %s' % results[4]) - c.log('Src: %s' % results[5]) - - return True + """Check version via ec console 'pty'. + + Args: + pty: a pty object connected to tigertail + + Commands are: + # > version + # Chip: stm stm32f07x + # Board: 0 + # RO: tigertail_v1.1.6749-74d1a312e + # RW: tigertail_v1.1.6749-74d1a312e + # Build: tigertail_v1.1.6749-74d1a312e + # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com + """ + cmd = "version" + regex = ( + r"RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+" + r"(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)" + ) + + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log("Version is %s" % results[3]) + c.log("RO: %s" % results[1]) + c.log("RW: %s" % results[2]) + c.log("Date: %s" % results[4]) + c.log("Src: %s" % results[5]) + + return True + def do_check_serial(pty): - """Check serial via ec console 'pty'. + """Check serial via ec console 'pty'. - Args: - pty: a pty object connected to tigertail + Args: + pty: a pty object connected to tigertail - Commands are: - # > serialno - # Serial number: number - """ - cmd = 'serialno' - regex = r'Serial number: ([^\n\r]+)' + Commands are: + # > serialno + # Serial number: number + """ + cmd = "serialno" + regex = r"Serial number: ([^\n\r]+)" - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('Serial is %s' % results[1]) + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log("Serial is %s" % results[1]) - return True + return True def do_power(count, bus, pty): - """Check power usage via ec console 'pty'. - - Args: - count: number of samples to capture - bus: rail to monitor, 'vbus', 'cc1', or 'cc2' - pty: a pty object connected to tigertail - - Commands are: - # > ina 0 - # Configuration: 4127 - # Shunt voltage: 02c4 => 1770 uV - # Bus voltage : 1008 => 5130 mV - # Power : 0019 => 625 mW - # Current : 0082 => 130 mA - # Calibration : 0155 - # Mask/Enable : 0008 - # Alert limit : 0000 - """ - if bus == 'vbus': - ina = 0 - if bus == 'cc1': - ina = 4 - if bus == 'cc2': - ina = 1 - - start = time.time() - - c.log('time,\tmV,\tmW,\tmA') - - cmd = 'ina %s' % ina - regex = r'Bus voltage : \S+ \S+ (\d+) mV\s+' \ - r'Power : \S+ \S+ (\d+) mW\s+' \ - r'Current : \S+ \S+ (\d+) mA' - - for i in range(0, count): - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('%.2f,\t%s,\t%s\t%s' % ( - time.time() - start, - results[1], results[2], results[3])) + """Check power usage via ec console 'pty'. + + Args: + count: number of samples to capture + bus: rail to monitor, 'vbus', 'cc1', or 'cc2' + pty: a pty object connected to tigertail + + Commands are: + # > ina 0 + # Configuration: 4127 + # Shunt voltage: 02c4 => 1770 uV + # Bus voltage : 1008 => 5130 mV + # Power : 0019 => 625 mW + # Current : 0082 => 130 mA + # Calibration : 0155 + # Mask/Enable : 0008 + # Alert limit : 0000 + """ + if bus == "vbus": + ina = 0 + if bus == "cc1": + ina = 4 + if bus == "cc2": + ina = 1 + + start = time.time() + + c.log("time,\tmV,\tmW,\tmA") + + cmd = "ina %s" % ina + regex = ( + r"Bus voltage : \S+ \S+ (\d+) mV\s+" + r"Power : \S+ \S+ (\d+) mW\s+" + r"Current : \S+ \S+ (\d+) mA" + ) + + for i in range(0, count): + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log( + "%.2f,\t%s,\t%s\t%s" + % (time.time() - start, results[1], results[2], results[3]) + ) + + return True - return True def do_reboot(pty, serialname): - """Reboot via ec console pty - - Args: - pty: a pty object connected to tigertail - serialname: serial name, can be None. - - Command is: reboot. - """ - cmd = 'reboot' - - # Check usb dev number on current instance. - devno = c.check_usb_dev(STM_VIDPID, serialname=serialname) - if not devno: - c.log('Device not found') - return False - - try: - pty._issue_cmd(cmd) - except Exception as e: - c.log('Failed to send command: ' + str(e)) - return False - - try: - c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds - # it's not going to. This step just goes faster if it's detected. - pass - - try: - c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - c.log('Failed to return from reboot: ' + str(e)) - return False - - # Check that the device had a new device number, i.e. it's - # disconnected and reconnected. - newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname) - if newdevno == devno: - c.log("Device didn't reboot") - return False - - return True + """Reboot via ec console pty + + Args: + pty: a pty object connected to tigertail + serialname: serial name, can be None. + + Command is: reboot. + """ + cmd = "reboot" + + # Check usb dev number on current instance. + devno = c.check_usb_dev(STM_VIDPID, serialname=serialname) + if not devno: + c.log("Device not found") + return False + + try: + pty._issue_cmd(cmd) + except Exception as e: + c.log("Failed to send command: " + str(e)) + return False + + try: + c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds + # it's not going to. This step just goes faster if it's detected. + pass + + try: + c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + c.log("Failed to return from reboot: " + str(e)) + return False + + # Check that the device had a new device number, i.e. it's + # disconnected and reconnected. + newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname) + if newdevno == devno: + c.log("Device didn't reboot") + return False + + return True + def do_sysjump(region, pty, serialname): - """Set region via ec console 'pty'. - - Args: - region: ec code region to execute, 'ro' or 'rw' - pty: a pty object connected to tigertail - serialname: serial name, can be None. - - Commands are: - # > sysjump rw - """ - validregion = ['ro', 'rw'] - if region not in validregion: - c.log('Region setting %s invalid, try one of %s' % ( - region, validregion)) - return False - - cmd = 'sysjump %s' % region - try: - pty._issue_cmd(cmd) - except Exception as e: - c.log('Exception: ' + str(e)) - return False - - try: - c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds - # it's not going to. This step just goes faster if it's detected. - pass - - try: - c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - c.log('Failed to return from restart: ' + str(e)) - return False - - c.log('Region requested %s' % region) - return True + """Set region via ec console 'pty'. + + Args: + region: ec code region to execute, 'ro' or 'rw' + pty: a pty object connected to tigertail + serialname: serial name, can be None. + + Commands are: + # > sysjump rw + """ + validregion = ["ro", "rw"] + if region not in validregion: + c.log( + "Region setting %s invalid, try one of %s" % (region, validregion) + ) + return False + + cmd = "sysjump %s" % region + try: + pty._issue_cmd(cmd) + except Exception as e: + c.log("Exception: " + str(e)) + return False + + try: + c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds + # it's not going to. This step just goes faster if it's detected. + pass + + try: + c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + c.log("Failed to return from restart: " + str(e)) + return False + + c.log("Region requested %s" % region) + return True + def get_parser(): - parser = argparse.ArgumentParser( - description=__doc__) - parser.add_argument('-s', '--serialno', type=str, default=None, - help='serial number of board to use') - parser.add_argument('-b', '--bus', type=str, default='vbus', - help='Which rail to log: [vbus|cc1|cc2]') - group = parser.add_mutually_exclusive_group() - group.add_argument('--setserialno', type=str, default=None, - help='serial number to set on the board.') - group.add_argument('--check_serial', action='store_true', - help='check serial number set on the board.') - group.add_argument('-m', '--mux', type=str, default=None, - help='mux selection') - group.add_argument('-p', '--power', action='store_true', - help='check VBUS') - group.add_argument('-l', '--powerlog', type=int, default=None, - help='log VBUS') - group.add_argument('-r', '--sysjump', type=str, default=None, - help='region selection') - group.add_argument('--reboot', action='store_true', - help='reboot tigertail') - group.add_argument('--check_version', action='store_true', - help='check tigertail version') - return parser + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-s", + "--serialno", + type=str, + default=None, + help="serial number of board to use", + ) + parser.add_argument( + "-b", + "--bus", + type=str, + default="vbus", + help="Which rail to log: [vbus|cc1|cc2]", + ) + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--setserialno", + type=str, + default=None, + help="serial number to set on the board.", + ) + group.add_argument( + "--check_serial", + action="store_true", + help="check serial number set on the board.", + ) + group.add_argument( + "-m", "--mux", type=str, default=None, help="mux selection" + ) + group.add_argument("-p", "--power", action="store_true", help="check VBUS") + group.add_argument( + "-l", "--powerlog", type=int, default=None, help="log VBUS" + ) + group.add_argument( + "-r", "--sysjump", type=str, default=None, help="region selection" + ) + group.add_argument("--reboot", action="store_true", help="reboot tigertail") + group.add_argument( + "--check_version", action="store_true", help="check tigertail version" + ) + return parser + def main(argv): - parser = get_parser() - opts = parser.parse_args(argv) + parser = get_parser() + opts = parser.parse_args(argv) - result = True + result = True - # Let's make sure there's a tigertail - # If nothing found in 5 seconds, fail. - c.wait_for_usb(STM_VIDPID, timeout=5., serialname=opts.serialno) + # Let's make sure there's a tigertail + # If nothing found in 5 seconds, fail. + c.wait_for_usb(STM_VIDPID, timeout=5.0, serialname=opts.serialno) - pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno) + pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno) - if opts.bus not in ('vbus', 'cc1', 'cc2'): - c.log('Try --bus [vbus|cc1|cc2]') - result = False + if opts.bus not in ("vbus", "cc1", "cc2"): + c.log("Try --bus [vbus|cc1|cc2]") + result = False - elif opts.setserialno: - try: - c.do_serialno(opts.setserialno, pty) - except Exception: - result = False + elif opts.setserialno: + try: + c.do_serialno(opts.setserialno, pty) + except Exception: + result = False - elif opts.mux: - result &= do_mux(opts.mux, pty) + elif opts.mux: + result &= do_mux(opts.mux, pty) - elif opts.sysjump: - result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno) + elif opts.sysjump: + result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno) - elif opts.reboot: - result &= do_reboot(pty, serialname=opts.serialno) + elif opts.reboot: + result &= do_reboot(pty, serialname=opts.serialno) - elif opts.check_version: - result &= do_version(pty) + elif opts.check_version: + result &= do_version(pty) - elif opts.check_serial: - result &= do_check_serial(pty) + elif opts.check_serial: + result &= do_check_serial(pty) - elif opts.power: - result &= do_power(1, opts.bus, pty) + elif opts.power: + result &= do_power(1, opts.bus, pty) - elif opts.powerlog: - result &= do_power(opts.powerlog, opts.bus, pty) + elif opts.powerlog: + result &= do_power(opts.powerlog, opts.bus, pty) - if result: - c.log('PASS') - else: - c.log('FAIL') - sys.exit(-1) + if result: + c.log("PASS") + else: + c.log("FAIL") + sys.exit(-1) -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/extra/touchpad_updater/Makefile b/extra/touchpad_updater/Makefile index ebf9c3212d..df824e8757 100644 --- a/extra/touchpad_updater/Makefile +++ b/extra/touchpad_updater/Makefile @@ -1,4 +1,4 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/touchpad_updater/touchpad_updater.c b/extra/touchpad_updater/touchpad_updater.c index 716ded00f5..fee898ca06 100644 --- a/extra/touchpad_updater/touchpad_updater.c +++ b/extra/touchpad_updater/touchpad_updater.c @@ -1,5 +1,5 @@ /* - * Copyright 2017 The Chromium OS Authors. All rights reserved. + * Copyright 2017 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -18,16 +18,16 @@ #include <libusb.h> /* Command line options */ -static uint16_t vid = 0x18d1; /* Google */ -static uint16_t pid = 0x5022; /* Hammer */ -static uint8_t ep_num = 4; /* console endpoint */ -static uint8_t extended_i2c_exercise; /* non-zero to exercise */ -static char *firmware_binary = "144.0_2.0.bin"; /* firmware blob */ +static uint16_t vid = 0x18d1; /* Google */ +static uint16_t pid = 0x5022; /* Hammer */ +static uint8_t ep_num = 4; /* console endpoint */ +static uint8_t extended_i2c_exercise; /* non-zero to exercise */ +static char *firmware_binary = "144.0_2.0.bin"; /* firmware blob */ /* Firmware binary blob related */ -#define MAX_FW_PAGE_SIZE 512 -#define MAX_FW_PAGE_COUNT 1024 -#define MAX_FW_SIZE (128 * 1024) +#define MAX_FW_PAGE_SIZE 512 +#define MAX_FW_PAGE_COUNT 1024 +#define MAX_FW_SIZE (128 * 1024) static uint8_t fw_data[MAX_FW_SIZE]; int fw_page_count; @@ -47,13 +47,10 @@ static char *progname; static char *short_opts = ":f:v:p:e:hd"; static const struct option long_opts[] = { /* name hasarg *flag val */ - {"file", 1, NULL, 'f'}, - {"vid", 1, NULL, 'v'}, - {"pid", 1, NULL, 'p'}, - {"ep", 1, NULL, 'e'}, - {"help", 0, NULL, 'h'}, - {"debug", 0, NULL, 'd'}, - {NULL, 0, NULL, 0}, + { "file", 1, NULL, 'f' }, { "vid", 1, NULL, 'v' }, + { "pid", 1, NULL, 'p' }, { "ep", 1, NULL, 'e' }, + { "help", 0, NULL, 'h' }, { "debug", 0, NULL, 'd' }, + { NULL, 0, NULL, 0 }, }; static void usage(int errs) @@ -71,7 +68,8 @@ static void usage(int errs) " -d,--debug Exercise extended read I2C over USB\n" " and print verbose debug messages.\n" " -h,--help Show this message\n" - "\n", progname, firmware_binary, vid, pid, ep_num); + "\n", + progname, firmware_binary, vid, pid, ep_num); exit(!!errs); } @@ -87,28 +85,28 @@ static void parse_cmdline(int argc, char *argv[]) else progname = argv[0]; - opterr = 0; /* quiet, you */ + opterr = 0; /* quiet, you */ while ((i = getopt_long(argc, argv, short_opts, long_opts, 0)) != -1) { switch (i) { case 'f': firmware_binary = optarg; break; case 'p': - pid = (uint16_t) strtoull(optarg, &e, 16); + pid = (uint16_t)strtoull(optarg, &e, 16); if (!*optarg || (e && *e)) { printf("Invalid argument: \"%s\"\n", optarg); errorcnt++; } break; case 'v': - vid = (uint16_t) strtoull(optarg, &e, 16); + vid = (uint16_t)strtoull(optarg, &e, 16); if (!*optarg || (e && *e)) { printf("Invalid argument: \"%s\"\n", optarg); errorcnt++; } break; case 'e': - ep_num = (uint8_t) strtoull(optarg, &e, 0); + ep_num = (uint8_t)strtoull(optarg, &e, 0); if (!*optarg || (e && *e)) { printf("Invalid argument: \"%s\"\n", optarg); errorcnt++; @@ -120,7 +118,7 @@ static void parse_cmdline(int argc, char *argv[]) case 'h': usage(errorcnt); break; - case 0: /* auto-handled option */ + case 0: /* auto-handled option */ break; case '?': if (optopt) @@ -142,7 +140,6 @@ static void parse_cmdline(int argc, char *argv[]) if (errorcnt) usage(errorcnt); - } /* USB transfer related */ @@ -163,7 +160,7 @@ static void request_exit(const char *format, ...) va_start(ap, format); vfprintf(stderr, format, ap); va_end(ap); - do_exit++; /* Why need this ? */ + do_exit++; /* Why need this ? */ if (tx_transfer) libusb_free_transfer(tx_transfer); @@ -178,9 +175,8 @@ static void request_exit(const char *format, ...) exit(1); } -#define DIE(msg, r) \ - request_exit("%s: line %d, %s\n", msg, __LINE__, \ - libusb_error_name(r)) +#define DIE(msg, r) \ + request_exit("%s: line %d, %s\n", msg, __LINE__, libusb_error_name(r)) static void sighandler(int signum) { @@ -259,8 +255,8 @@ static void register_sigaction(void) } /* Transfer over libusb */ -#define I2C_PORT_ON_HAMMER 0x00 -#define I2C_ADDRESS_ON_HAMMER 0x15 +#define I2C_PORT_ON_HAMMER 0x00 +#define I2C_ADDRESS_ON_HAMMER 0x15 static int check_read_status(int r, int expected, int actual) { @@ -291,12 +287,12 @@ static int check_read_status(int r, int expected, int actual) return r; } -#define MAX_USB_PACKET_SIZE 64 -#define PRIMITIVE_READING_SIZE 60 +#define MAX_USB_PACKET_SIZE 64 +#define PRIMITIVE_READING_SIZE 60 -static int libusb_single_write_and_read( - const uint8_t *to_write, uint16_t write_length, - uint8_t *to_read, uint16_t read_length) +static int libusb_single_write_and_read(const uint8_t *to_write, + uint16_t write_length, uint8_t *to_read, + uint16_t read_length) { int r; int tx_ready; @@ -315,10 +311,10 @@ static int libusb_single_write_and_read( tx_buf[4] = read_length >> 7; if (extended_i2c_exercise) { printf("Triggering extended reading." - "rc:%0x, rc1:%0x\n", - tx_buf[3], tx_buf[4]); + "rc:%0x, rc1:%0x\n", + tx_buf[3], tx_buf[4]); printf("Expecting %d Bytes.\n", - (tx_buf[3] & 0x7f) | (tx_buf[4] << 7)); + (tx_buf[3] & 0x7f) | (tx_buf[4] << 7)); } } else { tx_buf[3] = read_length; @@ -331,19 +327,18 @@ static int libusb_single_write_and_read( while (sent_bytes < (offset + write_length)) { tx_ready = remains = (offset + write_length) - sent_bytes; - r = libusb_bulk_transfer(devh, - (ep_num | LIBUSB_ENDPOINT_OUT), - tx_buf + sent_bytes, tx_ready, - &actual_length, 5000); + r = libusb_bulk_transfer(devh, (ep_num | LIBUSB_ENDPOINT_OUT), + tx_buf + sent_bytes, tx_ready, + &actual_length, 5000); if (r == 0 && actual_length == tx_ready) { r = libusb_bulk_transfer(devh, - (ep_num | LIBUSB_ENDPOINT_IN), - rx_buf, sizeof(rx_buf), - &actual_length, 5000); + (ep_num | LIBUSB_ENDPOINT_IN), + rx_buf, sizeof(rx_buf), + &actual_length, 5000); } - r = check_read_status( - r, (remains == tx_ready) ? read_length : 0, - actual_length); + r = check_read_status(r, + (remains == tx_ready) ? read_length : 0, + actual_length); if (r) break; sent_bytes += tx_ready; @@ -352,21 +347,19 @@ static int libusb_single_write_and_read( } /* Control Elan trackpad I2C over USB */ -#define ETP_I2C_INF_LENGTH 2 +#define ETP_I2C_INF_LENGTH 2 -static int elan_write_and_read( - int reg, uint8_t *buf, int read_length, - int with_cmd, int cmd) +static int elan_write_and_read(int reg, uint8_t *buf, int read_length, + int with_cmd, int cmd) { - tx_buf[0] = (reg >> 0) & 0xff; tx_buf[1] = (reg >> 8) & 0xff; if (with_cmd) { tx_buf[2] = (cmd >> 0) & 0xff; tx_buf[3] = (cmd >> 8) & 0xff; } - return libusb_single_write_and_read( - tx_buf, with_cmd ? 4 : 2, rx_buf, read_length); + return libusb_single_write_and_read(tx_buf, with_cmd ? 4 : 2, rx_buf, + read_length); } static int elan_read_block(int reg, uint8_t *buf, int read_length) @@ -385,16 +378,16 @@ static int elan_write_cmd(int reg, int cmd) } /* Elan trackpad firmware information related */ -#define ETP_I2C_IAP_VERSION_CMD 0x0110 -#define ETP_I2C_FW_VERSION_CMD 0x0102 -#define ETP_I2C_IAP_CHECKSUM_CMD 0x0315 -#define ETP_I2C_FW_CHECKSUM_CMD 0x030F -#define ETP_I2C_OSM_VERSION_CMD 0x0103 +#define ETP_I2C_IAP_VERSION_CMD 0x0110 +#define ETP_I2C_FW_VERSION_CMD 0x0102 +#define ETP_I2C_IAP_CHECKSUM_CMD 0x0315 +#define ETP_I2C_FW_CHECKSUM_CMD 0x030F +#define ETP_I2C_OSM_VERSION_CMD 0x0103 static int elan_get_version(int is_iap) { - elan_read_cmd( - is_iap ? ETP_I2C_IAP_VERSION_CMD : ETP_I2C_FW_VERSION_CMD); + elan_read_cmd(is_iap ? ETP_I2C_IAP_VERSION_CMD : + ETP_I2C_FW_VERSION_CMD); return le_bytes_to_int(rx_buf + 4); } @@ -435,8 +428,8 @@ static void elan_get_ic_page_count(void) static int elan_get_checksum(int is_iap) { - elan_read_cmd( - is_iap ? ETP_I2C_IAP_CHECKSUM_CMD : ETP_I2C_FW_CHECKSUM_CMD); + elan_read_cmd(is_iap ? ETP_I2C_IAP_CHECKSUM_CMD : + ETP_I2C_FW_CHECKSUM_CMD); return le_bytes_to_int(rx_buf + 4); } @@ -451,21 +444,21 @@ static uint16_t elan_get_fw_info(void) iap_checksum = elan_get_checksum(1); fw_version = elan_get_version(0); iap_version = elan_get_version(1); - printf("IAP version: %4x, FW version: %4x\n", - iap_version, fw_version); - printf("IAP checksum: %4x, FW checksum: %4x\n", - iap_checksum, fw_checksum); + printf("IAP version: %4x, FW version: %4x\n", iap_version, + fw_version); + printf("IAP checksum: %4x, FW checksum: %4x\n", iap_checksum, + fw_checksum); return fw_checksum; } /* Update preparation */ -#define ETP_I2C_IAP_RESET_CMD 0x0314 -#define ETP_I2C_IAP_RESET 0xF0F0 -#define ETP_I2C_IAP_CTRL_CMD 0x0310 -#define ETP_I2C_MAIN_MODE_ON (1 << 9) -#define ETP_I2C_IAP_CMD 0x0311 -#define ETP_I2C_IAP_PASSWORD 0x1EA5 -#define ETP_I2C_IAP_TYPE_CMD 0x0304 +#define ETP_I2C_IAP_RESET_CMD 0x0314 +#define ETP_I2C_IAP_RESET 0xF0F0 +#define ETP_I2C_IAP_CTRL_CMD 0x0310 +#define ETP_I2C_MAIN_MODE_ON (1 << 9) +#define ETP_I2C_IAP_CMD 0x0311 +#define ETP_I2C_IAP_PASSWORD 0x1EA5 +#define ETP_I2C_IAP_TYPE_CMD 0x0304 static int elan_in_main_mode(void) { @@ -478,8 +471,7 @@ static int elan_read_write_iap_type(void) for (int retry = 0; retry < 3; ++retry) { uint16_t val; - if (elan_write_cmd(ETP_I2C_IAP_TYPE_CMD, - fw_page_size / 2)) + if (elan_write_cmd(ETP_I2C_IAP_TYPE_CMD, fw_page_size / 2)) return -1; if (elan_read_cmd(ETP_I2C_IAP_TYPE_CMD)) @@ -490,7 +482,6 @@ static int elan_read_write_iap_type(void) printf("%s: OK\n", __func__); return 0; } - } return -1; } @@ -528,17 +519,17 @@ static void elan_prepare_for_update(void) request_exit("cannot read iap password.\n"); if (le_bytes_to_int(rx_buf + 4) != ETP_I2C_IAP_PASSWORD) request_exit("Got an unexpected IAP password %4x\n", - le_bytes_to_int(rx_buf + 4)); + le_bytes_to_int(rx_buf + 4)); } /* Firmware block update */ -#define ETP_IAP_START_ADDR 0x0083 +#define ETP_IAP_START_ADDR 0x0083 static uint16_t elan_calc_checksum(uint8_t *data, int length) { uint16_t checksum = 0; for (int i = 0; i < length; i += 2) - checksum += ((uint16_t)(data[i+1]) << 8) | (data[i]); + checksum += ((uint16_t)(data[i + 1]) << 8) | (data[i]); return checksum; } @@ -547,11 +538,11 @@ static int elan_get_iap_addr(void) return le_bytes_to_int(fw_data + ETP_IAP_START_ADDR * 2) * 2; } -#define ETP_I2C_IAP_REG_L 0x01 -#define ETP_I2C_IAP_REG_H 0x06 +#define ETP_I2C_IAP_REG_L 0x01 +#define ETP_I2C_IAP_REG_H 0x06 -#define ETP_FW_IAP_PAGE_ERR (1 << 5) -#define ETP_FW_IAP_INTF_ERR (1 << 4) +#define ETP_FW_IAP_PAGE_ERR (1 << 5) +#define ETP_FW_IAP_INTF_ERR (1 << 4) static int elan_write_fw_block(uint8_t *raw_data, uint16_t checksum) { @@ -564,8 +555,8 @@ static int elan_write_fw_block(uint8_t *raw_data, uint16_t checksum) page_store[fw_page_size + 2 + 0] = (checksum >> 0) & 0xff; page_store[fw_page_size + 2 + 1] = (checksum >> 8) & 0xff; - rv = libusb_single_write_and_read( - page_store, fw_page_size + 4, rx_buf, 0); + rv = libusb_single_write_and_read(page_store, fw_page_size + 4, rx_buf, + 0); if (rv) return rv; usleep((fw_page_size >= 512 ? 50 : 35) * 1000); @@ -578,7 +569,6 @@ static int elan_write_fw_block(uint8_t *raw_data, uint16_t checksum) return 0; } - static uint16_t elan_update_firmware(void) { uint16_t checksum = 0, block_checksum; @@ -661,7 +651,7 @@ int main(int argc, char *argv[]) remote_checksum = elan_get_checksum(1); if (remote_checksum != local_checksum) printf("checksum diff local=[%04X], remote=[%04X]\n", - local_checksum, remote_checksum); + local_checksum, remote_checksum); /* Print the updated firmware information */ elan_get_fw_info(); diff --git a/extra/usb_console/Makefile b/extra/usb_console/Makefile index bddca1d0a2..bc4c5909a2 100644 --- a/extra/usb_console/Makefile +++ b/extra/usb_console/Makefile @@ -1,4 +1,4 @@ -# Copyright 2015 The Chromium OS Authors. All rights reserved. +# Copyright 2015 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/usb_console/usb_console.c b/extra/usb_console/usb_console.c index e4f8ea504f..aea9eb8293 100644 --- a/extra/usb_console/usb_console.c +++ b/extra/usb_console/usb_console.c @@ -1,5 +1,5 @@ /* - * Copyright 2015 The Chromium OS Authors. All rights reserved. + * Copyright 2015 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -18,12 +18,12 @@ #include <libusb.h> /* Options */ -static uint16_t vid = 0x18d1; /* Google */ -static uint16_t pid = 0x500f; /* discovery-stm32f072 */ -static uint8_t ep_num = 4; /* console endpoint */ +static uint16_t vid = 0x18d1; /* Google */ +static uint16_t pid = 0x500f; /* discovery-stm32f072 */ +static uint8_t ep_num = 4; /* console endpoint */ -static unsigned char rx_buf[1024]; /* much too big */ -static unsigned char tx_buf[1024]; /* much too big */ +static unsigned char rx_buf[1024]; /* much too big */ +static unsigned char tx_buf[1024]; /* much too big */ static const struct libusb_pollfd **usb_fds; static struct libusb_device_handle *devh; static struct libusb_transfer *rx_transfer; @@ -40,9 +40,8 @@ static void request_exit(const char *format, ...) do_exit++; } -#define BOO(msg, r) \ - request_exit("%s: line %d, %s\n", msg, __LINE__, \ - libusb_error_name(r)) +#define BOO(msg, r) \ + request_exit("%s: line %d, %s\n", msg, __LINE__, libusb_error_name(r)) static void sighandler(int signum) { @@ -105,8 +104,8 @@ static void send_tx(int len) { int r; - libusb_fill_bulk_transfer(tx_transfer, devh, - ep_num, tx_buf, len, cb_tx, NULL, 0); + libusb_fill_bulk_transfer(tx_transfer, devh, ep_num, tx_buf, len, cb_tx, + NULL, 0); r = libusb_submit_transfer(tx_transfer); if (r < 0) @@ -185,7 +184,7 @@ static int wait_for_stuff_to_happen(void) return -1; } - if (r == 0) /* timed out */ + if (r == 0) /* timed out */ return 0; /* Ignore stdin until we've finished sending the current line */ @@ -235,11 +234,9 @@ static char *progname; static char *short_opts = ":v:p:e:h"; static const struct option long_opts[] = { /* name hasarg *flag val */ - {"vid", 1, NULL, 'v'}, - {"pid", 1, NULL, 'p'}, - {"ep", 1, NULL, 'e'}, - {"help", 0, NULL, 'h'}, - {NULL, 0, NULL, 0}, + { "vid", 1, NULL, 'v' }, { "pid", 1, NULL, 'p' }, + { "ep", 1, NULL, 'e' }, { "help", 0, NULL, 'h' }, + { NULL, 0, NULL, 0 }, }; static void usage(int errs) @@ -254,7 +251,8 @@ static void usage(int errs) " -p,--pid HEXVAL Product ID (default %04x)\n" " -e,--ep NUM Endpoint (default %d)\n" " -h,--help Show this message\n" - "\n", progname, vid, pid, ep_num); + "\n", + progname, vid, pid, ep_num); exit(!!errs); } @@ -275,25 +273,25 @@ int main(int argc, char *argv[]) else progname = argv[0]; - opterr = 0; /* quiet, you */ + opterr = 0; /* quiet, you */ while ((i = getopt_long(argc, argv, short_opts, long_opts, 0)) != -1) { switch (i) { case 'p': - pid = (uint16_t) strtoull(optarg, &e, 16); + pid = (uint16_t)strtoull(optarg, &e, 16); if (!*optarg || (e && *e)) { printf("Invalid argument: \"%s\"\n", optarg); errorcnt++; } break; case 'v': - vid = (uint16_t) strtoull(optarg, &e, 16); + vid = (uint16_t)strtoull(optarg, &e, 16); if (!*optarg || (e && *e)) { printf("Invalid argument: \"%s\"\n", optarg); errorcnt++; } break; case 'e': - ep_num = (uint8_t) strtoull(optarg, &e, 0); + ep_num = (uint8_t)strtoull(optarg, &e, 0); if (!*optarg || (e && *e)) { printf("Invalid argument: \"%s\"\n", optarg); errorcnt++; @@ -302,7 +300,7 @@ int main(int argc, char *argv[]) case 'h': usage(errorcnt); break; - case 0: /* auto-handled option */ + case 0: /* auto-handled option */ break; case '?': if (optopt) @@ -368,9 +366,8 @@ int main(int argc, char *argv[]) printf("can't alloc rx_transfer"); goto out; } - libusb_fill_bulk_transfer(rx_transfer, devh, - 0x80 | ep_num, - rx_buf, sizeof(rx_buf), cb_rx, NULL, 0); + libusb_fill_bulk_transfer(rx_transfer, devh, 0x80 | ep_num, rx_buf, + sizeof(rx_buf), cb_rx, NULL, 0); tx_transfer = libusb_alloc_transfer(0); if (!tx_transfer) { @@ -396,14 +393,14 @@ int main(int argc, char *argv[]) while (!do_exit) { r = wait_for_stuff_to_happen(); switch (r) { - case 0: /* timed out */ + case 0: /* timed out */ /* printf("."); */ /* fflush(stdout); */ break; - case 1: /* stdin ready */ + case 1: /* stdin ready */ handle_stdin(); break; - case 2: /* libusb ready */ + case 2: /* libusb ready */ handle_libusb(); break; } @@ -440,7 +437,7 @@ int main(int argc, char *argv[]) printf("bye\n"); r = 0; - out: +out: if (tx_transfer) libusb_free_transfer(tx_transfer); if (rx_transfer) diff --git a/extra/usb_gpio/Makefile b/extra/usb_gpio/Makefile index 644e3ee70f..84a27ccc12 100644 --- a/extra/usb_gpio/Makefile +++ b/extra/usb_gpio/Makefile @@ -1,4 +1,4 @@ -# Copyright 2014 The Chromium OS Authors. All rights reserved. +# Copyright 2014 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/extra/usb_gpio/usb_gpio.c b/extra/usb_gpio/usb_gpio.c index 8973f3d304..7f2121d2b0 100644 --- a/extra/usb_gpio/usb_gpio.c +++ b/extra/usb_gpio/usb_gpio.c @@ -1,5 +1,5 @@ /* - * Copyright 2014 The Chromium OS Authors. All rights reserved. + * Copyright 2014 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -11,54 +11,46 @@ #include <string.h> #include <unistd.h> -#define CHECK(expression) \ - ({ \ - int error__ = (expression); \ - \ - if (error__ != 0) { \ - fprintf(stderr, \ - "libusb error: %s:%d %s\n", \ - __FILE__, \ - __LINE__, \ - libusb_error_name(error__)); \ - return error__; \ - } \ - \ - error__; \ +#define CHECK(expression) \ + ({ \ + int error__ = (expression); \ + \ + if (error__ != 0) { \ + fprintf(stderr, "libusb error: %s:%d %s\n", __FILE__, \ + __LINE__, libusb_error_name(error__)); \ + return error__; \ + } \ + \ + error__; \ }) #define TRANSFER_TIMEOUT_MS 100 -static int gpio_write(libusb_device_handle *device, - uint32_t set_mask, +static int gpio_write(libusb_device_handle *device, uint32_t set_mask, uint32_t clear_mask) { uint8_t command[8]; - int transferred; + int transferred; - command[0] = (set_mask >> 0) & 0xff; - command[1] = (set_mask >> 8) & 0xff; + command[0] = (set_mask >> 0) & 0xff; + command[1] = (set_mask >> 8) & 0xff; command[2] = (set_mask >> 16) & 0xff; command[3] = (set_mask >> 24) & 0xff; - command[4] = (clear_mask >> 0) & 0xff; - command[5] = (clear_mask >> 8) & 0xff; + command[4] = (clear_mask >> 0) & 0xff; + command[5] = (clear_mask >> 8) & 0xff; command[6] = (clear_mask >> 16) & 0xff; command[7] = (clear_mask >> 24) & 0xff; - CHECK(libusb_bulk_transfer(device, - LIBUSB_ENDPOINT_OUT | 2, - command, - sizeof(command), - &transferred, + CHECK(libusb_bulk_transfer(device, LIBUSB_ENDPOINT_OUT | 2, command, + sizeof(command), &transferred, TRANSFER_TIMEOUT_MS)); if (transferred != sizeof(command)) { fprintf(stderr, "Failed to transfer full command " "(sent %d of %d bytes)\n", - transferred, - (int)sizeof(command)); + transferred, (int)sizeof(command)); return LIBUSB_ERROR_OTHER; } @@ -68,38 +60,29 @@ static int gpio_write(libusb_device_handle *device, static int gpio_read(libusb_device_handle *device, uint32_t *mask) { uint8_t response[4]; - int transferred; + int transferred; /* * The first query does triggers the sampling of the GPIO values, the * second query reads them back. */ - CHECK(libusb_bulk_transfer(device, - LIBUSB_ENDPOINT_IN | 2, - response, - sizeof(response), - &transferred, + CHECK(libusb_bulk_transfer(device, LIBUSB_ENDPOINT_IN | 2, response, + sizeof(response), &transferred, TRANSFER_TIMEOUT_MS)); - CHECK(libusb_bulk_transfer(device, - LIBUSB_ENDPOINT_IN | 2, - response, - sizeof(response), - &transferred, + CHECK(libusb_bulk_transfer(device, LIBUSB_ENDPOINT_IN | 2, response, + sizeof(response), &transferred, TRANSFER_TIMEOUT_MS)); if (transferred != sizeof(response)) { fprintf(stderr, "Failed to transfer full response " "(read %d of %d bytes)\n", - transferred, - (int)sizeof(response)); + transferred, (int)sizeof(response)); return LIBUSB_ERROR_OTHER; } - *mask = (response[0] << 0 | - response[1] << 8 | - response[2] << 16 | + *mask = (response[0] << 0 | response[1] << 8 | response[2] << 16 | response[3] << 24); return 0; @@ -107,13 +90,13 @@ static int gpio_read(libusb_device_handle *device, uint32_t *mask) int main(int argc, char **argv) { - libusb_context *context; + libusb_context *context; libusb_device_handle *device; - uint16_t vendor_id = 0x18d1; /* Google */ - uint16_t product_id = 0x500f; /* discovery-stm32f072 */ - int interface = 1; /* gpio interface */ + uint16_t vendor_id = 0x18d1; /* Google */ + uint16_t product_id = 0x500f; /* discovery-stm32f072 */ + int interface = 1; /* gpio interface */ - if (!(argc == 2 && strcmp(argv[1], "read") == 0) && + if (!(argc == 2 && strcmp(argv[1], "read") == 0) && !(argc == 4 && strcmp(argv[1], "write") == 0)) { puts("Usage: usb_gpio read\n" " usb_gpio write <set_mask> <clear_mask>\n"); @@ -122,15 +105,12 @@ int main(int argc, char **argv) CHECK(libusb_init(&context)); - device = libusb_open_device_with_vid_pid(context, - vendor_id, - product_id); + device = + libusb_open_device_with_vid_pid(context, vendor_id, product_id); if (device == NULL) { - fprintf(stderr, - "Unable to find device 0x%04x:0x%04x\n", - vendor_id, - product_id); + fprintf(stderr, "Unable to find device 0x%04x:0x%04x\n", + vendor_id, product_id); return 1; } @@ -146,7 +126,7 @@ int main(int argc, char **argv) } if (argc == 4 && strcmp(argv[1], "write") == 0) { - uint32_t set_mask = strtol(argv[2], NULL, 0); + uint32_t set_mask = strtol(argv[2], NULL, 0); uint32_t clear_mask = strtol(argv[3], NULL, 0); CHECK(gpio_write(device, set_mask, clear_mask)); diff --git a/extra/usb_power/convert_power_log_board.py b/extra/usb_power/convert_power_log_board.py index 8aab77ee4c..f5fb7e925d 100644 --- a/extra/usb_power/convert_power_log_board.py +++ b/extra/usb_power/convert_power_log_board.py @@ -1,11 +1,7 @@ #!/usr/bin/env python -# Copyright 2018 The Chromium OS Authors. All rights reserved. +# Copyright 2018 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """ Program to convert sweetberry config to servod config template. @@ -14,11 +10,12 @@ Program to convert sweetberry config to servod config template. # Note: This is a py2/3 compatible file. from __future__ import print_function + import json import os import sys -from powerlog import Spower +from powerlog import Spower # pylint:disable=import-error def fetch_records(board_file): @@ -48,21 +45,29 @@ def write_to_file(file, sweetberry, inas): inas: list of inas read from board file. """ - with open(file, 'w') as pyfile: + with open(file, "w") as pyfile: - pyfile.write('inas = [\n') + pyfile.write("inas = [\n") for rec in inas: - if rec['sweetberry'] != sweetberry: + if rec["sweetberry"] != sweetberry: continue # EX : ('sweetberry', 0x40, 'SB_FW_CAM_2P8', 5.0, 1.000, 3, False), - channel, i2c_addr = Spower.CHMAP[rec['channel']] - record = (" ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" - ",\n" % (i2c_addr, rec['name'], rec['rs'], channel)) + channel, i2c_addr = Spower.CHMAP[rec["channel"]] + record = ( + " ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" + ",\n" + % ( + i2c_addr, + rec["name"], + rec["rs"], + channel, + ) + ) pyfile.write(record) - pyfile.write(']\n') + pyfile.write("]\n") def main(argv): @@ -76,16 +81,18 @@ def main(argv): inas = fetch_records(inputf) - sweetberry = set(rec['sweetberry'] for rec in inas) + sweetberry = set(rec["sweetberry"] for rec in inas) if len(sweetberry) == 2: - print("Converting %s to %s and %s" % (inputf, basename + '_a.py', - basename + '_b.py')) - write_to_file(basename + '_a.py', 'A', inas) - write_to_file(basename + '_b.py', 'B', inas) + print( + "Converting %s to %s and %s" + % (inputf, basename + "_a.py", basename + "_b.py") + ) + write_to_file(basename + "_a.py", "A", inas) + write_to_file(basename + "_b.py", "B", inas) else: - print("Converting %s to %s" % (inputf, basename + '.py')) - write_to_file(basename + '.py', sweetberry.pop(), inas) + print("Converting %s to %s" % (inputf, basename + ".py")) + write_to_file(basename + ".py", sweetberry.pop(), inas) if __name__ == "__main__": diff --git a/extra/usb_power/convert_servo_ina.py b/extra/usb_power/convert_servo_ina.py index 1c70f31aeb..1deb75cda4 100755 --- a/extra/usb_power/convert_servo_ina.py +++ b/extra/usb_power/convert_servo_ina.py @@ -1,11 +1,7 @@ #!/usr/bin/env python -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Program to convert power logging config from a servo_ina device to a sweetberry config. @@ -14,67 +10,74 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import os import sys def fetch_records(basename): - """Import records from servo_ina file. + """Import records from servo_ina file. - servo_ina files are python imports, and have a list of tuples with - the INA data. - (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True) + servo_ina files are python imports, and have a list of tuples with + the INA data. + (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True) - Args: - basename: python import name (filename -.py) + Args: + basename: python import name (filename -.py) - Returns: - list of tuples as described above. - """ - ina_desc = __import__(basename) - return ina_desc.inas + Returns: + list of tuples as described above. + """ + ina_desc = __import__(basename) + return ina_desc.inas def main(argv): - if len(argv) != 2: - print("usage:") - print(" %s input.py" % argv[0]) - return + if len(argv) != 2: + print("usage:") + print(" %s input.py" % argv[0]) + return - inputf = argv[1] - basename = os.path.splitext(inputf)[0] - outputf = basename + '.board' - outputs = basename + '.scenario' + inputf = argv[1] + basename = os.path.splitext(inputf)[0] + outputf = basename + ".board" + outputs = basename + ".scenario" - print("Converting %s to %s, %s" % (inputf, outputf, outputs)) + print("Converting %s to %s, %s" % (inputf, outputf, outputs)) - inas = fetch_records(basename) + inas = fetch_records(basename) + boardfile = open(outputf, "w") + scenario = open(outputs, "w") - boardfile = open(outputf, 'w') - scenario = open(outputs, 'w') + boardfile.write("[\n") + scenario.write("[\n") + start = True - boardfile.write('[\n') - scenario.write('[\n') - start = True + for rec in inas: + if start: + start = False + else: + boardfile.write(",\n") + scenario.write(",\n") - for rec in inas: - if start: - start = False - else: - boardfile.write(',\n') - scenario.write(',\n') + record = ( + ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' + % ( + rec[2], + rec[4], + rec[1] - 64, + ) + ) + boardfile.write(record) + scenario.write('"%s"' % rec[2]) - record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % ( - rec[2], rec[4], rec[1] - 64) - boardfile.write(record) - scenario.write('"%s"' % rec[2]) + boardfile.write("\n") + boardfile.write("]") - boardfile.write('\n') - boardfile.write(']') + scenario.write("\n") + scenario.write("]") - scenario.write('\n') - scenario.write(']') if __name__ == "__main__": - main(sys.argv) + main(sys.argv) diff --git a/extra/usb_power/powerlog.py b/extra/usb_power/powerlog.py index 82cce3daed..13e41bd23a 100755 --- a/extra/usb_power/powerlog.py +++ b/extra/usb_power/powerlog.py @@ -1,11 +1,7 @@ #!/usr/bin/env python -# Copyright 2016 The Chromium OS Authors. All rights reserved. +# Copyright 2016 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Program to fetch power logging data from a sweetberry device or other usb device that exports a USB power logging interface. @@ -14,9 +10,9 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import array -from distutils import sysconfig import json import logging import os @@ -25,884 +21,1041 @@ import struct import sys import time import traceback +from distutils import sysconfig -import usb - -from stats_manager import StatsManager +import usb # pylint:disable=import-error +from stats_manager import StatsManager # pylint:disable=import-error # Directory where hdctools installs configuration files into. -LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), 'servo', - 'data') +LIB_DIR = os.path.join( + sysconfig.get_python_lib(standard_lib=False), "servo", "data" +) # Potential config file locations: current working directory, the same directory # as powerlog.py file or LIB_DIR. -CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), - LIB_DIR] - -def logoutput(msg): - print(msg) - sys.stdout.flush() - -def process_filename(filename): - """Find the file path from the filename. - - If filename is already the complete path, return that directly. If filename is - just the short name, look for the file in the current working directory, in - the directory of the current .py file, and then in the directory installed by - hdctools. If the file is found, return the complete path of the file. - - Args: - filename: complete file path or short file name. - - Returns: - a complete file path. - - Raises: - IOError if filename does not exist. - """ - # Check if filename is absolute path. - if os.path.isabs(filename) and os.path.isfile(filename): - return filename - # Check if filename is relative to a known config location. - for dirname in CONFIG_LOCATIONS: - file_at_dir = os.path.join(dirname, filename) - if os.path.isfile(file_at_dir): - return file_at_dir - raise IOError('No such file or directory: \'%s\'' % filename) - - -class Spower(object): - """Power class to access devices on the bus. - - Usage: - bus = Spower() - - Instance Variables: - _dev: pyUSB device object - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - - # INA interface type. - INA_POWER = 1 - INA_BUSV = 2 - INA_CURRENT = 3 - INA_SHUNTV = 4 - # INA_SUFFIX is used to differentiate multiple ina types for the same power - # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1 - # (power, no suffix for backward compatibility). - INA_SUFFIX = ['', '', '_busv', '_cur', '_shuntv'] - - # usb power commands - CMD_RESET = 0x0000 - CMD_STOP = 0x0001 - CMD_ADDINA = 0x0002 - CMD_START = 0x0003 - CMD_NEXT = 0x0004 - CMD_SETTIME = 0x0005 - - # Map between header channel number (0-47) - # and INA I2C bus/addr on sweetberry. - CHMAP = { - 0: (3, 0x40), - 1: (1, 0x40), - 2: (2, 0x40), - 3: (0, 0x40), - 4: (3, 0x41), - 5: (1, 0x41), - 6: (2, 0x41), - 7: (0, 0x41), - 8: (3, 0x42), - 9: (1, 0x42), - 10: (2, 0x42), - 11: (0, 0x42), - 12: (3, 0x43), - 13: (1, 0x43), - 14: (2, 0x43), - 15: (0, 0x43), - 16: (3, 0x44), - 17: (1, 0x44), - 18: (2, 0x44), - 19: (0, 0x44), - 20: (3, 0x45), - 21: (1, 0x45), - 22: (2, 0x45), - 23: (0, 0x45), - 24: (3, 0x46), - 25: (1, 0x46), - 26: (2, 0x46), - 27: (0, 0x46), - 28: (3, 0x47), - 29: (1, 0x47), - 30: (2, 0x47), - 31: (0, 0x47), - 32: (3, 0x48), - 33: (1, 0x48), - 34: (2, 0x48), - 35: (0, 0x48), - 36: (3, 0x49), - 37: (1, 0x49), - 38: (2, 0x49), - 39: (0, 0x49), - 40: (3, 0x4a), - 41: (1, 0x4a), - 42: (2, 0x4a), - 43: (0, 0x4a), - 44: (3, 0x4b), - 45: (1, 0x4b), - 46: (2, 0x4b), - 47: (0, 0x4b), - } - - def __init__(self, board, vendor=0x18d1, - product=0x5020, interface=1, serialname=None): - self._logger = logging.getLogger(__name__) - self._board = board - - # Find the stm32. - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise Exception("Power", "USB device not found") - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - dev_serial = "PyUSB dioesn't have a stable interface" - try: - dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) - except ValueError: - # Incompatible pyUsb version. - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - raise Exception("Power", "USB device(%s) not found" % serialname) - else: - try: - dev = dev_list[0] - except TypeError: - # Incompatible pyUsb version. - dev = dev_list.next() - - self._logger.debug("Found USB device: %04x:%04x", vendor, product) - self._dev = dev - - # Get an endpoint instance. - try: - dev.set_configuration() - except usb.USBError: - pass - cfg = dev.get_active_configuration() - - intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \ - i.bInterfaceClass==255 and i.bInterfaceSubClass==0x54) - - self._intf = intf - self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber) - - read_ep = usb.util.find_descriptor( - intf, - # match the first IN endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_IN - ) - - self._read_ep = read_ep - self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress) - - write_ep = usb.util.find_descriptor( - intf, - # match the first OUT endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_OUT - ) - - self._write_ep = write_ep - self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress) - - self.clear_ina_struct() - - self._logger.debug("Found power logging USB endpoint.") - - def clear_ina_struct(self): - """ Clear INA description struct.""" - self._inas = [] +CONFIG_LOCATIONS = [ + os.getcwd(), + os.path.dirname(os.path.realpath(__file__)), + LIB_DIR, +] - def append_ina_struct(self, name, rs, port, addr, - data=None, ina_type=INA_POWER): - """Add an INA descriptor into the list of active INAs. - Args: - name: Readable name of this channel. - rs: Sense resistor value in ohms, floating point. - port: I2C channel this INA is connected to. - addr: I2C addr of this INA. - data: Misc data for special handling, board specific. - ina_type: INA function to use, power, voltage, etc. - """ - ina = {} - ina['name'] = name - ina['rs'] = rs - ina['port'] = port - ina['addr'] = addr - ina['type'] = ina_type - # Calculate INA231 Calibration register - # (see INA231 spec p.15) - # CurrentLSB = uA per div = 80mV / (Rsh * 2^15) - # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000) - ina['uAscale'] = 80000000. / (rs * 0x8000); - ina['uWscale'] = 25. * ina['uAscale']; - ina['mVscale'] = 1.25 - ina['uVscale'] = 2.5 - ina['data'] = data - self._inas.append(ina) - - def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000): - """Write command to logger logic. - - This function writes byte command values list to stm, then reads - byte status. - - Args: - write_list: list of command byte values [0~255]. - read_count: number of status byte values to read. - - Interface: - write: [command, data ... ] - read: [status ] - - Returns: - bytes read, or None on failure. - """ - self._logger.debug("Spower.wr_command(write_list=[%s] (%d), read_count=%s)", - list(bytearray(write_list)), len(write_list), read_count) - - # Clean up args from python style to correct types. - write_length = 0 - if write_list: - write_length = len(write_list) - if not read_count: - read_count = 0 - - # Send command to stm32. - if write_list: - cmd = write_list - ret = self._write_ep.write(cmd, wtimeout) - - self._logger.debug("RET: %s ", ret) - - # Read back response if necessary. - if read_count: - bytesread = self._read_ep.read(512, rtimeout) - self._logger.debug("BYTES: [%s]", bytesread) - - if len(bytesread) != read_count: - pass - - self._logger.debug("STATUS: 0x%02x", int(bytesread[0])) - if read_count == 1: - return bytesread[0] - else: - return bytesread - - return None - - def clear(self): - """Clear pending reads on the stm32""" - try: - while True: - ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50) - self._logger.debug("Try Clear: read %s", - "success" if ret == 0 else "failure") - except: - pass - - def send_reset(self): - """Reset the power interface on the stm32""" - cmd = struct.pack("<H", self.CMD_RESET) - ret = self.wr_command(cmd, rtimeout=50, wtimeout=50) - self._logger.debug("Command RESET: %s", - "success" if ret == 0 else "failure") - - def reset(self): - """Try resetting the USB interface until success. - - Use linear back off strategy when encounter the error with 10ms increment. - - Raises: - Exception on failure. - """ - max_reset_retry = 100 - for count in range(1, max_reset_retry + 1): - self.clear() - try: - self.send_reset() - return - except Exception as e: - self.clear() - self.clear() - self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e) - time.sleep(count * 0.01) - raise Exception("Power", "Failed to reset") - - def stop(self): - """Stop any active data acquisition.""" - cmd = struct.pack("<H", self.CMD_STOP) - ret = self.wr_command(cmd) - self._logger.debug("Command STOP: %s", - "success" if ret == 0 else "failure") - - def start(self, integration_us): - """Start data acquisition. - - Args: - integration_us: int, how many us between samples, and - how often the data block must be read. +def logoutput(msg): + print(msg) + sys.stdout.flush() - Returns: - actual sampling interval in ms. - """ - cmd = struct.pack("<HI", self.CMD_START, integration_us) - read = self.wr_command(cmd, read_count=5) - actual_us = 0 - if len(read) == 5: - ret, actual_us = struct.unpack("<BI", read) - self._logger.debug("Command START: %s %dus", - "success" if ret == 0 else "failure", actual_us) - else: - self._logger.debug("Command START: FAIL") - return actual_us +def process_filename(filename): + """Find the file path from the filename. - def add_ina_name(self, name_tuple): - """Add INA from board config. + If filename is already the complete path, return that directly. If filename is + just the short name, look for the file in the current working directory, in + the directory of the current .py file, and then in the directory installed by + hdctools. If the file is found, return the complete path of the file. Args: - name_tuple: name and type of power rail in board config. + filename: complete file path or short file name. Returns: - True if INA added, False if the INA is not on this board. + a complete file path. Raises: - Exception on unexpected failure. + IOError if filename does not exist. """ - name, ina_type = name_tuple - - for datum in self._brdcfg: - if datum["name"] == name: - rs = int(float(datum["rs"]) * 1000.) - board = datum["sweetberry"] - - if board == self._board: - if 'port' in datum and 'addr' in datum: - port = datum['port'] - addr = datum['addr'] - else: - channel = int(datum["channel"]) - port, addr = self.CHMAP[channel] - self.add_ina(port, ina_type, addr, 0, rs, data=datum) - return True - else: - return False - raise Exception("Power", "Failed to find INA %s" % name) + # Check if filename is absolute path. + if os.path.isabs(filename) and os.path.isfile(filename): + return filename + # Check if filename is relative to a known config location. + for dirname in CONFIG_LOCATIONS: + file_at_dir = os.path.join(dirname, filename) + if os.path.isfile(file_at_dir): + return file_at_dir + raise IOError("No such file or directory: '%s'" % filename) - def set_time(self, timestamp_us): - """Set sweetberry time to match host time. - Args: - timestamp_us: host timestmap in us. - """ - # 0x0005 , 8 byte timestamp - cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us) - ret = self.wr_command(cmd) - - self._logger.debug("Command SETTIME: %s", - "success" if ret == 0 else "failure") +class Spower(object): + """Power class to access devices on the bus. - def add_ina(self, bus, ina_type, addr, extra, resistance, data=None): - """Add an INA to the data acquisition list. + Usage: + bus = Spower() - Args: - bus: which i2c bus the INA is on. Same ordering as Si2c. - ina_type: Ina interface: INA_POWER/BUSV/etc. - addr: 7 bit i2c addr of this INA - extra: extra data for nonstandard configs. - resistance: int, shunt resistance in mOhm + Instance Variables: + _dev: pyUSB device object + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs - cmd = struct.pack("<HBBBBI", self.CMD_ADDINA, - bus, ina_type, addr, extra, resistance) - ret = self.wr_command(cmd) - if ret == 0: - if data: - name = data['name'] - else: - name = "ina%d_%02x" % (bus, addr) - self.append_ina_struct(name, resistance, bus, addr, - data=data, ina_type=ina_type) - self._logger.debug("Command ADD_INA: %s", - "success" if ret == 0 else "failure") - - def report_header_size(self): - """Helper function to calculate power record header size.""" - result = 2 - timestamp = 8 - return result + timestamp - - def report_size(self, ina_count): - """Helper function to calculate full power record size.""" - record = 2 - - datasize = self.report_header_size() + ina_count * record - # Round to multiple of 4 bytes. - datasize = int(((datasize + 3) // 4) * 4) - - return datasize - - def read_line(self): - """Read a line of data from the setup INAs - Returns: - list of dicts of the values read by ina/type tuple, otherwise None. - [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}] - """ - try: - expected_bytes = self.report_size(len(self._inas)) - cmd = struct.pack("<H", self.CMD_NEXT) - bytesread = self.wr_command(cmd, read_count=expected_bytes) - except usb.core.USBError as e: - self._logger.error("READ LINE FAILED %s", e) - return None - - if len(bytesread) == 1: - if bytesread[0] != 0x6: - self._logger.debug("READ LINE FAILED bytes: %d ret: %02x", - len(bytesread), bytesread[0]) - return None - - if len(bytesread) % expected_bytes != 0: - self._logger.debug("READ LINE WARNING: expected %d, got %d", - expected_bytes, len(bytesread)) - - packet_count = len(bytesread) // expected_bytes - - values = [] - for i in range(0, packet_count): - start = i * expected_bytes - end = (i + 1) * expected_bytes - record = self.interpret_line(bytesread[start:end]) - values.append(record) - - return values - - def interpret_line(self, data): - """Interpret a power record from INAs + # INA interface type. + INA_POWER = 1 + INA_BUSV = 2 + INA_CURRENT = 3 + INA_SHUNTV = 4 + # INA_SUFFIX is used to differentiate multiple ina types for the same power + # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1 + # (power, no suffix for backward compatibility). + INA_SUFFIX = ["", "", "_busv", "_cur", "_shuntv"] + + # usb power commands + CMD_RESET = 0x0000 + CMD_STOP = 0x0001 + CMD_ADDINA = 0x0002 + CMD_START = 0x0003 + CMD_NEXT = 0x0004 + CMD_SETTIME = 0x0005 + + # Map between header channel number (0-47) + # and INA I2C bus/addr on sweetberry. + CHMAP = { + 0: (3, 0x40), + 1: (1, 0x40), + 2: (2, 0x40), + 3: (0, 0x40), + 4: (3, 0x41), + 5: (1, 0x41), + 6: (2, 0x41), + 7: (0, 0x41), + 8: (3, 0x42), + 9: (1, 0x42), + 10: (2, 0x42), + 11: (0, 0x42), + 12: (3, 0x43), + 13: (1, 0x43), + 14: (2, 0x43), + 15: (0, 0x43), + 16: (3, 0x44), + 17: (1, 0x44), + 18: (2, 0x44), + 19: (0, 0x44), + 20: (3, 0x45), + 21: (1, 0x45), + 22: (2, 0x45), + 23: (0, 0x45), + 24: (3, 0x46), + 25: (1, 0x46), + 26: (2, 0x46), + 27: (0, 0x46), + 28: (3, 0x47), + 29: (1, 0x47), + 30: (2, 0x47), + 31: (0, 0x47), + 32: (3, 0x48), + 33: (1, 0x48), + 34: (2, 0x48), + 35: (0, 0x48), + 36: (3, 0x49), + 37: (1, 0x49), + 38: (2, 0x49), + 39: (0, 0x49), + 40: (3, 0x4A), + 41: (1, 0x4A), + 42: (2, 0x4A), + 43: (0, 0x4A), + 44: (3, 0x4B), + 45: (1, 0x4B), + 46: (2, 0x4B), + 47: (0, 0x4B), + } + + def __init__( + self, board, vendor=0x18D1, product=0x5020, interface=1, serialname=None + ): + self._logger = logging.getLogger(__name__) + self._board = board + + # Find the stm32. + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise Exception("Power", "USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + dev_serial = "PyUSB dioesn't have a stable interface" + try: + dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) + except ValueError: + # Incompatible pyUsb version. + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + raise Exception( + "Power", "USB device(%s) not found" % serialname + ) + else: + dev = dev_list[0] - Args: - data: one single record of bytes. + self._logger.debug("Found USB device: %04x:%04x", vendor, product) + self._dev = dev - Output: - stdout of the record in csv format. + # Get an endpoint instance. + try: + dev.set_configuration() + except usb.USBError: + pass + cfg = dev.get_active_configuration() + + intf = usb.util.find_descriptor( + cfg, + custom_match=lambda i: i.bInterfaceClass == 255 + and i.bInterfaceSubClass == 0x54, + ) + + self._intf = intf + self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber) + + read_ep = usb.util.find_descriptor( + intf, + # match the first IN endpoint + custom_match=lambda e: usb.util.endpoint_direction( + e.bEndpointAddress + ) + == usb.util.ENDPOINT_IN, + ) + + self._read_ep = read_ep + self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress) + + write_ep = usb.util.find_descriptor( + intf, + # match the first OUT endpoint + custom_match=lambda e: usb.util.endpoint_direction( + e.bEndpointAddress + ) + == usb.util.ENDPOINT_OUT, + ) + + self._write_ep = write_ep + self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress) + + self.clear_ina_struct() + + self._logger.debug("Found power logging USB endpoint.") + + def clear_ina_struct(self): + """Clear INA description struct.""" + self._inas = [] + + def append_ina_struct( + self, name, rs, port, addr, data=None, ina_type=INA_POWER + ): + """Add an INA descriptor into the list of active INAs. + + Args: + name: Readable name of this channel. + rs: Sense resistor value in ohms, floating point. + port: I2C channel this INA is connected to. + addr: I2C addr of this INA. + data: Misc data for special handling, board specific. + ina_type: INA function to use, power, voltage, etc. + """ + ina = {} + ina["name"] = name + ina["rs"] = rs + ina["port"] = port + ina["addr"] = addr + ina["type"] = ina_type + # Calculate INA231 Calibration register + # (see INA231 spec p.15) + # CurrentLSB = uA per div = 80mV / (Rsh * 2^15) + # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000) + ina["uAscale"] = 80000000.0 / (rs * 0x8000) + ina["uWscale"] = 25.0 * ina["uAscale"] + ina["mVscale"] = 1.25 + ina["uVscale"] = 2.5 + ina["data"] = data + self._inas.append(ina) + + def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000): + """Write command to logger logic. + + This function writes byte command values list to stm, then reads + byte status. + + Args: + write_list: list of command byte values [0~255]. + read_count: number of status byte values to read. + + Interface: + write: [command, data ... ] + read: [status ] + + Returns: + bytes read, or None on failure. + """ + self._logger.debug( + "Spower.wr_command(write_list=[%s] (%d), read_count=%s)", + list(bytearray(write_list)), + len(write_list), + read_count, + ) + + # Clean up args from python style to correct types. + write_length = 0 + if write_list: + write_length = len(write_list) + if not read_count: + read_count = 0 + + # Send command to stm32. + if write_list: + cmd = write_list + ret = self._write_ep.write(cmd, wtimeout) + + self._logger.debug("RET: %s ", ret) + + # Read back response if necessary. + if read_count: + bytesread = self._read_ep.read(512, rtimeout) + self._logger.debug("BYTES: [%s]", bytesread) + + if len(bytesread) != read_count: + pass + + self._logger.debug("STATUS: 0x%02x", int(bytesread[0])) + if read_count == 1: + return bytesread[0] + else: + return bytesread + + return None + + def clear(self): + """Clear pending reads on the stm32""" + try: + while True: + ret = self.wr_command( + b"", read_count=512, rtimeout=100, wtimeout=50 + ) + self._logger.debug( + "Try Clear: read %s", "success" if ret == 0 else "failure" + ) + except: + pass + + def send_reset(self): + """Reset the power interface on the stm32""" + cmd = struct.pack("<H", self.CMD_RESET) + ret = self.wr_command(cmd, rtimeout=50, wtimeout=50) + self._logger.debug( + "Command RESET: %s", "success" if ret == 0 else "failure" + ) + + def reset(self): + """Try resetting the USB interface until success. + + Use linear back off strategy when encounter the error with 10ms increment. + + Raises: + Exception on failure. + """ + max_reset_retry = 100 + for count in range(1, max_reset_retry + 1): + self.clear() + try: + self.send_reset() + return + except Exception as e: + self.clear() + self.clear() + self._logger.debug( + "TRY %d of %d: %s", count, max_reset_retry, e + ) + time.sleep(count * 0.01) + raise Exception("Power", "Failed to reset") + + def stop(self): + """Stop any active data acquisition.""" + cmd = struct.pack("<H", self.CMD_STOP) + ret = self.wr_command(cmd) + self._logger.debug( + "Command STOP: %s", "success" if ret == 0 else "failure" + ) + + def start(self, integration_us): + """Start data acquisition. + + Args: + integration_us: int, how many us between samples, and + how often the data block must be read. + + Returns: + actual sampling interval in ms. + """ + cmd = struct.pack("<HI", self.CMD_START, integration_us) + read = self.wr_command(cmd, read_count=5) + actual_us = 0 + if len(read) == 5: + ret, actual_us = struct.unpack("<BI", read) + self._logger.debug( + "Command START: %s %dus", + "success" if ret == 0 else "failure", + actual_us, + ) + else: + self._logger.debug("Command START: FAIL") + + return actual_us + + def add_ina_name(self, name_tuple): + """Add INA from board config. + + Args: + name_tuple: name and type of power rail in board config. + + Returns: + True if INA added, False if the INA is not on this board. + + Raises: + Exception on unexpected failure. + """ + name, ina_type = name_tuple + + for datum in self._brdcfg: + if datum["name"] == name: + rs = int(float(datum["rs"]) * 1000.0) + board = datum["sweetberry"] + + if board == self._board: + if "port" in datum and "addr" in datum: + port = datum["port"] + addr = datum["addr"] + else: + channel = int(datum["channel"]) + port, addr = self.CHMAP[channel] + self.add_ina(port, ina_type, addr, 0, rs, data=datum) + return True + else: + return False + raise Exception("Power", "Failed to find INA %s" % name) + + def set_time(self, timestamp_us): + """Set sweetberry time to match host time. + + Args: + timestamp_us: host timestmap in us. + """ + # 0x0005 , 8 byte timestamp + cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us) + ret = self.wr_command(cmd) + + self._logger.debug( + "Command SETTIME: %s", "success" if ret == 0 else "failure" + ) + + def add_ina(self, bus, ina_type, addr, extra, resistance, data=None): + """Add an INA to the data acquisition list. + + Args: + bus: which i2c bus the INA is on. Same ordering as Si2c. + ina_type: Ina interface: INA_POWER/BUSV/etc. + addr: 7 bit i2c addr of this INA + extra: extra data for nonstandard configs. + resistance: int, shunt resistance in mOhm + """ + # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs + cmd = struct.pack( + "<HBBBBI", self.CMD_ADDINA, bus, ina_type, addr, extra, resistance + ) + ret = self.wr_command(cmd) + if ret == 0: + if data: + name = data["name"] + else: + name = "ina%d_%02x" % (bus, addr) + self.append_ina_struct( + name, resistance, bus, addr, data=data, ina_type=ina_type + ) + self._logger.debug( + "Command ADD_INA: %s", "success" if ret == 0 else "failure" + ) + + def report_header_size(self): + """Helper function to calculate power record header size.""" + result = 2 + timestamp = 8 + return result + timestamp + + def report_size(self, ina_count): + """Helper function to calculate full power record size.""" + record = 2 + + datasize = self.report_header_size() + ina_count * record + # Round to multiple of 4 bytes. + datasize = int(((datasize + 3) // 4) * 4) + + return datasize + + def read_line(self): + """Read a line of data from the setup INAs + + Returns: + list of dicts of the values read by ina/type tuple, otherwise None. + [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}] + """ + try: + expected_bytes = self.report_size(len(self._inas)) + cmd = struct.pack("<H", self.CMD_NEXT) + bytesread = self.wr_command(cmd, read_count=expected_bytes) + except usb.core.USBError as e: + self._logger.error("READ LINE FAILED %s", e) + return None + + if len(bytesread) == 1: + if bytesread[0] != 0x6: + self._logger.debug( + "READ LINE FAILED bytes: %d ret: %02x", + len(bytesread), + bytesread[0], + ) + return None + + if len(bytesread) % expected_bytes != 0: + self._logger.debug( + "READ LINE WARNING: expected %d, got %d", + expected_bytes, + len(bytesread), + ) + + packet_count = len(bytesread) // expected_bytes + + values = [] + for i in range(0, packet_count): + start = i * expected_bytes + end = (i + 1) * expected_bytes + record = self.interpret_line(bytesread[start:end]) + values.append(record) + + return values + + def interpret_line(self, data): + """Interpret a power record from INAs + + Args: + data: one single record of bytes. + + Output: + stdout of the record in csv format. + + Returns: + dict containing name, value of recorded data. + """ + status, size = struct.unpack("<BB", data[0:2]) + if len(data) != self.report_size(size): + self._logger.error( + "READ LINE FAILED st:%d size:%d expected:%d len:%d", + status, + size, + self.report_size(size), + len(data), + ) + else: + pass - Returns: - dict containing name, value of recorded data. - """ - status, size = struct.unpack("<BB", data[0:2]) - if len(data) != self.report_size(size): - self._logger.error("READ LINE FAILED st:%d size:%d expected:%d len:%d", - status, size, self.report_size(size), len(data)) - else: - pass + timestamp = struct.unpack("<Q", data[2:10])[0] + self._logger.debug( + "READ LINE: st:%d size:%d time:%dus", status, size, timestamp + ) + ftimestamp = float(timestamp) / 1000000.0 - timestamp = struct.unpack("<Q", data[2:10])[0] - self._logger.debug("READ LINE: st:%d size:%d time:%dus", status, size, - timestamp) - ftimestamp = float(timestamp) / 1000000. + record = {"ts": ftimestamp, "status": status, "berry": self._board} - record = {"ts": ftimestamp, "status": status, "berry":self._board} + for i in range(0, size): + idx = self.report_header_size() + 2 * i + name = self._inas[i]["name"] + name_tuple = (self._inas[i]["name"], self._inas[i]["type"]) - for i in range(0, size): - idx = self.report_header_size() + 2*i - name = self._inas[i]['name'] - name_tuple = (self._inas[i]['name'], self._inas[i]['type']) + raw_val = struct.unpack("<h", data[idx : idx + 2])[0] - raw_val = struct.unpack("<h", data[idx:idx+2])[0] + if self._inas[i]["type"] == Spower.INA_POWER: + val = raw_val * self._inas[i]["uWscale"] + elif self._inas[i]["type"] == Spower.INA_BUSV: + val = raw_val * self._inas[i]["mVscale"] + elif self._inas[i]["type"] == Spower.INA_CURRENT: + val = raw_val * self._inas[i]["uAscale"] + elif self._inas[i]["type"] == Spower.INA_SHUNTV: + val = raw_val * self._inas[i]["uVscale"] - if self._inas[i]['type'] == Spower.INA_POWER: - val = raw_val * self._inas[i]['uWscale'] - elif self._inas[i]['type'] == Spower.INA_BUSV: - val = raw_val * self._inas[i]['mVscale'] - elif self._inas[i]['type'] == Spower.INA_CURRENT: - val = raw_val * self._inas[i]['uAscale'] - elif self._inas[i]['type'] == Spower.INA_SHUNTV: - val = raw_val * self._inas[i]['uVscale'] + self._logger.debug( + "READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, raw_val, val + ) + record[name_tuple] = val - self._logger.debug("READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, - raw_val, val) - record[name_tuple] = val + return record - return record + def load_board(self, brdfile): + """Load a board config. - def load_board(self, brdfile): - """Load a board config. + Args: + brdfile: Filename of a json file decribing the INA wiring of this board. + """ + with open(process_filename(brdfile)) as data_file: + data = json.load(data_file) - Args: - brdfile: Filename of a json file decribing the INA wiring of this board. - """ - with open(process_filename(brdfile)) as data_file: - data = json.load(data_file) - - #TODO: validate this. - self._brdcfg = data; - self._logger.debug(pprint.pformat(data)) + # TODO: validate this. + self._brdcfg = data + self._logger.debug(pprint.pformat(data)) class powerlog(object): - """Power class to log aggregated power. - - Usage: - obj = powerlog() - - Instance Variables: - _data: a StatsManager object that records sweetberry readings and calculates - statistics. - _pwr[]: Spower objects for individual sweetberries. - """ + """Power class to log aggregated power. - def __init__(self, brdfile, cfgfile, serial_a=None, serial_b=None, - sync_date=False, use_ms=False, use_mW=False, print_stats=False, - stats_dir=None, stats_json_dir=None, print_raw_data=True, - raw_data_dir=None): - """Init the powerlog class and set the variables. - - Args: - brdfile: string name of json file containing board layout. - cfgfile: string name of json containing list of rails to read. - serial_a: serial number of sweetberry A. - serial_b: serial number of sweetberry B. - sync_date: report timestamps synced with host datetime. - use_ms: report timestamps in ms rather than us. - use_mW: report power as milliwatts, otherwise default to microwatts. - print_stats: print statistics for sweetberry readings at the end. - stats_dir: directory to save sweetberry readings statistics; if None then - do not save the statistics. - stats_json_dir: directory to save means of sweetberry readings in json - format; if None then do not save the statistics. - print_raw_data: print sweetberry readings raw data in real time, default - is to print. - raw_data_dir: directory to save sweetberry readings raw data; if None then - do not save the raw data. - """ - self._logger = logging.getLogger(__name__) - self._data = StatsManager() - self._pwr = {} - self._use_ms = use_ms - self._use_mW = use_mW - self._print_stats = print_stats - self._stats_dir = stats_dir - self._stats_json_dir = stats_json_dir - self._print_raw_data = print_raw_data - self._raw_data_dir = raw_data_dir - - if not serial_a and not serial_b: - self._pwr['A'] = Spower('A') - if serial_a: - self._pwr['A'] = Spower('A', serialname=serial_a) - if serial_b: - self._pwr['B'] = Spower('B', serialname=serial_b) - - with open(process_filename(cfgfile)) as data_file: - names = json.load(data_file) - self._names = self.process_scenario(names) - - for key in self._pwr: - self._pwr[key].load_board(brdfile) - self._pwr[key].reset() - - # Allocate the rails to the appropriate boards. - used_boards = [] - for name in self._names: - success = False - for key in self._pwr.keys(): - if self._pwr[key].add_ina_name(name): - success = True - if key not in used_boards: - used_boards.append(key) - if not success: - raise Exception("Failed to add %s (maybe missing " - "sweetberry, or bad board file?)" % name) - - # Evict unused boards. - for key in list(self._pwr.keys()): - if key not in used_boards: - self._pwr.pop(key) - - for key in self._pwr.keys(): - if sync_date: - self._pwr[key].set_time(time.time() * 1000000) - else: - self._pwr[key].set_time(0) - - def process_scenario(self, name_list): - """Return list of tuples indicating name and type. + Usage: + obj = powerlog() - Args: - json originated list of names, or [name, type] - Returns: - list of tuples of (name, type) defaulting to type "POWER" - Raises: exception, invalid INA type. + Instance Variables: + _data: a StatsManager object that records sweetberry readings and calculates + statistics. + _pwr[]: Spower objects for individual sweetberries. """ - names = [] - for entry in name_list: - if isinstance(entry, list): - name = entry[0] - if entry[1] == "POWER": - type = Spower.INA_POWER - elif entry[1] == "BUSV": - type = Spower.INA_BUSV - elif entry[1] == "CURRENT": - type = Spower.INA_CURRENT - elif entry[1] == "SHUNTV": - type = Spower.INA_SHUNTV - else: - raise Exception("Invalid INA type", "Type of %s [%s] not recognized," - " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1])) - else: - name = entry - type = Spower.INA_POWER - names.append((name, type)) - return names + def __init__( + self, + brdfile, + cfgfile, + serial_a=None, + serial_b=None, + sync_date=False, + use_ms=False, + use_mW=False, + print_stats=False, + stats_dir=None, + stats_json_dir=None, + print_raw_data=True, + raw_data_dir=None, + ): + """Init the powerlog class and set the variables. + + Args: + brdfile: string name of json file containing board layout. + cfgfile: string name of json containing list of rails to read. + serial_a: serial number of sweetberry A. + serial_b: serial number of sweetberry B. + sync_date: report timestamps synced with host datetime. + use_ms: report timestamps in ms rather than us. + use_mW: report power as milliwatts, otherwise default to microwatts. + print_stats: print statistics for sweetberry readings at the end. + stats_dir: directory to save sweetberry readings statistics; if None then + do not save the statistics. + stats_json_dir: directory to save means of sweetberry readings in json + format; if None then do not save the statistics. + print_raw_data: print sweetberry readings raw data in real time, default + is to print. + raw_data_dir: directory to save sweetberry readings raw data; if None then + do not save the raw data. + """ + self._logger = logging.getLogger(__name__) + self._data = StatsManager() + self._pwr = {} + self._use_ms = use_ms + self._use_mW = use_mW + self._print_stats = print_stats + self._stats_dir = stats_dir + self._stats_json_dir = stats_json_dir + self._print_raw_data = print_raw_data + self._raw_data_dir = raw_data_dir + + if not serial_a and not serial_b: + self._pwr["A"] = Spower("A") + if serial_a: + self._pwr["A"] = Spower("A", serialname=serial_a) + if serial_b: + self._pwr["B"] = Spower("B", serialname=serial_b) + + with open(process_filename(cfgfile)) as data_file: + names = json.load(data_file) + self._names = self.process_scenario(names) - def start(self, integration_us_request, seconds, sync_speed=.8): - """Starts sampling. - - Args: - integration_us_request: requested interval between sample values. - seconds: time until exit, or None to run until cancel. - sync_speed: A usb request is sent every [.8] * integration_us. - """ - # We will get back the actual integration us. - # It should be the same for all devices. - integration_us = None - for key in self._pwr: - integration_us_new = self._pwr[key].start(integration_us_request) - if integration_us: - if integration_us != integration_us_new: - raise Exception("FAIL", - "Integration on A: %dus != integration on B %dus" % ( - integration_us, integration_us_new)) - integration_us = integration_us_new - - # CSV header - title = "ts:%dus" % integration_us - for name_tuple in self._names: - name, ina_type = name_tuple - - if ina_type == Spower.INA_POWER: - unit = "mW" if self._use_mW else "uW" - elif ina_type == Spower.INA_BUSV: - unit = "mV" - elif ina_type == Spower.INA_CURRENT: - unit = "uA" - elif ina_type == Spower.INA_SHUNTV: - unit = "uV" - - title += ", %s %s" % (name, unit) - name_type = name + Spower.INA_SUFFIX[ina_type] - self._data.SetUnit(name_type, unit) - title += ", status" - if self._print_raw_data: - logoutput(title) - - forever = False - if not seconds: - forever = True - end_time = time.time() + seconds - try: - pending_records = [] - while forever or end_time > time.time(): - if (integration_us > 5000): - time.sleep((integration_us / 1000000.) * sync_speed) for key in self._pwr: - records = self._pwr[key].read_line() - if not records: - continue - - for record in records: - pending_records.append(record) - - pending_records.sort(key=lambda r: r['ts']) - - aggregate_record = {"boards": set()} - for record in pending_records: - if record["berry"] not in aggregate_record["boards"]: - for rkey in record.keys(): - aggregate_record[rkey] = record[rkey] - aggregate_record["boards"].add(record["berry"]) - else: - self._logger.info("break %s, %s", record["berry"], - aggregate_record["boards"]) - break - - if aggregate_record["boards"] == set(self._pwr.keys()): - csv = "%f" % aggregate_record["ts"] - for name in self._names: - if name in aggregate_record: - multiplier = 0.001 if (self._use_mW and - name[1]==Spower.INA_POWER) else 1 - value = aggregate_record[name] * multiplier - csv += ", %.2f" % value - name_type = name[0] + Spower.INA_SUFFIX[name[1]] - self._data.AddSample(name_type, value) - else: - csv += ", " - csv += ", %d" % aggregate_record["status"] - if self._print_raw_data: - logoutput(csv) - - aggregate_record = {"boards": set()} - for r in range(0, len(self._pwr)): - pending_records.pop(0) - - except KeyboardInterrupt: - self._logger.info('\nCTRL+C caught.') - - finally: - for key in self._pwr: - self._pwr[key].stop() - self._data.CalculateStats() - if self._print_stats: - print(self._data.SummaryToString()) - save_dir = 'sweetberry%s' % time.time() - if self._stats_dir: - stats_dir = os.path.join(self._stats_dir, save_dir) - self._data.SaveSummary(stats_dir) - if self._stats_json_dir: - stats_json_dir = os.path.join(self._stats_json_dir, save_dir) - self._data.SaveSummaryJSON(stats_json_dir) - if self._raw_data_dir: - raw_data_dir = os.path.join(self._raw_data_dir, save_dir) - self._data.SaveRawData(raw_data_dir) + self._pwr[key].load_board(brdfile) + self._pwr[key].reset() + + # Allocate the rails to the appropriate boards. + used_boards = [] + for name in self._names: + success = False + for key in self._pwr.keys(): + if self._pwr[key].add_ina_name(name): + success = True + if key not in used_boards: + used_boards.append(key) + if not success: + raise Exception( + "Failed to add %s (maybe missing " + "sweetberry, or bad board file?)" % name + ) + + # Evict unused boards. + for key in list(self._pwr.keys()): + if key not in used_boards: + self._pwr.pop(key) + + for key in self._pwr.keys(): + if sync_date: + self._pwr[key].set_time(time.time() * 1000000) + else: + self._pwr[key].set_time(0) + + def process_scenario(self, name_list): + """Return list of tuples indicating name and type. + + Args: + json originated list of names, or [name, type] + Returns: + list of tuples of (name, type) defaulting to type "POWER" + Raises: exception, invalid INA type. + """ + names = [] + for entry in name_list: + if isinstance(entry, list): + name = entry[0] + if entry[1] == "POWER": + type = Spower.INA_POWER + elif entry[1] == "BUSV": + type = Spower.INA_BUSV + elif entry[1] == "CURRENT": + type = Spower.INA_CURRENT + elif entry[1] == "SHUNTV": + type = Spower.INA_SHUNTV + else: + raise Exception( + "Invalid INA type", + "Type of %s [%s] not recognized," + " try one of POWER, BUSV, CURRENT" + % (entry[0], entry[1]), + ) + else: + name = entry + type = Spower.INA_POWER + + names.append((name, type)) + return names + + def start(self, integration_us_request, seconds, sync_speed=0.8): + """Starts sampling. + + Args: + integration_us_request: requested interval between sample values. + seconds: time until exit, or None to run until cancel. + sync_speed: A usb request is sent every [.8] * integration_us. + """ + # We will get back the actual integration us. + # It should be the same for all devices. + integration_us = None + for key in self._pwr: + integration_us_new = self._pwr[key].start(integration_us_request) + if integration_us: + if integration_us != integration_us_new: + raise Exception( + "FAIL", + # pylint:disable=bad-string-format-type + "Integration on A: %dus != integration on B %dus" + % (integration_us, integration_us_new), + ) + integration_us = integration_us_new + + # CSV header + title = "ts:%dus" % integration_us + for name_tuple in self._names: + name, ina_type = name_tuple + + if ina_type == Spower.INA_POWER: + unit = "mW" if self._use_mW else "uW" + elif ina_type == Spower.INA_BUSV: + unit = "mV" + elif ina_type == Spower.INA_CURRENT: + unit = "uA" + elif ina_type == Spower.INA_SHUNTV: + unit = "uV" + + title += ", %s %s" % (name, unit) + name_type = name + Spower.INA_SUFFIX[ina_type] + self._data.SetUnit(name_type, unit) + title += ", status" + if self._print_raw_data: + logoutput(title) + + forever = False + if not seconds: + forever = True + end_time = time.time() + seconds + try: + pending_records = [] + while forever or end_time > time.time(): + if integration_us > 5000: + time.sleep((integration_us / 1000000.0) * sync_speed) + for key in self._pwr: + records = self._pwr[key].read_line() + if not records: + continue + + for record in records: + pending_records.append(record) + + pending_records.sort(key=lambda r: r["ts"]) + + aggregate_record = {"boards": set()} + for record in pending_records: + if record["berry"] not in aggregate_record["boards"]: + for rkey in record.keys(): + aggregate_record[rkey] = record[rkey] + aggregate_record["boards"].add(record["berry"]) + else: + self._logger.info( + "break %s, %s", + record["berry"], + aggregate_record["boards"], + ) + break + + if aggregate_record["boards"] == set(self._pwr.keys()): + csv = "%f" % aggregate_record["ts"] + for name in self._names: + if name in aggregate_record: + multiplier = ( + 0.001 + if ( + self._use_mW + and name[1] == Spower.INA_POWER + ) + else 1 + ) + value = aggregate_record[name] * multiplier + csv += ", %.2f" % value + name_type = name[0] + Spower.INA_SUFFIX[name[1]] + self._data.AddSample(name_type, value) + else: + csv += ", " + csv += ", %d" % aggregate_record["status"] + if self._print_raw_data: + logoutput(csv) + + aggregate_record = {"boards": set()} + for r in range(0, len(self._pwr)): + pending_records.pop(0) + + except KeyboardInterrupt: + self._logger.info("\nCTRL+C caught.") + + finally: + for key in self._pwr: + self._pwr[key].stop() + self._data.CalculateStats() + if self._print_stats: + print(self._data.SummaryToString()) + save_dir = "sweetberry%s" % time.time() + if self._stats_dir: + stats_dir = os.path.join(self._stats_dir, save_dir) + self._data.SaveSummary(stats_dir) + if self._stats_json_dir: + stats_json_dir = os.path.join(self._stats_json_dir, save_dir) + self._data.SaveSummaryJSON(stats_json_dir) + if self._raw_data_dir: + raw_data_dir = os.path.join(self._raw_data_dir, save_dir) + self._data.SaveRawData(raw_data_dir) def main(argv=None): - if argv is None: - argv = sys.argv[1:] - # Command line argument description. - parser = argparse.ArgumentParser( - description="Gather CSV data from sweetberry") - parser.add_argument('-b', '--board', type=str, - help="Board configuration file, eg. my.board", default="") - parser.add_argument('-c', '--config', type=str, - help="Rail config to monitor, eg my.scenario", default="") - parser.add_argument('-A', '--serial', type=str, - help="Serial number of sweetberry A", default="") - parser.add_argument('-B', '--serial_b', type=str, - help="Serial number of sweetberry B", default="") - parser.add_argument('-t', '--integration_us', type=int, - help="Target integration time for samples", default=100000) - parser.add_argument('-s', '--seconds', type=float, - help="Seconds to run capture", default=0.) - parser.add_argument('--date', default=False, - help="Sync logged timestamp to host date", action="store_true") - parser.add_argument('--ms', default=False, - help="Print timestamp as milliseconds", action="store_true") - parser.add_argument('--mW', default=False, - help="Print power as milliwatts, otherwise default to microwatts", - action="store_true") - parser.add_argument('--slow', default=False, - help="Intentionally overflow", action="store_true") - parser.add_argument('--print_stats', default=False, action="store_true", - help="Print statistics for sweetberry readings at the end") - parser.add_argument('--save_stats', type=str, nargs='?', - dest='stats_dir', metavar='STATS_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save statistics for sweetberry readings to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "stats will be saved to where %(prog)s is located; if this flag is " - "not set, then do not save stats") - parser.add_argument('--save_stats_json', type=str, nargs='?', - dest='stats_json_dir', metavar='STATS_JSON_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save means for sweetberry readings in json to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "stats will be saved to where %(prog)s is located; if this flag is " - "not set, then do not save stats") - parser.add_argument('--no_print_raw_data', - dest='print_raw_data', default=True, action="store_false", - help="Not print raw sweetberry readings at real time, default is to " - "print") - parser.add_argument('--save_raw_data', type=str, nargs='?', - dest='raw_data_dir', metavar='RAW_DATA_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save raw data for sweetberry readings to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "raw data will be saved to where %(prog)s is located; if this flag " - "is not set, then do not save raw data") - parser.add_argument('-v', '--verbose', default=False, - help="Very chatty printout", action="store_true") - - args = parser.parse_args(argv) - - root_logger = logging.getLogger(__name__) - if args.verbose: - root_logger.setLevel(logging.DEBUG) - else: - root_logger.setLevel(logging.INFO) - - # if powerlog is used through main, log to sys.stdout - if __name__ == "__main__": - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) - root_logger.addHandler(stdout_handler) - - integration_us_request = args.integration_us - if not args.board: - raise Exception("Power", "No board file selected, see board.README") - if not args.config: - raise Exception("Power", "No config file selected, see board.README") - - brdfile = args.board - cfgfile = args.config - seconds = args.seconds - serial_a = args.serial - serial_b = args.serial_b - sync_date = args.date - use_ms = args.ms - use_mW = args.mW - print_stats = args.print_stats - stats_dir = args.stats_dir - stats_json_dir = args.stats_json_dir - print_raw_data = args.print_raw_data - raw_data_dir = args.raw_data_dir - - boards = [] - - sync_speed = .8 - if args.slow: - sync_speed = 1.2 - - # Set up logging interface. - powerlogger = powerlog(brdfile, cfgfile, serial_a=serial_a, serial_b=serial_b, - sync_date=sync_date, use_ms=use_ms, use_mW=use_mW, - print_stats=print_stats, stats_dir=stats_dir, - stats_json_dir=stats_json_dir, - print_raw_data=print_raw_data,raw_data_dir=raw_data_dir) - - # Start logging. - powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed) + if argv is None: + argv = sys.argv[1:] + # Command line argument description. + parser = argparse.ArgumentParser( + description="Gather CSV data from sweetberry" + ) + parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration file, eg. my.board", + default="", + ) + parser.add_argument( + "-c", + "--config", + type=str, + help="Rail config to monitor, eg my.scenario", + default="", + ) + parser.add_argument( + "-A", + "--serial", + type=str, + help="Serial number of sweetberry A", + default="", + ) + parser.add_argument( + "-B", + "--serial_b", + type=str, + help="Serial number of sweetberry B", + default="", + ) + parser.add_argument( + "-t", + "--integration_us", + type=int, + help="Target integration time for samples", + default=100000, + ) + parser.add_argument( + "-s", + "--seconds", + type=float, + help="Seconds to run capture", + default=0.0, + ) + parser.add_argument( + "--date", + default=False, + help="Sync logged timestamp to host date", + action="store_true", + ) + parser.add_argument( + "--ms", + default=False, + help="Print timestamp as milliseconds", + action="store_true", + ) + parser.add_argument( + "--mW", + default=False, + help="Print power as milliwatts, otherwise default to microwatts", + action="store_true", + ) + parser.add_argument( + "--slow", + default=False, + help="Intentionally overflow", + action="store_true", + ) + parser.add_argument( + "--print_stats", + default=False, + action="store_true", + help="Print statistics for sweetberry readings at the end", + ) + parser.add_argument( + "--save_stats", + type=str, + nargs="?", + dest="stats_dir", + metavar="STATS_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save statistics for sweetberry readings to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "stats will be saved to where %(prog)s is located; if this flag is " + "not set, then do not save stats", + ) + parser.add_argument( + "--save_stats_json", + type=str, + nargs="?", + dest="stats_json_dir", + metavar="STATS_JSON_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save means for sweetberry readings in json to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "stats will be saved to where %(prog)s is located; if this flag is " + "not set, then do not save stats", + ) + parser.add_argument( + "--no_print_raw_data", + dest="print_raw_data", + default=True, + action="store_false", + help="Not print raw sweetberry readings at real time, default is to " + "print", + ) + parser.add_argument( + "--save_raw_data", + type=str, + nargs="?", + dest="raw_data_dir", + metavar="RAW_DATA_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save raw data for sweetberry readings to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "raw data will be saved to where %(prog)s is located; if this flag " + "is not set, then do not save raw data", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + help="Very chatty printout", + action="store_true", + ) + + args = parser.parse_args(argv) + + root_logger = logging.getLogger(__name__) + if args.verbose: + root_logger.setLevel(logging.DEBUG) + else: + root_logger.setLevel(logging.INFO) + + # if powerlog is used through main, log to sys.stdout + if __name__ == "__main__": + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter( + logging.Formatter("%(levelname)s: %(message)s") + ) + root_logger.addHandler(stdout_handler) + + integration_us_request = args.integration_us + if not args.board: + raise Exception("Power", "No board file selected, see board.README") + if not args.config: + raise Exception("Power", "No config file selected, see board.README") + + brdfile = args.board + cfgfile = args.config + seconds = args.seconds + serial_a = args.serial + serial_b = args.serial_b + sync_date = args.date + use_ms = args.ms + use_mW = args.mW + print_stats = args.print_stats + stats_dir = args.stats_dir + stats_json_dir = args.stats_json_dir + print_raw_data = args.print_raw_data + raw_data_dir = args.raw_data_dir + + boards = [] + + sync_speed = 0.8 + if args.slow: + sync_speed = 1.2 + + # Set up logging interface. + powerlogger = powerlog( + brdfile, + cfgfile, + serial_a=serial_a, + serial_b=serial_b, + sync_date=sync_date, + use_ms=use_ms, + use_mW=use_mW, + print_stats=print_stats, + stats_dir=stats_dir, + stats_json_dir=stats_json_dir, + print_raw_data=print_raw_data, + raw_data_dir=raw_data_dir, + ) + + # Start logging. + powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed) if __name__ == "__main__": - main() + main() diff --git a/extra/usb_power/powerlog_unittest.py b/extra/usb_power/powerlog_unittest.py index 1d0718530e..62667e35b8 100644 --- a/extra/usb_power/powerlog_unittest.py +++ b/extra/usb_power/powerlog_unittest.py @@ -1,10 +1,6 @@ -# Copyright 2018 The Chromium OS Authors. All rights reserved. +# Copyright 2018 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Unit tests for powerlog.""" @@ -13,42 +9,44 @@ import shutil import tempfile import unittest -import powerlog +from usb_power import powerlog + class TestPowerlog(unittest.TestCase): - """Test to verify powerlog util methods work as expected.""" - - def setUp(self): - """Set up data and create a temporary directory to save data and stats.""" - self.tempdir = tempfile.mkdtemp() - self.filename = 'testfile' - self.filepath = os.path.join(self.tempdir, self.filename) - with open(self.filepath, 'w') as f: - f.write('') - - def tearDown(self): - """Delete the temporary directory and its content.""" - shutil.rmtree(self.tempdir) - - def test_ProcessFilenameAbsoluteFilePath(self): - """Absolute file path is returned unchanged.""" - processed_fname = powerlog.process_filename(self.filepath) - self.assertEqual(self.filepath, processed_fname) - - def test_ProcessFilenameRelativeFilePath(self): - """Finds relative file path inside a known config location.""" - original = powerlog.CONFIG_LOCATIONS - powerlog.CONFIG_LOCATIONS = [self.tempdir] - processed_fname = powerlog.process_filename(self.filename) - try: - self.assertEqual(self.filepath, processed_fname) - finally: - powerlog.CONFIG_LOCATIONS = original - - def test_ProcessFilenameInvalid(self): - """IOError is raised when file cannot be found by any of the four ways.""" - with self.assertRaises(IOError): - powerlog.process_filename(self.filename) - -if __name__ == '__main__': - unittest.main() + """Test to verify powerlog util methods work as expected.""" + + def setUp(self): + """Set up data and create a temporary directory to save data and stats.""" + self.tempdir = tempfile.mkdtemp() + self.filename = "testfile" + self.filepath = os.path.join(self.tempdir, self.filename) + with open(self.filepath, "w") as f: + f.write("") + + def tearDown(self): + """Delete the temporary directory and its content.""" + shutil.rmtree(self.tempdir) + + def test_ProcessFilenameAbsoluteFilePath(self): + """Absolute file path is returned unchanged.""" + processed_fname = powerlog.process_filename(self.filepath) + self.assertEqual(self.filepath, processed_fname) + + def test_ProcessFilenameRelativeFilePath(self): + """Finds relative file path inside a known config location.""" + original = powerlog.CONFIG_LOCATIONS + powerlog.CONFIG_LOCATIONS = [self.tempdir] + processed_fname = powerlog.process_filename(self.filename) + try: + self.assertEqual(self.filepath, processed_fname) + finally: + powerlog.CONFIG_LOCATIONS = original + + def test_ProcessFilenameInvalid(self): + """IOError is raised when file cannot be found by any of the four ways.""" + with self.assertRaises(IOError): + powerlog.process_filename(self.filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/usb_power/stats_manager.py b/extra/usb_power/stats_manager.py index 0f8c3fcb15..2035138731 100644 --- a/extra/usb_power/stats_manager.py +++ b/extra/usb_power/stats_manager.py @@ -1,10 +1,6 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Calculates statistics for lists of data and pretty print them.""" @@ -18,384 +14,409 @@ import logging import math import os -import numpy +import numpy # pylint:disable=import-error -STATS_PREFIX = '@@' -NAN_TAG = '*' -NAN_DESCRIPTION = '%s domains contain NaN samples' % NAN_TAG +STATS_PREFIX = "@@" +NAN_TAG = "*" +NAN_DESCRIPTION = "%s domains contain NaN samples" % NAN_TAG LONG_UNIT = { - '': 'N/A', - 'mW': 'milliwatt', - 'uW': 'microwatt', - 'mV': 'millivolt', - 'uA': 'microamp', - 'uV': 'microvolt' + "": "N/A", + "mW": "milliwatt", + "uW": "microwatt", + "mV": "millivolt", + "uA": "microamp", + "uV": "microvolt", } class StatsManagerError(Exception): - """Errors in StatsManager class.""" - pass + """Errors in StatsManager class.""" + pass -class StatsManager(object): - """Calculates statistics for several lists of data(float). - - Example usage: - - >>> stats = StatsManager(title='Title Banner') - >>> stats.AddSample(TIME_KEY, 50.0) - >>> stats.AddSample(TIME_KEY, 25.0) - >>> stats.AddSample(TIME_KEY, 40.0) - >>> stats.AddSample(TIME_KEY, 10.0) - >>> stats.AddSample(TIME_KEY, 10.0) - >>> stats.AddSample('frobnicate', 11.5) - >>> stats.AddSample('frobnicate', 9.0) - >>> stats.AddSample('foobar', 11111.0) - >>> stats.AddSample('foobar', 22222.0) - >>> stats.CalculateStats() - >>> print(stats.SummaryToString()) - ` @@-------------------------------------------------------------- - ` @@ Title Banner - @@-------------------------------------------------------------- - @@ NAME COUNT MEAN STDDEV MAX MIN - @@ sample_msecs 4 31.25 15.16 50.00 10.00 - @@ foobar 2 16666.50 5555.50 22222.00 11111.00 - @@ frobnicate 2 10.25 1.25 11.50 9.00 - ` @@-------------------------------------------------------------- - - Attributes: - _data: dict of list of readings for each domain(key) - _unit: dict of unit for each domain(key) - _smid: id supplied to differentiate data output to other StatsManager - instances that potentially save to the same directory - if smid all output files will be named |smid|_|fname| - _title: title to add as banner to formatted summary. If no title, - no banner gets added - _order: list of formatting order for domains. Domains not listed are - displayed in sorted order - _hide_domains: collection of domains to hide when formatting summary string - _accept_nan: flag to indicate if NaN samples are acceptable - _nan_domains: set to keep track of which domains contain NaN samples - _summary: dict of stats per domain (key): min, max, count, mean, stddev - _logger = StatsManager logger - - Note: - _summary is empty until CalculateStats() is called, and is updated when - CalculateStats() is called. - """ - - # pylint: disable=W0102 - def __init__(self, smid='', title='', order=[], hide_domains=[], - accept_nan=True): - """Initialize infrastructure for data and their statistics.""" - self._title = title - self._data = collections.defaultdict(list) - self._unit = collections.defaultdict(str) - self._smid = smid - self._order = order - self._hide_domains = hide_domains - self._accept_nan = accept_nan - self._nan_domains = set() - self._summary = {} - self._logger = logging.getLogger(type(self).__name__) - - def AddSample(self, domain, sample): - """Add one sample for a domain. - - Args: - domain: the domain name for the sample. - sample: one time sample for domain, expect type float. - - Raises: - StatsManagerError: if trying to add NaN and |_accept_nan| is false - """ - try: - sample = float(sample) - except ValueError: - # if we don't accept nan this will be caught below - self._logger.debug('sample %s for domain %s is not a number. Making NaN', - sample, domain) - sample = float('NaN') - if not self._accept_nan and math.isnan(sample): - raise StatsManagerError('accept_nan is false. Cannot add NaN sample.') - self._data[domain].append(sample) - if math.isnan(sample): - self._nan_domains.add(domain) - - def SetUnit(self, domain, unit): - """Set the unit for a domain. - - There can be only one unit for each domain. Setting unit twice will - overwrite the original unit. - - Args: - domain: the domain name. - unit: unit of the domain. - """ - if domain in self._unit: - self._logger.warning('overwriting the unit of %s, old unit is %s, new ' - 'unit is %s.', domain, self._unit[domain], unit) - self._unit[domain] = unit - def CalculateStats(self): - """Calculate stats for all domain-data pairs. - - First erases all previous stats, then calculate stats for all data. - """ - self._summary = {} - for domain, data in self._data.items(): - data_np = numpy.array(data) - self._summary[domain] = { - 'mean': numpy.nanmean(data_np), - 'min': numpy.nanmin(data_np), - 'max': numpy.nanmax(data_np), - 'stddev': numpy.nanstd(data_np), - 'count': data_np.size, - } - - @property - def DomainsToDisplay(self): - """List of domains that the manager will output in summaries.""" - return set(self._summary.keys()) - set(self._hide_domains) - - @property - def NanInOutput(self): - """Return whether any of the domains to display have NaN values.""" - return bool(len(set(self._nan_domains) & self.DomainsToDisplay)) - - def _SummaryTable(self): - """Generate the matrix to output as a summary. - - Returns: - A 2d matrix of headers and their data for each domain - e.g. - [[NAME, COUNT, MEAN, STDDEV, MAX, MIN], - [pp5000_mw, 10, 50, 0, 50, 50]] - """ - headers = ('NAME', 'COUNT', 'MEAN', 'STDDEV', 'MAX', 'MIN') - table = [headers] - # determine what domains to display & and the order - domains_to_display = self.DomainsToDisplay - display_order = [key for key in self._order if key in domains_to_display] - domains_to_display -= set(display_order) - display_order.extend(sorted(domains_to_display)) - for domain in display_order: - stats = self._summary[domain] - if not domain.endswith(self._unit[domain]): - domain = '%s_%s' % (domain, self._unit[domain]) - if domain in self._nan_domains: - domain = '%s%s' % (domain, NAN_TAG) - row = [domain] - row.append(str(stats['count'])) - for entry in headers[2:]: - row.append('%.2f' % stats[entry.lower()]) - table.append(row) - return table - - def SummaryToMarkdownString(self): - """Format the summary into a b/ compatible markdown table string. - - This requires this sort of output format - - | header1 | header2 | header3 | ... - | --------- | --------- | --------- | ... - | sample1h1 | sample1h2 | sample1h3 | ... - . - . - . - - Returns: - formatted summary string. - """ - # All we need to do before processing is insert a row of '-' between - # the headers, and the data - table = self._SummaryTable() - columns = len(table[0]) - # Using '-:' to allow the numbers to be right aligned - sep_row = ['-'] + ['-:'] * (columns - 1) - table.insert(1, sep_row) - text_rows = ['|'.join(r) for r in table] - body = '\n'.join(['|%s|' % r for r in text_rows]) - if self._title: - title_section = '**%s** \n\n' % self._title - body = title_section + body - # Make sure that the body is terminated with a newline. - return body + '\n' - - def SummaryToString(self, prefix=STATS_PREFIX): - """Format summary into a string, ready for pretty print. - - See class description for format example. - - Args: - prefix: start every row in summary string with prefix, for easier reading. - - Returns: - formatted summary string. - """ - table = self._SummaryTable() - max_col_width = [] - for col_idx in range(len(table[0])): - col_item_widths = [len(row[col_idx]) for row in table] - max_col_width.append(max(col_item_widths)) - - formatted_lines = [] - for row in table: - formatted_row = prefix + ' ' - for i in range(len(row)): - formatted_row += row[i].rjust(max_col_width[i] + 2) - formatted_lines.append(formatted_row) - if self.NanInOutput: - formatted_lines.append('%s %s' % (prefix, NAN_DESCRIPTION)) - - if self._title: - line_length = len(formatted_lines[0]) - dec_length = len(prefix) - # trim title to be at most as long as the longest line without the prefix - title = self._title[:(line_length - dec_length)] - # line is a seperator line consisting of ----- - line = '%s%s' % (prefix, '-' * (line_length - dec_length)) - # prepend the prefix to the centered title - padded_title = '%s%s' % (prefix, title.center(line_length)[dec_length:]) - formatted_lines = [line, padded_title, line] + formatted_lines + [line] - formatted_output = '\n'.join(formatted_lines) - return formatted_output - - def GetSummary(self): - """Getter for summary.""" - return self._summary - - def _MakeUniqueFName(self, fname): - """prepend |_smid| to fname & rotate fname to ensure uniqueness. - - Before saving a file through the StatsManager, make sure that the filename - is unique, first by prepending the smid if any and otherwise by appending - increasing integer suffixes until the filename is unique. - - If |smid| is defined /path/to/example/file.txt becomes - /path/to/example/{smid}_file.txt. - - The rotation works by changing /path/to/example/somename.txt to - /path/to/example/somename1.txt if the first one already exists on the - system. - - Note: this is not thread-safe. While it makes sense to use StatsManager - in a threaded data-collection, the data retrieval should happen in a - single threaded environment to ensure files don't get potentially clobbered. - - Args: - fname: filename to ensure uniqueness. - - Returns: - {smid_}fname{tag}.[b].ext - the smid portion gets prepended if |smid| is defined - the tag portion gets appended if necessary to ensure unique fname - """ - fdir = os.path.dirname(fname) - base, ext = os.path.splitext(os.path.basename(fname)) - if self._smid: - base = '%s_%s' % (self._smid, base) - unique_fname = os.path.join(fdir, '%s%s' % (base, ext)) - tag = 0 - while os.path.exists(unique_fname): - old_fname = unique_fname - unique_fname = os.path.join(fdir, '%s%d%s' % (base, tag, ext)) - self._logger.warning('Attempted to store stats information at %s, but ' - 'file already exists. Attempting to store at %s ' - 'now.', old_fname, unique_fname) - tag += 1 - return unique_fname - - def SaveSummary(self, directory, fname='summary.txt', prefix=STATS_PREFIX): - """Save summary to file. - - Args: - directory: directory to save the summary in. - fname: filename to save summary under. - prefix: start every row in summary string with prefix, for easier reading. - - Returns: - full path of summary save location - """ - summary_str = self.SummaryToString(prefix=prefix) + '\n' - return self._SaveSummary(summary_str, directory, fname) - - def SaveSummaryJSON(self, directory, fname='summary.json'): - """Save summary (only MEAN) into a JSON file. - - Args: - directory: directory to save the JSON summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location - """ - data = {} - for domain in self._summary: - unit = LONG_UNIT.get(self._unit[domain], self._unit[domain]) - data_entry = {'mean': self._summary[domain]['mean'], 'unit': unit} - data[domain] = data_entry - summary_str = json.dumps(data, indent=2) - return self._SaveSummary(summary_str, directory, fname) - - def SaveSummaryMD(self, directory, fname='summary.md'): - """Save summary into a MD file to paste into b/. - - Args: - directory: directory to save the MD summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location +class StatsManager(object): + """Calculates statistics for several lists of data(float). + + Example usage: + + >>> stats = StatsManager(title='Title Banner') + >>> stats.AddSample(TIME_KEY, 50.0) + >>> stats.AddSample(TIME_KEY, 25.0) + >>> stats.AddSample(TIME_KEY, 40.0) + >>> stats.AddSample(TIME_KEY, 10.0) + >>> stats.AddSample(TIME_KEY, 10.0) + >>> stats.AddSample('frobnicate', 11.5) + >>> stats.AddSample('frobnicate', 9.0) + >>> stats.AddSample('foobar', 11111.0) + >>> stats.AddSample('foobar', 22222.0) + >>> stats.CalculateStats() + >>> print(stats.SummaryToString()) + ` @@-------------------------------------------------------------- + ` @@ Title Banner + @@-------------------------------------------------------------- + @@ NAME COUNT MEAN STDDEV MAX MIN + @@ sample_msecs 4 31.25 15.16 50.00 10.00 + @@ foobar 2 16666.50 5555.50 22222.00 11111.00 + @@ frobnicate 2 10.25 1.25 11.50 9.00 + ` @@-------------------------------------------------------------- + + Attributes: + _data: dict of list of readings for each domain(key) + _unit: dict of unit for each domain(key) + _smid: id supplied to differentiate data output to other StatsManager + instances that potentially save to the same directory + if smid all output files will be named |smid|_|fname| + _title: title to add as banner to formatted summary. If no title, + no banner gets added + _order: list of formatting order for domains. Domains not listed are + displayed in sorted order + _hide_domains: collection of domains to hide when formatting summary string + _accept_nan: flag to indicate if NaN samples are acceptable + _nan_domains: set to keep track of which domains contain NaN samples + _summary: dict of stats per domain (key): min, max, count, mean, stddev + _logger = StatsManager logger + + Note: + _summary is empty until CalculateStats() is called, and is updated when + CalculateStats() is called. """ - summary_str = self.SummaryToMarkdownString() - return self._SaveSummary(summary_str, directory, fname) - def _SaveSummary(self, output_str, directory, fname): - """Wrote |output_str| to |fname|. - - Args: - output_str: formatted output string - directory: directory to save the summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location - """ - if not os.path.exists(directory): - os.makedirs(directory) - fname = self._MakeUniqueFName(os.path.join(directory, fname)) - with open(fname, 'w') as f: - f.write(output_str) - return fname - - def GetRawData(self): - """Getter for all raw_data.""" - return self._data - - def SaveRawData(self, directory, dirname='raw_data'): - """Save raw data to file. - - Args: - directory: directory to create the raw data folder in. - dirname: folder in which raw data live. - - Returns: - list of full path of each domain's raw data save location - """ - if not os.path.exists(directory): - os.makedirs(directory) - dirname = os.path.join(directory, dirname) - if not os.path.exists(dirname): - os.makedirs(dirname) - fnames = [] - for domain, data in self._data.items(): - if not domain.endswith(self._unit[domain]): - domain = '%s_%s' % (domain, self._unit[domain]) - fname = self._MakeUniqueFName(os.path.join(dirname, '%s.txt' % domain)) - with open(fname, 'w') as f: - f.write('\n'.join('%.2f' % sample for sample in data) + '\n') - fnames.append(fname) - return fnames + # pylint: disable=W0102 + def __init__( + self, smid="", title="", order=[], hide_domains=[], accept_nan=True + ): + """Initialize infrastructure for data and their statistics.""" + self._title = title + self._data = collections.defaultdict(list) + self._unit = collections.defaultdict(str) + self._smid = smid + self._order = order + self._hide_domains = hide_domains + self._accept_nan = accept_nan + self._nan_domains = set() + self._summary = {} + self._logger = logging.getLogger(type(self).__name__) + + def AddSample(self, domain, sample): + """Add one sample for a domain. + + Args: + domain: the domain name for the sample. + sample: one time sample for domain, expect type float. + + Raises: + StatsManagerError: if trying to add NaN and |_accept_nan| is false + """ + try: + sample = float(sample) + except ValueError: + # if we don't accept nan this will be caught below + self._logger.debug( + "sample %s for domain %s is not a number. Making NaN", + sample, + domain, + ) + sample = float("NaN") + if not self._accept_nan and math.isnan(sample): + raise StatsManagerError( + "accept_nan is false. Cannot add NaN sample." + ) + self._data[domain].append(sample) + if math.isnan(sample): + self._nan_domains.add(domain) + + def SetUnit(self, domain, unit): + """Set the unit for a domain. + + There can be only one unit for each domain. Setting unit twice will + overwrite the original unit. + + Args: + domain: the domain name. + unit: unit of the domain. + """ + if domain in self._unit: + self._logger.warning( + "overwriting the unit of %s, old unit is %s, new " + "unit is %s.", + domain, + self._unit[domain], + unit, + ) + self._unit[domain] = unit + + def CalculateStats(self): + """Calculate stats for all domain-data pairs. + + First erases all previous stats, then calculate stats for all data. + """ + self._summary = {} + for domain, data in self._data.items(): + data_np = numpy.array(data) + self._summary[domain] = { + "mean": numpy.nanmean(data_np), + "min": numpy.nanmin(data_np), + "max": numpy.nanmax(data_np), + "stddev": numpy.nanstd(data_np), + "count": data_np.size, + } + + @property + def DomainsToDisplay(self): + """List of domains that the manager will output in summaries.""" + return set(self._summary.keys()) - set(self._hide_domains) + + @property + def NanInOutput(self): + """Return whether any of the domains to display have NaN values.""" + return bool(len(set(self._nan_domains) & self.DomainsToDisplay)) + + def _SummaryTable(self): + """Generate the matrix to output as a summary. + + Returns: + A 2d matrix of headers and their data for each domain + e.g. + [[NAME, COUNT, MEAN, STDDEV, MAX, MIN], + [pp5000_mw, 10, 50, 0, 50, 50]] + """ + headers = ("NAME", "COUNT", "MEAN", "STDDEV", "MAX", "MIN") + table = [headers] + # determine what domains to display & and the order + domains_to_display = self.DomainsToDisplay + display_order = [ + key for key in self._order if key in domains_to_display + ] + domains_to_display -= set(display_order) + display_order.extend(sorted(domains_to_display)) + for domain in display_order: + stats = self._summary[domain] + if not domain.endswith(self._unit[domain]): + domain = "%s_%s" % (domain, self._unit[domain]) + if domain in self._nan_domains: + domain = "%s%s" % (domain, NAN_TAG) + row = [domain] + row.append(str(stats["count"])) + for entry in headers[2:]: + row.append("%.2f" % stats[entry.lower()]) + table.append(row) + return table + + def SummaryToMarkdownString(self): + """Format the summary into a b/ compatible markdown table string. + + This requires this sort of output format + + | header1 | header2 | header3 | ... + | --------- | --------- | --------- | ... + | sample1h1 | sample1h2 | sample1h3 | ... + . + . + . + + Returns: + formatted summary string. + """ + # All we need to do before processing is insert a row of '-' between + # the headers, and the data + table = self._SummaryTable() + columns = len(table[0]) + # Using '-:' to allow the numbers to be right aligned + sep_row = ["-"] + ["-:"] * (columns - 1) + table.insert(1, sep_row) + text_rows = ["|".join(r) for r in table] + body = "\n".join(["|%s|" % r for r in text_rows]) + if self._title: + title_section = "**%s** \n\n" % self._title + body = title_section + body + # Make sure that the body is terminated with a newline. + return body + "\n" + + def SummaryToString(self, prefix=STATS_PREFIX): + """Format summary into a string, ready for pretty print. + + See class description for format example. + + Args: + prefix: start every row in summary string with prefix, for easier reading. + + Returns: + formatted summary string. + """ + table = self._SummaryTable() + max_col_width = [] + for col_idx in range(len(table[0])): + col_item_widths = [len(row[col_idx]) for row in table] + max_col_width.append(max(col_item_widths)) + + formatted_lines = [] + for row in table: + formatted_row = prefix + " " + for i in range(len(row)): + formatted_row += row[i].rjust(max_col_width[i] + 2) + formatted_lines.append(formatted_row) + if self.NanInOutput: + formatted_lines.append("%s %s" % (prefix, NAN_DESCRIPTION)) + + if self._title: + line_length = len(formatted_lines[0]) + dec_length = len(prefix) + # trim title to be at most as long as the longest line without the prefix + title = self._title[: (line_length - dec_length)] + # line is a seperator line consisting of ----- + line = "%s%s" % (prefix, "-" * (line_length - dec_length)) + # prepend the prefix to the centered title + padded_title = "%s%s" % ( + prefix, + title.center(line_length)[dec_length:], + ) + formatted_lines = ( + [line, padded_title, line] + formatted_lines + [line] + ) + formatted_output = "\n".join(formatted_lines) + return formatted_output + + def GetSummary(self): + """Getter for summary.""" + return self._summary + + def _MakeUniqueFName(self, fname): + """prepend |_smid| to fname & rotate fname to ensure uniqueness. + + Before saving a file through the StatsManager, make sure that the filename + is unique, first by prepending the smid if any and otherwise by appending + increasing integer suffixes until the filename is unique. + + If |smid| is defined /path/to/example/file.txt becomes + /path/to/example/{smid}_file.txt. + + The rotation works by changing /path/to/example/somename.txt to + /path/to/example/somename1.txt if the first one already exists on the + system. + + Note: this is not thread-safe. While it makes sense to use StatsManager + in a threaded data-collection, the data retrieval should happen in a + single threaded environment to ensure files don't get potentially clobbered. + + Args: + fname: filename to ensure uniqueness. + + Returns: + {smid_}fname{tag}.[b].ext + the smid portion gets prepended if |smid| is defined + the tag portion gets appended if necessary to ensure unique fname + """ + fdir = os.path.dirname(fname) + base, ext = os.path.splitext(os.path.basename(fname)) + if self._smid: + base = "%s_%s" % (self._smid, base) + unique_fname = os.path.join(fdir, "%s%s" % (base, ext)) + tag = 0 + while os.path.exists(unique_fname): + old_fname = unique_fname + unique_fname = os.path.join(fdir, "%s%d%s" % (base, tag, ext)) + self._logger.warning( + "Attempted to store stats information at %s, but " + "file already exists. Attempting to store at %s " + "now.", + old_fname, + unique_fname, + ) + tag += 1 + return unique_fname + + def SaveSummary(self, directory, fname="summary.txt", prefix=STATS_PREFIX): + """Save summary to file. + + Args: + directory: directory to save the summary in. + fname: filename to save summary under. + prefix: start every row in summary string with prefix, for easier reading. + + Returns: + full path of summary save location + """ + summary_str = self.SummaryToString(prefix=prefix) + "\n" + return self._SaveSummary(summary_str, directory, fname) + + def SaveSummaryJSON(self, directory, fname="summary.json"): + """Save summary (only MEAN) into a JSON file. + + Args: + directory: directory to save the JSON summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + data = {} + for domain in self._summary: + unit = LONG_UNIT.get(self._unit[domain], self._unit[domain]) + data_entry = {"mean": self._summary[domain]["mean"], "unit": unit} + data[domain] = data_entry + summary_str = json.dumps(data, indent=2) + return self._SaveSummary(summary_str, directory, fname) + + def SaveSummaryMD(self, directory, fname="summary.md"): + """Save summary into a MD file to paste into b/. + + Args: + directory: directory to save the MD summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + summary_str = self.SummaryToMarkdownString() + return self._SaveSummary(summary_str, directory, fname) + + def _SaveSummary(self, output_str, directory, fname): + """Wrote |output_str| to |fname|. + + Args: + output_str: formatted output string + directory: directory to save the summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + if not os.path.exists(directory): + os.makedirs(directory) + fname = self._MakeUniqueFName(os.path.join(directory, fname)) + with open(fname, "w") as f: + f.write(output_str) + return fname + + def GetRawData(self): + """Getter for all raw_data.""" + return self._data + + def SaveRawData(self, directory, dirname="raw_data"): + """Save raw data to file. + + Args: + directory: directory to create the raw data folder in. + dirname: folder in which raw data live. + + Returns: + list of full path of each domain's raw data save location + """ + if not os.path.exists(directory): + os.makedirs(directory) + dirname = os.path.join(directory, dirname) + if not os.path.exists(dirname): + os.makedirs(dirname) + fnames = [] + for domain, data in self._data.items(): + if not domain.endswith(self._unit[domain]): + domain = "%s_%s" % (domain, self._unit[domain]) + fname = self._MakeUniqueFName( + os.path.join(dirname, "%s.txt" % domain) + ) + with open(fname, "w") as f: + f.write("\n".join("%.2f" % sample for sample in data) + "\n") + fnames.append(fname) + return fnames diff --git a/extra/usb_power/stats_manager_unittest.py b/extra/usb_power/stats_manager_unittest.py index beb9984b93..2bfaa5c83d 100644 --- a/extra/usb_power/stats_manager_unittest.py +++ b/extra/usb_power/stats_manager_unittest.py @@ -1,14 +1,11 @@ -# Copyright 2017 The Chromium OS Authors. All rights reserved. +# Copyright 2017 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Unit tests for StatsManager.""" from __future__ import print_function + import json import os import re @@ -16,300 +13,314 @@ import shutil import tempfile import unittest -import stats_manager +import stats_manager # pylint:disable=import-error class TestStatsManager(unittest.TestCase): - """Test to verify StatsManager methods work as expected. - - StatsManager should collect raw data, calculate their statistics, and save - them in expected format. - """ - - def _populate_mock_stats(self): - """Create a populated & processed StatsManager to test data retrieval.""" - self.data.AddSample('A', 99999.5) - self.data.AddSample('A', 100000.5) - self.data.SetUnit('A', 'uW') - self.data.SetUnit('A', 'mW') - self.data.AddSample('B', 1.5) - self.data.AddSample('B', 2.5) - self.data.AddSample('B', 3.5) - self.data.SetUnit('B', 'mV') - self.data.CalculateStats() - - def _populate_mock_stats_no_unit(self): - self.data.AddSample('B', 1000) - self.data.AddSample('A', 200) - self.data.SetUnit('A', 'blue') - - def setUp(self): - """Set up StatsManager and create a temporary directory for test.""" - self.tempdir = tempfile.mkdtemp() - self.data = stats_manager.StatsManager() - - def tearDown(self): - """Delete the temporary directory and its content.""" - shutil.rmtree(self.tempdir) - - def test_AddSample(self): - """Adding a sample successfully adds a sample.""" - self.data.AddSample('Test', 1000) - self.data.SetUnit('Test', 'test') - self.data.CalculateStats() - summary = self.data.GetSummary() - self.assertEqual(1, summary['Test']['count']) - - def test_AddSampleNoFloatAcceptNaN(self): - """Adding a non-number adds 'NaN' and doesn't raise an exception.""" - self.data.AddSample('Test', 10) - self.data.AddSample('Test', 20) - # adding a fake NaN: one that gets converted into NaN internally - self.data.AddSample('Test', 'fiesta') - # adding a real NaN - self.data.AddSample('Test', float('NaN')) - self.data.SetUnit('Test', 'test') - self.data.CalculateStats() - summary = self.data.GetSummary() - # assert that 'NaN' as added. - self.assertEqual(4, summary['Test']['count']) - # assert that mean, min, and max calculatings ignore the 'NaN' - self.assertEqual(10, summary['Test']['min']) - self.assertEqual(20, summary['Test']['max']) - self.assertEqual(15, summary['Test']['mean']) - - def test_AddSampleNoFloatNotAcceptNaN(self): - """Adding a non-number raises a StatsManagerError if accept_nan is False.""" - self.data = stats_manager.StatsManager(accept_nan=False) - with self.assertRaisesRegexp(stats_manager.StatsManagerError, - 'accept_nan is false. Cannot add NaN sample.'): - # adding a fake NaN: one that gets converted into NaN internally - self.data.AddSample('Test', 'fiesta') - with self.assertRaisesRegexp(stats_manager.StatsManagerError, - 'accept_nan is false. Cannot add NaN sample.'): - # adding a real NaN - self.data.AddSample('Test', float('NaN')) - - def test_AddSampleNoUnit(self): - """Not adding a unit does not cause an exception on CalculateStats().""" - self.data.AddSample('Test', 17) - self.data.CalculateStats() - summary = self.data.GetSummary() - self.assertEqual(1, summary['Test']['count']) - - def test_UnitSuffix(self): - """Unit gets appended as a suffix in the displayed summary.""" - self.data.AddSample('test', 250) - self.data.SetUnit('test', 'mw') - self.data.CalculateStats() - summary_str = self.data.SummaryToString() - self.assertIn('test_mw', summary_str) - - def test_DoubleUnitSuffix(self): - """If domain already ends in unit, verify that unit doesn't get appended.""" - self.data.AddSample('test_mw', 250) - self.data.SetUnit('test_mw', 'mw') - self.data.CalculateStats() - summary_str = self.data.SummaryToString() - self.assertIn('test_mw', summary_str) - self.assertNotIn('test_mw_mw', summary_str) - - def test_GetRawData(self): - """GetRawData returns exact same data as fed in.""" - self._populate_mock_stats() - raw_data = self.data.GetRawData() - self.assertListEqual([99999.5, 100000.5], raw_data['A']) - self.assertListEqual([1.5, 2.5, 3.5], raw_data['B']) - - def test_GetSummary(self): - """GetSummary returns expected stats about the data fed in.""" - self._populate_mock_stats() - summary = self.data.GetSummary() - self.assertEqual(2, summary['A']['count']) - self.assertAlmostEqual(100000.5, summary['A']['max']) - self.assertAlmostEqual(99999.5, summary['A']['min']) - self.assertAlmostEqual(0.5, summary['A']['stddev']) - self.assertAlmostEqual(100000.0, summary['A']['mean']) - self.assertEqual(3, summary['B']['count']) - self.assertAlmostEqual(3.5, summary['B']['max']) - self.assertAlmostEqual(1.5, summary['B']['min']) - self.assertAlmostEqual(0.81649658092773, summary['B']['stddev']) - self.assertAlmostEqual(2.5, summary['B']['mean']) - - def test_SaveRawData(self): - """SaveRawData stores same data as fed in.""" - self._populate_mock_stats() - dirname = 'unittest_raw_data' - expected_files = set(['A_mW.txt', 'B_mV.txt']) - fnames = self.data.SaveRawData(self.tempdir, dirname) - files_returned = set([os.path.basename(f) for f in fnames]) - # Assert that only the expected files got returned. - self.assertEqual(expected_files, files_returned) - # Assert that only the returned files are in the outdir. - self.assertEqual(set(os.listdir(os.path.join(self.tempdir, dirname))), - files_returned) - for fname in fnames: - with open(fname, 'r') as f: - if 'A_mW' in fname: - self.assertEqual('99999.50', f.readline().strip()) - self.assertEqual('100000.50', f.readline().strip()) - if 'B_mV' in fname: - self.assertEqual('1.50', f.readline().strip()) - self.assertEqual('2.50', f.readline().strip()) - self.assertEqual('3.50', f.readline().strip()) - - def test_SaveRawDataNoUnit(self): - """SaveRawData appends no unit suffix if the unit is not specified.""" - self._populate_mock_stats_no_unit() - self.data.CalculateStats() - outdir = 'unittest_raw_data' - files = self.data.SaveRawData(self.tempdir, outdir) - files = [os.path.basename(f) for f in files] - # Verify nothing gets appended to domain for filename if no unit exists. - self.assertIn('B.txt', files) - - def test_SaveRawDataSMID(self): - """SaveRawData uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - files = self.data.SaveRawData(self.tempdir) - for fname in files: - self.assertTrue(os.path.basename(fname).startswith(identifier)) - - def test_SummaryToStringNaNHelp(self): - """NaN containing row gets tagged with *, help banner gets added.""" - help_banner_exp = '%s %s' % (stats_manager.STATS_PREFIX, - stats_manager.NAN_DESCRIPTION) - nan_domain = 'A-domain' - nan_domain_exp = '%s%s' % (nan_domain, stats_manager.NAN_TAG) - # NaN helper banner is added when a NaN domain is found & domain gets tagged - data = stats_manager.StatsManager() - data.AddSample(nan_domain, float('NaN')) - data.AddSample(nan_domain, 17) - data.AddSample('B-domain', 17) - data.CalculateStats() - summarystr = data.SummaryToString() - self.assertIn(help_banner_exp, summarystr) - self.assertIn(nan_domain_exp, summarystr) - # NaN helper banner is not added when no NaN domain output, no tagging - data = stats_manager.StatsManager() - # nan_domain in this scenario does not contain any NaN - data.AddSample(nan_domain, 19) - data.AddSample('B-domain', 17) - data.CalculateStats() - summarystr = data.SummaryToString() - self.assertNotIn(help_banner_exp, summarystr) - self.assertNotIn(nan_domain_exp, summarystr) - - def test_SummaryToStringTitle(self): - """Title shows up in SummaryToString if title specified.""" - title = 'titulo' - data = stats_manager.StatsManager(title=title) - self._populate_mock_stats() - summary_str = data.SummaryToString() - self.assertIn(title, summary_str) - - def test_SummaryToStringHideDomains(self): - """Keys indicated in hide_domains are not printed in the summary.""" - data = stats_manager.StatsManager(hide_domains=['A-domain']) - data.AddSample('A-domain', 17) - data.AddSample('B-domain', 17) - data.CalculateStats() - summary_str = data.SummaryToString() - self.assertIn('B-domain', summary_str) - self.assertNotIn('A-domain', summary_str) - - def test_SummaryToStringOrder(self): - """Order passed into StatsManager is honoured when formatting summary.""" - # StatsManager that should print D & B first, and the subsequent elements - # are sorted. - d_b_a_c_regexp = re.compile('D-domain.*B-domain.*A-domain.*C-domain', - re.DOTALL) - data = stats_manager.StatsManager(order=['D-domain', 'B-domain']) - data.AddSample('A-domain', 17) - data.AddSample('B-domain', 17) - data.AddSample('C-domain', 17) - data.AddSample('D-domain', 17) - data.CalculateStats() - summary_str = data.SummaryToString() - self.assertRegexpMatches(summary_str, d_b_a_c_regexp) - - def test_MakeUniqueFName(self): - data = stats_manager.StatsManager() - testfile = os.path.join(self.tempdir, 'testfile.txt') - with open(testfile, 'w') as f: - f.write('') - expected_fname = os.path.join(self.tempdir, 'testfile0.txt') - self.assertEqual(expected_fname, data._MakeUniqueFName(testfile)) - - def test_SaveSummary(self): - """SaveSummary properly dumps the summary into a file.""" - self._populate_mock_stats() - fname = 'unittest_summary.txt' - expected_fname = os.path.join(self.tempdir, fname) - fname = self.data.SaveSummary(self.tempdir, fname) - # Assert the reported fname is the same as the expected fname - self.assertEqual(expected_fname, fname) - # Assert only the reported fname is output (in the tempdir) - self.assertEqual(set([os.path.basename(fname)]), - set(os.listdir(self.tempdir))) - with open(fname, 'r') as f: - self.assertEqual( - '@@ NAME COUNT MEAN STDDEV MAX MIN\n', - f.readline()) - self.assertEqual( - '@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n', - f.readline()) - self.assertEqual( - '@@ B_mV 3 2.50 0.82 3.50 1.50\n', - f.readline()) - - def test_SaveSummarySMID(self): - """SaveSummary uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - fname = os.path.basename(self.data.SaveSummary(self.tempdir)) - self.assertTrue(fname.startswith(identifier)) - - def test_SaveSummaryJSON(self): - """SaveSummaryJSON saves the added data properly in JSON format.""" - self._populate_mock_stats() - fname = 'unittest_summary.json' - expected_fname = os.path.join(self.tempdir, fname) - fname = self.data.SaveSummaryJSON(self.tempdir, fname) - # Assert the reported fname is the same as the expected fname - self.assertEqual(expected_fname, fname) - # Assert only the reported fname is output (in the tempdir) - self.assertEqual(set([os.path.basename(fname)]), - set(os.listdir(self.tempdir))) - with open(fname, 'r') as f: - summary = json.load(f) - self.assertAlmostEqual(100000.0, summary['A']['mean']) - self.assertEqual('milliwatt', summary['A']['unit']) - self.assertAlmostEqual(2.5, summary['B']['mean']) - self.assertEqual('millivolt', summary['B']['unit']) - - def test_SaveSummaryJSONSMID(self): - """SaveSummaryJSON uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir)) - self.assertTrue(fname.startswith(identifier)) - - def test_SaveSummaryJSONNoUnit(self): - """SaveSummaryJSON marks unknown units properly as N/A.""" - self._populate_mock_stats_no_unit() - self.data.CalculateStats() - fname = 'unittest_summary.json' - fname = self.data.SaveSummaryJSON(self.tempdir, fname) - with open(fname, 'r') as f: - summary = json.load(f) - self.assertEqual('blue', summary['A']['unit']) - # if no unit is specified, JSON should save 'N/A' as the unit. - self.assertEqual('N/A', summary['B']['unit']) - -if __name__ == '__main__': - unittest.main() + """Test to verify StatsManager methods work as expected. + + StatsManager should collect raw data, calculate their statistics, and save + them in expected format. + """ + + def _populate_mock_stats(self): + """Create a populated & processed StatsManager to test data retrieval.""" + self.data.AddSample("A", 99999.5) + self.data.AddSample("A", 100000.5) + self.data.SetUnit("A", "uW") + self.data.SetUnit("A", "mW") + self.data.AddSample("B", 1.5) + self.data.AddSample("B", 2.5) + self.data.AddSample("B", 3.5) + self.data.SetUnit("B", "mV") + self.data.CalculateStats() + + def _populate_mock_stats_no_unit(self): + self.data.AddSample("B", 1000) + self.data.AddSample("A", 200) + self.data.SetUnit("A", "blue") + + def setUp(self): + """Set up StatsManager and create a temporary directory for test.""" + self.tempdir = tempfile.mkdtemp() + self.data = stats_manager.StatsManager() + + def tearDown(self): + """Delete the temporary directory and its content.""" + shutil.rmtree(self.tempdir) + + def test_AddSample(self): + """Adding a sample successfully adds a sample.""" + self.data.AddSample("Test", 1000) + self.data.SetUnit("Test", "test") + self.data.CalculateStats() + summary = self.data.GetSummary() + self.assertEqual(1, summary["Test"]["count"]) + + def test_AddSampleNoFloatAcceptNaN(self): + """Adding a non-number adds 'NaN' and doesn't raise an exception.""" + self.data.AddSample("Test", 10) + self.data.AddSample("Test", 20) + # adding a fake NaN: one that gets converted into NaN internally + self.data.AddSample("Test", "fiesta") + # adding a real NaN + self.data.AddSample("Test", float("NaN")) + self.data.SetUnit("Test", "test") + self.data.CalculateStats() + summary = self.data.GetSummary() + # assert that 'NaN' as added. + self.assertEqual(4, summary["Test"]["count"]) + # assert that mean, min, and max calculatings ignore the 'NaN' + self.assertEqual(10, summary["Test"]["min"]) + self.assertEqual(20, summary["Test"]["max"]) + self.assertEqual(15, summary["Test"]["mean"]) + + def test_AddSampleNoFloatNotAcceptNaN(self): + """Adding a non-number raises a StatsManagerError if accept_nan is False.""" + self.data = stats_manager.StatsManager(accept_nan=False) + with self.assertRaisesRegexp( + stats_manager.StatsManagerError, + "accept_nan is false. Cannot add NaN sample.", + ): + # adding a fake NaN: one that gets converted into NaN internally + self.data.AddSample("Test", "fiesta") + with self.assertRaisesRegexp( + stats_manager.StatsManagerError, + "accept_nan is false. Cannot add NaN sample.", + ): + # adding a real NaN + self.data.AddSample("Test", float("NaN")) + + def test_AddSampleNoUnit(self): + """Not adding a unit does not cause an exception on CalculateStats().""" + self.data.AddSample("Test", 17) + self.data.CalculateStats() + summary = self.data.GetSummary() + self.assertEqual(1, summary["Test"]["count"]) + + def test_UnitSuffix(self): + """Unit gets appended as a suffix in the displayed summary.""" + self.data.AddSample("test", 250) + self.data.SetUnit("test", "mw") + self.data.CalculateStats() + summary_str = self.data.SummaryToString() + self.assertIn("test_mw", summary_str) + + def test_DoubleUnitSuffix(self): + """If domain already ends in unit, verify that unit doesn't get appended.""" + self.data.AddSample("test_mw", 250) + self.data.SetUnit("test_mw", "mw") + self.data.CalculateStats() + summary_str = self.data.SummaryToString() + self.assertIn("test_mw", summary_str) + self.assertNotIn("test_mw_mw", summary_str) + + def test_GetRawData(self): + """GetRawData returns exact same data as fed in.""" + self._populate_mock_stats() + raw_data = self.data.GetRawData() + self.assertListEqual([99999.5, 100000.5], raw_data["A"]) + self.assertListEqual([1.5, 2.5, 3.5], raw_data["B"]) + + def test_GetSummary(self): + """GetSummary returns expected stats about the data fed in.""" + self._populate_mock_stats() + summary = self.data.GetSummary() + self.assertEqual(2, summary["A"]["count"]) + self.assertAlmostEqual(100000.5, summary["A"]["max"]) + self.assertAlmostEqual(99999.5, summary["A"]["min"]) + self.assertAlmostEqual(0.5, summary["A"]["stddev"]) + self.assertAlmostEqual(100000.0, summary["A"]["mean"]) + self.assertEqual(3, summary["B"]["count"]) + self.assertAlmostEqual(3.5, summary["B"]["max"]) + self.assertAlmostEqual(1.5, summary["B"]["min"]) + self.assertAlmostEqual(0.81649658092773, summary["B"]["stddev"]) + self.assertAlmostEqual(2.5, summary["B"]["mean"]) + + def test_SaveRawData(self): + """SaveRawData stores same data as fed in.""" + self._populate_mock_stats() + dirname = "unittest_raw_data" + expected_files = set(["A_mW.txt", "B_mV.txt"]) + fnames = self.data.SaveRawData(self.tempdir, dirname) + files_returned = set([os.path.basename(f) for f in fnames]) + # Assert that only the expected files got returned. + self.assertEqual(expected_files, files_returned) + # Assert that only the returned files are in the outdir. + self.assertEqual( + set(os.listdir(os.path.join(self.tempdir, dirname))), files_returned + ) + for fname in fnames: + with open(fname, "r") as f: + if "A_mW" in fname: + self.assertEqual("99999.50", f.readline().strip()) + self.assertEqual("100000.50", f.readline().strip()) + if "B_mV" in fname: + self.assertEqual("1.50", f.readline().strip()) + self.assertEqual("2.50", f.readline().strip()) + self.assertEqual("3.50", f.readline().strip()) + + def test_SaveRawDataNoUnit(self): + """SaveRawData appends no unit suffix if the unit is not specified.""" + self._populate_mock_stats_no_unit() + self.data.CalculateStats() + outdir = "unittest_raw_data" + files = self.data.SaveRawData(self.tempdir, outdir) + files = [os.path.basename(f) for f in files] + # Verify nothing gets appended to domain for filename if no unit exists. + self.assertIn("B.txt", files) + + def test_SaveRawDataSMID(self): + """SaveRawData uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + files = self.data.SaveRawData(self.tempdir) + for fname in files: + self.assertTrue(os.path.basename(fname).startswith(identifier)) + + def test_SummaryToStringNaNHelp(self): + """NaN containing row gets tagged with *, help banner gets added.""" + help_banner_exp = "%s %s" % ( + stats_manager.STATS_PREFIX, + stats_manager.NAN_DESCRIPTION, + ) + nan_domain = "A-domain" + nan_domain_exp = "%s%s" % (nan_domain, stats_manager.NAN_TAG) + # NaN helper banner is added when a NaN domain is found & domain gets tagged + data = stats_manager.StatsManager() + data.AddSample(nan_domain, float("NaN")) + data.AddSample(nan_domain, 17) + data.AddSample("B-domain", 17) + data.CalculateStats() + summarystr = data.SummaryToString() + self.assertIn(help_banner_exp, summarystr) + self.assertIn(nan_domain_exp, summarystr) + # NaN helper banner is not added when no NaN domain output, no tagging + data = stats_manager.StatsManager() + # nan_domain in this scenario does not contain any NaN + data.AddSample(nan_domain, 19) + data.AddSample("B-domain", 17) + data.CalculateStats() + summarystr = data.SummaryToString() + self.assertNotIn(help_banner_exp, summarystr) + self.assertNotIn(nan_domain_exp, summarystr) + + def test_SummaryToStringTitle(self): + """Title shows up in SummaryToString if title specified.""" + title = "titulo" + data = stats_manager.StatsManager(title=title) + self._populate_mock_stats() + summary_str = data.SummaryToString() + self.assertIn(title, summary_str) + + def test_SummaryToStringHideDomains(self): + """Keys indicated in hide_domains are not printed in the summary.""" + data = stats_manager.StatsManager(hide_domains=["A-domain"]) + data.AddSample("A-domain", 17) + data.AddSample("B-domain", 17) + data.CalculateStats() + summary_str = data.SummaryToString() + self.assertIn("B-domain", summary_str) + self.assertNotIn("A-domain", summary_str) + + def test_SummaryToStringOrder(self): + """Order passed into StatsManager is honoured when formatting summary.""" + # StatsManager that should print D & B first, and the subsequent elements + # are sorted. + d_b_a_c_regexp = re.compile( + "D-domain.*B-domain.*A-domain.*C-domain", re.DOTALL + ) + data = stats_manager.StatsManager(order=["D-domain", "B-domain"]) + data.AddSample("A-domain", 17) + data.AddSample("B-domain", 17) + data.AddSample("C-domain", 17) + data.AddSample("D-domain", 17) + data.CalculateStats() + summary_str = data.SummaryToString() + self.assertRegexpMatches(summary_str, d_b_a_c_regexp) + + def test_MakeUniqueFName(self): + data = stats_manager.StatsManager() + testfile = os.path.join(self.tempdir, "testfile.txt") + with open(testfile, "w") as f: + f.write("") + expected_fname = os.path.join(self.tempdir, "testfile0.txt") + self.assertEqual(expected_fname, data._MakeUniqueFName(testfile)) + + def test_SaveSummary(self): + """SaveSummary properly dumps the summary into a file.""" + self._populate_mock_stats() + fname = "unittest_summary.txt" + expected_fname = os.path.join(self.tempdir, fname) + fname = self.data.SaveSummary(self.tempdir, fname) + # Assert the reported fname is the same as the expected fname + self.assertEqual(expected_fname, fname) + # Assert only the reported fname is output (in the tempdir) + self.assertEqual( + set([os.path.basename(fname)]), set(os.listdir(self.tempdir)) + ) + with open(fname, "r") as f: + self.assertEqual( + "@@ NAME COUNT MEAN STDDEV MAX MIN\n", + f.readline(), + ) + self.assertEqual( + "@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n", + f.readline(), + ) + self.assertEqual( + "@@ B_mV 3 2.50 0.82 3.50 1.50\n", + f.readline(), + ) + + def test_SaveSummarySMID(self): + """SaveSummary uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + fname = os.path.basename(self.data.SaveSummary(self.tempdir)) + self.assertTrue(fname.startswith(identifier)) + + def test_SaveSummaryJSON(self): + """SaveSummaryJSON saves the added data properly in JSON format.""" + self._populate_mock_stats() + fname = "unittest_summary.json" + expected_fname = os.path.join(self.tempdir, fname) + fname = self.data.SaveSummaryJSON(self.tempdir, fname) + # Assert the reported fname is the same as the expected fname + self.assertEqual(expected_fname, fname) + # Assert only the reported fname is output (in the tempdir) + self.assertEqual( + set([os.path.basename(fname)]), set(os.listdir(self.tempdir)) + ) + with open(fname, "r") as f: + summary = json.load(f) + self.assertAlmostEqual(100000.0, summary["A"]["mean"]) + self.assertEqual("milliwatt", summary["A"]["unit"]) + self.assertAlmostEqual(2.5, summary["B"]["mean"]) + self.assertEqual("millivolt", summary["B"]["unit"]) + + def test_SaveSummaryJSONSMID(self): + """SaveSummaryJSON uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir)) + self.assertTrue(fname.startswith(identifier)) + + def test_SaveSummaryJSONNoUnit(self): + """SaveSummaryJSON marks unknown units properly as N/A.""" + self._populate_mock_stats_no_unit() + self.data.CalculateStats() + fname = "unittest_summary.json" + fname = self.data.SaveSummaryJSON(self.tempdir, fname) + with open(fname, "r") as f: + summary = json.load(f) + self.assertEqual("blue", summary["A"]["unit"]) + # if no unit is specified, JSON should save 'N/A' as the unit. + self.assertEqual("N/A", summary["B"]["unit"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/usb_serial/add_usb_serial_id b/extra/usb_serial/add_usb_serial_id index ef8336afdc..12e0055e0b 100755 --- a/extra/usb_serial/add_usb_serial_id +++ b/extra/usb_serial/add_usb_serial_id @@ -1,6 +1,6 @@ #!/bin/sh -e # -# Copyright 2016 The Chromium OS Authors. All rights reserved. +# Copyright 2016 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # diff --git a/extra/usb_serial/console.py b/extra/usb_serial/console.py index d06b33ce23..2b0ecd5f13 100755 --- a/extra/usb_serial/console.py +++ b/extra/usb_serial/console.py @@ -1,17 +1,14 @@ -#!/usr/bin/env python -# Copyright 2016 The Chromium OS Authors. All rights reserved. +#!/usr/bin/env python3 +# Copyright 2016 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes """Allow creation of uart/console interface via usb google serial endpoint.""" # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import array import os @@ -19,25 +16,25 @@ import sys import termios import threading import time -import traceback import tty + try: - import usb -except: - print("import usb failed") - print("try running these commands:") - print(" sudo apt-get install python-pip") - print(" sudo pip install --pre pyusb") - print() - sys.exit(-1) + import usb # pylint:disable=import-error +except ModuleNotFoundError: + print("import usb failed") + print("try running these commands:") + print(" sudo apt-get install python-pip") + print(" sudo pip install --pre pyusb") + print() + sys.exit(-1) import six def GetBuffer(stream): - if six.PY3: - return stream.buffer - return stream + if six.PY3: + return stream.buffer + return stream """Class Susb covers USB device discovery and initialization. @@ -46,99 +43,103 @@ def GetBuffer(stream): and interface number. """ + class SusbError(Exception): - """Class for exceptions of Susb.""" - def __init__(self, msg, value=0): - """SusbError constructor. + """Class for exceptions of Susb.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SusbError, self).__init__(msg, value) - self.msg = msg - self.value = value - -class Susb(): - """Provide USB functionality. - - Instance Variables: - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - READ_ENDPOINT = 0x81 - WRITE_ENDPOINT = 0x1 - TIMEOUT_MS = 100 - - def __init__(self, vendor=0x18d1, - product=0x500f, interface=1, serialname=None): - """Susb constructor. - - Discovers and connects to USB endpoints. - - Args: - vendor : usb vendor id of device - product : usb product id of device - interface : interface number ( 1 - 8 ) of device to use - serialname: string of device serialnumber. - - Raises: - SusbError: An error accessing Susb object + def __init__(self, msg, value=0): + """SusbError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SusbError, self).__init__(msg, value) + self.msg = msg + self.value = value + + +class Susb: + """Provide USB functionality. + + Instance Variables: + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - # Find the device. - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise SusbError("USB device not found") - - # Check if we have multiple devices. - dev = None - if serialname: - for d in dev_list: - dev_serial = "PyUSB doesn't have a stable interface" - try: - dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) - except: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - raise SusbError("USB device(%s) not found" % (serialname,)) - else: - try: - dev = dev_list[0] - except: - try: - dev = dev_list.next() - except: - raise SusbError("USB device %04x:%04x not found" % (vendor, product)) - # If we can't set configuration, it's already been set. - try: - dev.set_configuration() - except usb.core.USBError: - pass + READ_ENDPOINT = 0x81 + WRITE_ENDPOINT = 0x1 + TIMEOUT_MS = 100 + + def __init__( + self, vendor=0x18D1, product=0x500F, interface=1, serialname=None + ): + """Susb constructor. + + Discovers and connects to USB endpoints. + + Args: + vendor: usb vendor id of device + product: usb product id of device + interface: interface number ( 1 - 8 ) of device to use + serialname: string of device serialnumber. + + Raises: + SusbError: An error accessing Susb object + """ + # Find the device. + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise SusbError("USB device not found") + + # Check if we have multiple devices. + dev = None + if serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % (serialname,)) + else: + try: + dev = dev_list[0] + except IndexError: + raise SusbError( + "USB device %04x:%04x not found" % (vendor, product) + ) + + # If we can't set configuration, it's already been set. + try: + dev.set_configuration() + except usb.core.USBError: + pass - # Get an endpoint instance. - cfg = dev.get_active_configuration() - intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface) - self._intf = intf + # Get an endpoint instance. + cfg = dev.get_active_configuration() + intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface) + self._intf = intf - if not intf: - raise SusbError("Interface not found") + if not intf: + raise SusbError("Interface not found") - # Detach raiden.ko if it is loaded. - if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: + # Detach raiden.ko if it is loaded. + if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: dev.detach_kernel_driver(intf.bInterfaceNumber) - read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT - read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) - self._read_ep = read_ep + read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT + read_ep = usb.util.find_descriptor( + intf, bEndpointAddress=read_ep_number + ) + self._read_ep = read_ep - write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT - write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) - self._write_ep = write_ep + write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT + write_ep = usb.util.find_descriptor( + intf, bEndpointAddress=write_ep_number + ) + self._write_ep = write_ep """Suart class implements a stream interface, to access Google's USB class. @@ -147,90 +148,96 @@ class Susb(): and forwards them across. This particular class is hardcoded to stdin/out. """ -class SuartError(Exception): - """Class for exceptions of Suart.""" - def __init__(self, msg, value=0): - """SuartError constructor. - - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SuartError, self).__init__(msg, value) - self.msg = msg - self.value = value - - -class Suart(): - """Provide interface to serial usb endpoint.""" - - def __init__(self, vendor=0x18d1, product=0x501c, interface=0, - serialname=None): - """Suart contstructor. - - Initializes USB stream interface. - - Args: - vendor: usb vendor id of device - product: usb product id of device - interface: interface number of device to use - serialname: Defaults to None. - - Raises: - SuartError: If init fails - """ - self._done = threading.Event() - self._susb = Susb(vendor=vendor, product=product, - interface=interface, serialname=serialname) - - def wait_until_done(self, timeout=None): - return self._done.wait(timeout=timeout) - - def run_rx_thread(self): - try: - while True: - try: - r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) - if r: - GetBuffer(sys.stdout).write(r.tobytes()) - GetBuffer(sys.stdout).flush() - - except Exception as e: - # If we miss some characters on pty disconnect, that's fine. - # ep.read() also throws USBError on timeout, which we discard. - if not isinstance(e, (OSError, usb.core.USBError)): - print("rx %s" % e) - finally: - self._done.set() - - def run_tx_thread(self): - try: - while True: - try: - r = GetBuffer(sys.stdin).read(1) - if not r or r == b"\x03": - break - if r: - self._susb._write_ep.write(array.array('B', r), - self._susb.TIMEOUT_MS) - except Exception as e: - print("tx %s" % e) - finally: - self._done.set() - - def run(self): - """Creates pthreads to poll USB & PTY for data. - """ - self._exit = False - - self._rx_thread = threading.Thread(target=self.run_rx_thread) - self._rx_thread.daemon = True - self._rx_thread.start() - - self._tx_thread = threading.Thread(target=self.run_tx_thread) - self._tx_thread.daemon = True - self._tx_thread.start() +class SuartError(Exception): + """Class for exceptions of Suart.""" + + def __init__(self, msg, value=0): + """SuartError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SuartError, self).__init__(msg, value) + self.msg = msg + self.value = value + + +class Suart: + """Provide interface to serial usb endpoint.""" + + def __init__( + self, vendor=0x18D1, product=0x501C, interface=0, serialname=None + ): + """Suart contstructor. + + Initializes USB stream interface. + + Args: + vendor: usb vendor id of device + product: usb product id of device + interface: interface number of device to use + serialname: Defaults to None. + + Raises: + SuartError: If init fails + """ + self._done = threading.Event() + self._susb = Susb( + vendor=vendor, + product=product, + interface=interface, + serialname=serialname, + ) + + def wait_until_done(self, timeout=None): + return self._done.wait(timeout=timeout) + + def run_rx_thread(self): + try: + while True: + try: + r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) + if r: + GetBuffer(sys.stdout).write(r.tobytes()) + GetBuffer(sys.stdout).flush() + + except Exception as e: + # If we miss some characters on pty disconnect, that's fine. + # ep.read() also throws USBError on timeout, which we discard. + if not isinstance(e, (OSError, usb.core.USBError)): + print("rx %s" % e) + finally: + self._done.set() + + def run_tx_thread(self): + try: + while True: + try: + r = GetBuffer(sys.stdin).read(1) + if not r or r == b"\x03": + break + if r: + self._susb._write_ep.write( + array.array("B", r), self._susb.TIMEOUT_MS + ) + except Exception as e: + print("tx %s" % e) + finally: + self._done.set() + + def run(self): + """Creates pthreads to poll USB & PTY for data.""" + self._exit = False + + self._rx_thread = threading.Thread(target=self.run_rx_thread) + self._rx_thread.daemon = True + self._rx_thread.start() + + self._tx_thread = threading.Thread(target=self.run_tx_thread) + self._tx_thread.daemon = True + self._tx_thread.start() """Command line functionality @@ -239,60 +246,76 @@ class Suart(): Ctrl-C exits. """ -parser = argparse.ArgumentParser(description="Open a console to a USB device") -parser.add_argument('-d', '--device', type=str, - help="vid:pid of target device", default="18d1:501c") -parser.add_argument('-i', '--interface', type=int, - help="interface number of console", default=0) -parser.add_argument('-s', '--serialno', type=str, - help="serial number of device", default="") -parser.add_argument('-S', '--notty-exit-sleep', type=float, default=0.2, - help="When stdin is *not* a TTY, wait this many seconds after EOF from " - "stdin before exiting, to give time for receiving a reply from the USB " - "device.") +parser = argparse.ArgumentParser( + description="Open a console to a USB device", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "-d", + "--device", + type=str, + help="vid:pid of target device", + default="18d1:501c", +) +parser.add_argument( + "-i", "--interface", type=int, help="interface number of console", default=0 +) +parser.add_argument( + "-s", "--serialno", type=str, help="serial number of device", default="" +) +parser.add_argument( + "-S", + "--notty-exit-sleep", + type=float, + default=0.2, + help="When stdin is *not* a TTY, wait this many seconds " + "after EOF from stdin before exiting, to give time for " + "receiving a reply from the USB device.", +) def runconsole(): - """Run the usb console code + """Run the usb console code - Starts the pty thread, and idles until a ^C is caught. - """ - args = parser.parse_args() + Starts the pty thread, and idles until a ^C is caught. + """ + args = parser.parse_args() - vidstr, pidstr = args.device.split(':') - vid = int(vidstr, 16) - pid = int(pidstr, 16) + vidstr, pidstr = args.device.split(":") + vid = int(vidstr, 16) + pid = int(pidstr, 16) - serialno = args.serialno - interface = args.interface + serialno = args.serialno + interface = args.interface - sobj = Suart(vendor=vid, product=pid, interface=interface, - serialname=serialno) - if sys.stdin.isatty(): - tty.setraw(sys.stdin.fileno()) - sobj.run() - sobj.wait_until_done() - if not sys.stdin.isatty() and args.notty_exit_sleep > 0: - time.sleep(args.notty_exit_sleep) + sobj = Suart( + vendor=vid, product=pid, interface=interface, serialname=serialno + ) + if sys.stdin.isatty(): + tty.setraw(sys.stdin.fileno()) + sobj.run() + sobj.wait_until_done() + if not sys.stdin.isatty() and args.notty_exit_sleep > 0: + time.sleep(args.notty_exit_sleep) def main(): - stdin_isatty = sys.stdin.isatty() - if stdin_isatty: - fd = sys.stdin.fileno() - os.system("stty -echo") - old_settings = termios.tcgetattr(fd) - - try: - runconsole() - finally: + stdin_isatty = sys.stdin.isatty() if stdin_isatty: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - os.system("stty echo") - # Avoid having the user's shell prompt start mid-line after the final output - # from this program. - print() + fd = sys.stdin.fileno() + os.system("stty -echo") + old_settings = termios.tcgetattr(fd) + + try: + runconsole() + finally: + if stdin_isatty: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + os.system("stty echo") + # Avoid having the user's shell prompt start mid-line after the final output + # from this program. + print() -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/usb_serial/install b/extra/usb_serial/install index eba1d2ac83..b49ad990e1 100755 --- a/extra/usb_serial/install +++ b/extra/usb_serial/install @@ -1,6 +1,6 @@ #!/bin/sh -e # -# Copyright 2016 The Chromium OS Authors. All rights reserved. +# Copyright 2016 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # diff --git a/extra/usb_serial/raiden.c b/extra/usb_serial/raiden.c index e4720b4357..131cddb00f 100644 --- a/extra/usb_serial/raiden.c +++ b/extra/usb_serial/raiden.c @@ -2,7 +2,7 @@ * USB Serial module for Raiden USB debug serial console forwarding. * SubClass and Protocol allocated in go/usb-ids * - * Copyright 2014 The Chromium OS Authors <chromium-os-dev@chromium.org> + * Copyright 2014 The ChromiumOS Authors <chromium-os-dev@chromium.org> * Author: Anton Staaf <robotboy@chromium.org> * * This program is free software; you can redistribute it and/or modify @@ -19,28 +19,25 @@ MODULE_LICENSE("GPL"); -#define USB_VENDOR_ID_GOOGLE 0x18d1 -#define USB_SUBCLASS_GOOGLE_SERIAL 0x50 -#define USB_PROTOCOL_GOOGLE_SERIAL 0x01 +#define USB_VENDOR_ID_GOOGLE 0x18d1 +#define USB_SUBCLASS_GOOGLE_SERIAL 0x50 +#define USB_PROTOCOL_GOOGLE_SERIAL 0x01 static struct usb_device_id const ids[] = { - { USB_VENDOR_AND_INTERFACE_INFO(USB_VENDOR_ID_GOOGLE, - USB_CLASS_VENDOR_SPEC, - USB_SUBCLASS_GOOGLE_SERIAL, - USB_PROTOCOL_GOOGLE_SERIAL) }, - { 0 } + { USB_VENDOR_AND_INTERFACE_INFO( + USB_VENDOR_ID_GOOGLE, USB_CLASS_VENDOR_SPEC, + USB_SUBCLASS_GOOGLE_SERIAL, USB_PROTOCOL_GOOGLE_SERIAL) }, + { 0 } }; MODULE_DEVICE_TABLE(usb, ids); -static struct usb_serial_driver device = -{ - .driver = { .owner = THIS_MODULE, - .name = "Google" }, - .id_table = ids, +static struct usb_serial_driver device = { + .driver = { .owner = THIS_MODULE, .name = "Google" }, + .id_table = ids, .num_ports = 1, }; -static struct usb_serial_driver * const drivers[] = { &device, NULL }; +static struct usb_serial_driver *const drivers[] = { &device, NULL }; module_usb_serial_driver(drivers, ids); diff --git a/extra/usb_updater/Makefile b/extra/usb_updater/Makefile index 1dfbc55645..5a8dc82c28 100644 --- a/extra/usb_updater/Makefile +++ b/extra/usb_updater/Makefile @@ -1,4 +1,4 @@ -# Copyright 2015 The Chromium OS Authors. All rights reserved. +# Copyright 2015 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. @@ -19,10 +19,10 @@ CFLAGS := -std=gnu99 \ -Wredundant-decls \ -Wmissing-declarations -ifeq (DEBUG,) -CFLAGS += -O3 -else +ifneq ($(DEBUG),) CFLAGS += -O0 +else +CFLAGS += -O3 endif # @@ -52,4 +52,3 @@ clean: parser_debug: desc_parser.c gcc -g -O0 -DTEST_PARSER desc_parser.c -o dp - diff --git a/extra/usb_updater/desc_parser.c b/extra/usb_updater/desc_parser.c index 5bd996bdda..7e9f583902 100644 --- a/extra/usb_updater/desc_parser.c +++ b/extra/usb_updater/desc_parser.c @@ -1,5 +1,5 @@ /* - * Copyright 2018 The Chromium OS Authors. All rights reserved. + * Copyright 2018 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -75,8 +75,7 @@ static int get_next_token(char *input, size_t expected_size, char **output) next_colon = strchr(input, ':'); if (next_colon) *next_colon = '\0'; - if (!next_colon || (expected_size && - strlen(input) != expected_size)) { + if (!next_colon || (expected_size && strlen(input) != expected_size)) { fprintf(stderr, "Invalid entry in section %d\n", section_count_); return -EINVAL; @@ -98,16 +97,15 @@ static int get_hex_value(char *input, char **output) value = strtol(input, &e, 16); if ((e && *e) || (strlen(input) > 8)) { - fprintf(stderr, "Invalid hex value %s in section %d\n", - input, section_count_); + fprintf(stderr, "Invalid hex value %s in section %d\n", input, + section_count_); return -EINVAL; } return value; } -static int parse_range(char *next_line, - size_t line_len, +static int parse_range(char *next_line, size_t line_len, struct addr_range *parsed_range) { char *line_cursor; @@ -299,7 +297,6 @@ int parser_get_next_range(struct addr_range **range) *range = new_range; return 0; - } int parser_find_board(const char *hash_file_name, const char *board_id) diff --git a/extra/usb_updater/desc_parser.h b/extra/usb_updater/desc_parser.h index faa80d1a63..e459927b57 100644 --- a/extra/usb_updater/desc_parser.h +++ b/extra/usb_updater/desc_parser.h @@ -1,5 +1,5 @@ /* - * Copyright 2018 The Chromium OS Authors. All rights reserved. + * Copyright 2018 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ diff --git a/extra/usb_updater/fw_update.py b/extra/usb_updater/fw_update.py index 0d7a570fc3..a77de94a7c 100755 --- a/extra/usb_updater/fw_update.py +++ b/extra/usb_updater/fw_update.py @@ -1,11 +1,7 @@ #!/usr/bin/env python -# Copyright 2016 The Chromium OS Authors. All rights reserved. +# Copyright 2016 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes # Upload firmware over USB # Note: This is a py2/3 compatible file. @@ -20,407 +16,436 @@ import struct import sys import time from pprint import pprint -import usb +import usb # pylint:disable=import-error +from ecusb.stm32usb import SusbError debug = False -def debuglog(msg): - if debug: - print(msg) - -def log(msg): - print(msg) - sys.stdout.flush() - - -"""Sends firmware update to CROS EC usb endpoint.""" - -class Supdate(object): - """Class to access firmware update endpoints. - - Usage: - d = Supdate() - - Instance Variables: - _dev: pyUSB device object - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - USB_SUBCLASS_GOOGLE_UPDATE = 0x53 - USB_CLASS_VENDOR = 0xFF - def __init__(self): - pass +def debuglog(msg): + if debug: + print(msg) - def connect_usb(self, serialname=None): - """Initial discovery and connection to USB endpoint. - - This searches for a USB device matching the VID:PID specified - in the config file, optionally matching a specified serialname. - - Args: - serialname: Find the device with this serial, in case multiple - devices are attached. - - Returns: - True on success. - Raises: - Exception on error. - """ - # Find the stm32. - vendor = self._brdcfg['vid'] - product = self._brdcfg['pid'] - - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise Exception("Update", "USB device not found") - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - if usb.util.get_string(d, d.iSerialNumber) == serialname: - dev = d - break - if dev is None: - raise SusbError("USB device(%s) not found" % serialname) - else: - try: - dev = dev_list[0] - except: - dev = dev_list.next() - - debuglog("Found stm32: %04x:%04x" % (vendor, product)) - self._dev = dev - - # Get an endpoint instance. - try: - dev.set_configuration() - except: - pass - cfg = dev.get_active_configuration() - - intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \ - i.bInterfaceClass==self.USB_CLASS_VENDOR and \ - i.bInterfaceSubClass==self.USB_SUBCLASS_GOOGLE_UPDATE) - - self._intf = intf - debuglog("Interface: %s" % intf) - debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber) - - read_ep = usb.util.find_descriptor( - intf, - # match the first IN endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_IN - ) - - self._read_ep = read_ep - debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress) - - write_ep = usb.util.find_descriptor( - intf, - # match the first OUT endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_OUT - ) - - self._write_ep = write_ep - debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress) - - return True - - - def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000): - """Write command to logger logic.. - - This function writes byte command values list to stm, then reads - byte status. - - Args: - write_list: list of command byte values [0~255]. - read_count: number of status byte values to read. - wtimeout: mS to wait for write success - rtimeout: mS to wait for read success - - Returns: - status byte, if one byte is read, - byte list, if multiple bytes are read, - None, if no bytes are read. - - Interface: - write: [command, data ... ] - read: [status ] - """ - debuglog("wr_command(write_list=[%s] (%d), read_count=%s)" % ( - list(bytearray(write_list)), len(write_list), read_count)) - - # Clean up args from python style to correct types. - write_length = 0 - if write_list: - write_length = len(write_list) - if not read_count: - read_count = 0 - - # Send command to stm32. - if write_list: - cmd = write_list - ret = self._write_ep.write(cmd, wtimeout) - debuglog("RET: %s " % ret) - - # Read back response if necessary. - if read_count: - bytesread = self._read_ep.read(512, rtimeout) - debuglog("BYTES: [%s]" % bytesread) - - if len(bytesread) != read_count: - debuglog("Unexpected bytes read: %d, expected: %d" % (len(bytesread), read_count)) - pass - - debuglog("STATUS: 0x%02x" % int(bytesread[0])) - if read_count == 1: - return bytesread[0] - else: - return bytesread - - return None - - def stop(self): - """Finalize system flash and exit.""" - cmd = struct.pack(">I", 0xB007AB1E) - read = self.wr_command(cmd, read_count=4) - if len(read) == 4: - log("Finished flashing") - return +def log(msg): + print(msg) + sys.stdout.flush() - raise Exception("Update", "Stop failed [%s]" % read) +"""Sends firmware update to CROS EC usb endpoint.""" - def write_file(self): - """Write the update region packet by packet to USB - This sends write packets of size 128B out, in 32B chunks. - Overall, this will write all data in the inactive code region. +class Supdate(object): + """Class to access firmware update endpoints. - Raises: - Exception if write failed or address out of bounds. - """ - region = self._region - flash_base = self._brdcfg["flash"] - offset = self._base - flash_base - if offset != self._brdcfg['regions'][region][0]: - raise Exception("Update", "Region %s offset 0x%x != available offset 0x%x" % ( - region, self._brdcfg['regions'][region][0], offset)) - - length = self._brdcfg['regions'][region][1] - log("Sending") - - # Go to the correct region in the ec.bin file. - self._binfile.seek(offset) - - # Send 32 bytes at a time. Must be less than the endpoint's max packet size. - maxpacket = 32 - - # While data is left, create update packets. - while length > 0: - # Update packets are 128B. We can use any number - # but the micro must malloc this memory. - pagesize = min(length, 128) - - # Packet is: - # packet size: page bytes transferred plus 3 x 32b values header. - # cmd: n/a - # base: flash address to write this packet. - # data: 128B of data to write into flash_base - cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base) - read = self.wr_command(cmd, read_count=0) - - # Push 'todo' bytes out the pipe. - todo = pagesize - while todo > 0: - packetsize = min(maxpacket, todo) - data = self._binfile.read(packetsize) - if len(data) != packetsize: - raise Exception("Update", "No more data from file") - for i in range(0, 10): - try: - self.wr_command(data, read_count=0) - break - except: - log("Timeout fail") - todo -= packetsize - # Done with this packet, move to the next one. - length -= pagesize - offset += pagesize - - # Validate that the micro thinks it successfully wrote the data. - read = self.wr_command(''.encode(), read_count=4) - result = struct.unpack("<I", read) - result = result[0] - if result != 0: - raise Exception("Update", "Upload failed with rc: 0x%x" % result) - - - def start(self): - """Start a transaction and erase currently inactive region. - - This function sends a start command, and receives the base of the - preferred inactive region. This could be RW, RW_B, - or RO (if there's no RW_B) - - Note that the region is erased here, so you'd better program the RO if - you just erased it. TODO(nsanders): Modify the protocol to allow active - region select or query before erase. - """ + Usage: + d = Supdate() - # Size is 3 uint32 fields - # packet: [packetsize, cmd, base] - size = 4 + 4 + 4 - # Return value is [status, base_addr] - expected = 4 + 4 - - cmd = struct.pack("<III", size, 0, 0) - read = self.wr_command(cmd, read_count=expected) - - if len(read) == 4: - raise Exception("Update", "Protocol version 0 not supported") - elif len(read) == expected: - base, version = struct.unpack(">II", read) - log("Update protocol v. %d" % version) - log("Available flash region base: %x" % base) - else: - raise Exception("Update", "Start command returned %d bytes" % len(read)) - - if base < 256: - raise Exception("Update", "Start returned error code 0x%x" % base) - - self._base = base - flash_base = self._brdcfg["flash"] - self._offset = self._base - flash_base - - # Find our active region. - for region in self._brdcfg['regions']: - if (self._offset >= self._brdcfg['regions'][region][0]) and \ - (self._offset < (self._brdcfg['regions'][region][0] + \ - self._brdcfg['regions'][region][1])): - log("Active region: %s" % region) - self._region = region - - - def load_board(self, brdfile): - """Load firmware layout file. - - example as follows: - { - "board": "servo micro", - "vid": 6353, - "pid": 20506, - "flash": 134217728, - "regions": { - "RW": [65536, 65536], - "PSTATE": [61440, 4096], - "RO": [0, 61440] - } - } - - Args: - brdfile: path to board description file. + Instance Variables: + _dev: pyUSB device object + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - with open(brdfile) as data_file: - data = json.load(data_file) - - # TODO(nsanders): validate this data before moving on. - self._brdcfg = data; - if debug: - pprint(data) - - log("Board is %s" % self._brdcfg['board']) - # Cast hex strings to int. - self._brdcfg['flash'] = int(self._brdcfg['flash'], 0) - self._brdcfg['vid'] = int(self._brdcfg['vid'], 0) - self._brdcfg['pid'] = int(self._brdcfg['pid'], 0) - - log("Flash Base is %x" % self._brdcfg['flash']) - self._flashsize = 0 - for region in self._brdcfg['regions']: - base = int(self._brdcfg['regions'][region][0], 0) - length = int(self._brdcfg['regions'][region][1], 0) - log("region %s\tbase:0x%08x size:0x%08x" % ( - region, base, length)) - self._flashsize += length - # Convert these to int because json doesn't support hex. - self._brdcfg['regions'][region][0] = base - self._brdcfg['regions'][region][1] = length + USB_SUBCLASS_GOOGLE_UPDATE = 0x53 + USB_CLASS_VENDOR = 0xFF - log("Flash Size: 0x%x" % self._flashsize) - - def load_file(self, binfile): - """Open and verify size of the target ec.bin file. - - Args: - binfile: path to ec.bin - - Raises: - Exception on file not found or filesize not matching. - """ - self._filesize = os.path.getsize(binfile) - self._binfile = open(binfile, 'rb') - - if self._filesize != self._flashsize: - raise Exception("Update", "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize)) + def __init__(self): + pass + def connect_usb(self, serialname=None): + """Initial discovery and connection to USB endpoint. + + This searches for a USB device matching the VID:PID specified + in the config file, optionally matching a specified serialname. + + Args: + serialname: Find the device with this serial, in case multiple + devices are attached. + + Returns: + True on success. + Raises: + Exception on error. + """ + # Find the stm32. + vendor = self._brdcfg["vid"] + product = self._brdcfg["pid"] + + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise Exception("Update", "USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + if usb.util.get_string(d, d.iSerialNumber) == serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % serialname) + else: + dev = dev_list[0] + + debuglog("Found stm32: %04x:%04x" % (vendor, product)) + self._dev = dev + + # Get an endpoint instance. + try: + dev.set_configuration() + except: + pass + cfg = dev.get_active_configuration() + + intf = usb.util.find_descriptor( + cfg, + custom_match=lambda i: i.bInterfaceClass == self.USB_CLASS_VENDOR + and i.bInterfaceSubClass == self.USB_SUBCLASS_GOOGLE_UPDATE, + ) + + self._intf = intf + debuglog("Interface: %s" % intf) + debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber) + + read_ep = usb.util.find_descriptor( + intf, + # match the first IN endpoint + custom_match=lambda e: usb.util.endpoint_direction( + e.bEndpointAddress + ) + == usb.util.ENDPOINT_IN, + ) + + self._read_ep = read_ep + debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress) + + write_ep = usb.util.find_descriptor( + intf, + # match the first OUT endpoint + custom_match=lambda e: usb.util.endpoint_direction( + e.bEndpointAddress + ) + == usb.util.ENDPOINT_OUT, + ) + + self._write_ep = write_ep + debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress) + + return True + + def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000): + """Write command to logger logic.. + + This function writes byte command values list to stm, then reads + byte status. + + Args: + write_list: list of command byte values [0~255]. + read_count: number of status byte values to read. + wtimeout: mS to wait for write success + rtimeout: mS to wait for read success + + Returns: + status byte, if one byte is read, + byte list, if multiple bytes are read, + None, if no bytes are read. + + Interface: + write: [command, data ... ] + read: [status ] + """ + debuglog( + "wr_command(write_list=[%s] (%d), read_count=%s)" + % (list(bytearray(write_list)), len(write_list), read_count) + ) + + # Clean up args from python style to correct types. + write_length = 0 + if write_list: + write_length = len(write_list) + if not read_count: + read_count = 0 + + # Send command to stm32. + if write_list: + cmd = write_list + ret = self._write_ep.write(cmd, wtimeout) + debuglog("RET: %s " % ret) + + # Read back response if necessary. + if read_count: + bytesread = self._read_ep.read(512, rtimeout) + debuglog("BYTES: [%s]" % bytesread) + + if len(bytesread) != read_count: + debuglog( + "Unexpected bytes read: %d, expected: %d" + % (len(bytesread), read_count) + ) + pass + + debuglog("STATUS: 0x%02x" % int(bytesread[0])) + if read_count == 1: + return bytesread[0] + else: + return bytesread + + return None + + def stop(self): + """Finalize system flash and exit.""" + cmd = struct.pack(">I", 0xB007AB1E) + read = self.wr_command(cmd, read_count=4) + + if len(read) == 4: + log("Finished flashing") + return + + raise Exception("Update", "Stop failed [%s]" % read) + + def write_file(self): + """Write the update region packet by packet to USB + + This sends write packets of size 128B out, in 32B chunks. + Overall, this will write all data in the inactive code region. + + Raises: + Exception if write failed or address out of bounds. + """ + region = self._region + flash_base = self._brdcfg["flash"] + offset = self._base - flash_base + if offset != self._brdcfg["regions"][region][0]: + raise Exception( + "Update", + "Region %s offset 0x%x != available offset 0x%x" + % (region, self._brdcfg["regions"][region][0], offset), + ) + + length = self._brdcfg["regions"][region][1] + log("Sending") + + # Go to the correct region in the ec.bin file. + self._binfile.seek(offset) + + # Send 32 bytes at a time. Must be less than the endpoint's max packet size. + maxpacket = 32 + + # While data is left, create update packets. + while length > 0: + # Update packets are 128B. We can use any number + # but the micro must malloc this memory. + pagesize = min(length, 128) + + # Packet is: + # packet size: page bytes transferred plus 3 x 32b values header. + # cmd: n/a + # base: flash address to write this packet. + # data: 128B of data to write into flash_base + cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base) + read = self.wr_command(cmd, read_count=0) + + # Push 'todo' bytes out the pipe. + todo = pagesize + while todo > 0: + packetsize = min(maxpacket, todo) + data = self._binfile.read(packetsize) + if len(data) != packetsize: + raise Exception("Update", "No more data from file") + for i in range(0, 10): + try: + self.wr_command(data, read_count=0) + break + except: + log("Timeout fail") + todo -= packetsize + # Done with this packet, move to the next one. + length -= pagesize + offset += pagesize + + # Validate that the micro thinks it successfully wrote the data. + read = self.wr_command("".encode(), read_count=4) + result = struct.unpack("<I", read) + result = result[0] + if result != 0: + raise Exception( + "Update", "Upload failed with rc: 0x%x" % result + ) + + def start(self): + """Start a transaction and erase currently inactive region. + + This function sends a start command, and receives the base of the + preferred inactive region. This could be RW, RW_B, + or RO (if there's no RW_B) + + Note that the region is erased here, so you'd better program the RO if + you just erased it. TODO(nsanders): Modify the protocol to allow active + region select or query before erase. + """ + + # Size is 3 uint32 fields + # packet: [packetsize, cmd, base] + size = 4 + 4 + 4 + # Return value is [status, base_addr] + expected = 4 + 4 + + cmd = struct.pack("<III", size, 0, 0) + read = self.wr_command(cmd, read_count=expected) + + if len(read) == 4: + raise Exception("Update", "Protocol version 0 not supported") + elif len(read) == expected: + base, version = struct.unpack(">II", read) + log("Update protocol v. %d" % version) + log("Available flash region base: %x" % base) + else: + raise Exception( + "Update", "Start command returned %d bytes" % len(read) + ) + + if base < 256: + raise Exception("Update", "Start returned error code 0x%x" % base) + + self._base = base + flash_base = self._brdcfg["flash"] + self._offset = self._base - flash_base + + # Find our active region. + for region in self._brdcfg["regions"]: + if (self._offset >= self._brdcfg["regions"][region][0]) and ( + self._offset + < ( + self._brdcfg["regions"][region][0] + + self._brdcfg["regions"][region][1] + ) + ): + log("Active region: %s" % region) + self._region = region + + def load_board(self, brdfile): + """Load firmware layout file. + + example as follows: + { + "board": "servo micro", + "vid": 6353, + "pid": 20506, + "flash": 134217728, + "regions": { + "RW": [65536, 65536], + "PSTATE": [61440, 4096], + "RO": [0, 61440] + } + } + + Args: + brdfile: path to board description file. + """ + with open(brdfile) as data_file: + data = json.load(data_file) + + # TODO(nsanders): validate this data before moving on. + self._brdcfg = data + if debug: + pprint(data) + + log("Board is %s" % self._brdcfg["board"]) + # Cast hex strings to int. + self._brdcfg["flash"] = int(self._brdcfg["flash"], 0) + self._brdcfg["vid"] = int(self._brdcfg["vid"], 0) + self._brdcfg["pid"] = int(self._brdcfg["pid"], 0) + + log("Flash Base is %x" % self._brdcfg["flash"]) + self._flashsize = 0 + for region in self._brdcfg["regions"]: + base = int(self._brdcfg["regions"][region][0], 0) + length = int(self._brdcfg["regions"][region][1], 0) + log("region %s\tbase:0x%08x size:0x%08x" % (region, base, length)) + self._flashsize += length + + # Convert these to int because json doesn't support hex. + self._brdcfg["regions"][region][0] = base + self._brdcfg["regions"][region][1] = length + + log("Flash Size: 0x%x" % self._flashsize) + + def load_file(self, binfile): + """Open and verify size of the target ec.bin file. + + Args: + binfile: path to ec.bin + + Raises: + Exception on file not found or filesize not matching. + """ + self._filesize = os.path.getsize(binfile) + self._binfile = open(binfile, "rb") + + if self._filesize != self._flashsize: + raise Exception( + "Update", + "Flash size 0x%x != file size 0x%x" + % (self._flashsize, self._filesize), + ) # Generate command line arguments parser = argparse.ArgumentParser(description="Update firmware over usb") -parser.add_argument('-b', '--board', type=str, help="Board configuration json file", default="board.json") -parser.add_argument('-f', '--file', type=str, help="Complete ec.bin file", default="ec.bin") -parser.add_argument('-s', '--serial', type=str, help="Serial number", default="") -parser.add_argument('-l', '--list', action="store_true", help="List regions") -parser.add_argument('-v', '--verbose', action="store_true", help="Chatty output") +parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration json file", + default="board.json", +) +parser.add_argument( + "-f", "--file", type=str, help="Complete ec.bin file", default="ec.bin" +) +parser.add_argument( + "-s", "--serial", type=str, help="Serial number", default="" +) +parser.add_argument("-l", "--list", action="store_true", help="List regions") +parser.add_argument( + "-v", "--verbose", action="store_true", help="Chatty output" +) + def main(): - global debug - args = parser.parse_args() + global debug + args = parser.parse_args() + brdfile = args.board + serial = args.serial + binfile = args.file + if args.verbose: + debug = True - brdfile = args.board - serial = args.serial - binfile = args.file - if args.verbose: - debug = True + with open(brdfile) as data_file: + names = json.load(data_file) - with open(brdfile) as data_file: - names = json.load(data_file) + p = Supdate() + p.load_board(brdfile) + p.connect_usb(serialname=serial) + p.load_file(binfile) - p = Supdate() - p.load_board(brdfile) - p.connect_usb(serialname=serial) - p.load_file(binfile) + # List solely prints the config. + if args.list: + return - # List solely prints the config. - if (args.list): - return + # Start transfer and erase. + p.start() + # Upload the bin file + log("Uploading %s" % binfile) + p.write_file() - # Start transfer and erase. - p.start() - # Upload the bin file - log("Uploading %s" % binfile) - p.write_file() + # Finalize + log("Done. Finalizing.") + p.stop() - # Finalize - log("Done. Finalizing.") - p.stop() if __name__ == "__main__": - main() - - + main() diff --git a/extra/usb_updater/sample_descriptor b/extra/usb_updater/sample_descriptor index 1566e9e2e1..3be408b642 100644 --- a/extra/usb_updater/sample_descriptor +++ b/extra/usb_updater/sample_descriptor @@ -1,4 +1,4 @@ -# Copyright 2018 The Chromium OS Authors. All rights reserved. +# Copyright 2018 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # diff --git a/extra/usb_updater/servo_updater.py b/extra/usb_updater/servo_updater.py index fa0d21670c..c0be11fdde 100755 --- a/extra/usb_updater/servo_updater.py +++ b/extra/usb_updater/servo_updater.py @@ -1,53 +1,55 @@ #!/usr/bin/env python -# Copyright 2016 The Chromium OS Authors. All rights reserved. +# Copyright 2016 The ChromiumOS Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Ignore indention messages, since legacy scripts use 2 spaces instead of 4. -# pylint: disable=bad-indentation,docstring-section-indent -# pylint: disable=docstring-trailing-quotes # Note: This is a py2/3 compatible file. +"""USB updater tool for servo and similar boards.""" + from __future__ import print_function import argparse -import errno +import json import os import re import subprocess import time -import tempfile - -import json -import fw_update import ecusb.tiny_servo_common as c +import fw_update from ecusb import tiny_servod + class ServoUpdaterException(Exception): - """Raised on exceptions generated by servo_updater.""" + """Raised on exceptions generated by servo_updater.""" -BOARD_C2D2 = 'c2d2' -BOARD_SERVO_MICRO = 'servo_micro' -BOARD_SERVO_V4 = 'servo_v4' -BOARD_SERVO_V4P1 = 'servo_v4p1' -BOARD_SWEETBERRY = 'sweetberry' + +BOARD_C2D2 = "c2d2" +BOARD_SERVO_MICRO = "servo_micro" +BOARD_SERVO_V4 = "servo_v4" +BOARD_SERVO_V4P1 = "servo_v4p1" +BOARD_SWEETBERRY = "sweetberry" DEFAULT_BOARD = BOARD_SERVO_V4 # These lists are to facilitate exposing choices in the command-line tool # below. -BOARDS = [BOARD_C2D2, BOARD_SERVO_MICRO, BOARD_SERVO_V4, BOARD_SERVO_V4P1, - BOARD_SWEETBERRY] +BOARDS = [ + BOARD_C2D2, + BOARD_SERVO_MICRO, + BOARD_SERVO_V4, + BOARD_SERVO_V4P1, + BOARD_SWEETBERRY, +] # Servo firmware bundles four channels of firmware. We need to make sure the # user does not request a non-existing channel, so keep the lists around to # guard on command-line usage. -DEFAULT_CHANNEL = STABLE_CHANNEL = 'stable' +DEFAULT_CHANNEL = STABLE_CHANNEL = "stable" -PREV_CHANNEL = 'prev' +PREV_CHANNEL = "prev" # The ordering here matters. From left to right it's the channel that the user # is most likely to be running. This is used to inform and warn the user if @@ -55,12 +57,12 @@ PREV_CHANNEL = 'prev' # user know they are running the 'stable' version before letting them know they # are running 'dev' or even 'alpah' which (while true) might cause confusion. -CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, 'dev', 'alpha'] +CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, "dev", "alpha"] -DEFAULT_BASE_PATH = '/usr/' -TEST_IMAGE_BASE_PATH = '/usr/local/' +DEFAULT_BASE_PATH = "/usr/" +TEST_IMAGE_BASE_PATH = "/usr/local/" -COMMON_PATH = 'share/servo_updater' +COMMON_PATH = "share/servo_updater" FIRMWARE_DIR = "firmware/" CONFIGS_DIR = "configs/" @@ -68,389 +70,444 @@ CONFIGS_DIR = "configs/" RETRIES_COUNT = 10 RETRIES_DELAY = 1 + def do_with_retries(func, *args): - """ - Call function passed as argument and check if no error happened. - If exception was raised by function, - it will be retried up to RETRIES_COUNT times. - - Args: - func: function that will be called - args: arguments passed to 'func' - - Returns: - If call to function was successful, its result will be returned. - If retries count was exceeded, exception will be raised. - """ - - retry = 0 - while retry < RETRIES_COUNT: - try: - return func(*args) - except Exception as e: - print("Retrying function %s: %s" % (func.__name__, e)) - retry = retry + 1 - time.sleep(RETRIES_DELAY) - continue - - raise Exception("'{}' failed after {} retries".format(func.__name__, RETRIES_COUNT)) + """Try a function several times + + Call function passed as argument and check if no error happened. + If exception was raised by function, + it will be retried up to RETRIES_COUNT times. + + Args: + func: function that will be called + args: arguments passed to 'func' + + Returns: + If call to function was successful, its result will be returned. + If retries count was exceeded, exception will be raised. + """ + + retry = 0 + while retry < RETRIES_COUNT: + try: + return func(*args) + except Exception as e: + print("Retrying function %s: %s" % (func.__name__, e)) + retry = retry + 1 + time.sleep(RETRIES_DELAY) + continue + + raise Exception( + "'{}' failed after {} retries".format(func.__name__, RETRIES_COUNT) + ) + def flash(brdfile, serialno, binfile): - """ - Call fw_update to upload to updater USB endpoint. - - Args: - brdfile: path to board configuration file - serialno: device serial number - binfile: firmware file - """ - - p = fw_update.Supdate() - p.load_board(brdfile) - p.connect_usb(serialname=serialno) - p.load_file(binfile) - - # Start transfer and erase. - p.start() - # Upload the bin file - print("Uploading %s" % binfile) - p.write_file() - - # Finalize - print("Done. Finalizing.") - p.stop() + """Call fw_update to upload to updater USB endpoint. + + Args: + brdfile: path to board configuration file + serialno: device serial number + binfile: firmware file + """ + + p = fw_update.Supdate() + p.load_board(brdfile) + p.connect_usb(serialname=serialno) + p.load_file(binfile) + + # Start transfer and erase. + p.start() + # Upload the bin file + print("Uploading %s" % binfile) + p.write_file() + + # Finalize + print("Done. Finalizing.") + p.stop() + def flash2(vidpid, serialno, binfile): - """ - Call fw update via usb_updater2 commandline. - - Args: - vidpid: vendor id and product id of device - serialno: device serial number (optional) - binfile: firmware file - """ - - tool = 'usb_updater2' - cmd = "%s -d %s" % (tool, vidpid) - if serialno: - cmd += " -S %s" % serialno - cmd += " -n" - cmd += " %s" % binfile - - print(cmd) - help_cmd = '%s --help' % tool - with open('/dev/null') as devnull: - valid_check = subprocess.call(help_cmd.split(), stdout=devnull, - stderr=devnull) - if valid_check: - raise ServoUpdaterException('%s exit with res = %d. Make sure the tool ' - 'is available on the device.' % (help_cmd, - valid_check)) - res = subprocess.call(cmd.split()) - - if res in (0, 1, 2): - return res - else: - raise ServoUpdaterException("%s exit with res = %d" % (cmd, res)) + """Call fw update via usb_updater2 commandline. + + Args: + vidpid: vendor id and product id of device + serialno: device serial number (optional) + binfile: firmware file + """ + + tool = "usb_updater2" + cmd = "%s -d %s" % (tool, vidpid) + if serialno: + cmd += " -S %s" % serialno + cmd += " -n" + cmd += " %s" % binfile + + print(cmd) + help_cmd = "%s --help" % tool + with open("/dev/null") as devnull: + valid_check = subprocess.call( + help_cmd.split(), stdout=devnull, stderr=devnull + ) + if valid_check: + raise ServoUpdaterException( + "%s exit with res = %d. Make sure the tool " + "is available on the device." % (help_cmd, valid_check) + ) + res = subprocess.call(cmd.split()) + + if res in (0, 1, 2): + return res + else: + raise ServoUpdaterException("%s exit with res = %d" % (cmd, res)) + def select(tinys, region): - """ - Ensure the servo is in the expected ro/rw region. - This function jumps to the required region and verify if jump was - successful by executing 'sysinfo' command and reading current region. - If response was not received or region is invalid, exception is raised. + """Jump to specified boot region - Args: - tinys: TinyServod object - region: region to jump to, only "rw" and "ro" is allowed - """ + Ensure the servo is in the expected ro/rw region. + This function jumps to the required region and verify if jump was + successful by executing 'sysinfo' command and reading current region. + If response was not received or region is invalid, exception is raised. - if region not in ["rw", "ro"]: - raise Exception("Region must be ro or rw") + Args: + tinys: TinyServod object + region: region to jump to, only "rw" and "ro" is allowed + """ + + if region not in ["rw", "ro"]: + raise Exception("Region must be ro or rw") + + if region == "ro": + cmd = "reboot" + else: + cmd = "sysjump %s" % region - if region is "ro": - cmd = "reboot" - else: - cmd = "sysjump %s" % region + tinys.pty._issue_cmd(cmd) - tinys.pty._issue_cmd(cmd) + tinys.close() + time.sleep(2) + tinys.reinitialize() - tinys.close() - time.sleep(2) - tinys.reinitialize() + res = tinys.pty._issue_cmd_get_results("sysinfo", [r"Copy:[\s]+(RO|RW)"]) + current_region = res[0][1].lower() + if current_region != region: + raise Exception("Invalid region: %s/%s" % (current_region, region)) - res = tinys.pty._issue_cmd_get_results("sysinfo", ["Copy:[\s]+(RO|RW)"]) - current_region = res[0][1].lower() - if current_region != region: - raise Exception("Invalid region: %s/%s" % (current_region, region)) def do_version(tinys): - """Check version via ec console 'pty'. + """Check version via ec console 'pty'. - Args: - tinys: TinyServod object + Args: + tinys: TinyServod object - Returns: - detected version number + Returns: + detected version number - Commands are: - # > version - # ... - # Build: tigertail_v1.1.6749-74d1a312e - """ - cmd = '\r\nversion\r\n' - regex = 'Build:\s+(\S+)[\r\n]+' + Commands are: + # > version + # ... + # Build: tigertail_v1.1.6749-74d1a312e + """ + cmd = "version" + regex = r"Build:\s+(\S+)[\r\n]+" - results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0] + results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0] + + return results[1].strip(" \t\r\n\0") - return results[1].strip(' \t\r\n\0') def do_updater_version(tinys): - """Check whether this uses python updater or c++ updater - - Args: - tinys: TinyServod object - - Returns: - updater version number. 2 or 6. - """ - vers = do_version(tinys) - - # Servo versions below 58 are from servo-9040.B. Versions starting with _v2 - # are newer than anything _v1, no need to check the exact number. Updater - # version is not directly queryable. - if re.search('_v[2-9]\.\d', vers): - return 6 - m = re.search('_v1\.1\.(\d\d\d\d)', vers) - if m: - version_number = int(m.group(1)) - if version_number < 5800: - return 2 - else: - return 6 - raise ServoUpdaterException( - "Can't determine updater target from vers: [%s]" % vers) + """Check whether this uses python updater or c++ updater -def _extract_version(boardname, binfile): - """Find the version string from |binfile|. + Args: + tinys: TinyServod object - Args: - boardname: the name of the board, eg. "servo_micro" - binfile: path to the binary to search + Returns: + updater version number. 2 or 6. + """ + vers = do_version(tinys) - Returns: - the version string. - """ - if boardname is None: - # cannot extract the version if the name is None - return None - rawstrings = subprocess.check_output( - ['cbfstool', binfile, 'read', '-r', 'RO_FRID', '-f', '/dev/stdout'], - **c.get_subprocess_args()) - m = re.match(r'%s_v\S+' % boardname, rawstrings) - if m: - newvers = m.group(0).strip(' \t\r\n\0') - else: - raise ServoUpdaterException("Can't find version from file: %s." % binfile) + # Servo versions below 58 are from servo-9040.B. Versions starting with _v2 + # are newer than anything _v1, no need to check the exact number. Updater + # version is not directly queryable. + if re.search(r"_v[2-9]\.\d", vers): + return 6 + m = re.search(r"_v1\.1\.(\d\d\d\d)", vers) + if m: + version_number = int(m.group(1)) + if version_number < 5800: + return 2 + else: + return 6 + raise ServoUpdaterException( + "Can't determine updater target from vers: [%s]" % vers + ) + + +def _extract_version(boardname, binfile): + """Find the version string from |binfile|. + + Args: + boardname: the name of the board, eg. "servo_micro" + binfile: path to the binary to search + + Returns: + the version string. + """ + if boardname is None: + # cannot extract the version if the name is None + return None + rawstrings = subprocess.check_output( + ["cbfstool", binfile, "read", "-r", "RO_FRID", "-f", "/dev/stdout"], + **c.get_subprocess_args() + ) + m = re.match(r"%s_v\S+" % boardname, rawstrings) + if m: + newvers = m.group(0).strip(" \t\r\n\0") + else: + raise ServoUpdaterException( + "Can't find version from file: %s." % binfile + ) + + return newvers - return newvers def get_firmware_channel(bname, version): - """Find out which channel |version| for |bname| came from. - - Args: - bname: board name - version: current version string - - Returns: - one of the channel names if |version| came from one of those, or None - """ - for channel in CHANNELS: - # Pass |bname| as cname to find the board specific file, and pass None as - # fname to ensure the default directory is searched - _, _, vers = get_files_and_version(bname, None, channel=channel) - if version == vers: - return channel - # None of the channels matched. This firmware is currently unknown. - return None + """Find out which channel |version| for |bname| came from. + + Args: + bname: board name + version: current version string + + Returns: + one of the channel names if |version| came from one of those, or None + """ + for channel in CHANNELS: + # Pass |bname| as cname to find the board specific file, and pass None as + # fname to ensure the default directory is searched + _, _, vers = get_files_and_version(bname, None, channel=channel) + if version == vers: + return channel + # None of the channels matched. This firmware is currently unknown. + return None + def get_files_and_version(cname, fname=None, channel=DEFAULT_CHANNEL): - """Select config and firmware binary files. - - This checks default file names and paths. - In: /usr/share/servo_updater/[firmware|configs] - check for board.json, board.bin - - Args: - cname: board name, or config name. eg. "servo_v4" or "servo_v4.json" - fname: firmware binary name. Can be None to try default. - channel: the channel requested for servo firmware. See |CHANNELS| above. - - Returns: - cname, fname, version: validated filenames selected from the path. - """ - for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH): - updater_path = os.path.join(p, COMMON_PATH) - if os.path.exists(updater_path): - break - else: - raise ServoUpdaterException('servo_updater/ dir not found in known spots.') - - firmware_path = os.path.join(updater_path, FIRMWARE_DIR) - configs_path = os.path.join(updater_path, CONFIGS_DIR) - - for p in (firmware_path, configs_path): - if not os.path.exists(p): - raise ServoUpdaterException('Could not find required path %r' % p) - - if not os.path.isfile(cname): - # If not an existing file, try checking on the default path. - newname = os.path.join(configs_path, cname) - if os.path.isfile(newname): - cname = newname + """Select config and firmware binary files. + + This checks default file names and paths. + In: /usr/share/servo_updater/[firmware|configs] + check for board.json, board.bin + + Args: + cname: board name, or config name. eg. "servo_v4" or "servo_v4.json" + fname: firmware binary name. Can be None to try default. + channel: the channel requested for servo firmware. See |CHANNELS| above. + + Returns: + cname, fname, version: validated filenames selected from the path. + """ + for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH): + updater_path = os.path.join(p, COMMON_PATH) + if os.path.exists(updater_path): + break else: - # Try appending ".json" to convert board name to config file. - cname = newname + ".json" + raise ServoUpdaterException( + "servo_updater/ dir not found in known spots." + ) + + firmware_path = os.path.join(updater_path, FIRMWARE_DIR) + configs_path = os.path.join(updater_path, CONFIGS_DIR) + + for p in (firmware_path, configs_path): + if not os.path.exists(p): + raise ServoUpdaterException("Could not find required path %r" % p) + if not os.path.isfile(cname): - raise ServoUpdaterException("Can't find config file: %s." % cname) - - # Always retrieve the boardname - with open(cname) as data_file: - data = json.load(data_file) - boardname = data['board'] - - if not fname: - # If no |fname| supplied, look for the default locations with the board - # and channel requested. - binary_file = '%s.%s.bin' % (boardname, channel) - newname = os.path.join(firmware_path, binary_file) - if os.path.isfile(newname): - fname = newname - else: - raise ServoUpdaterException("Can't find firmware binary: %s." % - binary_file) - elif not os.path.isfile(fname): - # If a name is specified but not found, try the default path. - newname = os.path.join(firmware_path, fname) - if os.path.isfile(newname): - fname = newname + # If not an existing file, try checking on the default path. + newname = os.path.join(configs_path, cname) + if os.path.isfile(newname): + cname = newname + else: + # Try appending ".json" to convert board name to config file. + cname = newname + ".json" + if not os.path.isfile(cname): + raise ServoUpdaterException("Can't find config file: %s." % cname) + + # Always retrieve the boardname + with open(cname) as data_file: + data = json.load(data_file) + boardname = data["board"] + + if not fname: + # If no |fname| supplied, look for the default locations with the board + # and channel requested. + binary_file = "%s.%s.bin" % (boardname, channel) + newname = os.path.join(firmware_path, binary_file) + if os.path.isfile(newname): + fname = newname + else: + raise ServoUpdaterException( + "Can't find firmware binary: %s." % binary_file + ) + elif not os.path.isfile(fname): + # If a name is specified but not found, try the default path. + newname = os.path.join(firmware_path, fname) + if os.path.isfile(newname): + fname = newname + else: + raise ServoUpdaterException("Can't find file: %s." % fname) + + # Lastly, retrieve the version as well for decision making, debug, and + # informational purposes. + binvers = _extract_version(boardname, fname) + + return cname, fname, binvers + + +def main(): + parser = argparse.ArgumentParser(description="Image a servo device") + parser.add_argument( + "-p", + "--print", + dest="print_only", + action="store_true", + default=False, + help="only print available firmware for board/channel", + ) + parser.add_argument( + "-s", + "--serialno", + type=str, + help="serial number to program", + default=None, + ) + parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration json file", + default=DEFAULT_BOARD, + choices=BOARDS, + ) + parser.add_argument( + "-c", + "--channel", + type=str, + help="Firmware channel to use", + default=DEFAULT_CHANNEL, + choices=CHANNELS, + ) + parser.add_argument( + "-f", "--file", type=str, help="Complete ec.bin file", default=None + ) + parser.add_argument( + "--force", + action="store_true", + help="Update even if version match", + default=False, + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Chatty output" + ) + parser.add_argument( + "-r", + "--reboot", + action="store_true", + help="Always reboot, even after probe.", + ) + + args = parser.parse_args() + + brdfile, binfile, newvers = get_files_and_version( + args.board, args.file, args.channel + ) + + # If the user only cares about the information then just print it here, + # and exit. + if args.print_only: + output = ("board: %s\nchannel: %s\nfirmware: %s") % ( + args.board, + args.channel, + newvers, + ) + print(output) + return + + serialno = args.serialno + + with open(brdfile) as data_file: + data = json.load(data_file) + vid, pid = int(data["vid"], 0), int(data["pid"], 0) + vidpid = "%04x:%04x" % (vid, pid) + iface = int(data["console"], 0) + boardname = data["board"] + + # Make sure device is up. + print("===== Waiting for USB device =====") + c.wait_for_usb(vidpid, serialname=serialno) + # We need a tiny_servod to query some information. Set it up first. + tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose) + + if not args.force: + vers = do_version(tinys) + print("Current %s version is %s" % (boardname, vers)) + print("Available %s version is %s" % (boardname, newvers)) + + if newvers == vers: + print("No version update needed") + if args.reboot: + select(tinys, "ro") + return + else: + print("Updating to recommended version.") + + # Make sure the servo MCU is in RO + print("===== Jumping to RO =====") + do_with_retries(select, tinys, "ro") + + print("===== Flashing RW =====") + vers = do_with_retries(do_updater_version, tinys) + # To make sure that the tiny_servod here does not interfere with other + # processes, close it out. + tinys.close() + + if vers == 2: + flash(brdfile, serialno, binfile) + elif vers == 6: + do_with_retries(flash2, vidpid, serialno, binfile) else: - raise ServoUpdaterException("Can't find file: %s." % fname) + raise ServoUpdaterException("Can't detect updater version") - # Lastly, retrieve the version as well for decision making, debug, and - # informational purposes. - binvers = _extract_version(boardname, fname) + # Make sure device is up. + c.wait_for_usb(vidpid, serialname=serialno) + # After we have made sure that it's back/available, reconnect the tiny servod. + tinys.reinitialize() - return cname, fname, binvers + # Make sure the servo MCU is in RW + print("===== Jumping to RW =====") + do_with_retries(select, tinys, "rw") -def main(): - parser = argparse.ArgumentParser(description="Image a servo device") - parser.add_argument('-p', '--print', dest='print_only', action='store_true', - default=False, - help='only print available firmware for board/channel') - parser.add_argument('-s', '--serialno', type=str, - help="serial number to program", default=None) - parser.add_argument('-b', '--board', type=str, - help="Board configuration json file", - default=DEFAULT_BOARD, choices=BOARDS) - parser.add_argument('-c', '--channel', type=str, - help="Firmware channel to use", - default=DEFAULT_CHANNEL, choices=CHANNELS) - parser.add_argument('-f', '--file', type=str, - help="Complete ec.bin file", default=None) - parser.add_argument('--force', action="store_true", - help="Update even if version match", default=False) - parser.add_argument('-v', '--verbose', action="store_true", - help="Chatty output") - parser.add_argument('-r', '--reboot', action="store_true", - help="Always reboot, even after probe.") - - args = parser.parse_args() - - brdfile, binfile, newvers = get_files_and_version(args.board, args.file, - args.channel) - - # If the user only cares about the information then just print it here, - # and exit. - if args.print_only: - output = ('board: %s\n' - 'channel: %s\n' - 'firmware: %s') % (args.board, args.channel, newvers) - print(output) - return - - serialno = args.serialno - - with open(brdfile) as data_file: - data = json.load(data_file) - vid, pid = int(data['vid'], 0), int(data['pid'], 0) - vidpid = "%04x:%04x" % (vid, pid) - iface = int(data['console'], 0) - boardname = data['board'] - - # Make sure device is up. - print("===== Waiting for USB device =====") - c.wait_for_usb(vidpid, serialname=serialno) - # We need a tiny_servod to query some information. Set it up first. - tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose) - - if not args.force: - vers = do_version(tinys) - print("Current %s version is %s" % (boardname, vers)) - print("Available %s version is %s" % (boardname, newvers)) - - if newvers == vers: - print("No version update needed") - if args.reboot: - select(tinys, 'ro') - return + print("===== Flashing RO =====") + vers = do_with_retries(do_updater_version, tinys) + + if vers == 2: + flash(brdfile, serialno, binfile) + elif vers == 6: + do_with_retries(flash2, vidpid, serialno, binfile) else: - print("Updating to recommended version.") - - # Make sure the servo MCU is in RO - print("===== Jumping to RO =====") - do_with_retries(select, tinys, 'ro') - - print("===== Flashing RW =====") - vers = do_with_retries(do_updater_version, tinys) - # To make sure that the tiny_servod here does not interfere with other - # processes, close it out. - tinys.close() - - if vers == 2: - flash(brdfile, serialno, binfile) - elif vers == 6: - do_with_retries(flash2, vidpid, serialno, binfile) - else: - raise ServoUpdaterException("Can't detect updater version") - - # Make sure device is up. - c.wait_for_usb(vidpid, serialname=serialno) - # After we have made sure that it's back/available, reconnect the tiny servod. - tinys.reinitialize() - - # Make sure the servo MCU is in RW - print("===== Jumping to RW =====") - do_with_retries(select, tinys, 'rw') - - print("===== Flashing RO =====") - vers = do_with_retries(do_updater_version, tinys) - - if vers == 2: - flash(brdfile, serialno, binfile) - elif vers == 6: - do_with_retries(flash2, vidpid, serialno, binfile) - else: - raise ServoUpdaterException("Can't detect updater version") - - # Make sure the servo MCU is in RO - print("===== Rebooting =====") - do_with_retries(select, tinys, 'ro') - # Perform additional reboot to free USB/UART resources, taken by tiny servod. - # See https://issuetracker.google.com/196021317 for background. - tinys.pty._issue_cmd("reboot") - - print("===== Finished =====") + raise ServoUpdaterException("Can't detect updater version") + + # Make sure the servo MCU is in RO + print("===== Rebooting =====") + do_with_retries(select, tinys, "ro") + # Perform additional reboot to free USB/UART resources, taken by tiny servod. + # See https://issuetracker.google.com/196021317 for background. + tinys.pty._issue_cmd("reboot") + + print("===== Finished =====") + if __name__ == "__main__": - main() + main() diff --git a/extra/usb_updater/usb_updater2.c b/extra/usb_updater/usb_updater2.c index 81cf48a680..d591811a2b 100644 --- a/extra/usb_updater/usb_updater2.c +++ b/extra/usb_updater/usb_updater2.c @@ -1,5 +1,5 @@ /* - * Copyright 2017 The Chromium OS Authors. All rights reserved. + * Copyright 2017 The ChromiumOS Authors * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ @@ -46,16 +46,16 @@ #define PROTOCOL USB_PROTOCOL_GOOGLE_UPDATE enum exit_values { - noop = 0, /* All up to date, no update needed. */ - all_updated = 1, /* Update completed, reboot required. */ - rw_updated = 2, /* RO was not updated, reboot required. */ - update_error = 3 /* Something went wrong. */ + noop = 0, /* All up to date, no update needed. */ + all_updated = 1, /* Update completed, reboot required. */ + rw_updated = 2, /* RO was not updated, reboot required. */ + update_error = 3 /* Something went wrong. */ }; struct usb_endpoint { struct libusb_device_handle *devh; uint8_t ep_num; - int chunk_len; + int chunk_len; }; struct transfer_descriptor { @@ -76,22 +76,22 @@ static char *progname; static char *short_opts = "bd:efg:hjlnp:rsS:tuw"; static const struct option long_opts[] = { /* name hasarg *flag val */ - {"binvers", 1, NULL, 'b'}, - {"device", 1, NULL, 'd'}, - {"entropy", 0, NULL, 'e'}, - {"fwver", 0, NULL, 'f'}, - {"tp_debug", 1, NULL, 'g'}, - {"help", 0, NULL, 'h'}, - {"jump_to_rw", 0, NULL, 'j'}, - {"follow_log", 0, NULL, 'l'}, - {"no_reset", 0, NULL, 'n'}, - {"tp_update", 1, NULL, 'p'}, - {"reboot", 0, NULL, 'r'}, - {"stay_in_ro", 0, NULL, 's'}, - {"serial", 1, NULL, 'S'}, - {"tp_info", 0, NULL, 't'}, - {"unlock_rollback", 0, NULL, 'u'}, - {"unlock_rw", 0, NULL, 'w'}, + { "binvers", 1, NULL, 'b' }, + { "device", 1, NULL, 'd' }, + { "entropy", 0, NULL, 'e' }, + { "fwver", 0, NULL, 'f' }, + { "tp_debug", 1, NULL, 'g' }, + { "help", 0, NULL, 'h' }, + { "jump_to_rw", 0, NULL, 'j' }, + { "follow_log", 0, NULL, 'l' }, + { "no_reset", 0, NULL, 'n' }, + { "tp_update", 1, NULL, 'p' }, + { "reboot", 0, NULL, 'r' }, + { "stay_in_ro", 0, NULL, 's' }, + { "serial", 1, NULL, 'S' }, + { "tp_info", 0, NULL, 't' }, + { "unlock_rollback", 0, NULL, 'u' }, + { "unlock_rw", 0, NULL, 'w' }, {}, }; @@ -113,7 +113,7 @@ static void usage(int errs) "Options:\n" "\n" " -b,--binvers Report versions of image's " - "RW and RO, do not update\n" + "RW and RO, do not update\n" " -d,--device VID:PID USB device (default %04x:%04x)\n" " -e,--entropy Add entropy to device secret\n" " -f,--fwver Report running firmware versions.\n" @@ -128,7 +128,8 @@ static void usage(int errs) " -t,--tp_info Get touchpad information\n" " -u,--unlock_rollback Tell EC to unlock the rollback region\n" " -w,--unlock_rw Tell EC to unlock the RW region\n" - "\n", progname, VID, PID); + "\n", + progname, VID, PID); exit(errs ? update_error : noop); } @@ -138,7 +139,7 @@ static void str2hex(const char *str, uint8_t *data, int *len) int i; int slen = strlen(str); - if (slen/2 > *len) { + if (slen / 2 > *len) { fprintf(stderr, "Hex string too long.\n"); exit(update_error); } @@ -153,7 +154,7 @@ static void str2hex(const char *str, uint8_t *data, int *len) char tmp[3]; tmp[0] = str[i]; - tmp[1] = str[i+1]; + tmp[1] = str[i + 1]; tmp[2] = 0; data[*len] = strtol(tmp, &end, 16); @@ -250,9 +251,9 @@ static uint8_t *get_file_or_die(const char *filename, size_t *len_ptr) return data; } -#define USB_ERROR(m, r) \ - fprintf(stderr, "%s:%d, %s returned %d (%s)\n", __FILE__, __LINE__, \ - m, r, libusb_strerror(r)) +#define USB_ERROR(m, r) \ + fprintf(stderr, "%s:%d, %s returned %d (%s)\n", __FILE__, __LINE__, m, \ + r, libusb_strerror(r)) /* * Actual USB transfer function, the 'allow_less' flag indicates that the @@ -261,17 +262,14 @@ static uint8_t *get_file_or_die(const char *filename, size_t *len_ptr) * bytes were received. */ static void do_xfer(struct usb_endpoint *uep, void *outbuf, int outlen, - void *inbuf, int inlen, int allow_less, - size_t *rxed_count) + void *inbuf, int inlen, int allow_less, size_t *rxed_count) { - int r, actual; /* Send data out */ if (outbuf && outlen) { actual = 0; - r = libusb_bulk_transfer(uep->devh, uep->ep_num, - outbuf, outlen, + r = libusb_bulk_transfer(uep->devh, uep->ep_num, outbuf, outlen, &actual, 2000); if (r < 0) { USB_ERROR("libusb_bulk_transfer", r); @@ -286,11 +284,9 @@ static void do_xfer(struct usb_endpoint *uep, void *outbuf, int outlen, /* Read reply back */ if (inbuf && inlen) { - actual = 0; - r = libusb_bulk_transfer(uep->devh, uep->ep_num | 0x80, - inbuf, inlen, - &actual, 5000); + r = libusb_bulk_transfer(uep->devh, uep->ep_num | 0x80, inbuf, + inlen, &actual, 5000); if (r < 0) { USB_ERROR("libusb_bulk_transfer", r); exit(update_error); @@ -307,8 +303,8 @@ static void do_xfer(struct usb_endpoint *uep, void *outbuf, int outlen, } } -static void xfer(struct usb_endpoint *uep, void *outbuf, - size_t outlen, void *inbuf, size_t inlen, int allow_less) +static void xfer(struct usb_endpoint *uep, void *outbuf, size_t outlen, + void *inbuf, size_t inlen, int allow_less) { do_xfer(uep, outbuf, outlen, inbuf, inlen, allow_less, NULL); } @@ -321,8 +317,7 @@ static int find_endpoint(const struct libusb_interface_descriptor *iface, if (iface->bInterfaceClass == 255 && iface->bInterfaceSubClass == SUBCLASS && - iface->bInterfaceProtocol == PROTOCOL && - iface->bNumEndpoints) { + iface->bInterfaceProtocol == PROTOCOL && iface->bNumEndpoints) { ep = &iface->endpoint[0]; uep->ep_num = ep->bEndpointAddress & 0x7f; uep->chunk_len = ep->wMaxPacketSize; @@ -377,19 +372,19 @@ static int parse_vidpid(const char *input, uint16_t *vid_ptr, uint16_t *pid_ptr) return 0; *s++ = '\0'; - *vid_ptr = (uint16_t) strtoull(copy, &e, 16); + *vid_ptr = (uint16_t)strtoull(copy, &e, 16); if (!*optarg || (e && *e)) return 0; - *pid_ptr = (uint16_t) strtoull(s, &e, 16); + *pid_ptr = (uint16_t)strtoull(s, &e, 16); if (!*optarg || (e && *e)) return 0; return 1; } -static libusb_device_handle *check_device(libusb_device *dev, - uint16_t vid, uint16_t pid, char *serialno) +static libusb_device_handle *check_device(libusb_device *dev, uint16_t vid, + uint16_t pid, char *serialno) { struct libusb_device_descriptor desc; libusb_device_handle *handle = NULL; @@ -409,7 +404,9 @@ static libusb_device_handle *check_device(libusb_device *dev, if (desc.iSerialNumber) { ret = libusb_get_string_descriptor_ascii(handle, - desc.iSerialNumber, (unsigned char *)sn, sizeof(sn)); + desc.iSerialNumber, + (unsigned char *)sn, + sizeof(sn)); if (ret > 0) snvalid = 1; } @@ -428,8 +425,8 @@ static libusb_device_handle *check_device(libusb_device *dev, return NULL; } -static void usb_findit(uint16_t vid, uint16_t pid, - char *serialno, struct usb_endpoint *uep) +static void usb_findit(uint16_t vid, uint16_t pid, char *serialno, + struct usb_endpoint *uep) { int iface_num, r, i; libusb_device **devs; @@ -475,8 +472,8 @@ static void usb_findit(uint16_t vid, uint16_t pid, shut_down(uep); } - printf("found interface %d endpoint %d, chunk_len %d\n", - iface_num, uep->ep_num, uep->chunk_len); + printf("found interface %d endpoint %d, chunk_len %d\n", iface_num, + uep->ep_num, uep->chunk_len); libusb_set_auto_detach_kernel_driver(uep->devh, 1); r = libusb_claim_interface(uep->devh, iface_num); @@ -511,9 +508,8 @@ static int transfer_block(struct usb_endpoint *uep, } /* Now get the reply. */ - r = libusb_bulk_transfer(uep->devh, uep->ep_num | 0x80, - (void *) &reply, sizeof(reply), - &actual, 5000); + r = libusb_bulk_transfer(uep->devh, uep->ep_num | 0x80, (void *)&reply, + sizeof(reply), &actual, 5000); if (r) { if (r == -7) { fprintf(stderr, "Timeout!\n"); @@ -541,10 +537,8 @@ static int transfer_block(struct usb_endpoint *uep, * data_len - section size * smart_update - non-zero to enable the smart trailing of 0xff. */ -static void transfer_section(struct transfer_descriptor *td, - uint8_t *data_ptr, - uint32_t section_addr, - size_t data_len, +static void transfer_section(struct transfer_descriptor *td, uint8_t *data_ptr, + uint32_t section_addr, size_t data_len, uint8_t smart_update) { /* @@ -571,17 +565,16 @@ static void transfer_section(struct transfer_descriptor *td, struct update_frame_header ufh; ufh.block_size = htobe32(payload_size + - sizeof(struct update_frame_header)); + sizeof(struct update_frame_header)); ufh.cmd.block_base = block_base; ufh.cmd.block_digest = 0; for (max_retries = 10; max_retries; max_retries--) - if (!transfer_block(&td->uep, &ufh, - data_ptr, payload_size)) + if (!transfer_block(&td->uep, &ufh, data_ptr, + payload_size)) break; if (!max_retries) { - fprintf(stderr, - "Failed to transfer block, %zd to go\n", + fprintf(stderr, "Failed to transfer block, %zd to go\n", data_len); exit(update_error); } @@ -596,30 +589,27 @@ static void transfer_section(struct transfer_descriptor *td, * states. */ enum upgrade_status { - not_needed = 0, /* Version below or equal that on the target. */ - not_possible, /* - * RO is newer, but can't be transferred due to - * target RW shortcomings. - */ - needed /* - * This section needs to be transferred to the - * target. - */ + not_needed = 0, /* Version below or equal that on the target. */ + not_possible, /* + * RO is newer, but can't be transferred due to + * target RW shortcomings. + */ + needed /* + * This section needs to be transferred to the + * target. + */ }; /* This array describes all sections of the new image. */ static struct { const char *name; - uint32_t offset; - uint32_t size; - enum upgrade_status ustatus; + uint32_t offset; + uint32_t size; + enum upgrade_status ustatus; char version[32]; int32_t rollback; uint32_t key_version; -} sections[] = { - {"RO"}, - {"RW"} -}; +} sections[] = { { "RO" }, { "RW" } }; static const struct fmap_area *fmap_find_area_or_die(const struct fmap *fmap, const char *name) @@ -650,7 +640,7 @@ static void fetch_header_versions(const uint8_t *image, size_t len) fprintf(stderr, "Cannot find FMAP in image\n"); exit(update_error); } - fmap = (const struct fmap *)(image+offset); + fmap = (const struct fmap *)(image + offset); /* FIXME: validate fmap struct more than this? */ if (fmap->size != len) { @@ -693,15 +683,15 @@ static void fetch_header_versions(const uint8_t *image, size_t len) fprintf(stderr, "Invalid fwid size\n"); exit(update_error); } - memcpy(sections[i].version, image+fmaparea->offset, - fmaparea->size); + memcpy(sections[i].version, image + fmaparea->offset, + fmaparea->size); sections[i].rollback = -1; if (fmap_rollback_name) { fmaparea = fmap_find_area(fmap, fmap_rollback_name); if (fmaparea) memcpy(§ions[i].rollback, - image+fmaparea->offset, + image + fmaparea->offset, sizeof(sections[i].rollback)); } @@ -710,7 +700,8 @@ static void fetch_header_versions(const uint8_t *image, size_t len) fmaparea = fmap_find_area(fmap, fmap_key_name); if (fmaparea) { const struct vb21_packed_key *key = - (const void *)(image+fmaparea->offset); + (const void *)(image + + fmaparea->offset); sections[i].key_version = key->key_version; } } @@ -723,9 +714,9 @@ static int show_headers_versions(const void *image) for (i = 0; i < ARRAY_SIZE(sections); i++) { printf("%s off=%08x/%08x v=%.32s rb=%d kv=%d\n", - sections[i].name, sections[i].offset, sections[i].size, - sections[i].version, sections[i].rollback, - sections[i].key_version); + sections[i].name, sections[i].offset, sections[i].size, + sections[i].version, sections[i].rollback, + sections[i].key_version); } return 0; } @@ -772,17 +763,16 @@ static void setup_connection(struct transfer_descriptor *td) int actual = 0; /* Flush all data from endpoint to recover in case of error. */ - while (!libusb_bulk_transfer(td->uep.devh, - td->uep.ep_num | 0x80, - (void *)&inbuf, td->uep.chunk_len, - &actual, 10)) { + while (!libusb_bulk_transfer(td->uep.devh, td->uep.ep_num | 0x80, + (void *)&inbuf, td->uep.chunk_len, &actual, + 10)) { printf("flush\n"); } memset(&ufh, 0, sizeof(ufh)); ufh.block_size = htobe32(sizeof(ufh)); - do_xfer(&td->uep, &ufh, sizeof(ufh), &start_resp, - sizeof(start_resp), 1, &rxed_size); + do_xfer(&td->uep, &ufh, sizeof(ufh), &start_resp, sizeof(start_resp), 1, + &rxed_size); /* We got something. Check for errors in response */ if (rxed_size < 8) { @@ -803,10 +793,9 @@ static void setup_connection(struct transfer_descriptor *td) header_type = be16toh(start_resp.rpdu.header_type); printf("target running protocol version %d (type %d)\n", - protocol_version, header_type); + protocol_version, header_type); if (header_type != UPDATE_HEADER_TYPE_COMMON) { - fprintf(stderr, "Unsupported header type %d\n", - header_type); + fprintf(stderr, "Unsupported header type %d\n", header_type); exit(update_error); } @@ -820,7 +809,7 @@ static void setup_connection(struct transfer_descriptor *td) td->offset = be32toh(start_resp.rpdu.common.offset); memcpy(targ.common.version, start_resp.rpdu.common.version, - sizeof(start_resp.rpdu.common.version)); + sizeof(start_resp.rpdu.common.version)); targ.common.maximum_pdu_size = be32toh(start_resp.rpdu.common.maximum_pdu_size); targ.common.flash_protection = @@ -845,21 +834,20 @@ static void setup_connection(struct transfer_descriptor *td) * if it is - of what maximum size. */ static int ext_cmd_over_usb(struct usb_endpoint *uep, uint16_t subcommand, - void *cmd_body, size_t body_size, - void *resp, size_t *resp_size, - int allow_less) + void *cmd_body, size_t body_size, void *resp, + size_t *resp_size, int allow_less) { struct update_frame_header *ufh; uint16_t *frame_ptr; size_t usb_msg_size; - usb_msg_size = sizeof(struct update_frame_header) + - sizeof(subcommand) + body_size; + usb_msg_size = sizeof(struct update_frame_header) + sizeof(subcommand) + + body_size; ufh = malloc(usb_msg_size); if (!ufh) { - printf("%s: failed to allocate %zd bytes\n", - __func__, usb_msg_size); + printf("%s: failed to allocate %zd bytes\n", __func__, + usb_msg_size); return -1; } @@ -895,30 +883,28 @@ static void send_done(struct usb_endpoint *uep) } static void send_subcommand(struct transfer_descriptor *td, uint16_t subcommand, - void *cmd_body, size_t body_size, - uint8_t *response, size_t response_size) + void *cmd_body, size_t body_size, uint8_t *response, + size_t response_size) { send_done(&td->uep); - ext_cmd_over_usb(&td->uep, subcommand, - cmd_body, body_size, - response, &response_size, 0); + ext_cmd_over_usb(&td->uep, subcommand, cmd_body, body_size, response, + &response_size, 0); printf("sent command %x, resp %x\n", subcommand, response[0]); } /* Returns number of successfully transmitted image sections. */ -static int transfer_image(struct transfer_descriptor *td, - uint8_t *data, size_t data_len) +static int transfer_image(struct transfer_descriptor *td, uint8_t *data, + size_t data_len) { size_t i; int num_txed_sections = 0; for (i = 0; i < ARRAY_SIZE(sections); i++) if (sections[i].ustatus == needed) { - transfer_section(td, - data + sections[i].offset, - sections[i].offset, - sections[i].size, 1); + transfer_section(td, data + sections[i].offset, + sections[i].offset, sections[i].size, + 1); num_txed_sections++; } @@ -968,9 +954,8 @@ static void generate_reset_request(struct transfer_descriptor *td) command_body_size = 0; response_size = 1; subcommand = UPDATE_EXTRA_CMD_IMMEDIATE_RESET; - ext_cmd_over_usb(&td->uep, subcommand, - command_body, command_body_size, - &response, &response_size, 0); + ext_cmd_over_usb(&td->uep, subcommand, command_body, command_body_size, + &response, &response_size, 0); printf("reboot not triggered\n"); } @@ -987,7 +972,7 @@ static void get_random(uint8_t *data, int len) } while (i < len) { - int ret = fread(data+i, len-i, 1, fp); + int ret = fread(data + i, len - i, 1, fp); if (ret < 0) { perror("fread"); @@ -1005,7 +990,8 @@ static void read_console(struct transfer_descriptor *td) uint8_t payload[] = { 0x1 }; uint8_t response[64]; size_t response_size = 64; - struct timespec sleep_duration = { /* 100 ms */ + struct timespec sleep_duration = { + /* 100 ms */ .tv_sec = 0, .tv_nsec = 100l * 1000l * 1000l, }; @@ -1015,17 +1001,15 @@ static void read_console(struct transfer_descriptor *td) printf("\n"); while (1) { response_size = 1; - ext_cmd_over_usb(&td->uep, - UPDATE_EXTRA_CMD_CONSOLE_READ_INIT, - NULL, 0, - response, &response_size, 0); + ext_cmd_over_usb(&td->uep, UPDATE_EXTRA_CMD_CONSOLE_READ_INIT, + NULL, 0, response, &response_size, 0); while (1) { response_size = 64; ext_cmd_over_usb(&td->uep, UPDATE_EXTRA_CMD_CONSOLE_READ_NEXT, - payload, sizeof(payload), - response, &response_size, 1); + payload, sizeof(payload), response, + &response_size, 1); if (response[0] == 0) break; /* make sure it's null-terminated. */ @@ -1067,7 +1051,7 @@ int main(int argc, char *argv[]) memset(&td, 0, sizeof(td)); errorcnt = 0; - opterr = 0; /* quiet, you */ + opterr = 0; /* quiet, you */ while ((i = getopt_long(argc, argv, short_opts, long_opts, 0)) != -1) { switch (i) { case 'b': @@ -1091,8 +1075,8 @@ int main(int argc, char *argv[]) extra_command = UPDATE_EXTRA_CMD_TOUCHPAD_DEBUG; /* Maximum length. */ extra_command_data_len = 50; - str2hex(optarg, - extra_command_data, &extra_command_data_len); + str2hex(optarg, extra_command_data, + &extra_command_data_len); hexdump(extra_command_data, extra_command_data_len); extra_command_answer_len = 64; break; @@ -1112,8 +1096,8 @@ int main(int argc, char *argv[]) touchpad_update = 1; data = get_file_or_die(optarg, &data_len); - printf("read %zd(%#zx) bytes from %s\n", - data_len, data_len, argv[optind - 1]); + printf("read %zd(%#zx) bytes from %s\n", data_len, + data_len, argv[optind - 1]); break; case 'r': @@ -1127,8 +1111,7 @@ int main(int argc, char *argv[]) break; case 't': extra_command = UPDATE_EXTRA_CMD_TOUCHPAD_INFO; - extra_command_answer_len = - sizeof(struct touchpad_info); + extra_command_answer_len = sizeof(struct touchpad_info); break; case 'u': extra_command = UPDATE_EXTRA_CMD_UNLOCK_ROLLBACK; @@ -1136,7 +1119,7 @@ int main(int argc, char *argv[]) case 'w': extra_command = UPDATE_EXTRA_CMD_UNLOCK_RW; break; - case 0: /* auto-handled option */ + case 0: /* auto-handled option */ break; case '?': if (optopt) @@ -1167,8 +1150,8 @@ int main(int argc, char *argv[]) } data = get_file_or_die(argv[optind], &data_len); - printf("read %zd(%#zx) bytes from %s\n", - data_len, data_len, argv[optind]); + printf("read %zd(%#zx) bytes from %s\n", data_len, data_len, + argv[optind]); fetch_header_versions(data, data_len); @@ -1190,16 +1173,13 @@ int main(int argc, char *argv[]) if (data) { if (touchpad_update) { - transfer_section(&td, - data, - 0x80000000, - data_len, 0); + transfer_section(&td, data, 0x80000000, data_len, 0); free(data); send_done(&td.uep); } else { - transferred_sections = transfer_image(&td, - data, data_len); + transferred_sections = + transfer_image(&td, data, data_len); free(data); if (transferred_sections && !no_reset_request) @@ -1208,16 +1188,16 @@ int main(int argc, char *argv[]) } else if (extra_command == UPDATE_EXTRA_CMD_CONSOLE_READ_INIT) { read_console(&td); } else if (extra_command > -1) { - send_subcommand(&td, extra_command, - extra_command_data, extra_command_data_len, - extra_command_answer, extra_command_answer_len); + send_subcommand(&td, extra_command, extra_command_data, + extra_command_data_len, extra_command_answer, + extra_command_answer_len); switch (extra_command) { case UPDATE_EXTRA_CMD_TOUCHPAD_INFO: dump_touchpad_info(extra_command_answer, extra_command_answer_len); break; - case UPDATE_EXTRA_CMD_TOUCHPAD_DEBUG: + case UPDATE_EXTRA_CMD_TOUCHPAD_DEBUG: hexdump(extra_command_answer, extra_command_answer_len); break; } |