diff options
author | Jeremy Bettis <jbettis@google.com> | 2022-07-08 10:58:19 -0600 |
---|---|---|
committer | Chromeos LUCI <chromeos-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2022-07-12 19:13:33 +0000 |
commit | 7540e7b47b55447475bb8191fb3520dd67cf7998 (patch) | |
tree | 13309dbcf1db48e60fa2c2e5aed79f63bce00b5e /extra | |
parent | 7c114b8e1a3bb29991da70b9de394ac5d4f6c909 (diff) | |
download | chrome-ec-7540e7b47b55447475bb8191fb3520dd67cf7998.tar.gz |
ec: Format all python files with black and isort
find . \( -path ./private -prune \) -o -name '*.py' -print | xargs black
find . \( -path ./private -prune \) -o -name '*.py' -print | xargs ~/chromiumos/chromite/scripts/isort --settings-file=.isort.cfg
BRANCH=None
BUG=b:238434058
TEST=None
Signed-off-by: Jeremy Bettis <jbettis@google.com>
Change-Id: I63462d6f15d1eaf3db84eb20d1404ee976be8382
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/ec/+/3749242
Commit-Queue: Jeremy Bettis <jbettis@chromium.org>
Reviewed-by: Tom Hughes <tomhughes@chromium.org>
Tested-by: Jeremy Bettis <jbettis@chromium.org>
Commit-Queue: Jack Rosenthal <jrosenth@chromium.org>
Auto-Submit: Jeremy Bettis <jbettis@chromium.org>
Reviewed-by: Jack Rosenthal <jrosenth@chromium.org>
Diffstat (limited to 'extra')
-rwxr-xr-x | extra/cr50_rma_open/cr50_rma_open.py | 418 | ||||
-rwxr-xr-x | extra/stack_analyzer/stack_analyzer.py | 3562 | ||||
-rwxr-xr-x | extra/stack_analyzer/stack_analyzer_unittest.py | 1708 | ||||
-rw-r--r-- | extra/tigertool/ecusb/__init__.py | 2 | ||||
-rw-r--r-- | extra/tigertool/ecusb/pty_driver.py | 545 | ||||
-rw-r--r-- | extra/tigertool/ecusb/stm32uart.py | 438 | ||||
-rw-r--r-- | extra/tigertool/ecusb/stm32usb.py | 207 | ||||
-rw-r--r-- | extra/tigertool/ecusb/tiny_servo_common.py | 343 | ||||
-rw-r--r-- | extra/tigertool/ecusb/tiny_servod.py | 83 | ||||
-rwxr-xr-x | extra/tigertool/tigertest.py | 97 | ||||
-rwxr-xr-x | extra/tigertool/tigertool.py | 505 | ||||
-rw-r--r-- | extra/usb_power/convert_power_log_board.py | 35 | ||||
-rwxr-xr-x | extra/usb_power/convert_servo_ina.py | 86 | ||||
-rwxr-xr-x | extra/usb_power/powerlog.py | 1762 | ||||
-rw-r--r-- | extra/usb_power/powerlog_unittest.py | 74 | ||||
-rw-r--r-- | extra/usb_power/stats_manager.py | 745 | ||||
-rw-r--r-- | extra/usb_power/stats_manager_unittest.py | 595 | ||||
-rwxr-xr-x | extra/usb_serial/console.py | 435 | ||||
-rwxr-xr-x | extra/usb_updater/fw_update.py | 762 | ||||
-rwxr-xr-x | extra/usb_updater/servo_updater.py | 767 |
20 files changed, 6864 insertions, 6305 deletions
diff --git a/extra/cr50_rma_open/cr50_rma_open.py b/extra/cr50_rma_open/cr50_rma_open.py index 42ddbbac2d..b77b8f3dbb 100755 --- a/extra/cr50_rma_open/cr50_rma_open.py +++ b/extra/cr50_rma_open/cr50_rma_open.py @@ -57,16 +57,17 @@ CCD_IS_UNRESTRICTED = 1 << 0 WP_IS_DISABLED = 1 << 1 TESTLAB_IS_ENABLED = 1 << 2 RMA_OPENED = CCD_IS_UNRESTRICTED | WP_IS_DISABLED -URL = ('https://www.google.com/chromeos/partner/console/cr50reset?' - 'challenge=%s&hwid=%s') -RMA_SUPPORT_PROD = '0.3.3' -RMA_SUPPORT_PREPVT = '0.4.5' -DEV_MODE_OPEN_PROD = '0.3.9' -DEV_MODE_OPEN_PREPVT = '0.4.7' -TESTLAB_PROD = '0.3.10' -CR50_USB = '18d1:5014' -CR50_LSUSB_CMD = ['lsusb', '-vd', CR50_USB] -ERASED_BID = 'ffffffff' +URL = ( + "https://www.google.com/chromeos/partner/console/cr50reset?" "challenge=%s&hwid=%s" +) +RMA_SUPPORT_PROD = "0.3.3" +RMA_SUPPORT_PREPVT = "0.4.5" +DEV_MODE_OPEN_PROD = "0.3.9" +DEV_MODE_OPEN_PREPVT = "0.4.7" +TESTLAB_PROD = "0.3.10" +CR50_USB = "18d1:5014" +CR50_LSUSB_CMD = ["lsusb", "-vd", CR50_USB] +ERASED_BID = "ffffffff" DEBUG_MISSING_USB = """ Unable to find Cr50 Device 18d1:5014 @@ -128,13 +129,14 @@ DEBUG_DUT_CONTROL_OSERROR = """ Run from chroot if you are trying to use a /dev/pts ccd servo console """ + class RMAOpen(object): """Used to find the cr50 console and run RMA open""" - ENABLE_TESTLAB_CMD = 'ccd testlab enabled\n' + ENABLE_TESTLAB_CMD = "ccd testlab enabled\n" def __init__(self, device=None, usb_serial=None, servo_port=None, ip=None): - self.servo_port = servo_port if servo_port else '9999' + self.servo_port = servo_port if servo_port else "9999" self.ip = ip if device: self.set_cr50_device(device) @@ -142,18 +144,18 @@ class RMAOpen(object): self.find_cr50_servo_uart() else: self.find_cr50_device(usb_serial) - logging.info('DEVICE: %s', self.device) + logging.info("DEVICE: %s", self.device) self.check_version() self.print_platform_info() - logging.info('Cr50 setup ok') + logging.info("Cr50 setup ok") self.update_ccd_state() self.using_ccd = self.device_is_running_with_servo_ccd() def _dut_control(self, control): """Run dut-control and return the response""" try: - cmd = ['dut-control', '-p', self.servo_port, control] - return subprocess.check_output(cmd, encoding='utf-8').strip() + cmd = ["dut-control", "-p", self.servo_port, control] + return subprocess.check_output(cmd, encoding="utf-8").strip() except OSError: logging.warning(DEBUG_DUT_CONTROL_OSERROR) raise @@ -163,8 +165,8 @@ class RMAOpen(object): Find the console and configure it, so it can be used with this script. """ - self._dut_control('cr50_uart_timestamp:off') - self.device = self._dut_control('cr50_uart_pty').split(':')[-1] + self._dut_control("cr50_uart_timestamp:off") + self.device = self._dut_control("cr50_uart_pty").split(":")[-1] def set_cr50_device(self, device): """Save the device used for the console""" @@ -183,38 +185,38 @@ class RMAOpen(object): try: ser = serial.Serial(self.device, timeout=1) except OSError: - logging.warning('Permission denied %s', self.device) - logging.warning('Try running cr50_rma_open with sudo') + logging.warning("Permission denied %s", self.device) + logging.warning("Try running cr50_rma_open with sudo") raise - write_cmd = cmd + '\n\n' - ser.write(write_cmd.encode('utf-8')) + write_cmd = cmd + "\n\n" + ser.write(write_cmd.encode("utf-8")) if nbytes: output = ser.read(nbytes) else: output = ser.readall() ser.close() - output = output.decode('utf-8').strip() if output else '' + output = output.decode("utf-8").strip() if output else "" # Return only the command output - split_cmd = cmd + '\r' + split_cmd = cmd + "\r" if cmd and split_cmd in output: - return ''.join(output.rpartition(split_cmd)[1::]).split('>')[0] + return "".join(output.rpartition(split_cmd)[1::]).split(">")[0] return output def device_is_running_with_servo_ccd(self): """Return True if the device is a servod ccd console""" # servod uses /dev/pts consoles. Non-servod uses /dev/ttyUSBX - if '/dev/pts' not in self.device: + if "/dev/pts" not in self.device: return False # If cr50 doesn't show rdd is connected, cr50 the device must not be # a ccd device - if 'Rdd: connected' not in self.send_cmd_get_output('ccdstate'): + if "Rdd: connected" not in self.send_cmd_get_output("ccdstate"): return False # Check if the servod is running with ccd. This requires the script # is run in the chroot, so run it last. - if 'ccd_cr50' not in self._dut_control('servo_type'): + if "ccd_cr50" not in self._dut_control("servo_type"): return False - logging.info('running through servod ccd') + logging.info("running through servod ccd") return True def get_rma_challenge(self): @@ -239,14 +241,14 @@ class RMAOpen(object): Returns: The RMA challenge with all whitespace removed. """ - output = self.send_cmd_get_output('rma_auth').strip() - logging.info('rma_auth output:\n%s', output) + output = self.send_cmd_get_output("rma_auth").strip() + logging.info("rma_auth output:\n%s", output) # Extract the challenge from the console output - if 'generated challenge:' in output: - return output.split('generated challenge:')[-1].strip() - challenge = ''.join(re.findall(r' \S{5}' * 4, output)) + if "generated challenge:" in output: + return output.split("generated challenge:")[-1].strip() + challenge = "".join(re.findall(r" \S{5}" * 4, output)) # Remove all whitespace - return re.sub(r'\s', '', challenge) + return re.sub(r"\s", "", challenge) def generate_challenge_url(self, hwid): """Get the rma_auth challenge @@ -257,12 +259,14 @@ class RMAOpen(object): challenge = self.get_rma_challenge() self.print_platform_info() - logging.info('CHALLENGE: %s', challenge) - logging.info('HWID: %s', hwid) + logging.info("CHALLENGE: %s", challenge) + logging.info("HWID: %s", hwid) url = URL % (challenge, hwid) - logging.info('GOTO:\n %s', url) - logging.info('If the server fails to debug the challenge make sure the ' - 'RLZ is allowlisted') + logging.info("GOTO:\n %s", url) + logging.info( + "If the server fails to debug the challenge make sure the " + "RLZ is allowlisted" + ) def try_authcode(self, authcode): """Try opening cr50 with the authcode @@ -272,48 +276,48 @@ class RMAOpen(object): """ # rma_auth may cause the system to reboot. Don't wait to read all that # output. Read the first 300 bytes and call it a day. - output = self.send_cmd_get_output('rma_auth ' + authcode, nbytes=300) - logging.info('CR50 RESPONSE: %s', output) - logging.info('waiting for cr50 reboot') + output = self.send_cmd_get_output("rma_auth " + authcode, nbytes=300) + logging.info("CR50 RESPONSE: %s", output) + logging.info("waiting for cr50 reboot") # Cr50 may be rebooting. Wait a bit time.sleep(5) if self.using_ccd: # After reboot, reset the ccd endpoints - self._dut_control('power_state:ccd_reset') + self._dut_control("power_state:ccd_reset") # Update the ccd state after the authcode attempt self.update_ccd_state() - authcode_match = 'process_response: success!' in output + authcode_match = "process_response: success!" in output if not self.check(CCD_IS_UNRESTRICTED): if not authcode_match: logging.warning(DEBUG_AUTHCODE_MISMATCH) - message = 'Authcode mismatch. Check args and url' + message = "Authcode mismatch. Check args and url" else: - message = 'Could not set all capability privileges to Always' + message = "Could not set all capability privileges to Always" raise ValueError(message) def wp_is_force_disabled(self): """Returns True if write protect is forced disabled""" - output = self.send_cmd_get_output('wp') - wp_state = output.split('Flash WP:', 1)[-1].split('\n', 1)[0].strip() - logging.info('wp: %s', wp_state) - return wp_state == 'forced disabled' + output = self.send_cmd_get_output("wp") + wp_state = output.split("Flash WP:", 1)[-1].split("\n", 1)[0].strip() + logging.info("wp: %s", wp_state) + return wp_state == "forced disabled" def testlab_is_enabled(self): """Returns True if testlab mode is enabled""" - output = self.send_cmd_get_output('ccd testlab') - testlab_state = output.split('mode')[-1].strip().lower() - logging.info('testlab: %s', testlab_state) - return testlab_state == 'enabled' + output = self.send_cmd_get_output("ccd testlab") + testlab_state = output.split("mode")[-1].strip().lower() + logging.info("testlab: %s", testlab_state) + return testlab_state == "enabled" def ccd_is_restricted(self): """Returns True if any of the capabilities are still restricted""" - output = self.send_cmd_get_output('ccd') - if 'Capabilities' not in output: - raise ValueError('Could not get ccd output') - logging.debug('CURRENT CCD SETTINGS:\n%s', output) - restricted = 'IfOpened' in output or 'IfUnlocked' in output - logging.info('ccd: %srestricted', '' if restricted else 'Un') + output = self.send_cmd_get_output("ccd") + if "Capabilities" not in output: + raise ValueError("Could not get ccd output") + logging.debug("CURRENT CCD SETTINGS:\n%s", output) + restricted = "IfOpened" in output or "IfUnlocked" in output + logging.info("ccd: %srestricted", "" if restricted else "Un") return restricted def update_ccd_state(self): @@ -339,9 +343,10 @@ class RMAOpen(object): def _capabilities_allow_open_from_console(self): """Return True if ccd open is Always allowed from usb""" - output = self.send_cmd_get_output('ccd') - return (re.search('OpenNoDevMode.*Always', output) and - re.search('OpenFromUSB.*Always', output)) + output = self.send_cmd_get_output("ccd") + return re.search("OpenNoDevMode.*Always", output) and re.search( + "OpenFromUSB.*Always", output + ) def _requires_dev_mode_open(self): """Return True if the image requires dev mode to open""" @@ -354,78 +359,81 @@ class RMAOpen(object): def _run_on_dut(self, command): """Run the command on the DUT.""" - return subprocess.check_output(['ssh', self.ip, command], - encoding='utf-8') + return subprocess.check_output(["ssh", self.ip, command], encoding="utf-8") def _open_in_dev_mode(self): """Open Cr50 when it's in dev mode""" - output = self.send_cmd_get_output('ccd') + output = self.send_cmd_get_output("ccd") # If the device is already open, nothing needs to be done. - if 'State: Open' not in output: + if "State: Open" not in output: # Verify the device is in devmode before trying to run open. - if 'dev_mode' not in output: - logging.warning('Enter dev mode to open ccd or update to %s', - TESTLAB_PROD) - raise ValueError('DUT not in dev mode') + if "dev_mode" not in output: + logging.warning( + "Enter dev mode to open ccd or update to %s", TESTLAB_PROD + ) + raise ValueError("DUT not in dev mode") if not self.ip: - logging.warning("If your DUT doesn't have ssh support, run " - "'gsctool -a -o' from the AP") - raise ValueError('Cannot run ccd open without dut ip') - self._run_on_dut('gsctool -a -o') + logging.warning( + "If your DUT doesn't have ssh support, run " + "'gsctool -a -o' from the AP" + ) + raise ValueError("Cannot run ccd open without dut ip") + self._run_on_dut("gsctool -a -o") # Wait >1 second for cr50 to update ccd state time.sleep(3) - output = self.send_cmd_get_output('ccd') - if 'State: Open' not in output: - raise ValueError('Could not open cr50') - logging.info('ccd is open') + output = self.send_cmd_get_output("ccd") + if "State: Open" not in output: + raise ValueError("Could not open cr50") + logging.info("ccd is open") def enable_testlab(self): """Disable write protect""" if not self._has_testlab_support(): - logging.warning('Testlab mode is not supported in prod iamges') + logging.warning("Testlab mode is not supported in prod iamges") return # Some cr50 images need to be in dev mode before they can be opened. if self._requires_dev_mode_open(): self._open_in_dev_mode() else: - self.send_cmd_get_output('ccd open') - logging.info('Enabling testlab mode reqires pressing the power button.') - logging.info('Once the process starts keep tapping the power button ' - 'for 10 seconds.') + self.send_cmd_get_output("ccd open") + logging.info("Enabling testlab mode reqires pressing the power button.") + logging.info( + "Once the process starts keep tapping the power button " "for 10 seconds." + ) input("Press Enter when you're ready to start...") end_time = time.time() + 15 ser = serial.Serial(self.device, timeout=1) - printed_lines = '' - output = '' + printed_lines = "" + output = "" # start ccd testlab enable - ser.write(self.ENABLE_TESTLAB_CMD.encode('utf-8')) - logging.info('start pressing the power button\n\n') + ser.write(self.ENABLE_TESTLAB_CMD.encode("utf-8")) + logging.info("start pressing the power button\n\n") # Print all of the cr50 output as we get it, so the user will have more # information about pressing the power button. Tapping the power button # a couple of times should do it, but this will give us more confidence # the process is still running/worked. try: while time.time() < end_time: - output += ser.read(100).decode('utf-8') - full_lines = output.rsplit('\n', 1)[0] + output += ser.read(100).decode("utf-8") + full_lines = output.rsplit("\n", 1)[0] new_lines = full_lines if printed_lines: new_lines = full_lines.split(printed_lines, 1)[-1].strip() - logging.info('\n%s', new_lines) + logging.info("\n%s", new_lines) printed_lines = full_lines # Make sure the process hasn't ended. If it has, print the last # of the output and exit. new_lines = output.split(printed_lines, 1)[-1] - if 'CCD test lab mode enabled' in output: + if "CCD test lab mode enabled" in output: # print the last of the ou logging.info(new_lines) break - elif 'Physical presence check timeout' in output: + elif "Physical presence check timeout" in output: logging.info(new_lines) - logging.warning('Did not detect power button press in time') - raise ValueError('Could not enable testlab mode try again') + logging.warning("Did not detect power button press in time") + raise ValueError("Could not enable testlab mode try again") finally: ser.close() # Wait for the ccd hook to update things @@ -433,44 +441,48 @@ class RMAOpen(object): # Update the state after attempting to disable write protect self.update_ccd_state() if not self.check(TESTLAB_IS_ENABLED): - raise ValueError('Could not enable testlab mode try again') + raise ValueError("Could not enable testlab mode try again") def wp_disable(self): """Disable write protect""" - logging.info('Disabling write protect') - self.send_cmd_get_output('wp disable') + logging.info("Disabling write protect") + self.send_cmd_get_output("wp disable") # Update the state after attempting to disable write protect self.update_ccd_state() if not self.check(WP_IS_DISABLED): - raise ValueError('Could not disable write protect') + raise ValueError("Could not disable write protect") def check_version(self): """Make sure cr50 is running a version that supports RMA Open""" - output = self.send_cmd_get_output('version') + output = self.send_cmd_get_output("version") if not output.strip(): logging.warning(DEBUG_DEVICE, self.device) - raise ValueError('Could not communicate with %s' % self.device) + raise ValueError("Could not communicate with %s" % self.device) - version = re.search(r'RW.*\* ([\d\.]+)/', output).group(1) - logging.info('Running Cr50 Version: %s', version) - self.running_ver_fields = [int(field) for field in version.split('.')] + version = re.search(r"RW.*\* ([\d\.]+)/", output).group(1) + logging.info("Running Cr50 Version: %s", version) + self.running_ver_fields = [int(field) for field in version.split(".")] # prePVT images have even major versions. Prod have odd self.is_prepvt = self.running_ver_fields[1] % 2 == 0 rma_support = RMA_SUPPORT_PREPVT if self.is_prepvt else RMA_SUPPORT_PROD - logging.info('%s RMA support added in: %s', - 'prePVT' if self.is_prepvt else 'prod', rma_support) + logging.info( + "%s RMA support added in: %s", + "prePVT" if self.is_prepvt else "prod", + rma_support, + ) if not self.is_prepvt and self._running_version_is_older(TESTLAB_PROD): - raise ValueError('Update cr50. No testlab support in old prod ' - 'images.') + raise ValueError("Update cr50. No testlab support in old prod " "images.") if self._running_version_is_older(rma_support): - raise ValueError('%s does not have RMA support. Update to at ' - 'least %s' % (version, rma_support)) + raise ValueError( + "%s does not have RMA support. Update to at " + "least %s" % (version, rma_support) + ) def _running_version_is_older(self, target_ver): """Returns True if running version is older than target_ver.""" - target_ver_fields = [int(field) for field in target_ver.split('.')] + target_ver_fields = [int(field) for field in target_ver.split(".")] for i, field in enumerate(self.running_ver_fields): if field > int(target_ver_fields[i]): return False @@ -486,11 +498,11 @@ class RMAOpen(object): is no output or sysinfo doesn't contain the devid. """ self.set_cr50_device(device) - sysinfo = self.send_cmd_get_output('sysinfo') + sysinfo = self.send_cmd_get_output("sysinfo") # Make sure there is some output, and it shows it's from Cr50 - if not sysinfo or 'cr50' not in sysinfo: + if not sysinfo or "cr50" not in sysinfo: return False - logging.debug('Sysinfo output: %s', sysinfo) + logging.debug("Sysinfo output: %s", sysinfo) # The cr50 device id should be in the sysinfo output, if we found # the right console. Make sure it is return devid in sysinfo @@ -508,104 +520,137 @@ class RMAOpen(object): ValueError if the console can't be found with the given serialname """ usb_serial = self.find_cr50_usb(usb_serial) - logging.info('SERIALNAME: %s', usb_serial) - devid = '0x' + ' 0x'.join(usb_serial.lower().split('-')) - logging.info('DEVID: %s', devid) + logging.info("SERIALNAME: %s", usb_serial) + devid = "0x" + " 0x".join(usb_serial.lower().split("-")) + logging.info("DEVID: %s", devid) # Get all the usb devices - devices = glob.glob('/dev/ttyUSB*') + devices = glob.glob("/dev/ttyUSB*") # Typically Cr50 has the lowest number. Sort the devices, so we're more # likely to try the cr50 console first. devices.sort() # Find the one that is the cr50 console for device in devices: - logging.info('testing %s', device) + logging.info("testing %s", device) if self.device_matches_devid(devid, device): - logging.info('found device: %s', device) + logging.info("found device: %s", device) return logging.warning(DEBUG_CONNECTION) - raise ValueError('Found USB device, but could not communicate with ' - 'cr50 console') + raise ValueError( + "Found USB device, but could not communicate with " "cr50 console" + ) def print_platform_info(self): """Print the cr50 BID RLZ code""" - bid_output = self.send_cmd_get_output('bid') - bid = re.search(r'Board ID: (\S+?)[:,]', bid_output).group(1) + bid_output = self.send_cmd_get_output("bid") + bid = re.search(r"Board ID: (\S+?)[:,]", bid_output).group(1) if bid == ERASED_BID: logging.warning(DEBUG_ERASED_BOARD_ID) - raise ValueError('Cannot run RMA Open when board id is erased') + raise ValueError("Cannot run RMA Open when board id is erased") bid = int(bid, 16) - chrs = [chr((bid >> (8 * i)) & 0xff) for i in range(4)] - logging.info('RLZ: %s', ''.join(chrs[::-1])) + chrs = [chr((bid >> (8 * i)) & 0xFF) for i in range(4)] + logging.info("RLZ: %s", "".join(chrs[::-1])) @staticmethod def find_cr50_usb(usb_serial): """Make sure the Cr50 USB device exists""" try: - output = subprocess.check_output(CR50_LSUSB_CMD, encoding='utf-8') + output = subprocess.check_output(CR50_LSUSB_CMD, encoding="utf-8") except: logging.warning(DEBUG_MISSING_USB) - raise ValueError('Could not find Cr50 USB device') - serialnames = re.findall(r'iSerial +\d+ (\S+)\s', output) + raise ValueError("Could not find Cr50 USB device") + serialnames = re.findall(r"iSerial +\d+ (\S+)\s", output) if usb_serial: if usb_serial not in serialnames: logging.warning(DEBUG_SERIALNAME) raise ValueError('Could not find usb device "%s"' % usb_serial) return usb_serial if len(serialnames) > 1: - logging.info('Found Cr50 device serialnames %s', - ', '.join(serialnames)) + logging.info("Found Cr50 device serialnames %s", ", ".join(serialnames)) logging.warning(DEBUG_TOO_MANY_USB_DEVICES) - raise ValueError('Too many cr50 usb devices') + raise ValueError("Too many cr50 usb devices") return serialnames[0] def print_dut_state(self): """Print CCD RMA and testlab mode state.""" if not self.check(CCD_IS_UNRESTRICTED): - logging.info('CCD is still restricted.') - logging.info('Run cr50_rma_open.py -g -i $HWID to generate a url') - logging.info('Run cr50_rma_open.py -a $AUTHCODE to open cr50 with ' - 'an authcode') + logging.info("CCD is still restricted.") + logging.info("Run cr50_rma_open.py -g -i $HWID to generate a url") + logging.info( + "Run cr50_rma_open.py -a $AUTHCODE to open cr50 with " "an authcode" + ) elif not self.check(WP_IS_DISABLED): - logging.info('WP is still enabled.') - logging.info('Run cr50_rma_open.py -w to disable write protect') + logging.info("WP is still enabled.") + logging.info("Run cr50_rma_open.py -w to disable write protect") if self.check(RMA_OPENED): - logging.info('RMA Open complete') + logging.info("RMA Open complete") if not self.check(TESTLAB_IS_ENABLED) and self.is_prepvt: - logging.info('testlab mode is disabled.') - logging.info('If you are prepping a device for the testlab, you ' - 'should enable testlab mode.') - logging.info('Run cr50_rma_open.py -t to enable testlab mode') + logging.info("testlab mode is disabled.") + logging.info( + "If you are prepping a device for the testlab, you " + "should enable testlab mode." + ) + logging.info("Run cr50_rma_open.py -t to enable testlab mode") def parse_args(argv): """Get cr50_rma_open args.""" parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('-g', '--generate_challenge', action='store_true', - help='Generate Cr50 challenge. Must be used with -i') - parser.add_argument('-t', '--enable_testlab', action='store_true', - help='enable testlab mode') - parser.add_argument('-w', '--wp_disable', action='store_true', - help='Disable write protect') - parser.add_argument('-c', '--check_connection', action='store_true', - help='Check cr50 console connection works') - parser.add_argument('-s', '--serialname', type=str, default='', - help='The cr50 usb serialname') - parser.add_argument('-D', '--debug', action='store_true', - help='print debug messages') - parser.add_argument('-d', '--device', type=str, default='', - help='cr50 console device ex /dev/ttyUSB0') - parser.add_argument('-i', '--hwid', type=str, default='', - help='The board hwid. Needed to generate a challenge') - parser.add_argument('-a', '--authcode', type=str, default='', - help='The authcode string from the challenge url') - parser.add_argument('-P', '--servo_port', type=str, default='', - help='the servo port') - parser.add_argument('-I', '--ip', type=str, default='', - help='The DUT IP. Necessary to do ccd open') + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "-g", + "--generate_challenge", + action="store_true", + help="Generate Cr50 challenge. Must be used with -i", + ) + parser.add_argument( + "-t", "--enable_testlab", action="store_true", help="enable testlab mode" + ) + parser.add_argument( + "-w", "--wp_disable", action="store_true", help="Disable write protect" + ) + parser.add_argument( + "-c", + "--check_connection", + action="store_true", + help="Check cr50 console connection works", + ) + parser.add_argument( + "-s", "--serialname", type=str, default="", help="The cr50 usb serialname" + ) + parser.add_argument( + "-D", "--debug", action="store_true", help="print debug messages" + ) + parser.add_argument( + "-d", + "--device", + type=str, + default="", + help="cr50 console device ex /dev/ttyUSB0", + ) + parser.add_argument( + "-i", + "--hwid", + type=str, + default="", + help="The board hwid. Needed to generate a challenge", + ) + parser.add_argument( + "-a", + "--authcode", + type=str, + default="", + help="The authcode string from the challenge url", + ) + parser.add_argument( + "-P", "--servo_port", type=str, default="", help="the servo port" + ) + parser.add_argument( + "-I", "--ip", type=str, default="", help="The DUT IP. Necessary to do ccd open" + ) return parser.parse_args(argv) @@ -614,52 +659,53 @@ def main(argv): opts = parse_args(argv) loglevel = logging.INFO - log_format = '%(levelname)7s' + log_format = "%(levelname)7s" if opts.debug: loglevel = logging.DEBUG - log_format += ' - %(lineno)3d:%(funcName)-15s' - log_format += ' - %(message)s' + log_format += " - %(lineno)3d:%(funcName)-15s" + log_format += " - %(message)s" logging.basicConfig(level=loglevel, format=log_format) tried_authcode = False - logging.info('Running cr50_rma_open version %s', SCRIPT_VERSION) + logging.info("Running cr50_rma_open version %s", SCRIPT_VERSION) - cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, - opts.ip) + cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, opts.ip) if opts.check_connection: sys.exit(0) if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): if opts.generate_challenge: if not opts.hwid: - logging.warning('--hwid necessary to generate challenge url') + logging.warning("--hwid necessary to generate challenge url") sys.exit(0) cr50_rma_open.generate_challenge_url(opts.hwid) sys.exit(0) elif opts.authcode: - logging.info('Using authcode: %s', opts.authcode) + logging.info("Using authcode: %s", opts.authcode) cr50_rma_open.try_authcode(opts.authcode) tried_authcode = True - if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or - opts.wp_disable): + if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or opts.wp_disable): if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): - raise ValueError("Can't disable write protect unless ccd is " - "open. Run through the rma open process first") + raise ValueError( + "Can't disable write protect unless ccd is " + "open. Run through the rma open process first" + ) if tried_authcode: - logging.warning('RMA Open did not disable write protect. File a ' - 'bug') - logging.warning('Trying to disable it manually') + logging.warning("RMA Open did not disable write protect. File a " "bug") + logging.warning("Trying to disable it manually") cr50_rma_open.wp_disable() if not cr50_rma_open.check(TESTLAB_IS_ENABLED) and opts.enable_testlab: if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): - raise ValueError("Can't enable testlab mode unless ccd is open." - "Run through the rma open process first") + raise ValueError( + "Can't enable testlab mode unless ccd is open." + "Run through the rma open process first" + ) cr50_rma_open.enable_testlab() cr50_rma_open.print_dut_state() -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/extra/stack_analyzer/stack_analyzer.py b/extra/stack_analyzer/stack_analyzer.py index 77d16d5450..17b2651972 100755 --- a/extra/stack_analyzer/stack_analyzer.py +++ b/extra/stack_analyzer/stack_analyzer.py @@ -25,1848 +25,1992 @@ import ctypes import os import re import subprocess -import yaml +import yaml -SECTION_RO = 'RO' -SECTION_RW = 'RW' +SECTION_RO = "RO" +SECTION_RW = "RW" # Default size of extra stack frame needed by exception context switch. # This value is for cortex-m with FPU enabled. DEFAULT_EXCEPTION_FRAME_SIZE = 224 class StackAnalyzerError(Exception): - """Exception class for stack analyzer utility.""" + """Exception class for stack analyzer utility.""" class TaskInfo(ctypes.Structure): - """Taskinfo ctypes structure. + """Taskinfo ctypes structure. - The structure definition is corresponding to the "struct taskinfo" - in "util/export_taskinfo.so.c". - """ - _fields_ = [('name', ctypes.c_char_p), - ('routine', ctypes.c_char_p), - ('stack_size', ctypes.c_uint32)] + The structure definition is corresponding to the "struct taskinfo" + in "util/export_taskinfo.so.c". + """ + _fields_ = [ + ("name", ctypes.c_char_p), + ("routine", ctypes.c_char_p), + ("stack_size", ctypes.c_uint32), + ] -class Task(object): - """Task information. - Attributes: - name: Task name. - routine_name: Routine function name. - stack_max_size: Max stack size. - routine_address: Resolved routine address. None if it hasn't been resolved. - """ - - def __init__(self, name, routine_name, stack_max_size, routine_address=None): - """Constructor. +class Task(object): + """Task information. - Args: + Attributes: name: Task name. routine_name: Routine function name. stack_max_size: Max stack size. - routine_address: Resolved routine address. + routine_address: Resolved routine address. None if it hasn't been resolved. """ - self.name = name - self.routine_name = routine_name - self.stack_max_size = stack_max_size - self.routine_address = routine_address - def __eq__(self, other): - """Task equality. + def __init__(self, name, routine_name, stack_max_size, routine_address=None): + """Constructor. - Args: - other: The compared object. + Args: + name: Task name. + routine_name: Routine function name. + stack_max_size: Max stack size. + routine_address: Resolved routine address. + """ + self.name = name + self.routine_name = routine_name + self.stack_max_size = stack_max_size + self.routine_address = routine_address - Returns: - True if equal, False if not. - """ - if not isinstance(other, Task): - return False + def __eq__(self, other): + """Task equality. - return (self.name == other.name and - self.routine_name == other.routine_name and - self.stack_max_size == other.stack_max_size and - self.routine_address == other.routine_address) + Args: + other: The compared object. + Returns: + True if equal, False if not. + """ + if not isinstance(other, Task): + return False -class Symbol(object): - """Symbol information. + return ( + self.name == other.name + and self.routine_name == other.routine_name + and self.stack_max_size == other.stack_max_size + and self.routine_address == other.routine_address + ) - Attributes: - address: Symbol address. - symtype: Symbol type, 'O' (data, object) or 'F' (function). - size: Symbol size. - name: Symbol name. - """ - def __init__(self, address, symtype, size, name): - """Constructor. +class Symbol(object): + """Symbol information. - Args: + Attributes: address: Symbol address. - symtype: Symbol type. + symtype: Symbol type, 'O' (data, object) or 'F' (function). size: Symbol size. name: Symbol name. """ - assert symtype in ['O', 'F'] - self.address = address - self.symtype = symtype - self.size = size - self.name = name - def __eq__(self, other): - """Symbol equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Symbol): - return False - - return (self.address == other.address and - self.symtype == other.symtype and - self.size == other.size and - self.name == other.name) + def __init__(self, address, symtype, size, name): + """Constructor. + + Args: + address: Symbol address. + symtype: Symbol type. + size: Symbol size. + name: Symbol name. + """ + assert symtype in ["O", "F"] + self.address = address + self.symtype = symtype + self.size = size + self.name = name + + def __eq__(self, other): + """Symbol equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Symbol): + return False + + return ( + self.address == other.address + and self.symtype == other.symtype + and self.size == other.size + and self.name == other.name + ) class Callsite(object): - """Function callsite. - - Attributes: - address: Address of callsite location. None if it is unknown. - target: Callee address. None if it is unknown. - is_tail: A bool indicates that it is a tailing call. - callee: Resolved callee function. None if it hasn't been resolved. - """ - - def __init__(self, address, target, is_tail, callee=None): - """Constructor. + """Function callsite. - Args: + Attributes: address: Address of callsite location. None if it is unknown. target: Callee address. None if it is unknown. - is_tail: A bool indicates that it is a tailing call. (function jump to - another function without restoring the stack frame) - callee: Resolved callee function. + is_tail: A bool indicates that it is a tailing call. + callee: Resolved callee function. None if it hasn't been resolved. """ - # It makes no sense that both address and target are unknown. - assert not (address is None and target is None) - self.address = address - self.target = target - self.is_tail = is_tail - self.callee = callee - - def __eq__(self, other): - """Callsite equality. - - Args: - other: The compared object. - Returns: - True if equal, False if not. - """ - if not isinstance(other, Callsite): - return False - - if not (self.address == other.address and - self.target == other.target and - self.is_tail == other.is_tail): - return False - - if self.callee is None: - return other.callee is None - elif other.callee is None: - return False - - # Assume the addresses of functions are unique. - return self.callee.address == other.callee.address + def __init__(self, address, target, is_tail, callee=None): + """Constructor. + + Args: + address: Address of callsite location. None if it is unknown. + target: Callee address. None if it is unknown. + is_tail: A bool indicates that it is a tailing call. (function jump to + another function without restoring the stack frame) + callee: Resolved callee function. + """ + # It makes no sense that both address and target are unknown. + assert not (address is None and target is None) + self.address = address + self.target = target + self.is_tail = is_tail + self.callee = callee + + def __eq__(self, other): + """Callsite equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Callsite): + return False + + if not ( + self.address == other.address + and self.target == other.target + and self.is_tail == other.is_tail + ): + return False + + if self.callee is None: + return other.callee is None + elif other.callee is None: + return False + + # Assume the addresses of functions are unique. + return self.callee.address == other.callee.address class Function(object): - """Function. - - Attributes: - address: Address of function. - name: Name of function from its symbol. - stack_frame: Size of stack frame. - callsites: Callsite list. - stack_max_usage: Max stack usage. None if it hasn't been analyzed. - stack_max_path: Max stack usage path. None if it hasn't been analyzed. - """ + """Function. - def __init__(self, address, name, stack_frame, callsites): - """Constructor. - - Args: + Attributes: address: Address of function. name: Name of function from its symbol. stack_frame: Size of stack frame. callsites: Callsite list. + stack_max_usage: Max stack usage. None if it hasn't been analyzed. + stack_max_path: Max stack usage path. None if it hasn't been analyzed. """ - self.address = address - self.name = name - self.stack_frame = stack_frame - self.callsites = callsites - self.stack_max_usage = None - self.stack_max_path = None - - def __eq__(self, other): - """Function equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Function): - return False - - if not (self.address == other.address and - self.name == other.name and - self.stack_frame == other.stack_frame and - self.callsites == other.callsites and - self.stack_max_usage == other.stack_max_usage): - return False - if self.stack_max_path is None: - return other.stack_max_path is None - elif other.stack_max_path is None: - return False + def __init__(self, address, name, stack_frame, callsites): + """Constructor. + + Args: + address: Address of function. + name: Name of function from its symbol. + stack_frame: Size of stack frame. + callsites: Callsite list. + """ + self.address = address + self.name = name + self.stack_frame = stack_frame + self.callsites = callsites + self.stack_max_usage = None + self.stack_max_path = None + + def __eq__(self, other): + """Function equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Function): + return False + + if not ( + self.address == other.address + and self.name == other.name + and self.stack_frame == other.stack_frame + and self.callsites == other.callsites + and self.stack_max_usage == other.stack_max_usage + ): + return False + + if self.stack_max_path is None: + return other.stack_max_path is None + elif other.stack_max_path is None: + return False + + if len(self.stack_max_path) != len(other.stack_max_path): + return False + + for self_func, other_func in zip(self.stack_max_path, other.stack_max_path): + # Assume the addresses of functions are unique. + if self_func.address != other_func.address: + return False + + return True + + def __hash__(self): + return id(self) - if len(self.stack_max_path) != len(other.stack_max_path): - return False - - for self_func, other_func in zip(self.stack_max_path, other.stack_max_path): - # Assume the addresses of functions are unique. - if self_func.address != other_func.address: - return False - - return True - - def __hash__(self): - return id(self) class AndesAnalyzer(object): - """Disassembly analyzer for Andes architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - GENERAL_PURPOSE_REGISTER_SIZE = 4 - - # Possible condition code suffixes. - CONDITION_CODES = [ 'eq', 'eqz', 'gez', 'gtz', 'lez', 'ltz', 'ne', 'nez', - 'eqc', 'nec', 'nezs', 'nes', 'eqs'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - - IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>' - # Branch instructions. - JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$' \ - .format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile \ - (r'^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$') - CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE)) - # Ignore lp register because it's for return. - INDIRECT_CALL_OPERAND_RE = re.compile \ - (r'^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$') - # TODO: Handle other kinds of store instructions. - PUSH_OPCODE_RE = re.compile(r'^push(\d{1,})$') - PUSH_OPERAND_RE = re.compile(r'^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}') - SMW_OPCODE_RE = re.compile(r'^smw(\.\w\w|\.\w\w\w)$') - SMW_OPERAND_RE = re.compile(r'^(\$r\d{1,}|\$\wp), \[\$\wp\], ' - r'(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}') - OPERANDGROUP_RE = re.compile(r'^\$r\d{1,}\~\$r\d{1,}') - - LWI_OPCODE_RE = re.compile(r'^lwi(\.\w\w)$') - LWI_PC_OPERAND_RE = re.compile(r'^\$pc, \[([^\]]+)\]') - # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec" - # Assume there is always a "\t" after the hex data. - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+' - r'(?P<words>[0-9A-Fa-f ]+)' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + """Disassembly analyzer for Andes architecture. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - - Returns: - (address, words, opcode, operand_text): The instruction address, words, - opcode, and the text of operands. - None if it isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - words = result.group('words') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, words, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - - stack_frame = 0 - callsites = [] - for address, words, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - - if is_jump_opcode or is_call_opcode: - is_tail = is_jump_opcode - - result = self.CALL_OPERAND_RE.match(operand_text) + GENERAL_PURPOSE_REGISTER_SIZE = 4 + + # Possible condition code suffixes. + CONDITION_CODES = [ + "eq", + "eqz", + "gez", + "gtz", + "lez", + "ltz", + "ne", + "nez", + "eqc", + "nec", + "nezs", + "nes", + "eqs", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + + IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>" + # Branch instructions. + JUMP_OPCODE_RE = re.compile( + r"^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$".format(CONDITION_CODES_RE) + ) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$") + CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE)) + # Ignore lp register because it's for return. + INDIRECT_CALL_OPERAND_RE = re.compile(r"^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$") + # TODO: Handle other kinds of store instructions. + PUSH_OPCODE_RE = re.compile(r"^push(\d{1,})$") + PUSH_OPERAND_RE = re.compile(r"^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}") + SMW_OPCODE_RE = re.compile(r"^smw(\.\w\w|\.\w\w\w)$") + SMW_OPERAND_RE = re.compile( + r"^(\$r\d{1,}|\$\wp), \[\$\wp\], " + r"(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}" + ) + OPERANDGROUP_RE = re.compile(r"^\$r\d{1,}\~\$r\d{1,}") + + LWI_OPCODE_RE = re.compile(r"^lwi(\.\w\w)$") + LWI_PC_OPERAND_RE = re.compile(r"^\$pc, \[([^\]]+)\]") + # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec" + # Assume there is always a "\t" after the hex data. + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+" + r"(?P<words>[0-9A-Fa-f ]+)" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, words, opcode, operand_text): The instruction address, words, + opcode, and the text of operands. + None if it isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) - + return None + + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None + + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + words = result.group("words") + if operand_text is None: + operand_text = "" else: - target_address = int(result.group(1), 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.LWI_OPCODE_RE.match(opcode) is not None: - result = self.LWI_PC_OPERAND_RE.match(operand_text) - if result is not None: - # Ignore "lwi $pc, [$sp], xx" because it's usually a return. - if result.group(1) != '$sp': - # Found an indirect call. - callsites.append(Callsite(address, None, True)) - - elif self.PUSH_OPCODE_RE.match(opcode) is not None: - # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp} - if self.PUSH_OPERAND_RE.match(operand_text) is not None: - # capture fc 20 - imm5u = int(words.split(' ')[1], 16) - # sp = sp - (imm5u << 3) - imm8u = (imm5u<<3) & 0xff - stack_frame += imm8u - - result = self.PUSH_OPERAND_RE.match(operand_text) - operandgroup_text = result.group(1) - # capture $rx~$ry - if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: - # capture number & transfer string to integer - oprandgrouphead = operandgroup_text.split(',')[0] - rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0]))) - ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1]))) - - stack_frame += ((len(operandgroup_text.split(','))+ry-rx) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - else: - stack_frame += (len(operandgroup_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - - elif self.SMW_OPCODE_RE.match(opcode) is not None: - # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp} - if self.SMW_OPERAND_RE.match(operand_text) is not None: - result = self.SMW_OPERAND_RE.match(operand_text) - operandgroup_text = result.group(3) - # capture $rx~$ry - if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: - # capture number & transfer string to integer - oprandgrouphead = operandgroup_text.split(',')[0] - rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0]))) - ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1]))) - - stack_frame += ((len(operandgroup_text.split(','))+ry-rx) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - else: - stack_frame += (len(operandgroup_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - - return (stack_frame, callsites) - -class ArmAnalyzer(object): - """Disassembly analyzer for ARM architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - GENERAL_PURPOSE_REGISTER_SIZE = 4 - - # Possible condition code suffixes. - CONDITION_CODES = ['', 'eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl', 'vs', - 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - # Assume there is no function name containing ">". - IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>' - - # Fuzzy regular expressions for instruction and operand parsing. - # Branch instructions. - JUMP_OPCODE_RE = re.compile( - r'^(b{0}|bx{0})(\.\w)?$'.format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile( - r'^(bl{0}|blx{0})(\.\w)?$'.format(CONDITION_CODES_RE)) - CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE)) - CBZ_CBNZ_OPCODE_RE = re.compile(r'^(cbz|cbnz)(\.\w)?$') - # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>" - CBZ_CBNZ_OPERAND_RE = re.compile(r'^[^,]+,\s+{}$'.format(IMM_ADDRESS_RE)) - # Ignore lr register because it's for return. - INDIRECT_CALL_OPERAND_RE = re.compile(r'^r\d+|sb|sl|fp|ip|sp|pc$') - # TODO(cheyuw): Handle conditional versions of following - # instructions. - # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc). - LDR_OPCODE_RE = re.compile(r'^ldr(\.\w)?$') - # Example: "pc, [sp], #4" - LDR_PC_OPERAND_RE = re.compile(r'^pc, \[([^\]]+)\]') - # TODO(cheyuw): Handle other kinds of stm instructions. - PUSH_OPCODE_RE = re.compile(r'^push$') - STM_OPCODE_RE = re.compile(r'^stmdb$') - # Stack subtraction instructions. - SUB_OPCODE_RE = re.compile(r'^sub(s|w)?(\.\w)?$') - SUB_OPERAND_RE = re.compile(r'^sp[^#]+#(\d+)') - # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68" - # Assume there is always a "\t" after the hex data. - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + operand_text = operand_text.strip() + + return (address, words, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + + stack_frame = 0 + callsites = [] + for address, words, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + + if is_jump_opcode or is_call_opcode: + is_tail = is_jump_opcode + + result = self.CALL_OPERAND_RE.match(operand_text) + + if result is None: + if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None: + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + target_address = int(result.group(1), 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append(Callsite(address, target_address, is_tail)) + + elif self.LWI_OPCODE_RE.match(opcode) is not None: + result = self.LWI_PC_OPERAND_RE.match(operand_text) + if result is not None: + # Ignore "lwi $pc, [$sp], xx" because it's usually a return. + if result.group(1) != "$sp": + # Found an indirect call. + callsites.append(Callsite(address, None, True)) + + elif self.PUSH_OPCODE_RE.match(opcode) is not None: + # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp} + if self.PUSH_OPERAND_RE.match(operand_text) is not None: + # capture fc 20 + imm5u = int(words.split(" ")[1], 16) + # sp = sp - (imm5u << 3) + imm8u = (imm5u << 3) & 0xFF + stack_frame += imm8u + + result = self.PUSH_OPERAND_RE.match(operand_text) + operandgroup_text = result.group(1) + # capture $rx~$ry + if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: + # capture number & transfer string to integer + oprandgrouphead = operandgroup_text.split(",")[0] + rx = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[0])) + ) + ry = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[1])) + ) + + stack_frame += ( + len(operandgroup_text.split(",")) + ry - rx + ) * self.GENERAL_PURPOSE_REGISTER_SIZE + else: + stack_frame += ( + len(operandgroup_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + elif self.SMW_OPCODE_RE.match(opcode) is not None: + # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp} + if self.SMW_OPERAND_RE.match(operand_text) is not None: + result = self.SMW_OPERAND_RE.match(operand_text) + operandgroup_text = result.group(3) + # capture $rx~$ry + if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: + # capture number & transfer string to integer + oprandgrouphead = operandgroup_text.split(",")[0] + rx = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[0])) + ) + ry = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[1])) + ) + + stack_frame += ( + len(operandgroup_text.split(",")) + ry - rx + ) * self.GENERAL_PURPOSE_REGISTER_SIZE + else: + stack_frame += ( + len(operandgroup_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + return (stack_frame, callsites) - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - Returns: - (address, opcode, operand_text): The instruction address, opcode, - and the text of operands. None if it - isn't an instruction line. - """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - """Analyze function, resolve the size of stack frame and callsites. - - Args: - function_symbol: Function symbol. - instructions: Instruction list. +class ArmAnalyzer(object): + """Disassembly analyzer for ARM architecture. - Returns: - (stack_frame, callsites): Size of stack frame, callsite list. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - stack_frame = 0 - callsites = [] - for address, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None - if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode: - is_tail = is_jump_opcode or is_cbz_cbnz_opcode - - if is_cbz_cbnz_opcode: - result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text) - else: - result = self.CALL_OPERAND_RE.match(operand_text) + GENERAL_PURPOSE_REGISTER_SIZE = 4 + + # Possible condition code suffixes. + CONDITION_CODES = [ + "", + "eq", + "ne", + "cs", + "hs", + "cc", + "lo", + "mi", + "pl", + "vs", + "vc", + "hi", + "ls", + "ge", + "lt", + "gt", + "le", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + # Assume there is no function name containing ">". + IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>" + + # Fuzzy regular expressions for instruction and operand parsing. + # Branch instructions. + JUMP_OPCODE_RE = re.compile(r"^(b{0}|bx{0})(\.\w)?$".format(CONDITION_CODES_RE)) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(bl{0}|blx{0})(\.\w)?$".format(CONDITION_CODES_RE)) + CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE)) + CBZ_CBNZ_OPCODE_RE = re.compile(r"^(cbz|cbnz)(\.\w)?$") + # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>" + CBZ_CBNZ_OPERAND_RE = re.compile(r"^[^,]+,\s+{}$".format(IMM_ADDRESS_RE)) + # Ignore lr register because it's for return. + INDIRECT_CALL_OPERAND_RE = re.compile(r"^r\d+|sb|sl|fp|ip|sp|pc$") + # TODO(cheyuw): Handle conditional versions of following + # instructions. + # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc). + LDR_OPCODE_RE = re.compile(r"^ldr(\.\w)?$") + # Example: "pc, [sp], #4" + LDR_PC_OPERAND_RE = re.compile(r"^pc, \[([^\]]+)\]") + # TODO(cheyuw): Handle other kinds of stm instructions. + PUSH_OPCODE_RE = re.compile(r"^push$") + STM_OPCODE_RE = re.compile(r"^stmdb$") + # Stack subtraction instructions. + SUB_OPCODE_RE = re.compile(r"^sub(s|w)?(\.\w)?$") + SUB_OPERAND_RE = re.compile(r"^sp[^#]+#(\d+)") + # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68" + # Assume there is always a "\t" after the hex data. + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, opcode, operand_text): The instruction address, opcode, + and the text of operands. None if it + isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - # Failed to match immediate address, maybe it is an indirect call. - # CBZ and CBNZ can't be indirect calls. - if (not is_cbz_cbnz_opcode and - self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) + return None - else: - target_address = int(result.group(1), 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.LDR_OPCODE_RE.match(opcode) is not None: - result = self.LDR_PC_OPERAND_RE.match(operand_text) - if result is not None: - # Ignore "ldr pc, [sp], xx" because it's usually a return. - if result.group(1) != 'sp': - # Found an indirect call. - callsites.append(Callsite(address, None, True)) - - elif self.PUSH_OPCODE_RE.match(opcode) is not None: - # Example: "{r4, r5, r6, r7, lr}" - stack_frame += (len(operand_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - elif self.SUB_OPCODE_RE.match(opcode) is not None: - result = self.SUB_OPERAND_RE.match(operand_text) - if result is not None: - stack_frame += int(result.group(1)) - else: - # Unhandled stack register subtraction. - assert not operand_text.startswith('sp') + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None - elif self.STM_OPCODE_RE.match(opcode) is not None: - if operand_text.startswith('sp!'): - # Subtract and writeback to stack register. - # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}" - # Get the text of pushed register list. - unused_sp, unused_sep, parameter_text = operand_text.partition(',') - stack_frame += (len(parameter_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + if operand_text is None: + operand_text = "" + else: + operand_text = operand_text.strip() + + return (address, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + """Analyze function, resolve the size of stack frame and callsites. + + Args: + function_symbol: Function symbol. + instructions: Instruction list. + + Returns: + (stack_frame, callsites): Size of stack frame, callsite list. + """ + stack_frame = 0 + callsites = [] + for address, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None + if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode: + is_tail = is_jump_opcode or is_cbz_cbnz_opcode + + if is_cbz_cbnz_opcode: + result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text) + else: + result = self.CALL_OPERAND_RE.match(operand_text) + + if result is None: + # Failed to match immediate address, maybe it is an indirect call. + # CBZ and CBNZ can't be indirect calls. + if ( + not is_cbz_cbnz_opcode + and self.INDIRECT_CALL_OPERAND_RE.match(operand_text) + is not None + ): + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + target_address = int(result.group(1), 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append(Callsite(address, target_address, is_tail)) + + elif self.LDR_OPCODE_RE.match(opcode) is not None: + result = self.LDR_PC_OPERAND_RE.match(operand_text) + if result is not None: + # Ignore "ldr pc, [sp], xx" because it's usually a return. + if result.group(1) != "sp": + # Found an indirect call. + callsites.append(Callsite(address, None, True)) + + elif self.PUSH_OPCODE_RE.match(opcode) is not None: + # Example: "{r4, r5, r6, r7, lr}" + stack_frame += ( + len(operand_text.split(",")) * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + elif self.SUB_OPCODE_RE.match(opcode) is not None: + result = self.SUB_OPERAND_RE.match(operand_text) + if result is not None: + stack_frame += int(result.group(1)) + else: + # Unhandled stack register subtraction. + assert not operand_text.startswith("sp") + + elif self.STM_OPCODE_RE.match(opcode) is not None: + if operand_text.startswith("sp!"): + # Subtract and writeback to stack register. + # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}" + # Get the text of pushed register list. + unused_sp, unused_sep, parameter_text = operand_text.partition(",") + stack_frame += ( + len(parameter_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + return (stack_frame, callsites) - return (stack_frame, callsites) class RiscvAnalyzer(object): - """Disassembly analyzer for RISC-V architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - # Possible condition code suffixes. - CONDITION_CODES = [ 'eqz', 'nez', 'lez', 'gez', 'ltz', 'gtz', 'gt', 'le', - 'gtu', 'leu', 'eq', 'ne', 'ge', 'lt', 'ltu', 'geu'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - # Branch instructions. - JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr)$'.format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile(r'^(jal|jalr)$') - # Example: "j 8009b318 <set_state_prl_hr>" or - # "jal ra,800a4394 <power_get_signals>" or - # "bltu t0,t1,80080300 <data_loop>" - JUMP_ADDRESS_RE = r'((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>' - CALL_OPERAND_RE = re.compile(r'^{}$'.format(JUMP_ADDRESS_RE)) - # Capture address, Example: 800a4394 - CAPTURE_ADDRESS = re.compile(r'[0-9A-Fa-f]{8}') - # Indirect jump, Example: jalr a5 - INDIRECT_CALL_OPERAND_RE = re.compile(r'^t\d+|s\d+|a\d+$') - # Example: addi - ADDI_OPCODE_RE = re.compile(r'^addi$') - # Allocate stack instructions. - ADDI_OPERAND_RE = re.compile(r'^(sp,sp,-\d+)$') - # Example: "800804b6: 1101 addi sp,sp,-32" - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + """Disassembly analyzer for RISC-V architecture. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - - Returns: - (address, opcode, operand_text): The instruction address, opcode, - and the text of operands. None if it - isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - - stack_frame = 0 - callsites = [] - for address, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - if is_jump_opcode or is_call_opcode: - is_tail = is_jump_opcode - - result = self.CALL_OPERAND_RE.match(operand_text) + # Possible condition code suffixes. + CONDITION_CODES = [ + "eqz", + "nez", + "lez", + "gez", + "ltz", + "gtz", + "gt", + "le", + "gtu", + "leu", + "eq", + "ne", + "ge", + "lt", + "ltu", + "geu", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + # Branch instructions. + JUMP_OPCODE_RE = re.compile(r"^(b{0}|j|jr)$".format(CONDITION_CODES_RE)) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(jal|jalr)$") + # Example: "j 8009b318 <set_state_prl_hr>" or + # "jal ra,800a4394 <power_get_signals>" or + # "bltu t0,t1,80080300 <data_loop>" + JUMP_ADDRESS_RE = r"((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>" + CALL_OPERAND_RE = re.compile(r"^{}$".format(JUMP_ADDRESS_RE)) + # Capture address, Example: 800a4394 + CAPTURE_ADDRESS = re.compile(r"[0-9A-Fa-f]{8}") + # Indirect jump, Example: jalr a5 + INDIRECT_CALL_OPERAND_RE = re.compile(r"^t\d+|s\d+|a\d+$") + # Example: addi + ADDI_OPCODE_RE = re.compile(r"^addi$") + # Allocate stack instructions. + ADDI_OPERAND_RE = re.compile(r"^(sp,sp,-\d+)$") + # Example: "800804b6: 1101 addi sp,sp,-32" + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, opcode, operand_text): The instruction address, opcode, + and the text of operands. None if it + isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) - - else: - # Capture address form operand_text and then convert to string - address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text)) - # String to integer - target_address = int(address_str, 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.ADDI_OPCODE_RE.match(opcode) is not None: - # Example: sp,sp,-32 - if self.ADDI_OPERAND_RE.match(operand_text) is not None: - stack_frame += abs(int(operand_text.split(",")[2])) - - return (stack_frame, callsites) - -class StackAnalyzer(object): - """Class to analyze stack usage. - - Public Methods: - Analyze: Run the stack analysis. - """ - - C_FUNCTION_NAME = r'_A-Za-z0-9' - - # Assume there is no ":" in the path. - # Example: "driver/accel_kionix.c:321 (discriminator 3)" - ADDRTOLINE_RE = re.compile( - r'^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$') - # To eliminate the suffix appended by compilers, try to extract the - # C function name from the prefix of symbol name. - # Example: "SHA256_transform.constprop.28" - FUNCTION_PREFIX_NAME_RE = re.compile( - r'^(?P<name>[{0}]+)([^{0}].*)?$'.format(C_FUNCTION_NAME)) - - # Errors of annotation resolving. - ANNOTATION_ERROR_INVALID = 'invalid signature' - ANNOTATION_ERROR_NOTFOUND = 'function is not found' - ANNOTATION_ERROR_AMBIGUOUS = 'signature is ambiguous' - - def __init__(self, options, symbols, rodata, tasklist, annotation): - """Constructor. - - Args: - options: Namespace from argparse.parse_args(). - symbols: Symbol list. - rodata: Content of .rodata section (offset, data) - tasklist: Task list. - annotation: Annotation config. - """ - self.options = options - self.symbols = symbols - self.rodata_offset = rodata[0] - self.rodata = rodata[1] - self.tasklist = tasklist - self.annotation = annotation - self.address_to_line_cache = {} - - def AddressToLine(self, address, resolve_inline=False): - """Convert address to line. - - Args: - address: Target address. - resolve_inline: Output the stack of inlining. - - Returns: - lines: List of the corresponding lines. - - Raises: - StackAnalyzerError: If addr2line is failed. - """ - cache_key = (address, resolve_inline) - if cache_key in self.address_to_line_cache: - return self.address_to_line_cache[cache_key] - - try: - args = [self.options.addr2line, - '-f', - '-e', - self.options.elf_path, - '{:x}'.format(address)] - if resolve_inline: - args.append('-i') - - line_text = subprocess.check_output(args, encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('addr2line failed to resolve lines.') - except OSError: - raise StackAnalyzerError('Failed to run addr2line.') - - lines = [line.strip() for line in line_text.splitlines()] - # Assume the output has at least one pair like "function\nlocation\n", and - # they always show up in pairs. - # Example: "handle_request\n - # common/usb_pd_protocol.c:1191\n" - assert len(lines) >= 2 and len(lines) % 2 == 0 - - line_infos = [] - for index in range(0, len(lines), 2): - (function_name, line_text) = lines[index:index + 2] - if line_text in ['??:0', ':?']: - line_infos.append(None) - else: - result = self.ADDRTOLINE_RE.match(line_text) - # Assume the output is always well-formed. - assert result is not None - line_infos.append((function_name.strip(), - os.path.realpath(result.group('path').strip()), - int(result.group('linenum')))) - - self.address_to_line_cache[cache_key] = line_infos - return line_infos - - def AnalyzeDisassembly(self, disasm_text): - """Parse the disassembly text, analyze, and build a map of all functions. - - Args: - disasm_text: Disassembly text. - - Returns: - function_map: Dict of functions. - """ - disasm_lines = [line.strip() for line in disasm_text.splitlines()] - - if 'nds' in disasm_lines[1]: - analyzer = AndesAnalyzer() - elif 'arm' in disasm_lines[1]: - analyzer = ArmAnalyzer() - elif 'riscv' in disasm_lines[1]: - analyzer = RiscvAnalyzer() - else: - raise StackAnalyzerError('Unsupported architecture.') - - # Example: "08028c8c <motion_lid_calc>:" - function_signature_regex = re.compile( - r'^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$') - - def DetectFunctionHead(line): - """Check if the line is a function head. - - Args: - line: Text of disassembly. - - Returns: - symbol: Function symbol. None if it isn't a function head. - """ - result = function_signature_regex.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - symbol = symbol_map.get(address) - - # Check if the function exists and matches. - if symbol is None or symbol.symtype != 'F': - return None - - return symbol - - # Build symbol map, indexed by symbol address. - symbol_map = {} - for symbol in self.symbols: - # If there are multiple symbols with same address, keeping any of them is - # good enough. - symbol_map[symbol.address] = symbol - - # Parse the disassembly text. We update the variable "line" to next line - # when needed. There are two steps of parser: - # - # Step 1: Searching for the function head. Once reach the function head, - # move to the next line, which is the first line of function body. - # - # Step 2: Parsing each instruction line of function body. Once reach a - # non-instruction line, stop parsing and analyze the parsed instructions. - # - # Finally turn back to the step 1 without updating the line, because the - # current non-instruction line can be another function head. - function_map = {} - # The following three variables are the states of the parsing processing. - # They will be initialized properly during the state changes. - function_symbol = None - function_end = None - instructions = [] - - # Remove heading and tailing spaces for each line. - line_index = 0 - while line_index < len(disasm_lines): - # Get the current line. - line = disasm_lines[line_index] - - if function_symbol is None: - # Step 1: Search for the function head. - - function_symbol = DetectFunctionHead(line) - if function_symbol is not None: - # Assume there is no empty function. If the function head is followed - # by EOF, it is an empty function. - assert line_index + 1 < len(disasm_lines) - - # Found the function head, initialize and turn to the step 2. - instructions = [] - # If symbol size exists, use it as a hint of function size. - if function_symbol.size > 0: - function_end = function_symbol.address + function_symbol.size - else: - function_end = None - - else: - # Step 2: Parse the function body. - - instruction = analyzer.ParseInstruction(line, function_end) - if instruction is not None: - instructions.append(instruction) - - if instruction is None or line_index + 1 == len(disasm_lines): - # Either the invalid instruction or EOF indicates the end of the - # function, finalize the function analysis. - - # Assume there is no empty function. - assert len(instructions) > 0 - - (stack_frame, callsites) = analyzer.AnalyzeFunction(function_symbol, - instructions) - # Assume the function addresses are unique in the disassembly. - assert function_symbol.address not in function_map - function_map[function_symbol.address] = Function( - function_symbol.address, - function_symbol.name, - stack_frame, - callsites) - - # Initialize and turn back to the step 1. - function_symbol = None - - # If the current line isn't an instruction, it can be another function - # head, skip moving to the next line. - if instruction is None: - continue - - # Move to the next line. - line_index += 1 + return None - # Resolve callees of functions. - for function in function_map.values(): - for callsite in function.callsites: - if callsite.target is not None: - # Remain the callee as None if we can't resolve it. - callsite.callee = function_map.get(callsite.target) + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None - return function_map + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + if operand_text is None: + operand_text = "" + else: + operand_text = operand_text.strip() + + return (address, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + + stack_frame = 0 + callsites = [] + for address, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + + if is_jump_opcode or is_call_opcode: + is_tail = is_jump_opcode + + result = self.CALL_OPERAND_RE.match(operand_text) + if result is None: + if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None: + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + # Capture address form operand_text and then convert to string + address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text)) + # String to integer + target_address = int(address_str, 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append(Callsite(address, target_address, is_tail)) + + elif self.ADDI_OPCODE_RE.match(opcode) is not None: + # Example: sp,sp,-32 + if self.ADDI_OPERAND_RE.match(operand_text) is not None: + stack_frame += abs(int(operand_text.split(",")[2])) + + return (stack_frame, callsites) - def MapAnnotation(self, function_map, signature_set): - """Map annotation signatures to functions. - Args: - function_map: Function map. - signature_set: Set of annotation signatures. +class StackAnalyzer(object): + """Class to analyze stack usage. - Returns: - Map of signatures to functions, map of signatures which can't be resolved. + Public Methods: + Analyze: Run the stack analysis. """ - # Build the symbol map indexed by symbol name. If there are multiple symbols - # with the same name, add them into a set. (e.g. symbols of static function - # with the same name) - symbol_map = collections.defaultdict(set) - for symbol in self.symbols: - if symbol.symtype == 'F': - # Function symbol. - result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) - if result is not None: - function = function_map.get(symbol.address) - # Ignore the symbol not in disassembly. - if function is not None: - # If there are multiple symbol with the same name and point to the - # same function, the set will deduplicate them. - symbol_map[result.group('name').strip()].add(function) - - # Build the signature map indexed by annotation signature. - signature_map = {} - sig_error_map = {} - symbol_path_map = {} - for sig in signature_set: - (name, path, _) = sig - - functions = symbol_map.get(name) - if functions is None: - sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND - continue - - if name not in symbol_path_map: - # Lazy symbol path resolving. Since the addr2line isn't fast, only - # resolve needed symbol paths. - group_map = collections.defaultdict(list) - for function in functions: - line_info = self.AddressToLine(function.address)[0] - if line_info is None: - continue - (_, symbol_path, _) = line_info + C_FUNCTION_NAME = r"_A-Za-z0-9" - # Group the functions with the same symbol signature (symbol name + - # symbol path). Assume they are the same copies and do the same - # annotation operations of them because we don't know which copy is - # indicated by the users. - group_map[symbol_path].append(function) + # Assume there is no ":" in the path. + # Example: "driver/accel_kionix.c:321 (discriminator 3)" + ADDRTOLINE_RE = re.compile( + r"^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$" + ) + # To eliminate the suffix appended by compilers, try to extract the + # C function name from the prefix of symbol name. + # Example: "SHA256_transform.constprop.28" + FUNCTION_PREFIX_NAME_RE = re.compile( + r"^(?P<name>[{0}]+)([^{0}].*)?$".format(C_FUNCTION_NAME) + ) + + # Errors of annotation resolving. + ANNOTATION_ERROR_INVALID = "invalid signature" + ANNOTATION_ERROR_NOTFOUND = "function is not found" + ANNOTATION_ERROR_AMBIGUOUS = "signature is ambiguous" + + def __init__(self, options, symbols, rodata, tasklist, annotation): + """Constructor. + + Args: + options: Namespace from argparse.parse_args(). + symbols: Symbol list. + rodata: Content of .rodata section (offset, data) + tasklist: Task list. + annotation: Annotation config. + """ + self.options = options + self.symbols = symbols + self.rodata_offset = rodata[0] + self.rodata = rodata[1] + self.tasklist = tasklist + self.annotation = annotation + self.address_to_line_cache = {} + + def AddressToLine(self, address, resolve_inline=False): + """Convert address to line. + + Args: + address: Target address. + resolve_inline: Output the stack of inlining. + + Returns: + lines: List of the corresponding lines. + + Raises: + StackAnalyzerError: If addr2line is failed. + """ + cache_key = (address, resolve_inline) + if cache_key in self.address_to_line_cache: + return self.address_to_line_cache[cache_key] + + try: + args = [ + self.options.addr2line, + "-f", + "-e", + self.options.elf_path, + "{:x}".format(address), + ] + if resolve_inline: + args.append("-i") + + line_text = subprocess.check_output(args, encoding="utf-8") + except subprocess.CalledProcessError: + raise StackAnalyzerError("addr2line failed to resolve lines.") + except OSError: + raise StackAnalyzerError("Failed to run addr2line.") + + lines = [line.strip() for line in line_text.splitlines()] + # Assume the output has at least one pair like "function\nlocation\n", and + # they always show up in pairs. + # Example: "handle_request\n + # common/usb_pd_protocol.c:1191\n" + assert len(lines) >= 2 and len(lines) % 2 == 0 + + line_infos = [] + for index in range(0, len(lines), 2): + (function_name, line_text) = lines[index : index + 2] + if line_text in ["??:0", ":?"]: + line_infos.append(None) + else: + result = self.ADDRTOLINE_RE.match(line_text) + # Assume the output is always well-formed. + assert result is not None + line_infos.append( + ( + function_name.strip(), + os.path.realpath(result.group("path").strip()), + int(result.group("linenum")), + ) + ) + + self.address_to_line_cache[cache_key] = line_infos + return line_infos + + def AnalyzeDisassembly(self, disasm_text): + """Parse the disassembly text, analyze, and build a map of all functions. + + Args: + disasm_text: Disassembly text. + + Returns: + function_map: Dict of functions. + """ + disasm_lines = [line.strip() for line in disasm_text.splitlines()] + + if "nds" in disasm_lines[1]: + analyzer = AndesAnalyzer() + elif "arm" in disasm_lines[1]: + analyzer = ArmAnalyzer() + elif "riscv" in disasm_lines[1]: + analyzer = RiscvAnalyzer() + else: + raise StackAnalyzerError("Unsupported architecture.") - symbol_path_map[name] = group_map + # Example: "08028c8c <motion_lid_calc>:" + function_signature_regex = re.compile( + r"^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$" + ) - # Symbol matching. - function_group = None - group_map = symbol_path_map[name] - if len(group_map) > 0: - if path is None: - if len(group_map) > 1: - # There is ambiguity but the path isn't specified. - sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS - continue + def DetectFunctionHead(line): + """Check if the line is a function head. - # No path signature but all symbol signatures of functions are same. - # Assume they are the same functions, so there is no ambiguity. - (function_group,) = group_map.values() - else: - function_group = group_map.get(path) + Args: + line: Text of disassembly. - if function_group is None: - sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND - continue + Returns: + symbol: Function symbol. None if it isn't a function head. + """ + result = function_signature_regex.match(line) + if result is None: + return None - # The function_group is a list of all the same functions (according to - # our assumption) which should be annotated together. - signature_map[sig] = function_group + address = int(result.group("address"), 16) + symbol = symbol_map.get(address) - return (signature_map, sig_error_map) + # Check if the function exists and matches. + if symbol is None or symbol.symtype != "F": + return None - def LoadAnnotation(self): - """Load annotation rules. + return symbol - Returns: - Map of add rules, set of remove rules, set of text signatures which can't - be parsed. - """ - # Assume there is no ":" in the path. - # Example: "get_range.lto.2501[driver/accel_kionix.c:327]" - annotation_signature_regex = re.compile( - r'^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$') - - def NormalizeSignature(signature_text): - """Parse and normalize the annotation signature. - - Args: - signature_text: Text of the annotation signature. - - Returns: - (function name, path, line number) of the signature. The path and line - number can be None if not exist. None if failed to parse. - """ - result = annotation_signature_regex.match(signature_text.strip()) - if result is None: - return None - - name_result = self.FUNCTION_PREFIX_NAME_RE.match( - result.group('name').strip()) - if name_result is None: - return None - - path = result.group('path') - if path is not None: - path = os.path.realpath(path.strip()) - - linenum = result.group('linenum') - if linenum is not None: - linenum = int(linenum.strip()) - - return (name_result.group('name').strip(), path, linenum) - - def ExpandArray(dic): - """Parse and expand a symbol array - - Args: - dic: Dictionary for the array annotation - - Returns: - array of (symbol name, None, None). - """ - # TODO(drinkcat): This function is quite inefficient, as it goes through - # the symbol table multiple times. - - begin_name = dic['name'] - end_name = dic['name'] + "_end" - offset = dic['offset'] if 'offset' in dic else 0 - stride = dic['stride'] - - begin_address = None - end_address = None - - for symbol in self.symbols: - if (symbol.name == begin_name): - begin_address = symbol.address - if (symbol.name == end_name): - end_address = symbol.address - - if (not begin_address or not end_address): - return None - - output = [] - # TODO(drinkcat): This is inefficient as we go from address to symbol - # object then to symbol name, and later on we'll go back from symbol name - # to symbol object. - for addr in range(begin_address+offset, end_address, stride): - # TODO(drinkcat): Not all architectures need to drop the first bit. - val = self.rodata[(addr-self.rodata_offset) // 4] & 0xfffffffe - name = None + # Build symbol map, indexed by symbol address. + symbol_map = {} for symbol in self.symbols: - if (symbol.address == val): - result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) - name = result.group('name') - break - - if not name: - raise StackAnalyzerError('Cannot find function for address %s.', - hex(val)) - - output.append((name, None, None)) - - return output - - add_rules = collections.defaultdict(set) - remove_rules = list() - invalid_sigtxts = set() - - if 'add' in self.annotation and self.annotation['add'] is not None: - for src_sigtxt, dst_sigtxts in self.annotation['add'].items(): - src_sig = NormalizeSignature(src_sigtxt) - if src_sig is None: - invalid_sigtxts.add(src_sigtxt) - continue - - for dst_sigtxt in dst_sigtxts: - if isinstance(dst_sigtxt, dict): - dst_sig = ExpandArray(dst_sigtxt) - if dst_sig is None: - invalid_sigtxts.add(str(dst_sigtxt)) + # If there are multiple symbols with same address, keeping any of them is + # good enough. + symbol_map[symbol.address] = symbol + + # Parse the disassembly text. We update the variable "line" to next line + # when needed. There are two steps of parser: + # + # Step 1: Searching for the function head. Once reach the function head, + # move to the next line, which is the first line of function body. + # + # Step 2: Parsing each instruction line of function body. Once reach a + # non-instruction line, stop parsing and analyze the parsed instructions. + # + # Finally turn back to the step 1 without updating the line, because the + # current non-instruction line can be another function head. + function_map = {} + # The following three variables are the states of the parsing processing. + # They will be initialized properly during the state changes. + function_symbol = None + function_end = None + instructions = [] + + # Remove heading and tailing spaces for each line. + line_index = 0 + while line_index < len(disasm_lines): + # Get the current line. + line = disasm_lines[line_index] + + if function_symbol is None: + # Step 1: Search for the function head. + + function_symbol = DetectFunctionHead(line) + if function_symbol is not None: + # Assume there is no empty function. If the function head is followed + # by EOF, it is an empty function. + assert line_index + 1 < len(disasm_lines) + + # Found the function head, initialize and turn to the step 2. + instructions = [] + # If symbol size exists, use it as a hint of function size. + if function_symbol.size > 0: + function_end = function_symbol.address + function_symbol.size + else: + function_end = None + else: - add_rules[src_sig].update(dst_sig) - else: - dst_sig = NormalizeSignature(dst_sigtxt) - if dst_sig is None: - invalid_sigtxts.add(dst_sigtxt) + # Step 2: Parse the function body. + + instruction = analyzer.ParseInstruction(line, function_end) + if instruction is not None: + instructions.append(instruction) + + if instruction is None or line_index + 1 == len(disasm_lines): + # Either the invalid instruction or EOF indicates the end of the + # function, finalize the function analysis. + + # Assume there is no empty function. + assert len(instructions) > 0 + + (stack_frame, callsites) = analyzer.AnalyzeFunction( + function_symbol, instructions + ) + # Assume the function addresses are unique in the disassembly. + assert function_symbol.address not in function_map + function_map[function_symbol.address] = Function( + function_symbol.address, + function_symbol.name, + stack_frame, + callsites, + ) + + # Initialize and turn back to the step 1. + function_symbol = None + + # If the current line isn't an instruction, it can be another function + # head, skip moving to the next line. + if instruction is None: + continue + + # Move to the next line. + line_index += 1 + + # Resolve callees of functions. + for function in function_map.values(): + for callsite in function.callsites: + if callsite.target is not None: + # Remain the callee as None if we can't resolve it. + callsite.callee = function_map.get(callsite.target) + + return function_map + + def MapAnnotation(self, function_map, signature_set): + """Map annotation signatures to functions. + + Args: + function_map: Function map. + signature_set: Set of annotation signatures. + + Returns: + Map of signatures to functions, map of signatures which can't be resolved. + """ + # Build the symbol map indexed by symbol name. If there are multiple symbols + # with the same name, add them into a set. (e.g. symbols of static function + # with the same name) + symbol_map = collections.defaultdict(set) + for symbol in self.symbols: + if symbol.symtype == "F": + # Function symbol. + result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) + if result is not None: + function = function_map.get(symbol.address) + # Ignore the symbol not in disassembly. + if function is not None: + # If there are multiple symbol with the same name and point to the + # same function, the set will deduplicate them. + symbol_map[result.group("name").strip()].add(function) + + # Build the signature map indexed by annotation signature. + signature_map = {} + sig_error_map = {} + symbol_path_map = {} + for sig in signature_set: + (name, path, _) = sig + + functions = symbol_map.get(name) + if functions is None: + sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND + continue + + if name not in symbol_path_map: + # Lazy symbol path resolving. Since the addr2line isn't fast, only + # resolve needed symbol paths. + group_map = collections.defaultdict(list) + for function in functions: + line_info = self.AddressToLine(function.address)[0] + if line_info is None: + continue + + (_, symbol_path, _) = line_info + + # Group the functions with the same symbol signature (symbol name + + # symbol path). Assume they are the same copies and do the same + # annotation operations of them because we don't know which copy is + # indicated by the users. + group_map[symbol_path].append(function) + + symbol_path_map[name] = group_map + + # Symbol matching. + function_group = None + group_map = symbol_path_map[name] + if len(group_map) > 0: + if path is None: + if len(group_map) > 1: + # There is ambiguity but the path isn't specified. + sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS + continue + + # No path signature but all symbol signatures of functions are same. + # Assume they are the same functions, so there is no ambiguity. + (function_group,) = group_map.values() + else: + function_group = group_map.get(path) + + if function_group is None: + sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND + continue + + # The function_group is a list of all the same functions (according to + # our assumption) which should be annotated together. + signature_map[sig] = function_group + + return (signature_map, sig_error_map) + + def LoadAnnotation(self): + """Load annotation rules. + + Returns: + Map of add rules, set of remove rules, set of text signatures which can't + be parsed. + """ + # Assume there is no ":" in the path. + # Example: "get_range.lto.2501[driver/accel_kionix.c:327]" + annotation_signature_regex = re.compile( + r"^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$" + ) + + def NormalizeSignature(signature_text): + """Parse and normalize the annotation signature. + + Args: + signature_text: Text of the annotation signature. + + Returns: + (function name, path, line number) of the signature. The path and line + number can be None if not exist. None if failed to parse. + """ + result = annotation_signature_regex.match(signature_text.strip()) + if result is None: + return None + + name_result = self.FUNCTION_PREFIX_NAME_RE.match( + result.group("name").strip() + ) + if name_result is None: + return None + + path = result.group("path") + if path is not None: + path = os.path.realpath(path.strip()) + + linenum = result.group("linenum") + if linenum is not None: + linenum = int(linenum.strip()) + + return (name_result.group("name").strip(), path, linenum) + + def ExpandArray(dic): + """Parse and expand a symbol array + + Args: + dic: Dictionary for the array annotation + + Returns: + array of (symbol name, None, None). + """ + # TODO(drinkcat): This function is quite inefficient, as it goes through + # the symbol table multiple times. + + begin_name = dic["name"] + end_name = dic["name"] + "_end" + offset = dic["offset"] if "offset" in dic else 0 + stride = dic["stride"] + + begin_address = None + end_address = None + + for symbol in self.symbols: + if symbol.name == begin_name: + begin_address = symbol.address + if symbol.name == end_name: + end_address = symbol.address + + if not begin_address or not end_address: + return None + + output = [] + # TODO(drinkcat): This is inefficient as we go from address to symbol + # object then to symbol name, and later on we'll go back from symbol name + # to symbol object. + for addr in range(begin_address + offset, end_address, stride): + # TODO(drinkcat): Not all architectures need to drop the first bit. + val = self.rodata[(addr - self.rodata_offset) // 4] & 0xFFFFFFFE + name = None + for symbol in self.symbols: + if symbol.address == val: + result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) + name = result.group("name") + break + + if not name: + raise StackAnalyzerError( + "Cannot find function for address %s.", hex(val) + ) + + output.append((name, None, None)) + + return output + + add_rules = collections.defaultdict(set) + remove_rules = list() + invalid_sigtxts = set() + + if "add" in self.annotation and self.annotation["add"] is not None: + for src_sigtxt, dst_sigtxts in self.annotation["add"].items(): + src_sig = NormalizeSignature(src_sigtxt) + if src_sig is None: + invalid_sigtxts.add(src_sigtxt) + continue + + for dst_sigtxt in dst_sigtxts: + if isinstance(dst_sigtxt, dict): + dst_sig = ExpandArray(dst_sigtxt) + if dst_sig is None: + invalid_sigtxts.add(str(dst_sigtxt)) + else: + add_rules[src_sig].update(dst_sig) + else: + dst_sig = NormalizeSignature(dst_sigtxt) + if dst_sig is None: + invalid_sigtxts.add(dst_sigtxt) + else: + add_rules[src_sig].add(dst_sig) + + if "remove" in self.annotation and self.annotation["remove"] is not None: + for sigtxt_path in self.annotation["remove"]: + if isinstance(sigtxt_path, str): + # The path has only one vertex. + sigtxt_path = [sigtxt_path] + + if len(sigtxt_path) == 0: + continue + + # Generate multiple remove paths from all the combinations of the + # signatures of each vertex. + sig_paths = [[]] + broken_flag = False + for sigtxt_node in sigtxt_path: + if isinstance(sigtxt_node, str): + # The vertex has only one signature. + sigtxt_set = {sigtxt_node} + elif isinstance(sigtxt_node, list): + # The vertex has multiple signatures. + sigtxt_set = set(sigtxt_node) + else: + # Assume the format of annotation is verified. There should be no + # invalid case. + assert False + + sig_set = set() + for sigtxt in sigtxt_set: + sig = NormalizeSignature(sigtxt) + if sig is None: + invalid_sigtxts.add(sigtxt) + broken_flag = True + elif not broken_flag: + sig_set.add(sig) + + if broken_flag: + continue + + # Append each signature of the current node to the all previous + # remove paths. + sig_paths = [path + [sig] for path in sig_paths for sig in sig_set] + + if not broken_flag: + # All signatures are normalized. The remove path has no error. + remove_rules.extend(sig_paths) + + return (add_rules, remove_rules, invalid_sigtxts) + + def ResolveAnnotation(self, function_map): + """Resolve annotation. + + Args: + function_map: Function map. + + Returns: + Set of added call edges, list of remove paths, set of eliminated + callsite addresses, set of annotation signatures which can't be resolved. + """ + + def StringifySignature(signature): + """Stringify the tupled signature. + + Args: + signature: Tupled signature. + + Returns: + Signature string. + """ + (name, path, linenum) = signature + bracket_text = "" + if path is not None: + path = os.path.relpath(path) + if linenum is None: + bracket_text = "[{}]".format(path) + else: + bracket_text = "[{}:{}]".format(path, linenum) + + return name + bracket_text + + (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation() + + signature_set = set() + for src_sig, dst_sigs in add_rules.items(): + signature_set.add(src_sig) + signature_set.update(dst_sigs) + + for remove_sigs in remove_rules: + signature_set.update(remove_sigs) + + # Map signatures to functions. + (signature_map, sig_error_map) = self.MapAnnotation(function_map, signature_set) + + # Build the indirect callsite map indexed by callsite signature. + indirect_map = collections.defaultdict(set) + for function in function_map.values(): + for callsite in function.callsites: + if callsite.target is not None: + continue + + # Found an indirect callsite. + line_info = self.AddressToLine(callsite.address)[0] + if line_info is None: + continue + + (name, path, linenum) = line_info + result = self.FUNCTION_PREFIX_NAME_RE.match(name) + if result is None: + continue + + indirect_map[(result.group("name").strip(), path, linenum)].add( + (function, callsite.address) + ) + + # Generate the annotation sets. + add_set = set() + remove_list = list() + eliminated_addrs = set() + + for src_sig, dst_sigs in add_rules.items(): + src_funcs = set(signature_map.get(src_sig, [])) + # Try to match the source signature to the indirect callsites. Even if it + # can't be found in disassembly. + indirect_calls = indirect_map.get(src_sig) + if indirect_calls is not None: + for function, callsite_address in indirect_calls: + # Add the caller of the indirect callsite to the source functions. + src_funcs.add(function) + # Assume each callsite can be represented by a unique address. + eliminated_addrs.add(callsite_address) + + if src_sig in sig_error_map: + # Assume the error is always the not found error. Since the signature + # found in indirect callsite map must be a full signature, it can't + # happen the ambiguous error. + assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND + # Found in inline stack, remove the not found error. + del sig_error_map[src_sig] + + for dst_sig in dst_sigs: + dst_funcs = signature_map.get(dst_sig) + if dst_funcs is None: + continue + + # Duplicate the call edge for all the same source and destination + # functions. + for src_func in src_funcs: + for dst_func in dst_funcs: + add_set.add((src_func, dst_func)) + + for remove_sigs in remove_rules: + # Since each signature can be mapped to multiple functions, generate + # multiple remove paths from all the combinations of these functions. + remove_paths = [[]] + skip_flag = False + for remove_sig in remove_sigs: + # Transform each signature to the corresponding functions. + remove_funcs = signature_map.get(remove_sig) + if remove_funcs is None: + # There is an unresolved signature in the remove path. Ignore the + # whole broken remove path. + skip_flag = True + break + else: + # Append each function of the current signature to the all previous + # remove paths. + remove_paths = [p + [f] for p in remove_paths for f in remove_funcs] + + if skip_flag: + # Ignore the broken remove path. + continue + + for remove_path in remove_paths: + # Deduplicate the remove paths. + if remove_path not in remove_list: + remove_list.append(remove_path) + + # Format the error messages. + failed_sigtxts = set() + for sigtxt in invalid_sigtxts: + failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID)) + + for sig, error in sig_error_map.items(): + failed_sigtxts.add((StringifySignature(sig), error)) + + return (add_set, remove_list, eliminated_addrs, failed_sigtxts) + + def PreprocessAnnotation( + self, function_map, add_set, remove_list, eliminated_addrs + ): + """Preprocess the annotation and callgraph. + + Add the missing call edges, and delete simple remove paths (the paths have + one or two vertices) from the function_map. + + Eliminate the annotated indirect callsites. + + Return the remaining remove list. + + Args: + function_map: Function map. + add_set: Set of missing call edges. + remove_list: List of remove paths. + eliminated_addrs: Set of eliminated callsite addresses. + + Returns: + List of remaining remove paths. + """ + + def CheckEdge(path): + """Check if all edges of the path are on the callgraph. + + Args: + path: Path. + + Returns: + True or False. + """ + for index in range(len(path) - 1): + if (path[index], path[index + 1]) not in edge_set: + return False + + return True + + for src_func, dst_func in add_set: + # TODO(cheyuw): Support tailing call annotation. + src_func.callsites.append(Callsite(None, dst_func.address, False, dst_func)) + + # Delete simple remove paths. + remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2) + edge_set = set() + for function in function_map.values(): + cleaned_callsites = [] + for callsite in function.callsites: + if (callsite.callee,) in remove_simple or ( + function, + callsite.callee, + ) in remove_simple: + continue + + if callsite.target is None and callsite.address in eliminated_addrs: + continue + + cleaned_callsites.append(callsite) + if callsite.callee is not None: + edge_set.add((function, callsite.callee)) + + function.callsites = cleaned_callsites + + return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)] + + def AnalyzeCallGraph(self, function_map, remove_list): + """Analyze callgraph. + + It will update the max stack size and path for each function. + + Args: + function_map: Function map. + remove_list: List of remove paths. + + Returns: + List of function cycles. + """ + + def Traverse(curr_state): + """Traverse the callgraph and calculate the max stack usages of functions. + + Args: + curr_state: Current state. + + Returns: + SCC lowest link. + """ + scc_index = scc_index_counter[0] + scc_index_counter[0] += 1 + scc_index_map[curr_state] = scc_index + scc_lowlink = scc_index + scc_stack.append(curr_state) + # Push the current state in the stack. We can use a set to maintain this + # because the stacked states are unique; otherwise we will find a cycle + # first. + stacked_states.add(curr_state) + + (curr_address, curr_positions) = curr_state + curr_func = function_map[curr_address] + + invalid_flag = False + new_positions = list(curr_positions) + for index, position in enumerate(curr_positions): + remove_path = remove_list[index] + + # The position of each remove path in the state is the length of the + # longest matching path between the prefix of the remove path and the + # suffix of the current traversing path. We maintain this length when + # appending the next callee to the traversing path. And it can be used + # to check if the remove path appears in the traversing path. + + # TODO(cheyuw): Implement KMP algorithm to match remove paths + # efficiently. + if remove_path[position] is curr_func: + # Matches the current function, extend the length. + new_positions[index] = position + 1 + if new_positions[index] == len(remove_path): + # The length of the longest matching path is equal to the length of + # the remove path, which means the suffix of the current traversing + # path matches the remove path. + invalid_flag = True + break + + else: + # We can't get the new longest matching path by extending the previous + # one directly. Fallback to search the new longest matching path. + + # If we can't find any matching path in the following search, reset + # the matching length to 0. + new_positions[index] = 0 + + # We want to find the new longest matching prefix of remove path with + # the suffix of the current traversing path. Because the new longest + # matching path won't be longer than the prevous one now, and part of + # the suffix matches the prefix of remove path, we can get the needed + # suffix from the previous matching prefix of the invalid path. + suffix = remove_path[:position] + [curr_func] + for offset in range(1, len(suffix)): + length = position - offset + if remove_path[:length] == suffix[offset:]: + new_positions[index] = length + break + + new_positions = tuple(new_positions) + + # If the current suffix is invalid, set the max stack usage to 0. + max_stack_usage = 0 + max_callee_state = None + self_loop = False + + if not invalid_flag: + # Max stack usage is at least equal to the stack frame. + max_stack_usage = curr_func.stack_frame + for callsite in curr_func.callsites: + callee = callsite.callee + if callee is None: + continue + + callee_state = (callee.address, new_positions) + if callee_state not in scc_index_map: + # Unvisited state. + scc_lowlink = min(scc_lowlink, Traverse(callee_state)) + elif callee_state in stacked_states: + # The state is shown in the stack. There is a cycle. + sub_stack_usage = 0 + scc_lowlink = min(scc_lowlink, scc_index_map[callee_state]) + if callee_state == curr_state: + self_loop = True + + done_result = done_states.get(callee_state) + if done_result is not None: + # Already done this state and use its result. If the state reaches a + # cycle, reusing the result will cause inaccuracy (the stack usage + # of cycle depends on where the entrance is). But it's fine since we + # can't get accurate stack usage under this situation, and we rely + # on user-provided annotations to break the cycle, after which the + # result will be accurate again. + (sub_stack_usage, _) = done_result + + if callsite.is_tail: + # For tailing call, since the callee reuses the stack frame of the + # caller, choose the larger one directly. + stack_usage = max(curr_func.stack_frame, sub_stack_usage) + else: + stack_usage = curr_func.stack_frame + sub_stack_usage + + if stack_usage > max_stack_usage: + max_stack_usage = stack_usage + max_callee_state = callee_state + + if scc_lowlink == scc_index: + group = [] + while scc_stack[-1] != curr_state: + scc_state = scc_stack.pop() + stacked_states.remove(scc_state) + group.append(scc_state) + + scc_stack.pop() + stacked_states.remove(curr_state) + + # If the cycle is not empty, record it. + if len(group) > 0 or self_loop: + group.append(curr_state) + cycle_groups.append(group) + + # Store the done result. + done_states[curr_state] = (max_stack_usage, max_callee_state) + + if curr_positions == initial_positions: + # If the current state is initial state, we traversed the callgraph by + # using the current function as start point. Update the stack usage of + # the function. + # If the function matches a single vertex remove path, this will set its + # max stack usage to 0, which is not expected (we still calculate its + # max stack usage, but prevent any function from calling it). However, + # all the single vertex remove paths have been preprocessed and removed. + curr_func.stack_max_usage = max_stack_usage + + # Reconstruct the max stack path by traversing the state transitions. + max_stack_path = [curr_func] + callee_state = max_callee_state + while callee_state is not None: + # The first element of state tuple is function address. + max_stack_path.append(function_map[callee_state[0]]) + done_result = done_states.get(callee_state) + # All of the descendants should be done. + assert done_result is not None + (_, callee_state) = done_result + + curr_func.stack_max_path = max_stack_path + + return scc_lowlink + + # The state is the concatenation of the current function address and the + # state of matching position. + initial_positions = (0,) * len(remove_list) + done_states = {} + stacked_states = set() + scc_index_counter = [0] + scc_index_map = {} + scc_stack = [] + cycle_groups = [] + for function in function_map.values(): + if function.stack_max_usage is None: + Traverse((function.address, initial_positions)) + + cycle_functions = [] + for group in cycle_groups: + cycle = set(function_map[state[0]] for state in group) + if cycle not in cycle_functions: + cycle_functions.append(cycle) + + return cycle_functions + + def Analyze(self): + """Run the stack analysis. + + Raises: + StackAnalyzerError: If disassembly fails. + """ + + def OutputInlineStack(address, prefix=""): + """Output beautiful inline stack. + + Args: + address: Address. + prefix: Prefix of each line. + + Returns: + Key for sorting, output text + """ + line_infos = self.AddressToLine(address, True) + + if line_infos[0] is None: + order_key = (None, None) else: - add_rules[src_sig].add(dst_sig) - - if 'remove' in self.annotation and self.annotation['remove'] is not None: - for sigtxt_path in self.annotation['remove']: - if isinstance(sigtxt_path, str): - # The path has only one vertex. - sigtxt_path = [sigtxt_path] - - if len(sigtxt_path) == 0: - continue - - # Generate multiple remove paths from all the combinations of the - # signatures of each vertex. - sig_paths = [[]] - broken_flag = False - for sigtxt_node in sigtxt_path: - if isinstance(sigtxt_node, str): - # The vertex has only one signature. - sigtxt_set = {sigtxt_node} - elif isinstance(sigtxt_node, list): - # The vertex has multiple signatures. - sigtxt_set = set(sigtxt_node) - else: - # Assume the format of annotation is verified. There should be no - # invalid case. - assert False - - sig_set = set() - for sigtxt in sigtxt_set: - sig = NormalizeSignature(sigtxt) - if sig is None: - invalid_sigtxts.add(sigtxt) - broken_flag = True - elif not broken_flag: - sig_set.add(sig) - - if broken_flag: - continue - - # Append each signature of the current node to the all previous - # remove paths. - sig_paths = [path + [sig] for path in sig_paths for sig in sig_set] - - if not broken_flag: - # All signatures are normalized. The remove path has no error. - remove_rules.extend(sig_paths) - - return (add_rules, remove_rules, invalid_sigtxts) + (_, path, linenum) = line_infos[0] + order_key = (linenum, path) + + line_texts = [] + for line_info in reversed(line_infos): + if line_info is None: + (function_name, path, linenum) = ("??", "??", 0) + else: + (function_name, path, linenum) = line_info + + line_texts.append( + "{}[{}:{}]".format(function_name, os.path.relpath(path), linenum) + ) + + output = "{}-> {} {:x}\n".format(prefix, line_texts[0], address) + for depth, line_text in enumerate(line_texts[1:]): + output += "{} {}- {}\n".format(prefix, " " * depth, line_text) + + # Remove the last newline character. + return (order_key, output.rstrip("\n")) + + # Analyze disassembly. + try: + disasm_text = subprocess.check_output( + [self.options.objdump, "-d", self.options.elf_path], encoding="utf-8" + ) + except subprocess.CalledProcessError: + raise StackAnalyzerError("objdump failed to disassemble.") + except OSError: + raise StackAnalyzerError("Failed to run objdump.") + + function_map = self.AnalyzeDisassembly(disasm_text) + result = self.ResolveAnnotation(function_map) + (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result + remove_list = self.PreprocessAnnotation( + function_map, add_set, remove_list, eliminated_addrs + ) + cycle_functions = self.AnalyzeCallGraph(function_map, remove_list) + + # Print the results of task-aware stack analysis. + extra_stack_frame = self.annotation.get( + "exception_frame_size", DEFAULT_EXCEPTION_FRAME_SIZE + ) + for task in self.tasklist: + routine_func = function_map[task.routine_address] + print( + "Task: {}, Max size: {} ({} + {}), Allocated size: {}".format( + task.name, + routine_func.stack_max_usage + extra_stack_frame, + routine_func.stack_max_usage, + extra_stack_frame, + task.stack_max_size, + ) + ) + + print("Call Trace:") + max_stack_path = routine_func.stack_max_path + # Assume the routine function is resolved. + assert max_stack_path is not None + for depth, curr_func in enumerate(max_stack_path): + line_info = self.AddressToLine(curr_func.address)[0] + if line_info is None: + (path, linenum) = ("??", 0) + else: + (_, path, linenum) = line_info + + print( + " {} ({}) [{}:{}] {:x}".format( + curr_func.name, + curr_func.stack_frame, + os.path.relpath(path), + linenum, + curr_func.address, + ) + ) + + if depth + 1 < len(max_stack_path): + succ_func = max_stack_path[depth + 1] + text_list = [] + for callsite in curr_func.callsites: + if callsite.callee is succ_func: + indent_prefix = " " + if callsite.address is None: + order_text = ( + None, + "{}-> [annotation]".format(indent_prefix), + ) + else: + order_text = OutputInlineStack( + callsite.address, indent_prefix + ) + + text_list.append(order_text) + + for _, text in sorted(text_list, key=lambda item: item[0]): + print(text) + + print("Unresolved indirect callsites:") + for function in function_map.values(): + indirect_callsites = [] + for callsite in function.callsites: + if callsite.target is None: + indirect_callsites.append(callsite.address) + + if len(indirect_callsites) > 0: + print(" In function {}:".format(function.name)) + text_list = [] + for address in indirect_callsites: + text_list.append(OutputInlineStack(address, " ")) + + for _, text in sorted(text_list, key=lambda item: item[0]): + print(text) + + print("Unresolved annotation signatures:") + for sigtxt, error in failed_sigtxts: + print(" {}: {}".format(sigtxt, error)) + + if len(cycle_functions) > 0: + print("There are cycles in the following function sets:") + for functions in cycle_functions: + print("[{}]".format(", ".join(function.name for function in functions))) - def ResolveAnnotation(self, function_map): - """Resolve annotation. - Args: - function_map: Function map. +def ParseArgs(): + """Parse commandline arguments. Returns: - Set of added call edges, list of remove paths, set of eliminated - callsite addresses, set of annotation signatures which can't be resolved. + options: Namespace from argparse.parse_args(). """ - def StringifySignature(signature): - """Stringify the tupled signature. - - Args: - signature: Tupled signature. - - Returns: - Signature string. - """ - (name, path, linenum) = signature - bracket_text = '' - if path is not None: - path = os.path.relpath(path) - if linenum is None: - bracket_text = '[{}]'.format(path) - else: - bracket_text = '[{}:{}]'.format(path, linenum) - - return name + bracket_text - - (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation() - - signature_set = set() - for src_sig, dst_sigs in add_rules.items(): - signature_set.add(src_sig) - signature_set.update(dst_sigs) - - for remove_sigs in remove_rules: - signature_set.update(remove_sigs) - - # Map signatures to functions. - (signature_map, sig_error_map) = self.MapAnnotation(function_map, - signature_set) - - # Build the indirect callsite map indexed by callsite signature. - indirect_map = collections.defaultdict(set) - for function in function_map.values(): - for callsite in function.callsites: - if callsite.target is not None: - continue - - # Found an indirect callsite. - line_info = self.AddressToLine(callsite.address)[0] - if line_info is None: - continue - - (name, path, linenum) = line_info - result = self.FUNCTION_PREFIX_NAME_RE.match(name) - if result is None: - continue - - indirect_map[(result.group('name').strip(), path, linenum)].add( - (function, callsite.address)) - - # Generate the annotation sets. - add_set = set() - remove_list = list() - eliminated_addrs = set() - - for src_sig, dst_sigs in add_rules.items(): - src_funcs = set(signature_map.get(src_sig, [])) - # Try to match the source signature to the indirect callsites. Even if it - # can't be found in disassembly. - indirect_calls = indirect_map.get(src_sig) - if indirect_calls is not None: - for function, callsite_address in indirect_calls: - # Add the caller of the indirect callsite to the source functions. - src_funcs.add(function) - # Assume each callsite can be represented by a unique address. - eliminated_addrs.add(callsite_address) - - if src_sig in sig_error_map: - # Assume the error is always the not found error. Since the signature - # found in indirect callsite map must be a full signature, it can't - # happen the ambiguous error. - assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND - # Found in inline stack, remove the not found error. - del sig_error_map[src_sig] - - for dst_sig in dst_sigs: - dst_funcs = signature_map.get(dst_sig) - if dst_funcs is None: - continue - - # Duplicate the call edge for all the same source and destination - # functions. - for src_func in src_funcs: - for dst_func in dst_funcs: - add_set.add((src_func, dst_func)) - - for remove_sigs in remove_rules: - # Since each signature can be mapped to multiple functions, generate - # multiple remove paths from all the combinations of these functions. - remove_paths = [[]] - skip_flag = False - for remove_sig in remove_sigs: - # Transform each signature to the corresponding functions. - remove_funcs = signature_map.get(remove_sig) - if remove_funcs is None: - # There is an unresolved signature in the remove path. Ignore the - # whole broken remove path. - skip_flag = True - break - else: - # Append each function of the current signature to the all previous - # remove paths. - remove_paths = [p + [f] for p in remove_paths for f in remove_funcs] - - if skip_flag: - # Ignore the broken remove path. - continue - - for remove_path in remove_paths: - # Deduplicate the remove paths. - if remove_path not in remove_list: - remove_list.append(remove_path) - - # Format the error messages. - failed_sigtxts = set() - for sigtxt in invalid_sigtxts: - failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID)) - - for sig, error in sig_error_map.items(): - failed_sigtxts.add((StringifySignature(sig), error)) - - return (add_set, remove_list, eliminated_addrs, failed_sigtxts) - - def PreprocessAnnotation(self, function_map, add_set, remove_list, - eliminated_addrs): - """Preprocess the annotation and callgraph. + parser = argparse.ArgumentParser(description="EC firmware stack analyzer.") + parser.add_argument("elf_path", help="the path of EC firmware ELF") + parser.add_argument( + "--export_taskinfo", + required=True, + help="the path of export_taskinfo.so utility", + ) + parser.add_argument( + "--section", + required=True, + help="the section.", + choices=[SECTION_RO, SECTION_RW], + ) + parser.add_argument("--objdump", default="objdump", help="the path of objdump") + parser.add_argument( + "--addr2line", default="addr2line", help="the path of addr2line" + ) + parser.add_argument( + "--annotation", default=None, help="the path of annotation file" + ) + + # TODO(cheyuw): Add an option for dumping stack usage of all functions. + + return parser.parse_args() - Add the missing call edges, and delete simple remove paths (the paths have - one or two vertices) from the function_map. - Eliminate the annotated indirect callsites. - - Return the remaining remove list. +def ParseSymbolText(symbol_text): + """Parse the content of the symbol text. Args: - function_map: Function map. - add_set: Set of missing call edges. - remove_list: List of remove paths. - eliminated_addrs: Set of eliminated callsite addresses. + symbol_text: Text of the symbols. Returns: - List of remaining remove paths. + symbols: Symbol list. """ - def CheckEdge(path): - """Check if all edges of the path are on the callgraph. - - Args: - path: Path. - - Returns: - True or False. - """ - for index in range(len(path) - 1): - if (path[index], path[index + 1]) not in edge_set: - return False - - return True - - for src_func, dst_func in add_set: - # TODO(cheyuw): Support tailing call annotation. - src_func.callsites.append( - Callsite(None, dst_func.address, False, dst_func)) - - # Delete simple remove paths. - remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2) - edge_set = set() - for function in function_map.values(): - cleaned_callsites = [] - for callsite in function.callsites: - if ((callsite.callee,) in remove_simple or - (function, callsite.callee) in remove_simple): - continue - - if callsite.target is None and callsite.address in eliminated_addrs: - continue - - cleaned_callsites.append(callsite) - if callsite.callee is not None: - edge_set.add((function, callsite.callee)) + # Example: "10093064 g F .text 0000015c .hidden hook_task" + symbol_regex = re.compile( + r"^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+" + r"((?P<type>[OF])\s+)?\S+\s+" + r"(?P<size>[0-9A-Fa-f]+)\s+" + r"(\S+\s+)?(?P<name>\S+)$" + ) + + symbols = [] + for line in symbol_text.splitlines(): + line = line.strip() + result = symbol_regex.match(line) + if result is not None: + address = int(result.group("address"), 16) + symtype = result.group("type") + if symtype is None: + symtype = "O" - function.callsites = cleaned_callsites + size = int(result.group("size"), 16) + name = result.group("name") + symbols.append(Symbol(address, symtype, size, name)) - return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)] + return symbols - def AnalyzeCallGraph(self, function_map, remove_list): - """Analyze callgraph. - It will update the max stack size and path for each function. +def ParseRoDataText(rodata_text): + """Parse the content of rodata Args: - function_map: Function map. - remove_list: List of remove paths. + symbol_text: Text of the rodata dump. Returns: - List of function cycles. + symbols: Symbol list. """ - def Traverse(curr_state): - """Traverse the callgraph and calculate the max stack usages of functions. - - Args: - curr_state: Current state. - - Returns: - SCC lowest link. - """ - scc_index = scc_index_counter[0] - scc_index_counter[0] += 1 - scc_index_map[curr_state] = scc_index - scc_lowlink = scc_index - scc_stack.append(curr_state) - # Push the current state in the stack. We can use a set to maintain this - # because the stacked states are unique; otherwise we will find a cycle - # first. - stacked_states.add(curr_state) - - (curr_address, curr_positions) = curr_state - curr_func = function_map[curr_address] - - invalid_flag = False - new_positions = list(curr_positions) - for index, position in enumerate(curr_positions): - remove_path = remove_list[index] - - # The position of each remove path in the state is the length of the - # longest matching path between the prefix of the remove path and the - # suffix of the current traversing path. We maintain this length when - # appending the next callee to the traversing path. And it can be used - # to check if the remove path appears in the traversing path. - - # TODO(cheyuw): Implement KMP algorithm to match remove paths - # efficiently. - if remove_path[position] is curr_func: - # Matches the current function, extend the length. - new_positions[index] = position + 1 - if new_positions[index] == len(remove_path): - # The length of the longest matching path is equal to the length of - # the remove path, which means the suffix of the current traversing - # path matches the remove path. - invalid_flag = True - break - - else: - # We can't get the new longest matching path by extending the previous - # one directly. Fallback to search the new longest matching path. - - # If we can't find any matching path in the following search, reset - # the matching length to 0. - new_positions[index] = 0 - - # We want to find the new longest matching prefix of remove path with - # the suffix of the current traversing path. Because the new longest - # matching path won't be longer than the prevous one now, and part of - # the suffix matches the prefix of remove path, we can get the needed - # suffix from the previous matching prefix of the invalid path. - suffix = remove_path[:position] + [curr_func] - for offset in range(1, len(suffix)): - length = position - offset - if remove_path[:length] == suffix[offset:]: - new_positions[index] = length - break - - new_positions = tuple(new_positions) - - # If the current suffix is invalid, set the max stack usage to 0. - max_stack_usage = 0 - max_callee_state = None - self_loop = False - - if not invalid_flag: - # Max stack usage is at least equal to the stack frame. - max_stack_usage = curr_func.stack_frame - for callsite in curr_func.callsites: - callee = callsite.callee - if callee is None: + # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K... + # 100a7294 00000000 00000000 01000000 ............ + + base_offset = None + offset = None + rodata = [] + for line in rodata_text.splitlines(): + line = line.strip() + space = line.find(" ") + if space < 0: + continue + try: + address = int(line[0:space], 16) + except ValueError: continue - callee_state = (callee.address, new_positions) - if callee_state not in scc_index_map: - # Unvisited state. - scc_lowlink = min(scc_lowlink, Traverse(callee_state)) - elif callee_state in stacked_states: - # The state is shown in the stack. There is a cycle. - sub_stack_usage = 0 - scc_lowlink = min(scc_lowlink, scc_index_map[callee_state]) - if callee_state == curr_state: - self_loop = True - - done_result = done_states.get(callee_state) - if done_result is not None: - # Already done this state and use its result. If the state reaches a - # cycle, reusing the result will cause inaccuracy (the stack usage - # of cycle depends on where the entrance is). But it's fine since we - # can't get accurate stack usage under this situation, and we rely - # on user-provided annotations to break the cycle, after which the - # result will be accurate again. - (sub_stack_usage, _) = done_result - - if callsite.is_tail: - # For tailing call, since the callee reuses the stack frame of the - # caller, choose the larger one directly. - stack_usage = max(curr_func.stack_frame, sub_stack_usage) - else: - stack_usage = curr_func.stack_frame + sub_stack_usage - - if stack_usage > max_stack_usage: - max_stack_usage = stack_usage - max_callee_state = callee_state - - if scc_lowlink == scc_index: - group = [] - while scc_stack[-1] != curr_state: - scc_state = scc_stack.pop() - stacked_states.remove(scc_state) - group.append(scc_state) - - scc_stack.pop() - stacked_states.remove(curr_state) - - # If the cycle is not empty, record it. - if len(group) > 0 or self_loop: - group.append(curr_state) - cycle_groups.append(group) - - # Store the done result. - done_states[curr_state] = (max_stack_usage, max_callee_state) - - if curr_positions == initial_positions: - # If the current state is initial state, we traversed the callgraph by - # using the current function as start point. Update the stack usage of - # the function. - # If the function matches a single vertex remove path, this will set its - # max stack usage to 0, which is not expected (we still calculate its - # max stack usage, but prevent any function from calling it). However, - # all the single vertex remove paths have been preprocessed and removed. - curr_func.stack_max_usage = max_stack_usage - - # Reconstruct the max stack path by traversing the state transitions. - max_stack_path = [curr_func] - callee_state = max_callee_state - while callee_state is not None: - # The first element of state tuple is function address. - max_stack_path.append(function_map[callee_state[0]]) - done_result = done_states.get(callee_state) - # All of the descendants should be done. - assert done_result is not None - (_, callee_state) = done_result - - curr_func.stack_max_path = max_stack_path - - return scc_lowlink - - # The state is the concatenation of the current function address and the - # state of matching position. - initial_positions = (0,) * len(remove_list) - done_states = {} - stacked_states = set() - scc_index_counter = [0] - scc_index_map = {} - scc_stack = [] - cycle_groups = [] - for function in function_map.values(): - if function.stack_max_usage is None: - Traverse((function.address, initial_positions)) - - cycle_functions = [] - for group in cycle_groups: - cycle = set(function_map[state[0]] for state in group) - if cycle not in cycle_functions: - cycle_functions.append(cycle) - - return cycle_functions - - def Analyze(self): - """Run the stack analysis. - - Raises: - StackAnalyzerError: If disassembly fails. - """ - def OutputInlineStack(address, prefix=''): - """Output beautiful inline stack. - - Args: - address: Address. - prefix: Prefix of each line. - - Returns: - Key for sorting, output text - """ - line_infos = self.AddressToLine(address, True) - - if line_infos[0] is None: - order_key = (None, None) - else: - (_, path, linenum) = line_infos[0] - order_key = (linenum, path) - - line_texts = [] - for line_info in reversed(line_infos): - if line_info is None: - (function_name, path, linenum) = ('??', '??', 0) - else: - (function_name, path, linenum) = line_info - - line_texts.append('{}[{}:{}]'.format(function_name, - os.path.relpath(path), - linenum)) - - output = '{}-> {} {:x}\n'.format(prefix, line_texts[0], address) - for depth, line_text in enumerate(line_texts[1:]): - output += '{} {}- {}\n'.format(prefix, ' ' * depth, line_text) - - # Remove the last newline character. - return (order_key, output.rstrip('\n')) - - # Analyze disassembly. - try: - disasm_text = subprocess.check_output([self.options.objdump, - '-d', - self.options.elf_path], - encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('objdump failed to disassemble.') - except OSError: - raise StackAnalyzerError('Failed to run objdump.') - - function_map = self.AnalyzeDisassembly(disasm_text) - result = self.ResolveAnnotation(function_map) - (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result - remove_list = self.PreprocessAnnotation(function_map, - add_set, - remove_list, - eliminated_addrs) - cycle_functions = self.AnalyzeCallGraph(function_map, remove_list) - - # Print the results of task-aware stack analysis. - extra_stack_frame = self.annotation.get('exception_frame_size', - DEFAULT_EXCEPTION_FRAME_SIZE) - for task in self.tasklist: - routine_func = function_map[task.routine_address] - print('Task: {}, Max size: {} ({} + {}), Allocated size: {}'.format( - task.name, - routine_func.stack_max_usage + extra_stack_frame, - routine_func.stack_max_usage, - extra_stack_frame, - task.stack_max_size)) - - print('Call Trace:') - max_stack_path = routine_func.stack_max_path - # Assume the routine function is resolved. - assert max_stack_path is not None - for depth, curr_func in enumerate(max_stack_path): - line_info = self.AddressToLine(curr_func.address)[0] - if line_info is None: - (path, linenum) = ('??', 0) - else: - (_, path, linenum) = line_info - - print(' {} ({}) [{}:{}] {:x}'.format(curr_func.name, - curr_func.stack_frame, - os.path.relpath(path), - linenum, - curr_func.address)) - - if depth + 1 < len(max_stack_path): - succ_func = max_stack_path[depth + 1] - text_list = [] - for callsite in curr_func.callsites: - if callsite.callee is succ_func: - indent_prefix = ' ' - if callsite.address is None: - order_text = (None, '{}-> [annotation]'.format(indent_prefix)) - else: - order_text = OutputInlineStack(callsite.address, indent_prefix) - - text_list.append(order_text) - - for _, text in sorted(text_list, key=lambda item: item[0]): - print(text) - - print('Unresolved indirect callsites:') - for function in function_map.values(): - indirect_callsites = [] - for callsite in function.callsites: - if callsite.target is None: - indirect_callsites.append(callsite.address) - - if len(indirect_callsites) > 0: - print(' In function {}:'.format(function.name)) - text_list = [] - for address in indirect_callsites: - text_list.append(OutputInlineStack(address, ' ')) - - for _, text in sorted(text_list, key=lambda item: item[0]): - print(text) - - print('Unresolved annotation signatures:') - for sigtxt, error in failed_sigtxts: - print(' {}: {}'.format(sigtxt, error)) - - if len(cycle_functions) > 0: - print('There are cycles in the following function sets:') - for functions in cycle_functions: - print('[{}]'.format(', '.join(function.name for function in functions))) - - -def ParseArgs(): - """Parse commandline arguments. - - Returns: - options: Namespace from argparse.parse_args(). - """ - parser = argparse.ArgumentParser(description="EC firmware stack analyzer.") - parser.add_argument('elf_path', help="the path of EC firmware ELF") - parser.add_argument('--export_taskinfo', required=True, - help="the path of export_taskinfo.so utility") - parser.add_argument('--section', required=True, help='the section.', - choices=[SECTION_RO, SECTION_RW]) - parser.add_argument('--objdump', default='objdump', - help='the path of objdump') - parser.add_argument('--addr2line', default='addr2line', - help='the path of addr2line') - parser.add_argument('--annotation', default=None, - help='the path of annotation file') - - # TODO(cheyuw): Add an option for dumping stack usage of all functions. - - return parser.parse_args() - - -def ParseSymbolText(symbol_text): - """Parse the content of the symbol text. - - Args: - symbol_text: Text of the symbols. - - Returns: - symbols: Symbol list. - """ - # Example: "10093064 g F .text 0000015c .hidden hook_task" - symbol_regex = re.compile(r'^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+' - r'((?P<type>[OF])\s+)?\S+\s+' - r'(?P<size>[0-9A-Fa-f]+)\s+' - r'(\S+\s+)?(?P<name>\S+)$') - - symbols = [] - for line in symbol_text.splitlines(): - line = line.strip() - result = symbol_regex.match(line) - if result is not None: - address = int(result.group('address'), 16) - symtype = result.group('type') - if symtype is None: - symtype = 'O' - - size = int(result.group('size'), 16) - name = result.group('name') - symbols.append(Symbol(address, symtype, size, name)) - - return symbols - - -def ParseRoDataText(rodata_text): - """Parse the content of rodata - - Args: - symbol_text: Text of the rodata dump. - - Returns: - symbols: Symbol list. - """ - # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K... - # 100a7294 00000000 00000000 01000000 ............ - - base_offset = None - offset = None - rodata = [] - for line in rodata_text.splitlines(): - line = line.strip() - space = line.find(' ') - if space < 0: - continue - try: - address = int(line[0:space], 16) - except ValueError: - continue - - if not base_offset: - base_offset = address - offset = address - elif address != offset: - raise StackAnalyzerError('objdump of rodata not contiguous.') + if not base_offset: + base_offset = address + offset = address + elif address != offset: + raise StackAnalyzerError("objdump of rodata not contiguous.") - for i in range(0, 4): - num = line[(space + 1 + i*9):(space + 9 + i*9)] - if len(num.strip()) > 0: - val = int(num, 16) - else: - val = 0 - # TODO(drinkcat): Not all platforms are necessarily big-endian - rodata.append((val & 0x000000ff) << 24 | - (val & 0x0000ff00) << 8 | - (val & 0x00ff0000) >> 8 | - (val & 0xff000000) >> 24) + for i in range(0, 4): + num = line[(space + 1 + i * 9) : (space + 9 + i * 9)] + if len(num.strip()) > 0: + val = int(num, 16) + else: + val = 0 + # TODO(drinkcat): Not all platforms are necessarily big-endian + rodata.append( + (val & 0x000000FF) << 24 + | (val & 0x0000FF00) << 8 + | (val & 0x00FF0000) >> 8 + | (val & 0xFF000000) >> 24 + ) - offset = offset + 4*4 + offset = offset + 4 * 4 - return (base_offset, rodata) + return (base_offset, rodata) def LoadTasklist(section, export_taskinfo, symbols): - """Load the task information. + """Load the task information. - Args: - section: Section (RO | RW). - export_taskinfo: Handle of export_taskinfo.so. - symbols: Symbol list. + Args: + section: Section (RO | RW). + export_taskinfo: Handle of export_taskinfo.so. + symbols: Symbol list. - Returns: - tasklist: Task list. - """ + Returns: + tasklist: Task list. + """ - TaskInfoPointer = ctypes.POINTER(TaskInfo) - taskinfos = TaskInfoPointer() - if section == SECTION_RO: - get_taskinfos_func = export_taskinfo.get_ro_taskinfos - else: - get_taskinfos_func = export_taskinfo.get_rw_taskinfos + TaskInfoPointer = ctypes.POINTER(TaskInfo) + taskinfos = TaskInfoPointer() + if section == SECTION_RO: + get_taskinfos_func = export_taskinfo.get_ro_taskinfos + else: + get_taskinfos_func = export_taskinfo.get_rw_taskinfos - taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos)) + taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos)) - tasklist = [] - for index in range(taskinfo_num): - taskinfo = taskinfos[index] - tasklist.append(Task(taskinfo.name.decode('utf-8'), - taskinfo.routine.decode('utf-8'), - taskinfo.stack_size)) + tasklist = [] + for index in range(taskinfo_num): + taskinfo = taskinfos[index] + tasklist.append( + Task( + taskinfo.name.decode("utf-8"), + taskinfo.routine.decode("utf-8"), + taskinfo.stack_size, + ) + ) - # Resolve routine address for each task. It's more efficient to resolve all - # routine addresses of tasks together. - routine_map = dict((task.routine_name, None) for task in tasklist) + # Resolve routine address for each task. It's more efficient to resolve all + # routine addresses of tasks together. + routine_map = dict((task.routine_name, None) for task in tasklist) - for symbol in symbols: - # Resolve task routine address. - if symbol.name in routine_map: - # Assume the symbol of routine is unique. - assert routine_map[symbol.name] is None - routine_map[symbol.name] = symbol.address + for symbol in symbols: + # Resolve task routine address. + if symbol.name in routine_map: + # Assume the symbol of routine is unique. + assert routine_map[symbol.name] is None + routine_map[symbol.name] = symbol.address - for task in tasklist: - address = routine_map[task.routine_name] - # Assume we have resolved all routine addresses. - assert address is not None - task.routine_address = address + for task in tasklist: + address = routine_map[task.routine_name] + # Assume we have resolved all routine addresses. + assert address is not None + task.routine_address = address - return tasklist + return tasklist def main(): - """Main function.""" - try: - options = ParseArgs() - - # Load annotation config. - if options.annotation is None: - annotation = {} - elif not os.path.exists(options.annotation): - print('Warning: Annotation file {} does not exist.' - .format(options.annotation)) - annotation = {} - else: - try: - with open(options.annotation, 'r') as annotation_file: - annotation = yaml.safe_load(annotation_file) - - except yaml.YAMLError: - raise StackAnalyzerError('Failed to parse annotation file {}.' - .format(options.annotation)) - except IOError: - raise StackAnalyzerError('Failed to open annotation file {}.' - .format(options.annotation)) - - # TODO(cheyuw): Do complete annotation format verification. - if not isinstance(annotation, dict): - raise StackAnalyzerError('Invalid annotation file {}.' - .format(options.annotation)) - - # Generate and parse the symbols. + """Main function.""" try: - symbol_text = subprocess.check_output([options.objdump, - '-t', - options.elf_path], - encoding='utf-8') - rodata_text = subprocess.check_output([options.objdump, - '-s', - '-j', '.rodata', - options.elf_path], - encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('objdump failed to dump symbol table or rodata.') - except OSError: - raise StackAnalyzerError('Failed to run objdump.') - - symbols = ParseSymbolText(symbol_text) - rodata = ParseRoDataText(rodata_text) - - # Load the tasklist. - try: - export_taskinfo = ctypes.CDLL(options.export_taskinfo) - except OSError: - raise StackAnalyzerError('Failed to load export_taskinfo.') - - tasklist = LoadTasklist(options.section, export_taskinfo, symbols) - - analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation) - analyzer.Analyze() - except StackAnalyzerError as e: - print('Error: {}'.format(e)) - - -if __name__ == '__main__': - main() + options = ParseArgs() + + # Load annotation config. + if options.annotation is None: + annotation = {} + elif not os.path.exists(options.annotation): + print( + "Warning: Annotation file {} does not exist.".format(options.annotation) + ) + annotation = {} + else: + try: + with open(options.annotation, "r") as annotation_file: + annotation = yaml.safe_load(annotation_file) + + except yaml.YAMLError: + raise StackAnalyzerError( + "Failed to parse annotation file {}.".format(options.annotation) + ) + except IOError: + raise StackAnalyzerError( + "Failed to open annotation file {}.".format(options.annotation) + ) + + # TODO(cheyuw): Do complete annotation format verification. + if not isinstance(annotation, dict): + raise StackAnalyzerError( + "Invalid annotation file {}.".format(options.annotation) + ) + + # Generate and parse the symbols. + try: + symbol_text = subprocess.check_output( + [options.objdump, "-t", options.elf_path], encoding="utf-8" + ) + rodata_text = subprocess.check_output( + [options.objdump, "-s", "-j", ".rodata", options.elf_path], + encoding="utf-8", + ) + except subprocess.CalledProcessError: + raise StackAnalyzerError("objdump failed to dump symbol table or rodata.") + except OSError: + raise StackAnalyzerError("Failed to run objdump.") + + symbols = ParseSymbolText(symbol_text) + rodata = ParseRoDataText(rodata_text) + + # Load the tasklist. + try: + export_taskinfo = ctypes.CDLL(options.export_taskinfo) + except OSError: + raise StackAnalyzerError("Failed to load export_taskinfo.") + + tasklist = LoadTasklist(options.section, export_taskinfo, symbols) + + analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation) + analyzer.Analyze() + except StackAnalyzerError as e: + print("Error: {}".format(e)) + + +if __name__ == "__main__": + main() diff --git a/extra/stack_analyzer/stack_analyzer_unittest.py b/extra/stack_analyzer/stack_analyzer_unittest.py index c36fa9da45..ad2837a8a4 100755 --- a/extra/stack_analyzer/stack_analyzer_unittest.py +++ b/extra/stack_analyzer/stack_analyzer_unittest.py @@ -11,820 +11,924 @@ from __future__ import print_function -import mock import os import subprocess import unittest +import mock import stack_analyzer as sa class ObjectTest(unittest.TestCase): - """Tests for classes of basic objects.""" - - def testTask(self): - task_a = sa.Task('a', 'a_task', 1234) - task_b = sa.Task('b', 'b_task', 5678, 0x1000) - self.assertEqual(task_a, task_a) - self.assertNotEqual(task_a, task_b) - self.assertNotEqual(task_a, None) - - def testSymbol(self): - symbol_a = sa.Symbol(0x1234, 'F', 32, 'a') - symbol_b = sa.Symbol(0x234, 'O', 42, 'b') - self.assertEqual(symbol_a, symbol_a) - self.assertNotEqual(symbol_a, symbol_b) - self.assertNotEqual(symbol_a, None) - - def testCallsite(self): - callsite_a = sa.Callsite(0x1002, 0x3000, False) - callsite_b = sa.Callsite(0x1002, 0x3000, True) - self.assertEqual(callsite_a, callsite_a) - self.assertNotEqual(callsite_a, callsite_b) - self.assertNotEqual(callsite_a, None) - - def testFunction(self): - func_a = sa.Function(0x100, 'a', 0, []) - func_b = sa.Function(0x200, 'b', 0, []) - self.assertEqual(func_a, func_a) - self.assertNotEqual(func_a, func_b) - self.assertNotEqual(func_a, None) + """Tests for classes of basic objects.""" + + def testTask(self): + task_a = sa.Task("a", "a_task", 1234) + task_b = sa.Task("b", "b_task", 5678, 0x1000) + self.assertEqual(task_a, task_a) + self.assertNotEqual(task_a, task_b) + self.assertNotEqual(task_a, None) + + def testSymbol(self): + symbol_a = sa.Symbol(0x1234, "F", 32, "a") + symbol_b = sa.Symbol(0x234, "O", 42, "b") + self.assertEqual(symbol_a, symbol_a) + self.assertNotEqual(symbol_a, symbol_b) + self.assertNotEqual(symbol_a, None) + + def testCallsite(self): + callsite_a = sa.Callsite(0x1002, 0x3000, False) + callsite_b = sa.Callsite(0x1002, 0x3000, True) + self.assertEqual(callsite_a, callsite_a) + self.assertNotEqual(callsite_a, callsite_b) + self.assertNotEqual(callsite_a, None) + + def testFunction(self): + func_a = sa.Function(0x100, "a", 0, []) + func_b = sa.Function(0x200, "b", 0, []) + self.assertEqual(func_a, func_a) + self.assertNotEqual(func_a, func_b) + self.assertNotEqual(func_a, None) class ArmAnalyzerTest(unittest.TestCase): - """Tests for class ArmAnalyzer.""" - - def AppendConditionCode(self, opcodes): - rets = [] - for opcode in opcodes: - rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES) - - return rets - - def testInstructionMatching(self): - jump_list = self.AppendConditionCode(['b', 'bx']) - jump_list += (list(opcode + '.n' for opcode in jump_list) + - list(opcode + '.w' for opcode in jump_list)) - for opcode in jump_list: - self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('bl')) - self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('blx')) - - cbz_list = ['cbz', 'cbnz', 'cbz.n', 'cbnz.n', 'cbz.w', 'cbnz.w'] - for opcode in cbz_list: - self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match('cbn')) - - call_list = self.AppendConditionCode(['bl', 'blx']) - call_list += list(opcode + '.n' for opcode in call_list) - for opcode in call_list: - self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match('ble')) - - result = sa.ArmAnalyzer.CALL_OPERAND_RE.match('53f90 <get_time+0x18>') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '53f90') - self.assertEqual(result.group(2), 'get_time+0x18') - - result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match('r6, 53f90 <get+0x0>') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '53f90') - self.assertEqual(result.group(2), 'get+0x0') - - self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('push')) - self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('pushal')) - self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('stmdb')) - self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('lstm')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subw')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub.w')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs.w')) - - result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, sp, #1668 ; 0x684') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '1668') - result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, #1668') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '1668') - self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match('sl, #1668')) - - def testAnalyzeFunction(self): - analyzer = sa.ArmAnalyzer() - symbol = sa.Symbol(0x10, 'F', 0x100, 'foo') - instructions = [ - (0x10, 'push', '{r4, r5, r6, r7, lr}'), - (0x12, 'subw', 'sp, sp, #16 ; 0x10'), - (0x16, 'movs', 'lr, r1'), - (0x18, 'beq.n', '26 <foo+0x26>'), - (0x1a, 'bl', '30 <foo+0x30>'), - (0x1e, 'bl', 'deadbeef <bar>'), - (0x22, 'blx', '0 <woo>'), - (0x26, 'push', '{r1}'), - (0x28, 'stmdb', 'sp!, {r4, r5, r6, r7, r8, r9, lr}'), - (0x2c, 'stmdb', 'sp!, {r4}'), - (0x30, 'stmdb', 'sp, {r4}'), - (0x34, 'bx.n', '10 <foo>'), - (0x36, 'bx.n', 'r3'), - (0x38, 'ldr', 'pc, [r10]'), - ] - (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions) - self.assertEqual(size, 72) - expect_callsites = [sa.Callsite(0x1e, 0xdeadbeef, False), - sa.Callsite(0x22, 0x0, False), - sa.Callsite(0x34, 0x10, True), - sa.Callsite(0x36, None, True), - sa.Callsite(0x38, None, True)] - self.assertEqual(callsites, expect_callsites) + """Tests for class ArmAnalyzer.""" + + def AppendConditionCode(self, opcodes): + rets = [] + for opcode in opcodes: + rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES) + + return rets + + def testInstructionMatching(self): + jump_list = self.AppendConditionCode(["b", "bx"]) + jump_list += list(opcode + ".n" for opcode in jump_list) + list( + opcode + ".w" for opcode in jump_list + ) + for opcode in jump_list: + self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("bl")) + self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("blx")) + + cbz_list = ["cbz", "cbnz", "cbz.n", "cbnz.n", "cbz.w", "cbnz.w"] + for opcode in cbz_list: + self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match("cbn")) + + call_list = self.AppendConditionCode(["bl", "blx"]) + call_list += list(opcode + ".n" for opcode in call_list) + for opcode in call_list: + self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match("ble")) + + result = sa.ArmAnalyzer.CALL_OPERAND_RE.match("53f90 <get_time+0x18>") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "53f90") + self.assertEqual(result.group(2), "get_time+0x18") + + result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match("r6, 53f90 <get+0x0>") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "53f90") + self.assertEqual(result.group(2), "get+0x0") + + self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("push")) + self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("pushal")) + self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("stmdb")) + self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("lstm")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subw")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub.w")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs.w")) + + result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, sp, #1668 ; 0x684") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "1668") + result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, #1668") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "1668") + self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match("sl, #1668")) + + def testAnalyzeFunction(self): + analyzer = sa.ArmAnalyzer() + symbol = sa.Symbol(0x10, "F", 0x100, "foo") + instructions = [ + (0x10, "push", "{r4, r5, r6, r7, lr}"), + (0x12, "subw", "sp, sp, #16 ; 0x10"), + (0x16, "movs", "lr, r1"), + (0x18, "beq.n", "26 <foo+0x26>"), + (0x1A, "bl", "30 <foo+0x30>"), + (0x1E, "bl", "deadbeef <bar>"), + (0x22, "blx", "0 <woo>"), + (0x26, "push", "{r1}"), + (0x28, "stmdb", "sp!, {r4, r5, r6, r7, r8, r9, lr}"), + (0x2C, "stmdb", "sp!, {r4}"), + (0x30, "stmdb", "sp, {r4}"), + (0x34, "bx.n", "10 <foo>"), + (0x36, "bx.n", "r3"), + (0x38, "ldr", "pc, [r10]"), + ] + (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions) + self.assertEqual(size, 72) + expect_callsites = [ + sa.Callsite(0x1E, 0xDEADBEEF, False), + sa.Callsite(0x22, 0x0, False), + sa.Callsite(0x34, 0x10, True), + sa.Callsite(0x36, None, True), + sa.Callsite(0x38, None, True), + ] + self.assertEqual(callsites, expect_callsites) class StackAnalyzerTest(unittest.TestCase): - """Tests for class StackAnalyzer.""" - - def setUp(self): - symbols = [sa.Symbol(0x1000, 'F', 0x15C, 'hook_task'), - sa.Symbol(0x2000, 'F', 0x51C, 'console_task'), - sa.Symbol(0x3200, 'O', 0x124, '__just_data'), - sa.Symbol(0x4000, 'F', 0x11C, 'touchpad_calc'), - sa.Symbol(0x5000, 'F', 0x12C, 'touchpad_calc.constprop.42'), - sa.Symbol(0x12000, 'F', 0x13C, 'trackpad_range'), - sa.Symbol(0x13000, 'F', 0x200, 'inlined_mul'), - sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul'), - sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul_alias'), - sa.Symbol(0x20000, 'O', 0x0, '__array'), - sa.Symbol(0x20010, 'O', 0x0, '__array_end'), - ] - tasklist = [sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - sa.Task('CONSOLE', 'console_task', 460, 0x2000)] - # Array at 0x20000 that contains pointers to hook_task and console_task, - # with stride=8, offset=4 - rodata = (0x20000, [ 0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000 ]) - options = mock.MagicMock(elf_path='./ec.RW.elf', - export_taskinfo='fake', - section='RW', - objdump='objdump', - addr2line='addr2line', - annotation=None) - self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {}) - - def testParseSymbolText(self): - symbol_text = ( - '0 g F .text e8 Foo\n' - '0000dead w F .text 000000e8 .hidden Bar\n' - 'deadbeef l O .bss 00000004 .hidden Woooo\n' - 'deadbee g O .rodata 00000008 __Hooo_ooo\n' - 'deadbee g .rodata 00000000 __foo_doo_coo_end\n' - ) - symbols = sa.ParseSymbolText(symbol_text) - expect_symbols = [sa.Symbol(0x0, 'F', 0xe8, 'Foo'), - sa.Symbol(0xdead, 'F', 0xe8, 'Bar'), - sa.Symbol(0xdeadbeef, 'O', 0x4, 'Woooo'), - sa.Symbol(0xdeadbee, 'O', 0x8, '__Hooo_ooo'), - sa.Symbol(0xdeadbee, 'O', 0x0, '__foo_doo_coo_end')] - self.assertEqual(symbols, expect_symbols) - - def testParseRoData(self): - rodata_text = ( - '\n' - 'Contents of section .rodata:\n' - ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n' - ) - rodata = sa.ParseRoDataText(rodata_text) - expect_rodata = (0x20000, - [ 0x0010adde, 0x00001000, 0x0020adde, 0x00002000 ]) - self.assertEqual(rodata, expect_rodata) - - def testLoadTasklist(self): - def tasklist_to_taskinfos(pointer, tasklist): - taskinfos = [] - for task in tasklist: - taskinfos.append(sa.TaskInfo(name=task.name.encode('utf-8'), - routine=task.routine_name.encode('utf-8'), - stack_size=task.stack_max_size)) - - TaskInfoArray = sa.TaskInfo * len(taskinfos) - pointer.contents.contents = TaskInfoArray(*taskinfos) - return len(taskinfos) - - def ro_taskinfos(pointer): - return tasklist_to_taskinfos(pointer, expect_ro_tasklist) - - def rw_taskinfos(pointer): - return tasklist_to_taskinfos(pointer, expect_rw_tasklist) - - expect_ro_tasklist = [ - sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - ] - - expect_rw_tasklist = [ - sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - sa.Task('WOOKS', 'hook_task', 4096, 0x1000), - sa.Task('CONSOLE', 'console_task', 460, 0x2000), - ] - - export_taskinfo = mock.MagicMock( - get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos), - get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos)) - - tasklist = sa.LoadTasklist('RO', export_taskinfo, self.analyzer.symbols) - self.assertEqual(tasklist, expect_ro_tasklist) - tasklist = sa.LoadTasklist('RW', export_taskinfo, self.analyzer.symbols) - self.assertEqual(tasklist, expect_rw_tasklist) - - def testResolveAnnotation(self): - self.analyzer.annotation = {} - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(remove_rules, []) - self.assertEqual(invalid_sigtxts, set()) - - self.analyzer.annotation = {'add': None, 'remove': None} - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(remove_rules, []) - self.assertEqual(invalid_sigtxts, set()) - - self.analyzer.annotation = { - 'add': None, - 'remove': [ - [['a', 'b'], ['0', '[', '2'], 'x'], - [['a', 'b[x:3]'], ['0', '1', '2'], 'x'], - ], - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(list.sort(remove_rules), list.sort([ - [('a', None, None), ('1', None, None), ('x', None, None)], - [('a', None, None), ('0', None, None), ('x', None, None)], - [('a', None, None), ('2', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('1', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('0', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('2', None, None), ('x', None, None)], - ])) - self.assertEqual(invalid_sigtxts, {'['}) - - self.analyzer.annotation = { - 'add': { - 'touchpad_calc': [ dict(name='__array', stride=8, offset=4) ], + """Tests for class StackAnalyzer.""" + + def setUp(self): + symbols = [ + sa.Symbol(0x1000, "F", 0x15C, "hook_task"), + sa.Symbol(0x2000, "F", 0x51C, "console_task"), + sa.Symbol(0x3200, "O", 0x124, "__just_data"), + sa.Symbol(0x4000, "F", 0x11C, "touchpad_calc"), + sa.Symbol(0x5000, "F", 0x12C, "touchpad_calc.constprop.42"), + sa.Symbol(0x12000, "F", 0x13C, "trackpad_range"), + sa.Symbol(0x13000, "F", 0x200, "inlined_mul"), + sa.Symbol(0x13100, "F", 0x200, "inlined_mul"), + sa.Symbol(0x13100, "F", 0x200, "inlined_mul_alias"), + sa.Symbol(0x20000, "O", 0x0, "__array"), + sa.Symbol(0x20010, "O", 0x0, "__array_end"), + ] + tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + sa.Task("CONSOLE", "console_task", 460, 0x2000), + ] + # Array at 0x20000 that contains pointers to hook_task and console_task, + # with stride=8, offset=4 + rodata = (0x20000, [0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000]) + options = mock.MagicMock( + elf_path="./ec.RW.elf", + export_taskinfo="fake", + section="RW", + objdump="objdump", + addr2line="addr2line", + annotation=None, + ) + self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {}) + + def testParseSymbolText(self): + symbol_text = ( + "0 g F .text e8 Foo\n" + "0000dead w F .text 000000e8 .hidden Bar\n" + "deadbeef l O .bss 00000004 .hidden Woooo\n" + "deadbee g O .rodata 00000008 __Hooo_ooo\n" + "deadbee g .rodata 00000000 __foo_doo_coo_end\n" + ) + symbols = sa.ParseSymbolText(symbol_text) + expect_symbols = [ + sa.Symbol(0x0, "F", 0xE8, "Foo"), + sa.Symbol(0xDEAD, "F", 0xE8, "Bar"), + sa.Symbol(0xDEADBEEF, "O", 0x4, "Woooo"), + sa.Symbol(0xDEADBEE, "O", 0x8, "__Hooo_ooo"), + sa.Symbol(0xDEADBEE, "O", 0x0, "__foo_doo_coo_end"), + ] + self.assertEqual(symbols, expect_symbols) + + def testParseRoData(self): + rodata_text = ( + "\n" + "Contents of section .rodata:\n" + " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n" + ) + rodata = sa.ParseRoDataText(rodata_text) + expect_rodata = (0x20000, [0x0010ADDE, 0x00001000, 0x0020ADDE, 0x00002000]) + self.assertEqual(rodata, expect_rodata) + + def testLoadTasklist(self): + def tasklist_to_taskinfos(pointer, tasklist): + taskinfos = [] + for task in tasklist: + taskinfos.append( + sa.TaskInfo( + name=task.name.encode("utf-8"), + routine=task.routine_name.encode("utf-8"), + stack_size=task.stack_max_size, + ) + ) + + TaskInfoArray = sa.TaskInfo * len(taskinfos) + pointer.contents.contents = TaskInfoArray(*taskinfos) + return len(taskinfos) + + def ro_taskinfos(pointer): + return tasklist_to_taskinfos(pointer, expect_ro_tasklist) + + def rw_taskinfos(pointer): + return tasklist_to_taskinfos(pointer, expect_rw_tasklist) + + expect_ro_tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + ] + + expect_rw_tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + sa.Task("WOOKS", "hook_task", 4096, 0x1000), + sa.Task("CONSOLE", "console_task", 460, 0x2000), + ] + + export_taskinfo = mock.MagicMock( + get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos), + get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos), + ) + + tasklist = sa.LoadTasklist("RO", export_taskinfo, self.analyzer.symbols) + self.assertEqual(tasklist, expect_ro_tasklist) + tasklist = sa.LoadTasklist("RW", export_taskinfo, self.analyzer.symbols) + self.assertEqual(tasklist, expect_rw_tasklist) + + def testResolveAnnotation(self): + self.analyzer.annotation = {} + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual(remove_rules, []) + self.assertEqual(invalid_sigtxts, set()) + + self.analyzer.annotation = {"add": None, "remove": None} + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual(remove_rules, []) + self.assertEqual(invalid_sigtxts, set()) + + self.analyzer.annotation = { + "add": None, + "remove": [ + [["a", "b"], ["0", "[", "2"], "x"], + [["a", "b[x:3]"], ["0", "1", "2"], "x"], + ], + } + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual( + list.sort(remove_rules), + list.sort( + [ + [("a", None, None), ("1", None, None), ("x", None, None)], + [("a", None, None), ("0", None, None), ("x", None, None)], + [("a", None, None), ("2", None, None), ("x", None, None)], + [ + ("b", os.path.abspath("x"), 3), + ("1", None, None), + ("x", None, None), + ], + [ + ("b", os.path.abspath("x"), 3), + ("0", None, None), + ("x", None, None), + ], + [ + ("b", os.path.abspath("x"), 3), + ("2", None, None), + ("x", None, None), + ], + ] + ), + ) + self.assertEqual(invalid_sigtxts, {"["}) + + self.analyzer.annotation = { + "add": { + "touchpad_calc": [dict(name="__array", stride=8, offset=4)], + } + } + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual( + add_rules, + { + ("touchpad_calc", None, None): set( + [("console_task", None, None), ("hook_task", None, None)] + ) + }, + ) + + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + 0x5000: sa.Function(0x5000, "touchpad_calc.constprop.42", 0, []), + 0x13000: sa.Function(0x13000, "inlined_mul", 0, []), + 0x13100: sa.Function(0x13100, "inlined_mul", 0, []), + } + funcs[0x1000].callsites = [sa.Callsite(0x1002, None, False, None)] + # Set address_to_line_cache to fake the results of addr2line. + self.analyzer.address_to_line_cache = { + (0x1000, False): [("hook_task", os.path.abspath("a.c"), 10)], + (0x1002, False): [("toot_calc", os.path.abspath("t.c"), 1234)], + (0x2000, False): [("console_task", os.path.abspath("b.c"), 20)], + (0x4000, False): [("toudhpad_calc", os.path.abspath("a.c"), 20)], + (0x5000, False): [ + ("touchpad_calc.constprop.42", os.path.abspath("b.c"), 40) + ], + (0x12000, False): [("trackpad_range", os.path.abspath("t.c"), 10)], + (0x13000, False): [("inlined_mul", os.path.abspath("x.c"), 12)], + (0x13100, False): [("inlined_mul", os.path.abspath("x.c"), 12)], + } + self.analyzer.annotation = { + "add": { + "hook_task.lto.573": ["touchpad_calc.lto.2501[a.c]"], + "console_task": ["touchpad_calc[b.c]", "inlined_mul_alias"], + "hook_task[q.c]": ["hook_task"], + "inlined_mul[x.c]": ["inlined_mul"], + "toot_calc[t.c:1234]": ["hook_task"], + }, + "remove": [ + ["touchpad?calc["], + "touchpad_calc", + ["touchpad_calc[a.c]"], + ["task_unk[a.c]"], + ["touchpad_calc[x/a.c]"], + ["trackpad_range"], + ["inlined_mul"], + ["inlined_mul", "console_task", "touchpad_calc[a.c]"], + ["inlined_mul", "inlined_mul_alias", "console_task"], + ["inlined_mul", "inlined_mul_alias", "console_task"], + ], + } + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(invalid_sigtxts, {"touchpad?calc["}) + + signature_set = set() + for src_sig, dst_sigs in add_rules.items(): + signature_set.add(src_sig) + signature_set.update(dst_sigs) + + for remove_sigs in remove_rules: + signature_set.update(remove_sigs) + + (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, signature_set) + result = self.analyzer.ResolveAnnotation(funcs) + (add_set, remove_list, eliminated_addrs, failed_sigs) = result + + expect_signature_map = { + ("hook_task", None, None): {funcs[0x1000]}, + ("touchpad_calc", os.path.abspath("a.c"), None): {funcs[0x4000]}, + ("touchpad_calc", os.path.abspath("b.c"), None): {funcs[0x5000]}, + ("console_task", None, None): {funcs[0x2000]}, + ("inlined_mul_alias", None, None): {funcs[0x13100]}, + ("inlined_mul", os.path.abspath("x.c"), None): { + funcs[0x13000], + funcs[0x13100], + }, + ("inlined_mul", None, None): {funcs[0x13000], funcs[0x13100]}, + } + self.assertEqual(len(signature_map), len(expect_signature_map)) + for sig, funclist in signature_map.items(): + self.assertEqual(set(funclist), expect_signature_map[sig]) + + self.assertEqual( + add_set, + { + (funcs[0x1000], funcs[0x4000]), + (funcs[0x1000], funcs[0x1000]), + (funcs[0x2000], funcs[0x5000]), + (funcs[0x2000], funcs[0x13100]), + (funcs[0x13000], funcs[0x13000]), + (funcs[0x13000], funcs[0x13100]), + (funcs[0x13100], funcs[0x13000]), + (funcs[0x13100], funcs[0x13100]), + }, + ) + expect_remove_list = [ + [funcs[0x4000]], + [funcs[0x13000]], + [funcs[0x13100]], + [funcs[0x13000], funcs[0x2000], funcs[0x4000]], + [funcs[0x13100], funcs[0x2000], funcs[0x4000]], + [funcs[0x13000], funcs[0x13100], funcs[0x2000]], + [funcs[0x13100], funcs[0x13100], funcs[0x2000]], + ] + self.assertEqual(len(remove_list), len(expect_remove_list)) + for remove_path in remove_list: + self.assertTrue(remove_path in expect_remove_list) + + self.assertEqual(eliminated_addrs, {0x1002}) + self.assertEqual( + failed_sigs, + { + ("touchpad?calc[", sa.StackAnalyzer.ANNOTATION_ERROR_INVALID), + ("touchpad_calc", sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS), + ("hook_task[q.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("task_unk[a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("touchpad_calc[x/a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("trackpad_range", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + }, + ) + + def testPreprocessAnnotation(self): + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + funcs[0x1000].callsites = [sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])] + funcs[0x2000].callsites = [ + sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]), + sa.Callsite(0x2006, None, True, None), + ] + add_set = { + (funcs[0x2000], funcs[0x2000]), + (funcs[0x2000], funcs[0x4000]), + (funcs[0x4000], funcs[0x1000]), + (funcs[0x4000], funcs[0x2000]), } - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, { - ('touchpad_calc', None, None): - set([('console_task', None, None), ('hook_task', None, None)])}) - - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - 0x5000: sa.Function(0x5000, 'touchpad_calc.constprop.42', 0, []), - 0x13000: sa.Function(0x13000, 'inlined_mul', 0, []), - 0x13100: sa.Function(0x13100, 'inlined_mul', 0, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, None, False, None)] - # Set address_to_line_cache to fake the results of addr2line. - self.analyzer.address_to_line_cache = { - (0x1000, False): [('hook_task', os.path.abspath('a.c'), 10)], - (0x1002, False): [('toot_calc', os.path.abspath('t.c'), 1234)], - (0x2000, False): [('console_task', os.path.abspath('b.c'), 20)], - (0x4000, False): [('toudhpad_calc', os.path.abspath('a.c'), 20)], - (0x5000, False): [ - ('touchpad_calc.constprop.42', os.path.abspath('b.c'), 40)], - (0x12000, False): [('trackpad_range', os.path.abspath('t.c'), 10)], - (0x13000, False): [('inlined_mul', os.path.abspath('x.c'), 12)], - (0x13100, False): [('inlined_mul', os.path.abspath('x.c'), 12)], - } - self.analyzer.annotation = { - 'add': { - 'hook_task.lto.573': ['touchpad_calc.lto.2501[a.c]'], - 'console_task': ['touchpad_calc[b.c]', 'inlined_mul_alias'], - 'hook_task[q.c]': ['hook_task'], - 'inlined_mul[x.c]': ['inlined_mul'], - 'toot_calc[t.c:1234]': ['hook_task'], - }, - 'remove': [ - ['touchpad?calc['], - 'touchpad_calc', - ['touchpad_calc[a.c]'], - ['task_unk[a.c]'], - ['touchpad_calc[x/a.c]'], - ['trackpad_range'], - ['inlined_mul'], - ['inlined_mul', 'console_task', 'touchpad_calc[a.c]'], - ['inlined_mul', 'inlined_mul_alias', 'console_task'], - ['inlined_mul', 'inlined_mul_alias', 'console_task'], - ], - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(invalid_sigtxts, {'touchpad?calc['}) - - signature_set = set() - for src_sig, dst_sigs in add_rules.items(): - signature_set.add(src_sig) - signature_set.update(dst_sigs) - - for remove_sigs in remove_rules: - signature_set.update(remove_sigs) - - (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, - signature_set) - result = self.analyzer.ResolveAnnotation(funcs) - (add_set, remove_list, eliminated_addrs, failed_sigs) = result - - expect_signature_map = { - ('hook_task', None, None): {funcs[0x1000]}, - ('touchpad_calc', os.path.abspath('a.c'), None): {funcs[0x4000]}, - ('touchpad_calc', os.path.abspath('b.c'), None): {funcs[0x5000]}, - ('console_task', None, None): {funcs[0x2000]}, - ('inlined_mul_alias', None, None): {funcs[0x13100]}, - ('inlined_mul', os.path.abspath('x.c'), None): {funcs[0x13000], - funcs[0x13100]}, - ('inlined_mul', None, None): {funcs[0x13000], funcs[0x13100]}, - } - self.assertEqual(len(signature_map), len(expect_signature_map)) - for sig, funclist in signature_map.items(): - self.assertEqual(set(funclist), expect_signature_map[sig]) - - self.assertEqual(add_set, { - (funcs[0x1000], funcs[0x4000]), - (funcs[0x1000], funcs[0x1000]), - (funcs[0x2000], funcs[0x5000]), - (funcs[0x2000], funcs[0x13100]), - (funcs[0x13000], funcs[0x13000]), - (funcs[0x13000], funcs[0x13100]), - (funcs[0x13100], funcs[0x13000]), - (funcs[0x13100], funcs[0x13100]), - }) - expect_remove_list = [ - [funcs[0x4000]], - [funcs[0x13000]], - [funcs[0x13100]], - [funcs[0x13000], funcs[0x2000], funcs[0x4000]], - [funcs[0x13100], funcs[0x2000], funcs[0x4000]], - [funcs[0x13000], funcs[0x13100], funcs[0x2000]], - [funcs[0x13100], funcs[0x13100], funcs[0x2000]], - ] - self.assertEqual(len(remove_list), len(expect_remove_list)) - for remove_path in remove_list: - self.assertTrue(remove_path in expect_remove_list) - - self.assertEqual(eliminated_addrs, {0x1002}) - self.assertEqual(failed_sigs, { - ('touchpad?calc[', sa.StackAnalyzer.ANNOTATION_ERROR_INVALID), - ('touchpad_calc', sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS), - ('hook_task[q.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('task_unk[a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('touchpad_calc[x/a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('trackpad_range', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - }) - - def testPreprocessAnnotation(self): - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])] - funcs[0x2000].callsites = [ - sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]), - sa.Callsite(0x2006, None, True, None), - ] - add_set = { - (funcs[0x2000], funcs[0x2000]), - (funcs[0x2000], funcs[0x4000]), - (funcs[0x4000], funcs[0x1000]), - (funcs[0x4000], funcs[0x2000]), - } - remove_list = [ - [funcs[0x1000]], - [funcs[0x2000], funcs[0x2000]], - [funcs[0x4000], funcs[0x1000]], - [funcs[0x2000], funcs[0x4000], funcs[0x2000]], - [funcs[0x4000], funcs[0x1000], funcs[0x4000]], - ] - eliminated_addrs = {0x2006} - - remaining_remove_list = self.analyzer.PreprocessAnnotation(funcs, - add_set, - remove_list, - eliminated_addrs) - - expect_funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - expect_funcs[0x2000].callsites = [ - sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])] - expect_funcs[0x4000].callsites = [ - sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])] - self.assertEqual(funcs, expect_funcs) - self.assertEqual(remaining_remove_list, [ - [funcs[0x2000], funcs[0x4000], funcs[0x2000]], - ]) - - def testAndesAnalyzeDisassembly(self): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n' - ' 1004: 47 70\t\tmovi55 $r0, #1\n' - ' 1006: b1 13\tbnezs8 100929de <flash_command_write>\n' - ' 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n' - '00002000 <console_task>:\n' - ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n' - ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n' - ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n' - ' 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n' - '00004000 <touchpad_calc>:\n' - ' 4000: 47 70\t\tmovi55 $r0, #1\n' - '00010000 <look_task>:' - ) - function_map = self.analyzer.AnalyzeDisassembly(disasm_text) - func_hook_task = sa.Function(0x1000, 'hook_task', 48, [ - sa.Callsite(0x1006, 0x100929de, True, None)]) - expect_funcmap = { - 0x1000: func_hook_task, - 0x2000: sa.Function(0x2000, 'console_task', 16, - [sa.Callsite(0x2002, 0x1000, False, func_hook_task), - sa.Callsite(0x2006, 0x53968, True, None)]), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - self.assertEqual(function_map, expect_funcmap) - - def testArmAnalyzeDisassembly(self): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: dead beef\tfake\n' - ' 1004: 4770\t\tbx lr\n' - ' 1006: b113\tcbz r3, 100929de <flash_command_write>\n' - ' 1008: 00015cfc\t.word 0x00015cfc\n' - '00002000 <console_task>:\n' - ' 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n' - ' 2002: f00e fcc5\tbl 1000 <hook_task>\n' - ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n' - ' 200a: dead beef\tfake\n' - '00004000 <touchpad_calc>:\n' - ' 4000: 4770\t\tbx lr\n' - '00010000 <look_task>:' - ) - function_map = self.analyzer.AnalyzeDisassembly(disasm_text) - func_hook_task = sa.Function(0x1000, 'hook_task', 0, [ - sa.Callsite(0x1006, 0x100929de, True, None)]) - expect_funcmap = { - 0x1000: func_hook_task, - 0x2000: sa.Function(0x2000, 'console_task', 8, - [sa.Callsite(0x2002, 0x1000, False, func_hook_task), - sa.Callsite(0x2006, 0x53968, True, None)]), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - self.assertEqual(function_map, expect_funcmap) - - def testAnalyzeCallGraph(self): - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 8, []), - 0x3000: sa.Function(0x3000, 'task_a', 12, []), - 0x4000: sa.Function(0x4000, 'task_b', 96, []), - 0x5000: sa.Function(0x5000, 'task_c', 32, []), - 0x6000: sa.Function(0x6000, 'task_d', 100, []), - 0x7000: sa.Function(0x7000, 'task_e', 24, []), - 0x8000: sa.Function(0x8000, 'task_f', 20, []), - 0x9000: sa.Function(0x9000, 'task_g', 20, []), - 0x10000: sa.Function(0x10000, 'task_x', 16, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]), - sa.Callsite(0x1006, 0x4000, False, funcs[0x4000])] - funcs[0x2000].callsites = [ - sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]), - sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]), - sa.Callsite(0x200a, 0x10000, False, funcs[0x10000])] - funcs[0x3000].callsites = [ - sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]), - sa.Callsite(0x3006, 0x1000, False, funcs[0x1000])] - funcs[0x4000].callsites = [ - sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]), - sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]), - sa.Callsite(0x400a, 0x8000, False, funcs[0x8000])] - funcs[0x5000].callsites = [ - sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])] - funcs[0x7000].callsites = [ - sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])] - funcs[0x8000].callsites = [ - sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])] - funcs[0x9000].callsites = [ - sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])] - funcs[0x10000].callsites = [ - sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])] - - cycles = self.analyzer.AnalyzeCallGraph(funcs, [ - [funcs[0x2000]] * 2, - [funcs[0x10000], funcs[0x2000]] * 3, - [funcs[0x1000], funcs[0x3000], funcs[0x1000]] - ]) - - expect_func_stack = { - 0x1000: (268, [funcs[0x1000], - funcs[0x3000], - funcs[0x4000], - funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x2000: (208, [funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x5000], - funcs[0x4000], - funcs[0x7000]]), - 0x3000: (280, [funcs[0x3000], - funcs[0x1000], - funcs[0x3000], - funcs[0x4000], - funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x4000: (120, [funcs[0x4000], funcs[0x7000]]), - 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]), - 0x6000: (100, [funcs[0x6000]]), - 0x7000: (24, [funcs[0x7000]]), - 0x8000: (160, [funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]), - 0x10000: (200, [funcs[0x10000], - funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x5000], - funcs[0x4000], - funcs[0x7000]]), - } - expect_cycles = [ - {funcs[0x4000], funcs[0x8000], funcs[0x9000]}, - {funcs[0x7000]}, - ] - for func in funcs.values(): - (stack_max_usage, stack_max_path) = expect_func_stack[func.address] - self.assertEqual(func.stack_max_usage, stack_max_usage) - self.assertEqual(func.stack_max_path, stack_max_path) - - self.assertEqual(len(cycles), len(expect_cycles)) - for cycle in cycles: - self.assertTrue(cycle in expect_cycles) - - @mock.patch('subprocess.check_output') - def testAddressToLine(self, checkoutput_mock): - checkoutput_mock.return_value = 'fake_func\n/test.c:1' - self.assertEqual(self.analyzer.AddressToLine(0x1234), - [('fake_func', '/test.c', 1)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '1234'], encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = 'fake_func\n/a.c:1\nbake_func\n/b.c:2\n' - self.assertEqual(self.analyzer.AddressToLine(0x1234, True), - [('fake_func', '/a.c', 1), ('bake_func', '/b.c', 2)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '1234', '-i'], - encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = 'fake_func\n/test.c:1 (discriminator 128)' - self.assertEqual(self.analyzer.AddressToLine(0x12345), - [('fake_func', '/test.c', 1)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '12345'], encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = '??\n:?\nbake_func\n/b.c:2\n' - self.assertEqual(self.analyzer.AddressToLine(0x123456), - [None, ('bake_func', '/b.c', 2)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '123456'], encoding='utf-8') - checkoutput_mock.reset_mock() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'addr2line failed to resolve lines.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.AddressToLine(0x5678) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run addr2line.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.AddressToLine(0x9012) - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine') - def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n' - ' 1002: 47 70\t\tmovi55 $r0, #1\n' - ' 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n' - '00002000 <console_task>:\n' - ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n' - ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n' - ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n' - ' 200a: 12 34 56 78\tjral5 $r0\n' - ) - - addrtoline_mock.return_value = [('??', '??', 0)] - self.analyzer.annotation = { - 'exception_frame_size': 64, - 'remove': [['fake_func']], - } - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.return_value = disasm_text - self.analyzer.Analyze() - print_mock.assert_has_calls([ - mock.call( - 'Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048'), - mock.call('Call Trace:'), - mock.call(' hook_task (32) [??:0] 1000'), - mock.call( - 'Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460'), - mock.call('Call Trace:'), - mock.call(' console_task (16) [??:0] 2000'), - mock.call(' -> ??[??:0] 2002'), - mock.call(' hook_task (32) [??:0] 1000'), - mock.call('Unresolved indirect callsites:'), - mock.call(' In function console_task:'), - mock.call(' -> ??[??:0] 200a'), - mock.call('Unresolved annotation signatures:'), - mock.call(' fake_func: function is not found'), - ]) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run objdump.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.Analyze() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'objdump failed to disassemble.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.Analyze() - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine') - def testArmAnalyze(self, addrtoline_mock, checkoutput_mock): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: b508\t\tpush {r3, lr}\n' - ' 1002: 4770\t\tbx lr\n' - ' 1006: 00015cfc\t.word 0x00015cfc\n' - '00002000 <console_task>:\n' - ' 2000: b508\t\tpush {r3, lr}\n' - ' 2002: f00e fcc5\tbl 1000 <hook_task>\n' - ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n' - ' 200a: 1234 5678\tb.w sl\n' - ) - - addrtoline_mock.return_value = [('??', '??', 0)] - self.analyzer.annotation = { - 'exception_frame_size': 64, - 'remove': [['fake_func']], - } - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.return_value = disasm_text - self.analyzer.Analyze() - print_mock.assert_has_calls([ - mock.call( - 'Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048'), - mock.call('Call Trace:'), - mock.call(' hook_task (8) [??:0] 1000'), - mock.call( - 'Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460'), - mock.call('Call Trace:'), - mock.call(' console_task (8) [??:0] 2000'), - mock.call(' -> ??[??:0] 2002'), - mock.call(' hook_task (8) [??:0] 1000'), - mock.call('Unresolved indirect callsites:'), - mock.call(' In function console_task:'), - mock.call(' -> ??[??:0] 200a'), - mock.call('Unresolved annotation signatures:'), - mock.call(' fake_func: function is not found'), - ]) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run objdump.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.Analyze() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'objdump failed to disassemble.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.Analyze() - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.ParseArgs') - def testMain(self, parseargs_mock, checkoutput_mock): - symbol_text = ('1000 g F .text 0000015c .hidden hook_task\n' - '2000 g F .text 0000051c .hidden console_task\n') - rodata_text = ( - '\n' - 'Contents of section .rodata:\n' - ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n' - ) - - args = mock.MagicMock(elf_path='./ec.RW.elf', - export_taskinfo='fake', - section='RW', - objdump='objdump', - addr2line='addr2line', - annotation='fake') - parseargs_mock.return_value = args - - with mock.patch('os.path.exists') as path_mock: - path_mock.return_value = False - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - sa.main() - print_mock.assert_any_call( - 'Warning: Annotation file fake does not exist.') - - with mock.patch('os.path.exists') as path_mock: - path_mock.return_value = True - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - open_mock.side_effect = IOError() - sa.main() - print_mock.assert_called_once_with( - 'Error: Failed to open annotation file fake.') - - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - open_mock.return_value.read.side_effect = ['{', ''] - sa.main() - open_mock.assert_called_once_with('fake', 'r') - print_mock.assert_called_once_with( - 'Error: Failed to parse annotation file fake.') - - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', - mock.mock_open(read_data='')) as open_mock: - sa.main() - print_mock.assert_called_once_with( - 'Error: Invalid annotation file fake.') - - args.annotation = None - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = [symbol_text, rodata_text] - sa.main() - print_mock.assert_called_once_with( - 'Error: Failed to load export_taskinfo.') - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - sa.main() - print_mock.assert_called_once_with( - 'Error: objdump failed to dump symbol table or rodata.') - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = OSError() - sa.main() - print_mock.assert_called_once_with('Error: Failed to run objdump.') - - -if __name__ == '__main__': - unittest.main() + remove_list = [ + [funcs[0x1000]], + [funcs[0x2000], funcs[0x2000]], + [funcs[0x4000], funcs[0x1000]], + [funcs[0x2000], funcs[0x4000], funcs[0x2000]], + [funcs[0x4000], funcs[0x1000], funcs[0x4000]], + ] + eliminated_addrs = {0x2006} + + remaining_remove_list = self.analyzer.PreprocessAnnotation( + funcs, add_set, remove_list, eliminated_addrs + ) + + expect_funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + expect_funcs[0x2000].callsites = [ + sa.Callsite(None, 0x4000, False, expect_funcs[0x4000]) + ] + expect_funcs[0x4000].callsites = [ + sa.Callsite(None, 0x2000, False, expect_funcs[0x2000]) + ] + self.assertEqual(funcs, expect_funcs) + self.assertEqual( + remaining_remove_list, + [ + [funcs[0x2000], funcs[0x4000], funcs[0x2000]], + ], + ) + + def testAndesAnalyzeDisassembly(self): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n" + " 1004: 47 70\t\tmovi55 $r0, #1\n" + " 1006: b1 13\tbnezs8 100929de <flash_command_write>\n" + " 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n" + "00002000 <console_task>:\n" + " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n" + " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n" + " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n" + " 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n" + "00004000 <touchpad_calc>:\n" + " 4000: 47 70\t\tmovi55 $r0, #1\n" + "00010000 <look_task>:" + ) + function_map = self.analyzer.AnalyzeDisassembly(disasm_text) + func_hook_task = sa.Function( + 0x1000, "hook_task", 48, [sa.Callsite(0x1006, 0x100929DE, True, None)] + ) + expect_funcmap = { + 0x1000: func_hook_task, + 0x2000: sa.Function( + 0x2000, + "console_task", + 16, + [ + sa.Callsite(0x2002, 0x1000, False, func_hook_task), + sa.Callsite(0x2006, 0x53968, True, None), + ], + ), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + self.assertEqual(function_map, expect_funcmap) + + def testArmAnalyzeDisassembly(self): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: dead beef\tfake\n" + " 1004: 4770\t\tbx lr\n" + " 1006: b113\tcbz r3, 100929de <flash_command_write>\n" + " 1008: 00015cfc\t.word 0x00015cfc\n" + "00002000 <console_task>:\n" + " 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n" + " 2002: f00e fcc5\tbl 1000 <hook_task>\n" + " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n" + " 200a: dead beef\tfake\n" + "00004000 <touchpad_calc>:\n" + " 4000: 4770\t\tbx lr\n" + "00010000 <look_task>:" + ) + function_map = self.analyzer.AnalyzeDisassembly(disasm_text) + func_hook_task = sa.Function( + 0x1000, "hook_task", 0, [sa.Callsite(0x1006, 0x100929DE, True, None)] + ) + expect_funcmap = { + 0x1000: func_hook_task, + 0x2000: sa.Function( + 0x2000, + "console_task", + 8, + [ + sa.Callsite(0x2002, 0x1000, False, func_hook_task), + sa.Callsite(0x2006, 0x53968, True, None), + ], + ), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + self.assertEqual(function_map, expect_funcmap) + + def testAnalyzeCallGraph(self): + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 8, []), + 0x3000: sa.Function(0x3000, "task_a", 12, []), + 0x4000: sa.Function(0x4000, "task_b", 96, []), + 0x5000: sa.Function(0x5000, "task_c", 32, []), + 0x6000: sa.Function(0x6000, "task_d", 100, []), + 0x7000: sa.Function(0x7000, "task_e", 24, []), + 0x8000: sa.Function(0x8000, "task_f", 20, []), + 0x9000: sa.Function(0x9000, "task_g", 20, []), + 0x10000: sa.Function(0x10000, "task_x", 16, []), + } + funcs[0x1000].callsites = [ + sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]), + sa.Callsite(0x1006, 0x4000, False, funcs[0x4000]), + ] + funcs[0x2000].callsites = [ + sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]), + sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]), + sa.Callsite(0x200A, 0x10000, False, funcs[0x10000]), + ] + funcs[0x3000].callsites = [ + sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]), + sa.Callsite(0x3006, 0x1000, False, funcs[0x1000]), + ] + funcs[0x4000].callsites = [ + sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]), + sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]), + sa.Callsite(0x400A, 0x8000, False, funcs[0x8000]), + ] + funcs[0x5000].callsites = [sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])] + funcs[0x7000].callsites = [sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])] + funcs[0x8000].callsites = [sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])] + funcs[0x9000].callsites = [sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])] + funcs[0x10000].callsites = [sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])] + + cycles = self.analyzer.AnalyzeCallGraph( + funcs, + [ + [funcs[0x2000]] * 2, + [funcs[0x10000], funcs[0x2000]] * 3, + [funcs[0x1000], funcs[0x3000], funcs[0x1000]], + ], + ) + + expect_func_stack = { + 0x1000: ( + 268, + [ + funcs[0x1000], + funcs[0x3000], + funcs[0x4000], + funcs[0x8000], + funcs[0x9000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x2000: ( + 208, + [ + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x5000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x3000: ( + 280, + [ + funcs[0x3000], + funcs[0x1000], + funcs[0x3000], + funcs[0x4000], + funcs[0x8000], + funcs[0x9000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x4000: (120, [funcs[0x4000], funcs[0x7000]]), + 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]), + 0x6000: (100, [funcs[0x6000]]), + 0x7000: (24, [funcs[0x7000]]), + 0x8000: (160, [funcs[0x8000], funcs[0x9000], funcs[0x4000], funcs[0x7000]]), + 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]), + 0x10000: ( + 200, + [ + funcs[0x10000], + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x5000], + funcs[0x4000], + funcs[0x7000], + ], + ), + } + expect_cycles = [ + {funcs[0x4000], funcs[0x8000], funcs[0x9000]}, + {funcs[0x7000]}, + ] + for func in funcs.values(): + (stack_max_usage, stack_max_path) = expect_func_stack[func.address] + self.assertEqual(func.stack_max_usage, stack_max_usage) + self.assertEqual(func.stack_max_path, stack_max_path) + + self.assertEqual(len(cycles), len(expect_cycles)) + for cycle in cycles: + self.assertTrue(cycle in expect_cycles) + + @mock.patch("subprocess.check_output") + def testAddressToLine(self, checkoutput_mock): + checkoutput_mock.return_value = "fake_func\n/test.c:1" + self.assertEqual( + self.analyzer.AddressToLine(0x1234), [("fake_func", "/test.c", 1)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "1234"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "fake_func\n/a.c:1\nbake_func\n/b.c:2\n" + self.assertEqual( + self.analyzer.AddressToLine(0x1234, True), + [("fake_func", "/a.c", 1), ("bake_func", "/b.c", 2)], + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "1234", "-i"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "fake_func\n/test.c:1 (discriminator 128)" + self.assertEqual( + self.analyzer.AddressToLine(0x12345), [("fake_func", "/test.c", 1)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "12345"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "??\n:?\nbake_func\n/b.c:2\n" + self.assertEqual( + self.analyzer.AddressToLine(0x123456), [None, ("bake_func", "/b.c", 2)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "123456"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "addr2line failed to resolve lines." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.AddressToLine(0x5678) + + with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run addr2line."): + checkoutput_mock.side_effect = OSError() + self.analyzer.AddressToLine(0x9012) + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine") + def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n" + " 1002: 47 70\t\tmovi55 $r0, #1\n" + " 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n" + "00002000 <console_task>:\n" + " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n" + " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n" + " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n" + " 200a: 12 34 56 78\tjral5 $r0\n" + ) + + addrtoline_mock.return_value = [("??", "??", 0)] + self.analyzer.annotation = { + "exception_frame_size": 64, + "remove": [["fake_func"]], + } + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.return_value = disasm_text + self.analyzer.Analyze() + print_mock.assert_has_calls( + [ + mock.call( + "Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048" + ), + mock.call("Call Trace:"), + mock.call(" hook_task (32) [??:0] 1000"), + mock.call( + "Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460" + ), + mock.call("Call Trace:"), + mock.call(" console_task (16) [??:0] 2000"), + mock.call(" -> ??[??:0] 2002"), + mock.call(" hook_task (32) [??:0] 1000"), + mock.call("Unresolved indirect callsites:"), + mock.call(" In function console_task:"), + mock.call(" -> ??[??:0] 200a"), + mock.call("Unresolved annotation signatures:"), + mock.call(" fake_func: function is not found"), + ] + ) + + with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."): + checkoutput_mock.side_effect = OSError() + self.analyzer.Analyze() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "objdump failed to disassemble." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.Analyze() + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine") + def testArmAnalyze(self, addrtoline_mock, checkoutput_mock): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: b508\t\tpush {r3, lr}\n" + " 1002: 4770\t\tbx lr\n" + " 1006: 00015cfc\t.word 0x00015cfc\n" + "00002000 <console_task>:\n" + " 2000: b508\t\tpush {r3, lr}\n" + " 2002: f00e fcc5\tbl 1000 <hook_task>\n" + " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n" + " 200a: 1234 5678\tb.w sl\n" + ) + + addrtoline_mock.return_value = [("??", "??", 0)] + self.analyzer.annotation = { + "exception_frame_size": 64, + "remove": [["fake_func"]], + } + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.return_value = disasm_text + self.analyzer.Analyze() + print_mock.assert_has_calls( + [ + mock.call( + "Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048" + ), + mock.call("Call Trace:"), + mock.call(" hook_task (8) [??:0] 1000"), + mock.call( + "Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460" + ), + mock.call("Call Trace:"), + mock.call(" console_task (8) [??:0] 2000"), + mock.call(" -> ??[??:0] 2002"), + mock.call(" hook_task (8) [??:0] 1000"), + mock.call("Unresolved indirect callsites:"), + mock.call(" In function console_task:"), + mock.call(" -> ??[??:0] 200a"), + mock.call("Unresolved annotation signatures:"), + mock.call(" fake_func: function is not found"), + ] + ) + + with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."): + checkoutput_mock.side_effect = OSError() + self.analyzer.Analyze() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "objdump failed to disassemble." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.Analyze() + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.ParseArgs") + def testMain(self, parseargs_mock, checkoutput_mock): + symbol_text = ( + "1000 g F .text 0000015c .hidden hook_task\n" + "2000 g F .text 0000051c .hidden console_task\n" + ) + rodata_text = ( + "\n" + "Contents of section .rodata:\n" + " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n" + ) + + args = mock.MagicMock( + elf_path="./ec.RW.elf", + export_taskinfo="fake", + section="RW", + objdump="objdump", + addr2line="addr2line", + annotation="fake", + ) + parseargs_mock.return_value = args + + with mock.patch("os.path.exists") as path_mock: + path_mock.return_value = False + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + sa.main() + print_mock.assert_any_call( + "Warning: Annotation file fake does not exist." + ) + + with mock.patch("os.path.exists") as path_mock: + path_mock.return_value = True + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + open_mock.side_effect = IOError() + sa.main() + print_mock.assert_called_once_with( + "Error: Failed to open annotation file fake." + ) + + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + open_mock.return_value.read.side_effect = ["{", ""] + sa.main() + open_mock.assert_called_once_with("fake", "r") + print_mock.assert_called_once_with( + "Error: Failed to parse annotation file fake." + ) + + with mock.patch("builtins.print") as print_mock: + with mock.patch( + "builtins.open", mock.mock_open(read_data="") + ) as open_mock: + sa.main() + print_mock.assert_called_once_with( + "Error: Invalid annotation file fake." + ) + + args.annotation = None + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = [symbol_text, rodata_text] + sa.main() + print_mock.assert_called_once_with("Error: Failed to load export_taskinfo.") + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + sa.main() + print_mock.assert_called_once_with( + "Error: objdump failed to dump symbol table or rodata." + ) + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = OSError() + sa.main() + print_mock.assert_called_once_with("Error: Failed to run objdump.") + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/tigertool/ecusb/__init__.py b/extra/tigertool/ecusb/__init__.py index fe4dbc6749..7a48fdb360 100644 --- a/extra/tigertool/ecusb/__init__.py +++ b/extra/tigertool/ecusb/__init__.py @@ -6,4 +6,4 @@ # pylint: disable=bad-indentation,docstring-section-indent # pylint: disable=docstring-trailing-quotes -__all__ = ['tiny_servo_common', 'stm32usb', 'stm32uart', 'pty_driver'] +__all__ = ["tiny_servo_common", "stm32usb", "stm32uart", "pty_driver"] diff --git a/extra/tigertool/ecusb/pty_driver.py b/extra/tigertool/ecusb/pty_driver.py index 09ef8c42e4..037ff8d529 100644 --- a/extra/tigertool/ecusb/pty_driver.py +++ b/extra/tigertool/ecusb/pty_driver.py @@ -17,8 +17,9 @@ import ast import errno import fcntl import os -import pexpect import time + +import pexpect from pexpect import fdpexpect # Expecting a result in 3 seconds is plenty even for slow platforms. @@ -27,281 +28,285 @@ FLUSH_UART_TIMEOUT = 1 class ptyError(Exception): - """Exception class for pty errors.""" + """Exception class for pty errors.""" UART_PARAMS = { - 'uart_cmd': None, - 'uart_multicmd': None, - 'uart_regexp': None, - 'uart_timeout': DEFAULT_UART_TIMEOUT, + "uart_cmd": None, + "uart_multicmd": None, + "uart_regexp": None, + "uart_timeout": DEFAULT_UART_TIMEOUT, } class ptyDriver(object): - """Automate interactive commands on a pty interface.""" - def __init__(self, interface, params, fast=False): - """Init class variables.""" - self._child = None - self._fd = None - self._interface = interface - self._pty_path = self._interface.get_pty() - self._dict = UART_PARAMS.copy() - self._fast = fast - - def __del__(self): - self.close() - - def close(self): - """Close any open files and interfaces.""" - if self._fd: - self._close() - self._interface.close() - - def _open(self): - """Connect to serial device and create pexpect interface.""" - assert self._fd is None - self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK) - # Don't allow forked processes to access. - fcntl.fcntl(self._fd, fcntl.F_SETFD, - fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC) - self._child = fdpexpect.fdspawn(self._fd) - # pexpect defaults to a 100ms delay before sending characters, to - # work around race conditions in ssh. We don't need this feature - # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up. - if self._fast: - self._child.delaybeforesend = 0.001 - - def _close(self): - """Close serial device connection.""" - os.close(self._fd) - self._fd = None - self._child = None - - def _flush(self): - """Flush device output to prevent previous messages interfering.""" - if self._child.sendline('') != 1: - raise ptyError('Failed to send newline.') - # Have a maximum timeout for the flush operation. We should have cleared - # all data from the buffer, but if data is regularly being generated, we - # can't guarantee it will ever stop. - flush_end_time = time.time() + FLUSH_UART_TIMEOUT - while time.time() <= flush_end_time: - try: - self._child.expect('.', timeout=0.01) - except (pexpect.TIMEOUT, pexpect.EOF): - break - except OSError as e: - # EAGAIN indicates no data available, maybe we didn't wait long enough. - if e.errno != errno.EAGAIN: - raise - break - - def _send(self, cmds): - """Send command to EC. - - This function always flushes serial device before sending, and is used as - a wrapper function to make sure the channel is always flushed before - sending commands. - - Args: - cmds: The commands to send to the device, either a list or a string. - - Raises: - ptyError: Raised when writing to the device fails. - """ - self._flush() - if not isinstance(cmds, list): - cmds = [cmds] - for cmd in cmds: - if self._child.sendline(cmd) != len(cmd) + 1: - raise ptyError('Failed to send command.') - - def _issue_cmd(self, cmds): - """Send command to the device and do not wait for response. - - Args: - cmds: The commands to send to the device, either a list or a string. - """ - self._issue_cmd_get_results(cmds, []) - - def _issue_cmd_get_results(self, cmds, - regex_list, timeout=DEFAULT_UART_TIMEOUT): - """Send command to the device and wait for response. - - This function waits for response message matching a regular - expressions. - - Args: - cmds: The commands issued, either a list or a string. - regex_list: List of Regular expressions used to match response message. - Note1, list must be ordered. - Note2, empty list sends and returns. - timeout: time to wait for matching results before failing. - - Returns: - List of tuples, each of which contains the entire matched string and - all the subgroups of the match. None if not matched. - For example: - response of the given command: - High temp: 37.2 - Low temp: 36.4 - regex_list: - ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)'] - returns: - [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')] - - Raises: - ptyError: If timed out waiting for a response - """ - result_list = [] - self._open() - try: - self._send(cmds) - for regex in regex_list: - self._child.expect(regex, timeout) - match = self._child.match - lastindex = match.lastindex if match and match.lastindex else 0 - # Create a tuple which contains the entire matched string and all - # the subgroups of the match. - result = match.group(*range(lastindex + 1)) if match else None - if result: - result = tuple(res.decode('utf-8') for res in result) - result_list.append(result) - except pexpect.TIMEOUT: - raise ptyError('Timeout waiting for response.') - finally: - if not regex_list: - # Must be longer than delaybeforesend - time.sleep(0.1) - self._close() - return result_list - - def _issue_cmd_get_multi_results(self, cmd, regex): - """Send command to the device and wait for multiple response. - - This function waits for arbitrary number of response message - matching a regular expression. - - Args: - cmd: The command issued. - regex: Regular expression used to match response message. - - Returns: - List of tuples, each of which contains the entire matched string and - all the subgroups of the match. None if not matched. - """ - result_list = [] - self._open() - try: - self._send(cmd) - while True: + """Automate interactive commands on a pty interface.""" + + def __init__(self, interface, params, fast=False): + """Init class variables.""" + self._child = None + self._fd = None + self._interface = interface + self._pty_path = self._interface.get_pty() + self._dict = UART_PARAMS.copy() + self._fast = fast + + def __del__(self): + self.close() + + def close(self): + """Close any open files and interfaces.""" + if self._fd: + self._close() + self._interface.close() + + def _open(self): + """Connect to serial device and create pexpect interface.""" + assert self._fd is None + self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK) + # Don't allow forked processes to access. + fcntl.fcntl( + self._fd, + fcntl.F_SETFD, + fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC, + ) + self._child = fdpexpect.fdspawn(self._fd) + # pexpect defaults to a 100ms delay before sending characters, to + # work around race conditions in ssh. We don't need this feature + # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up. + if self._fast: + self._child.delaybeforesend = 0.001 + + def _close(self): + """Close serial device connection.""" + os.close(self._fd) + self._fd = None + self._child = None + + def _flush(self): + """Flush device output to prevent previous messages interfering.""" + if self._child.sendline("") != 1: + raise ptyError("Failed to send newline.") + # Have a maximum timeout for the flush operation. We should have cleared + # all data from the buffer, but if data is regularly being generated, we + # can't guarantee it will ever stop. + flush_end_time = time.time() + FLUSH_UART_TIMEOUT + while time.time() <= flush_end_time: + try: + self._child.expect(".", timeout=0.01) + except (pexpect.TIMEOUT, pexpect.EOF): + break + except OSError as e: + # EAGAIN indicates no data available, maybe we didn't wait long enough. + if e.errno != errno.EAGAIN: + raise + break + + def _send(self, cmds): + """Send command to EC. + + This function always flushes serial device before sending, and is used as + a wrapper function to make sure the channel is always flushed before + sending commands. + + Args: + cmds: The commands to send to the device, either a list or a string. + + Raises: + ptyError: Raised when writing to the device fails. + """ + self._flush() + if not isinstance(cmds, list): + cmds = [cmds] + for cmd in cmds: + if self._child.sendline(cmd) != len(cmd) + 1: + raise ptyError("Failed to send command.") + + def _issue_cmd(self, cmds): + """Send command to the device and do not wait for response. + + Args: + cmds: The commands to send to the device, either a list or a string. + """ + self._issue_cmd_get_results(cmds, []) + + def _issue_cmd_get_results(self, cmds, regex_list, timeout=DEFAULT_UART_TIMEOUT): + """Send command to the device and wait for response. + + This function waits for response message matching a regular + expressions. + + Args: + cmds: The commands issued, either a list or a string. + regex_list: List of Regular expressions used to match response message. + Note1, list must be ordered. + Note2, empty list sends and returns. + timeout: time to wait for matching results before failing. + + Returns: + List of tuples, each of which contains the entire matched string and + all the subgroups of the match. None if not matched. + For example: + response of the given command: + High temp: 37.2 + Low temp: 36.4 + regex_list: + ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)'] + returns: + [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')] + + Raises: + ptyError: If timed out waiting for a response + """ + result_list = [] + self._open() try: - self._child.expect(regex, timeout=0.1) - match = self._child.match - lastindex = match.lastindex if match and match.lastindex else 0 - # Create a tuple which contains the entire matched string and all - # the subgroups of the match. - result = match.group(*range(lastindex + 1)) if match else None - if result: - result = tuple(res.decode('utf-8') for res in result) - result_list.append(result) + self._send(cmds) + for regex in regex_list: + self._child.expect(regex, timeout) + match = self._child.match + lastindex = match.lastindex if match and match.lastindex else 0 + # Create a tuple which contains the entire matched string and all + # the subgroups of the match. + result = match.group(*range(lastindex + 1)) if match else None + if result: + result = tuple(res.decode("utf-8") for res in result) + result_list.append(result) except pexpect.TIMEOUT: - break - finally: - self._close() - return result_list - - def _Set_uart_timeout(self, timeout): - """Set timeout value for waiting for the device response. - - Args: - timeout: Timeout value in second. - """ - self._dict['uart_timeout'] = timeout - - def _Get_uart_timeout(self): - """Get timeout value for waiting for the device response. - - Returns: - Timeout value in second. - """ - return self._dict['uart_timeout'] - - def _Set_uart_regexp(self, regexp): - """Set the list of regular expressions which matches the command response. - - Args: - regexp: A string which contains a list of regular expressions. - """ - if not isinstance(regexp, str): - raise ptyError('The argument regexp should be a string.') - self._dict['uart_regexp'] = ast.literal_eval(regexp) - - def _Get_uart_regexp(self): - """Get the list of regular expressions which matches the command response. - - Returns: - A string which contains a list of regular expressions. - """ - return str(self._dict['uart_regexp']) - - def _Set_uart_cmd(self, cmd): - """Set the UART command and send it to the device. - - If ec_uart_regexp is 'None', the command is just sent and it doesn't care - about its response. - - If ec_uart_regexp is not 'None', the command is send and its response, - which matches the regular expression of ec_uart_regexp, will be kept. - Use its getter to obtain this result. If no match after ec_uart_timeout - seconds, a timeout error will be raised. - - Args: - cmd: A string of UART command. - """ - if self._dict['uart_regexp']: - self._dict['uart_cmd'] = self._issue_cmd_get_results( - cmd, self._dict['uart_regexp'], self._dict['uart_timeout']) - else: - self._dict['uart_cmd'] = None - self._issue_cmd(cmd) - - def _Set_uart_multicmd(self, cmds): - """Set multiple UART commands and send them to the device. - - Note that ec_uart_regexp is not supported to match the results. - - Args: - cmds: A semicolon-separated string of UART commands. - """ - self._issue_cmd(cmds.split(';')) - - def _Get_uart_cmd(self): - """Get the result of the latest UART command. - - Returns: - A string which contains a list of tuples, each of which contains the - entire matched string and all the subgroups of the match. 'None' if - the ec_uart_regexp is 'None'. - """ - return str(self._dict['uart_cmd']) - - def _Set_uart_capture(self, cmd): - """Set UART capture mode (on or off). - - Once capture is enabled, UART output could be collected periodically by - invoking _Get_uart_stream() below. - - Args: - cmd: True for on, False for off - """ - self._interface.set_capture_active(cmd) - - def _Get_uart_capture(self): - """Get the UART capture mode (on or off).""" - return self._interface.get_capture_active() - - def _Get_uart_stream(self): - """Get uart stream generated since last time.""" - return self._interface.get_stream() + raise ptyError("Timeout waiting for response.") + finally: + if not regex_list: + # Must be longer than delaybeforesend + time.sleep(0.1) + self._close() + return result_list + + def _issue_cmd_get_multi_results(self, cmd, regex): + """Send command to the device and wait for multiple response. + + This function waits for arbitrary number of response message + matching a regular expression. + + Args: + cmd: The command issued. + regex: Regular expression used to match response message. + + Returns: + List of tuples, each of which contains the entire matched string and + all the subgroups of the match. None if not matched. + """ + result_list = [] + self._open() + try: + self._send(cmd) + while True: + try: + self._child.expect(regex, timeout=0.1) + match = self._child.match + lastindex = match.lastindex if match and match.lastindex else 0 + # Create a tuple which contains the entire matched string and all + # the subgroups of the match. + result = match.group(*range(lastindex + 1)) if match else None + if result: + result = tuple(res.decode("utf-8") for res in result) + result_list.append(result) + except pexpect.TIMEOUT: + break + finally: + self._close() + return result_list + + def _Set_uart_timeout(self, timeout): + """Set timeout value for waiting for the device response. + + Args: + timeout: Timeout value in second. + """ + self._dict["uart_timeout"] = timeout + + def _Get_uart_timeout(self): + """Get timeout value for waiting for the device response. + + Returns: + Timeout value in second. + """ + return self._dict["uart_timeout"] + + def _Set_uart_regexp(self, regexp): + """Set the list of regular expressions which matches the command response. + + Args: + regexp: A string which contains a list of regular expressions. + """ + if not isinstance(regexp, str): + raise ptyError("The argument regexp should be a string.") + self._dict["uart_regexp"] = ast.literal_eval(regexp) + + def _Get_uart_regexp(self): + """Get the list of regular expressions which matches the command response. + + Returns: + A string which contains a list of regular expressions. + """ + return str(self._dict["uart_regexp"]) + + def _Set_uart_cmd(self, cmd): + """Set the UART command and send it to the device. + + If ec_uart_regexp is 'None', the command is just sent and it doesn't care + about its response. + + If ec_uart_regexp is not 'None', the command is send and its response, + which matches the regular expression of ec_uart_regexp, will be kept. + Use its getter to obtain this result. If no match after ec_uart_timeout + seconds, a timeout error will be raised. + + Args: + cmd: A string of UART command. + """ + if self._dict["uart_regexp"]: + self._dict["uart_cmd"] = self._issue_cmd_get_results( + cmd, self._dict["uart_regexp"], self._dict["uart_timeout"] + ) + else: + self._dict["uart_cmd"] = None + self._issue_cmd(cmd) + + def _Set_uart_multicmd(self, cmds): + """Set multiple UART commands and send them to the device. + + Note that ec_uart_regexp is not supported to match the results. + + Args: + cmds: A semicolon-separated string of UART commands. + """ + self._issue_cmd(cmds.split(";")) + + def _Get_uart_cmd(self): + """Get the result of the latest UART command. + + Returns: + A string which contains a list of tuples, each of which contains the + entire matched string and all the subgroups of the match. 'None' if + the ec_uart_regexp is 'None'. + """ + return str(self._dict["uart_cmd"]) + + def _Set_uart_capture(self, cmd): + """Set UART capture mode (on or off). + + Once capture is enabled, UART output could be collected periodically by + invoking _Get_uart_stream() below. + + Args: + cmd: True for on, False for off + """ + self._interface.set_capture_active(cmd) + + def _Get_uart_capture(self): + """Get the UART capture mode (on or off).""" + return self._interface.get_capture_active() + + def _Get_uart_stream(self): + """Get uart stream generated since last time.""" + return self._interface.get_stream() diff --git a/extra/tigertool/ecusb/stm32uart.py b/extra/tigertool/ecusb/stm32uart.py index 95219455a9..5794bd091b 100644 --- a/extra/tigertool/ecusb/stm32uart.py +++ b/extra/tigertool/ecusb/stm32uart.py @@ -17,232 +17,244 @@ import termios import threading import time import tty + import usb from . import stm32usb class SuartError(Exception): - """Class for exceptions of Suart.""" - def __init__(self, msg, value=0): - """SuartError constructor. + """Class for exceptions of Suart.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SuartError, self).__init__(msg, value) - self.msg = msg - self.value = value + def __init__(self, msg, value=0): + """SuartError constructor. + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SuartError, self).__init__(msg, value) + self.msg = msg + self.value = value -class Suart(object): - """Provide interface to stm32 serial usb endpoint.""" - def __init__(self, vendor=0x18d1, product=0x501a, interface=0, - serialname=None, debuglog=False): - """Suart contstructor. - - Initializes stm32 USB stream interface. - - Args: - vendor: usb vendor id of stm32 device - product: usb product id of stm32 device - interface: interface number of stm32 device to use - serialname: serial name to target. Defaults to None. - debuglog: chatty output. Defaults to False. - - Raises: - SuartError: If init fails - """ - self._ptym = None - self._ptys = None - self._ptyname = None - self._rx_thread = None - self._tx_thread = None - self._debuglog = debuglog - self._susb = stm32usb.Susb(vendor=vendor, product=product, - interface=interface, serialname=serialname) - self._running = False - - def __del__(self): - """Suart destructor.""" - self.close() - - def close(self): - """Stop all running threads.""" - self._running = False - if self._rx_thread: - self._rx_thread.join(2) - self._rx_thread = None - if self._tx_thread: - self._tx_thread.join(2) - self._tx_thread = None - self._susb.close() - - def run_rx_thread(self): - """Background loop to pass data from USB to pty.""" - ep = select.epoll() - ep.register(self._ptym, select.EPOLLHUP) - try: - while self._running: - events = ep.poll(0) - # Check if the pty is connected to anything, or hungup. - if not events: - try: - r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) - if r: - if self._debuglog: - print(''.join([chr(x) for x in r]), end='') - os.write(self._ptym, r) - - # If we miss some characters on pty disconnect, that's fine. - # ep.read() also throws USBError on timeout, which we discard. - except OSError: - pass - except usb.core.USBError: - pass - else: - time.sleep(.1) - except Exception as e: - raise e - - def run_tx_thread(self): - """Background loop to pass data from pty to USB.""" - ep = select.epoll() - ep.register(self._ptym, select.EPOLLHUP) - try: - while self._running: - events = ep.poll(0) - # Check if the pty is connected to anything, or hungup. - if not events: - try: - r = os.read(self._ptym, 64) - # TODO(crosbug.com/936182): Remove when the servo v4/micro console - # issues are fixed. - time.sleep(0.001) - if r: - self._susb._write_ep.write(r, self._susb.TIMEOUT_MS) - - except OSError: - pass - except usb.core.USBError: - pass - else: - time.sleep(.1) - except Exception as e: - raise e - - def run(self): - """Creates pthreads to poll stm32 & PTY for data.""" - m, s = os.openpty() - self._ptyname = os.ttyname(s) - - self._ptym = m - self._ptys = s - - os.fchmod(s, 0o660) - - # Change the owner and group of the PTY to the user who started servod. - try: - uid = int(os.environ.get('SUDO_UID', -1)) - except TypeError: - uid = -1 - try: - gid = int(os.environ.get('SUDO_GID', -1)) - except TypeError: - gid = -1 - os.fchown(s, uid, gid) - - tty.setraw(self._ptym, termios.TCSADRAIN) - - # Generate a HUP flag on pty slave fd. - os.fdopen(s).close() - - self._running = True - - self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[]) - self._rx_thread.daemon = True - self._rx_thread.start() - - self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[]) - self._tx_thread.daemon = True - self._tx_thread.start() - - def get_uart_props(self): - """Get the uart's properties. - - Returns: - dict where: - baudrate: integer of uarts baudrate - bits: integer, number of bits of data Can be 5|6|7|8 inclusive - parity: integer, parity of 0-2 inclusive where: - 0: no parity - 1: odd parity - 2: even parity - sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: - 0: 1 stop bit - 1: 1.5 stop bits - 2: 2 stop bits - """ - return { - 'baudrate': 115200, - 'bits': 8, - 'parity': 0, - 'sbits': 1, - } - - def set_uart_props(self, line_props): - """Set the uart's properties. - - Note that Suart cannot set properties - and will fail if the properties are not the default 115200,8n1. - - Args: - line_props: dict where: - baudrate: integer of uarts baudrate - bits: integer, number of bits of data ( prior to stop bit) - parity: integer, parity of 0-2 inclusive where - 0: no parity - 1: odd parity - 2: even parity - sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: - 0: 1 stop bit - 1: 1.5 stop bits - 2: 2 stop bits - - Raises: - SuartError: If requested line properties are not the default. - """ - curr_props = self.get_uart_props() - for prop in line_props: - if line_props[prop] != curr_props[prop]: - raise SuartError('Line property %s cannot be set from %s to %s' % ( - prop, curr_props[prop], line_props[prop])) - return True - - def get_pty(self): - """Gets path to pty for communication to/from uart. - - Returns: - String path to the pty connected to the uart - """ - return self._ptyname +class Suart(object): + """Provide interface to stm32 serial usb endpoint.""" + + def __init__( + self, + vendor=0x18D1, + product=0x501A, + interface=0, + serialname=None, + debuglog=False, + ): + """Suart contstructor. + + Initializes stm32 USB stream interface. + + Args: + vendor: usb vendor id of stm32 device + product: usb product id of stm32 device + interface: interface number of stm32 device to use + serialname: serial name to target. Defaults to None. + debuglog: chatty output. Defaults to False. + + Raises: + SuartError: If init fails + """ + self._ptym = None + self._ptys = None + self._ptyname = None + self._rx_thread = None + self._tx_thread = None + self._debuglog = debuglog + self._susb = stm32usb.Susb( + vendor=vendor, product=product, interface=interface, serialname=serialname + ) + self._running = False + + def __del__(self): + """Suart destructor.""" + self.close() + + def close(self): + """Stop all running threads.""" + self._running = False + if self._rx_thread: + self._rx_thread.join(2) + self._rx_thread = None + if self._tx_thread: + self._tx_thread.join(2) + self._tx_thread = None + self._susb.close() + + def run_rx_thread(self): + """Background loop to pass data from USB to pty.""" + ep = select.epoll() + ep.register(self._ptym, select.EPOLLHUP) + try: + while self._running: + events = ep.poll(0) + # Check if the pty is connected to anything, or hungup. + if not events: + try: + r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) + if r: + if self._debuglog: + print("".join([chr(x) for x in r]), end="") + os.write(self._ptym, r) + + # If we miss some characters on pty disconnect, that's fine. + # ep.read() also throws USBError on timeout, which we discard. + except OSError: + pass + except usb.core.USBError: + pass + else: + time.sleep(0.1) + except Exception as e: + raise e + + def run_tx_thread(self): + """Background loop to pass data from pty to USB.""" + ep = select.epoll() + ep.register(self._ptym, select.EPOLLHUP) + try: + while self._running: + events = ep.poll(0) + # Check if the pty is connected to anything, or hungup. + if not events: + try: + r = os.read(self._ptym, 64) + # TODO(crosbug.com/936182): Remove when the servo v4/micro console + # issues are fixed. + time.sleep(0.001) + if r: + self._susb._write_ep.write(r, self._susb.TIMEOUT_MS) + + except OSError: + pass + except usb.core.USBError: + pass + else: + time.sleep(0.1) + except Exception as e: + raise e + + def run(self): + """Creates pthreads to poll stm32 & PTY for data.""" + m, s = os.openpty() + self._ptyname = os.ttyname(s) + + self._ptym = m + self._ptys = s + + os.fchmod(s, 0o660) + + # Change the owner and group of the PTY to the user who started servod. + try: + uid = int(os.environ.get("SUDO_UID", -1)) + except TypeError: + uid = -1 + + try: + gid = int(os.environ.get("SUDO_GID", -1)) + except TypeError: + gid = -1 + os.fchown(s, uid, gid) + + tty.setraw(self._ptym, termios.TCSADRAIN) + + # Generate a HUP flag on pty slave fd. + os.fdopen(s).close() + + self._running = True + + self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[]) + self._rx_thread.daemon = True + self._rx_thread.start() + + self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[]) + self._tx_thread.daemon = True + self._tx_thread.start() + + def get_uart_props(self): + """Get the uart's properties. + + Returns: + dict where: + baudrate: integer of uarts baudrate + bits: integer, number of bits of data Can be 5|6|7|8 inclusive + parity: integer, parity of 0-2 inclusive where: + 0: no parity + 1: odd parity + 2: even parity + sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: + 0: 1 stop bit + 1: 1.5 stop bits + 2: 2 stop bits + """ + return { + "baudrate": 115200, + "bits": 8, + "parity": 0, + "sbits": 1, + } + + def set_uart_props(self, line_props): + """Set the uart's properties. + + Note that Suart cannot set properties + and will fail if the properties are not the default 115200,8n1. + + Args: + line_props: dict where: + baudrate: integer of uarts baudrate + bits: integer, number of bits of data ( prior to stop bit) + parity: integer, parity of 0-2 inclusive where + 0: no parity + 1: odd parity + 2: even parity + sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: + 0: 1 stop bit + 1: 1.5 stop bits + 2: 2 stop bits + + Raises: + SuartError: If requested line properties are not the default. + """ + curr_props = self.get_uart_props() + for prop in line_props: + if line_props[prop] != curr_props[prop]: + raise SuartError( + "Line property %s cannot be set from %s to %s" + % (prop, curr_props[prop], line_props[prop]) + ) + return True + + def get_pty(self): + """Gets path to pty for communication to/from uart. + + Returns: + String path to the pty connected to the uart + """ + return self._ptyname def main(): - """Run a suart test with the default parameters.""" - try: - sobj = Suart() - sobj.run() + """Run a suart test with the default parameters.""" + try: + sobj = Suart() + sobj.run() - # run() is a thread so just busy wait to mimic server. - while True: - # Ours sleeps to eleven! - time.sleep(11) - except KeyboardInterrupt: - sys.exit(0) + # run() is a thread so just busy wait to mimic server. + while True: + # Ours sleeps to eleven! + time.sleep(11) + except KeyboardInterrupt: + sys.exit(0) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/tigertool/ecusb/stm32usb.py b/extra/tigertool/ecusb/stm32usb.py index bfd5fbb1fb..4b2b23fbac 100644 --- a/extra/tigertool/ecusb/stm32usb.py +++ b/extra/tigertool/ecusb/stm32usb.py @@ -12,108 +12,115 @@ import usb class SusbError(Exception): - """Class for exceptions of Susb.""" - def __init__(self, msg, value=0): - """SusbError constructor. + """Class for exceptions of Susb.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SusbError, self).__init__(msg, value) - self.msg = msg - self.value = value + def __init__(self, msg, value=0): + """SusbError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SusbError, self).__init__(msg, value) + self.msg = msg + self.value = value class Susb(object): - """Provide stm32 USB functionality. - - Instance Variables: - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - READ_ENDPOINT = 0x81 - WRITE_ENDPOINT = 0x1 - TIMEOUT_MS = 100 - - def __init__(self, vendor=0x18d1, - product=0x5027, interface=1, serialname=None, logger=None): - """Susb constructor. - - Discovers and connects to stm32 USB endpoints. - - Args: - vendor: usb vendor id of stm32 device. - product: usb product id of stm32 device. - interface: interface number ( 1 - 4 ) of stm32 device to use. - serialname: string of device serialname. - logger: none - - Raises: - SusbError: An error accessing Susb object + """Provide stm32 USB functionality. + + Instance Variables: + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - self._vendor = vendor - self._product = product - self._interface = interface - self._serialname = serialname - self._find_device() - - def _find_device(self): - """Set up the usb endpoint""" - # Find the stm32. - dev_g = usb.core.find(idVendor=self._vendor, idProduct=self._product, - find_all=True) - dev_list = list(dev_g) - - if not dev_list: - raise SusbError('USB device not found') - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if self._serialname: - for d in dev_list: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == self._serialname: - dev = d - break - if dev is None: - raise SusbError('USB device(%s) not found' % self._serialname) - else: - try: - dev = dev_list[0] - except StopIteration: - raise SusbError('USB device %04x:%04x not found' % ( - self._vendor, self._product)) - - # If we can't set configuration, it's already been set. - try: - dev.set_configuration() - except usb.core.USBError: - pass - - self._dev = dev - - # Get an endpoint instance. - cfg = dev.get_active_configuration() - intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface) - self._intf = intf - if not intf: - raise SusbError('Interface %04x:%04x - 0x%x not found' % ( - self._vendor, self._product, self._interface)) - - # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel - # module driver that produces a ttyUSB, or direct endpoint access, but - # can't do both at the same time. - if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: - dev.detach_kernel_driver(intf.bInterfaceNumber) - - read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT - read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) - self._read_ep = read_ep - - write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT - write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) - self._write_ep = write_ep - - def close(self): - usb.util.dispose_resources(self._dev) + + READ_ENDPOINT = 0x81 + WRITE_ENDPOINT = 0x1 + TIMEOUT_MS = 100 + + def __init__( + self, vendor=0x18D1, product=0x5027, interface=1, serialname=None, logger=None + ): + """Susb constructor. + + Discovers and connects to stm32 USB endpoints. + + Args: + vendor: usb vendor id of stm32 device. + product: usb product id of stm32 device. + interface: interface number ( 1 - 4 ) of stm32 device to use. + serialname: string of device serialname. + logger: none + + Raises: + SusbError: An error accessing Susb object + """ + self._vendor = vendor + self._product = product + self._interface = interface + self._serialname = serialname + self._find_device() + + def _find_device(self): + """Set up the usb endpoint""" + # Find the stm32. + dev_g = usb.core.find( + idVendor=self._vendor, idProduct=self._product, find_all=True + ) + dev_list = list(dev_g) + + if not dev_list: + raise SusbError("USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if self._serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == self._serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % self._serialname) + else: + try: + dev = dev_list[0] + except StopIteration: + raise SusbError( + "USB device %04x:%04x not found" % (self._vendor, self._product) + ) + + # If we can't set configuration, it's already been set. + try: + dev.set_configuration() + except usb.core.USBError: + pass + + self._dev = dev + + # Get an endpoint instance. + cfg = dev.get_active_configuration() + intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface) + self._intf = intf + if not intf: + raise SusbError( + "Interface %04x:%04x - 0x%x not found" + % (self._vendor, self._product, self._interface) + ) + + # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel + # module driver that produces a ttyUSB, or direct endpoint access, but + # can't do both at the same time. + if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: + dev.detach_kernel_driver(intf.bInterfaceNumber) + + read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT + read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) + self._read_ep = read_ep + + write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT + write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) + self._write_ep = write_ep + + def close(self): + usb.util.dispose_resources(self._dev) diff --git a/extra/tigertool/ecusb/tiny_servo_common.py b/extra/tigertool/ecusb/tiny_servo_common.py index 726e2e64b7..599bae80dc 100644 --- a/extra/tigertool/ecusb/tiny_servo_common.py +++ b/extra/tigertool/ecusb/tiny_servo_common.py @@ -17,215 +17,224 @@ import time import six import usb -from . import pty_driver -from . import stm32uart +from . import pty_driver, stm32uart def get_subprocess_args(): - if six.PY3: - return {'encoding': 'utf-8'} - return {} + if six.PY3: + return {"encoding": "utf-8"} + return {} class TinyServoError(Exception): - """Exceptions.""" + """Exceptions.""" def log(output): - """Print output to console, logfiles can be added here. + """Print output to console, logfiles can be added here. + + Args: + output: string to output. + """ + sys.stdout.write(output) + sys.stdout.write("\n") + sys.stdout.flush() - Args: - output: string to output. - """ - sys.stdout.write(output) - sys.stdout.write('\n') - sys.stdout.flush() def check_usb(vidpid, serialname=None): - """Check if |vidpid| is present on the system's USB. + """Check if |vidpid| is present on the system's USB. + + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + Returns: + True if found, False, otherwise. + """ + if get_usb_dev(vidpid, serialname): + return True - Returns: - True if found, False, otherwise. - """ - if get_usb_dev(vidpid, serialname): - return True + return False - return False def check_usb_sn(vidpid): - """Return the serial number + """Return the serial number - Return the serial number of the first USB device with VID:PID vidpid, - or None if no device is found. This will not work well with two of - the same device attached. + Return the serial number of the first USB device with VID:PID vidpid, + or None if no device is found. This will not work well with two of + the same device attached. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - Returns: - string serial number if found, None otherwise. - """ - dev = get_usb_dev(vidpid) + Returns: + string serial number if found, None otherwise. + """ + dev = get_usb_dev(vidpid) - if dev: - dev_serial = usb.util.get_string(dev, dev.iSerialNumber) + if dev: + dev_serial = usb.util.get_string(dev, dev.iSerialNumber) - return dev_serial + return dev_serial + + return None - return None def get_usb_dev(vidpid, serialname=None): - """Return the USB pyusb devie struct + """Return the USB pyusb devie struct + + Return the dev struct of the first USB device with VID:PID vidpid, + or None if no device is found. If more than one device check serial + if supplied. + + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. + + Returns: + pyusb device if found, None otherwise. + """ + vidpidst = vidpid.split(":") + vid = int(vidpidst[0], 16) + pid = int(vidpidst[1], 16) + + dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True) + dev_list = list(dev_g) + + if not dev_list: + return None + + # Check if we have multiple devices and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + return None + else: + try: + dev = dev_list[0] + except StopIteration: + return None + + return dev + - Return the dev struct of the first USB device with VID:PID vidpid, - or None if no device is found. If more than one device check serial - if supplied. +def check_usb_dev(vidpid, serialname=None): + """Return the USB dev number + + Return the dev number of the first USB device with VID:PID vidpid, + or None if no device is found. If more than one device check serial + if supplied. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. - Returns: - pyusb device if found, None otherwise. - """ - vidpidst = vidpid.split(':') - vid = int(vidpidst[0], 16) - pid = int(vidpidst[1], 16) + Returns: + usb device number if found, None otherwise. + """ + dev = get_usb_dev(vidpid, serialname=serialname) - dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True) - dev_list = list(dev_g) + if dev: + return dev.address - if not dev_list: return None - # Check if we have multiple devices and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - return None - else: - try: - dev = dev_list[0] - except StopIteration: - return None - - return dev -def check_usb_dev(vidpid, serialname=None): - """Return the USB dev number +def wait_for_usb_remove(vidpid, serialname=None, timeout=None): + """Wait for USB device with vidpid to be removed. - Return the dev number of the first USB device with VID:PID vidpid, - or None if no device is found. If more than one device check serial - if supplied. + Wrapper for wait_for_usb below + """ + wait_for_usb(vidpid, serialname=serialname, timeout=timeout, desiredpresence=False) - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. - Returns: - usb device number if found, None otherwise. - """ - dev = get_usb_dev(vidpid, serialname=serialname) +def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True): + """Wait for usb device with vidpid to be present/absent. - if dev: - return dev.address + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specificed. + timeout: timeout in seconds, None for no timeout. + desiredpresence: True for present, False for not present. - return None + Raises: + TinyServoError: on timeout. + """ + if timeout: + finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout) + while check_usb(vidpid, serialname) != desiredpresence: + time.sleep(0.1) + if timeout: + if datetime.datetime.now() > finish: + raise TinyServoError("Timeout", "Timeout waiting for USB %s" % vidpid) -def wait_for_usb_remove(vidpid, serialname=None, timeout=None): - """Wait for USB device with vidpid to be removed. +def do_serialno(serialno, pty): + """Set serialnumber 'serialno' via ec console 'pty'. - Wrapper for wait_for_usb below - """ - wait_for_usb(vidpid, serialname=serialname, - timeout=timeout, desiredpresence=False) + Commands are: + # > serialno set 1234 + # Saving serial number + # Serial number: 1234 -def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True): - """Wait for usb device with vidpid to be present/absent. - - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specificed. - timeout: timeout in seconds, None for no timeout. - desiredpresence: True for present, False for not present. - - Raises: - TinyServoError: on timeout. - """ - if timeout: - finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout) - while check_usb(vidpid, serialname) != desiredpresence: - time.sleep(.1) - if timeout: - if datetime.datetime.now() > finish: - raise TinyServoError('Timeout', 'Timeout waiting for USB %s' % vidpid) + Args: + serialno: string serial number to set. + pty: tinyservo console to send commands. + + Raises: + TinyServoError: on failure to set. + ptyError: on command interface error. + """ + cmd = r"serialno set %s" % serialno + regex = r"Serial number:\s+(\S+)" + + results = pty._issue_cmd_get_results(cmd, [regex])[0] + sn = results[1].strip().strip("\n\r") + + if sn == serialno: + log("Success !") + log("Serial set to %s" % sn) + else: + log("Serial number set to %s but saved as %s." % (serialno, sn)) + raise TinyServoError( + "Serial Number", "Serial number set to %s but saved as %s." % (serialno, sn) + ) -def do_serialno(serialno, pty): - """Set serialnumber 'serialno' via ec console 'pty'. - - Commands are: - # > serialno set 1234 - # Saving serial number - # Serial number: 1234 - - Args: - serialno: string serial number to set. - pty: tinyservo console to send commands. - - Raises: - TinyServoError: on failure to set. - ptyError: on command interface error. - """ - cmd = r'serialno set %s' % serialno - regex = r'Serial number:\s+(\S+)' - - results = pty._issue_cmd_get_results(cmd, [regex])[0] - sn = results[1].strip().strip('\n\r') - - if sn == serialno: - log('Success !') - log('Serial set to %s' % sn) - else: - log('Serial number set to %s but saved as %s.' % (serialno, sn)) - raise TinyServoError( - 'Serial Number', - 'Serial number set to %s but saved as %s.' % (serialno, sn)) def setup_tinyservod(vidpid, interface, serialname=None, debuglog=False): - """Set up a pty - - Set up a pty to the ec console in order - to send commands. Returns a pty_driver object. - - Args: - vidpid: string vidpid of device to access. - interface: not used. - serialname: string serial name of device requested, optional. - debuglog: chatty printout (boolean) - - Returns: - pty object - - Raises: - UsbError, SusbError: on device not found - """ - vidstr, pidstr = vidpid.split(':') - vid = int(vidstr, 16) - pid = int(pidstr, 16) - suart = stm32uart.Suart(vendor=vid, product=pid, - interface=interface, serialname=serialname, - debuglog=debuglog) - suart.run() - pty = pty_driver.ptyDriver(suart, []) - - return pty + """Set up a pty + + Set up a pty to the ec console in order + to send commands. Returns a pty_driver object. + + Args: + vidpid: string vidpid of device to access. + interface: not used. + serialname: string serial name of device requested, optional. + debuglog: chatty printout (boolean) + + Returns: + pty object + + Raises: + UsbError, SusbError: on device not found + """ + vidstr, pidstr = vidpid.split(":") + vid = int(vidstr, 16) + pid = int(pidstr, 16) + suart = stm32uart.Suart( + vendor=vid, + product=pid, + interface=interface, + serialname=serialname, + debuglog=debuglog, + ) + suart.run() + pty = pty_driver.ptyDriver(suart, []) + + return pty diff --git a/extra/tigertool/ecusb/tiny_servod.py b/extra/tigertool/ecusb/tiny_servod.py index 632d9c3a20..ca5ca63f31 100644 --- a/extra/tigertool/ecusb/tiny_servod.py +++ b/extra/tigertool/ecusb/tiny_servod.py @@ -8,47 +8,48 @@ """Helper class to facilitate communication to servo ec console.""" -from ecusb import pty_driver -from ecusb import stm32uart +from ecusb import pty_driver, stm32uart class TinyServod(object): - """Helper class to wrap a pty_driver with interface.""" - - def __init__(self, vid, pid, interface, serialname=None, debug=False): - """Build the driver and interface. - - Args: - vid: servo device vid - pid: servo device pid - interface: which usb interface the servo console is on - serialname: the servo device serial (if available) - """ - self._vid = vid - self._pid = pid - self._interface = interface - self._serial = serialname - self._debug = debug - self._init() - - def _init(self): - self.suart = stm32uart.Suart(vendor=self._vid, - product=self._pid, - interface=self._interface, - serialname=self._serial, - debuglog=self._debug) - self.suart.run() - self.pty = pty_driver.ptyDriver(self.suart, []) - - def reinitialize(self): - """Reinitialize the connect after a reset/disconnect/etc.""" - self.close() - self._init() - - def close(self): - """Close out the connection and release resources. - - Note: if another TinyServod process or servod itself needs the same device - it's necessary to call this to ensure the usb device is available. - """ - self.suart.close() + """Helper class to wrap a pty_driver with interface.""" + + def __init__(self, vid, pid, interface, serialname=None, debug=False): + """Build the driver and interface. + + Args: + vid: servo device vid + pid: servo device pid + interface: which usb interface the servo console is on + serialname: the servo device serial (if available) + """ + self._vid = vid + self._pid = pid + self._interface = interface + self._serial = serialname + self._debug = debug + self._init() + + def _init(self): + self.suart = stm32uart.Suart( + vendor=self._vid, + product=self._pid, + interface=self._interface, + serialname=self._serial, + debuglog=self._debug, + ) + self.suart.run() + self.pty = pty_driver.ptyDriver(self.suart, []) + + def reinitialize(self): + """Reinitialize the connect after a reset/disconnect/etc.""" + self.close() + self._init() + + def close(self): + """Close out the connection and release resources. + + Note: if another TinyServod process or servod itself needs the same device + it's necessary to call this to ensure the usb device is available. + """ + self.suart.close() diff --git a/extra/tigertool/tigertest.py b/extra/tigertool/tigertest.py index 0cd31c8cce..8f8b2c7f03 100755 --- a/extra/tigertool/tigertest.py +++ b/extra/tigertool/tigertest.py @@ -13,7 +13,6 @@ import argparse import subprocess import sys - # Script to control tigertail USB-C Mux board. # # optional arguments: @@ -35,58 +34,60 @@ import sys def testCmd(cmd, expected_results): - """Run command on console, check for success. - - Args: - cmd: shell command to run. - expected_results: a list object of strings expected in the result. - - Raises: - Exception on fail. - """ - print('run: ' + cmd) - try: - p = subprocess.run(cmd, shell=True, check=False, capture_output=True) - output = p.stdout.decode('utf-8') - error = p.stderr.decode('utf-8') - assert p.returncode == 0 - for result in expected_results: - output.index(result) - except Exception as e: - print('FAIL') - print('cmd: ' + cmd) - print('error: ' + str(e)) - print('stdout:\n' + output) - print('stderr:\n' + error) - print('expected: ' + str(expected_results)) - print('RC: ' + str(p.returncode)) - raise e + """Run command on console, check for success. + + Args: + cmd: shell command to run. + expected_results: a list object of strings expected in the result. + + Raises: + Exception on fail. + """ + print("run: " + cmd) + try: + p = subprocess.run(cmd, shell=True, check=False, capture_output=True) + output = p.stdout.decode("utf-8") + error = p.stderr.decode("utf-8") + assert p.returncode == 0 + for result in expected_results: + output.index(result) + except Exception as e: + print("FAIL") + print("cmd: " + cmd) + print("error: " + str(e)) + print("stdout:\n" + output) + print("stderr:\n" + error) + print("expected: " + str(expected_results)) + print("RC: " + str(p.returncode)) + raise e + def test_sequence(): - testCmd('./tigertool.py --reboot', ['PASS']) - testCmd('./tigertool.py --setserialno test', ['PASS']) - testCmd('./tigertool.py --check_serial', ['test', 'PASS']) - testCmd('./tigertool.py -s test --check_serial', ['test', 'PASS']) - testCmd('./tigertool.py -m A', ['Mux set to A', 'PASS']) - testCmd('./tigertool.py -m B', ['Mux set to B', 'PASS']) - testCmd('./tigertool.py -m off', ['Mux set to off', 'PASS']) - testCmd('./tigertool.py -p', ['PASS']) - testCmd('./tigertool.py -r rw', ['PASS']) - testCmd('./tigertool.py -r ro', ['PASS']) - testCmd('./tigertool.py --check_version', ['RW', 'RO', 'PASS']) - - print('PASS') + testCmd("./tigertool.py --reboot", ["PASS"]) + testCmd("./tigertool.py --setserialno test", ["PASS"]) + testCmd("./tigertool.py --check_serial", ["test", "PASS"]) + testCmd("./tigertool.py -s test --check_serial", ["test", "PASS"]) + testCmd("./tigertool.py -m A", ["Mux set to A", "PASS"]) + testCmd("./tigertool.py -m B", ["Mux set to B", "PASS"]) + testCmd("./tigertool.py -m off", ["Mux set to off", "PASS"]) + testCmd("./tigertool.py -p", ["PASS"]) + testCmd("./tigertool.py -r rw", ["PASS"]) + testCmd("./tigertool.py -r ro", ["PASS"]) + testCmd("./tigertool.py --check_version", ["RW", "RO", "PASS"]) + + print("PASS") + def main(argv): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('-c', '--count', type=int, default=1, - help='loops to run') + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("-c", "--count", type=int, default=1, help="loops to run") + + opts = parser.parse_args(argv) - opts = parser.parse_args(argv) + for i in range(1, opts.count + 1): + print("Iteration: %d" % i) + test_sequence() - for i in range(1, opts.count + 1): - print('Iteration: %d' % i) - test_sequence() -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/extra/tigertool/tigertool.py b/extra/tigertool/tigertool.py index 6baae8abdf..37b7b01495 100755 --- a/extra/tigertool/tigertool.py +++ b/extra/tigertool/tigertool.py @@ -17,287 +17,308 @@ import time import ecusb.tiny_servo_common as c -STM_VIDPID = '18d1:5027' -serialno = 'Uninitialized' +STM_VIDPID = "18d1:5027" +serialno = "Uninitialized" + def do_mux(mux, pty): - """Set mux via ec console 'pty'. + """Set mux via ec console 'pty'. + + Args: + mux: mux to connect to DUT, 'A', 'B', or 'off' + pty: a pty object connected to tigertail - Args: - mux: mux to connect to DUT, 'A', 'B', or 'off' - pty: a pty object connected to tigertail + Commands are: + # > mux A + # TYPE-C mux is A + """ + validmux = ["A", "B", "off"] + if mux not in validmux: + c.log("Mux setting %s invalid, try one of %s" % (mux, validmux)) + return False - Commands are: - # > mux A - # TYPE-C mux is A - """ - validmux = ['A', 'B', 'off'] - if mux not in validmux: - c.log('Mux setting %s invalid, try one of %s' % (mux, validmux)) - return False + cmd = "mux %s" % mux + regex = "TYPE\-C mux is ([^\s\r\n]*)\r" - cmd = 'mux %s' % mux - regex = 'TYPE\-C mux is ([^\s\r\n]*)\r' + results = pty._issue_cmd_get_results(cmd, [regex])[0] + result = results[1].strip().strip("\n\r") - results = pty._issue_cmd_get_results(cmd, [regex])[0] - result = results[1].strip().strip('\n\r') + if result != mux: + c.log("Mux set to %s but saved as %s." % (mux, result)) + return False + c.log("Mux set to %s" % result) + return True - if result != mux: - c.log('Mux set to %s but saved as %s.' % (mux, result)) - return False - c.log('Mux set to %s' % result) - return True def do_version(pty): - """Check version via ec console 'pty'. - - Args: - pty: a pty object connected to tigertail - - Commands are: - # > version - # Chip: stm stm32f07x - # Board: 0 - # RO: tigertail_v1.1.6749-74d1a312e - # RW: tigertail_v1.1.6749-74d1a312e - # Build: tigertail_v1.1.6749-74d1a312e - # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com - """ - cmd = 'version' - regex = r'RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+' \ - r'(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)' - - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('Version is %s' % results[3]) - c.log('RO: %s' % results[1]) - c.log('RW: %s' % results[2]) - c.log('Date: %s' % results[4]) - c.log('Src: %s' % results[5]) - - return True + """Check version via ec console 'pty'. + + Args: + pty: a pty object connected to tigertail + + Commands are: + # > version + # Chip: stm stm32f07x + # Board: 0 + # RO: tigertail_v1.1.6749-74d1a312e + # RW: tigertail_v1.1.6749-74d1a312e + # Build: tigertail_v1.1.6749-74d1a312e + # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com + """ + cmd = "version" + regex = ( + r"RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+" + r"(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)" + ) + + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log("Version is %s" % results[3]) + c.log("RO: %s" % results[1]) + c.log("RW: %s" % results[2]) + c.log("Date: %s" % results[4]) + c.log("Src: %s" % results[5]) + + return True + def do_check_serial(pty): - """Check serial via ec console 'pty'. + """Check serial via ec console 'pty'. - Args: - pty: a pty object connected to tigertail + Args: + pty: a pty object connected to tigertail - Commands are: - # > serialno - # Serial number: number - """ - cmd = 'serialno' - regex = r'Serial number: ([^\n\r]+)' + Commands are: + # > serialno + # Serial number: number + """ + cmd = "serialno" + regex = r"Serial number: ([^\n\r]+)" - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('Serial is %s' % results[1]) + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log("Serial is %s" % results[1]) - return True + return True def do_power(count, bus, pty): - """Check power usage via ec console 'pty'. - - Args: - count: number of samples to capture - bus: rail to monitor, 'vbus', 'cc1', or 'cc2' - pty: a pty object connected to tigertail - - Commands are: - # > ina 0 - # Configuration: 4127 - # Shunt voltage: 02c4 => 1770 uV - # Bus voltage : 1008 => 5130 mV - # Power : 0019 => 625 mW - # Current : 0082 => 130 mA - # Calibration : 0155 - # Mask/Enable : 0008 - # Alert limit : 0000 - """ - if bus == 'vbus': - ina = 0 - if bus == 'cc1': - ina = 4 - if bus == 'cc2': - ina = 1 - - start = time.time() - - c.log('time,\tmV,\tmW,\tmA') - - cmd = 'ina %s' % ina - regex = r'Bus voltage : \S+ \S+ (\d+) mV\s+' \ - r'Power : \S+ \S+ (\d+) mW\s+' \ - r'Current : \S+ \S+ (\d+) mA' - - for i in range(0, count): - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('%.2f,\t%s,\t%s\t%s' % ( - time.time() - start, - results[1], results[2], results[3])) + """Check power usage via ec console 'pty'. + + Args: + count: number of samples to capture + bus: rail to monitor, 'vbus', 'cc1', or 'cc2' + pty: a pty object connected to tigertail + + Commands are: + # > ina 0 + # Configuration: 4127 + # Shunt voltage: 02c4 => 1770 uV + # Bus voltage : 1008 => 5130 mV + # Power : 0019 => 625 mW + # Current : 0082 => 130 mA + # Calibration : 0155 + # Mask/Enable : 0008 + # Alert limit : 0000 + """ + if bus == "vbus": + ina = 0 + if bus == "cc1": + ina = 4 + if bus == "cc2": + ina = 1 + + start = time.time() + + c.log("time,\tmV,\tmW,\tmA") + + cmd = "ina %s" % ina + regex = ( + r"Bus voltage : \S+ \S+ (\d+) mV\s+" + r"Power : \S+ \S+ (\d+) mW\s+" + r"Current : \S+ \S+ (\d+) mA" + ) + + for i in range(0, count): + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log( + "%.2f,\t%s,\t%s\t%s" + % (time.time() - start, results[1], results[2], results[3]) + ) + + return True - return True def do_reboot(pty, serialname): - """Reboot via ec console pty - - Args: - pty: a pty object connected to tigertail - serialname: serial name, can be None. - - Command is: reboot. - """ - cmd = 'reboot' - - # Check usb dev number on current instance. - devno = c.check_usb_dev(STM_VIDPID, serialname=serialname) - if not devno: - c.log('Device not found') - return False - - try: - pty._issue_cmd(cmd) - except Exception as e: - c.log('Failed to send command: ' + str(e)) - return False - - try: - c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds - # it's not going to. This step just goes faster if it's detected. - pass - - try: - c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - c.log('Failed to return from reboot: ' + str(e)) - return False - - # Check that the device had a new device number, i.e. it's - # disconnected and reconnected. - newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname) - if newdevno == devno: - c.log("Device didn't reboot") - return False - - return True + """Reboot via ec console pty + + Args: + pty: a pty object connected to tigertail + serialname: serial name, can be None. + + Command is: reboot. + """ + cmd = "reboot" + + # Check usb dev number on current instance. + devno = c.check_usb_dev(STM_VIDPID, serialname=serialname) + if not devno: + c.log("Device not found") + return False + + try: + pty._issue_cmd(cmd) + except Exception as e: + c.log("Failed to send command: " + str(e)) + return False + + try: + c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds + # it's not going to. This step just goes faster if it's detected. + pass + + try: + c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + c.log("Failed to return from reboot: " + str(e)) + return False + + # Check that the device had a new device number, i.e. it's + # disconnected and reconnected. + newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname) + if newdevno == devno: + c.log("Device didn't reboot") + return False + + return True + def do_sysjump(region, pty, serialname): - """Set region via ec console 'pty'. - - Args: - region: ec code region to execute, 'ro' or 'rw' - pty: a pty object connected to tigertail - serialname: serial name, can be None. - - Commands are: - # > sysjump rw - """ - validregion = ['ro', 'rw'] - if region not in validregion: - c.log('Region setting %s invalid, try one of %s' % ( - region, validregion)) - return False - - cmd = 'sysjump %s' % region - try: - pty._issue_cmd(cmd) - except Exception as e: - c.log('Exception: ' + str(e)) - return False - - try: - c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds - # it's not going to. This step just goes faster if it's detected. - pass - - try: - c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - c.log('Failed to return from restart: ' + str(e)) - return False - - c.log('Region requested %s' % region) - return True + """Set region via ec console 'pty'. + + Args: + region: ec code region to execute, 'ro' or 'rw' + pty: a pty object connected to tigertail + serialname: serial name, can be None. + + Commands are: + # > sysjump rw + """ + validregion = ["ro", "rw"] + if region not in validregion: + c.log("Region setting %s invalid, try one of %s" % (region, validregion)) + return False + + cmd = "sysjump %s" % region + try: + pty._issue_cmd(cmd) + except Exception as e: + c.log("Exception: " + str(e)) + return False + + try: + c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds + # it's not going to. This step just goes faster if it's detected. + pass + + try: + c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + c.log("Failed to return from restart: " + str(e)) + return False + + c.log("Region requested %s" % region) + return True + def get_parser(): - parser = argparse.ArgumentParser( - description=__doc__) - parser.add_argument('-s', '--serialno', type=str, default=None, - help='serial number of board to use') - parser.add_argument('-b', '--bus', type=str, default='vbus', - help='Which rail to log: [vbus|cc1|cc2]') - group = parser.add_mutually_exclusive_group() - group.add_argument('--setserialno', type=str, default=None, - help='serial number to set on the board.') - group.add_argument('--check_serial', action='store_true', - help='check serial number set on the board.') - group.add_argument('-m', '--mux', type=str, default=None, - help='mux selection') - group.add_argument('-p', '--power', action='store_true', - help='check VBUS') - group.add_argument('-l', '--powerlog', type=int, default=None, - help='log VBUS') - group.add_argument('-r', '--sysjump', type=str, default=None, - help='region selection') - group.add_argument('--reboot', action='store_true', - help='reboot tigertail') - group.add_argument('--check_version', action='store_true', - help='check tigertail version') - return parser + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-s", "--serialno", type=str, default=None, help="serial number of board to use" + ) + parser.add_argument( + "-b", + "--bus", + type=str, + default="vbus", + help="Which rail to log: [vbus|cc1|cc2]", + ) + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--setserialno", + type=str, + default=None, + help="serial number to set on the board.", + ) + group.add_argument( + "--check_serial", + action="store_true", + help="check serial number set on the board.", + ) + group.add_argument("-m", "--mux", type=str, default=None, help="mux selection") + group.add_argument("-p", "--power", action="store_true", help="check VBUS") + group.add_argument("-l", "--powerlog", type=int, default=None, help="log VBUS") + group.add_argument( + "-r", "--sysjump", type=str, default=None, help="region selection" + ) + group.add_argument("--reboot", action="store_true", help="reboot tigertail") + group.add_argument( + "--check_version", action="store_true", help="check tigertail version" + ) + return parser + def main(argv): - parser = get_parser() - opts = parser.parse_args(argv) + parser = get_parser() + opts = parser.parse_args(argv) - result = True + result = True - # Let's make sure there's a tigertail - # If nothing found in 5 seconds, fail. - c.wait_for_usb(STM_VIDPID, timeout=5., serialname=opts.serialno) + # Let's make sure there's a tigertail + # If nothing found in 5 seconds, fail. + c.wait_for_usb(STM_VIDPID, timeout=5.0, serialname=opts.serialno) - pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno) + pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno) - if opts.bus not in ('vbus', 'cc1', 'cc2'): - c.log('Try --bus [vbus|cc1|cc2]') - result = False + if opts.bus not in ("vbus", "cc1", "cc2"): + c.log("Try --bus [vbus|cc1|cc2]") + result = False - elif opts.setserialno: - try: - c.do_serialno(opts.setserialno, pty) - except Exception: - result = False + elif opts.setserialno: + try: + c.do_serialno(opts.setserialno, pty) + except Exception: + result = False - elif opts.mux: - result &= do_mux(opts.mux, pty) + elif opts.mux: + result &= do_mux(opts.mux, pty) - elif opts.sysjump: - result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno) + elif opts.sysjump: + result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno) - elif opts.reboot: - result &= do_reboot(pty, serialname=opts.serialno) + elif opts.reboot: + result &= do_reboot(pty, serialname=opts.serialno) - elif opts.check_version: - result &= do_version(pty) + elif opts.check_version: + result &= do_version(pty) - elif opts.check_serial: - result &= do_check_serial(pty) + elif opts.check_serial: + result &= do_check_serial(pty) - elif opts.power: - result &= do_power(1, opts.bus, pty) + elif opts.power: + result &= do_power(1, opts.bus, pty) - elif opts.powerlog: - result &= do_power(opts.powerlog, opts.bus, pty) + elif opts.powerlog: + result &= do_power(opts.powerlog, opts.bus, pty) - if result: - c.log('PASS') - else: - c.log('FAIL') - sys.exit(-1) + if result: + c.log("PASS") + else: + c.log("FAIL") + sys.exit(-1) -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/extra/usb_power/convert_power_log_board.py b/extra/usb_power/convert_power_log_board.py index 8aab77ee4c..c1b25f57db 100644 --- a/extra/usb_power/convert_power_log_board.py +++ b/extra/usb_power/convert_power_log_board.py @@ -14,6 +14,7 @@ Program to convert sweetberry config to servod config template. # Note: This is a py2/3 compatible file. from __future__ import print_function + import json import os import sys @@ -48,21 +49,25 @@ def write_to_file(file, sweetberry, inas): inas: list of inas read from board file. """ - with open(file, 'w') as pyfile: + with open(file, "w") as pyfile: - pyfile.write('inas = [\n') + pyfile.write("inas = [\n") for rec in inas: - if rec['sweetberry'] != sweetberry: + if rec["sweetberry"] != sweetberry: continue # EX : ('sweetberry', 0x40, 'SB_FW_CAM_2P8', 5.0, 1.000, 3, False), - channel, i2c_addr = Spower.CHMAP[rec['channel']] - record = (" ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" - ",\n" % (i2c_addr, rec['name'], rec['rs'], channel)) + channel, i2c_addr = Spower.CHMAP[rec["channel"]] + record = " ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" ",\n" % ( + i2c_addr, + rec["name"], + rec["rs"], + channel, + ) pyfile.write(record) - pyfile.write(']\n') + pyfile.write("]\n") def main(argv): @@ -76,16 +81,18 @@ def main(argv): inas = fetch_records(inputf) - sweetberry = set(rec['sweetberry'] for rec in inas) + sweetberry = set(rec["sweetberry"] for rec in inas) if len(sweetberry) == 2: - print("Converting %s to %s and %s" % (inputf, basename + '_a.py', - basename + '_b.py')) - write_to_file(basename + '_a.py', 'A', inas) - write_to_file(basename + '_b.py', 'B', inas) + print( + "Converting %s to %s and %s" + % (inputf, basename + "_a.py", basename + "_b.py") + ) + write_to_file(basename + "_a.py", "A", inas) + write_to_file(basename + "_b.py", "B", inas) else: - print("Converting %s to %s" % (inputf, basename + '.py')) - write_to_file(basename + '.py', sweetberry.pop(), inas) + print("Converting %s to %s" % (inputf, basename + ".py")) + write_to_file(basename + ".py", sweetberry.pop(), inas) if __name__ == "__main__": diff --git a/extra/usb_power/convert_servo_ina.py b/extra/usb_power/convert_servo_ina.py index 1c70f31aeb..6ccd474e6c 100755 --- a/extra/usb_power/convert_servo_ina.py +++ b/extra/usb_power/convert_servo_ina.py @@ -14,67 +14,71 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import os import sys def fetch_records(basename): - """Import records from servo_ina file. + """Import records from servo_ina file. - servo_ina files are python imports, and have a list of tuples with - the INA data. - (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True) + servo_ina files are python imports, and have a list of tuples with + the INA data. + (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True) - Args: - basename: python import name (filename -.py) + Args: + basename: python import name (filename -.py) - Returns: - list of tuples as described above. - """ - ina_desc = __import__(basename) - return ina_desc.inas + Returns: + list of tuples as described above. + """ + ina_desc = __import__(basename) + return ina_desc.inas def main(argv): - if len(argv) != 2: - print("usage:") - print(" %s input.py" % argv[0]) - return + if len(argv) != 2: + print("usage:") + print(" %s input.py" % argv[0]) + return - inputf = argv[1] - basename = os.path.splitext(inputf)[0] - outputf = basename + '.board' - outputs = basename + '.scenario' + inputf = argv[1] + basename = os.path.splitext(inputf)[0] + outputf = basename + ".board" + outputs = basename + ".scenario" - print("Converting %s to %s, %s" % (inputf, outputf, outputs)) + print("Converting %s to %s, %s" % (inputf, outputf, outputs)) - inas = fetch_records(basename) + inas = fetch_records(basename) + boardfile = open(outputf, "w") + scenario = open(outputs, "w") - boardfile = open(outputf, 'w') - scenario = open(outputs, 'w') + boardfile.write("[\n") + scenario.write("[\n") + start = True - boardfile.write('[\n') - scenario.write('[\n') - start = True + for rec in inas: + if start: + start = False + else: + boardfile.write(",\n") + scenario.write(",\n") - for rec in inas: - if start: - start = False - else: - boardfile.write(',\n') - scenario.write(',\n') + record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % ( + rec[2], + rec[4], + rec[1] - 64, + ) + boardfile.write(record) + scenario.write('"%s"' % rec[2]) - record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % ( - rec[2], rec[4], rec[1] - 64) - boardfile.write(record) - scenario.write('"%s"' % rec[2]) + boardfile.write("\n") + boardfile.write("]") - boardfile.write('\n') - boardfile.write(']') + scenario.write("\n") + scenario.write("]") - scenario.write('\n') - scenario.write(']') if __name__ == "__main__": - main(sys.argv) + main(sys.argv) diff --git a/extra/usb_power/powerlog.py b/extra/usb_power/powerlog.py index 82cce3daed..d893e5c6b9 100755 --- a/extra/usb_power/powerlog.py +++ b/extra/usb_power/powerlog.py @@ -14,9 +14,9 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import array -from distutils import sysconfig import json import logging import os @@ -25,884 +25,988 @@ import struct import sys import time import traceback +from distutils import sysconfig import usb - from stats_manager import StatsManager # Directory where hdctools installs configuration files into. -LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), 'servo', - 'data') +LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), "servo", "data") # Potential config file locations: current working directory, the same directory # as powerlog.py file or LIB_DIR. -CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), - LIB_DIR] - -def logoutput(msg): - print(msg) - sys.stdout.flush() - -def process_filename(filename): - """Find the file path from the filename. - - If filename is already the complete path, return that directly. If filename is - just the short name, look for the file in the current working directory, in - the directory of the current .py file, and then in the directory installed by - hdctools. If the file is found, return the complete path of the file. - - Args: - filename: complete file path or short file name. - - Returns: - a complete file path. - - Raises: - IOError if filename does not exist. - """ - # Check if filename is absolute path. - if os.path.isabs(filename) and os.path.isfile(filename): - return filename - # Check if filename is relative to a known config location. - for dirname in CONFIG_LOCATIONS: - file_at_dir = os.path.join(dirname, filename) - if os.path.isfile(file_at_dir): - return file_at_dir - raise IOError('No such file or directory: \'%s\'' % filename) - - -class Spower(object): - """Power class to access devices on the bus. - - Usage: - bus = Spower() - - Instance Variables: - _dev: pyUSB device object - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - - # INA interface type. - INA_POWER = 1 - INA_BUSV = 2 - INA_CURRENT = 3 - INA_SHUNTV = 4 - # INA_SUFFIX is used to differentiate multiple ina types for the same power - # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1 - # (power, no suffix for backward compatibility). - INA_SUFFIX = ['', '', '_busv', '_cur', '_shuntv'] - - # usb power commands - CMD_RESET = 0x0000 - CMD_STOP = 0x0001 - CMD_ADDINA = 0x0002 - CMD_START = 0x0003 - CMD_NEXT = 0x0004 - CMD_SETTIME = 0x0005 - - # Map between header channel number (0-47) - # and INA I2C bus/addr on sweetberry. - CHMAP = { - 0: (3, 0x40), - 1: (1, 0x40), - 2: (2, 0x40), - 3: (0, 0x40), - 4: (3, 0x41), - 5: (1, 0x41), - 6: (2, 0x41), - 7: (0, 0x41), - 8: (3, 0x42), - 9: (1, 0x42), - 10: (2, 0x42), - 11: (0, 0x42), - 12: (3, 0x43), - 13: (1, 0x43), - 14: (2, 0x43), - 15: (0, 0x43), - 16: (3, 0x44), - 17: (1, 0x44), - 18: (2, 0x44), - 19: (0, 0x44), - 20: (3, 0x45), - 21: (1, 0x45), - 22: (2, 0x45), - 23: (0, 0x45), - 24: (3, 0x46), - 25: (1, 0x46), - 26: (2, 0x46), - 27: (0, 0x46), - 28: (3, 0x47), - 29: (1, 0x47), - 30: (2, 0x47), - 31: (0, 0x47), - 32: (3, 0x48), - 33: (1, 0x48), - 34: (2, 0x48), - 35: (0, 0x48), - 36: (3, 0x49), - 37: (1, 0x49), - 38: (2, 0x49), - 39: (0, 0x49), - 40: (3, 0x4a), - 41: (1, 0x4a), - 42: (2, 0x4a), - 43: (0, 0x4a), - 44: (3, 0x4b), - 45: (1, 0x4b), - 46: (2, 0x4b), - 47: (0, 0x4b), - } - - def __init__(self, board, vendor=0x18d1, - product=0x5020, interface=1, serialname=None): - self._logger = logging.getLogger(__name__) - self._board = board - - # Find the stm32. - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise Exception("Power", "USB device not found") - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - dev_serial = "PyUSB dioesn't have a stable interface" - try: - dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) - except ValueError: - # Incompatible pyUsb version. - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - raise Exception("Power", "USB device(%s) not found" % serialname) - else: - try: - dev = dev_list[0] - except TypeError: - # Incompatible pyUsb version. - dev = dev_list.next() - - self._logger.debug("Found USB device: %04x:%04x", vendor, product) - self._dev = dev - - # Get an endpoint instance. - try: - dev.set_configuration() - except usb.USBError: - pass - cfg = dev.get_active_configuration() - - intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \ - i.bInterfaceClass==255 and i.bInterfaceSubClass==0x54) - - self._intf = intf - self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber) - - read_ep = usb.util.find_descriptor( - intf, - # match the first IN endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_IN - ) - - self._read_ep = read_ep - self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress) - - write_ep = usb.util.find_descriptor( - intf, - # match the first OUT endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_OUT - ) - - self._write_ep = write_ep - self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress) - - self.clear_ina_struct() - - self._logger.debug("Found power logging USB endpoint.") - - def clear_ina_struct(self): - """ Clear INA description struct.""" - self._inas = [] +CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), LIB_DIR] - def append_ina_struct(self, name, rs, port, addr, - data=None, ina_type=INA_POWER): - """Add an INA descriptor into the list of active INAs. - Args: - name: Readable name of this channel. - rs: Sense resistor value in ohms, floating point. - port: I2C channel this INA is connected to. - addr: I2C addr of this INA. - data: Misc data for special handling, board specific. - ina_type: INA function to use, power, voltage, etc. - """ - ina = {} - ina['name'] = name - ina['rs'] = rs - ina['port'] = port - ina['addr'] = addr - ina['type'] = ina_type - # Calculate INA231 Calibration register - # (see INA231 spec p.15) - # CurrentLSB = uA per div = 80mV / (Rsh * 2^15) - # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000) - ina['uAscale'] = 80000000. / (rs * 0x8000); - ina['uWscale'] = 25. * ina['uAscale']; - ina['mVscale'] = 1.25 - ina['uVscale'] = 2.5 - ina['data'] = data - self._inas.append(ina) - - def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000): - """Write command to logger logic. - - This function writes byte command values list to stm, then reads - byte status. - - Args: - write_list: list of command byte values [0~255]. - read_count: number of status byte values to read. - - Interface: - write: [command, data ... ] - read: [status ] - - Returns: - bytes read, or None on failure. - """ - self._logger.debug("Spower.wr_command(write_list=[%s] (%d), read_count=%s)", - list(bytearray(write_list)), len(write_list), read_count) - - # Clean up args from python style to correct types. - write_length = 0 - if write_list: - write_length = len(write_list) - if not read_count: - read_count = 0 - - # Send command to stm32. - if write_list: - cmd = write_list - ret = self._write_ep.write(cmd, wtimeout) - - self._logger.debug("RET: %s ", ret) - - # Read back response if necessary. - if read_count: - bytesread = self._read_ep.read(512, rtimeout) - self._logger.debug("BYTES: [%s]", bytesread) - - if len(bytesread) != read_count: - pass - - self._logger.debug("STATUS: 0x%02x", int(bytesread[0])) - if read_count == 1: - return bytesread[0] - else: - return bytesread - - return None - - def clear(self): - """Clear pending reads on the stm32""" - try: - while True: - ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50) - self._logger.debug("Try Clear: read %s", - "success" if ret == 0 else "failure") - except: - pass - - def send_reset(self): - """Reset the power interface on the stm32""" - cmd = struct.pack("<H", self.CMD_RESET) - ret = self.wr_command(cmd, rtimeout=50, wtimeout=50) - self._logger.debug("Command RESET: %s", - "success" if ret == 0 else "failure") - - def reset(self): - """Try resetting the USB interface until success. - - Use linear back off strategy when encounter the error with 10ms increment. - - Raises: - Exception on failure. - """ - max_reset_retry = 100 - for count in range(1, max_reset_retry + 1): - self.clear() - try: - self.send_reset() - return - except Exception as e: - self.clear() - self.clear() - self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e) - time.sleep(count * 0.01) - raise Exception("Power", "Failed to reset") - - def stop(self): - """Stop any active data acquisition.""" - cmd = struct.pack("<H", self.CMD_STOP) - ret = self.wr_command(cmd) - self._logger.debug("Command STOP: %s", - "success" if ret == 0 else "failure") - - def start(self, integration_us): - """Start data acquisition. - - Args: - integration_us: int, how many us between samples, and - how often the data block must be read. +def logoutput(msg): + print(msg) + sys.stdout.flush() - Returns: - actual sampling interval in ms. - """ - cmd = struct.pack("<HI", self.CMD_START, integration_us) - read = self.wr_command(cmd, read_count=5) - actual_us = 0 - if len(read) == 5: - ret, actual_us = struct.unpack("<BI", read) - self._logger.debug("Command START: %s %dus", - "success" if ret == 0 else "failure", actual_us) - else: - self._logger.debug("Command START: FAIL") - return actual_us +def process_filename(filename): + """Find the file path from the filename. - def add_ina_name(self, name_tuple): - """Add INA from board config. + If filename is already the complete path, return that directly. If filename is + just the short name, look for the file in the current working directory, in + the directory of the current .py file, and then in the directory installed by + hdctools. If the file is found, return the complete path of the file. Args: - name_tuple: name and type of power rail in board config. + filename: complete file path or short file name. Returns: - True if INA added, False if the INA is not on this board. + a complete file path. Raises: - Exception on unexpected failure. + IOError if filename does not exist. """ - name, ina_type = name_tuple - - for datum in self._brdcfg: - if datum["name"] == name: - rs = int(float(datum["rs"]) * 1000.) - board = datum["sweetberry"] - - if board == self._board: - if 'port' in datum and 'addr' in datum: - port = datum['port'] - addr = datum['addr'] - else: - channel = int(datum["channel"]) - port, addr = self.CHMAP[channel] - self.add_ina(port, ina_type, addr, 0, rs, data=datum) - return True - else: - return False - raise Exception("Power", "Failed to find INA %s" % name) - - def set_time(self, timestamp_us): - """Set sweetberry time to match host time. - - Args: - timestamp_us: host timestmap in us. - """ - # 0x0005 , 8 byte timestamp - cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us) - ret = self.wr_command(cmd) + # Check if filename is absolute path. + if os.path.isabs(filename) and os.path.isfile(filename): + return filename + # Check if filename is relative to a known config location. + for dirname in CONFIG_LOCATIONS: + file_at_dir = os.path.join(dirname, filename) + if os.path.isfile(file_at_dir): + return file_at_dir + raise IOError("No such file or directory: '%s'" % filename) - self._logger.debug("Command SETTIME: %s", - "success" if ret == 0 else "failure") - def add_ina(self, bus, ina_type, addr, extra, resistance, data=None): - """Add an INA to the data acquisition list. +class Spower(object): + """Power class to access devices on the bus. - Args: - bus: which i2c bus the INA is on. Same ordering as Si2c. - ina_type: Ina interface: INA_POWER/BUSV/etc. - addr: 7 bit i2c addr of this INA - extra: extra data for nonstandard configs. - resistance: int, shunt resistance in mOhm - """ - # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs - cmd = struct.pack("<HBBBBI", self.CMD_ADDINA, - bus, ina_type, addr, extra, resistance) - ret = self.wr_command(cmd) - if ret == 0: - if data: - name = data['name'] - else: - name = "ina%d_%02x" % (bus, addr) - self.append_ina_struct(name, resistance, bus, addr, - data=data, ina_type=ina_type) - self._logger.debug("Command ADD_INA: %s", - "success" if ret == 0 else "failure") - - def report_header_size(self): - """Helper function to calculate power record header size.""" - result = 2 - timestamp = 8 - return result + timestamp - - def report_size(self, ina_count): - """Helper function to calculate full power record size.""" - record = 2 - - datasize = self.report_header_size() + ina_count * record - # Round to multiple of 4 bytes. - datasize = int(((datasize + 3) // 4) * 4) - - return datasize - - def read_line(self): - """Read a line of data from the setup INAs + Usage: + bus = Spower() - Returns: - list of dicts of the values read by ina/type tuple, otherwise None. - [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}] + Instance Variables: + _dev: pyUSB device object + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - try: - expected_bytes = self.report_size(len(self._inas)) - cmd = struct.pack("<H", self.CMD_NEXT) - bytesread = self.wr_command(cmd, read_count=expected_bytes) - except usb.core.USBError as e: - self._logger.error("READ LINE FAILED %s", e) - return None - - if len(bytesread) == 1: - if bytesread[0] != 0x6: - self._logger.debug("READ LINE FAILED bytes: %d ret: %02x", - len(bytesread), bytesread[0]) - return None - - if len(bytesread) % expected_bytes != 0: - self._logger.debug("READ LINE WARNING: expected %d, got %d", - expected_bytes, len(bytesread)) - - packet_count = len(bytesread) // expected_bytes - - values = [] - for i in range(0, packet_count): - start = i * expected_bytes - end = (i + 1) * expected_bytes - record = self.interpret_line(bytesread[start:end]) - values.append(record) - - return values - - def interpret_line(self, data): - """Interpret a power record from INAs - Args: - data: one single record of bytes. + # INA interface type. + INA_POWER = 1 + INA_BUSV = 2 + INA_CURRENT = 3 + INA_SHUNTV = 4 + # INA_SUFFIX is used to differentiate multiple ina types for the same power + # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1 + # (power, no suffix for backward compatibility). + INA_SUFFIX = ["", "", "_busv", "_cur", "_shuntv"] + + # usb power commands + CMD_RESET = 0x0000 + CMD_STOP = 0x0001 + CMD_ADDINA = 0x0002 + CMD_START = 0x0003 + CMD_NEXT = 0x0004 + CMD_SETTIME = 0x0005 + + # Map between header channel number (0-47) + # and INA I2C bus/addr on sweetberry. + CHMAP = { + 0: (3, 0x40), + 1: (1, 0x40), + 2: (2, 0x40), + 3: (0, 0x40), + 4: (3, 0x41), + 5: (1, 0x41), + 6: (2, 0x41), + 7: (0, 0x41), + 8: (3, 0x42), + 9: (1, 0x42), + 10: (2, 0x42), + 11: (0, 0x42), + 12: (3, 0x43), + 13: (1, 0x43), + 14: (2, 0x43), + 15: (0, 0x43), + 16: (3, 0x44), + 17: (1, 0x44), + 18: (2, 0x44), + 19: (0, 0x44), + 20: (3, 0x45), + 21: (1, 0x45), + 22: (2, 0x45), + 23: (0, 0x45), + 24: (3, 0x46), + 25: (1, 0x46), + 26: (2, 0x46), + 27: (0, 0x46), + 28: (3, 0x47), + 29: (1, 0x47), + 30: (2, 0x47), + 31: (0, 0x47), + 32: (3, 0x48), + 33: (1, 0x48), + 34: (2, 0x48), + 35: (0, 0x48), + 36: (3, 0x49), + 37: (1, 0x49), + 38: (2, 0x49), + 39: (0, 0x49), + 40: (3, 0x4A), + 41: (1, 0x4A), + 42: (2, 0x4A), + 43: (0, 0x4A), + 44: (3, 0x4B), + 45: (1, 0x4B), + 46: (2, 0x4B), + 47: (0, 0x4B), + } + + def __init__( + self, board, vendor=0x18D1, product=0x5020, interface=1, serialname=None + ): + self._logger = logging.getLogger(__name__) + self._board = board + + # Find the stm32. + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise Exception("Power", "USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + dev_serial = "PyUSB dioesn't have a stable interface" + try: + dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) + except ValueError: + # Incompatible pyUsb version. + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + raise Exception("Power", "USB device(%s) not found" % serialname) + else: + try: + dev = dev_list[0] + except TypeError: + # Incompatible pyUsb version. + dev = dev_list.next() - Output: - stdout of the record in csv format. + self._logger.debug("Found USB device: %04x:%04x", vendor, product) + self._dev = dev - Returns: - dict containing name, value of recorded data. - """ - status, size = struct.unpack("<BB", data[0:2]) - if len(data) != self.report_size(size): - self._logger.error("READ LINE FAILED st:%d size:%d expected:%d len:%d", - status, size, self.report_size(size), len(data)) - else: - pass + # Get an endpoint instance. + try: + dev.set_configuration() + except usb.USBError: + pass + cfg = dev.get_active_configuration() + + intf = usb.util.find_descriptor( + cfg, + custom_match=lambda i: i.bInterfaceClass == 255 + and i.bInterfaceSubClass == 0x54, + ) + + self._intf = intf + self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber) + + read_ep = usb.util.find_descriptor( + intf, + # match the first IN endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_IN, + ) + + self._read_ep = read_ep + self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress) + + write_ep = usb.util.find_descriptor( + intf, + # match the first OUT endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_OUT, + ) + + self._write_ep = write_ep + self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress) + + self.clear_ina_struct() + + self._logger.debug("Found power logging USB endpoint.") + + def clear_ina_struct(self): + """Clear INA description struct.""" + self._inas = [] + + def append_ina_struct(self, name, rs, port, addr, data=None, ina_type=INA_POWER): + """Add an INA descriptor into the list of active INAs. + + Args: + name: Readable name of this channel. + rs: Sense resistor value in ohms, floating point. + port: I2C channel this INA is connected to. + addr: I2C addr of this INA. + data: Misc data for special handling, board specific. + ina_type: INA function to use, power, voltage, etc. + """ + ina = {} + ina["name"] = name + ina["rs"] = rs + ina["port"] = port + ina["addr"] = addr + ina["type"] = ina_type + # Calculate INA231 Calibration register + # (see INA231 spec p.15) + # CurrentLSB = uA per div = 80mV / (Rsh * 2^15) + # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000) + ina["uAscale"] = 80000000.0 / (rs * 0x8000) + ina["uWscale"] = 25.0 * ina["uAscale"] + ina["mVscale"] = 1.25 + ina["uVscale"] = 2.5 + ina["data"] = data + self._inas.append(ina) + + def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000): + """Write command to logger logic. + + This function writes byte command values list to stm, then reads + byte status. + + Args: + write_list: list of command byte values [0~255]. + read_count: number of status byte values to read. + + Interface: + write: [command, data ... ] + read: [status ] + + Returns: + bytes read, or None on failure. + """ + self._logger.debug( + "Spower.wr_command(write_list=[%s] (%d), read_count=%s)", + list(bytearray(write_list)), + len(write_list), + read_count, + ) + + # Clean up args from python style to correct types. + write_length = 0 + if write_list: + write_length = len(write_list) + if not read_count: + read_count = 0 + + # Send command to stm32. + if write_list: + cmd = write_list + ret = self._write_ep.write(cmd, wtimeout) + + self._logger.debug("RET: %s ", ret) + + # Read back response if necessary. + if read_count: + bytesread = self._read_ep.read(512, rtimeout) + self._logger.debug("BYTES: [%s]", bytesread) + + if len(bytesread) != read_count: + pass + + self._logger.debug("STATUS: 0x%02x", int(bytesread[0])) + if read_count == 1: + return bytesread[0] + else: + return bytesread + + return None + + def clear(self): + """Clear pending reads on the stm32""" + try: + while True: + ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50) + self._logger.debug( + "Try Clear: read %s", "success" if ret == 0 else "failure" + ) + except: + pass + + def send_reset(self): + """Reset the power interface on the stm32""" + cmd = struct.pack("<H", self.CMD_RESET) + ret = self.wr_command(cmd, rtimeout=50, wtimeout=50) + self._logger.debug("Command RESET: %s", "success" if ret == 0 else "failure") + + def reset(self): + """Try resetting the USB interface until success. + + Use linear back off strategy when encounter the error with 10ms increment. + + Raises: + Exception on failure. + """ + max_reset_retry = 100 + for count in range(1, max_reset_retry + 1): + self.clear() + try: + self.send_reset() + return + except Exception as e: + self.clear() + self.clear() + self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e) + time.sleep(count * 0.01) + raise Exception("Power", "Failed to reset") + + def stop(self): + """Stop any active data acquisition.""" + cmd = struct.pack("<H", self.CMD_STOP) + ret = self.wr_command(cmd) + self._logger.debug("Command STOP: %s", "success" if ret == 0 else "failure") + + def start(self, integration_us): + """Start data acquisition. + + Args: + integration_us: int, how many us between samples, and + how often the data block must be read. + + Returns: + actual sampling interval in ms. + """ + cmd = struct.pack("<HI", self.CMD_START, integration_us) + read = self.wr_command(cmd, read_count=5) + actual_us = 0 + if len(read) == 5: + ret, actual_us = struct.unpack("<BI", read) + self._logger.debug( + "Command START: %s %dus", + "success" if ret == 0 else "failure", + actual_us, + ) + else: + self._logger.debug("Command START: FAIL") + + return actual_us + + def add_ina_name(self, name_tuple): + """Add INA from board config. + + Args: + name_tuple: name and type of power rail in board config. + + Returns: + True if INA added, False if the INA is not on this board. + + Raises: + Exception on unexpected failure. + """ + name, ina_type = name_tuple + + for datum in self._brdcfg: + if datum["name"] == name: + rs = int(float(datum["rs"]) * 1000.0) + board = datum["sweetberry"] + + if board == self._board: + if "port" in datum and "addr" in datum: + port = datum["port"] + addr = datum["addr"] + else: + channel = int(datum["channel"]) + port, addr = self.CHMAP[channel] + self.add_ina(port, ina_type, addr, 0, rs, data=datum) + return True + else: + return False + raise Exception("Power", "Failed to find INA %s" % name) + + def set_time(self, timestamp_us): + """Set sweetberry time to match host time. + + Args: + timestamp_us: host timestmap in us. + """ + # 0x0005 , 8 byte timestamp + cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us) + ret = self.wr_command(cmd) + + self._logger.debug("Command SETTIME: %s", "success" if ret == 0 else "failure") + + def add_ina(self, bus, ina_type, addr, extra, resistance, data=None): + """Add an INA to the data acquisition list. + + Args: + bus: which i2c bus the INA is on. Same ordering as Si2c. + ina_type: Ina interface: INA_POWER/BUSV/etc. + addr: 7 bit i2c addr of this INA + extra: extra data for nonstandard configs. + resistance: int, shunt resistance in mOhm + """ + # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs + cmd = struct.pack( + "<HBBBBI", self.CMD_ADDINA, bus, ina_type, addr, extra, resistance + ) + ret = self.wr_command(cmd) + if ret == 0: + if data: + name = data["name"] + else: + name = "ina%d_%02x" % (bus, addr) + self.append_ina_struct( + name, resistance, bus, addr, data=data, ina_type=ina_type + ) + self._logger.debug("Command ADD_INA: %s", "success" if ret == 0 else "failure") + + def report_header_size(self): + """Helper function to calculate power record header size.""" + result = 2 + timestamp = 8 + return result + timestamp + + def report_size(self, ina_count): + """Helper function to calculate full power record size.""" + record = 2 + + datasize = self.report_header_size() + ina_count * record + # Round to multiple of 4 bytes. + datasize = int(((datasize + 3) // 4) * 4) + + return datasize + + def read_line(self): + """Read a line of data from the setup INAs + + Returns: + list of dicts of the values read by ina/type tuple, otherwise None. + [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}] + """ + try: + expected_bytes = self.report_size(len(self._inas)) + cmd = struct.pack("<H", self.CMD_NEXT) + bytesread = self.wr_command(cmd, read_count=expected_bytes) + except usb.core.USBError as e: + self._logger.error("READ LINE FAILED %s", e) + return None + + if len(bytesread) == 1: + if bytesread[0] != 0x6: + self._logger.debug( + "READ LINE FAILED bytes: %d ret: %02x", len(bytesread), bytesread[0] + ) + return None + + if len(bytesread) % expected_bytes != 0: + self._logger.debug( + "READ LINE WARNING: expected %d, got %d", expected_bytes, len(bytesread) + ) + + packet_count = len(bytesread) // expected_bytes + + values = [] + for i in range(0, packet_count): + start = i * expected_bytes + end = (i + 1) * expected_bytes + record = self.interpret_line(bytesread[start:end]) + values.append(record) + + return values + + def interpret_line(self, data): + """Interpret a power record from INAs + + Args: + data: one single record of bytes. + + Output: + stdout of the record in csv format. + + Returns: + dict containing name, value of recorded data. + """ + status, size = struct.unpack("<BB", data[0:2]) + if len(data) != self.report_size(size): + self._logger.error( + "READ LINE FAILED st:%d size:%d expected:%d len:%d", + status, + size, + self.report_size(size), + len(data), + ) + else: + pass - timestamp = struct.unpack("<Q", data[2:10])[0] - self._logger.debug("READ LINE: st:%d size:%d time:%dus", status, size, - timestamp) - ftimestamp = float(timestamp) / 1000000. + timestamp = struct.unpack("<Q", data[2:10])[0] + self._logger.debug( + "READ LINE: st:%d size:%d time:%dus", status, size, timestamp + ) + ftimestamp = float(timestamp) / 1000000.0 - record = {"ts": ftimestamp, "status": status, "berry":self._board} + record = {"ts": ftimestamp, "status": status, "berry": self._board} - for i in range(0, size): - idx = self.report_header_size() + 2*i - name = self._inas[i]['name'] - name_tuple = (self._inas[i]['name'], self._inas[i]['type']) + for i in range(0, size): + idx = self.report_header_size() + 2 * i + name = self._inas[i]["name"] + name_tuple = (self._inas[i]["name"], self._inas[i]["type"]) - raw_val = struct.unpack("<h", data[idx:idx+2])[0] + raw_val = struct.unpack("<h", data[idx : idx + 2])[0] - if self._inas[i]['type'] == Spower.INA_POWER: - val = raw_val * self._inas[i]['uWscale'] - elif self._inas[i]['type'] == Spower.INA_BUSV: - val = raw_val * self._inas[i]['mVscale'] - elif self._inas[i]['type'] == Spower.INA_CURRENT: - val = raw_val * self._inas[i]['uAscale'] - elif self._inas[i]['type'] == Spower.INA_SHUNTV: - val = raw_val * self._inas[i]['uVscale'] + if self._inas[i]["type"] == Spower.INA_POWER: + val = raw_val * self._inas[i]["uWscale"] + elif self._inas[i]["type"] == Spower.INA_BUSV: + val = raw_val * self._inas[i]["mVscale"] + elif self._inas[i]["type"] == Spower.INA_CURRENT: + val = raw_val * self._inas[i]["uAscale"] + elif self._inas[i]["type"] == Spower.INA_SHUNTV: + val = raw_val * self._inas[i]["uVscale"] - self._logger.debug("READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, - raw_val, val) - record[name_tuple] = val + self._logger.debug( + "READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, raw_val, val + ) + record[name_tuple] = val - return record + return record - def load_board(self, brdfile): - """Load a board config. + def load_board(self, brdfile): + """Load a board config. - Args: - brdfile: Filename of a json file decribing the INA wiring of this board. - """ - with open(process_filename(brdfile)) as data_file: - data = json.load(data_file) + Args: + brdfile: Filename of a json file decribing the INA wiring of this board. + """ + with open(process_filename(brdfile)) as data_file: + data = json.load(data_file) - #TODO: validate this. - self._brdcfg = data; - self._logger.debug(pprint.pformat(data)) + # TODO: validate this. + self._brdcfg = data + self._logger.debug(pprint.pformat(data)) class powerlog(object): - """Power class to log aggregated power. - - Usage: - obj = powerlog() + """Power class to log aggregated power. - Instance Variables: - _data: a StatsManager object that records sweetberry readings and calculates - statistics. - _pwr[]: Spower objects for individual sweetberries. - """ - - def __init__(self, brdfile, cfgfile, serial_a=None, serial_b=None, - sync_date=False, use_ms=False, use_mW=False, print_stats=False, - stats_dir=None, stats_json_dir=None, print_raw_data=True, - raw_data_dir=None): - """Init the powerlog class and set the variables. - - Args: - brdfile: string name of json file containing board layout. - cfgfile: string name of json containing list of rails to read. - serial_a: serial number of sweetberry A. - serial_b: serial number of sweetberry B. - sync_date: report timestamps synced with host datetime. - use_ms: report timestamps in ms rather than us. - use_mW: report power as milliwatts, otherwise default to microwatts. - print_stats: print statistics for sweetberry readings at the end. - stats_dir: directory to save sweetberry readings statistics; if None then - do not save the statistics. - stats_json_dir: directory to save means of sweetberry readings in json - format; if None then do not save the statistics. - print_raw_data: print sweetberry readings raw data in real time, default - is to print. - raw_data_dir: directory to save sweetberry readings raw data; if None then - do not save the raw data. - """ - self._logger = logging.getLogger(__name__) - self._data = StatsManager() - self._pwr = {} - self._use_ms = use_ms - self._use_mW = use_mW - self._print_stats = print_stats - self._stats_dir = stats_dir - self._stats_json_dir = stats_json_dir - self._print_raw_data = print_raw_data - self._raw_data_dir = raw_data_dir - - if not serial_a and not serial_b: - self._pwr['A'] = Spower('A') - if serial_a: - self._pwr['A'] = Spower('A', serialname=serial_a) - if serial_b: - self._pwr['B'] = Spower('B', serialname=serial_b) - - with open(process_filename(cfgfile)) as data_file: - names = json.load(data_file) - self._names = self.process_scenario(names) - - for key in self._pwr: - self._pwr[key].load_board(brdfile) - self._pwr[key].reset() - - # Allocate the rails to the appropriate boards. - used_boards = [] - for name in self._names: - success = False - for key in self._pwr.keys(): - if self._pwr[key].add_ina_name(name): - success = True - if key not in used_boards: - used_boards.append(key) - if not success: - raise Exception("Failed to add %s (maybe missing " - "sweetberry, or bad board file?)" % name) - - # Evict unused boards. - for key in list(self._pwr.keys()): - if key not in used_boards: - self._pwr.pop(key) - - for key in self._pwr.keys(): - if sync_date: - self._pwr[key].set_time(time.time() * 1000000) - else: - self._pwr[key].set_time(0) - - def process_scenario(self, name_list): - """Return list of tuples indicating name and type. + Usage: + obj = powerlog() - Args: - json originated list of names, or [name, type] - Returns: - list of tuples of (name, type) defaulting to type "POWER" - Raises: exception, invalid INA type. + Instance Variables: + _data: a StatsManager object that records sweetberry readings and calculates + statistics. + _pwr[]: Spower objects for individual sweetberries. """ - names = [] - for entry in name_list: - if isinstance(entry, list): - name = entry[0] - if entry[1] == "POWER": - type = Spower.INA_POWER - elif entry[1] == "BUSV": - type = Spower.INA_BUSV - elif entry[1] == "CURRENT": - type = Spower.INA_CURRENT - elif entry[1] == "SHUNTV": - type = Spower.INA_SHUNTV - else: - raise Exception("Invalid INA type", "Type of %s [%s] not recognized," - " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1])) - else: - name = entry - type = Spower.INA_POWER - - names.append((name, type)) - return names - def start(self, integration_us_request, seconds, sync_speed=.8): - """Starts sampling. + def __init__( + self, + brdfile, + cfgfile, + serial_a=None, + serial_b=None, + sync_date=False, + use_ms=False, + use_mW=False, + print_stats=False, + stats_dir=None, + stats_json_dir=None, + print_raw_data=True, + raw_data_dir=None, + ): + """Init the powerlog class and set the variables. + + Args: + brdfile: string name of json file containing board layout. + cfgfile: string name of json containing list of rails to read. + serial_a: serial number of sweetberry A. + serial_b: serial number of sweetberry B. + sync_date: report timestamps synced with host datetime. + use_ms: report timestamps in ms rather than us. + use_mW: report power as milliwatts, otherwise default to microwatts. + print_stats: print statistics for sweetberry readings at the end. + stats_dir: directory to save sweetberry readings statistics; if None then + do not save the statistics. + stats_json_dir: directory to save means of sweetberry readings in json + format; if None then do not save the statistics. + print_raw_data: print sweetberry readings raw data in real time, default + is to print. + raw_data_dir: directory to save sweetberry readings raw data; if None then + do not save the raw data. + """ + self._logger = logging.getLogger(__name__) + self._data = StatsManager() + self._pwr = {} + self._use_ms = use_ms + self._use_mW = use_mW + self._print_stats = print_stats + self._stats_dir = stats_dir + self._stats_json_dir = stats_json_dir + self._print_raw_data = print_raw_data + self._raw_data_dir = raw_data_dir + + if not serial_a and not serial_b: + self._pwr["A"] = Spower("A") + if serial_a: + self._pwr["A"] = Spower("A", serialname=serial_a) + if serial_b: + self._pwr["B"] = Spower("B", serialname=serial_b) + + with open(process_filename(cfgfile)) as data_file: + names = json.load(data_file) + self._names = self.process_scenario(names) - Args: - integration_us_request: requested interval between sample values. - seconds: time until exit, or None to run until cancel. - sync_speed: A usb request is sent every [.8] * integration_us. - """ - # We will get back the actual integration us. - # It should be the same for all devices. - integration_us = None - for key in self._pwr: - integration_us_new = self._pwr[key].start(integration_us_request) - if integration_us: - if integration_us != integration_us_new: - raise Exception("FAIL", - "Integration on A: %dus != integration on B %dus" % ( - integration_us, integration_us_new)) - integration_us = integration_us_new - - # CSV header - title = "ts:%dus" % integration_us - for name_tuple in self._names: - name, ina_type = name_tuple - - if ina_type == Spower.INA_POWER: - unit = "mW" if self._use_mW else "uW" - elif ina_type == Spower.INA_BUSV: - unit = "mV" - elif ina_type == Spower.INA_CURRENT: - unit = "uA" - elif ina_type == Spower.INA_SHUNTV: - unit = "uV" - - title += ", %s %s" % (name, unit) - name_type = name + Spower.INA_SUFFIX[ina_type] - self._data.SetUnit(name_type, unit) - title += ", status" - if self._print_raw_data: - logoutput(title) - - forever = False - if not seconds: - forever = True - end_time = time.time() + seconds - try: - pending_records = [] - while forever or end_time > time.time(): - if (integration_us > 5000): - time.sleep((integration_us / 1000000.) * sync_speed) for key in self._pwr: - records = self._pwr[key].read_line() - if not records: - continue - - for record in records: - pending_records.append(record) - - pending_records.sort(key=lambda r: r['ts']) - - aggregate_record = {"boards": set()} - for record in pending_records: - if record["berry"] not in aggregate_record["boards"]: - for rkey in record.keys(): - aggregate_record[rkey] = record[rkey] - aggregate_record["boards"].add(record["berry"]) - else: - self._logger.info("break %s, %s", record["berry"], - aggregate_record["boards"]) - break - - if aggregate_record["boards"] == set(self._pwr.keys()): - csv = "%f" % aggregate_record["ts"] - for name in self._names: - if name in aggregate_record: - multiplier = 0.001 if (self._use_mW and - name[1]==Spower.INA_POWER) else 1 - value = aggregate_record[name] * multiplier - csv += ", %.2f" % value - name_type = name[0] + Spower.INA_SUFFIX[name[1]] - self._data.AddSample(name_type, value) - else: - csv += ", " - csv += ", %d" % aggregate_record["status"] - if self._print_raw_data: - logoutput(csv) - - aggregate_record = {"boards": set()} - for r in range(0, len(self._pwr)): - pending_records.pop(0) - - except KeyboardInterrupt: - self._logger.info('\nCTRL+C caught.') - - finally: - for key in self._pwr: - self._pwr[key].stop() - self._data.CalculateStats() - if self._print_stats: - print(self._data.SummaryToString()) - save_dir = 'sweetberry%s' % time.time() - if self._stats_dir: - stats_dir = os.path.join(self._stats_dir, save_dir) - self._data.SaveSummary(stats_dir) - if self._stats_json_dir: - stats_json_dir = os.path.join(self._stats_json_dir, save_dir) - self._data.SaveSummaryJSON(stats_json_dir) - if self._raw_data_dir: - raw_data_dir = os.path.join(self._raw_data_dir, save_dir) - self._data.SaveRawData(raw_data_dir) + self._pwr[key].load_board(brdfile) + self._pwr[key].reset() + + # Allocate the rails to the appropriate boards. + used_boards = [] + for name in self._names: + success = False + for key in self._pwr.keys(): + if self._pwr[key].add_ina_name(name): + success = True + if key not in used_boards: + used_boards.append(key) + if not success: + raise Exception( + "Failed to add %s (maybe missing " + "sweetberry, or bad board file?)" % name + ) + + # Evict unused boards. + for key in list(self._pwr.keys()): + if key not in used_boards: + self._pwr.pop(key) + + for key in self._pwr.keys(): + if sync_date: + self._pwr[key].set_time(time.time() * 1000000) + else: + self._pwr[key].set_time(0) + + def process_scenario(self, name_list): + """Return list of tuples indicating name and type. + + Args: + json originated list of names, or [name, type] + Returns: + list of tuples of (name, type) defaulting to type "POWER" + Raises: exception, invalid INA type. + """ + names = [] + for entry in name_list: + if isinstance(entry, list): + name = entry[0] + if entry[1] == "POWER": + type = Spower.INA_POWER + elif entry[1] == "BUSV": + type = Spower.INA_BUSV + elif entry[1] == "CURRENT": + type = Spower.INA_CURRENT + elif entry[1] == "SHUNTV": + type = Spower.INA_SHUNTV + else: + raise Exception( + "Invalid INA type", + "Type of %s [%s] not recognized," + " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1]), + ) + else: + name = entry + type = Spower.INA_POWER + + names.append((name, type)) + return names + + def start(self, integration_us_request, seconds, sync_speed=0.8): + """Starts sampling. + + Args: + integration_us_request: requested interval between sample values. + seconds: time until exit, or None to run until cancel. + sync_speed: A usb request is sent every [.8] * integration_us. + """ + # We will get back the actual integration us. + # It should be the same for all devices. + integration_us = None + for key in self._pwr: + integration_us_new = self._pwr[key].start(integration_us_request) + if integration_us: + if integration_us != integration_us_new: + raise Exception( + "FAIL", + "Integration on A: %dus != integration on B %dus" + % (integration_us, integration_us_new), + ) + integration_us = integration_us_new + + # CSV header + title = "ts:%dus" % integration_us + for name_tuple in self._names: + name, ina_type = name_tuple + + if ina_type == Spower.INA_POWER: + unit = "mW" if self._use_mW else "uW" + elif ina_type == Spower.INA_BUSV: + unit = "mV" + elif ina_type == Spower.INA_CURRENT: + unit = "uA" + elif ina_type == Spower.INA_SHUNTV: + unit = "uV" + + title += ", %s %s" % (name, unit) + name_type = name + Spower.INA_SUFFIX[ina_type] + self._data.SetUnit(name_type, unit) + title += ", status" + if self._print_raw_data: + logoutput(title) + + forever = False + if not seconds: + forever = True + end_time = time.time() + seconds + try: + pending_records = [] + while forever or end_time > time.time(): + if integration_us > 5000: + time.sleep((integration_us / 1000000.0) * sync_speed) + for key in self._pwr: + records = self._pwr[key].read_line() + if not records: + continue + + for record in records: + pending_records.append(record) + + pending_records.sort(key=lambda r: r["ts"]) + + aggregate_record = {"boards": set()} + for record in pending_records: + if record["berry"] not in aggregate_record["boards"]: + for rkey in record.keys(): + aggregate_record[rkey] = record[rkey] + aggregate_record["boards"].add(record["berry"]) + else: + self._logger.info( + "break %s, %s", record["berry"], aggregate_record["boards"] + ) + break + + if aggregate_record["boards"] == set(self._pwr.keys()): + csv = "%f" % aggregate_record["ts"] + for name in self._names: + if name in aggregate_record: + multiplier = ( + 0.001 + if (self._use_mW and name[1] == Spower.INA_POWER) + else 1 + ) + value = aggregate_record[name] * multiplier + csv += ", %.2f" % value + name_type = name[0] + Spower.INA_SUFFIX[name[1]] + self._data.AddSample(name_type, value) + else: + csv += ", " + csv += ", %d" % aggregate_record["status"] + if self._print_raw_data: + logoutput(csv) + + aggregate_record = {"boards": set()} + for r in range(0, len(self._pwr)): + pending_records.pop(0) + + except KeyboardInterrupt: + self._logger.info("\nCTRL+C caught.") + + finally: + for key in self._pwr: + self._pwr[key].stop() + self._data.CalculateStats() + if self._print_stats: + print(self._data.SummaryToString()) + save_dir = "sweetberry%s" % time.time() + if self._stats_dir: + stats_dir = os.path.join(self._stats_dir, save_dir) + self._data.SaveSummary(stats_dir) + if self._stats_json_dir: + stats_json_dir = os.path.join(self._stats_json_dir, save_dir) + self._data.SaveSummaryJSON(stats_json_dir) + if self._raw_data_dir: + raw_data_dir = os.path.join(self._raw_data_dir, save_dir) + self._data.SaveRawData(raw_data_dir) def main(argv=None): - if argv is None: - argv = sys.argv[1:] - # Command line argument description. - parser = argparse.ArgumentParser( - description="Gather CSV data from sweetberry") - parser.add_argument('-b', '--board', type=str, - help="Board configuration file, eg. my.board", default="") - parser.add_argument('-c', '--config', type=str, - help="Rail config to monitor, eg my.scenario", default="") - parser.add_argument('-A', '--serial', type=str, - help="Serial number of sweetberry A", default="") - parser.add_argument('-B', '--serial_b', type=str, - help="Serial number of sweetberry B", default="") - parser.add_argument('-t', '--integration_us', type=int, - help="Target integration time for samples", default=100000) - parser.add_argument('-s', '--seconds', type=float, - help="Seconds to run capture", default=0.) - parser.add_argument('--date', default=False, - help="Sync logged timestamp to host date", action="store_true") - parser.add_argument('--ms', default=False, - help="Print timestamp as milliseconds", action="store_true") - parser.add_argument('--mW', default=False, - help="Print power as milliwatts, otherwise default to microwatts", - action="store_true") - parser.add_argument('--slow', default=False, - help="Intentionally overflow", action="store_true") - parser.add_argument('--print_stats', default=False, action="store_true", - help="Print statistics for sweetberry readings at the end") - parser.add_argument('--save_stats', type=str, nargs='?', - dest='stats_dir', metavar='STATS_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save statistics for sweetberry readings to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "stats will be saved to where %(prog)s is located; if this flag is " - "not set, then do not save stats") - parser.add_argument('--save_stats_json', type=str, nargs='?', - dest='stats_json_dir', metavar='STATS_JSON_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save means for sweetberry readings in json to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "stats will be saved to where %(prog)s is located; if this flag is " - "not set, then do not save stats") - parser.add_argument('--no_print_raw_data', - dest='print_raw_data', default=True, action="store_false", - help="Not print raw sweetberry readings at real time, default is to " - "print") - parser.add_argument('--save_raw_data', type=str, nargs='?', - dest='raw_data_dir', metavar='RAW_DATA_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save raw data for sweetberry readings to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "raw data will be saved to where %(prog)s is located; if this flag " - "is not set, then do not save raw data") - parser.add_argument('-v', '--verbose', default=False, - help="Very chatty printout", action="store_true") - - args = parser.parse_args(argv) - - root_logger = logging.getLogger(__name__) - if args.verbose: - root_logger.setLevel(logging.DEBUG) - else: - root_logger.setLevel(logging.INFO) - - # if powerlog is used through main, log to sys.stdout - if __name__ == "__main__": - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) - root_logger.addHandler(stdout_handler) - - integration_us_request = args.integration_us - if not args.board: - raise Exception("Power", "No board file selected, see board.README") - if not args.config: - raise Exception("Power", "No config file selected, see board.README") - - brdfile = args.board - cfgfile = args.config - seconds = args.seconds - serial_a = args.serial - serial_b = args.serial_b - sync_date = args.date - use_ms = args.ms - use_mW = args.mW - print_stats = args.print_stats - stats_dir = args.stats_dir - stats_json_dir = args.stats_json_dir - print_raw_data = args.print_raw_data - raw_data_dir = args.raw_data_dir - - boards = [] - - sync_speed = .8 - if args.slow: - sync_speed = 1.2 - - # Set up logging interface. - powerlogger = powerlog(brdfile, cfgfile, serial_a=serial_a, serial_b=serial_b, - sync_date=sync_date, use_ms=use_ms, use_mW=use_mW, - print_stats=print_stats, stats_dir=stats_dir, - stats_json_dir=stats_json_dir, - print_raw_data=print_raw_data,raw_data_dir=raw_data_dir) - - # Start logging. - powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed) + if argv is None: + argv = sys.argv[1:] + # Command line argument description. + parser = argparse.ArgumentParser(description="Gather CSV data from sweetberry") + parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration file, eg. my.board", + default="", + ) + parser.add_argument( + "-c", + "--config", + type=str, + help="Rail config to monitor, eg my.scenario", + default="", + ) + parser.add_argument( + "-A", "--serial", type=str, help="Serial number of sweetberry A", default="" + ) + parser.add_argument( + "-B", "--serial_b", type=str, help="Serial number of sweetberry B", default="" + ) + parser.add_argument( + "-t", + "--integration_us", + type=int, + help="Target integration time for samples", + default=100000, + ) + parser.add_argument( + "-s", "--seconds", type=float, help="Seconds to run capture", default=0.0 + ) + parser.add_argument( + "--date", + default=False, + help="Sync logged timestamp to host date", + action="store_true", + ) + parser.add_argument( + "--ms", + default=False, + help="Print timestamp as milliseconds", + action="store_true", + ) + parser.add_argument( + "--mW", + default=False, + help="Print power as milliwatts, otherwise default to microwatts", + action="store_true", + ) + parser.add_argument( + "--slow", default=False, help="Intentionally overflow", action="store_true" + ) + parser.add_argument( + "--print_stats", + default=False, + action="store_true", + help="Print statistics for sweetberry readings at the end", + ) + parser.add_argument( + "--save_stats", + type=str, + nargs="?", + dest="stats_dir", + metavar="STATS_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save statistics for sweetberry readings to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "stats will be saved to where %(prog)s is located; if this flag is " + "not set, then do not save stats", + ) + parser.add_argument( + "--save_stats_json", + type=str, + nargs="?", + dest="stats_json_dir", + metavar="STATS_JSON_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save means for sweetberry readings in json to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "stats will be saved to where %(prog)s is located; if this flag is " + "not set, then do not save stats", + ) + parser.add_argument( + "--no_print_raw_data", + dest="print_raw_data", + default=True, + action="store_false", + help="Not print raw sweetberry readings at real time, default is to " "print", + ) + parser.add_argument( + "--save_raw_data", + type=str, + nargs="?", + dest="raw_data_dir", + metavar="RAW_DATA_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save raw data for sweetberry readings to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "raw data will be saved to where %(prog)s is located; if this flag " + "is not set, then do not save raw data", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + help="Very chatty printout", + action="store_true", + ) + + args = parser.parse_args(argv) + + root_logger = logging.getLogger(__name__) + if args.verbose: + root_logger.setLevel(logging.DEBUG) + else: + root_logger.setLevel(logging.INFO) + + # if powerlog is used through main, log to sys.stdout + if __name__ == "__main__": + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + root_logger.addHandler(stdout_handler) + + integration_us_request = args.integration_us + if not args.board: + raise Exception("Power", "No board file selected, see board.README") + if not args.config: + raise Exception("Power", "No config file selected, see board.README") + + brdfile = args.board + cfgfile = args.config + seconds = args.seconds + serial_a = args.serial + serial_b = args.serial_b + sync_date = args.date + use_ms = args.ms + use_mW = args.mW + print_stats = args.print_stats + stats_dir = args.stats_dir + stats_json_dir = args.stats_json_dir + print_raw_data = args.print_raw_data + raw_data_dir = args.raw_data_dir + + boards = [] + + sync_speed = 0.8 + if args.slow: + sync_speed = 1.2 + + # Set up logging interface. + powerlogger = powerlog( + brdfile, + cfgfile, + serial_a=serial_a, + serial_b=serial_b, + sync_date=sync_date, + use_ms=use_ms, + use_mW=use_mW, + print_stats=print_stats, + stats_dir=stats_dir, + stats_json_dir=stats_json_dir, + print_raw_data=print_raw_data, + raw_data_dir=raw_data_dir, + ) + + # Start logging. + powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed) if __name__ == "__main__": - main() + main() diff --git a/extra/usb_power/powerlog_unittest.py b/extra/usb_power/powerlog_unittest.py index 1d0718530e..693826c16d 100644 --- a/extra/usb_power/powerlog_unittest.py +++ b/extra/usb_power/powerlog_unittest.py @@ -15,40 +15,42 @@ import unittest import powerlog + class TestPowerlog(unittest.TestCase): - """Test to verify powerlog util methods work as expected.""" - - def setUp(self): - """Set up data and create a temporary directory to save data and stats.""" - self.tempdir = tempfile.mkdtemp() - self.filename = 'testfile' - self.filepath = os.path.join(self.tempdir, self.filename) - with open(self.filepath, 'w') as f: - f.write('') - - def tearDown(self): - """Delete the temporary directory and its content.""" - shutil.rmtree(self.tempdir) - - def test_ProcessFilenameAbsoluteFilePath(self): - """Absolute file path is returned unchanged.""" - processed_fname = powerlog.process_filename(self.filepath) - self.assertEqual(self.filepath, processed_fname) - - def test_ProcessFilenameRelativeFilePath(self): - """Finds relative file path inside a known config location.""" - original = powerlog.CONFIG_LOCATIONS - powerlog.CONFIG_LOCATIONS = [self.tempdir] - processed_fname = powerlog.process_filename(self.filename) - try: - self.assertEqual(self.filepath, processed_fname) - finally: - powerlog.CONFIG_LOCATIONS = original - - def test_ProcessFilenameInvalid(self): - """IOError is raised when file cannot be found by any of the four ways.""" - with self.assertRaises(IOError): - powerlog.process_filename(self.filename) - -if __name__ == '__main__': - unittest.main() + """Test to verify powerlog util methods work as expected.""" + + def setUp(self): + """Set up data and create a temporary directory to save data and stats.""" + self.tempdir = tempfile.mkdtemp() + self.filename = "testfile" + self.filepath = os.path.join(self.tempdir, self.filename) + with open(self.filepath, "w") as f: + f.write("") + + def tearDown(self): + """Delete the temporary directory and its content.""" + shutil.rmtree(self.tempdir) + + def test_ProcessFilenameAbsoluteFilePath(self): + """Absolute file path is returned unchanged.""" + processed_fname = powerlog.process_filename(self.filepath) + self.assertEqual(self.filepath, processed_fname) + + def test_ProcessFilenameRelativeFilePath(self): + """Finds relative file path inside a known config location.""" + original = powerlog.CONFIG_LOCATIONS + powerlog.CONFIG_LOCATIONS = [self.tempdir] + processed_fname = powerlog.process_filename(self.filename) + try: + self.assertEqual(self.filepath, processed_fname) + finally: + powerlog.CONFIG_LOCATIONS = original + + def test_ProcessFilenameInvalid(self): + """IOError is raised when file cannot be found by any of the four ways.""" + with self.assertRaises(IOError): + powerlog.process_filename(self.filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/usb_power/stats_manager.py b/extra/usb_power/stats_manager.py index 0f8c3fcb15..633312311d 100644 --- a/extra/usb_power/stats_manager.py +++ b/extra/usb_power/stats_manager.py @@ -20,382 +20,391 @@ import os import numpy -STATS_PREFIX = '@@' -NAN_TAG = '*' -NAN_DESCRIPTION = '%s domains contain NaN samples' % NAN_TAG +STATS_PREFIX = "@@" +NAN_TAG = "*" +NAN_DESCRIPTION = "%s domains contain NaN samples" % NAN_TAG LONG_UNIT = { - '': 'N/A', - 'mW': 'milliwatt', - 'uW': 'microwatt', - 'mV': 'millivolt', - 'uA': 'microamp', - 'uV': 'microvolt' + "": "N/A", + "mW": "milliwatt", + "uW": "microwatt", + "mV": "millivolt", + "uA": "microamp", + "uV": "microvolt", } class StatsManagerError(Exception): - """Errors in StatsManager class.""" - pass + """Errors in StatsManager class.""" + pass -class StatsManager(object): - """Calculates statistics for several lists of data(float). - - Example usage: - - >>> stats = StatsManager(title='Title Banner') - >>> stats.AddSample(TIME_KEY, 50.0) - >>> stats.AddSample(TIME_KEY, 25.0) - >>> stats.AddSample(TIME_KEY, 40.0) - >>> stats.AddSample(TIME_KEY, 10.0) - >>> stats.AddSample(TIME_KEY, 10.0) - >>> stats.AddSample('frobnicate', 11.5) - >>> stats.AddSample('frobnicate', 9.0) - >>> stats.AddSample('foobar', 11111.0) - >>> stats.AddSample('foobar', 22222.0) - >>> stats.CalculateStats() - >>> print(stats.SummaryToString()) - ` @@-------------------------------------------------------------- - ` @@ Title Banner - @@-------------------------------------------------------------- - @@ NAME COUNT MEAN STDDEV MAX MIN - @@ sample_msecs 4 31.25 15.16 50.00 10.00 - @@ foobar 2 16666.50 5555.50 22222.00 11111.00 - @@ frobnicate 2 10.25 1.25 11.50 9.00 - ` @@-------------------------------------------------------------- - - Attributes: - _data: dict of list of readings for each domain(key) - _unit: dict of unit for each domain(key) - _smid: id supplied to differentiate data output to other StatsManager - instances that potentially save to the same directory - if smid all output files will be named |smid|_|fname| - _title: title to add as banner to formatted summary. If no title, - no banner gets added - _order: list of formatting order for domains. Domains not listed are - displayed in sorted order - _hide_domains: collection of domains to hide when formatting summary string - _accept_nan: flag to indicate if NaN samples are acceptable - _nan_domains: set to keep track of which domains contain NaN samples - _summary: dict of stats per domain (key): min, max, count, mean, stddev - _logger = StatsManager logger - - Note: - _summary is empty until CalculateStats() is called, and is updated when - CalculateStats() is called. - """ - - # pylint: disable=W0102 - def __init__(self, smid='', title='', order=[], hide_domains=[], - accept_nan=True): - """Initialize infrastructure for data and their statistics.""" - self._title = title - self._data = collections.defaultdict(list) - self._unit = collections.defaultdict(str) - self._smid = smid - self._order = order - self._hide_domains = hide_domains - self._accept_nan = accept_nan - self._nan_domains = set() - self._summary = {} - self._logger = logging.getLogger(type(self).__name__) - - def AddSample(self, domain, sample): - """Add one sample for a domain. - - Args: - domain: the domain name for the sample. - sample: one time sample for domain, expect type float. - - Raises: - StatsManagerError: if trying to add NaN and |_accept_nan| is false - """ - try: - sample = float(sample) - except ValueError: - # if we don't accept nan this will be caught below - self._logger.debug('sample %s for domain %s is not a number. Making NaN', - sample, domain) - sample = float('NaN') - if not self._accept_nan and math.isnan(sample): - raise StatsManagerError('accept_nan is false. Cannot add NaN sample.') - self._data[domain].append(sample) - if math.isnan(sample): - self._nan_domains.add(domain) - - def SetUnit(self, domain, unit): - """Set the unit for a domain. - - There can be only one unit for each domain. Setting unit twice will - overwrite the original unit. - - Args: - domain: the domain name. - unit: unit of the domain. - """ - if domain in self._unit: - self._logger.warning('overwriting the unit of %s, old unit is %s, new ' - 'unit is %s.', domain, self._unit[domain], unit) - self._unit[domain] = unit - def CalculateStats(self): - """Calculate stats for all domain-data pairs. - - First erases all previous stats, then calculate stats for all data. - """ - self._summary = {} - for domain, data in self._data.items(): - data_np = numpy.array(data) - self._summary[domain] = { - 'mean': numpy.nanmean(data_np), - 'min': numpy.nanmin(data_np), - 'max': numpy.nanmax(data_np), - 'stddev': numpy.nanstd(data_np), - 'count': data_np.size, - } - - @property - def DomainsToDisplay(self): - """List of domains that the manager will output in summaries.""" - return set(self._summary.keys()) - set(self._hide_domains) - - @property - def NanInOutput(self): - """Return whether any of the domains to display have NaN values.""" - return bool(len(set(self._nan_domains) & self.DomainsToDisplay)) - - def _SummaryTable(self): - """Generate the matrix to output as a summary. - - Returns: - A 2d matrix of headers and their data for each domain - e.g. - [[NAME, COUNT, MEAN, STDDEV, MAX, MIN], - [pp5000_mw, 10, 50, 0, 50, 50]] - """ - headers = ('NAME', 'COUNT', 'MEAN', 'STDDEV', 'MAX', 'MIN') - table = [headers] - # determine what domains to display & and the order - domains_to_display = self.DomainsToDisplay - display_order = [key for key in self._order if key in domains_to_display] - domains_to_display -= set(display_order) - display_order.extend(sorted(domains_to_display)) - for domain in display_order: - stats = self._summary[domain] - if not domain.endswith(self._unit[domain]): - domain = '%s_%s' % (domain, self._unit[domain]) - if domain in self._nan_domains: - domain = '%s%s' % (domain, NAN_TAG) - row = [domain] - row.append(str(stats['count'])) - for entry in headers[2:]: - row.append('%.2f' % stats[entry.lower()]) - table.append(row) - return table - - def SummaryToMarkdownString(self): - """Format the summary into a b/ compatible markdown table string. - - This requires this sort of output format - - | header1 | header2 | header3 | ... - | --------- | --------- | --------- | ... - | sample1h1 | sample1h2 | sample1h3 | ... - . - . - . - - Returns: - formatted summary string. - """ - # All we need to do before processing is insert a row of '-' between - # the headers, and the data - table = self._SummaryTable() - columns = len(table[0]) - # Using '-:' to allow the numbers to be right aligned - sep_row = ['-'] + ['-:'] * (columns - 1) - table.insert(1, sep_row) - text_rows = ['|'.join(r) for r in table] - body = '\n'.join(['|%s|' % r for r in text_rows]) - if self._title: - title_section = '**%s** \n\n' % self._title - body = title_section + body - # Make sure that the body is terminated with a newline. - return body + '\n' - - def SummaryToString(self, prefix=STATS_PREFIX): - """Format summary into a string, ready for pretty print. - - See class description for format example. - - Args: - prefix: start every row in summary string with prefix, for easier reading. - - Returns: - formatted summary string. - """ - table = self._SummaryTable() - max_col_width = [] - for col_idx in range(len(table[0])): - col_item_widths = [len(row[col_idx]) for row in table] - max_col_width.append(max(col_item_widths)) - - formatted_lines = [] - for row in table: - formatted_row = prefix + ' ' - for i in range(len(row)): - formatted_row += row[i].rjust(max_col_width[i] + 2) - formatted_lines.append(formatted_row) - if self.NanInOutput: - formatted_lines.append('%s %s' % (prefix, NAN_DESCRIPTION)) - - if self._title: - line_length = len(formatted_lines[0]) - dec_length = len(prefix) - # trim title to be at most as long as the longest line without the prefix - title = self._title[:(line_length - dec_length)] - # line is a seperator line consisting of ----- - line = '%s%s' % (prefix, '-' * (line_length - dec_length)) - # prepend the prefix to the centered title - padded_title = '%s%s' % (prefix, title.center(line_length)[dec_length:]) - formatted_lines = [line, padded_title, line] + formatted_lines + [line] - formatted_output = '\n'.join(formatted_lines) - return formatted_output - - def GetSummary(self): - """Getter for summary.""" - return self._summary - - def _MakeUniqueFName(self, fname): - """prepend |_smid| to fname & rotate fname to ensure uniqueness. - - Before saving a file through the StatsManager, make sure that the filename - is unique, first by prepending the smid if any and otherwise by appending - increasing integer suffixes until the filename is unique. - - If |smid| is defined /path/to/example/file.txt becomes - /path/to/example/{smid}_file.txt. - - The rotation works by changing /path/to/example/somename.txt to - /path/to/example/somename1.txt if the first one already exists on the - system. - - Note: this is not thread-safe. While it makes sense to use StatsManager - in a threaded data-collection, the data retrieval should happen in a - single threaded environment to ensure files don't get potentially clobbered. - - Args: - fname: filename to ensure uniqueness. - - Returns: - {smid_}fname{tag}.[b].ext - the smid portion gets prepended if |smid| is defined - the tag portion gets appended if necessary to ensure unique fname - """ - fdir = os.path.dirname(fname) - base, ext = os.path.splitext(os.path.basename(fname)) - if self._smid: - base = '%s_%s' % (self._smid, base) - unique_fname = os.path.join(fdir, '%s%s' % (base, ext)) - tag = 0 - while os.path.exists(unique_fname): - old_fname = unique_fname - unique_fname = os.path.join(fdir, '%s%d%s' % (base, tag, ext)) - self._logger.warning('Attempted to store stats information at %s, but ' - 'file already exists. Attempting to store at %s ' - 'now.', old_fname, unique_fname) - tag += 1 - return unique_fname - - def SaveSummary(self, directory, fname='summary.txt', prefix=STATS_PREFIX): - """Save summary to file. - - Args: - directory: directory to save the summary in. - fname: filename to save summary under. - prefix: start every row in summary string with prefix, for easier reading. - - Returns: - full path of summary save location - """ - summary_str = self.SummaryToString(prefix=prefix) + '\n' - return self._SaveSummary(summary_str, directory, fname) - - def SaveSummaryJSON(self, directory, fname='summary.json'): - """Save summary (only MEAN) into a JSON file. - - Args: - directory: directory to save the JSON summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location - """ - data = {} - for domain in self._summary: - unit = LONG_UNIT.get(self._unit[domain], self._unit[domain]) - data_entry = {'mean': self._summary[domain]['mean'], 'unit': unit} - data[domain] = data_entry - summary_str = json.dumps(data, indent=2) - return self._SaveSummary(summary_str, directory, fname) - - def SaveSummaryMD(self, directory, fname='summary.md'): - """Save summary into a MD file to paste into b/. - - Args: - directory: directory to save the MD summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location +class StatsManager(object): + """Calculates statistics for several lists of data(float). + + Example usage: + + >>> stats = StatsManager(title='Title Banner') + >>> stats.AddSample(TIME_KEY, 50.0) + >>> stats.AddSample(TIME_KEY, 25.0) + >>> stats.AddSample(TIME_KEY, 40.0) + >>> stats.AddSample(TIME_KEY, 10.0) + >>> stats.AddSample(TIME_KEY, 10.0) + >>> stats.AddSample('frobnicate', 11.5) + >>> stats.AddSample('frobnicate', 9.0) + >>> stats.AddSample('foobar', 11111.0) + >>> stats.AddSample('foobar', 22222.0) + >>> stats.CalculateStats() + >>> print(stats.SummaryToString()) + ` @@-------------------------------------------------------------- + ` @@ Title Banner + @@-------------------------------------------------------------- + @@ NAME COUNT MEAN STDDEV MAX MIN + @@ sample_msecs 4 31.25 15.16 50.00 10.00 + @@ foobar 2 16666.50 5555.50 22222.00 11111.00 + @@ frobnicate 2 10.25 1.25 11.50 9.00 + ` @@-------------------------------------------------------------- + + Attributes: + _data: dict of list of readings for each domain(key) + _unit: dict of unit for each domain(key) + _smid: id supplied to differentiate data output to other StatsManager + instances that potentially save to the same directory + if smid all output files will be named |smid|_|fname| + _title: title to add as banner to formatted summary. If no title, + no banner gets added + _order: list of formatting order for domains. Domains not listed are + displayed in sorted order + _hide_domains: collection of domains to hide when formatting summary string + _accept_nan: flag to indicate if NaN samples are acceptable + _nan_domains: set to keep track of which domains contain NaN samples + _summary: dict of stats per domain (key): min, max, count, mean, stddev + _logger = StatsManager logger + + Note: + _summary is empty until CalculateStats() is called, and is updated when + CalculateStats() is called. """ - summary_str = self.SummaryToMarkdownString() - return self._SaveSummary(summary_str, directory, fname) - def _SaveSummary(self, output_str, directory, fname): - """Wrote |output_str| to |fname|. - - Args: - output_str: formatted output string - directory: directory to save the summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location - """ - if not os.path.exists(directory): - os.makedirs(directory) - fname = self._MakeUniqueFName(os.path.join(directory, fname)) - with open(fname, 'w') as f: - f.write(output_str) - return fname - - def GetRawData(self): - """Getter for all raw_data.""" - return self._data - - def SaveRawData(self, directory, dirname='raw_data'): - """Save raw data to file. - - Args: - directory: directory to create the raw data folder in. - dirname: folder in which raw data live. - - Returns: - list of full path of each domain's raw data save location - """ - if not os.path.exists(directory): - os.makedirs(directory) - dirname = os.path.join(directory, dirname) - if not os.path.exists(dirname): - os.makedirs(dirname) - fnames = [] - for domain, data in self._data.items(): - if not domain.endswith(self._unit[domain]): - domain = '%s_%s' % (domain, self._unit[domain]) - fname = self._MakeUniqueFName(os.path.join(dirname, '%s.txt' % domain)) - with open(fname, 'w') as f: - f.write('\n'.join('%.2f' % sample for sample in data) + '\n') - fnames.append(fname) - return fnames + # pylint: disable=W0102 + def __init__(self, smid="", title="", order=[], hide_domains=[], accept_nan=True): + """Initialize infrastructure for data and their statistics.""" + self._title = title + self._data = collections.defaultdict(list) + self._unit = collections.defaultdict(str) + self._smid = smid + self._order = order + self._hide_domains = hide_domains + self._accept_nan = accept_nan + self._nan_domains = set() + self._summary = {} + self._logger = logging.getLogger(type(self).__name__) + + def AddSample(self, domain, sample): + """Add one sample for a domain. + + Args: + domain: the domain name for the sample. + sample: one time sample for domain, expect type float. + + Raises: + StatsManagerError: if trying to add NaN and |_accept_nan| is false + """ + try: + sample = float(sample) + except ValueError: + # if we don't accept nan this will be caught below + self._logger.debug( + "sample %s for domain %s is not a number. Making NaN", sample, domain + ) + sample = float("NaN") + if not self._accept_nan and math.isnan(sample): + raise StatsManagerError("accept_nan is false. Cannot add NaN sample.") + self._data[domain].append(sample) + if math.isnan(sample): + self._nan_domains.add(domain) + + def SetUnit(self, domain, unit): + """Set the unit for a domain. + + There can be only one unit for each domain. Setting unit twice will + overwrite the original unit. + + Args: + domain: the domain name. + unit: unit of the domain. + """ + if domain in self._unit: + self._logger.warning( + "overwriting the unit of %s, old unit is %s, new " "unit is %s.", + domain, + self._unit[domain], + unit, + ) + self._unit[domain] = unit + + def CalculateStats(self): + """Calculate stats for all domain-data pairs. + + First erases all previous stats, then calculate stats for all data. + """ + self._summary = {} + for domain, data in self._data.items(): + data_np = numpy.array(data) + self._summary[domain] = { + "mean": numpy.nanmean(data_np), + "min": numpy.nanmin(data_np), + "max": numpy.nanmax(data_np), + "stddev": numpy.nanstd(data_np), + "count": data_np.size, + } + + @property + def DomainsToDisplay(self): + """List of domains that the manager will output in summaries.""" + return set(self._summary.keys()) - set(self._hide_domains) + + @property + def NanInOutput(self): + """Return whether any of the domains to display have NaN values.""" + return bool(len(set(self._nan_domains) & self.DomainsToDisplay)) + + def _SummaryTable(self): + """Generate the matrix to output as a summary. + + Returns: + A 2d matrix of headers and their data for each domain + e.g. + [[NAME, COUNT, MEAN, STDDEV, MAX, MIN], + [pp5000_mw, 10, 50, 0, 50, 50]] + """ + headers = ("NAME", "COUNT", "MEAN", "STDDEV", "MAX", "MIN") + table = [headers] + # determine what domains to display & and the order + domains_to_display = self.DomainsToDisplay + display_order = [key for key in self._order if key in domains_to_display] + domains_to_display -= set(display_order) + display_order.extend(sorted(domains_to_display)) + for domain in display_order: + stats = self._summary[domain] + if not domain.endswith(self._unit[domain]): + domain = "%s_%s" % (domain, self._unit[domain]) + if domain in self._nan_domains: + domain = "%s%s" % (domain, NAN_TAG) + row = [domain] + row.append(str(stats["count"])) + for entry in headers[2:]: + row.append("%.2f" % stats[entry.lower()]) + table.append(row) + return table + + def SummaryToMarkdownString(self): + """Format the summary into a b/ compatible markdown table string. + + This requires this sort of output format + + | header1 | header2 | header3 | ... + | --------- | --------- | --------- | ... + | sample1h1 | sample1h2 | sample1h3 | ... + . + . + . + + Returns: + formatted summary string. + """ + # All we need to do before processing is insert a row of '-' between + # the headers, and the data + table = self._SummaryTable() + columns = len(table[0]) + # Using '-:' to allow the numbers to be right aligned + sep_row = ["-"] + ["-:"] * (columns - 1) + table.insert(1, sep_row) + text_rows = ["|".join(r) for r in table] + body = "\n".join(["|%s|" % r for r in text_rows]) + if self._title: + title_section = "**%s** \n\n" % self._title + body = title_section + body + # Make sure that the body is terminated with a newline. + return body + "\n" + + def SummaryToString(self, prefix=STATS_PREFIX): + """Format summary into a string, ready for pretty print. + + See class description for format example. + + Args: + prefix: start every row in summary string with prefix, for easier reading. + + Returns: + formatted summary string. + """ + table = self._SummaryTable() + max_col_width = [] + for col_idx in range(len(table[0])): + col_item_widths = [len(row[col_idx]) for row in table] + max_col_width.append(max(col_item_widths)) + + formatted_lines = [] + for row in table: + formatted_row = prefix + " " + for i in range(len(row)): + formatted_row += row[i].rjust(max_col_width[i] + 2) + formatted_lines.append(formatted_row) + if self.NanInOutput: + formatted_lines.append("%s %s" % (prefix, NAN_DESCRIPTION)) + + if self._title: + line_length = len(formatted_lines[0]) + dec_length = len(prefix) + # trim title to be at most as long as the longest line without the prefix + title = self._title[: (line_length - dec_length)] + # line is a seperator line consisting of ----- + line = "%s%s" % (prefix, "-" * (line_length - dec_length)) + # prepend the prefix to the centered title + padded_title = "%s%s" % (prefix, title.center(line_length)[dec_length:]) + formatted_lines = [line, padded_title, line] + formatted_lines + [line] + formatted_output = "\n".join(formatted_lines) + return formatted_output + + def GetSummary(self): + """Getter for summary.""" + return self._summary + + def _MakeUniqueFName(self, fname): + """prepend |_smid| to fname & rotate fname to ensure uniqueness. + + Before saving a file through the StatsManager, make sure that the filename + is unique, first by prepending the smid if any and otherwise by appending + increasing integer suffixes until the filename is unique. + + If |smid| is defined /path/to/example/file.txt becomes + /path/to/example/{smid}_file.txt. + + The rotation works by changing /path/to/example/somename.txt to + /path/to/example/somename1.txt if the first one already exists on the + system. + + Note: this is not thread-safe. While it makes sense to use StatsManager + in a threaded data-collection, the data retrieval should happen in a + single threaded environment to ensure files don't get potentially clobbered. + + Args: + fname: filename to ensure uniqueness. + + Returns: + {smid_}fname{tag}.[b].ext + the smid portion gets prepended if |smid| is defined + the tag portion gets appended if necessary to ensure unique fname + """ + fdir = os.path.dirname(fname) + base, ext = os.path.splitext(os.path.basename(fname)) + if self._smid: + base = "%s_%s" % (self._smid, base) + unique_fname = os.path.join(fdir, "%s%s" % (base, ext)) + tag = 0 + while os.path.exists(unique_fname): + old_fname = unique_fname + unique_fname = os.path.join(fdir, "%s%d%s" % (base, tag, ext)) + self._logger.warning( + "Attempted to store stats information at %s, but " + "file already exists. Attempting to store at %s " + "now.", + old_fname, + unique_fname, + ) + tag += 1 + return unique_fname + + def SaveSummary(self, directory, fname="summary.txt", prefix=STATS_PREFIX): + """Save summary to file. + + Args: + directory: directory to save the summary in. + fname: filename to save summary under. + prefix: start every row in summary string with prefix, for easier reading. + + Returns: + full path of summary save location + """ + summary_str = self.SummaryToString(prefix=prefix) + "\n" + return self._SaveSummary(summary_str, directory, fname) + + def SaveSummaryJSON(self, directory, fname="summary.json"): + """Save summary (only MEAN) into a JSON file. + + Args: + directory: directory to save the JSON summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + data = {} + for domain in self._summary: + unit = LONG_UNIT.get(self._unit[domain], self._unit[domain]) + data_entry = {"mean": self._summary[domain]["mean"], "unit": unit} + data[domain] = data_entry + summary_str = json.dumps(data, indent=2) + return self._SaveSummary(summary_str, directory, fname) + + def SaveSummaryMD(self, directory, fname="summary.md"): + """Save summary into a MD file to paste into b/. + + Args: + directory: directory to save the MD summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + summary_str = self.SummaryToMarkdownString() + return self._SaveSummary(summary_str, directory, fname) + + def _SaveSummary(self, output_str, directory, fname): + """Wrote |output_str| to |fname|. + + Args: + output_str: formatted output string + directory: directory to save the summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + if not os.path.exists(directory): + os.makedirs(directory) + fname = self._MakeUniqueFName(os.path.join(directory, fname)) + with open(fname, "w") as f: + f.write(output_str) + return fname + + def GetRawData(self): + """Getter for all raw_data.""" + return self._data + + def SaveRawData(self, directory, dirname="raw_data"): + """Save raw data to file. + + Args: + directory: directory to create the raw data folder in. + dirname: folder in which raw data live. + + Returns: + list of full path of each domain's raw data save location + """ + if not os.path.exists(directory): + os.makedirs(directory) + dirname = os.path.join(directory, dirname) + if not os.path.exists(dirname): + os.makedirs(dirname) + fnames = [] + for domain, data in self._data.items(): + if not domain.endswith(self._unit[domain]): + domain = "%s_%s" % (domain, self._unit[domain]) + fname = self._MakeUniqueFName(os.path.join(dirname, "%s.txt" % domain)) + with open(fname, "w") as f: + f.write("\n".join("%.2f" % sample for sample in data) + "\n") + fnames.append(fname) + return fnames diff --git a/extra/usb_power/stats_manager_unittest.py b/extra/usb_power/stats_manager_unittest.py index beb9984b93..37a472af98 100644 --- a/extra/usb_power/stats_manager_unittest.py +++ b/extra/usb_power/stats_manager_unittest.py @@ -9,6 +9,7 @@ """Unit tests for StatsManager.""" from __future__ import print_function + import json import os import re @@ -20,296 +21,304 @@ import stats_manager class TestStatsManager(unittest.TestCase): - """Test to verify StatsManager methods work as expected. - - StatsManager should collect raw data, calculate their statistics, and save - them in expected format. - """ - - def _populate_mock_stats(self): - """Create a populated & processed StatsManager to test data retrieval.""" - self.data.AddSample('A', 99999.5) - self.data.AddSample('A', 100000.5) - self.data.SetUnit('A', 'uW') - self.data.SetUnit('A', 'mW') - self.data.AddSample('B', 1.5) - self.data.AddSample('B', 2.5) - self.data.AddSample('B', 3.5) - self.data.SetUnit('B', 'mV') - self.data.CalculateStats() - - def _populate_mock_stats_no_unit(self): - self.data.AddSample('B', 1000) - self.data.AddSample('A', 200) - self.data.SetUnit('A', 'blue') - - def setUp(self): - """Set up StatsManager and create a temporary directory for test.""" - self.tempdir = tempfile.mkdtemp() - self.data = stats_manager.StatsManager() - - def tearDown(self): - """Delete the temporary directory and its content.""" - shutil.rmtree(self.tempdir) - - def test_AddSample(self): - """Adding a sample successfully adds a sample.""" - self.data.AddSample('Test', 1000) - self.data.SetUnit('Test', 'test') - self.data.CalculateStats() - summary = self.data.GetSummary() - self.assertEqual(1, summary['Test']['count']) - - def test_AddSampleNoFloatAcceptNaN(self): - """Adding a non-number adds 'NaN' and doesn't raise an exception.""" - self.data.AddSample('Test', 10) - self.data.AddSample('Test', 20) - # adding a fake NaN: one that gets converted into NaN internally - self.data.AddSample('Test', 'fiesta') - # adding a real NaN - self.data.AddSample('Test', float('NaN')) - self.data.SetUnit('Test', 'test') - self.data.CalculateStats() - summary = self.data.GetSummary() - # assert that 'NaN' as added. - self.assertEqual(4, summary['Test']['count']) - # assert that mean, min, and max calculatings ignore the 'NaN' - self.assertEqual(10, summary['Test']['min']) - self.assertEqual(20, summary['Test']['max']) - self.assertEqual(15, summary['Test']['mean']) - - def test_AddSampleNoFloatNotAcceptNaN(self): - """Adding a non-number raises a StatsManagerError if accept_nan is False.""" - self.data = stats_manager.StatsManager(accept_nan=False) - with self.assertRaisesRegexp(stats_manager.StatsManagerError, - 'accept_nan is false. Cannot add NaN sample.'): - # adding a fake NaN: one that gets converted into NaN internally - self.data.AddSample('Test', 'fiesta') - with self.assertRaisesRegexp(stats_manager.StatsManagerError, - 'accept_nan is false. Cannot add NaN sample.'): - # adding a real NaN - self.data.AddSample('Test', float('NaN')) - - def test_AddSampleNoUnit(self): - """Not adding a unit does not cause an exception on CalculateStats().""" - self.data.AddSample('Test', 17) - self.data.CalculateStats() - summary = self.data.GetSummary() - self.assertEqual(1, summary['Test']['count']) - - def test_UnitSuffix(self): - """Unit gets appended as a suffix in the displayed summary.""" - self.data.AddSample('test', 250) - self.data.SetUnit('test', 'mw') - self.data.CalculateStats() - summary_str = self.data.SummaryToString() - self.assertIn('test_mw', summary_str) - - def test_DoubleUnitSuffix(self): - """If domain already ends in unit, verify that unit doesn't get appended.""" - self.data.AddSample('test_mw', 250) - self.data.SetUnit('test_mw', 'mw') - self.data.CalculateStats() - summary_str = self.data.SummaryToString() - self.assertIn('test_mw', summary_str) - self.assertNotIn('test_mw_mw', summary_str) - - def test_GetRawData(self): - """GetRawData returns exact same data as fed in.""" - self._populate_mock_stats() - raw_data = self.data.GetRawData() - self.assertListEqual([99999.5, 100000.5], raw_data['A']) - self.assertListEqual([1.5, 2.5, 3.5], raw_data['B']) - - def test_GetSummary(self): - """GetSummary returns expected stats about the data fed in.""" - self._populate_mock_stats() - summary = self.data.GetSummary() - self.assertEqual(2, summary['A']['count']) - self.assertAlmostEqual(100000.5, summary['A']['max']) - self.assertAlmostEqual(99999.5, summary['A']['min']) - self.assertAlmostEqual(0.5, summary['A']['stddev']) - self.assertAlmostEqual(100000.0, summary['A']['mean']) - self.assertEqual(3, summary['B']['count']) - self.assertAlmostEqual(3.5, summary['B']['max']) - self.assertAlmostEqual(1.5, summary['B']['min']) - self.assertAlmostEqual(0.81649658092773, summary['B']['stddev']) - self.assertAlmostEqual(2.5, summary['B']['mean']) - - def test_SaveRawData(self): - """SaveRawData stores same data as fed in.""" - self._populate_mock_stats() - dirname = 'unittest_raw_data' - expected_files = set(['A_mW.txt', 'B_mV.txt']) - fnames = self.data.SaveRawData(self.tempdir, dirname) - files_returned = set([os.path.basename(f) for f in fnames]) - # Assert that only the expected files got returned. - self.assertEqual(expected_files, files_returned) - # Assert that only the returned files are in the outdir. - self.assertEqual(set(os.listdir(os.path.join(self.tempdir, dirname))), - files_returned) - for fname in fnames: - with open(fname, 'r') as f: - if 'A_mW' in fname: - self.assertEqual('99999.50', f.readline().strip()) - self.assertEqual('100000.50', f.readline().strip()) - if 'B_mV' in fname: - self.assertEqual('1.50', f.readline().strip()) - self.assertEqual('2.50', f.readline().strip()) - self.assertEqual('3.50', f.readline().strip()) - - def test_SaveRawDataNoUnit(self): - """SaveRawData appends no unit suffix if the unit is not specified.""" - self._populate_mock_stats_no_unit() - self.data.CalculateStats() - outdir = 'unittest_raw_data' - files = self.data.SaveRawData(self.tempdir, outdir) - files = [os.path.basename(f) for f in files] - # Verify nothing gets appended to domain for filename if no unit exists. - self.assertIn('B.txt', files) - - def test_SaveRawDataSMID(self): - """SaveRawData uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - files = self.data.SaveRawData(self.tempdir) - for fname in files: - self.assertTrue(os.path.basename(fname).startswith(identifier)) - - def test_SummaryToStringNaNHelp(self): - """NaN containing row gets tagged with *, help banner gets added.""" - help_banner_exp = '%s %s' % (stats_manager.STATS_PREFIX, - stats_manager.NAN_DESCRIPTION) - nan_domain = 'A-domain' - nan_domain_exp = '%s%s' % (nan_domain, stats_manager.NAN_TAG) - # NaN helper banner is added when a NaN domain is found & domain gets tagged - data = stats_manager.StatsManager() - data.AddSample(nan_domain, float('NaN')) - data.AddSample(nan_domain, 17) - data.AddSample('B-domain', 17) - data.CalculateStats() - summarystr = data.SummaryToString() - self.assertIn(help_banner_exp, summarystr) - self.assertIn(nan_domain_exp, summarystr) - # NaN helper banner is not added when no NaN domain output, no tagging - data = stats_manager.StatsManager() - # nan_domain in this scenario does not contain any NaN - data.AddSample(nan_domain, 19) - data.AddSample('B-domain', 17) - data.CalculateStats() - summarystr = data.SummaryToString() - self.assertNotIn(help_banner_exp, summarystr) - self.assertNotIn(nan_domain_exp, summarystr) - - def test_SummaryToStringTitle(self): - """Title shows up in SummaryToString if title specified.""" - title = 'titulo' - data = stats_manager.StatsManager(title=title) - self._populate_mock_stats() - summary_str = data.SummaryToString() - self.assertIn(title, summary_str) - - def test_SummaryToStringHideDomains(self): - """Keys indicated in hide_domains are not printed in the summary.""" - data = stats_manager.StatsManager(hide_domains=['A-domain']) - data.AddSample('A-domain', 17) - data.AddSample('B-domain', 17) - data.CalculateStats() - summary_str = data.SummaryToString() - self.assertIn('B-domain', summary_str) - self.assertNotIn('A-domain', summary_str) - - def test_SummaryToStringOrder(self): - """Order passed into StatsManager is honoured when formatting summary.""" - # StatsManager that should print D & B first, and the subsequent elements - # are sorted. - d_b_a_c_regexp = re.compile('D-domain.*B-domain.*A-domain.*C-domain', - re.DOTALL) - data = stats_manager.StatsManager(order=['D-domain', 'B-domain']) - data.AddSample('A-domain', 17) - data.AddSample('B-domain', 17) - data.AddSample('C-domain', 17) - data.AddSample('D-domain', 17) - data.CalculateStats() - summary_str = data.SummaryToString() - self.assertRegexpMatches(summary_str, d_b_a_c_regexp) - - def test_MakeUniqueFName(self): - data = stats_manager.StatsManager() - testfile = os.path.join(self.tempdir, 'testfile.txt') - with open(testfile, 'w') as f: - f.write('') - expected_fname = os.path.join(self.tempdir, 'testfile0.txt') - self.assertEqual(expected_fname, data._MakeUniqueFName(testfile)) - - def test_SaveSummary(self): - """SaveSummary properly dumps the summary into a file.""" - self._populate_mock_stats() - fname = 'unittest_summary.txt' - expected_fname = os.path.join(self.tempdir, fname) - fname = self.data.SaveSummary(self.tempdir, fname) - # Assert the reported fname is the same as the expected fname - self.assertEqual(expected_fname, fname) - # Assert only the reported fname is output (in the tempdir) - self.assertEqual(set([os.path.basename(fname)]), - set(os.listdir(self.tempdir))) - with open(fname, 'r') as f: - self.assertEqual( - '@@ NAME COUNT MEAN STDDEV MAX MIN\n', - f.readline()) - self.assertEqual( - '@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n', - f.readline()) - self.assertEqual( - '@@ B_mV 3 2.50 0.82 3.50 1.50\n', - f.readline()) - - def test_SaveSummarySMID(self): - """SaveSummary uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - fname = os.path.basename(self.data.SaveSummary(self.tempdir)) - self.assertTrue(fname.startswith(identifier)) - - def test_SaveSummaryJSON(self): - """SaveSummaryJSON saves the added data properly in JSON format.""" - self._populate_mock_stats() - fname = 'unittest_summary.json' - expected_fname = os.path.join(self.tempdir, fname) - fname = self.data.SaveSummaryJSON(self.tempdir, fname) - # Assert the reported fname is the same as the expected fname - self.assertEqual(expected_fname, fname) - # Assert only the reported fname is output (in the tempdir) - self.assertEqual(set([os.path.basename(fname)]), - set(os.listdir(self.tempdir))) - with open(fname, 'r') as f: - summary = json.load(f) - self.assertAlmostEqual(100000.0, summary['A']['mean']) - self.assertEqual('milliwatt', summary['A']['unit']) - self.assertAlmostEqual(2.5, summary['B']['mean']) - self.assertEqual('millivolt', summary['B']['unit']) - - def test_SaveSummaryJSONSMID(self): - """SaveSummaryJSON uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir)) - self.assertTrue(fname.startswith(identifier)) - - def test_SaveSummaryJSONNoUnit(self): - """SaveSummaryJSON marks unknown units properly as N/A.""" - self._populate_mock_stats_no_unit() - self.data.CalculateStats() - fname = 'unittest_summary.json' - fname = self.data.SaveSummaryJSON(self.tempdir, fname) - with open(fname, 'r') as f: - summary = json.load(f) - self.assertEqual('blue', summary['A']['unit']) - # if no unit is specified, JSON should save 'N/A' as the unit. - self.assertEqual('N/A', summary['B']['unit']) - -if __name__ == '__main__': - unittest.main() + """Test to verify StatsManager methods work as expected. + + StatsManager should collect raw data, calculate their statistics, and save + them in expected format. + """ + + def _populate_mock_stats(self): + """Create a populated & processed StatsManager to test data retrieval.""" + self.data.AddSample("A", 99999.5) + self.data.AddSample("A", 100000.5) + self.data.SetUnit("A", "uW") + self.data.SetUnit("A", "mW") + self.data.AddSample("B", 1.5) + self.data.AddSample("B", 2.5) + self.data.AddSample("B", 3.5) + self.data.SetUnit("B", "mV") + self.data.CalculateStats() + + def _populate_mock_stats_no_unit(self): + self.data.AddSample("B", 1000) + self.data.AddSample("A", 200) + self.data.SetUnit("A", "blue") + + def setUp(self): + """Set up StatsManager and create a temporary directory for test.""" + self.tempdir = tempfile.mkdtemp() + self.data = stats_manager.StatsManager() + + def tearDown(self): + """Delete the temporary directory and its content.""" + shutil.rmtree(self.tempdir) + + def test_AddSample(self): + """Adding a sample successfully adds a sample.""" + self.data.AddSample("Test", 1000) + self.data.SetUnit("Test", "test") + self.data.CalculateStats() + summary = self.data.GetSummary() + self.assertEqual(1, summary["Test"]["count"]) + + def test_AddSampleNoFloatAcceptNaN(self): + """Adding a non-number adds 'NaN' and doesn't raise an exception.""" + self.data.AddSample("Test", 10) + self.data.AddSample("Test", 20) + # adding a fake NaN: one that gets converted into NaN internally + self.data.AddSample("Test", "fiesta") + # adding a real NaN + self.data.AddSample("Test", float("NaN")) + self.data.SetUnit("Test", "test") + self.data.CalculateStats() + summary = self.data.GetSummary() + # assert that 'NaN' as added. + self.assertEqual(4, summary["Test"]["count"]) + # assert that mean, min, and max calculatings ignore the 'NaN' + self.assertEqual(10, summary["Test"]["min"]) + self.assertEqual(20, summary["Test"]["max"]) + self.assertEqual(15, summary["Test"]["mean"]) + + def test_AddSampleNoFloatNotAcceptNaN(self): + """Adding a non-number raises a StatsManagerError if accept_nan is False.""" + self.data = stats_manager.StatsManager(accept_nan=False) + with self.assertRaisesRegexp( + stats_manager.StatsManagerError, + "accept_nan is false. Cannot add NaN sample.", + ): + # adding a fake NaN: one that gets converted into NaN internally + self.data.AddSample("Test", "fiesta") + with self.assertRaisesRegexp( + stats_manager.StatsManagerError, + "accept_nan is false. Cannot add NaN sample.", + ): + # adding a real NaN + self.data.AddSample("Test", float("NaN")) + + def test_AddSampleNoUnit(self): + """Not adding a unit does not cause an exception on CalculateStats().""" + self.data.AddSample("Test", 17) + self.data.CalculateStats() + summary = self.data.GetSummary() + self.assertEqual(1, summary["Test"]["count"]) + + def test_UnitSuffix(self): + """Unit gets appended as a suffix in the displayed summary.""" + self.data.AddSample("test", 250) + self.data.SetUnit("test", "mw") + self.data.CalculateStats() + summary_str = self.data.SummaryToString() + self.assertIn("test_mw", summary_str) + + def test_DoubleUnitSuffix(self): + """If domain already ends in unit, verify that unit doesn't get appended.""" + self.data.AddSample("test_mw", 250) + self.data.SetUnit("test_mw", "mw") + self.data.CalculateStats() + summary_str = self.data.SummaryToString() + self.assertIn("test_mw", summary_str) + self.assertNotIn("test_mw_mw", summary_str) + + def test_GetRawData(self): + """GetRawData returns exact same data as fed in.""" + self._populate_mock_stats() + raw_data = self.data.GetRawData() + self.assertListEqual([99999.5, 100000.5], raw_data["A"]) + self.assertListEqual([1.5, 2.5, 3.5], raw_data["B"]) + + def test_GetSummary(self): + """GetSummary returns expected stats about the data fed in.""" + self._populate_mock_stats() + summary = self.data.GetSummary() + self.assertEqual(2, summary["A"]["count"]) + self.assertAlmostEqual(100000.5, summary["A"]["max"]) + self.assertAlmostEqual(99999.5, summary["A"]["min"]) + self.assertAlmostEqual(0.5, summary["A"]["stddev"]) + self.assertAlmostEqual(100000.0, summary["A"]["mean"]) + self.assertEqual(3, summary["B"]["count"]) + self.assertAlmostEqual(3.5, summary["B"]["max"]) + self.assertAlmostEqual(1.5, summary["B"]["min"]) + self.assertAlmostEqual(0.81649658092773, summary["B"]["stddev"]) + self.assertAlmostEqual(2.5, summary["B"]["mean"]) + + def test_SaveRawData(self): + """SaveRawData stores same data as fed in.""" + self._populate_mock_stats() + dirname = "unittest_raw_data" + expected_files = set(["A_mW.txt", "B_mV.txt"]) + fnames = self.data.SaveRawData(self.tempdir, dirname) + files_returned = set([os.path.basename(f) for f in fnames]) + # Assert that only the expected files got returned. + self.assertEqual(expected_files, files_returned) + # Assert that only the returned files are in the outdir. + self.assertEqual( + set(os.listdir(os.path.join(self.tempdir, dirname))), files_returned + ) + for fname in fnames: + with open(fname, "r") as f: + if "A_mW" in fname: + self.assertEqual("99999.50", f.readline().strip()) + self.assertEqual("100000.50", f.readline().strip()) + if "B_mV" in fname: + self.assertEqual("1.50", f.readline().strip()) + self.assertEqual("2.50", f.readline().strip()) + self.assertEqual("3.50", f.readline().strip()) + + def test_SaveRawDataNoUnit(self): + """SaveRawData appends no unit suffix if the unit is not specified.""" + self._populate_mock_stats_no_unit() + self.data.CalculateStats() + outdir = "unittest_raw_data" + files = self.data.SaveRawData(self.tempdir, outdir) + files = [os.path.basename(f) for f in files] + # Verify nothing gets appended to domain for filename if no unit exists. + self.assertIn("B.txt", files) + + def test_SaveRawDataSMID(self): + """SaveRawData uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + files = self.data.SaveRawData(self.tempdir) + for fname in files: + self.assertTrue(os.path.basename(fname).startswith(identifier)) + + def test_SummaryToStringNaNHelp(self): + """NaN containing row gets tagged with *, help banner gets added.""" + help_banner_exp = "%s %s" % ( + stats_manager.STATS_PREFIX, + stats_manager.NAN_DESCRIPTION, + ) + nan_domain = "A-domain" + nan_domain_exp = "%s%s" % (nan_domain, stats_manager.NAN_TAG) + # NaN helper banner is added when a NaN domain is found & domain gets tagged + data = stats_manager.StatsManager() + data.AddSample(nan_domain, float("NaN")) + data.AddSample(nan_domain, 17) + data.AddSample("B-domain", 17) + data.CalculateStats() + summarystr = data.SummaryToString() + self.assertIn(help_banner_exp, summarystr) + self.assertIn(nan_domain_exp, summarystr) + # NaN helper banner is not added when no NaN domain output, no tagging + data = stats_manager.StatsManager() + # nan_domain in this scenario does not contain any NaN + data.AddSample(nan_domain, 19) + data.AddSample("B-domain", 17) + data.CalculateStats() + summarystr = data.SummaryToString() + self.assertNotIn(help_banner_exp, summarystr) + self.assertNotIn(nan_domain_exp, summarystr) + + def test_SummaryToStringTitle(self): + """Title shows up in SummaryToString if title specified.""" + title = "titulo" + data = stats_manager.StatsManager(title=title) + self._populate_mock_stats() + summary_str = data.SummaryToString() + self.assertIn(title, summary_str) + + def test_SummaryToStringHideDomains(self): + """Keys indicated in hide_domains are not printed in the summary.""" + data = stats_manager.StatsManager(hide_domains=["A-domain"]) + data.AddSample("A-domain", 17) + data.AddSample("B-domain", 17) + data.CalculateStats() + summary_str = data.SummaryToString() + self.assertIn("B-domain", summary_str) + self.assertNotIn("A-domain", summary_str) + + def test_SummaryToStringOrder(self): + """Order passed into StatsManager is honoured when formatting summary.""" + # StatsManager that should print D & B first, and the subsequent elements + # are sorted. + d_b_a_c_regexp = re.compile("D-domain.*B-domain.*A-domain.*C-domain", re.DOTALL) + data = stats_manager.StatsManager(order=["D-domain", "B-domain"]) + data.AddSample("A-domain", 17) + data.AddSample("B-domain", 17) + data.AddSample("C-domain", 17) + data.AddSample("D-domain", 17) + data.CalculateStats() + summary_str = data.SummaryToString() + self.assertRegexpMatches(summary_str, d_b_a_c_regexp) + + def test_MakeUniqueFName(self): + data = stats_manager.StatsManager() + testfile = os.path.join(self.tempdir, "testfile.txt") + with open(testfile, "w") as f: + f.write("") + expected_fname = os.path.join(self.tempdir, "testfile0.txt") + self.assertEqual(expected_fname, data._MakeUniqueFName(testfile)) + + def test_SaveSummary(self): + """SaveSummary properly dumps the summary into a file.""" + self._populate_mock_stats() + fname = "unittest_summary.txt" + expected_fname = os.path.join(self.tempdir, fname) + fname = self.data.SaveSummary(self.tempdir, fname) + # Assert the reported fname is the same as the expected fname + self.assertEqual(expected_fname, fname) + # Assert only the reported fname is output (in the tempdir) + self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir))) + with open(fname, "r") as f: + self.assertEqual( + "@@ NAME COUNT MEAN STDDEV MAX MIN\n", + f.readline(), + ) + self.assertEqual( + "@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n", + f.readline(), + ) + self.assertEqual( + "@@ B_mV 3 2.50 0.82 3.50 1.50\n", + f.readline(), + ) + + def test_SaveSummarySMID(self): + """SaveSummary uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + fname = os.path.basename(self.data.SaveSummary(self.tempdir)) + self.assertTrue(fname.startswith(identifier)) + + def test_SaveSummaryJSON(self): + """SaveSummaryJSON saves the added data properly in JSON format.""" + self._populate_mock_stats() + fname = "unittest_summary.json" + expected_fname = os.path.join(self.tempdir, fname) + fname = self.data.SaveSummaryJSON(self.tempdir, fname) + # Assert the reported fname is the same as the expected fname + self.assertEqual(expected_fname, fname) + # Assert only the reported fname is output (in the tempdir) + self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir))) + with open(fname, "r") as f: + summary = json.load(f) + self.assertAlmostEqual(100000.0, summary["A"]["mean"]) + self.assertEqual("milliwatt", summary["A"]["unit"]) + self.assertAlmostEqual(2.5, summary["B"]["mean"]) + self.assertEqual("millivolt", summary["B"]["unit"]) + + def test_SaveSummaryJSONSMID(self): + """SaveSummaryJSON uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir)) + self.assertTrue(fname.startswith(identifier)) + + def test_SaveSummaryJSONNoUnit(self): + """SaveSummaryJSON marks unknown units properly as N/A.""" + self._populate_mock_stats_no_unit() + self.data.CalculateStats() + fname = "unittest_summary.json" + fname = self.data.SaveSummaryJSON(self.tempdir, fname) + with open(fname, "r") as f: + summary = json.load(f) + self.assertEqual("blue", summary["A"]["unit"]) + # if no unit is specified, JSON should save 'N/A' as the unit. + self.assertEqual("N/A", summary["B"]["unit"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/usb_serial/console.py b/extra/usb_serial/console.py index 7211dceff6..c08cb72092 100755 --- a/extra/usb_serial/console.py +++ b/extra/usb_serial/console.py @@ -12,6 +12,7 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import array import os @@ -20,23 +21,24 @@ import termios import threading import time import tty + try: - import usb + import usb except ModuleNotFoundError: - print('import usb failed') - print('try running these commands:') - print(' sudo apt-get install python-pip') - print(' sudo pip install --pre pyusb') - print() - sys.exit(-1) + print("import usb failed") + print("try running these commands:") + print(" sudo apt-get install python-pip") + print(" sudo pip install --pre pyusb") + print() + sys.exit(-1) import six def GetBuffer(stream): - if six.PY3: - return stream.buffer - return stream + if six.PY3: + return stream.buffer + return stream """Class Susb covers USB device discovery and initialization. @@ -45,99 +47,104 @@ def GetBuffer(stream): and interface number. """ + class SusbError(Exception): - """Class for exceptions of Susb.""" - def __init__(self, msg, value=0): - """SusbError constructor. + """Class for exceptions of Susb.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SusbError, self).__init__(msg, value) - self.msg = msg - self.value = value - -class Susb(): - """Provide USB functionality. - - Instance Variables: - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - READ_ENDPOINT = 0x81 - WRITE_ENDPOINT = 0x1 - TIMEOUT_MS = 100 - - def __init__(self, vendor=0x18d1, - product=0x500f, interface=1, serialname=None): - """Susb constructor. - - Discovers and connects to USB endpoints. - - Args: - vendor: usb vendor id of device - product: usb product id of device - interface: interface number ( 1 - 8 ) of device to use - serialname: string of device serialnumber. - - Raises: - SusbError: An error accessing Susb object + def __init__(self, msg, value=0): + """SusbError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SusbError, self).__init__(msg, value) + self.msg = msg + self.value = value + + +class Susb: + """Provide USB functionality. + + Instance Variables: + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - # Find the device. - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise SusbError('USB device not found') - - # Check if we have multiple devices. - dev = None - if serialname: - for d in dev_list: - dev_serial = "PyUSB doesn't have a stable interface" - try: - dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) - except: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - raise SusbError('USB device(%s) not found' % (serialname,)) - else: - try: - dev = dev_list[0] - except: - try: - dev = dev_list.next() - except: - raise SusbError('USB device %04x:%04x not found' % (vendor, product)) - # If we can't set configuration, it's already been set. - try: - dev.set_configuration() - except usb.core.USBError: - pass + READ_ENDPOINT = 0x81 + WRITE_ENDPOINT = 0x1 + TIMEOUT_MS = 100 + + def __init__(self, vendor=0x18D1, product=0x500F, interface=1, serialname=None): + """Susb constructor. + + Discovers and connects to USB endpoints. + + Args: + vendor: usb vendor id of device + product: usb product id of device + interface: interface number ( 1 - 8 ) of device to use + serialname: string of device serialnumber. + + Raises: + SusbError: An error accessing Susb object + """ + # Find the device. + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise SusbError("USB device not found") + + # Check if we have multiple devices. + dev = None + if serialname: + for d in dev_list: + dev_serial = "PyUSB doesn't have a stable interface" + try: + dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) + except: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % (serialname,)) + else: + try: + dev = dev_list[0] + except: + try: + dev = dev_list.next() + except: + raise SusbError( + "USB device %04x:%04x not found" % (vendor, product) + ) + + # If we can't set configuration, it's already been set. + try: + dev.set_configuration() + except usb.core.USBError: + pass - # Get an endpoint instance. - cfg = dev.get_active_configuration() - intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface) - self._intf = intf + # Get an endpoint instance. + cfg = dev.get_active_configuration() + intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface) + self._intf = intf - if not intf: - raise SusbError('Interface not found') + if not intf: + raise SusbError("Interface not found") - # Detach raiden.ko if it is loaded. - if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: + # Detach raiden.ko if it is loaded. + if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: dev.detach_kernel_driver(intf.bInterfaceNumber) - read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT - read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) - self._read_ep = read_ep + read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT + read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) + self._read_ep = read_ep - write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT - write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) - self._write_ep = write_ep + write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT + write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) + self._write_ep = write_ep """Suart class implements a stream interface, to access Google's USB class. @@ -146,89 +153,91 @@ class Susb(): and forwards them across. This particular class is hardcoded to stdin/out. """ -class SuartError(Exception): - """Class for exceptions of Suart.""" - def __init__(self, msg, value=0): - """SuartError constructor. - - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SuartError, self).__init__(msg, value) - self.msg = msg - self.value = value - - -class Suart(): - """Provide interface to serial usb endpoint.""" - def __init__(self, vendor=0x18d1, product=0x501c, interface=0, - serialname=None): - """Suart contstructor. +class SuartError(Exception): + """Class for exceptions of Suart.""" - Initializes USB stream interface. + def __init__(self, msg, value=0): + """SuartError constructor. - Args: - vendor: usb vendor id of device - product: usb product id of device - interface: interface number of device to use - serialname: Defaults to None. + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SuartError, self).__init__(msg, value) + self.msg = msg + self.value = value - Raises: - SuartError: If init fails - """ - self._done = threading.Event() - self._susb = Susb(vendor=vendor, product=product, - interface=interface, serialname=serialname) - def wait_until_done(self, timeout=None): - return self._done.wait(timeout=timeout) +class Suart: + """Provide interface to serial usb endpoint.""" - def run_rx_thread(self): - try: - while True: - try: - r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) - if r: - GetBuffer(sys.stdout).write(r.tobytes()) - GetBuffer(sys.stdout).flush() - - except Exception as e: - # If we miss some characters on pty disconnect, that's fine. - # ep.read() also throws USBError on timeout, which we discard. - if not isinstance(e, (OSError, usb.core.USBError)): - print('rx %s' % e) - finally: - self._done.set() + def __init__(self, vendor=0x18D1, product=0x501C, interface=0, serialname=None): + """Suart contstructor. - def run_tx_thread(self): - try: - while True: - try: - r = GetBuffer(sys.stdin).read(1) - if not r or r == b'\x03': - break - if r: - self._susb._write_ep.write(array.array('B', r), - self._susb.TIMEOUT_MS) - except Exception as e: - print('tx %s' % e) - finally: - self._done.set() + Initializes USB stream interface. - def run(self): - """Creates pthreads to poll USB & PTY for data.""" - self._exit = False + Args: + vendor: usb vendor id of device + product: usb product id of device + interface: interface number of device to use + serialname: Defaults to None. - self._rx_thread = threading.Thread(target=self.run_rx_thread) - self._rx_thread.daemon = True - self._rx_thread.start() + Raises: + SuartError: If init fails + """ + self._done = threading.Event() + self._susb = Susb( + vendor=vendor, product=product, interface=interface, serialname=serialname + ) - self._tx_thread = threading.Thread(target=self.run_tx_thread) - self._tx_thread.daemon = True - self._tx_thread.start() + def wait_until_done(self, timeout=None): + return self._done.wait(timeout=timeout) + def run_rx_thread(self): + try: + while True: + try: + r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) + if r: + GetBuffer(sys.stdout).write(r.tobytes()) + GetBuffer(sys.stdout).flush() + + except Exception as e: + # If we miss some characters on pty disconnect, that's fine. + # ep.read() also throws USBError on timeout, which we discard. + if not isinstance(e, (OSError, usb.core.USBError)): + print("rx %s" % e) + finally: + self._done.set() + + def run_tx_thread(self): + try: + while True: + try: + r = GetBuffer(sys.stdin).read(1) + if not r or r == b"\x03": + break + if r: + self._susb._write_ep.write( + array.array("B", r), self._susb.TIMEOUT_MS + ) + except Exception as e: + print("tx %s" % e) + finally: + self._done.set() + + def run(self): + """Creates pthreads to poll USB & PTY for data.""" + self._exit = False + + self._rx_thread = threading.Thread(target=self.run_rx_thread) + self._rx_thread.daemon = True + self._rx_thread.start() + + self._tx_thread = threading.Thread(target=self.run_tx_thread) + self._tx_thread.daemon = True + self._tx_thread.start() """Command line functionality @@ -237,59 +246,67 @@ class Suart(): Ctrl-C exits. """ -parser = argparse.ArgumentParser(description='Open a console to a USB device') -parser.add_argument('-d', '--device', type=str, - help='vid:pid of target device', default='18d1:501c') -parser.add_argument('-i', '--interface', type=int, - help='interface number of console', default=0) -parser.add_argument('-s', '--serialno', type=str, - help='serial number of device', default='') -parser.add_argument('-S', '--notty-exit-sleep', type=float, default=0.2, - help='When stdin is *not* a TTY, wait this many seconds ' - 'after EOF from stdin before exiting, to give time for ' - 'receiving a reply from the USB device.') +parser = argparse.ArgumentParser(description="Open a console to a USB device") +parser.add_argument( + "-d", "--device", type=str, help="vid:pid of target device", default="18d1:501c" +) +parser.add_argument( + "-i", "--interface", type=int, help="interface number of console", default=0 +) +parser.add_argument( + "-s", "--serialno", type=str, help="serial number of device", default="" +) +parser.add_argument( + "-S", + "--notty-exit-sleep", + type=float, + default=0.2, + help="When stdin is *not* a TTY, wait this many seconds " + "after EOF from stdin before exiting, to give time for " + "receiving a reply from the USB device.", +) + def runconsole(): - """Run the usb console code + """Run the usb console code - Starts the pty thread, and idles until a ^C is caught. - """ - args = parser.parse_args() + Starts the pty thread, and idles until a ^C is caught. + """ + args = parser.parse_args() - vidstr, pidstr = args.device.split(':') - vid = int(vidstr, 16) - pid = int(pidstr, 16) + vidstr, pidstr = args.device.split(":") + vid = int(vidstr, 16) + pid = int(pidstr, 16) - serialno = args.serialno - interface = args.interface + serialno = args.serialno + interface = args.interface - sobj = Suart(vendor=vid, product=pid, interface=interface, - serialname=serialno) - if sys.stdin.isatty(): - tty.setraw(sys.stdin.fileno()) - sobj.run() - sobj.wait_until_done() - if not sys.stdin.isatty() and args.notty_exit_sleep > 0: - time.sleep(args.notty_exit_sleep) + sobj = Suart(vendor=vid, product=pid, interface=interface, serialname=serialno) + if sys.stdin.isatty(): + tty.setraw(sys.stdin.fileno()) + sobj.run() + sobj.wait_until_done() + if not sys.stdin.isatty() and args.notty_exit_sleep > 0: + time.sleep(args.notty_exit_sleep) def main(): - stdin_isatty = sys.stdin.isatty() - if stdin_isatty: - fd = sys.stdin.fileno() - os.system('stty -echo') - old_settings = termios.tcgetattr(fd) - - try: - runconsole() - finally: + stdin_isatty = sys.stdin.isatty() if stdin_isatty: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - os.system('stty echo') - # Avoid having the user's shell prompt start mid-line after the final output - # from this program. - print() + fd = sys.stdin.fileno() + os.system("stty -echo") + old_settings = termios.tcgetattr(fd) + + try: + runconsole() + finally: + if stdin_isatty: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + os.system("stty echo") + # Avoid having the user's shell prompt start mid-line after the final output + # from this program. + print() -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/usb_updater/fw_update.py b/extra/usb_updater/fw_update.py index 0d7a570fc3..f05797bfb6 100755 --- a/extra/usb_updater/fw_update.py +++ b/extra/usb_updater/fw_update.py @@ -20,407 +20,425 @@ import struct import sys import time from pprint import pprint -import usb +import usb debug = False -def debuglog(msg): - if debug: - print(msg) - -def log(msg): - print(msg) - sys.stdout.flush() - - -"""Sends firmware update to CROS EC usb endpoint.""" - -class Supdate(object): - """Class to access firmware update endpoints. - - Usage: - d = Supdate() - Instance Variables: - _dev: pyUSB device object - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - USB_SUBCLASS_GOOGLE_UPDATE = 0x53 - USB_CLASS_VENDOR = 0xFF - def __init__(self): - pass - - - def connect_usb(self, serialname=None): - """Initial discovery and connection to USB endpoint. - - This searches for a USB device matching the VID:PID specified - in the config file, optionally matching a specified serialname. - - Args: - serialname: Find the device with this serial, in case multiple - devices are attached. - - Returns: - True on success. - Raises: - Exception on error. - """ - # Find the stm32. - vendor = self._brdcfg['vid'] - product = self._brdcfg['pid'] - - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise Exception("Update", "USB device not found") - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - if usb.util.get_string(d, d.iSerialNumber) == serialname: - dev = d - break - if dev is None: - raise SusbError("USB device(%s) not found" % serialname) - else: - try: - dev = dev_list[0] - except: - dev = dev_list.next() - - debuglog("Found stm32: %04x:%04x" % (vendor, product)) - self._dev = dev - - # Get an endpoint instance. - try: - dev.set_configuration() - except: - pass - cfg = dev.get_active_configuration() - - intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \ - i.bInterfaceClass==self.USB_CLASS_VENDOR and \ - i.bInterfaceSubClass==self.USB_SUBCLASS_GOOGLE_UPDATE) - - self._intf = intf - debuglog("Interface: %s" % intf) - debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber) - - read_ep = usb.util.find_descriptor( - intf, - # match the first IN endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_IN - ) - - self._read_ep = read_ep - debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress) - - write_ep = usb.util.find_descriptor( - intf, - # match the first OUT endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_OUT - ) - - self._write_ep = write_ep - debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress) - - return True - - - def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000): - """Write command to logger logic.. - - This function writes byte command values list to stm, then reads - byte status. - - Args: - write_list: list of command byte values [0~255]. - read_count: number of status byte values to read. - wtimeout: mS to wait for write success - rtimeout: mS to wait for read success - - Returns: - status byte, if one byte is read, - byte list, if multiple bytes are read, - None, if no bytes are read. - - Interface: - write: [command, data ... ] - read: [status ] - """ - debuglog("wr_command(write_list=[%s] (%d), read_count=%s)" % ( - list(bytearray(write_list)), len(write_list), read_count)) - - # Clean up args from python style to correct types. - write_length = 0 - if write_list: - write_length = len(write_list) - if not read_count: - read_count = 0 - - # Send command to stm32. - if write_list: - cmd = write_list - ret = self._write_ep.write(cmd, wtimeout) - debuglog("RET: %s " % ret) - - # Read back response if necessary. - if read_count: - bytesread = self._read_ep.read(512, rtimeout) - debuglog("BYTES: [%s]" % bytesread) - - if len(bytesread) != read_count: - debuglog("Unexpected bytes read: %d, expected: %d" % (len(bytesread), read_count)) - pass - - debuglog("STATUS: 0x%02x" % int(bytesread[0])) - if read_count == 1: - return bytesread[0] - else: - return bytesread - - return None +def debuglog(msg): + if debug: + print(msg) - def stop(self): - """Finalize system flash and exit.""" - cmd = struct.pack(">I", 0xB007AB1E) - read = self.wr_command(cmd, read_count=4) - if len(read) == 4: - log("Finished flashing") - return +def log(msg): + print(msg) + sys.stdout.flush() - raise Exception("Update", "Stop failed [%s]" % read) +"""Sends firmware update to CROS EC usb endpoint.""" - def write_file(self): - """Write the update region packet by packet to USB - This sends write packets of size 128B out, in 32B chunks. - Overall, this will write all data in the inactive code region. +class Supdate(object): + """Class to access firmware update endpoints. - Raises: - Exception if write failed or address out of bounds. - """ - region = self._region - flash_base = self._brdcfg["flash"] - offset = self._base - flash_base - if offset != self._brdcfg['regions'][region][0]: - raise Exception("Update", "Region %s offset 0x%x != available offset 0x%x" % ( - region, self._brdcfg['regions'][region][0], offset)) - - length = self._brdcfg['regions'][region][1] - log("Sending") - - # Go to the correct region in the ec.bin file. - self._binfile.seek(offset) - - # Send 32 bytes at a time. Must be less than the endpoint's max packet size. - maxpacket = 32 - - # While data is left, create update packets. - while length > 0: - # Update packets are 128B. We can use any number - # but the micro must malloc this memory. - pagesize = min(length, 128) - - # Packet is: - # packet size: page bytes transferred plus 3 x 32b values header. - # cmd: n/a - # base: flash address to write this packet. - # data: 128B of data to write into flash_base - cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base) - read = self.wr_command(cmd, read_count=0) - - # Push 'todo' bytes out the pipe. - todo = pagesize - while todo > 0: - packetsize = min(maxpacket, todo) - data = self._binfile.read(packetsize) - if len(data) != packetsize: - raise Exception("Update", "No more data from file") - for i in range(0, 10): - try: - self.wr_command(data, read_count=0) - break - except: - log("Timeout fail") - todo -= packetsize - # Done with this packet, move to the next one. - length -= pagesize - offset += pagesize - - # Validate that the micro thinks it successfully wrote the data. - read = self.wr_command(''.encode(), read_count=4) - result = struct.unpack("<I", read) - result = result[0] - if result != 0: - raise Exception("Update", "Upload failed with rc: 0x%x" % result) - - - def start(self): - """Start a transaction and erase currently inactive region. - - This function sends a start command, and receives the base of the - preferred inactive region. This could be RW, RW_B, - or RO (if there's no RW_B) - - Note that the region is erased here, so you'd better program the RO if - you just erased it. TODO(nsanders): Modify the protocol to allow active - region select or query before erase. - """ + Usage: + d = Supdate() - # Size is 3 uint32 fields - # packet: [packetsize, cmd, base] - size = 4 + 4 + 4 - # Return value is [status, base_addr] - expected = 4 + 4 - - cmd = struct.pack("<III", size, 0, 0) - read = self.wr_command(cmd, read_count=expected) - - if len(read) == 4: - raise Exception("Update", "Protocol version 0 not supported") - elif len(read) == expected: - base, version = struct.unpack(">II", read) - log("Update protocol v. %d" % version) - log("Available flash region base: %x" % base) - else: - raise Exception("Update", "Start command returned %d bytes" % len(read)) - - if base < 256: - raise Exception("Update", "Start returned error code 0x%x" % base) - - self._base = base - flash_base = self._brdcfg["flash"] - self._offset = self._base - flash_base - - # Find our active region. - for region in self._brdcfg['regions']: - if (self._offset >= self._brdcfg['regions'][region][0]) and \ - (self._offset < (self._brdcfg['regions'][region][0] + \ - self._brdcfg['regions'][region][1])): - log("Active region: %s" % region) - self._region = region - - - def load_board(self, brdfile): - """Load firmware layout file. - - example as follows: - { - "board": "servo micro", - "vid": 6353, - "pid": 20506, - "flash": 134217728, - "regions": { - "RW": [65536, 65536], - "PSTATE": [61440, 4096], - "RO": [0, 61440] - } - } - - Args: - brdfile: path to board description file. + Instance Variables: + _dev: pyUSB device object + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - with open(brdfile) as data_file: - data = json.load(data_file) - - # TODO(nsanders): validate this data before moving on. - self._brdcfg = data; - if debug: - pprint(data) - - log("Board is %s" % self._brdcfg['board']) - # Cast hex strings to int. - self._brdcfg['flash'] = int(self._brdcfg['flash'], 0) - self._brdcfg['vid'] = int(self._brdcfg['vid'], 0) - self._brdcfg['pid'] = int(self._brdcfg['pid'], 0) - log("Flash Base is %x" % self._brdcfg['flash']) - self._flashsize = 0 - for region in self._brdcfg['regions']: - base = int(self._brdcfg['regions'][region][0], 0) - length = int(self._brdcfg['regions'][region][1], 0) - log("region %s\tbase:0x%08x size:0x%08x" % ( - region, base, length)) - self._flashsize += length + USB_SUBCLASS_GOOGLE_UPDATE = 0x53 + USB_CLASS_VENDOR = 0xFF - # Convert these to int because json doesn't support hex. - self._brdcfg['regions'][region][0] = base - self._brdcfg['regions'][region][1] = length - - log("Flash Size: 0x%x" % self._flashsize) - - def load_file(self, binfile): - """Open and verify size of the target ec.bin file. - - Args: - binfile: path to ec.bin - - Raises: - Exception on file not found or filesize not matching. - """ - self._filesize = os.path.getsize(binfile) - self._binfile = open(binfile, 'rb') - - if self._filesize != self._flashsize: - raise Exception("Update", "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize)) + def __init__(self): + pass + def connect_usb(self, serialname=None): + """Initial discovery and connection to USB endpoint. + + This searches for a USB device matching the VID:PID specified + in the config file, optionally matching a specified serialname. + + Args: + serialname: Find the device with this serial, in case multiple + devices are attached. + + Returns: + True on success. + Raises: + Exception on error. + """ + # Find the stm32. + vendor = self._brdcfg["vid"] + product = self._brdcfg["pid"] + + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise Exception("Update", "USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + if usb.util.get_string(d, d.iSerialNumber) == serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % serialname) + else: + try: + dev = dev_list[0] + except: + dev = dev_list.next() + + debuglog("Found stm32: %04x:%04x" % (vendor, product)) + self._dev = dev + + # Get an endpoint instance. + try: + dev.set_configuration() + except: + pass + cfg = dev.get_active_configuration() + + intf = usb.util.find_descriptor( + cfg, + custom_match=lambda i: i.bInterfaceClass == self.USB_CLASS_VENDOR + and i.bInterfaceSubClass == self.USB_SUBCLASS_GOOGLE_UPDATE, + ) + + self._intf = intf + debuglog("Interface: %s" % intf) + debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber) + + read_ep = usb.util.find_descriptor( + intf, + # match the first IN endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_IN, + ) + + self._read_ep = read_ep + debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress) + + write_ep = usb.util.find_descriptor( + intf, + # match the first OUT endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_OUT, + ) + + self._write_ep = write_ep + debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress) + + return True + + def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000): + """Write command to logger logic.. + + This function writes byte command values list to stm, then reads + byte status. + + Args: + write_list: list of command byte values [0~255]. + read_count: number of status byte values to read. + wtimeout: mS to wait for write success + rtimeout: mS to wait for read success + + Returns: + status byte, if one byte is read, + byte list, if multiple bytes are read, + None, if no bytes are read. + + Interface: + write: [command, data ... ] + read: [status ] + """ + debuglog( + "wr_command(write_list=[%s] (%d), read_count=%s)" + % (list(bytearray(write_list)), len(write_list), read_count) + ) + + # Clean up args from python style to correct types. + write_length = 0 + if write_list: + write_length = len(write_list) + if not read_count: + read_count = 0 + + # Send command to stm32. + if write_list: + cmd = write_list + ret = self._write_ep.write(cmd, wtimeout) + debuglog("RET: %s " % ret) + + # Read back response if necessary. + if read_count: + bytesread = self._read_ep.read(512, rtimeout) + debuglog("BYTES: [%s]" % bytesread) + + if len(bytesread) != read_count: + debuglog( + "Unexpected bytes read: %d, expected: %d" + % (len(bytesread), read_count) + ) + pass + + debuglog("STATUS: 0x%02x" % int(bytesread[0])) + if read_count == 1: + return bytesread[0] + else: + return bytesread + + return None + + def stop(self): + """Finalize system flash and exit.""" + cmd = struct.pack(">I", 0xB007AB1E) + read = self.wr_command(cmd, read_count=4) + + if len(read) == 4: + log("Finished flashing") + return + + raise Exception("Update", "Stop failed [%s]" % read) + + def write_file(self): + """Write the update region packet by packet to USB + + This sends write packets of size 128B out, in 32B chunks. + Overall, this will write all data in the inactive code region. + + Raises: + Exception if write failed or address out of bounds. + """ + region = self._region + flash_base = self._brdcfg["flash"] + offset = self._base - flash_base + if offset != self._brdcfg["regions"][region][0]: + raise Exception( + "Update", + "Region %s offset 0x%x != available offset 0x%x" + % (region, self._brdcfg["regions"][region][0], offset), + ) + + length = self._brdcfg["regions"][region][1] + log("Sending") + + # Go to the correct region in the ec.bin file. + self._binfile.seek(offset) + + # Send 32 bytes at a time. Must be less than the endpoint's max packet size. + maxpacket = 32 + + # While data is left, create update packets. + while length > 0: + # Update packets are 128B. We can use any number + # but the micro must malloc this memory. + pagesize = min(length, 128) + + # Packet is: + # packet size: page bytes transferred plus 3 x 32b values header. + # cmd: n/a + # base: flash address to write this packet. + # data: 128B of data to write into flash_base + cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base) + read = self.wr_command(cmd, read_count=0) + + # Push 'todo' bytes out the pipe. + todo = pagesize + while todo > 0: + packetsize = min(maxpacket, todo) + data = self._binfile.read(packetsize) + if len(data) != packetsize: + raise Exception("Update", "No more data from file") + for i in range(0, 10): + try: + self.wr_command(data, read_count=0) + break + except: + log("Timeout fail") + todo -= packetsize + # Done with this packet, move to the next one. + length -= pagesize + offset += pagesize + + # Validate that the micro thinks it successfully wrote the data. + read = self.wr_command("".encode(), read_count=4) + result = struct.unpack("<I", read) + result = result[0] + if result != 0: + raise Exception("Update", "Upload failed with rc: 0x%x" % result) + + def start(self): + """Start a transaction and erase currently inactive region. + + This function sends a start command, and receives the base of the + preferred inactive region. This could be RW, RW_B, + or RO (if there's no RW_B) + + Note that the region is erased here, so you'd better program the RO if + you just erased it. TODO(nsanders): Modify the protocol to allow active + region select or query before erase. + """ + + # Size is 3 uint32 fields + # packet: [packetsize, cmd, base] + size = 4 + 4 + 4 + # Return value is [status, base_addr] + expected = 4 + 4 + + cmd = struct.pack("<III", size, 0, 0) + read = self.wr_command(cmd, read_count=expected) + + if len(read) == 4: + raise Exception("Update", "Protocol version 0 not supported") + elif len(read) == expected: + base, version = struct.unpack(">II", read) + log("Update protocol v. %d" % version) + log("Available flash region base: %x" % base) + else: + raise Exception("Update", "Start command returned %d bytes" % len(read)) + + if base < 256: + raise Exception("Update", "Start returned error code 0x%x" % base) + + self._base = base + flash_base = self._brdcfg["flash"] + self._offset = self._base - flash_base + + # Find our active region. + for region in self._brdcfg["regions"]: + if (self._offset >= self._brdcfg["regions"][region][0]) and ( + self._offset + < ( + self._brdcfg["regions"][region][0] + + self._brdcfg["regions"][region][1] + ) + ): + log("Active region: %s" % region) + self._region = region + + def load_board(self, brdfile): + """Load firmware layout file. + + example as follows: + { + "board": "servo micro", + "vid": 6353, + "pid": 20506, + "flash": 134217728, + "regions": { + "RW": [65536, 65536], + "PSTATE": [61440, 4096], + "RO": [0, 61440] + } + } + + Args: + brdfile: path to board description file. + """ + with open(brdfile) as data_file: + data = json.load(data_file) + + # TODO(nsanders): validate this data before moving on. + self._brdcfg = data + if debug: + pprint(data) + + log("Board is %s" % self._brdcfg["board"]) + # Cast hex strings to int. + self._brdcfg["flash"] = int(self._brdcfg["flash"], 0) + self._brdcfg["vid"] = int(self._brdcfg["vid"], 0) + self._brdcfg["pid"] = int(self._brdcfg["pid"], 0) + + log("Flash Base is %x" % self._brdcfg["flash"]) + self._flashsize = 0 + for region in self._brdcfg["regions"]: + base = int(self._brdcfg["regions"][region][0], 0) + length = int(self._brdcfg["regions"][region][1], 0) + log("region %s\tbase:0x%08x size:0x%08x" % (region, base, length)) + self._flashsize += length + + # Convert these to int because json doesn't support hex. + self._brdcfg["regions"][region][0] = base + self._brdcfg["regions"][region][1] = length + + log("Flash Size: 0x%x" % self._flashsize) + + def load_file(self, binfile): + """Open and verify size of the target ec.bin file. + + Args: + binfile: path to ec.bin + + Raises: + Exception on file not found or filesize not matching. + """ + self._filesize = os.path.getsize(binfile) + self._binfile = open(binfile, "rb") + + if self._filesize != self._flashsize: + raise Exception( + "Update", + "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize), + ) # Generate command line arguments parser = argparse.ArgumentParser(description="Update firmware over usb") -parser.add_argument('-b', '--board', type=str, help="Board configuration json file", default="board.json") -parser.add_argument('-f', '--file', type=str, help="Complete ec.bin file", default="ec.bin") -parser.add_argument('-s', '--serial', type=str, help="Serial number", default="") -parser.add_argument('-l', '--list', action="store_true", help="List regions") -parser.add_argument('-v', '--verbose', action="store_true", help="Chatty output") +parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration json file", + default="board.json", +) +parser.add_argument( + "-f", "--file", type=str, help="Complete ec.bin file", default="ec.bin" +) +parser.add_argument("-s", "--serial", type=str, help="Serial number", default="") +parser.add_argument("-l", "--list", action="store_true", help="List regions") +parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output") + def main(): - global debug - args = parser.parse_args() + global debug + args = parser.parse_args() + brdfile = args.board + serial = args.serial + binfile = args.file + if args.verbose: + debug = True - brdfile = args.board - serial = args.serial - binfile = args.file - if args.verbose: - debug = True + with open(brdfile) as data_file: + names = json.load(data_file) - with open(brdfile) as data_file: - names = json.load(data_file) + p = Supdate() + p.load_board(brdfile) + p.connect_usb(serialname=serial) + p.load_file(binfile) - p = Supdate() - p.load_board(brdfile) - p.connect_usb(serialname=serial) - p.load_file(binfile) + # List solely prints the config. + if args.list: + return - # List solely prints the config. - if (args.list): - return + # Start transfer and erase. + p.start() + # Upload the bin file + log("Uploading %s" % binfile) + p.write_file() - # Start transfer and erase. - p.start() - # Upload the bin file - log("Uploading %s" % binfile) - p.write_file() + # Finalize + log("Done. Finalizing.") + p.stop() - # Finalize - log("Done. Finalizing.") - p.stop() if __name__ == "__main__": - main() - - + main() diff --git a/extra/usb_updater/servo_updater.py b/extra/usb_updater/servo_updater.py index 432ee120de..5402af70aa 100755 --- a/extra/usb_updater/servo_updater.py +++ b/extra/usb_updater/servo_updater.py @@ -14,40 +14,46 @@ from __future__ import print_function import argparse +import json import os import re import subprocess import time -import json - -import fw_update import ecusb.tiny_servo_common as c +import fw_update from ecusb import tiny_servod + class ServoUpdaterException(Exception): - """Raised on exceptions generated by servo_updater.""" + """Raised on exceptions generated by servo_updater.""" -BOARD_C2D2 = 'c2d2' -BOARD_SERVO_MICRO = 'servo_micro' -BOARD_SERVO_V4 = 'servo_v4' -BOARD_SERVO_V4P1 = 'servo_v4p1' -BOARD_SWEETBERRY = 'sweetberry' + +BOARD_C2D2 = "c2d2" +BOARD_SERVO_MICRO = "servo_micro" +BOARD_SERVO_V4 = "servo_v4" +BOARD_SERVO_V4P1 = "servo_v4p1" +BOARD_SWEETBERRY = "sweetberry" DEFAULT_BOARD = BOARD_SERVO_V4 # These lists are to facilitate exposing choices in the command-line tool # below. -BOARDS = [BOARD_C2D2, BOARD_SERVO_MICRO, BOARD_SERVO_V4, BOARD_SERVO_V4P1, - BOARD_SWEETBERRY] +BOARDS = [ + BOARD_C2D2, + BOARD_SERVO_MICRO, + BOARD_SERVO_V4, + BOARD_SERVO_V4P1, + BOARD_SWEETBERRY, +] # Servo firmware bundles four channels of firmware. We need to make sure the # user does not request a non-existing channel, so keep the lists around to # guard on command-line usage. -DEFAULT_CHANNEL = STABLE_CHANNEL = 'stable' +DEFAULT_CHANNEL = STABLE_CHANNEL = "stable" -PREV_CHANNEL = 'prev' +PREV_CHANNEL = "prev" # The ordering here matters. From left to right it's the channel that the user # is most likely to be running. This is used to inform and warn the user if @@ -55,403 +61,436 @@ PREV_CHANNEL = 'prev' # user know they are running the 'stable' version before letting them know they # are running 'dev' or even 'alpah' which (while true) might cause confusion. -CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, 'dev', 'alpha'] +CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, "dev", "alpha"] -DEFAULT_BASE_PATH = '/usr/' -TEST_IMAGE_BASE_PATH = '/usr/local/' +DEFAULT_BASE_PATH = "/usr/" +TEST_IMAGE_BASE_PATH = "/usr/local/" -COMMON_PATH = 'share/servo_updater' +COMMON_PATH = "share/servo_updater" -FIRMWARE_DIR = 'firmware/' -CONFIGS_DIR = 'configs/' +FIRMWARE_DIR = "firmware/" +CONFIGS_DIR = "configs/" RETRIES_COUNT = 10 RETRIES_DELAY = 1 + def do_with_retries(func, *args): - """Try a function several times - - Call function passed as argument and check if no error happened. - If exception was raised by function, - it will be retried up to RETRIES_COUNT times. - - Args: - func: function that will be called - args: arguments passed to 'func' - - Returns: - If call to function was successful, its result will be returned. - If retries count was exceeded, exception will be raised. - """ - - retry = 0 - while retry < RETRIES_COUNT: - try: - return func(*args) - except Exception as e: - print('Retrying function %s: %s' % (func.__name__, e)) - retry = retry + 1 - time.sleep(RETRIES_DELAY) - continue - - raise Exception("'{}' failed after {} retries".format( - func.__name__, RETRIES_COUNT)) + """Try a function several times + + Call function passed as argument and check if no error happened. + If exception was raised by function, + it will be retried up to RETRIES_COUNT times. + + Args: + func: function that will be called + args: arguments passed to 'func' + + Returns: + If call to function was successful, its result will be returned. + If retries count was exceeded, exception will be raised. + """ + + retry = 0 + while retry < RETRIES_COUNT: + try: + return func(*args) + except Exception as e: + print("Retrying function %s: %s" % (func.__name__, e)) + retry = retry + 1 + time.sleep(RETRIES_DELAY) + continue + + raise Exception("'{}' failed after {} retries".format(func.__name__, RETRIES_COUNT)) + def flash(brdfile, serialno, binfile): - """Call fw_update to upload to updater USB endpoint. + """Call fw_update to upload to updater USB endpoint. + + Args: + brdfile: path to board configuration file + serialno: device serial number + binfile: firmware file + """ - Args: - brdfile: path to board configuration file - serialno: device serial number - binfile: firmware file - """ + p = fw_update.Supdate() + p.load_board(brdfile) + p.connect_usb(serialname=serialno) + p.load_file(binfile) - p = fw_update.Supdate() - p.load_board(brdfile) - p.connect_usb(serialname=serialno) - p.load_file(binfile) + # Start transfer and erase. + p.start() + # Upload the bin file + print("Uploading %s" % binfile) + p.write_file() - # Start transfer and erase. - p.start() - # Upload the bin file - print('Uploading %s' % binfile) - p.write_file() + # Finalize + print("Done. Finalizing.") + p.stop() - # Finalize - print('Done. Finalizing.') - p.stop() def flash2(vidpid, serialno, binfile): - """Call fw update via usb_updater2 commandline. - - Args: - vidpid: vendor id and product id of device - serialno: device serial number (optional) - binfile: firmware file - """ - - tool = 'usb_updater2' - cmd = '%s -d %s' % (tool, vidpid) - if serialno: - cmd += ' -S %s' % serialno - cmd += ' -n' - cmd += ' %s' % binfile - - print(cmd) - help_cmd = '%s --help' % tool - with open('/dev/null') as devnull: - valid_check = subprocess.call(help_cmd.split(), stdout=devnull, - stderr=devnull) - if valid_check: - raise ServoUpdaterException('%s exit with res = %d. Make sure the tool ' - 'is available on the device.' % (help_cmd, - valid_check)) - res = subprocess.call(cmd.split()) - - if res in (0, 1, 2): - return res - else: - raise ServoUpdaterException('%s exit with res = %d' % (cmd, res)) + """Call fw update via usb_updater2 commandline. + + Args: + vidpid: vendor id and product id of device + serialno: device serial number (optional) + binfile: firmware file + """ + + tool = "usb_updater2" + cmd = "%s -d %s" % (tool, vidpid) + if serialno: + cmd += " -S %s" % serialno + cmd += " -n" + cmd += " %s" % binfile + + print(cmd) + help_cmd = "%s --help" % tool + with open("/dev/null") as devnull: + valid_check = subprocess.call(help_cmd.split(), stdout=devnull, stderr=devnull) + if valid_check: + raise ServoUpdaterException( + "%s exit with res = %d. Make sure the tool " + "is available on the device." % (help_cmd, valid_check) + ) + res = subprocess.call(cmd.split()) + + if res in (0, 1, 2): + return res + else: + raise ServoUpdaterException("%s exit with res = %d" % (cmd, res)) + def select(tinys, region): - """Jump to specified boot region + """Jump to specified boot region - Ensure the servo is in the expected ro/rw region. - This function jumps to the required region and verify if jump was - successful by executing 'sysinfo' command and reading current region. - If response was not received or region is invalid, exception is raised. + Ensure the servo is in the expected ro/rw region. + This function jumps to the required region and verify if jump was + successful by executing 'sysinfo' command and reading current region. + If response was not received or region is invalid, exception is raised. - Args: - tinys: TinyServod object - region: region to jump to, only "rw" and "ro" is allowed - """ + Args: + tinys: TinyServod object + region: region to jump to, only "rw" and "ro" is allowed + """ - if region not in ['rw', 'ro']: - raise Exception('Region must be ro or rw') + if region not in ["rw", "ro"]: + raise Exception("Region must be ro or rw") - if region == 'ro': - cmd = 'reboot' - else: - cmd = 'sysjump %s' % region + if region == "ro": + cmd = "reboot" + else: + cmd = "sysjump %s" % region - tinys.pty._issue_cmd(cmd) + tinys.pty._issue_cmd(cmd) - tinys.close() - time.sleep(2) - tinys.reinitialize() + tinys.close() + time.sleep(2) + tinys.reinitialize() + + res = tinys.pty._issue_cmd_get_results("sysinfo", [r"Copy:[\s]+(RO|RW)"]) + current_region = res[0][1].lower() + if current_region != region: + raise Exception("Invalid region: %s/%s" % (current_region, region)) - res = tinys.pty._issue_cmd_get_results('sysinfo', [r'Copy:[\s]+(RO|RW)']) - current_region = res[0][1].lower() - if current_region != region: - raise Exception('Invalid region: %s/%s' % (current_region, region)) def do_version(tinys): - """Check version via ec console 'pty'. + """Check version via ec console 'pty'. + + Args: + tinys: TinyServod object - Args: - tinys: TinyServod object + Returns: + detected version number - Returns: - detected version number + Commands are: + # > version + # ... + # Build: tigertail_v1.1.6749-74d1a312e + """ + cmd = "version" + regex = r"Build:\s+(\S+)[\r\n]+" - Commands are: - # > version - # ... - # Build: tigertail_v1.1.6749-74d1a312e - """ - cmd = 'version' - regex = r'Build:\s+(\S+)[\r\n]+' + results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0] - results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0] + return results[1].strip(" \t\r\n\0") - return results[1].strip(' \t\r\n\0') def do_updater_version(tinys): - """Check whether this uses python updater or c++ updater - - Args: - tinys: TinyServod object - - Returns: - updater version number. 2 or 6. - """ - vers = do_version(tinys) - - # Servo versions below 58 are from servo-9040.B. Versions starting with _v2 - # are newer than anything _v1, no need to check the exact number. Updater - # version is not directly queryable. - if re.search(r'_v[2-9]\.\d', vers): - return 6 - m = re.search(r'_v1\.1\.(\d\d\d\d)', vers) - if m: - version_number = int(m.group(1)) - if version_number < 5800: - return 2 - else: - return 6 - raise ServoUpdaterException( - "Can't determine updater target from vers: [%s]" % vers) + """Check whether this uses python updater or c++ updater -def _extract_version(boardname, binfile): - """Find the version string from |binfile|. + Args: + tinys: TinyServod object - Args: - boardname: the name of the board, eg. "servo_micro" - binfile: path to the binary to search + Returns: + updater version number. 2 or 6. + """ + vers = do_version(tinys) - Returns: - the version string. - """ - if boardname is None: - # cannot extract the version if the name is None - return None - rawstrings = subprocess.check_output( - ['cbfstool', binfile, 'read', '-r', 'RO_FRID', '-f', '/dev/stdout'], - **c.get_subprocess_args()) - m = re.match(r'%s_v\S+' % boardname, rawstrings) - if m: - newvers = m.group(0).strip(' \t\r\n\0') - else: - raise ServoUpdaterException("Can't find version from file: %s." % binfile) + # Servo versions below 58 are from servo-9040.B. Versions starting with _v2 + # are newer than anything _v1, no need to check the exact number. Updater + # version is not directly queryable. + if re.search(r"_v[2-9]\.\d", vers): + return 6 + m = re.search(r"_v1\.1\.(\d\d\d\d)", vers) + if m: + version_number = int(m.group(1)) + if version_number < 5800: + return 2 + else: + return 6 + raise ServoUpdaterException("Can't determine updater target from vers: [%s]" % vers) + + +def _extract_version(boardname, binfile): + """Find the version string from |binfile|. + + Args: + boardname: the name of the board, eg. "servo_micro" + binfile: path to the binary to search + + Returns: + the version string. + """ + if boardname is None: + # cannot extract the version if the name is None + return None + rawstrings = subprocess.check_output( + ["cbfstool", binfile, "read", "-r", "RO_FRID", "-f", "/dev/stdout"], + **c.get_subprocess_args() + ) + m = re.match(r"%s_v\S+" % boardname, rawstrings) + if m: + newvers = m.group(0).strip(" \t\r\n\0") + else: + raise ServoUpdaterException("Can't find version from file: %s." % binfile) + + return newvers - return newvers def get_firmware_channel(bname, version): - """Find out which channel |version| for |bname| came from. - - Args: - bname: board name - version: current version string - - Returns: - one of the channel names if |version| came from one of those, or None - """ - for channel in CHANNELS: - # Pass |bname| as cname to find the board specific file, and pass None as - # fname to ensure the default directory is searched - _, _, vers = get_files_and_version(bname, None, channel=channel) - if version == vers: - return channel - # None of the channels matched. This firmware is currently unknown. - return None + """Find out which channel |version| for |bname| came from. + + Args: + bname: board name + version: current version string + + Returns: + one of the channel names if |version| came from one of those, or None + """ + for channel in CHANNELS: + # Pass |bname| as cname to find the board specific file, and pass None as + # fname to ensure the default directory is searched + _, _, vers = get_files_and_version(bname, None, channel=channel) + if version == vers: + return channel + # None of the channels matched. This firmware is currently unknown. + return None + def get_files_and_version(cname, fname=None, channel=DEFAULT_CHANNEL): - """Select config and firmware binary files. - - This checks default file names and paths. - In: /usr/share/servo_updater/[firmware|configs] - check for board.json, board.bin - - Args: - cname: board name, or config name. eg. "servo_v4" or "servo_v4.json" - fname: firmware binary name. Can be None to try default. - channel: the channel requested for servo firmware. See |CHANNELS| above. - - Returns: - cname, fname, version: validated filenames selected from the path. - """ - for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH): - updater_path = os.path.join(p, COMMON_PATH) - if os.path.exists(updater_path): - break - else: - raise ServoUpdaterException('servo_updater/ dir not found in known spots.') - - firmware_path = os.path.join(updater_path, FIRMWARE_DIR) - configs_path = os.path.join(updater_path, CONFIGS_DIR) - - for p in (firmware_path, configs_path): - if not os.path.exists(p): - raise ServoUpdaterException('Could not find required path %r' % p) - - if not os.path.isfile(cname): - # If not an existing file, try checking on the default path. - newname = os.path.join(configs_path, cname) - if os.path.isfile(newname): - cname = newname + """Select config and firmware binary files. + + This checks default file names and paths. + In: /usr/share/servo_updater/[firmware|configs] + check for board.json, board.bin + + Args: + cname: board name, or config name. eg. "servo_v4" or "servo_v4.json" + fname: firmware binary name. Can be None to try default. + channel: the channel requested for servo firmware. See |CHANNELS| above. + + Returns: + cname, fname, version: validated filenames selected from the path. + """ + for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH): + updater_path = os.path.join(p, COMMON_PATH) + if os.path.exists(updater_path): + break else: - # Try appending ".json" to convert board name to config file. - cname = newname + '.json' + raise ServoUpdaterException("servo_updater/ dir not found in known spots.") + + firmware_path = os.path.join(updater_path, FIRMWARE_DIR) + configs_path = os.path.join(updater_path, CONFIGS_DIR) + + for p in (firmware_path, configs_path): + if not os.path.exists(p): + raise ServoUpdaterException("Could not find required path %r" % p) + if not os.path.isfile(cname): - raise ServoUpdaterException("Can't find config file: %s." % cname) - - # Always retrieve the boardname - with open(cname) as data_file: - data = json.load(data_file) - boardname = data['board'] - - if not fname: - # If no |fname| supplied, look for the default locations with the board - # and channel requested. - binary_file = '%s.%s.bin' % (boardname, channel) - newname = os.path.join(firmware_path, binary_file) - if os.path.isfile(newname): - fname = newname - else: - raise ServoUpdaterException("Can't find firmware binary: %s." % - binary_file) - elif not os.path.isfile(fname): - # If a name is specified but not found, try the default path. - newname = os.path.join(firmware_path, fname) - if os.path.isfile(newname): - fname = newname + # If not an existing file, try checking on the default path. + newname = os.path.join(configs_path, cname) + if os.path.isfile(newname): + cname = newname + else: + # Try appending ".json" to convert board name to config file. + cname = newname + ".json" + if not os.path.isfile(cname): + raise ServoUpdaterException("Can't find config file: %s." % cname) + + # Always retrieve the boardname + with open(cname) as data_file: + data = json.load(data_file) + boardname = data["board"] + + if not fname: + # If no |fname| supplied, look for the default locations with the board + # and channel requested. + binary_file = "%s.%s.bin" % (boardname, channel) + newname = os.path.join(firmware_path, binary_file) + if os.path.isfile(newname): + fname = newname + else: + raise ServoUpdaterException("Can't find firmware binary: %s." % binary_file) + elif not os.path.isfile(fname): + # If a name is specified but not found, try the default path. + newname = os.path.join(firmware_path, fname) + if os.path.isfile(newname): + fname = newname + else: + raise ServoUpdaterException("Can't find file: %s." % fname) + + # Lastly, retrieve the version as well for decision making, debug, and + # informational purposes. + binvers = _extract_version(boardname, fname) + + return cname, fname, binvers + + +def main(): + parser = argparse.ArgumentParser(description="Image a servo device") + parser.add_argument( + "-p", + "--print", + dest="print_only", + action="store_true", + default=False, + help="only print available firmware for board/channel", + ) + parser.add_argument( + "-s", "--serialno", type=str, help="serial number to program", default=None + ) + parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration json file", + default=DEFAULT_BOARD, + choices=BOARDS, + ) + parser.add_argument( + "-c", + "--channel", + type=str, + help="Firmware channel to use", + default=DEFAULT_CHANNEL, + choices=CHANNELS, + ) + parser.add_argument( + "-f", "--file", type=str, help="Complete ec.bin file", default=None + ) + parser.add_argument( + "--force", + action="store_true", + help="Update even if version match", + default=False, + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output") + parser.add_argument( + "-r", "--reboot", action="store_true", help="Always reboot, even after probe." + ) + + args = parser.parse_args() + + brdfile, binfile, newvers = get_files_and_version( + args.board, args.file, args.channel + ) + + # If the user only cares about the information then just print it here, + # and exit. + if args.print_only: + output = ("board: %s\n" "channel: %s\n" "firmware: %s") % ( + args.board, + args.channel, + newvers, + ) + print(output) + return + + serialno = args.serialno + + with open(brdfile) as data_file: + data = json.load(data_file) + vid, pid = int(data["vid"], 0), int(data["pid"], 0) + vidpid = "%04x:%04x" % (vid, pid) + iface = int(data["console"], 0) + boardname = data["board"] + + # Make sure device is up. + print("===== Waiting for USB device =====") + c.wait_for_usb(vidpid, serialname=serialno) + # We need a tiny_servod to query some information. Set it up first. + tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose) + + if not args.force: + vers = do_version(tinys) + print("Current %s version is %s" % (boardname, vers)) + print("Available %s version is %s" % (boardname, newvers)) + + if newvers == vers: + print("No version update needed") + if args.reboot: + select(tinys, "ro") + return + else: + print("Updating to recommended version.") + + # Make sure the servo MCU is in RO + print("===== Jumping to RO =====") + do_with_retries(select, tinys, "ro") + + print("===== Flashing RW =====") + vers = do_with_retries(do_updater_version, tinys) + # To make sure that the tiny_servod here does not interfere with other + # processes, close it out. + tinys.close() + + if vers == 2: + flash(brdfile, serialno, binfile) + elif vers == 6: + do_with_retries(flash2, vidpid, serialno, binfile) else: - raise ServoUpdaterException("Can't find file: %s." % fname) + raise ServoUpdaterException("Can't detect updater version") - # Lastly, retrieve the version as well for decision making, debug, and - # informational purposes. - binvers = _extract_version(boardname, fname) + # Make sure device is up. + c.wait_for_usb(vidpid, serialname=serialno) + # After we have made sure that it's back/available, reconnect the tiny servod. + tinys.reinitialize() - return cname, fname, binvers + # Make sure the servo MCU is in RW + print("===== Jumping to RW =====") + do_with_retries(select, tinys, "rw") -def main(): - parser = argparse.ArgumentParser(description='Image a servo device') - parser.add_argument('-p', '--print', dest='print_only', action='store_true', - default=False, - help='only print available firmware for board/channel') - parser.add_argument('-s', '--serialno', type=str, - help='serial number to program', default=None) - parser.add_argument('-b', '--board', type=str, - help='Board configuration json file', - default=DEFAULT_BOARD, choices=BOARDS) - parser.add_argument('-c', '--channel', type=str, - help='Firmware channel to use', - default=DEFAULT_CHANNEL, choices=CHANNELS) - parser.add_argument('-f', '--file', type=str, - help='Complete ec.bin file', default=None) - parser.add_argument('--force', action='store_true', - help='Update even if version match', default=False) - parser.add_argument('-v', '--verbose', action='store_true', - help='Chatty output') - parser.add_argument('-r', '--reboot', action='store_true', - help='Always reboot, even after probe.') - - args = parser.parse_args() - - brdfile, binfile, newvers = get_files_and_version(args.board, args.file, - args.channel) - - # If the user only cares about the information then just print it here, - # and exit. - if args.print_only: - output = ('board: %s\n' - 'channel: %s\n' - 'firmware: %s') % (args.board, args.channel, newvers) - print(output) - return - - serialno = args.serialno - - with open(brdfile) as data_file: - data = json.load(data_file) - vid, pid = int(data['vid'], 0), int(data['pid'], 0) - vidpid = '%04x:%04x' % (vid, pid) - iface = int(data['console'], 0) - boardname = data['board'] - - # Make sure device is up. - print('===== Waiting for USB device =====') - c.wait_for_usb(vidpid, serialname=serialno) - # We need a tiny_servod to query some information. Set it up first. - tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose) - - if not args.force: - vers = do_version(tinys) - print('Current %s version is %s' % (boardname, vers)) - print('Available %s version is %s' % (boardname, newvers)) - - if newvers == vers: - print('No version update needed') - if args.reboot: - select(tinys, 'ro') - return + print("===== Flashing RO =====") + vers = do_with_retries(do_updater_version, tinys) + + if vers == 2: + flash(brdfile, serialno, binfile) + elif vers == 6: + do_with_retries(flash2, vidpid, serialno, binfile) else: - print('Updating to recommended version.') - - # Make sure the servo MCU is in RO - print('===== Jumping to RO =====') - do_with_retries(select, tinys, 'ro') - - print('===== Flashing RW =====') - vers = do_with_retries(do_updater_version, tinys) - # To make sure that the tiny_servod here does not interfere with other - # processes, close it out. - tinys.close() - - if vers == 2: - flash(brdfile, serialno, binfile) - elif vers == 6: - do_with_retries(flash2, vidpid, serialno, binfile) - else: - raise ServoUpdaterException("Can't detect updater version") - - # Make sure device is up. - c.wait_for_usb(vidpid, serialname=serialno) - # After we have made sure that it's back/available, reconnect the tiny servod. - tinys.reinitialize() - - # Make sure the servo MCU is in RW - print('===== Jumping to RW =====') - do_with_retries(select, tinys, 'rw') - - print('===== Flashing RO =====') - vers = do_with_retries(do_updater_version, tinys) - - if vers == 2: - flash(brdfile, serialno, binfile) - elif vers == 6: - do_with_retries(flash2, vidpid, serialno, binfile) - else: - raise ServoUpdaterException("Can't detect updater version") - - # Make sure the servo MCU is in RO - print('===== Rebooting =====') - do_with_retries(select, tinys, 'ro') - # Perform additional reboot to free USB/UART resources, taken by tiny servod. - # See https://issuetracker.google.com/196021317 for background. - tinys.pty._issue_cmd('reboot') - - print('===== Finished =====') - -if __name__ == '__main__': - main() + raise ServoUpdaterException("Can't detect updater version") + + # Make sure the servo MCU is in RO + print("===== Rebooting =====") + do_with_retries(select, tinys, "ro") + # Perform additional reboot to free USB/UART resources, taken by tiny servod. + # See https://issuetracker.google.com/196021317 for background. + tinys.pty._issue_cmd("reboot") + + print("===== Finished =====") + + +if __name__ == "__main__": + main() |