summaryrefslogtreecommitdiff
path: root/extra
diff options
context:
space:
mode:
authorJeremy Bettis <jbettis@google.com>2022-07-08 10:58:19 -0600
committerChromeos LUCI <chromeos-scoped@luci-project-accounts.iam.gserviceaccount.com>2022-07-12 19:13:33 +0000
commit7540e7b47b55447475bb8191fb3520dd67cf7998 (patch)
tree13309dbcf1db48e60fa2c2e5aed79f63bce00b5e /extra
parent7c114b8e1a3bb29991da70b9de394ac5d4f6c909 (diff)
downloadchrome-ec-7540e7b47b55447475bb8191fb3520dd67cf7998.tar.gz
ec: Format all python files with black and isort
find . \( -path ./private -prune \) -o -name '*.py' -print | xargs black find . \( -path ./private -prune \) -o -name '*.py' -print | xargs ~/chromiumos/chromite/scripts/isort --settings-file=.isort.cfg BRANCH=None BUG=b:238434058 TEST=None Signed-off-by: Jeremy Bettis <jbettis@google.com> Change-Id: I63462d6f15d1eaf3db84eb20d1404ee976be8382 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/ec/+/3749242 Commit-Queue: Jeremy Bettis <jbettis@chromium.org> Reviewed-by: Tom Hughes <tomhughes@chromium.org> Tested-by: Jeremy Bettis <jbettis@chromium.org> Commit-Queue: Jack Rosenthal <jrosenth@chromium.org> Auto-Submit: Jeremy Bettis <jbettis@chromium.org> Reviewed-by: Jack Rosenthal <jrosenth@chromium.org>
Diffstat (limited to 'extra')
-rwxr-xr-xextra/cr50_rma_open/cr50_rma_open.py418
-rwxr-xr-xextra/stack_analyzer/stack_analyzer.py3562
-rwxr-xr-xextra/stack_analyzer/stack_analyzer_unittest.py1708
-rw-r--r--extra/tigertool/ecusb/__init__.py2
-rw-r--r--extra/tigertool/ecusb/pty_driver.py545
-rw-r--r--extra/tigertool/ecusb/stm32uart.py438
-rw-r--r--extra/tigertool/ecusb/stm32usb.py207
-rw-r--r--extra/tigertool/ecusb/tiny_servo_common.py343
-rw-r--r--extra/tigertool/ecusb/tiny_servod.py83
-rwxr-xr-xextra/tigertool/tigertest.py97
-rwxr-xr-xextra/tigertool/tigertool.py505
-rw-r--r--extra/usb_power/convert_power_log_board.py35
-rwxr-xr-xextra/usb_power/convert_servo_ina.py86
-rwxr-xr-xextra/usb_power/powerlog.py1762
-rw-r--r--extra/usb_power/powerlog_unittest.py74
-rw-r--r--extra/usb_power/stats_manager.py745
-rw-r--r--extra/usb_power/stats_manager_unittest.py595
-rwxr-xr-xextra/usb_serial/console.py435
-rwxr-xr-xextra/usb_updater/fw_update.py762
-rwxr-xr-xextra/usb_updater/servo_updater.py767
20 files changed, 6864 insertions, 6305 deletions
diff --git a/extra/cr50_rma_open/cr50_rma_open.py b/extra/cr50_rma_open/cr50_rma_open.py
index 42ddbbac2d..b77b8f3dbb 100755
--- a/extra/cr50_rma_open/cr50_rma_open.py
+++ b/extra/cr50_rma_open/cr50_rma_open.py
@@ -57,16 +57,17 @@ CCD_IS_UNRESTRICTED = 1 << 0
WP_IS_DISABLED = 1 << 1
TESTLAB_IS_ENABLED = 1 << 2
RMA_OPENED = CCD_IS_UNRESTRICTED | WP_IS_DISABLED
-URL = ('https://www.google.com/chromeos/partner/console/cr50reset?'
- 'challenge=%s&hwid=%s')
-RMA_SUPPORT_PROD = '0.3.3'
-RMA_SUPPORT_PREPVT = '0.4.5'
-DEV_MODE_OPEN_PROD = '0.3.9'
-DEV_MODE_OPEN_PREPVT = '0.4.7'
-TESTLAB_PROD = '0.3.10'
-CR50_USB = '18d1:5014'
-CR50_LSUSB_CMD = ['lsusb', '-vd', CR50_USB]
-ERASED_BID = 'ffffffff'
+URL = (
+ "https://www.google.com/chromeos/partner/console/cr50reset?" "challenge=%s&hwid=%s"
+)
+RMA_SUPPORT_PROD = "0.3.3"
+RMA_SUPPORT_PREPVT = "0.4.5"
+DEV_MODE_OPEN_PROD = "0.3.9"
+DEV_MODE_OPEN_PREPVT = "0.4.7"
+TESTLAB_PROD = "0.3.10"
+CR50_USB = "18d1:5014"
+CR50_LSUSB_CMD = ["lsusb", "-vd", CR50_USB]
+ERASED_BID = "ffffffff"
DEBUG_MISSING_USB = """
Unable to find Cr50 Device 18d1:5014
@@ -128,13 +129,14 @@ DEBUG_DUT_CONTROL_OSERROR = """
Run from chroot if you are trying to use a /dev/pts ccd servo console
"""
+
class RMAOpen(object):
"""Used to find the cr50 console and run RMA open"""
- ENABLE_TESTLAB_CMD = 'ccd testlab enabled\n'
+ ENABLE_TESTLAB_CMD = "ccd testlab enabled\n"
def __init__(self, device=None, usb_serial=None, servo_port=None, ip=None):
- self.servo_port = servo_port if servo_port else '9999'
+ self.servo_port = servo_port if servo_port else "9999"
self.ip = ip
if device:
self.set_cr50_device(device)
@@ -142,18 +144,18 @@ class RMAOpen(object):
self.find_cr50_servo_uart()
else:
self.find_cr50_device(usb_serial)
- logging.info('DEVICE: %s', self.device)
+ logging.info("DEVICE: %s", self.device)
self.check_version()
self.print_platform_info()
- logging.info('Cr50 setup ok')
+ logging.info("Cr50 setup ok")
self.update_ccd_state()
self.using_ccd = self.device_is_running_with_servo_ccd()
def _dut_control(self, control):
"""Run dut-control and return the response"""
try:
- cmd = ['dut-control', '-p', self.servo_port, control]
- return subprocess.check_output(cmd, encoding='utf-8').strip()
+ cmd = ["dut-control", "-p", self.servo_port, control]
+ return subprocess.check_output(cmd, encoding="utf-8").strip()
except OSError:
logging.warning(DEBUG_DUT_CONTROL_OSERROR)
raise
@@ -163,8 +165,8 @@ class RMAOpen(object):
Find the console and configure it, so it can be used with this script.
"""
- self._dut_control('cr50_uart_timestamp:off')
- self.device = self._dut_control('cr50_uart_pty').split(':')[-1]
+ self._dut_control("cr50_uart_timestamp:off")
+ self.device = self._dut_control("cr50_uart_pty").split(":")[-1]
def set_cr50_device(self, device):
"""Save the device used for the console"""
@@ -183,38 +185,38 @@ class RMAOpen(object):
try:
ser = serial.Serial(self.device, timeout=1)
except OSError:
- logging.warning('Permission denied %s', self.device)
- logging.warning('Try running cr50_rma_open with sudo')
+ logging.warning("Permission denied %s", self.device)
+ logging.warning("Try running cr50_rma_open with sudo")
raise
- write_cmd = cmd + '\n\n'
- ser.write(write_cmd.encode('utf-8'))
+ write_cmd = cmd + "\n\n"
+ ser.write(write_cmd.encode("utf-8"))
if nbytes:
output = ser.read(nbytes)
else:
output = ser.readall()
ser.close()
- output = output.decode('utf-8').strip() if output else ''
+ output = output.decode("utf-8").strip() if output else ""
# Return only the command output
- split_cmd = cmd + '\r'
+ split_cmd = cmd + "\r"
if cmd and split_cmd in output:
- return ''.join(output.rpartition(split_cmd)[1::]).split('>')[0]
+ return "".join(output.rpartition(split_cmd)[1::]).split(">")[0]
return output
def device_is_running_with_servo_ccd(self):
"""Return True if the device is a servod ccd console"""
# servod uses /dev/pts consoles. Non-servod uses /dev/ttyUSBX
- if '/dev/pts' not in self.device:
+ if "/dev/pts" not in self.device:
return False
# If cr50 doesn't show rdd is connected, cr50 the device must not be
# a ccd device
- if 'Rdd: connected' not in self.send_cmd_get_output('ccdstate'):
+ if "Rdd: connected" not in self.send_cmd_get_output("ccdstate"):
return False
# Check if the servod is running with ccd. This requires the script
# is run in the chroot, so run it last.
- if 'ccd_cr50' not in self._dut_control('servo_type'):
+ if "ccd_cr50" not in self._dut_control("servo_type"):
return False
- logging.info('running through servod ccd')
+ logging.info("running through servod ccd")
return True
def get_rma_challenge(self):
@@ -239,14 +241,14 @@ class RMAOpen(object):
Returns:
The RMA challenge with all whitespace removed.
"""
- output = self.send_cmd_get_output('rma_auth').strip()
- logging.info('rma_auth output:\n%s', output)
+ output = self.send_cmd_get_output("rma_auth").strip()
+ logging.info("rma_auth output:\n%s", output)
# Extract the challenge from the console output
- if 'generated challenge:' in output:
- return output.split('generated challenge:')[-1].strip()
- challenge = ''.join(re.findall(r' \S{5}' * 4, output))
+ if "generated challenge:" in output:
+ return output.split("generated challenge:")[-1].strip()
+ challenge = "".join(re.findall(r" \S{5}" * 4, output))
# Remove all whitespace
- return re.sub(r'\s', '', challenge)
+ return re.sub(r"\s", "", challenge)
def generate_challenge_url(self, hwid):
"""Get the rma_auth challenge
@@ -257,12 +259,14 @@ class RMAOpen(object):
challenge = self.get_rma_challenge()
self.print_platform_info()
- logging.info('CHALLENGE: %s', challenge)
- logging.info('HWID: %s', hwid)
+ logging.info("CHALLENGE: %s", challenge)
+ logging.info("HWID: %s", hwid)
url = URL % (challenge, hwid)
- logging.info('GOTO:\n %s', url)
- logging.info('If the server fails to debug the challenge make sure the '
- 'RLZ is allowlisted')
+ logging.info("GOTO:\n %s", url)
+ logging.info(
+ "If the server fails to debug the challenge make sure the "
+ "RLZ is allowlisted"
+ )
def try_authcode(self, authcode):
"""Try opening cr50 with the authcode
@@ -272,48 +276,48 @@ class RMAOpen(object):
"""
# rma_auth may cause the system to reboot. Don't wait to read all that
# output. Read the first 300 bytes and call it a day.
- output = self.send_cmd_get_output('rma_auth ' + authcode, nbytes=300)
- logging.info('CR50 RESPONSE: %s', output)
- logging.info('waiting for cr50 reboot')
+ output = self.send_cmd_get_output("rma_auth " + authcode, nbytes=300)
+ logging.info("CR50 RESPONSE: %s", output)
+ logging.info("waiting for cr50 reboot")
# Cr50 may be rebooting. Wait a bit
time.sleep(5)
if self.using_ccd:
# After reboot, reset the ccd endpoints
- self._dut_control('power_state:ccd_reset')
+ self._dut_control("power_state:ccd_reset")
# Update the ccd state after the authcode attempt
self.update_ccd_state()
- authcode_match = 'process_response: success!' in output
+ authcode_match = "process_response: success!" in output
if not self.check(CCD_IS_UNRESTRICTED):
if not authcode_match:
logging.warning(DEBUG_AUTHCODE_MISMATCH)
- message = 'Authcode mismatch. Check args and url'
+ message = "Authcode mismatch. Check args and url"
else:
- message = 'Could not set all capability privileges to Always'
+ message = "Could not set all capability privileges to Always"
raise ValueError(message)
def wp_is_force_disabled(self):
"""Returns True if write protect is forced disabled"""
- output = self.send_cmd_get_output('wp')
- wp_state = output.split('Flash WP:', 1)[-1].split('\n', 1)[0].strip()
- logging.info('wp: %s', wp_state)
- return wp_state == 'forced disabled'
+ output = self.send_cmd_get_output("wp")
+ wp_state = output.split("Flash WP:", 1)[-1].split("\n", 1)[0].strip()
+ logging.info("wp: %s", wp_state)
+ return wp_state == "forced disabled"
def testlab_is_enabled(self):
"""Returns True if testlab mode is enabled"""
- output = self.send_cmd_get_output('ccd testlab')
- testlab_state = output.split('mode')[-1].strip().lower()
- logging.info('testlab: %s', testlab_state)
- return testlab_state == 'enabled'
+ output = self.send_cmd_get_output("ccd testlab")
+ testlab_state = output.split("mode")[-1].strip().lower()
+ logging.info("testlab: %s", testlab_state)
+ return testlab_state == "enabled"
def ccd_is_restricted(self):
"""Returns True if any of the capabilities are still restricted"""
- output = self.send_cmd_get_output('ccd')
- if 'Capabilities' not in output:
- raise ValueError('Could not get ccd output')
- logging.debug('CURRENT CCD SETTINGS:\n%s', output)
- restricted = 'IfOpened' in output or 'IfUnlocked' in output
- logging.info('ccd: %srestricted', '' if restricted else 'Un')
+ output = self.send_cmd_get_output("ccd")
+ if "Capabilities" not in output:
+ raise ValueError("Could not get ccd output")
+ logging.debug("CURRENT CCD SETTINGS:\n%s", output)
+ restricted = "IfOpened" in output or "IfUnlocked" in output
+ logging.info("ccd: %srestricted", "" if restricted else "Un")
return restricted
def update_ccd_state(self):
@@ -339,9 +343,10 @@ class RMAOpen(object):
def _capabilities_allow_open_from_console(self):
"""Return True if ccd open is Always allowed from usb"""
- output = self.send_cmd_get_output('ccd')
- return (re.search('OpenNoDevMode.*Always', output) and
- re.search('OpenFromUSB.*Always', output))
+ output = self.send_cmd_get_output("ccd")
+ return re.search("OpenNoDevMode.*Always", output) and re.search(
+ "OpenFromUSB.*Always", output
+ )
def _requires_dev_mode_open(self):
"""Return True if the image requires dev mode to open"""
@@ -354,78 +359,81 @@ class RMAOpen(object):
def _run_on_dut(self, command):
"""Run the command on the DUT."""
- return subprocess.check_output(['ssh', self.ip, command],
- encoding='utf-8')
+ return subprocess.check_output(["ssh", self.ip, command], encoding="utf-8")
def _open_in_dev_mode(self):
"""Open Cr50 when it's in dev mode"""
- output = self.send_cmd_get_output('ccd')
+ output = self.send_cmd_get_output("ccd")
# If the device is already open, nothing needs to be done.
- if 'State: Open' not in output:
+ if "State: Open" not in output:
# Verify the device is in devmode before trying to run open.
- if 'dev_mode' not in output:
- logging.warning('Enter dev mode to open ccd or update to %s',
- TESTLAB_PROD)
- raise ValueError('DUT not in dev mode')
+ if "dev_mode" not in output:
+ logging.warning(
+ "Enter dev mode to open ccd or update to %s", TESTLAB_PROD
+ )
+ raise ValueError("DUT not in dev mode")
if not self.ip:
- logging.warning("If your DUT doesn't have ssh support, run "
- "'gsctool -a -o' from the AP")
- raise ValueError('Cannot run ccd open without dut ip')
- self._run_on_dut('gsctool -a -o')
+ logging.warning(
+ "If your DUT doesn't have ssh support, run "
+ "'gsctool -a -o' from the AP"
+ )
+ raise ValueError("Cannot run ccd open without dut ip")
+ self._run_on_dut("gsctool -a -o")
# Wait >1 second for cr50 to update ccd state
time.sleep(3)
- output = self.send_cmd_get_output('ccd')
- if 'State: Open' not in output:
- raise ValueError('Could not open cr50')
- logging.info('ccd is open')
+ output = self.send_cmd_get_output("ccd")
+ if "State: Open" not in output:
+ raise ValueError("Could not open cr50")
+ logging.info("ccd is open")
def enable_testlab(self):
"""Disable write protect"""
if not self._has_testlab_support():
- logging.warning('Testlab mode is not supported in prod iamges')
+ logging.warning("Testlab mode is not supported in prod iamges")
return
# Some cr50 images need to be in dev mode before they can be opened.
if self._requires_dev_mode_open():
self._open_in_dev_mode()
else:
- self.send_cmd_get_output('ccd open')
- logging.info('Enabling testlab mode reqires pressing the power button.')
- logging.info('Once the process starts keep tapping the power button '
- 'for 10 seconds.')
+ self.send_cmd_get_output("ccd open")
+ logging.info("Enabling testlab mode reqires pressing the power button.")
+ logging.info(
+ "Once the process starts keep tapping the power button " "for 10 seconds."
+ )
input("Press Enter when you're ready to start...")
end_time = time.time() + 15
ser = serial.Serial(self.device, timeout=1)
- printed_lines = ''
- output = ''
+ printed_lines = ""
+ output = ""
# start ccd testlab enable
- ser.write(self.ENABLE_TESTLAB_CMD.encode('utf-8'))
- logging.info('start pressing the power button\n\n')
+ ser.write(self.ENABLE_TESTLAB_CMD.encode("utf-8"))
+ logging.info("start pressing the power button\n\n")
# Print all of the cr50 output as we get it, so the user will have more
# information about pressing the power button. Tapping the power button
# a couple of times should do it, but this will give us more confidence
# the process is still running/worked.
try:
while time.time() < end_time:
- output += ser.read(100).decode('utf-8')
- full_lines = output.rsplit('\n', 1)[0]
+ output += ser.read(100).decode("utf-8")
+ full_lines = output.rsplit("\n", 1)[0]
new_lines = full_lines
if printed_lines:
new_lines = full_lines.split(printed_lines, 1)[-1].strip()
- logging.info('\n%s', new_lines)
+ logging.info("\n%s", new_lines)
printed_lines = full_lines
# Make sure the process hasn't ended. If it has, print the last
# of the output and exit.
new_lines = output.split(printed_lines, 1)[-1]
- if 'CCD test lab mode enabled' in output:
+ if "CCD test lab mode enabled" in output:
# print the last of the ou
logging.info(new_lines)
break
- elif 'Physical presence check timeout' in output:
+ elif "Physical presence check timeout" in output:
logging.info(new_lines)
- logging.warning('Did not detect power button press in time')
- raise ValueError('Could not enable testlab mode try again')
+ logging.warning("Did not detect power button press in time")
+ raise ValueError("Could not enable testlab mode try again")
finally:
ser.close()
# Wait for the ccd hook to update things
@@ -433,44 +441,48 @@ class RMAOpen(object):
# Update the state after attempting to disable write protect
self.update_ccd_state()
if not self.check(TESTLAB_IS_ENABLED):
- raise ValueError('Could not enable testlab mode try again')
+ raise ValueError("Could not enable testlab mode try again")
def wp_disable(self):
"""Disable write protect"""
- logging.info('Disabling write protect')
- self.send_cmd_get_output('wp disable')
+ logging.info("Disabling write protect")
+ self.send_cmd_get_output("wp disable")
# Update the state after attempting to disable write protect
self.update_ccd_state()
if not self.check(WP_IS_DISABLED):
- raise ValueError('Could not disable write protect')
+ raise ValueError("Could not disable write protect")
def check_version(self):
"""Make sure cr50 is running a version that supports RMA Open"""
- output = self.send_cmd_get_output('version')
+ output = self.send_cmd_get_output("version")
if not output.strip():
logging.warning(DEBUG_DEVICE, self.device)
- raise ValueError('Could not communicate with %s' % self.device)
+ raise ValueError("Could not communicate with %s" % self.device)
- version = re.search(r'RW.*\* ([\d\.]+)/', output).group(1)
- logging.info('Running Cr50 Version: %s', version)
- self.running_ver_fields = [int(field) for field in version.split('.')]
+ version = re.search(r"RW.*\* ([\d\.]+)/", output).group(1)
+ logging.info("Running Cr50 Version: %s", version)
+ self.running_ver_fields = [int(field) for field in version.split(".")]
# prePVT images have even major versions. Prod have odd
self.is_prepvt = self.running_ver_fields[1] % 2 == 0
rma_support = RMA_SUPPORT_PREPVT if self.is_prepvt else RMA_SUPPORT_PROD
- logging.info('%s RMA support added in: %s',
- 'prePVT' if self.is_prepvt else 'prod', rma_support)
+ logging.info(
+ "%s RMA support added in: %s",
+ "prePVT" if self.is_prepvt else "prod",
+ rma_support,
+ )
if not self.is_prepvt and self._running_version_is_older(TESTLAB_PROD):
- raise ValueError('Update cr50. No testlab support in old prod '
- 'images.')
+ raise ValueError("Update cr50. No testlab support in old prod " "images.")
if self._running_version_is_older(rma_support):
- raise ValueError('%s does not have RMA support. Update to at '
- 'least %s' % (version, rma_support))
+ raise ValueError(
+ "%s does not have RMA support. Update to at "
+ "least %s" % (version, rma_support)
+ )
def _running_version_is_older(self, target_ver):
"""Returns True if running version is older than target_ver."""
- target_ver_fields = [int(field) for field in target_ver.split('.')]
+ target_ver_fields = [int(field) for field in target_ver.split(".")]
for i, field in enumerate(self.running_ver_fields):
if field > int(target_ver_fields[i]):
return False
@@ -486,11 +498,11 @@ class RMAOpen(object):
is no output or sysinfo doesn't contain the devid.
"""
self.set_cr50_device(device)
- sysinfo = self.send_cmd_get_output('sysinfo')
+ sysinfo = self.send_cmd_get_output("sysinfo")
# Make sure there is some output, and it shows it's from Cr50
- if not sysinfo or 'cr50' not in sysinfo:
+ if not sysinfo or "cr50" not in sysinfo:
return False
- logging.debug('Sysinfo output: %s', sysinfo)
+ logging.debug("Sysinfo output: %s", sysinfo)
# The cr50 device id should be in the sysinfo output, if we found
# the right console. Make sure it is
return devid in sysinfo
@@ -508,104 +520,137 @@ class RMAOpen(object):
ValueError if the console can't be found with the given serialname
"""
usb_serial = self.find_cr50_usb(usb_serial)
- logging.info('SERIALNAME: %s', usb_serial)
- devid = '0x' + ' 0x'.join(usb_serial.lower().split('-'))
- logging.info('DEVID: %s', devid)
+ logging.info("SERIALNAME: %s", usb_serial)
+ devid = "0x" + " 0x".join(usb_serial.lower().split("-"))
+ logging.info("DEVID: %s", devid)
# Get all the usb devices
- devices = glob.glob('/dev/ttyUSB*')
+ devices = glob.glob("/dev/ttyUSB*")
# Typically Cr50 has the lowest number. Sort the devices, so we're more
# likely to try the cr50 console first.
devices.sort()
# Find the one that is the cr50 console
for device in devices:
- logging.info('testing %s', device)
+ logging.info("testing %s", device)
if self.device_matches_devid(devid, device):
- logging.info('found device: %s', device)
+ logging.info("found device: %s", device)
return
logging.warning(DEBUG_CONNECTION)
- raise ValueError('Found USB device, but could not communicate with '
- 'cr50 console')
+ raise ValueError(
+ "Found USB device, but could not communicate with " "cr50 console"
+ )
def print_platform_info(self):
"""Print the cr50 BID RLZ code"""
- bid_output = self.send_cmd_get_output('bid')
- bid = re.search(r'Board ID: (\S+?)[:,]', bid_output).group(1)
+ bid_output = self.send_cmd_get_output("bid")
+ bid = re.search(r"Board ID: (\S+?)[:,]", bid_output).group(1)
if bid == ERASED_BID:
logging.warning(DEBUG_ERASED_BOARD_ID)
- raise ValueError('Cannot run RMA Open when board id is erased')
+ raise ValueError("Cannot run RMA Open when board id is erased")
bid = int(bid, 16)
- chrs = [chr((bid >> (8 * i)) & 0xff) for i in range(4)]
- logging.info('RLZ: %s', ''.join(chrs[::-1]))
+ chrs = [chr((bid >> (8 * i)) & 0xFF) for i in range(4)]
+ logging.info("RLZ: %s", "".join(chrs[::-1]))
@staticmethod
def find_cr50_usb(usb_serial):
"""Make sure the Cr50 USB device exists"""
try:
- output = subprocess.check_output(CR50_LSUSB_CMD, encoding='utf-8')
+ output = subprocess.check_output(CR50_LSUSB_CMD, encoding="utf-8")
except:
logging.warning(DEBUG_MISSING_USB)
- raise ValueError('Could not find Cr50 USB device')
- serialnames = re.findall(r'iSerial +\d+ (\S+)\s', output)
+ raise ValueError("Could not find Cr50 USB device")
+ serialnames = re.findall(r"iSerial +\d+ (\S+)\s", output)
if usb_serial:
if usb_serial not in serialnames:
logging.warning(DEBUG_SERIALNAME)
raise ValueError('Could not find usb device "%s"' % usb_serial)
return usb_serial
if len(serialnames) > 1:
- logging.info('Found Cr50 device serialnames %s',
- ', '.join(serialnames))
+ logging.info("Found Cr50 device serialnames %s", ", ".join(serialnames))
logging.warning(DEBUG_TOO_MANY_USB_DEVICES)
- raise ValueError('Too many cr50 usb devices')
+ raise ValueError("Too many cr50 usb devices")
return serialnames[0]
def print_dut_state(self):
"""Print CCD RMA and testlab mode state."""
if not self.check(CCD_IS_UNRESTRICTED):
- logging.info('CCD is still restricted.')
- logging.info('Run cr50_rma_open.py -g -i $HWID to generate a url')
- logging.info('Run cr50_rma_open.py -a $AUTHCODE to open cr50 with '
- 'an authcode')
+ logging.info("CCD is still restricted.")
+ logging.info("Run cr50_rma_open.py -g -i $HWID to generate a url")
+ logging.info(
+ "Run cr50_rma_open.py -a $AUTHCODE to open cr50 with " "an authcode"
+ )
elif not self.check(WP_IS_DISABLED):
- logging.info('WP is still enabled.')
- logging.info('Run cr50_rma_open.py -w to disable write protect')
+ logging.info("WP is still enabled.")
+ logging.info("Run cr50_rma_open.py -w to disable write protect")
if self.check(RMA_OPENED):
- logging.info('RMA Open complete')
+ logging.info("RMA Open complete")
if not self.check(TESTLAB_IS_ENABLED) and self.is_prepvt:
- logging.info('testlab mode is disabled.')
- logging.info('If you are prepping a device for the testlab, you '
- 'should enable testlab mode.')
- logging.info('Run cr50_rma_open.py -t to enable testlab mode')
+ logging.info("testlab mode is disabled.")
+ logging.info(
+ "If you are prepping a device for the testlab, you "
+ "should enable testlab mode."
+ )
+ logging.info("Run cr50_rma_open.py -t to enable testlab mode")
def parse_args(argv):
"""Get cr50_rma_open args."""
parser = argparse.ArgumentParser(
- description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
- parser.add_argument('-g', '--generate_challenge', action='store_true',
- help='Generate Cr50 challenge. Must be used with -i')
- parser.add_argument('-t', '--enable_testlab', action='store_true',
- help='enable testlab mode')
- parser.add_argument('-w', '--wp_disable', action='store_true',
- help='Disable write protect')
- parser.add_argument('-c', '--check_connection', action='store_true',
- help='Check cr50 console connection works')
- parser.add_argument('-s', '--serialname', type=str, default='',
- help='The cr50 usb serialname')
- parser.add_argument('-D', '--debug', action='store_true',
- help='print debug messages')
- parser.add_argument('-d', '--device', type=str, default='',
- help='cr50 console device ex /dev/ttyUSB0')
- parser.add_argument('-i', '--hwid', type=str, default='',
- help='The board hwid. Needed to generate a challenge')
- parser.add_argument('-a', '--authcode', type=str, default='',
- help='The authcode string from the challenge url')
- parser.add_argument('-P', '--servo_port', type=str, default='',
- help='the servo port')
- parser.add_argument('-I', '--ip', type=str, default='',
- help='The DUT IP. Necessary to do ccd open')
+ description=__doc__, formatter_class=argparse.RawTextHelpFormatter
+ )
+ parser.add_argument(
+ "-g",
+ "--generate_challenge",
+ action="store_true",
+ help="Generate Cr50 challenge. Must be used with -i",
+ )
+ parser.add_argument(
+ "-t", "--enable_testlab", action="store_true", help="enable testlab mode"
+ )
+ parser.add_argument(
+ "-w", "--wp_disable", action="store_true", help="Disable write protect"
+ )
+ parser.add_argument(
+ "-c",
+ "--check_connection",
+ action="store_true",
+ help="Check cr50 console connection works",
+ )
+ parser.add_argument(
+ "-s", "--serialname", type=str, default="", help="The cr50 usb serialname"
+ )
+ parser.add_argument(
+ "-D", "--debug", action="store_true", help="print debug messages"
+ )
+ parser.add_argument(
+ "-d",
+ "--device",
+ type=str,
+ default="",
+ help="cr50 console device ex /dev/ttyUSB0",
+ )
+ parser.add_argument(
+ "-i",
+ "--hwid",
+ type=str,
+ default="",
+ help="The board hwid. Needed to generate a challenge",
+ )
+ parser.add_argument(
+ "-a",
+ "--authcode",
+ type=str,
+ default="",
+ help="The authcode string from the challenge url",
+ )
+ parser.add_argument(
+ "-P", "--servo_port", type=str, default="", help="the servo port"
+ )
+ parser.add_argument(
+ "-I", "--ip", type=str, default="", help="The DUT IP. Necessary to do ccd open"
+ )
return parser.parse_args(argv)
@@ -614,52 +659,53 @@ def main(argv):
opts = parse_args(argv)
loglevel = logging.INFO
- log_format = '%(levelname)7s'
+ log_format = "%(levelname)7s"
if opts.debug:
loglevel = logging.DEBUG
- log_format += ' - %(lineno)3d:%(funcName)-15s'
- log_format += ' - %(message)s'
+ log_format += " - %(lineno)3d:%(funcName)-15s"
+ log_format += " - %(message)s"
logging.basicConfig(level=loglevel, format=log_format)
tried_authcode = False
- logging.info('Running cr50_rma_open version %s', SCRIPT_VERSION)
+ logging.info("Running cr50_rma_open version %s", SCRIPT_VERSION)
- cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port,
- opts.ip)
+ cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, opts.ip)
if opts.check_connection:
sys.exit(0)
if not cr50_rma_open.check(CCD_IS_UNRESTRICTED):
if opts.generate_challenge:
if not opts.hwid:
- logging.warning('--hwid necessary to generate challenge url')
+ logging.warning("--hwid necessary to generate challenge url")
sys.exit(0)
cr50_rma_open.generate_challenge_url(opts.hwid)
sys.exit(0)
elif opts.authcode:
- logging.info('Using authcode: %s', opts.authcode)
+ logging.info("Using authcode: %s", opts.authcode)
cr50_rma_open.try_authcode(opts.authcode)
tried_authcode = True
- if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or
- opts.wp_disable):
+ if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or opts.wp_disable):
if not cr50_rma_open.check(CCD_IS_UNRESTRICTED):
- raise ValueError("Can't disable write protect unless ccd is "
- "open. Run through the rma open process first")
+ raise ValueError(
+ "Can't disable write protect unless ccd is "
+ "open. Run through the rma open process first"
+ )
if tried_authcode:
- logging.warning('RMA Open did not disable write protect. File a '
- 'bug')
- logging.warning('Trying to disable it manually')
+ logging.warning("RMA Open did not disable write protect. File a " "bug")
+ logging.warning("Trying to disable it manually")
cr50_rma_open.wp_disable()
if not cr50_rma_open.check(TESTLAB_IS_ENABLED) and opts.enable_testlab:
if not cr50_rma_open.check(CCD_IS_UNRESTRICTED):
- raise ValueError("Can't enable testlab mode unless ccd is open."
- "Run through the rma open process first")
+ raise ValueError(
+ "Can't enable testlab mode unless ccd is open."
+ "Run through the rma open process first"
+ )
cr50_rma_open.enable_testlab()
cr50_rma_open.print_dut_state()
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/extra/stack_analyzer/stack_analyzer.py b/extra/stack_analyzer/stack_analyzer.py
index 77d16d5450..17b2651972 100755
--- a/extra/stack_analyzer/stack_analyzer.py
+++ b/extra/stack_analyzer/stack_analyzer.py
@@ -25,1848 +25,1992 @@ import ctypes
import os
import re
import subprocess
-import yaml
+import yaml
-SECTION_RO = 'RO'
-SECTION_RW = 'RW'
+SECTION_RO = "RO"
+SECTION_RW = "RW"
# Default size of extra stack frame needed by exception context switch.
# This value is for cortex-m with FPU enabled.
DEFAULT_EXCEPTION_FRAME_SIZE = 224
class StackAnalyzerError(Exception):
- """Exception class for stack analyzer utility."""
+ """Exception class for stack analyzer utility."""
class TaskInfo(ctypes.Structure):
- """Taskinfo ctypes structure.
+ """Taskinfo ctypes structure.
- The structure definition is corresponding to the "struct taskinfo"
- in "util/export_taskinfo.so.c".
- """
- _fields_ = [('name', ctypes.c_char_p),
- ('routine', ctypes.c_char_p),
- ('stack_size', ctypes.c_uint32)]
+ The structure definition is corresponding to the "struct taskinfo"
+ in "util/export_taskinfo.so.c".
+ """
+ _fields_ = [
+ ("name", ctypes.c_char_p),
+ ("routine", ctypes.c_char_p),
+ ("stack_size", ctypes.c_uint32),
+ ]
-class Task(object):
- """Task information.
- Attributes:
- name: Task name.
- routine_name: Routine function name.
- stack_max_size: Max stack size.
- routine_address: Resolved routine address. None if it hasn't been resolved.
- """
-
- def __init__(self, name, routine_name, stack_max_size, routine_address=None):
- """Constructor.
+class Task(object):
+ """Task information.
- Args:
+ Attributes:
name: Task name.
routine_name: Routine function name.
stack_max_size: Max stack size.
- routine_address: Resolved routine address.
+ routine_address: Resolved routine address. None if it hasn't been resolved.
"""
- self.name = name
- self.routine_name = routine_name
- self.stack_max_size = stack_max_size
- self.routine_address = routine_address
- def __eq__(self, other):
- """Task equality.
+ def __init__(self, name, routine_name, stack_max_size, routine_address=None):
+ """Constructor.
- Args:
- other: The compared object.
+ Args:
+ name: Task name.
+ routine_name: Routine function name.
+ stack_max_size: Max stack size.
+ routine_address: Resolved routine address.
+ """
+ self.name = name
+ self.routine_name = routine_name
+ self.stack_max_size = stack_max_size
+ self.routine_address = routine_address
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Task):
- return False
+ def __eq__(self, other):
+ """Task equality.
- return (self.name == other.name and
- self.routine_name == other.routine_name and
- self.stack_max_size == other.stack_max_size and
- self.routine_address == other.routine_address)
+ Args:
+ other: The compared object.
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Task):
+ return False
-class Symbol(object):
- """Symbol information.
+ return (
+ self.name == other.name
+ and self.routine_name == other.routine_name
+ and self.stack_max_size == other.stack_max_size
+ and self.routine_address == other.routine_address
+ )
- Attributes:
- address: Symbol address.
- symtype: Symbol type, 'O' (data, object) or 'F' (function).
- size: Symbol size.
- name: Symbol name.
- """
- def __init__(self, address, symtype, size, name):
- """Constructor.
+class Symbol(object):
+ """Symbol information.
- Args:
+ Attributes:
address: Symbol address.
- symtype: Symbol type.
+ symtype: Symbol type, 'O' (data, object) or 'F' (function).
size: Symbol size.
name: Symbol name.
"""
- assert symtype in ['O', 'F']
- self.address = address
- self.symtype = symtype
- self.size = size
- self.name = name
- def __eq__(self, other):
- """Symbol equality.
-
- Args:
- other: The compared object.
-
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Symbol):
- return False
-
- return (self.address == other.address and
- self.symtype == other.symtype and
- self.size == other.size and
- self.name == other.name)
+ def __init__(self, address, symtype, size, name):
+ """Constructor.
+
+ Args:
+ address: Symbol address.
+ symtype: Symbol type.
+ size: Symbol size.
+ name: Symbol name.
+ """
+ assert symtype in ["O", "F"]
+ self.address = address
+ self.symtype = symtype
+ self.size = size
+ self.name = name
+
+ def __eq__(self, other):
+ """Symbol equality.
+
+ Args:
+ other: The compared object.
+
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Symbol):
+ return False
+
+ return (
+ self.address == other.address
+ and self.symtype == other.symtype
+ and self.size == other.size
+ and self.name == other.name
+ )
class Callsite(object):
- """Function callsite.
-
- Attributes:
- address: Address of callsite location. None if it is unknown.
- target: Callee address. None if it is unknown.
- is_tail: A bool indicates that it is a tailing call.
- callee: Resolved callee function. None if it hasn't been resolved.
- """
-
- def __init__(self, address, target, is_tail, callee=None):
- """Constructor.
+ """Function callsite.
- Args:
+ Attributes:
address: Address of callsite location. None if it is unknown.
target: Callee address. None if it is unknown.
- is_tail: A bool indicates that it is a tailing call. (function jump to
- another function without restoring the stack frame)
- callee: Resolved callee function.
+ is_tail: A bool indicates that it is a tailing call.
+ callee: Resolved callee function. None if it hasn't been resolved.
"""
- # It makes no sense that both address and target are unknown.
- assert not (address is None and target is None)
- self.address = address
- self.target = target
- self.is_tail = is_tail
- self.callee = callee
-
- def __eq__(self, other):
- """Callsite equality.
-
- Args:
- other: The compared object.
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Callsite):
- return False
-
- if not (self.address == other.address and
- self.target == other.target and
- self.is_tail == other.is_tail):
- return False
-
- if self.callee is None:
- return other.callee is None
- elif other.callee is None:
- return False
-
- # Assume the addresses of functions are unique.
- return self.callee.address == other.callee.address
+ def __init__(self, address, target, is_tail, callee=None):
+ """Constructor.
+
+ Args:
+ address: Address of callsite location. None if it is unknown.
+ target: Callee address. None if it is unknown.
+ is_tail: A bool indicates that it is a tailing call. (function jump to
+ another function without restoring the stack frame)
+ callee: Resolved callee function.
+ """
+ # It makes no sense that both address and target are unknown.
+ assert not (address is None and target is None)
+ self.address = address
+ self.target = target
+ self.is_tail = is_tail
+ self.callee = callee
+
+ def __eq__(self, other):
+ """Callsite equality.
+
+ Args:
+ other: The compared object.
+
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Callsite):
+ return False
+
+ if not (
+ self.address == other.address
+ and self.target == other.target
+ and self.is_tail == other.is_tail
+ ):
+ return False
+
+ if self.callee is None:
+ return other.callee is None
+ elif other.callee is None:
+ return False
+
+ # Assume the addresses of functions are unique.
+ return self.callee.address == other.callee.address
class Function(object):
- """Function.
-
- Attributes:
- address: Address of function.
- name: Name of function from its symbol.
- stack_frame: Size of stack frame.
- callsites: Callsite list.
- stack_max_usage: Max stack usage. None if it hasn't been analyzed.
- stack_max_path: Max stack usage path. None if it hasn't been analyzed.
- """
+ """Function.
- def __init__(self, address, name, stack_frame, callsites):
- """Constructor.
-
- Args:
+ Attributes:
address: Address of function.
name: Name of function from its symbol.
stack_frame: Size of stack frame.
callsites: Callsite list.
+ stack_max_usage: Max stack usage. None if it hasn't been analyzed.
+ stack_max_path: Max stack usage path. None if it hasn't been analyzed.
"""
- self.address = address
- self.name = name
- self.stack_frame = stack_frame
- self.callsites = callsites
- self.stack_max_usage = None
- self.stack_max_path = None
-
- def __eq__(self, other):
- """Function equality.
-
- Args:
- other: The compared object.
-
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Function):
- return False
-
- if not (self.address == other.address and
- self.name == other.name and
- self.stack_frame == other.stack_frame and
- self.callsites == other.callsites and
- self.stack_max_usage == other.stack_max_usage):
- return False
- if self.stack_max_path is None:
- return other.stack_max_path is None
- elif other.stack_max_path is None:
- return False
+ def __init__(self, address, name, stack_frame, callsites):
+ """Constructor.
+
+ Args:
+ address: Address of function.
+ name: Name of function from its symbol.
+ stack_frame: Size of stack frame.
+ callsites: Callsite list.
+ """
+ self.address = address
+ self.name = name
+ self.stack_frame = stack_frame
+ self.callsites = callsites
+ self.stack_max_usage = None
+ self.stack_max_path = None
+
+ def __eq__(self, other):
+ """Function equality.
+
+ Args:
+ other: The compared object.
+
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Function):
+ return False
+
+ if not (
+ self.address == other.address
+ and self.name == other.name
+ and self.stack_frame == other.stack_frame
+ and self.callsites == other.callsites
+ and self.stack_max_usage == other.stack_max_usage
+ ):
+ return False
+
+ if self.stack_max_path is None:
+ return other.stack_max_path is None
+ elif other.stack_max_path is None:
+ return False
+
+ if len(self.stack_max_path) != len(other.stack_max_path):
+ return False
+
+ for self_func, other_func in zip(self.stack_max_path, other.stack_max_path):
+ # Assume the addresses of functions are unique.
+ if self_func.address != other_func.address:
+ return False
+
+ return True
+
+ def __hash__(self):
+ return id(self)
- if len(self.stack_max_path) != len(other.stack_max_path):
- return False
-
- for self_func, other_func in zip(self.stack_max_path, other.stack_max_path):
- # Assume the addresses of functions are unique.
- if self_func.address != other_func.address:
- return False
-
- return True
-
- def __hash__(self):
- return id(self)
class AndesAnalyzer(object):
- """Disassembly analyzer for Andes architecture.
-
- Public Methods:
- AnalyzeFunction: Analyze stack frame and callsites of the function.
- """
-
- GENERAL_PURPOSE_REGISTER_SIZE = 4
-
- # Possible condition code suffixes.
- CONDITION_CODES = [ 'eq', 'eqz', 'gez', 'gtz', 'lez', 'ltz', 'ne', 'nez',
- 'eqc', 'nec', 'nezs', 'nes', 'eqs']
- CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES))
-
- IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>'
- # Branch instructions.
- JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$' \
- .format(CONDITION_CODES_RE))
- # Call instructions.
- CALL_OPCODE_RE = re.compile \
- (r'^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$')
- CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE))
- # Ignore lp register because it's for return.
- INDIRECT_CALL_OPERAND_RE = re.compile \
- (r'^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$')
- # TODO: Handle other kinds of store instructions.
- PUSH_OPCODE_RE = re.compile(r'^push(\d{1,})$')
- PUSH_OPERAND_RE = re.compile(r'^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}')
- SMW_OPCODE_RE = re.compile(r'^smw(\.\w\w|\.\w\w\w)$')
- SMW_OPERAND_RE = re.compile(r'^(\$r\d{1,}|\$\wp), \[\$\wp\], '
- r'(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}')
- OPERANDGROUP_RE = re.compile(r'^\$r\d{1,}\~\$r\d{1,}')
-
- LWI_OPCODE_RE = re.compile(r'^lwi(\.\w\w)$')
- LWI_PC_OPERAND_RE = re.compile(r'^\$pc, \[([^\]]+)\]')
- # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec"
- # Assume there is always a "\t" after the hex data.
- DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+'
- r'(?P<words>[0-9A-Fa-f ]+)'
- r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?')
-
- def ParseInstruction(self, line, function_end):
- """Parse the line of instruction.
+ """Disassembly analyzer for Andes architecture.
- Args:
- line: Text of disassembly.
- function_end: End address of the current function. None if unknown.
-
- Returns:
- (address, words, opcode, operand_text): The instruction address, words,
- opcode, and the text of operands.
- None if it isn't an instruction line.
+ Public Methods:
+ AnalyzeFunction: Analyze stack frame and callsites of the function.
"""
- result = self.DISASM_REGEX_RE.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- # Check if it's out of bound.
- if function_end is not None and address >= function_end:
- return None
-
- opcode = result.group('opcode').strip()
- operand_text = result.group('operand')
- words = result.group('words')
- if operand_text is None:
- operand_text = ''
- else:
- operand_text = operand_text.strip()
-
- return (address, words, opcode, operand_text)
-
- def AnalyzeFunction(self, function_symbol, instructions):
-
- stack_frame = 0
- callsites = []
- for address, words, opcode, operand_text in instructions:
- is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
- is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
-
- if is_jump_opcode or is_call_opcode:
- is_tail = is_jump_opcode
-
- result = self.CALL_OPERAND_RE.match(operand_text)
+ GENERAL_PURPOSE_REGISTER_SIZE = 4
+
+ # Possible condition code suffixes.
+ CONDITION_CODES = [
+ "eq",
+ "eqz",
+ "gez",
+ "gtz",
+ "lez",
+ "ltz",
+ "ne",
+ "nez",
+ "eqc",
+ "nec",
+ "nezs",
+ "nes",
+ "eqs",
+ ]
+ CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES))
+
+ IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>"
+ # Branch instructions.
+ JUMP_OPCODE_RE = re.compile(
+ r"^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$".format(CONDITION_CODES_RE)
+ )
+ # Call instructions.
+ CALL_OPCODE_RE = re.compile(r"^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$")
+ CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE))
+ # Ignore lp register because it's for return.
+ INDIRECT_CALL_OPERAND_RE = re.compile(r"^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$")
+ # TODO: Handle other kinds of store instructions.
+ PUSH_OPCODE_RE = re.compile(r"^push(\d{1,})$")
+ PUSH_OPERAND_RE = re.compile(r"^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}")
+ SMW_OPCODE_RE = re.compile(r"^smw(\.\w\w|\.\w\w\w)$")
+ SMW_OPERAND_RE = re.compile(
+ r"^(\$r\d{1,}|\$\wp), \[\$\wp\], "
+ r"(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}"
+ )
+ OPERANDGROUP_RE = re.compile(r"^\$r\d{1,}\~\$r\d{1,}")
+
+ LWI_OPCODE_RE = re.compile(r"^lwi(\.\w\w)$")
+ LWI_PC_OPERAND_RE = re.compile(r"^\$pc, \[([^\]]+)\]")
+ # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec"
+ # Assume there is always a "\t" after the hex data.
+ DISASM_REGEX_RE = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+):\s+"
+ r"(?P<words>[0-9A-Fa-f ]+)"
+ r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?"
+ )
+
+ def ParseInstruction(self, line, function_end):
+ """Parse the line of instruction.
+
+ Args:
+ line: Text of disassembly.
+ function_end: End address of the current function. None if unknown.
+
+ Returns:
+ (address, words, opcode, operand_text): The instruction address, words,
+ opcode, and the text of operands.
+ None if it isn't an instruction line.
+ """
+ result = self.DISASM_REGEX_RE.match(line)
if result is None:
- if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None):
- # Found an indirect call.
- callsites.append(Callsite(address, None, is_tail))
-
+ return None
+
+ address = int(result.group("address"), 16)
+ # Check if it's out of bound.
+ if function_end is not None and address >= function_end:
+ return None
+
+ opcode = result.group("opcode").strip()
+ operand_text = result.group("operand")
+ words = result.group("words")
+ if operand_text is None:
+ operand_text = ""
else:
- target_address = int(result.group(1), 16)
- # Filter out the in-function target (branches and in-function calls,
- # which are actually branches).
- if not (function_symbol.size > 0 and
- function_symbol.address < target_address <
- (function_symbol.address + function_symbol.size)):
- # Maybe it is a callsite.
- callsites.append(Callsite(address, target_address, is_tail))
-
- elif self.LWI_OPCODE_RE.match(opcode) is not None:
- result = self.LWI_PC_OPERAND_RE.match(operand_text)
- if result is not None:
- # Ignore "lwi $pc, [$sp], xx" because it's usually a return.
- if result.group(1) != '$sp':
- # Found an indirect call.
- callsites.append(Callsite(address, None, True))
-
- elif self.PUSH_OPCODE_RE.match(opcode) is not None:
- # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp}
- if self.PUSH_OPERAND_RE.match(operand_text) is not None:
- # capture fc 20
- imm5u = int(words.split(' ')[1], 16)
- # sp = sp - (imm5u << 3)
- imm8u = (imm5u<<3) & 0xff
- stack_frame += imm8u
-
- result = self.PUSH_OPERAND_RE.match(operand_text)
- operandgroup_text = result.group(1)
- # capture $rx~$ry
- if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
- # capture number & transfer string to integer
- oprandgrouphead = operandgroup_text.split(',')[0]
- rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0])))
- ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1])))
-
- stack_frame += ((len(operandgroup_text.split(','))+ry-rx) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
- else:
- stack_frame += (len(operandgroup_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
-
- elif self.SMW_OPCODE_RE.match(opcode) is not None:
- # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp}
- if self.SMW_OPERAND_RE.match(operand_text) is not None:
- result = self.SMW_OPERAND_RE.match(operand_text)
- operandgroup_text = result.group(3)
- # capture $rx~$ry
- if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
- # capture number & transfer string to integer
- oprandgrouphead = operandgroup_text.split(',')[0]
- rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0])))
- ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1])))
-
- stack_frame += ((len(operandgroup_text.split(','))+ry-rx) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
- else:
- stack_frame += (len(operandgroup_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
-
- return (stack_frame, callsites)
-
-class ArmAnalyzer(object):
- """Disassembly analyzer for ARM architecture.
-
- Public Methods:
- AnalyzeFunction: Analyze stack frame and callsites of the function.
- """
-
- GENERAL_PURPOSE_REGISTER_SIZE = 4
-
- # Possible condition code suffixes.
- CONDITION_CODES = ['', 'eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl', 'vs',
- 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le']
- CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES))
- # Assume there is no function name containing ">".
- IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>'
-
- # Fuzzy regular expressions for instruction and operand parsing.
- # Branch instructions.
- JUMP_OPCODE_RE = re.compile(
- r'^(b{0}|bx{0})(\.\w)?$'.format(CONDITION_CODES_RE))
- # Call instructions.
- CALL_OPCODE_RE = re.compile(
- r'^(bl{0}|blx{0})(\.\w)?$'.format(CONDITION_CODES_RE))
- CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE))
- CBZ_CBNZ_OPCODE_RE = re.compile(r'^(cbz|cbnz)(\.\w)?$')
- # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>"
- CBZ_CBNZ_OPERAND_RE = re.compile(r'^[^,]+,\s+{}$'.format(IMM_ADDRESS_RE))
- # Ignore lr register because it's for return.
- INDIRECT_CALL_OPERAND_RE = re.compile(r'^r\d+|sb|sl|fp|ip|sp|pc$')
- # TODO(cheyuw): Handle conditional versions of following
- # instructions.
- # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc).
- LDR_OPCODE_RE = re.compile(r'^ldr(\.\w)?$')
- # Example: "pc, [sp], #4"
- LDR_PC_OPERAND_RE = re.compile(r'^pc, \[([^\]]+)\]')
- # TODO(cheyuw): Handle other kinds of stm instructions.
- PUSH_OPCODE_RE = re.compile(r'^push$')
- STM_OPCODE_RE = re.compile(r'^stmdb$')
- # Stack subtraction instructions.
- SUB_OPCODE_RE = re.compile(r'^sub(s|w)?(\.\w)?$')
- SUB_OPERAND_RE = re.compile(r'^sp[^#]+#(\d+)')
- # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68"
- # Assume there is always a "\t" after the hex data.
- DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+'
- r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?')
-
- def ParseInstruction(self, line, function_end):
- """Parse the line of instruction.
+ operand_text = operand_text.strip()
+
+ return (address, words, opcode, operand_text)
+
+ def AnalyzeFunction(self, function_symbol, instructions):
+
+ stack_frame = 0
+ callsites = []
+ for address, words, opcode, operand_text in instructions:
+ is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
+ is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
+
+ if is_jump_opcode or is_call_opcode:
+ is_tail = is_jump_opcode
+
+ result = self.CALL_OPERAND_RE.match(operand_text)
+
+ if result is None:
+ if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None:
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, is_tail))
+
+ else:
+ target_address = int(result.group(1), 16)
+ # Filter out the in-function target (branches and in-function calls,
+ # which are actually branches).
+ if not (
+ function_symbol.size > 0
+ and function_symbol.address
+ < target_address
+ < (function_symbol.address + function_symbol.size)
+ ):
+ # Maybe it is a callsite.
+ callsites.append(Callsite(address, target_address, is_tail))
+
+ elif self.LWI_OPCODE_RE.match(opcode) is not None:
+ result = self.LWI_PC_OPERAND_RE.match(operand_text)
+ if result is not None:
+ # Ignore "lwi $pc, [$sp], xx" because it's usually a return.
+ if result.group(1) != "$sp":
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, True))
+
+ elif self.PUSH_OPCODE_RE.match(opcode) is not None:
+ # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp}
+ if self.PUSH_OPERAND_RE.match(operand_text) is not None:
+ # capture fc 20
+ imm5u = int(words.split(" ")[1], 16)
+ # sp = sp - (imm5u << 3)
+ imm8u = (imm5u << 3) & 0xFF
+ stack_frame += imm8u
+
+ result = self.PUSH_OPERAND_RE.match(operand_text)
+ operandgroup_text = result.group(1)
+ # capture $rx~$ry
+ if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
+ # capture number & transfer string to integer
+ oprandgrouphead = operandgroup_text.split(",")[0]
+ rx = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[0]))
+ )
+ ry = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[1]))
+ )
+
+ stack_frame += (
+ len(operandgroup_text.split(",")) + ry - rx
+ ) * self.GENERAL_PURPOSE_REGISTER_SIZE
+ else:
+ stack_frame += (
+ len(operandgroup_text.split(","))
+ * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+
+ elif self.SMW_OPCODE_RE.match(opcode) is not None:
+ # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp}
+ if self.SMW_OPERAND_RE.match(operand_text) is not None:
+ result = self.SMW_OPERAND_RE.match(operand_text)
+ operandgroup_text = result.group(3)
+ # capture $rx~$ry
+ if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
+ # capture number & transfer string to integer
+ oprandgrouphead = operandgroup_text.split(",")[0]
+ rx = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[0]))
+ )
+ ry = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[1]))
+ )
+
+ stack_frame += (
+ len(operandgroup_text.split(",")) + ry - rx
+ ) * self.GENERAL_PURPOSE_REGISTER_SIZE
+ else:
+ stack_frame += (
+ len(operandgroup_text.split(","))
+ * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+
+ return (stack_frame, callsites)
- Args:
- line: Text of disassembly.
- function_end: End address of the current function. None if unknown.
- Returns:
- (address, opcode, operand_text): The instruction address, opcode,
- and the text of operands. None if it
- isn't an instruction line.
- """
- result = self.DISASM_REGEX_RE.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- # Check if it's out of bound.
- if function_end is not None and address >= function_end:
- return None
-
- opcode = result.group('opcode').strip()
- operand_text = result.group('operand')
- if operand_text is None:
- operand_text = ''
- else:
- operand_text = operand_text.strip()
-
- return (address, opcode, operand_text)
-
- def AnalyzeFunction(self, function_symbol, instructions):
- """Analyze function, resolve the size of stack frame and callsites.
-
- Args:
- function_symbol: Function symbol.
- instructions: Instruction list.
+class ArmAnalyzer(object):
+ """Disassembly analyzer for ARM architecture.
- Returns:
- (stack_frame, callsites): Size of stack frame, callsite list.
+ Public Methods:
+ AnalyzeFunction: Analyze stack frame and callsites of the function.
"""
- stack_frame = 0
- callsites = []
- for address, opcode, operand_text in instructions:
- is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
- is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
- is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None
- if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode:
- is_tail = is_jump_opcode or is_cbz_cbnz_opcode
-
- if is_cbz_cbnz_opcode:
- result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text)
- else:
- result = self.CALL_OPERAND_RE.match(operand_text)
+ GENERAL_PURPOSE_REGISTER_SIZE = 4
+
+ # Possible condition code suffixes.
+ CONDITION_CODES = [
+ "",
+ "eq",
+ "ne",
+ "cs",
+ "hs",
+ "cc",
+ "lo",
+ "mi",
+ "pl",
+ "vs",
+ "vc",
+ "hi",
+ "ls",
+ "ge",
+ "lt",
+ "gt",
+ "le",
+ ]
+ CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES))
+ # Assume there is no function name containing ">".
+ IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>"
+
+ # Fuzzy regular expressions for instruction and operand parsing.
+ # Branch instructions.
+ JUMP_OPCODE_RE = re.compile(r"^(b{0}|bx{0})(\.\w)?$".format(CONDITION_CODES_RE))
+ # Call instructions.
+ CALL_OPCODE_RE = re.compile(r"^(bl{0}|blx{0})(\.\w)?$".format(CONDITION_CODES_RE))
+ CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE))
+ CBZ_CBNZ_OPCODE_RE = re.compile(r"^(cbz|cbnz)(\.\w)?$")
+ # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>"
+ CBZ_CBNZ_OPERAND_RE = re.compile(r"^[^,]+,\s+{}$".format(IMM_ADDRESS_RE))
+ # Ignore lr register because it's for return.
+ INDIRECT_CALL_OPERAND_RE = re.compile(r"^r\d+|sb|sl|fp|ip|sp|pc$")
+ # TODO(cheyuw): Handle conditional versions of following
+ # instructions.
+ # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc).
+ LDR_OPCODE_RE = re.compile(r"^ldr(\.\w)?$")
+ # Example: "pc, [sp], #4"
+ LDR_PC_OPERAND_RE = re.compile(r"^pc, \[([^\]]+)\]")
+ # TODO(cheyuw): Handle other kinds of stm instructions.
+ PUSH_OPCODE_RE = re.compile(r"^push$")
+ STM_OPCODE_RE = re.compile(r"^stmdb$")
+ # Stack subtraction instructions.
+ SUB_OPCODE_RE = re.compile(r"^sub(s|w)?(\.\w)?$")
+ SUB_OPERAND_RE = re.compile(r"^sp[^#]+#(\d+)")
+ # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68"
+ # Assume there is always a "\t" after the hex data.
+ DISASM_REGEX_RE = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+"
+ r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?"
+ )
+
+ def ParseInstruction(self, line, function_end):
+ """Parse the line of instruction.
+
+ Args:
+ line: Text of disassembly.
+ function_end: End address of the current function. None if unknown.
+
+ Returns:
+ (address, opcode, operand_text): The instruction address, opcode,
+ and the text of operands. None if it
+ isn't an instruction line.
+ """
+ result = self.DISASM_REGEX_RE.match(line)
if result is None:
- # Failed to match immediate address, maybe it is an indirect call.
- # CBZ and CBNZ can't be indirect calls.
- if (not is_cbz_cbnz_opcode and
- self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None):
- # Found an indirect call.
- callsites.append(Callsite(address, None, is_tail))
+ return None
- else:
- target_address = int(result.group(1), 16)
- # Filter out the in-function target (branches and in-function calls,
- # which are actually branches).
- if not (function_symbol.size > 0 and
- function_symbol.address < target_address <
- (function_symbol.address + function_symbol.size)):
- # Maybe it is a callsite.
- callsites.append(Callsite(address, target_address, is_tail))
-
- elif self.LDR_OPCODE_RE.match(opcode) is not None:
- result = self.LDR_PC_OPERAND_RE.match(operand_text)
- if result is not None:
- # Ignore "ldr pc, [sp], xx" because it's usually a return.
- if result.group(1) != 'sp':
- # Found an indirect call.
- callsites.append(Callsite(address, None, True))
-
- elif self.PUSH_OPCODE_RE.match(opcode) is not None:
- # Example: "{r4, r5, r6, r7, lr}"
- stack_frame += (len(operand_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
- elif self.SUB_OPCODE_RE.match(opcode) is not None:
- result = self.SUB_OPERAND_RE.match(operand_text)
- if result is not None:
- stack_frame += int(result.group(1))
- else:
- # Unhandled stack register subtraction.
- assert not operand_text.startswith('sp')
+ address = int(result.group("address"), 16)
+ # Check if it's out of bound.
+ if function_end is not None and address >= function_end:
+ return None
- elif self.STM_OPCODE_RE.match(opcode) is not None:
- if operand_text.startswith('sp!'):
- # Subtract and writeback to stack register.
- # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}"
- # Get the text of pushed register list.
- unused_sp, unused_sep, parameter_text = operand_text.partition(',')
- stack_frame += (len(parameter_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
+ opcode = result.group("opcode").strip()
+ operand_text = result.group("operand")
+ if operand_text is None:
+ operand_text = ""
+ else:
+ operand_text = operand_text.strip()
+
+ return (address, opcode, operand_text)
+
+ def AnalyzeFunction(self, function_symbol, instructions):
+ """Analyze function, resolve the size of stack frame and callsites.
+
+ Args:
+ function_symbol: Function symbol.
+ instructions: Instruction list.
+
+ Returns:
+ (stack_frame, callsites): Size of stack frame, callsite list.
+ """
+ stack_frame = 0
+ callsites = []
+ for address, opcode, operand_text in instructions:
+ is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
+ is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
+ is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None
+ if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode:
+ is_tail = is_jump_opcode or is_cbz_cbnz_opcode
+
+ if is_cbz_cbnz_opcode:
+ result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text)
+ else:
+ result = self.CALL_OPERAND_RE.match(operand_text)
+
+ if result is None:
+ # Failed to match immediate address, maybe it is an indirect call.
+ # CBZ and CBNZ can't be indirect calls.
+ if (
+ not is_cbz_cbnz_opcode
+ and self.INDIRECT_CALL_OPERAND_RE.match(operand_text)
+ is not None
+ ):
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, is_tail))
+
+ else:
+ target_address = int(result.group(1), 16)
+ # Filter out the in-function target (branches and in-function calls,
+ # which are actually branches).
+ if not (
+ function_symbol.size > 0
+ and function_symbol.address
+ < target_address
+ < (function_symbol.address + function_symbol.size)
+ ):
+ # Maybe it is a callsite.
+ callsites.append(Callsite(address, target_address, is_tail))
+
+ elif self.LDR_OPCODE_RE.match(opcode) is not None:
+ result = self.LDR_PC_OPERAND_RE.match(operand_text)
+ if result is not None:
+ # Ignore "ldr pc, [sp], xx" because it's usually a return.
+ if result.group(1) != "sp":
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, True))
+
+ elif self.PUSH_OPCODE_RE.match(opcode) is not None:
+ # Example: "{r4, r5, r6, r7, lr}"
+ stack_frame += (
+ len(operand_text.split(",")) * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+ elif self.SUB_OPCODE_RE.match(opcode) is not None:
+ result = self.SUB_OPERAND_RE.match(operand_text)
+ if result is not None:
+ stack_frame += int(result.group(1))
+ else:
+ # Unhandled stack register subtraction.
+ assert not operand_text.startswith("sp")
+
+ elif self.STM_OPCODE_RE.match(opcode) is not None:
+ if operand_text.startswith("sp!"):
+ # Subtract and writeback to stack register.
+ # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}"
+ # Get the text of pushed register list.
+ unused_sp, unused_sep, parameter_text = operand_text.partition(",")
+ stack_frame += (
+ len(parameter_text.split(","))
+ * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+
+ return (stack_frame, callsites)
- return (stack_frame, callsites)
class RiscvAnalyzer(object):
- """Disassembly analyzer for RISC-V architecture.
-
- Public Methods:
- AnalyzeFunction: Analyze stack frame and callsites of the function.
- """
-
- # Possible condition code suffixes.
- CONDITION_CODES = [ 'eqz', 'nez', 'lez', 'gez', 'ltz', 'gtz', 'gt', 'le',
- 'gtu', 'leu', 'eq', 'ne', 'ge', 'lt', 'ltu', 'geu']
- CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES))
- # Branch instructions.
- JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr)$'.format(CONDITION_CODES_RE))
- # Call instructions.
- CALL_OPCODE_RE = re.compile(r'^(jal|jalr)$')
- # Example: "j 8009b318 <set_state_prl_hr>" or
- # "jal ra,800a4394 <power_get_signals>" or
- # "bltu t0,t1,80080300 <data_loop>"
- JUMP_ADDRESS_RE = r'((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>'
- CALL_OPERAND_RE = re.compile(r'^{}$'.format(JUMP_ADDRESS_RE))
- # Capture address, Example: 800a4394
- CAPTURE_ADDRESS = re.compile(r'[0-9A-Fa-f]{8}')
- # Indirect jump, Example: jalr a5
- INDIRECT_CALL_OPERAND_RE = re.compile(r'^t\d+|s\d+|a\d+$')
- # Example: addi
- ADDI_OPCODE_RE = re.compile(r'^addi$')
- # Allocate stack instructions.
- ADDI_OPERAND_RE = re.compile(r'^(sp,sp,-\d+)$')
- # Example: "800804b6: 1101 addi sp,sp,-32"
- DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+'
- r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?')
-
- def ParseInstruction(self, line, function_end):
- """Parse the line of instruction.
+ """Disassembly analyzer for RISC-V architecture.
- Args:
- line: Text of disassembly.
- function_end: End address of the current function. None if unknown.
-
- Returns:
- (address, opcode, operand_text): The instruction address, opcode,
- and the text of operands. None if it
- isn't an instruction line.
+ Public Methods:
+ AnalyzeFunction: Analyze stack frame and callsites of the function.
"""
- result = self.DISASM_REGEX_RE.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- # Check if it's out of bound.
- if function_end is not None and address >= function_end:
- return None
-
- opcode = result.group('opcode').strip()
- operand_text = result.group('operand')
- if operand_text is None:
- operand_text = ''
- else:
- operand_text = operand_text.strip()
-
- return (address, opcode, operand_text)
-
- def AnalyzeFunction(self, function_symbol, instructions):
-
- stack_frame = 0
- callsites = []
- for address, opcode, operand_text in instructions:
- is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
- is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
- if is_jump_opcode or is_call_opcode:
- is_tail = is_jump_opcode
-
- result = self.CALL_OPERAND_RE.match(operand_text)
+ # Possible condition code suffixes.
+ CONDITION_CODES = [
+ "eqz",
+ "nez",
+ "lez",
+ "gez",
+ "ltz",
+ "gtz",
+ "gt",
+ "le",
+ "gtu",
+ "leu",
+ "eq",
+ "ne",
+ "ge",
+ "lt",
+ "ltu",
+ "geu",
+ ]
+ CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES))
+ # Branch instructions.
+ JUMP_OPCODE_RE = re.compile(r"^(b{0}|j|jr)$".format(CONDITION_CODES_RE))
+ # Call instructions.
+ CALL_OPCODE_RE = re.compile(r"^(jal|jalr)$")
+ # Example: "j 8009b318 <set_state_prl_hr>" or
+ # "jal ra,800a4394 <power_get_signals>" or
+ # "bltu t0,t1,80080300 <data_loop>"
+ JUMP_ADDRESS_RE = r"((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>"
+ CALL_OPERAND_RE = re.compile(r"^{}$".format(JUMP_ADDRESS_RE))
+ # Capture address, Example: 800a4394
+ CAPTURE_ADDRESS = re.compile(r"[0-9A-Fa-f]{8}")
+ # Indirect jump, Example: jalr a5
+ INDIRECT_CALL_OPERAND_RE = re.compile(r"^t\d+|s\d+|a\d+$")
+ # Example: addi
+ ADDI_OPCODE_RE = re.compile(r"^addi$")
+ # Allocate stack instructions.
+ ADDI_OPERAND_RE = re.compile(r"^(sp,sp,-\d+)$")
+ # Example: "800804b6: 1101 addi sp,sp,-32"
+ DISASM_REGEX_RE = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+"
+ r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?"
+ )
+
+ def ParseInstruction(self, line, function_end):
+ """Parse the line of instruction.
+
+ Args:
+ line: Text of disassembly.
+ function_end: End address of the current function. None if unknown.
+
+ Returns:
+ (address, opcode, operand_text): The instruction address, opcode,
+ and the text of operands. None if it
+ isn't an instruction line.
+ """
+ result = self.DISASM_REGEX_RE.match(line)
if result is None:
- if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None):
- # Found an indirect call.
- callsites.append(Callsite(address, None, is_tail))
-
- else:
- # Capture address form operand_text and then convert to string
- address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text))
- # String to integer
- target_address = int(address_str, 16)
- # Filter out the in-function target (branches and in-function calls,
- # which are actually branches).
- if not (function_symbol.size > 0 and
- function_symbol.address < target_address <
- (function_symbol.address + function_symbol.size)):
- # Maybe it is a callsite.
- callsites.append(Callsite(address, target_address, is_tail))
-
- elif self.ADDI_OPCODE_RE.match(opcode) is not None:
- # Example: sp,sp,-32
- if self.ADDI_OPERAND_RE.match(operand_text) is not None:
- stack_frame += abs(int(operand_text.split(",")[2]))
-
- return (stack_frame, callsites)
-
-class StackAnalyzer(object):
- """Class to analyze stack usage.
-
- Public Methods:
- Analyze: Run the stack analysis.
- """
-
- C_FUNCTION_NAME = r'_A-Za-z0-9'
-
- # Assume there is no ":" in the path.
- # Example: "driver/accel_kionix.c:321 (discriminator 3)"
- ADDRTOLINE_RE = re.compile(
- r'^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$')
- # To eliminate the suffix appended by compilers, try to extract the
- # C function name from the prefix of symbol name.
- # Example: "SHA256_transform.constprop.28"
- FUNCTION_PREFIX_NAME_RE = re.compile(
- r'^(?P<name>[{0}]+)([^{0}].*)?$'.format(C_FUNCTION_NAME))
-
- # Errors of annotation resolving.
- ANNOTATION_ERROR_INVALID = 'invalid signature'
- ANNOTATION_ERROR_NOTFOUND = 'function is not found'
- ANNOTATION_ERROR_AMBIGUOUS = 'signature is ambiguous'
-
- def __init__(self, options, symbols, rodata, tasklist, annotation):
- """Constructor.
-
- Args:
- options: Namespace from argparse.parse_args().
- symbols: Symbol list.
- rodata: Content of .rodata section (offset, data)
- tasklist: Task list.
- annotation: Annotation config.
- """
- self.options = options
- self.symbols = symbols
- self.rodata_offset = rodata[0]
- self.rodata = rodata[1]
- self.tasklist = tasklist
- self.annotation = annotation
- self.address_to_line_cache = {}
-
- def AddressToLine(self, address, resolve_inline=False):
- """Convert address to line.
-
- Args:
- address: Target address.
- resolve_inline: Output the stack of inlining.
-
- Returns:
- lines: List of the corresponding lines.
-
- Raises:
- StackAnalyzerError: If addr2line is failed.
- """
- cache_key = (address, resolve_inline)
- if cache_key in self.address_to_line_cache:
- return self.address_to_line_cache[cache_key]
-
- try:
- args = [self.options.addr2line,
- '-f',
- '-e',
- self.options.elf_path,
- '{:x}'.format(address)]
- if resolve_inline:
- args.append('-i')
-
- line_text = subprocess.check_output(args, encoding='utf-8')
- except subprocess.CalledProcessError:
- raise StackAnalyzerError('addr2line failed to resolve lines.')
- except OSError:
- raise StackAnalyzerError('Failed to run addr2line.')
-
- lines = [line.strip() for line in line_text.splitlines()]
- # Assume the output has at least one pair like "function\nlocation\n", and
- # they always show up in pairs.
- # Example: "handle_request\n
- # common/usb_pd_protocol.c:1191\n"
- assert len(lines) >= 2 and len(lines) % 2 == 0
-
- line_infos = []
- for index in range(0, len(lines), 2):
- (function_name, line_text) = lines[index:index + 2]
- if line_text in ['??:0', ':?']:
- line_infos.append(None)
- else:
- result = self.ADDRTOLINE_RE.match(line_text)
- # Assume the output is always well-formed.
- assert result is not None
- line_infos.append((function_name.strip(),
- os.path.realpath(result.group('path').strip()),
- int(result.group('linenum'))))
-
- self.address_to_line_cache[cache_key] = line_infos
- return line_infos
-
- def AnalyzeDisassembly(self, disasm_text):
- """Parse the disassembly text, analyze, and build a map of all functions.
-
- Args:
- disasm_text: Disassembly text.
-
- Returns:
- function_map: Dict of functions.
- """
- disasm_lines = [line.strip() for line in disasm_text.splitlines()]
-
- if 'nds' in disasm_lines[1]:
- analyzer = AndesAnalyzer()
- elif 'arm' in disasm_lines[1]:
- analyzer = ArmAnalyzer()
- elif 'riscv' in disasm_lines[1]:
- analyzer = RiscvAnalyzer()
- else:
- raise StackAnalyzerError('Unsupported architecture.')
-
- # Example: "08028c8c <motion_lid_calc>:"
- function_signature_regex = re.compile(
- r'^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$')
-
- def DetectFunctionHead(line):
- """Check if the line is a function head.
-
- Args:
- line: Text of disassembly.
-
- Returns:
- symbol: Function symbol. None if it isn't a function head.
- """
- result = function_signature_regex.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- symbol = symbol_map.get(address)
-
- # Check if the function exists and matches.
- if symbol is None or symbol.symtype != 'F':
- return None
-
- return symbol
-
- # Build symbol map, indexed by symbol address.
- symbol_map = {}
- for symbol in self.symbols:
- # If there are multiple symbols with same address, keeping any of them is
- # good enough.
- symbol_map[symbol.address] = symbol
-
- # Parse the disassembly text. We update the variable "line" to next line
- # when needed. There are two steps of parser:
- #
- # Step 1: Searching for the function head. Once reach the function head,
- # move to the next line, which is the first line of function body.
- #
- # Step 2: Parsing each instruction line of function body. Once reach a
- # non-instruction line, stop parsing and analyze the parsed instructions.
- #
- # Finally turn back to the step 1 without updating the line, because the
- # current non-instruction line can be another function head.
- function_map = {}
- # The following three variables are the states of the parsing processing.
- # They will be initialized properly during the state changes.
- function_symbol = None
- function_end = None
- instructions = []
-
- # Remove heading and tailing spaces for each line.
- line_index = 0
- while line_index < len(disasm_lines):
- # Get the current line.
- line = disasm_lines[line_index]
-
- if function_symbol is None:
- # Step 1: Search for the function head.
-
- function_symbol = DetectFunctionHead(line)
- if function_symbol is not None:
- # Assume there is no empty function. If the function head is followed
- # by EOF, it is an empty function.
- assert line_index + 1 < len(disasm_lines)
-
- # Found the function head, initialize and turn to the step 2.
- instructions = []
- # If symbol size exists, use it as a hint of function size.
- if function_symbol.size > 0:
- function_end = function_symbol.address + function_symbol.size
- else:
- function_end = None
-
- else:
- # Step 2: Parse the function body.
-
- instruction = analyzer.ParseInstruction(line, function_end)
- if instruction is not None:
- instructions.append(instruction)
-
- if instruction is None or line_index + 1 == len(disasm_lines):
- # Either the invalid instruction or EOF indicates the end of the
- # function, finalize the function analysis.
-
- # Assume there is no empty function.
- assert len(instructions) > 0
-
- (stack_frame, callsites) = analyzer.AnalyzeFunction(function_symbol,
- instructions)
- # Assume the function addresses are unique in the disassembly.
- assert function_symbol.address not in function_map
- function_map[function_symbol.address] = Function(
- function_symbol.address,
- function_symbol.name,
- stack_frame,
- callsites)
-
- # Initialize and turn back to the step 1.
- function_symbol = None
-
- # If the current line isn't an instruction, it can be another function
- # head, skip moving to the next line.
- if instruction is None:
- continue
-
- # Move to the next line.
- line_index += 1
+ return None
- # Resolve callees of functions.
- for function in function_map.values():
- for callsite in function.callsites:
- if callsite.target is not None:
- # Remain the callee as None if we can't resolve it.
- callsite.callee = function_map.get(callsite.target)
+ address = int(result.group("address"), 16)
+ # Check if it's out of bound.
+ if function_end is not None and address >= function_end:
+ return None
- return function_map
+ opcode = result.group("opcode").strip()
+ operand_text = result.group("operand")
+ if operand_text is None:
+ operand_text = ""
+ else:
+ operand_text = operand_text.strip()
+
+ return (address, opcode, operand_text)
+
+ def AnalyzeFunction(self, function_symbol, instructions):
+
+ stack_frame = 0
+ callsites = []
+ for address, opcode, operand_text in instructions:
+ is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
+ is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
+
+ if is_jump_opcode or is_call_opcode:
+ is_tail = is_jump_opcode
+
+ result = self.CALL_OPERAND_RE.match(operand_text)
+ if result is None:
+ if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None:
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, is_tail))
+
+ else:
+ # Capture address form operand_text and then convert to string
+ address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text))
+ # String to integer
+ target_address = int(address_str, 16)
+ # Filter out the in-function target (branches and in-function calls,
+ # which are actually branches).
+ if not (
+ function_symbol.size > 0
+ and function_symbol.address
+ < target_address
+ < (function_symbol.address + function_symbol.size)
+ ):
+ # Maybe it is a callsite.
+ callsites.append(Callsite(address, target_address, is_tail))
+
+ elif self.ADDI_OPCODE_RE.match(opcode) is not None:
+ # Example: sp,sp,-32
+ if self.ADDI_OPERAND_RE.match(operand_text) is not None:
+ stack_frame += abs(int(operand_text.split(",")[2]))
+
+ return (stack_frame, callsites)
- def MapAnnotation(self, function_map, signature_set):
- """Map annotation signatures to functions.
- Args:
- function_map: Function map.
- signature_set: Set of annotation signatures.
+class StackAnalyzer(object):
+ """Class to analyze stack usage.
- Returns:
- Map of signatures to functions, map of signatures which can't be resolved.
+ Public Methods:
+ Analyze: Run the stack analysis.
"""
- # Build the symbol map indexed by symbol name. If there are multiple symbols
- # with the same name, add them into a set. (e.g. symbols of static function
- # with the same name)
- symbol_map = collections.defaultdict(set)
- for symbol in self.symbols:
- if symbol.symtype == 'F':
- # Function symbol.
- result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
- if result is not None:
- function = function_map.get(symbol.address)
- # Ignore the symbol not in disassembly.
- if function is not None:
- # If there are multiple symbol with the same name and point to the
- # same function, the set will deduplicate them.
- symbol_map[result.group('name').strip()].add(function)
-
- # Build the signature map indexed by annotation signature.
- signature_map = {}
- sig_error_map = {}
- symbol_path_map = {}
- for sig in signature_set:
- (name, path, _) = sig
-
- functions = symbol_map.get(name)
- if functions is None:
- sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
- continue
-
- if name not in symbol_path_map:
- # Lazy symbol path resolving. Since the addr2line isn't fast, only
- # resolve needed symbol paths.
- group_map = collections.defaultdict(list)
- for function in functions:
- line_info = self.AddressToLine(function.address)[0]
- if line_info is None:
- continue
- (_, symbol_path, _) = line_info
+ C_FUNCTION_NAME = r"_A-Za-z0-9"
- # Group the functions with the same symbol signature (symbol name +
- # symbol path). Assume they are the same copies and do the same
- # annotation operations of them because we don't know which copy is
- # indicated by the users.
- group_map[symbol_path].append(function)
+ # Assume there is no ":" in the path.
+ # Example: "driver/accel_kionix.c:321 (discriminator 3)"
+ ADDRTOLINE_RE = re.compile(
+ r"^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$"
+ )
+ # To eliminate the suffix appended by compilers, try to extract the
+ # C function name from the prefix of symbol name.
+ # Example: "SHA256_transform.constprop.28"
+ FUNCTION_PREFIX_NAME_RE = re.compile(
+ r"^(?P<name>[{0}]+)([^{0}].*)?$".format(C_FUNCTION_NAME)
+ )
+
+ # Errors of annotation resolving.
+ ANNOTATION_ERROR_INVALID = "invalid signature"
+ ANNOTATION_ERROR_NOTFOUND = "function is not found"
+ ANNOTATION_ERROR_AMBIGUOUS = "signature is ambiguous"
+
+ def __init__(self, options, symbols, rodata, tasklist, annotation):
+ """Constructor.
+
+ Args:
+ options: Namespace from argparse.parse_args().
+ symbols: Symbol list.
+ rodata: Content of .rodata section (offset, data)
+ tasklist: Task list.
+ annotation: Annotation config.
+ """
+ self.options = options
+ self.symbols = symbols
+ self.rodata_offset = rodata[0]
+ self.rodata = rodata[1]
+ self.tasklist = tasklist
+ self.annotation = annotation
+ self.address_to_line_cache = {}
+
+ def AddressToLine(self, address, resolve_inline=False):
+ """Convert address to line.
+
+ Args:
+ address: Target address.
+ resolve_inline: Output the stack of inlining.
+
+ Returns:
+ lines: List of the corresponding lines.
+
+ Raises:
+ StackAnalyzerError: If addr2line is failed.
+ """
+ cache_key = (address, resolve_inline)
+ if cache_key in self.address_to_line_cache:
+ return self.address_to_line_cache[cache_key]
+
+ try:
+ args = [
+ self.options.addr2line,
+ "-f",
+ "-e",
+ self.options.elf_path,
+ "{:x}".format(address),
+ ]
+ if resolve_inline:
+ args.append("-i")
+
+ line_text = subprocess.check_output(args, encoding="utf-8")
+ except subprocess.CalledProcessError:
+ raise StackAnalyzerError("addr2line failed to resolve lines.")
+ except OSError:
+ raise StackAnalyzerError("Failed to run addr2line.")
+
+ lines = [line.strip() for line in line_text.splitlines()]
+ # Assume the output has at least one pair like "function\nlocation\n", and
+ # they always show up in pairs.
+ # Example: "handle_request\n
+ # common/usb_pd_protocol.c:1191\n"
+ assert len(lines) >= 2 and len(lines) % 2 == 0
+
+ line_infos = []
+ for index in range(0, len(lines), 2):
+ (function_name, line_text) = lines[index : index + 2]
+ if line_text in ["??:0", ":?"]:
+ line_infos.append(None)
+ else:
+ result = self.ADDRTOLINE_RE.match(line_text)
+ # Assume the output is always well-formed.
+ assert result is not None
+ line_infos.append(
+ (
+ function_name.strip(),
+ os.path.realpath(result.group("path").strip()),
+ int(result.group("linenum")),
+ )
+ )
+
+ self.address_to_line_cache[cache_key] = line_infos
+ return line_infos
+
+ def AnalyzeDisassembly(self, disasm_text):
+ """Parse the disassembly text, analyze, and build a map of all functions.
+
+ Args:
+ disasm_text: Disassembly text.
+
+ Returns:
+ function_map: Dict of functions.
+ """
+ disasm_lines = [line.strip() for line in disasm_text.splitlines()]
+
+ if "nds" in disasm_lines[1]:
+ analyzer = AndesAnalyzer()
+ elif "arm" in disasm_lines[1]:
+ analyzer = ArmAnalyzer()
+ elif "riscv" in disasm_lines[1]:
+ analyzer = RiscvAnalyzer()
+ else:
+ raise StackAnalyzerError("Unsupported architecture.")
- symbol_path_map[name] = group_map
+ # Example: "08028c8c <motion_lid_calc>:"
+ function_signature_regex = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$"
+ )
- # Symbol matching.
- function_group = None
- group_map = symbol_path_map[name]
- if len(group_map) > 0:
- if path is None:
- if len(group_map) > 1:
- # There is ambiguity but the path isn't specified.
- sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS
- continue
+ def DetectFunctionHead(line):
+ """Check if the line is a function head.
- # No path signature but all symbol signatures of functions are same.
- # Assume they are the same functions, so there is no ambiguity.
- (function_group,) = group_map.values()
- else:
- function_group = group_map.get(path)
+ Args:
+ line: Text of disassembly.
- if function_group is None:
- sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
- continue
+ Returns:
+ symbol: Function symbol. None if it isn't a function head.
+ """
+ result = function_signature_regex.match(line)
+ if result is None:
+ return None
- # The function_group is a list of all the same functions (according to
- # our assumption) which should be annotated together.
- signature_map[sig] = function_group
+ address = int(result.group("address"), 16)
+ symbol = symbol_map.get(address)
- return (signature_map, sig_error_map)
+ # Check if the function exists and matches.
+ if symbol is None or symbol.symtype != "F":
+ return None
- def LoadAnnotation(self):
- """Load annotation rules.
+ return symbol
- Returns:
- Map of add rules, set of remove rules, set of text signatures which can't
- be parsed.
- """
- # Assume there is no ":" in the path.
- # Example: "get_range.lto.2501[driver/accel_kionix.c:327]"
- annotation_signature_regex = re.compile(
- r'^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$')
-
- def NormalizeSignature(signature_text):
- """Parse and normalize the annotation signature.
-
- Args:
- signature_text: Text of the annotation signature.
-
- Returns:
- (function name, path, line number) of the signature. The path and line
- number can be None if not exist. None if failed to parse.
- """
- result = annotation_signature_regex.match(signature_text.strip())
- if result is None:
- return None
-
- name_result = self.FUNCTION_PREFIX_NAME_RE.match(
- result.group('name').strip())
- if name_result is None:
- return None
-
- path = result.group('path')
- if path is not None:
- path = os.path.realpath(path.strip())
-
- linenum = result.group('linenum')
- if linenum is not None:
- linenum = int(linenum.strip())
-
- return (name_result.group('name').strip(), path, linenum)
-
- def ExpandArray(dic):
- """Parse and expand a symbol array
-
- Args:
- dic: Dictionary for the array annotation
-
- Returns:
- array of (symbol name, None, None).
- """
- # TODO(drinkcat): This function is quite inefficient, as it goes through
- # the symbol table multiple times.
-
- begin_name = dic['name']
- end_name = dic['name'] + "_end"
- offset = dic['offset'] if 'offset' in dic else 0
- stride = dic['stride']
-
- begin_address = None
- end_address = None
-
- for symbol in self.symbols:
- if (symbol.name == begin_name):
- begin_address = symbol.address
- if (symbol.name == end_name):
- end_address = symbol.address
-
- if (not begin_address or not end_address):
- return None
-
- output = []
- # TODO(drinkcat): This is inefficient as we go from address to symbol
- # object then to symbol name, and later on we'll go back from symbol name
- # to symbol object.
- for addr in range(begin_address+offset, end_address, stride):
- # TODO(drinkcat): Not all architectures need to drop the first bit.
- val = self.rodata[(addr-self.rodata_offset) // 4] & 0xfffffffe
- name = None
+ # Build symbol map, indexed by symbol address.
+ symbol_map = {}
for symbol in self.symbols:
- if (symbol.address == val):
- result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
- name = result.group('name')
- break
-
- if not name:
- raise StackAnalyzerError('Cannot find function for address %s.',
- hex(val))
-
- output.append((name, None, None))
-
- return output
-
- add_rules = collections.defaultdict(set)
- remove_rules = list()
- invalid_sigtxts = set()
-
- if 'add' in self.annotation and self.annotation['add'] is not None:
- for src_sigtxt, dst_sigtxts in self.annotation['add'].items():
- src_sig = NormalizeSignature(src_sigtxt)
- if src_sig is None:
- invalid_sigtxts.add(src_sigtxt)
- continue
-
- for dst_sigtxt in dst_sigtxts:
- if isinstance(dst_sigtxt, dict):
- dst_sig = ExpandArray(dst_sigtxt)
- if dst_sig is None:
- invalid_sigtxts.add(str(dst_sigtxt))
+ # If there are multiple symbols with same address, keeping any of them is
+ # good enough.
+ symbol_map[symbol.address] = symbol
+
+ # Parse the disassembly text. We update the variable "line" to next line
+ # when needed. There are two steps of parser:
+ #
+ # Step 1: Searching for the function head. Once reach the function head,
+ # move to the next line, which is the first line of function body.
+ #
+ # Step 2: Parsing each instruction line of function body. Once reach a
+ # non-instruction line, stop parsing and analyze the parsed instructions.
+ #
+ # Finally turn back to the step 1 without updating the line, because the
+ # current non-instruction line can be another function head.
+ function_map = {}
+ # The following three variables are the states of the parsing processing.
+ # They will be initialized properly during the state changes.
+ function_symbol = None
+ function_end = None
+ instructions = []
+
+ # Remove heading and tailing spaces for each line.
+ line_index = 0
+ while line_index < len(disasm_lines):
+ # Get the current line.
+ line = disasm_lines[line_index]
+
+ if function_symbol is None:
+ # Step 1: Search for the function head.
+
+ function_symbol = DetectFunctionHead(line)
+ if function_symbol is not None:
+ # Assume there is no empty function. If the function head is followed
+ # by EOF, it is an empty function.
+ assert line_index + 1 < len(disasm_lines)
+
+ # Found the function head, initialize and turn to the step 2.
+ instructions = []
+ # If symbol size exists, use it as a hint of function size.
+ if function_symbol.size > 0:
+ function_end = function_symbol.address + function_symbol.size
+ else:
+ function_end = None
+
else:
- add_rules[src_sig].update(dst_sig)
- else:
- dst_sig = NormalizeSignature(dst_sigtxt)
- if dst_sig is None:
- invalid_sigtxts.add(dst_sigtxt)
+ # Step 2: Parse the function body.
+
+ instruction = analyzer.ParseInstruction(line, function_end)
+ if instruction is not None:
+ instructions.append(instruction)
+
+ if instruction is None or line_index + 1 == len(disasm_lines):
+ # Either the invalid instruction or EOF indicates the end of the
+ # function, finalize the function analysis.
+
+ # Assume there is no empty function.
+ assert len(instructions) > 0
+
+ (stack_frame, callsites) = analyzer.AnalyzeFunction(
+ function_symbol, instructions
+ )
+ # Assume the function addresses are unique in the disassembly.
+ assert function_symbol.address not in function_map
+ function_map[function_symbol.address] = Function(
+ function_symbol.address,
+ function_symbol.name,
+ stack_frame,
+ callsites,
+ )
+
+ # Initialize and turn back to the step 1.
+ function_symbol = None
+
+ # If the current line isn't an instruction, it can be another function
+ # head, skip moving to the next line.
+ if instruction is None:
+ continue
+
+ # Move to the next line.
+ line_index += 1
+
+ # Resolve callees of functions.
+ for function in function_map.values():
+ for callsite in function.callsites:
+ if callsite.target is not None:
+ # Remain the callee as None if we can't resolve it.
+ callsite.callee = function_map.get(callsite.target)
+
+ return function_map
+
+ def MapAnnotation(self, function_map, signature_set):
+ """Map annotation signatures to functions.
+
+ Args:
+ function_map: Function map.
+ signature_set: Set of annotation signatures.
+
+ Returns:
+ Map of signatures to functions, map of signatures which can't be resolved.
+ """
+ # Build the symbol map indexed by symbol name. If there are multiple symbols
+ # with the same name, add them into a set. (e.g. symbols of static function
+ # with the same name)
+ symbol_map = collections.defaultdict(set)
+ for symbol in self.symbols:
+ if symbol.symtype == "F":
+ # Function symbol.
+ result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
+ if result is not None:
+ function = function_map.get(symbol.address)
+ # Ignore the symbol not in disassembly.
+ if function is not None:
+ # If there are multiple symbol with the same name and point to the
+ # same function, the set will deduplicate them.
+ symbol_map[result.group("name").strip()].add(function)
+
+ # Build the signature map indexed by annotation signature.
+ signature_map = {}
+ sig_error_map = {}
+ symbol_path_map = {}
+ for sig in signature_set:
+ (name, path, _) = sig
+
+ functions = symbol_map.get(name)
+ if functions is None:
+ sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
+ continue
+
+ if name not in symbol_path_map:
+ # Lazy symbol path resolving. Since the addr2line isn't fast, only
+ # resolve needed symbol paths.
+ group_map = collections.defaultdict(list)
+ for function in functions:
+ line_info = self.AddressToLine(function.address)[0]
+ if line_info is None:
+ continue
+
+ (_, symbol_path, _) = line_info
+
+ # Group the functions with the same symbol signature (symbol name +
+ # symbol path). Assume they are the same copies and do the same
+ # annotation operations of them because we don't know which copy is
+ # indicated by the users.
+ group_map[symbol_path].append(function)
+
+ symbol_path_map[name] = group_map
+
+ # Symbol matching.
+ function_group = None
+ group_map = symbol_path_map[name]
+ if len(group_map) > 0:
+ if path is None:
+ if len(group_map) > 1:
+ # There is ambiguity but the path isn't specified.
+ sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS
+ continue
+
+ # No path signature but all symbol signatures of functions are same.
+ # Assume they are the same functions, so there is no ambiguity.
+ (function_group,) = group_map.values()
+ else:
+ function_group = group_map.get(path)
+
+ if function_group is None:
+ sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
+ continue
+
+ # The function_group is a list of all the same functions (according to
+ # our assumption) which should be annotated together.
+ signature_map[sig] = function_group
+
+ return (signature_map, sig_error_map)
+
+ def LoadAnnotation(self):
+ """Load annotation rules.
+
+ Returns:
+ Map of add rules, set of remove rules, set of text signatures which can't
+ be parsed.
+ """
+ # Assume there is no ":" in the path.
+ # Example: "get_range.lto.2501[driver/accel_kionix.c:327]"
+ annotation_signature_regex = re.compile(
+ r"^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$"
+ )
+
+ def NormalizeSignature(signature_text):
+ """Parse and normalize the annotation signature.
+
+ Args:
+ signature_text: Text of the annotation signature.
+
+ Returns:
+ (function name, path, line number) of the signature. The path and line
+ number can be None if not exist. None if failed to parse.
+ """
+ result = annotation_signature_regex.match(signature_text.strip())
+ if result is None:
+ return None
+
+ name_result = self.FUNCTION_PREFIX_NAME_RE.match(
+ result.group("name").strip()
+ )
+ if name_result is None:
+ return None
+
+ path = result.group("path")
+ if path is not None:
+ path = os.path.realpath(path.strip())
+
+ linenum = result.group("linenum")
+ if linenum is not None:
+ linenum = int(linenum.strip())
+
+ return (name_result.group("name").strip(), path, linenum)
+
+ def ExpandArray(dic):
+ """Parse and expand a symbol array
+
+ Args:
+ dic: Dictionary for the array annotation
+
+ Returns:
+ array of (symbol name, None, None).
+ """
+ # TODO(drinkcat): This function is quite inefficient, as it goes through
+ # the symbol table multiple times.
+
+ begin_name = dic["name"]
+ end_name = dic["name"] + "_end"
+ offset = dic["offset"] if "offset" in dic else 0
+ stride = dic["stride"]
+
+ begin_address = None
+ end_address = None
+
+ for symbol in self.symbols:
+ if symbol.name == begin_name:
+ begin_address = symbol.address
+ if symbol.name == end_name:
+ end_address = symbol.address
+
+ if not begin_address or not end_address:
+ return None
+
+ output = []
+ # TODO(drinkcat): This is inefficient as we go from address to symbol
+ # object then to symbol name, and later on we'll go back from symbol name
+ # to symbol object.
+ for addr in range(begin_address + offset, end_address, stride):
+ # TODO(drinkcat): Not all architectures need to drop the first bit.
+ val = self.rodata[(addr - self.rodata_offset) // 4] & 0xFFFFFFFE
+ name = None
+ for symbol in self.symbols:
+ if symbol.address == val:
+ result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
+ name = result.group("name")
+ break
+
+ if not name:
+ raise StackAnalyzerError(
+ "Cannot find function for address %s.", hex(val)
+ )
+
+ output.append((name, None, None))
+
+ return output
+
+ add_rules = collections.defaultdict(set)
+ remove_rules = list()
+ invalid_sigtxts = set()
+
+ if "add" in self.annotation and self.annotation["add"] is not None:
+ for src_sigtxt, dst_sigtxts in self.annotation["add"].items():
+ src_sig = NormalizeSignature(src_sigtxt)
+ if src_sig is None:
+ invalid_sigtxts.add(src_sigtxt)
+ continue
+
+ for dst_sigtxt in dst_sigtxts:
+ if isinstance(dst_sigtxt, dict):
+ dst_sig = ExpandArray(dst_sigtxt)
+ if dst_sig is None:
+ invalid_sigtxts.add(str(dst_sigtxt))
+ else:
+ add_rules[src_sig].update(dst_sig)
+ else:
+ dst_sig = NormalizeSignature(dst_sigtxt)
+ if dst_sig is None:
+ invalid_sigtxts.add(dst_sigtxt)
+ else:
+ add_rules[src_sig].add(dst_sig)
+
+ if "remove" in self.annotation and self.annotation["remove"] is not None:
+ for sigtxt_path in self.annotation["remove"]:
+ if isinstance(sigtxt_path, str):
+ # The path has only one vertex.
+ sigtxt_path = [sigtxt_path]
+
+ if len(sigtxt_path) == 0:
+ continue
+
+ # Generate multiple remove paths from all the combinations of the
+ # signatures of each vertex.
+ sig_paths = [[]]
+ broken_flag = False
+ for sigtxt_node in sigtxt_path:
+ if isinstance(sigtxt_node, str):
+ # The vertex has only one signature.
+ sigtxt_set = {sigtxt_node}
+ elif isinstance(sigtxt_node, list):
+ # The vertex has multiple signatures.
+ sigtxt_set = set(sigtxt_node)
+ else:
+ # Assume the format of annotation is verified. There should be no
+ # invalid case.
+ assert False
+
+ sig_set = set()
+ for sigtxt in sigtxt_set:
+ sig = NormalizeSignature(sigtxt)
+ if sig is None:
+ invalid_sigtxts.add(sigtxt)
+ broken_flag = True
+ elif not broken_flag:
+ sig_set.add(sig)
+
+ if broken_flag:
+ continue
+
+ # Append each signature of the current node to the all previous
+ # remove paths.
+ sig_paths = [path + [sig] for path in sig_paths for sig in sig_set]
+
+ if not broken_flag:
+ # All signatures are normalized. The remove path has no error.
+ remove_rules.extend(sig_paths)
+
+ return (add_rules, remove_rules, invalid_sigtxts)
+
+ def ResolveAnnotation(self, function_map):
+ """Resolve annotation.
+
+ Args:
+ function_map: Function map.
+
+ Returns:
+ Set of added call edges, list of remove paths, set of eliminated
+ callsite addresses, set of annotation signatures which can't be resolved.
+ """
+
+ def StringifySignature(signature):
+ """Stringify the tupled signature.
+
+ Args:
+ signature: Tupled signature.
+
+ Returns:
+ Signature string.
+ """
+ (name, path, linenum) = signature
+ bracket_text = ""
+ if path is not None:
+ path = os.path.relpath(path)
+ if linenum is None:
+ bracket_text = "[{}]".format(path)
+ else:
+ bracket_text = "[{}:{}]".format(path, linenum)
+
+ return name + bracket_text
+
+ (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation()
+
+ signature_set = set()
+ for src_sig, dst_sigs in add_rules.items():
+ signature_set.add(src_sig)
+ signature_set.update(dst_sigs)
+
+ for remove_sigs in remove_rules:
+ signature_set.update(remove_sigs)
+
+ # Map signatures to functions.
+ (signature_map, sig_error_map) = self.MapAnnotation(function_map, signature_set)
+
+ # Build the indirect callsite map indexed by callsite signature.
+ indirect_map = collections.defaultdict(set)
+ for function in function_map.values():
+ for callsite in function.callsites:
+ if callsite.target is not None:
+ continue
+
+ # Found an indirect callsite.
+ line_info = self.AddressToLine(callsite.address)[0]
+ if line_info is None:
+ continue
+
+ (name, path, linenum) = line_info
+ result = self.FUNCTION_PREFIX_NAME_RE.match(name)
+ if result is None:
+ continue
+
+ indirect_map[(result.group("name").strip(), path, linenum)].add(
+ (function, callsite.address)
+ )
+
+ # Generate the annotation sets.
+ add_set = set()
+ remove_list = list()
+ eliminated_addrs = set()
+
+ for src_sig, dst_sigs in add_rules.items():
+ src_funcs = set(signature_map.get(src_sig, []))
+ # Try to match the source signature to the indirect callsites. Even if it
+ # can't be found in disassembly.
+ indirect_calls = indirect_map.get(src_sig)
+ if indirect_calls is not None:
+ for function, callsite_address in indirect_calls:
+ # Add the caller of the indirect callsite to the source functions.
+ src_funcs.add(function)
+ # Assume each callsite can be represented by a unique address.
+ eliminated_addrs.add(callsite_address)
+
+ if src_sig in sig_error_map:
+ # Assume the error is always the not found error. Since the signature
+ # found in indirect callsite map must be a full signature, it can't
+ # happen the ambiguous error.
+ assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND
+ # Found in inline stack, remove the not found error.
+ del sig_error_map[src_sig]
+
+ for dst_sig in dst_sigs:
+ dst_funcs = signature_map.get(dst_sig)
+ if dst_funcs is None:
+ continue
+
+ # Duplicate the call edge for all the same source and destination
+ # functions.
+ for src_func in src_funcs:
+ for dst_func in dst_funcs:
+ add_set.add((src_func, dst_func))
+
+ for remove_sigs in remove_rules:
+ # Since each signature can be mapped to multiple functions, generate
+ # multiple remove paths from all the combinations of these functions.
+ remove_paths = [[]]
+ skip_flag = False
+ for remove_sig in remove_sigs:
+ # Transform each signature to the corresponding functions.
+ remove_funcs = signature_map.get(remove_sig)
+ if remove_funcs is None:
+ # There is an unresolved signature in the remove path. Ignore the
+ # whole broken remove path.
+ skip_flag = True
+ break
+ else:
+ # Append each function of the current signature to the all previous
+ # remove paths.
+ remove_paths = [p + [f] for p in remove_paths for f in remove_funcs]
+
+ if skip_flag:
+ # Ignore the broken remove path.
+ continue
+
+ for remove_path in remove_paths:
+ # Deduplicate the remove paths.
+ if remove_path not in remove_list:
+ remove_list.append(remove_path)
+
+ # Format the error messages.
+ failed_sigtxts = set()
+ for sigtxt in invalid_sigtxts:
+ failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID))
+
+ for sig, error in sig_error_map.items():
+ failed_sigtxts.add((StringifySignature(sig), error))
+
+ return (add_set, remove_list, eliminated_addrs, failed_sigtxts)
+
+ def PreprocessAnnotation(
+ self, function_map, add_set, remove_list, eliminated_addrs
+ ):
+ """Preprocess the annotation and callgraph.
+
+ Add the missing call edges, and delete simple remove paths (the paths have
+ one or two vertices) from the function_map.
+
+ Eliminate the annotated indirect callsites.
+
+ Return the remaining remove list.
+
+ Args:
+ function_map: Function map.
+ add_set: Set of missing call edges.
+ remove_list: List of remove paths.
+ eliminated_addrs: Set of eliminated callsite addresses.
+
+ Returns:
+ List of remaining remove paths.
+ """
+
+ def CheckEdge(path):
+ """Check if all edges of the path are on the callgraph.
+
+ Args:
+ path: Path.
+
+ Returns:
+ True or False.
+ """
+ for index in range(len(path) - 1):
+ if (path[index], path[index + 1]) not in edge_set:
+ return False
+
+ return True
+
+ for src_func, dst_func in add_set:
+ # TODO(cheyuw): Support tailing call annotation.
+ src_func.callsites.append(Callsite(None, dst_func.address, False, dst_func))
+
+ # Delete simple remove paths.
+ remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2)
+ edge_set = set()
+ for function in function_map.values():
+ cleaned_callsites = []
+ for callsite in function.callsites:
+ if (callsite.callee,) in remove_simple or (
+ function,
+ callsite.callee,
+ ) in remove_simple:
+ continue
+
+ if callsite.target is None and callsite.address in eliminated_addrs:
+ continue
+
+ cleaned_callsites.append(callsite)
+ if callsite.callee is not None:
+ edge_set.add((function, callsite.callee))
+
+ function.callsites = cleaned_callsites
+
+ return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)]
+
+ def AnalyzeCallGraph(self, function_map, remove_list):
+ """Analyze callgraph.
+
+ It will update the max stack size and path for each function.
+
+ Args:
+ function_map: Function map.
+ remove_list: List of remove paths.
+
+ Returns:
+ List of function cycles.
+ """
+
+ def Traverse(curr_state):
+ """Traverse the callgraph and calculate the max stack usages of functions.
+
+ Args:
+ curr_state: Current state.
+
+ Returns:
+ SCC lowest link.
+ """
+ scc_index = scc_index_counter[0]
+ scc_index_counter[0] += 1
+ scc_index_map[curr_state] = scc_index
+ scc_lowlink = scc_index
+ scc_stack.append(curr_state)
+ # Push the current state in the stack. We can use a set to maintain this
+ # because the stacked states are unique; otherwise we will find a cycle
+ # first.
+ stacked_states.add(curr_state)
+
+ (curr_address, curr_positions) = curr_state
+ curr_func = function_map[curr_address]
+
+ invalid_flag = False
+ new_positions = list(curr_positions)
+ for index, position in enumerate(curr_positions):
+ remove_path = remove_list[index]
+
+ # The position of each remove path in the state is the length of the
+ # longest matching path between the prefix of the remove path and the
+ # suffix of the current traversing path. We maintain this length when
+ # appending the next callee to the traversing path. And it can be used
+ # to check if the remove path appears in the traversing path.
+
+ # TODO(cheyuw): Implement KMP algorithm to match remove paths
+ # efficiently.
+ if remove_path[position] is curr_func:
+ # Matches the current function, extend the length.
+ new_positions[index] = position + 1
+ if new_positions[index] == len(remove_path):
+ # The length of the longest matching path is equal to the length of
+ # the remove path, which means the suffix of the current traversing
+ # path matches the remove path.
+ invalid_flag = True
+ break
+
+ else:
+ # We can't get the new longest matching path by extending the previous
+ # one directly. Fallback to search the new longest matching path.
+
+ # If we can't find any matching path in the following search, reset
+ # the matching length to 0.
+ new_positions[index] = 0
+
+ # We want to find the new longest matching prefix of remove path with
+ # the suffix of the current traversing path. Because the new longest
+ # matching path won't be longer than the prevous one now, and part of
+ # the suffix matches the prefix of remove path, we can get the needed
+ # suffix from the previous matching prefix of the invalid path.
+ suffix = remove_path[:position] + [curr_func]
+ for offset in range(1, len(suffix)):
+ length = position - offset
+ if remove_path[:length] == suffix[offset:]:
+ new_positions[index] = length
+ break
+
+ new_positions = tuple(new_positions)
+
+ # If the current suffix is invalid, set the max stack usage to 0.
+ max_stack_usage = 0
+ max_callee_state = None
+ self_loop = False
+
+ if not invalid_flag:
+ # Max stack usage is at least equal to the stack frame.
+ max_stack_usage = curr_func.stack_frame
+ for callsite in curr_func.callsites:
+ callee = callsite.callee
+ if callee is None:
+ continue
+
+ callee_state = (callee.address, new_positions)
+ if callee_state not in scc_index_map:
+ # Unvisited state.
+ scc_lowlink = min(scc_lowlink, Traverse(callee_state))
+ elif callee_state in stacked_states:
+ # The state is shown in the stack. There is a cycle.
+ sub_stack_usage = 0
+ scc_lowlink = min(scc_lowlink, scc_index_map[callee_state])
+ if callee_state == curr_state:
+ self_loop = True
+
+ done_result = done_states.get(callee_state)
+ if done_result is not None:
+ # Already done this state and use its result. If the state reaches a
+ # cycle, reusing the result will cause inaccuracy (the stack usage
+ # of cycle depends on where the entrance is). But it's fine since we
+ # can't get accurate stack usage under this situation, and we rely
+ # on user-provided annotations to break the cycle, after which the
+ # result will be accurate again.
+ (sub_stack_usage, _) = done_result
+
+ if callsite.is_tail:
+ # For tailing call, since the callee reuses the stack frame of the
+ # caller, choose the larger one directly.
+ stack_usage = max(curr_func.stack_frame, sub_stack_usage)
+ else:
+ stack_usage = curr_func.stack_frame + sub_stack_usage
+
+ if stack_usage > max_stack_usage:
+ max_stack_usage = stack_usage
+ max_callee_state = callee_state
+
+ if scc_lowlink == scc_index:
+ group = []
+ while scc_stack[-1] != curr_state:
+ scc_state = scc_stack.pop()
+ stacked_states.remove(scc_state)
+ group.append(scc_state)
+
+ scc_stack.pop()
+ stacked_states.remove(curr_state)
+
+ # If the cycle is not empty, record it.
+ if len(group) > 0 or self_loop:
+ group.append(curr_state)
+ cycle_groups.append(group)
+
+ # Store the done result.
+ done_states[curr_state] = (max_stack_usage, max_callee_state)
+
+ if curr_positions == initial_positions:
+ # If the current state is initial state, we traversed the callgraph by
+ # using the current function as start point. Update the stack usage of
+ # the function.
+ # If the function matches a single vertex remove path, this will set its
+ # max stack usage to 0, which is not expected (we still calculate its
+ # max stack usage, but prevent any function from calling it). However,
+ # all the single vertex remove paths have been preprocessed and removed.
+ curr_func.stack_max_usage = max_stack_usage
+
+ # Reconstruct the max stack path by traversing the state transitions.
+ max_stack_path = [curr_func]
+ callee_state = max_callee_state
+ while callee_state is not None:
+ # The first element of state tuple is function address.
+ max_stack_path.append(function_map[callee_state[0]])
+ done_result = done_states.get(callee_state)
+ # All of the descendants should be done.
+ assert done_result is not None
+ (_, callee_state) = done_result
+
+ curr_func.stack_max_path = max_stack_path
+
+ return scc_lowlink
+
+ # The state is the concatenation of the current function address and the
+ # state of matching position.
+ initial_positions = (0,) * len(remove_list)
+ done_states = {}
+ stacked_states = set()
+ scc_index_counter = [0]
+ scc_index_map = {}
+ scc_stack = []
+ cycle_groups = []
+ for function in function_map.values():
+ if function.stack_max_usage is None:
+ Traverse((function.address, initial_positions))
+
+ cycle_functions = []
+ for group in cycle_groups:
+ cycle = set(function_map[state[0]] for state in group)
+ if cycle not in cycle_functions:
+ cycle_functions.append(cycle)
+
+ return cycle_functions
+
+ def Analyze(self):
+ """Run the stack analysis.
+
+ Raises:
+ StackAnalyzerError: If disassembly fails.
+ """
+
+ def OutputInlineStack(address, prefix=""):
+ """Output beautiful inline stack.
+
+ Args:
+ address: Address.
+ prefix: Prefix of each line.
+
+ Returns:
+ Key for sorting, output text
+ """
+ line_infos = self.AddressToLine(address, True)
+
+ if line_infos[0] is None:
+ order_key = (None, None)
else:
- add_rules[src_sig].add(dst_sig)
-
- if 'remove' in self.annotation and self.annotation['remove'] is not None:
- for sigtxt_path in self.annotation['remove']:
- if isinstance(sigtxt_path, str):
- # The path has only one vertex.
- sigtxt_path = [sigtxt_path]
-
- if len(sigtxt_path) == 0:
- continue
-
- # Generate multiple remove paths from all the combinations of the
- # signatures of each vertex.
- sig_paths = [[]]
- broken_flag = False
- for sigtxt_node in sigtxt_path:
- if isinstance(sigtxt_node, str):
- # The vertex has only one signature.
- sigtxt_set = {sigtxt_node}
- elif isinstance(sigtxt_node, list):
- # The vertex has multiple signatures.
- sigtxt_set = set(sigtxt_node)
- else:
- # Assume the format of annotation is verified. There should be no
- # invalid case.
- assert False
-
- sig_set = set()
- for sigtxt in sigtxt_set:
- sig = NormalizeSignature(sigtxt)
- if sig is None:
- invalid_sigtxts.add(sigtxt)
- broken_flag = True
- elif not broken_flag:
- sig_set.add(sig)
-
- if broken_flag:
- continue
-
- # Append each signature of the current node to the all previous
- # remove paths.
- sig_paths = [path + [sig] for path in sig_paths for sig in sig_set]
-
- if not broken_flag:
- # All signatures are normalized. The remove path has no error.
- remove_rules.extend(sig_paths)
-
- return (add_rules, remove_rules, invalid_sigtxts)
+ (_, path, linenum) = line_infos[0]
+ order_key = (linenum, path)
+
+ line_texts = []
+ for line_info in reversed(line_infos):
+ if line_info is None:
+ (function_name, path, linenum) = ("??", "??", 0)
+ else:
+ (function_name, path, linenum) = line_info
+
+ line_texts.append(
+ "{}[{}:{}]".format(function_name, os.path.relpath(path), linenum)
+ )
+
+ output = "{}-> {} {:x}\n".format(prefix, line_texts[0], address)
+ for depth, line_text in enumerate(line_texts[1:]):
+ output += "{} {}- {}\n".format(prefix, " " * depth, line_text)
+
+ # Remove the last newline character.
+ return (order_key, output.rstrip("\n"))
+
+ # Analyze disassembly.
+ try:
+ disasm_text = subprocess.check_output(
+ [self.options.objdump, "-d", self.options.elf_path], encoding="utf-8"
+ )
+ except subprocess.CalledProcessError:
+ raise StackAnalyzerError("objdump failed to disassemble.")
+ except OSError:
+ raise StackAnalyzerError("Failed to run objdump.")
+
+ function_map = self.AnalyzeDisassembly(disasm_text)
+ result = self.ResolveAnnotation(function_map)
+ (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result
+ remove_list = self.PreprocessAnnotation(
+ function_map, add_set, remove_list, eliminated_addrs
+ )
+ cycle_functions = self.AnalyzeCallGraph(function_map, remove_list)
+
+ # Print the results of task-aware stack analysis.
+ extra_stack_frame = self.annotation.get(
+ "exception_frame_size", DEFAULT_EXCEPTION_FRAME_SIZE
+ )
+ for task in self.tasklist:
+ routine_func = function_map[task.routine_address]
+ print(
+ "Task: {}, Max size: {} ({} + {}), Allocated size: {}".format(
+ task.name,
+ routine_func.stack_max_usage + extra_stack_frame,
+ routine_func.stack_max_usage,
+ extra_stack_frame,
+ task.stack_max_size,
+ )
+ )
+
+ print("Call Trace:")
+ max_stack_path = routine_func.stack_max_path
+ # Assume the routine function is resolved.
+ assert max_stack_path is not None
+ for depth, curr_func in enumerate(max_stack_path):
+ line_info = self.AddressToLine(curr_func.address)[0]
+ if line_info is None:
+ (path, linenum) = ("??", 0)
+ else:
+ (_, path, linenum) = line_info
+
+ print(
+ " {} ({}) [{}:{}] {:x}".format(
+ curr_func.name,
+ curr_func.stack_frame,
+ os.path.relpath(path),
+ linenum,
+ curr_func.address,
+ )
+ )
+
+ if depth + 1 < len(max_stack_path):
+ succ_func = max_stack_path[depth + 1]
+ text_list = []
+ for callsite in curr_func.callsites:
+ if callsite.callee is succ_func:
+ indent_prefix = " "
+ if callsite.address is None:
+ order_text = (
+ None,
+ "{}-> [annotation]".format(indent_prefix),
+ )
+ else:
+ order_text = OutputInlineStack(
+ callsite.address, indent_prefix
+ )
+
+ text_list.append(order_text)
+
+ for _, text in sorted(text_list, key=lambda item: item[0]):
+ print(text)
+
+ print("Unresolved indirect callsites:")
+ for function in function_map.values():
+ indirect_callsites = []
+ for callsite in function.callsites:
+ if callsite.target is None:
+ indirect_callsites.append(callsite.address)
+
+ if len(indirect_callsites) > 0:
+ print(" In function {}:".format(function.name))
+ text_list = []
+ for address in indirect_callsites:
+ text_list.append(OutputInlineStack(address, " "))
+
+ for _, text in sorted(text_list, key=lambda item: item[0]):
+ print(text)
+
+ print("Unresolved annotation signatures:")
+ for sigtxt, error in failed_sigtxts:
+ print(" {}: {}".format(sigtxt, error))
+
+ if len(cycle_functions) > 0:
+ print("There are cycles in the following function sets:")
+ for functions in cycle_functions:
+ print("[{}]".format(", ".join(function.name for function in functions)))
- def ResolveAnnotation(self, function_map):
- """Resolve annotation.
- Args:
- function_map: Function map.
+def ParseArgs():
+ """Parse commandline arguments.
Returns:
- Set of added call edges, list of remove paths, set of eliminated
- callsite addresses, set of annotation signatures which can't be resolved.
+ options: Namespace from argparse.parse_args().
"""
- def StringifySignature(signature):
- """Stringify the tupled signature.
-
- Args:
- signature: Tupled signature.
-
- Returns:
- Signature string.
- """
- (name, path, linenum) = signature
- bracket_text = ''
- if path is not None:
- path = os.path.relpath(path)
- if linenum is None:
- bracket_text = '[{}]'.format(path)
- else:
- bracket_text = '[{}:{}]'.format(path, linenum)
-
- return name + bracket_text
-
- (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation()
-
- signature_set = set()
- for src_sig, dst_sigs in add_rules.items():
- signature_set.add(src_sig)
- signature_set.update(dst_sigs)
-
- for remove_sigs in remove_rules:
- signature_set.update(remove_sigs)
-
- # Map signatures to functions.
- (signature_map, sig_error_map) = self.MapAnnotation(function_map,
- signature_set)
-
- # Build the indirect callsite map indexed by callsite signature.
- indirect_map = collections.defaultdict(set)
- for function in function_map.values():
- for callsite in function.callsites:
- if callsite.target is not None:
- continue
-
- # Found an indirect callsite.
- line_info = self.AddressToLine(callsite.address)[0]
- if line_info is None:
- continue
-
- (name, path, linenum) = line_info
- result = self.FUNCTION_PREFIX_NAME_RE.match(name)
- if result is None:
- continue
-
- indirect_map[(result.group('name').strip(), path, linenum)].add(
- (function, callsite.address))
-
- # Generate the annotation sets.
- add_set = set()
- remove_list = list()
- eliminated_addrs = set()
-
- for src_sig, dst_sigs in add_rules.items():
- src_funcs = set(signature_map.get(src_sig, []))
- # Try to match the source signature to the indirect callsites. Even if it
- # can't be found in disassembly.
- indirect_calls = indirect_map.get(src_sig)
- if indirect_calls is not None:
- for function, callsite_address in indirect_calls:
- # Add the caller of the indirect callsite to the source functions.
- src_funcs.add(function)
- # Assume each callsite can be represented by a unique address.
- eliminated_addrs.add(callsite_address)
-
- if src_sig in sig_error_map:
- # Assume the error is always the not found error. Since the signature
- # found in indirect callsite map must be a full signature, it can't
- # happen the ambiguous error.
- assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND
- # Found in inline stack, remove the not found error.
- del sig_error_map[src_sig]
-
- for dst_sig in dst_sigs:
- dst_funcs = signature_map.get(dst_sig)
- if dst_funcs is None:
- continue
-
- # Duplicate the call edge for all the same source and destination
- # functions.
- for src_func in src_funcs:
- for dst_func in dst_funcs:
- add_set.add((src_func, dst_func))
-
- for remove_sigs in remove_rules:
- # Since each signature can be mapped to multiple functions, generate
- # multiple remove paths from all the combinations of these functions.
- remove_paths = [[]]
- skip_flag = False
- for remove_sig in remove_sigs:
- # Transform each signature to the corresponding functions.
- remove_funcs = signature_map.get(remove_sig)
- if remove_funcs is None:
- # There is an unresolved signature in the remove path. Ignore the
- # whole broken remove path.
- skip_flag = True
- break
- else:
- # Append each function of the current signature to the all previous
- # remove paths.
- remove_paths = [p + [f] for p in remove_paths for f in remove_funcs]
-
- if skip_flag:
- # Ignore the broken remove path.
- continue
-
- for remove_path in remove_paths:
- # Deduplicate the remove paths.
- if remove_path not in remove_list:
- remove_list.append(remove_path)
-
- # Format the error messages.
- failed_sigtxts = set()
- for sigtxt in invalid_sigtxts:
- failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID))
-
- for sig, error in sig_error_map.items():
- failed_sigtxts.add((StringifySignature(sig), error))
-
- return (add_set, remove_list, eliminated_addrs, failed_sigtxts)
-
- def PreprocessAnnotation(self, function_map, add_set, remove_list,
- eliminated_addrs):
- """Preprocess the annotation and callgraph.
+ parser = argparse.ArgumentParser(description="EC firmware stack analyzer.")
+ parser.add_argument("elf_path", help="the path of EC firmware ELF")
+ parser.add_argument(
+ "--export_taskinfo",
+ required=True,
+ help="the path of export_taskinfo.so utility",
+ )
+ parser.add_argument(
+ "--section",
+ required=True,
+ help="the section.",
+ choices=[SECTION_RO, SECTION_RW],
+ )
+ parser.add_argument("--objdump", default="objdump", help="the path of objdump")
+ parser.add_argument(
+ "--addr2line", default="addr2line", help="the path of addr2line"
+ )
+ parser.add_argument(
+ "--annotation", default=None, help="the path of annotation file"
+ )
+
+ # TODO(cheyuw): Add an option for dumping stack usage of all functions.
+
+ return parser.parse_args()
- Add the missing call edges, and delete simple remove paths (the paths have
- one or two vertices) from the function_map.
- Eliminate the annotated indirect callsites.
-
- Return the remaining remove list.
+def ParseSymbolText(symbol_text):
+ """Parse the content of the symbol text.
Args:
- function_map: Function map.
- add_set: Set of missing call edges.
- remove_list: List of remove paths.
- eliminated_addrs: Set of eliminated callsite addresses.
+ symbol_text: Text of the symbols.
Returns:
- List of remaining remove paths.
+ symbols: Symbol list.
"""
- def CheckEdge(path):
- """Check if all edges of the path are on the callgraph.
-
- Args:
- path: Path.
-
- Returns:
- True or False.
- """
- for index in range(len(path) - 1):
- if (path[index], path[index + 1]) not in edge_set:
- return False
-
- return True
-
- for src_func, dst_func in add_set:
- # TODO(cheyuw): Support tailing call annotation.
- src_func.callsites.append(
- Callsite(None, dst_func.address, False, dst_func))
-
- # Delete simple remove paths.
- remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2)
- edge_set = set()
- for function in function_map.values():
- cleaned_callsites = []
- for callsite in function.callsites:
- if ((callsite.callee,) in remove_simple or
- (function, callsite.callee) in remove_simple):
- continue
-
- if callsite.target is None and callsite.address in eliminated_addrs:
- continue
-
- cleaned_callsites.append(callsite)
- if callsite.callee is not None:
- edge_set.add((function, callsite.callee))
+ # Example: "10093064 g F .text 0000015c .hidden hook_task"
+ symbol_regex = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+"
+ r"((?P<type>[OF])\s+)?\S+\s+"
+ r"(?P<size>[0-9A-Fa-f]+)\s+"
+ r"(\S+\s+)?(?P<name>\S+)$"
+ )
+
+ symbols = []
+ for line in symbol_text.splitlines():
+ line = line.strip()
+ result = symbol_regex.match(line)
+ if result is not None:
+ address = int(result.group("address"), 16)
+ symtype = result.group("type")
+ if symtype is None:
+ symtype = "O"
- function.callsites = cleaned_callsites
+ size = int(result.group("size"), 16)
+ name = result.group("name")
+ symbols.append(Symbol(address, symtype, size, name))
- return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)]
+ return symbols
- def AnalyzeCallGraph(self, function_map, remove_list):
- """Analyze callgraph.
- It will update the max stack size and path for each function.
+def ParseRoDataText(rodata_text):
+ """Parse the content of rodata
Args:
- function_map: Function map.
- remove_list: List of remove paths.
+ symbol_text: Text of the rodata dump.
Returns:
- List of function cycles.
+ symbols: Symbol list.
"""
- def Traverse(curr_state):
- """Traverse the callgraph and calculate the max stack usages of functions.
-
- Args:
- curr_state: Current state.
-
- Returns:
- SCC lowest link.
- """
- scc_index = scc_index_counter[0]
- scc_index_counter[0] += 1
- scc_index_map[curr_state] = scc_index
- scc_lowlink = scc_index
- scc_stack.append(curr_state)
- # Push the current state in the stack. We can use a set to maintain this
- # because the stacked states are unique; otherwise we will find a cycle
- # first.
- stacked_states.add(curr_state)
-
- (curr_address, curr_positions) = curr_state
- curr_func = function_map[curr_address]
-
- invalid_flag = False
- new_positions = list(curr_positions)
- for index, position in enumerate(curr_positions):
- remove_path = remove_list[index]
-
- # The position of each remove path in the state is the length of the
- # longest matching path between the prefix of the remove path and the
- # suffix of the current traversing path. We maintain this length when
- # appending the next callee to the traversing path. And it can be used
- # to check if the remove path appears in the traversing path.
-
- # TODO(cheyuw): Implement KMP algorithm to match remove paths
- # efficiently.
- if remove_path[position] is curr_func:
- # Matches the current function, extend the length.
- new_positions[index] = position + 1
- if new_positions[index] == len(remove_path):
- # The length of the longest matching path is equal to the length of
- # the remove path, which means the suffix of the current traversing
- # path matches the remove path.
- invalid_flag = True
- break
-
- else:
- # We can't get the new longest matching path by extending the previous
- # one directly. Fallback to search the new longest matching path.
-
- # If we can't find any matching path in the following search, reset
- # the matching length to 0.
- new_positions[index] = 0
-
- # We want to find the new longest matching prefix of remove path with
- # the suffix of the current traversing path. Because the new longest
- # matching path won't be longer than the prevous one now, and part of
- # the suffix matches the prefix of remove path, we can get the needed
- # suffix from the previous matching prefix of the invalid path.
- suffix = remove_path[:position] + [curr_func]
- for offset in range(1, len(suffix)):
- length = position - offset
- if remove_path[:length] == suffix[offset:]:
- new_positions[index] = length
- break
-
- new_positions = tuple(new_positions)
-
- # If the current suffix is invalid, set the max stack usage to 0.
- max_stack_usage = 0
- max_callee_state = None
- self_loop = False
-
- if not invalid_flag:
- # Max stack usage is at least equal to the stack frame.
- max_stack_usage = curr_func.stack_frame
- for callsite in curr_func.callsites:
- callee = callsite.callee
- if callee is None:
+ # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K...
+ # 100a7294 00000000 00000000 01000000 ............
+
+ base_offset = None
+ offset = None
+ rodata = []
+ for line in rodata_text.splitlines():
+ line = line.strip()
+ space = line.find(" ")
+ if space < 0:
+ continue
+ try:
+ address = int(line[0:space], 16)
+ except ValueError:
continue
- callee_state = (callee.address, new_positions)
- if callee_state not in scc_index_map:
- # Unvisited state.
- scc_lowlink = min(scc_lowlink, Traverse(callee_state))
- elif callee_state in stacked_states:
- # The state is shown in the stack. There is a cycle.
- sub_stack_usage = 0
- scc_lowlink = min(scc_lowlink, scc_index_map[callee_state])
- if callee_state == curr_state:
- self_loop = True
-
- done_result = done_states.get(callee_state)
- if done_result is not None:
- # Already done this state and use its result. If the state reaches a
- # cycle, reusing the result will cause inaccuracy (the stack usage
- # of cycle depends on where the entrance is). But it's fine since we
- # can't get accurate stack usage under this situation, and we rely
- # on user-provided annotations to break the cycle, after which the
- # result will be accurate again.
- (sub_stack_usage, _) = done_result
-
- if callsite.is_tail:
- # For tailing call, since the callee reuses the stack frame of the
- # caller, choose the larger one directly.
- stack_usage = max(curr_func.stack_frame, sub_stack_usage)
- else:
- stack_usage = curr_func.stack_frame + sub_stack_usage
-
- if stack_usage > max_stack_usage:
- max_stack_usage = stack_usage
- max_callee_state = callee_state
-
- if scc_lowlink == scc_index:
- group = []
- while scc_stack[-1] != curr_state:
- scc_state = scc_stack.pop()
- stacked_states.remove(scc_state)
- group.append(scc_state)
-
- scc_stack.pop()
- stacked_states.remove(curr_state)
-
- # If the cycle is not empty, record it.
- if len(group) > 0 or self_loop:
- group.append(curr_state)
- cycle_groups.append(group)
-
- # Store the done result.
- done_states[curr_state] = (max_stack_usage, max_callee_state)
-
- if curr_positions == initial_positions:
- # If the current state is initial state, we traversed the callgraph by
- # using the current function as start point. Update the stack usage of
- # the function.
- # If the function matches a single vertex remove path, this will set its
- # max stack usage to 0, which is not expected (we still calculate its
- # max stack usage, but prevent any function from calling it). However,
- # all the single vertex remove paths have been preprocessed and removed.
- curr_func.stack_max_usage = max_stack_usage
-
- # Reconstruct the max stack path by traversing the state transitions.
- max_stack_path = [curr_func]
- callee_state = max_callee_state
- while callee_state is not None:
- # The first element of state tuple is function address.
- max_stack_path.append(function_map[callee_state[0]])
- done_result = done_states.get(callee_state)
- # All of the descendants should be done.
- assert done_result is not None
- (_, callee_state) = done_result
-
- curr_func.stack_max_path = max_stack_path
-
- return scc_lowlink
-
- # The state is the concatenation of the current function address and the
- # state of matching position.
- initial_positions = (0,) * len(remove_list)
- done_states = {}
- stacked_states = set()
- scc_index_counter = [0]
- scc_index_map = {}
- scc_stack = []
- cycle_groups = []
- for function in function_map.values():
- if function.stack_max_usage is None:
- Traverse((function.address, initial_positions))
-
- cycle_functions = []
- for group in cycle_groups:
- cycle = set(function_map[state[0]] for state in group)
- if cycle not in cycle_functions:
- cycle_functions.append(cycle)
-
- return cycle_functions
-
- def Analyze(self):
- """Run the stack analysis.
-
- Raises:
- StackAnalyzerError: If disassembly fails.
- """
- def OutputInlineStack(address, prefix=''):
- """Output beautiful inline stack.
-
- Args:
- address: Address.
- prefix: Prefix of each line.
-
- Returns:
- Key for sorting, output text
- """
- line_infos = self.AddressToLine(address, True)
-
- if line_infos[0] is None:
- order_key = (None, None)
- else:
- (_, path, linenum) = line_infos[0]
- order_key = (linenum, path)
-
- line_texts = []
- for line_info in reversed(line_infos):
- if line_info is None:
- (function_name, path, linenum) = ('??', '??', 0)
- else:
- (function_name, path, linenum) = line_info
-
- line_texts.append('{}[{}:{}]'.format(function_name,
- os.path.relpath(path),
- linenum))
-
- output = '{}-> {} {:x}\n'.format(prefix, line_texts[0], address)
- for depth, line_text in enumerate(line_texts[1:]):
- output += '{} {}- {}\n'.format(prefix, ' ' * depth, line_text)
-
- # Remove the last newline character.
- return (order_key, output.rstrip('\n'))
-
- # Analyze disassembly.
- try:
- disasm_text = subprocess.check_output([self.options.objdump,
- '-d',
- self.options.elf_path],
- encoding='utf-8')
- except subprocess.CalledProcessError:
- raise StackAnalyzerError('objdump failed to disassemble.')
- except OSError:
- raise StackAnalyzerError('Failed to run objdump.')
-
- function_map = self.AnalyzeDisassembly(disasm_text)
- result = self.ResolveAnnotation(function_map)
- (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result
- remove_list = self.PreprocessAnnotation(function_map,
- add_set,
- remove_list,
- eliminated_addrs)
- cycle_functions = self.AnalyzeCallGraph(function_map, remove_list)
-
- # Print the results of task-aware stack analysis.
- extra_stack_frame = self.annotation.get('exception_frame_size',
- DEFAULT_EXCEPTION_FRAME_SIZE)
- for task in self.tasklist:
- routine_func = function_map[task.routine_address]
- print('Task: {}, Max size: {} ({} + {}), Allocated size: {}'.format(
- task.name,
- routine_func.stack_max_usage + extra_stack_frame,
- routine_func.stack_max_usage,
- extra_stack_frame,
- task.stack_max_size))
-
- print('Call Trace:')
- max_stack_path = routine_func.stack_max_path
- # Assume the routine function is resolved.
- assert max_stack_path is not None
- for depth, curr_func in enumerate(max_stack_path):
- line_info = self.AddressToLine(curr_func.address)[0]
- if line_info is None:
- (path, linenum) = ('??', 0)
- else:
- (_, path, linenum) = line_info
-
- print(' {} ({}) [{}:{}] {:x}'.format(curr_func.name,
- curr_func.stack_frame,
- os.path.relpath(path),
- linenum,
- curr_func.address))
-
- if depth + 1 < len(max_stack_path):
- succ_func = max_stack_path[depth + 1]
- text_list = []
- for callsite in curr_func.callsites:
- if callsite.callee is succ_func:
- indent_prefix = ' '
- if callsite.address is None:
- order_text = (None, '{}-> [annotation]'.format(indent_prefix))
- else:
- order_text = OutputInlineStack(callsite.address, indent_prefix)
-
- text_list.append(order_text)
-
- for _, text in sorted(text_list, key=lambda item: item[0]):
- print(text)
-
- print('Unresolved indirect callsites:')
- for function in function_map.values():
- indirect_callsites = []
- for callsite in function.callsites:
- if callsite.target is None:
- indirect_callsites.append(callsite.address)
-
- if len(indirect_callsites) > 0:
- print(' In function {}:'.format(function.name))
- text_list = []
- for address in indirect_callsites:
- text_list.append(OutputInlineStack(address, ' '))
-
- for _, text in sorted(text_list, key=lambda item: item[0]):
- print(text)
-
- print('Unresolved annotation signatures:')
- for sigtxt, error in failed_sigtxts:
- print(' {}: {}'.format(sigtxt, error))
-
- if len(cycle_functions) > 0:
- print('There are cycles in the following function sets:')
- for functions in cycle_functions:
- print('[{}]'.format(', '.join(function.name for function in functions)))
-
-
-def ParseArgs():
- """Parse commandline arguments.
-
- Returns:
- options: Namespace from argparse.parse_args().
- """
- parser = argparse.ArgumentParser(description="EC firmware stack analyzer.")
- parser.add_argument('elf_path', help="the path of EC firmware ELF")
- parser.add_argument('--export_taskinfo', required=True,
- help="the path of export_taskinfo.so utility")
- parser.add_argument('--section', required=True, help='the section.',
- choices=[SECTION_RO, SECTION_RW])
- parser.add_argument('--objdump', default='objdump',
- help='the path of objdump')
- parser.add_argument('--addr2line', default='addr2line',
- help='the path of addr2line')
- parser.add_argument('--annotation', default=None,
- help='the path of annotation file')
-
- # TODO(cheyuw): Add an option for dumping stack usage of all functions.
-
- return parser.parse_args()
-
-
-def ParseSymbolText(symbol_text):
- """Parse the content of the symbol text.
-
- Args:
- symbol_text: Text of the symbols.
-
- Returns:
- symbols: Symbol list.
- """
- # Example: "10093064 g F .text 0000015c .hidden hook_task"
- symbol_regex = re.compile(r'^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+'
- r'((?P<type>[OF])\s+)?\S+\s+'
- r'(?P<size>[0-9A-Fa-f]+)\s+'
- r'(\S+\s+)?(?P<name>\S+)$')
-
- symbols = []
- for line in symbol_text.splitlines():
- line = line.strip()
- result = symbol_regex.match(line)
- if result is not None:
- address = int(result.group('address'), 16)
- symtype = result.group('type')
- if symtype is None:
- symtype = 'O'
-
- size = int(result.group('size'), 16)
- name = result.group('name')
- symbols.append(Symbol(address, symtype, size, name))
-
- return symbols
-
-
-def ParseRoDataText(rodata_text):
- """Parse the content of rodata
-
- Args:
- symbol_text: Text of the rodata dump.
-
- Returns:
- symbols: Symbol list.
- """
- # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K...
- # 100a7294 00000000 00000000 01000000 ............
-
- base_offset = None
- offset = None
- rodata = []
- for line in rodata_text.splitlines():
- line = line.strip()
- space = line.find(' ')
- if space < 0:
- continue
- try:
- address = int(line[0:space], 16)
- except ValueError:
- continue
-
- if not base_offset:
- base_offset = address
- offset = address
- elif address != offset:
- raise StackAnalyzerError('objdump of rodata not contiguous.')
+ if not base_offset:
+ base_offset = address
+ offset = address
+ elif address != offset:
+ raise StackAnalyzerError("objdump of rodata not contiguous.")
- for i in range(0, 4):
- num = line[(space + 1 + i*9):(space + 9 + i*9)]
- if len(num.strip()) > 0:
- val = int(num, 16)
- else:
- val = 0
- # TODO(drinkcat): Not all platforms are necessarily big-endian
- rodata.append((val & 0x000000ff) << 24 |
- (val & 0x0000ff00) << 8 |
- (val & 0x00ff0000) >> 8 |
- (val & 0xff000000) >> 24)
+ for i in range(0, 4):
+ num = line[(space + 1 + i * 9) : (space + 9 + i * 9)]
+ if len(num.strip()) > 0:
+ val = int(num, 16)
+ else:
+ val = 0
+ # TODO(drinkcat): Not all platforms are necessarily big-endian
+ rodata.append(
+ (val & 0x000000FF) << 24
+ | (val & 0x0000FF00) << 8
+ | (val & 0x00FF0000) >> 8
+ | (val & 0xFF000000) >> 24
+ )
- offset = offset + 4*4
+ offset = offset + 4 * 4
- return (base_offset, rodata)
+ return (base_offset, rodata)
def LoadTasklist(section, export_taskinfo, symbols):
- """Load the task information.
+ """Load the task information.
- Args:
- section: Section (RO | RW).
- export_taskinfo: Handle of export_taskinfo.so.
- symbols: Symbol list.
+ Args:
+ section: Section (RO | RW).
+ export_taskinfo: Handle of export_taskinfo.so.
+ symbols: Symbol list.
- Returns:
- tasklist: Task list.
- """
+ Returns:
+ tasklist: Task list.
+ """
- TaskInfoPointer = ctypes.POINTER(TaskInfo)
- taskinfos = TaskInfoPointer()
- if section == SECTION_RO:
- get_taskinfos_func = export_taskinfo.get_ro_taskinfos
- else:
- get_taskinfos_func = export_taskinfo.get_rw_taskinfos
+ TaskInfoPointer = ctypes.POINTER(TaskInfo)
+ taskinfos = TaskInfoPointer()
+ if section == SECTION_RO:
+ get_taskinfos_func = export_taskinfo.get_ro_taskinfos
+ else:
+ get_taskinfos_func = export_taskinfo.get_rw_taskinfos
- taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos))
+ taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos))
- tasklist = []
- for index in range(taskinfo_num):
- taskinfo = taskinfos[index]
- tasklist.append(Task(taskinfo.name.decode('utf-8'),
- taskinfo.routine.decode('utf-8'),
- taskinfo.stack_size))
+ tasklist = []
+ for index in range(taskinfo_num):
+ taskinfo = taskinfos[index]
+ tasklist.append(
+ Task(
+ taskinfo.name.decode("utf-8"),
+ taskinfo.routine.decode("utf-8"),
+ taskinfo.stack_size,
+ )
+ )
- # Resolve routine address for each task. It's more efficient to resolve all
- # routine addresses of tasks together.
- routine_map = dict((task.routine_name, None) for task in tasklist)
+ # Resolve routine address for each task. It's more efficient to resolve all
+ # routine addresses of tasks together.
+ routine_map = dict((task.routine_name, None) for task in tasklist)
- for symbol in symbols:
- # Resolve task routine address.
- if symbol.name in routine_map:
- # Assume the symbol of routine is unique.
- assert routine_map[symbol.name] is None
- routine_map[symbol.name] = symbol.address
+ for symbol in symbols:
+ # Resolve task routine address.
+ if symbol.name in routine_map:
+ # Assume the symbol of routine is unique.
+ assert routine_map[symbol.name] is None
+ routine_map[symbol.name] = symbol.address
- for task in tasklist:
- address = routine_map[task.routine_name]
- # Assume we have resolved all routine addresses.
- assert address is not None
- task.routine_address = address
+ for task in tasklist:
+ address = routine_map[task.routine_name]
+ # Assume we have resolved all routine addresses.
+ assert address is not None
+ task.routine_address = address
- return tasklist
+ return tasklist
def main():
- """Main function."""
- try:
- options = ParseArgs()
-
- # Load annotation config.
- if options.annotation is None:
- annotation = {}
- elif not os.path.exists(options.annotation):
- print('Warning: Annotation file {} does not exist.'
- .format(options.annotation))
- annotation = {}
- else:
- try:
- with open(options.annotation, 'r') as annotation_file:
- annotation = yaml.safe_load(annotation_file)
-
- except yaml.YAMLError:
- raise StackAnalyzerError('Failed to parse annotation file {}.'
- .format(options.annotation))
- except IOError:
- raise StackAnalyzerError('Failed to open annotation file {}.'
- .format(options.annotation))
-
- # TODO(cheyuw): Do complete annotation format verification.
- if not isinstance(annotation, dict):
- raise StackAnalyzerError('Invalid annotation file {}.'
- .format(options.annotation))
-
- # Generate and parse the symbols.
+ """Main function."""
try:
- symbol_text = subprocess.check_output([options.objdump,
- '-t',
- options.elf_path],
- encoding='utf-8')
- rodata_text = subprocess.check_output([options.objdump,
- '-s',
- '-j', '.rodata',
- options.elf_path],
- encoding='utf-8')
- except subprocess.CalledProcessError:
- raise StackAnalyzerError('objdump failed to dump symbol table or rodata.')
- except OSError:
- raise StackAnalyzerError('Failed to run objdump.')
-
- symbols = ParseSymbolText(symbol_text)
- rodata = ParseRoDataText(rodata_text)
-
- # Load the tasklist.
- try:
- export_taskinfo = ctypes.CDLL(options.export_taskinfo)
- except OSError:
- raise StackAnalyzerError('Failed to load export_taskinfo.')
-
- tasklist = LoadTasklist(options.section, export_taskinfo, symbols)
-
- analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation)
- analyzer.Analyze()
- except StackAnalyzerError as e:
- print('Error: {}'.format(e))
-
-
-if __name__ == '__main__':
- main()
+ options = ParseArgs()
+
+ # Load annotation config.
+ if options.annotation is None:
+ annotation = {}
+ elif not os.path.exists(options.annotation):
+ print(
+ "Warning: Annotation file {} does not exist.".format(options.annotation)
+ )
+ annotation = {}
+ else:
+ try:
+ with open(options.annotation, "r") as annotation_file:
+ annotation = yaml.safe_load(annotation_file)
+
+ except yaml.YAMLError:
+ raise StackAnalyzerError(
+ "Failed to parse annotation file {}.".format(options.annotation)
+ )
+ except IOError:
+ raise StackAnalyzerError(
+ "Failed to open annotation file {}.".format(options.annotation)
+ )
+
+ # TODO(cheyuw): Do complete annotation format verification.
+ if not isinstance(annotation, dict):
+ raise StackAnalyzerError(
+ "Invalid annotation file {}.".format(options.annotation)
+ )
+
+ # Generate and parse the symbols.
+ try:
+ symbol_text = subprocess.check_output(
+ [options.objdump, "-t", options.elf_path], encoding="utf-8"
+ )
+ rodata_text = subprocess.check_output(
+ [options.objdump, "-s", "-j", ".rodata", options.elf_path],
+ encoding="utf-8",
+ )
+ except subprocess.CalledProcessError:
+ raise StackAnalyzerError("objdump failed to dump symbol table or rodata.")
+ except OSError:
+ raise StackAnalyzerError("Failed to run objdump.")
+
+ symbols = ParseSymbolText(symbol_text)
+ rodata = ParseRoDataText(rodata_text)
+
+ # Load the tasklist.
+ try:
+ export_taskinfo = ctypes.CDLL(options.export_taskinfo)
+ except OSError:
+ raise StackAnalyzerError("Failed to load export_taskinfo.")
+
+ tasklist = LoadTasklist(options.section, export_taskinfo, symbols)
+
+ analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation)
+ analyzer.Analyze()
+ except StackAnalyzerError as e:
+ print("Error: {}".format(e))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/extra/stack_analyzer/stack_analyzer_unittest.py b/extra/stack_analyzer/stack_analyzer_unittest.py
index c36fa9da45..ad2837a8a4 100755
--- a/extra/stack_analyzer/stack_analyzer_unittest.py
+++ b/extra/stack_analyzer/stack_analyzer_unittest.py
@@ -11,820 +11,924 @@
from __future__ import print_function
-import mock
import os
import subprocess
import unittest
+import mock
import stack_analyzer as sa
class ObjectTest(unittest.TestCase):
- """Tests for classes of basic objects."""
-
- def testTask(self):
- task_a = sa.Task('a', 'a_task', 1234)
- task_b = sa.Task('b', 'b_task', 5678, 0x1000)
- self.assertEqual(task_a, task_a)
- self.assertNotEqual(task_a, task_b)
- self.assertNotEqual(task_a, None)
-
- def testSymbol(self):
- symbol_a = sa.Symbol(0x1234, 'F', 32, 'a')
- symbol_b = sa.Symbol(0x234, 'O', 42, 'b')
- self.assertEqual(symbol_a, symbol_a)
- self.assertNotEqual(symbol_a, symbol_b)
- self.assertNotEqual(symbol_a, None)
-
- def testCallsite(self):
- callsite_a = sa.Callsite(0x1002, 0x3000, False)
- callsite_b = sa.Callsite(0x1002, 0x3000, True)
- self.assertEqual(callsite_a, callsite_a)
- self.assertNotEqual(callsite_a, callsite_b)
- self.assertNotEqual(callsite_a, None)
-
- def testFunction(self):
- func_a = sa.Function(0x100, 'a', 0, [])
- func_b = sa.Function(0x200, 'b', 0, [])
- self.assertEqual(func_a, func_a)
- self.assertNotEqual(func_a, func_b)
- self.assertNotEqual(func_a, None)
+ """Tests for classes of basic objects."""
+
+ def testTask(self):
+ task_a = sa.Task("a", "a_task", 1234)
+ task_b = sa.Task("b", "b_task", 5678, 0x1000)
+ self.assertEqual(task_a, task_a)
+ self.assertNotEqual(task_a, task_b)
+ self.assertNotEqual(task_a, None)
+
+ def testSymbol(self):
+ symbol_a = sa.Symbol(0x1234, "F", 32, "a")
+ symbol_b = sa.Symbol(0x234, "O", 42, "b")
+ self.assertEqual(symbol_a, symbol_a)
+ self.assertNotEqual(symbol_a, symbol_b)
+ self.assertNotEqual(symbol_a, None)
+
+ def testCallsite(self):
+ callsite_a = sa.Callsite(0x1002, 0x3000, False)
+ callsite_b = sa.Callsite(0x1002, 0x3000, True)
+ self.assertEqual(callsite_a, callsite_a)
+ self.assertNotEqual(callsite_a, callsite_b)
+ self.assertNotEqual(callsite_a, None)
+
+ def testFunction(self):
+ func_a = sa.Function(0x100, "a", 0, [])
+ func_b = sa.Function(0x200, "b", 0, [])
+ self.assertEqual(func_a, func_a)
+ self.assertNotEqual(func_a, func_b)
+ self.assertNotEqual(func_a, None)
class ArmAnalyzerTest(unittest.TestCase):
- """Tests for class ArmAnalyzer."""
-
- def AppendConditionCode(self, opcodes):
- rets = []
- for opcode in opcodes:
- rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES)
-
- return rets
-
- def testInstructionMatching(self):
- jump_list = self.AppendConditionCode(['b', 'bx'])
- jump_list += (list(opcode + '.n' for opcode in jump_list) +
- list(opcode + '.w' for opcode in jump_list))
- for opcode in jump_list:
- self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode))
-
- self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('bl'))
- self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('blx'))
-
- cbz_list = ['cbz', 'cbnz', 'cbz.n', 'cbnz.n', 'cbz.w', 'cbnz.w']
- for opcode in cbz_list:
- self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode))
-
- self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match('cbn'))
-
- call_list = self.AppendConditionCode(['bl', 'blx'])
- call_list += list(opcode + '.n' for opcode in call_list)
- for opcode in call_list:
- self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode))
-
- self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match('ble'))
-
- result = sa.ArmAnalyzer.CALL_OPERAND_RE.match('53f90 <get_time+0x18>')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '53f90')
- self.assertEqual(result.group(2), 'get_time+0x18')
-
- result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match('r6, 53f90 <get+0x0>')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '53f90')
- self.assertEqual(result.group(2), 'get+0x0')
-
- self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('push'))
- self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('pushal'))
- self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('stmdb'))
- self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('lstm'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subw'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub.w'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs.w'))
-
- result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, sp, #1668 ; 0x684')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '1668')
- result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, #1668')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '1668')
- self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match('sl, #1668'))
-
- def testAnalyzeFunction(self):
- analyzer = sa.ArmAnalyzer()
- symbol = sa.Symbol(0x10, 'F', 0x100, 'foo')
- instructions = [
- (0x10, 'push', '{r4, r5, r6, r7, lr}'),
- (0x12, 'subw', 'sp, sp, #16 ; 0x10'),
- (0x16, 'movs', 'lr, r1'),
- (0x18, 'beq.n', '26 <foo+0x26>'),
- (0x1a, 'bl', '30 <foo+0x30>'),
- (0x1e, 'bl', 'deadbeef <bar>'),
- (0x22, 'blx', '0 <woo>'),
- (0x26, 'push', '{r1}'),
- (0x28, 'stmdb', 'sp!, {r4, r5, r6, r7, r8, r9, lr}'),
- (0x2c, 'stmdb', 'sp!, {r4}'),
- (0x30, 'stmdb', 'sp, {r4}'),
- (0x34, 'bx.n', '10 <foo>'),
- (0x36, 'bx.n', 'r3'),
- (0x38, 'ldr', 'pc, [r10]'),
- ]
- (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions)
- self.assertEqual(size, 72)
- expect_callsites = [sa.Callsite(0x1e, 0xdeadbeef, False),
- sa.Callsite(0x22, 0x0, False),
- sa.Callsite(0x34, 0x10, True),
- sa.Callsite(0x36, None, True),
- sa.Callsite(0x38, None, True)]
- self.assertEqual(callsites, expect_callsites)
+ """Tests for class ArmAnalyzer."""
+
+ def AppendConditionCode(self, opcodes):
+ rets = []
+ for opcode in opcodes:
+ rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES)
+
+ return rets
+
+ def testInstructionMatching(self):
+ jump_list = self.AppendConditionCode(["b", "bx"])
+ jump_list += list(opcode + ".n" for opcode in jump_list) + list(
+ opcode + ".w" for opcode in jump_list
+ )
+ for opcode in jump_list:
+ self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode))
+
+ self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("bl"))
+ self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("blx"))
+
+ cbz_list = ["cbz", "cbnz", "cbz.n", "cbnz.n", "cbz.w", "cbnz.w"]
+ for opcode in cbz_list:
+ self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode))
+
+ self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match("cbn"))
+
+ call_list = self.AppendConditionCode(["bl", "blx"])
+ call_list += list(opcode + ".n" for opcode in call_list)
+ for opcode in call_list:
+ self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode))
+
+ self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match("ble"))
+
+ result = sa.ArmAnalyzer.CALL_OPERAND_RE.match("53f90 <get_time+0x18>")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "53f90")
+ self.assertEqual(result.group(2), "get_time+0x18")
+
+ result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match("r6, 53f90 <get+0x0>")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "53f90")
+ self.assertEqual(result.group(2), "get+0x0")
+
+ self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("push"))
+ self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("pushal"))
+ self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("stmdb"))
+ self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("lstm"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subw"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub.w"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs.w"))
+
+ result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, sp, #1668 ; 0x684")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "1668")
+ result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, #1668")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "1668")
+ self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match("sl, #1668"))
+
+ def testAnalyzeFunction(self):
+ analyzer = sa.ArmAnalyzer()
+ symbol = sa.Symbol(0x10, "F", 0x100, "foo")
+ instructions = [
+ (0x10, "push", "{r4, r5, r6, r7, lr}"),
+ (0x12, "subw", "sp, sp, #16 ; 0x10"),
+ (0x16, "movs", "lr, r1"),
+ (0x18, "beq.n", "26 <foo+0x26>"),
+ (0x1A, "bl", "30 <foo+0x30>"),
+ (0x1E, "bl", "deadbeef <bar>"),
+ (0x22, "blx", "0 <woo>"),
+ (0x26, "push", "{r1}"),
+ (0x28, "stmdb", "sp!, {r4, r5, r6, r7, r8, r9, lr}"),
+ (0x2C, "stmdb", "sp!, {r4}"),
+ (0x30, "stmdb", "sp, {r4}"),
+ (0x34, "bx.n", "10 <foo>"),
+ (0x36, "bx.n", "r3"),
+ (0x38, "ldr", "pc, [r10]"),
+ ]
+ (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions)
+ self.assertEqual(size, 72)
+ expect_callsites = [
+ sa.Callsite(0x1E, 0xDEADBEEF, False),
+ sa.Callsite(0x22, 0x0, False),
+ sa.Callsite(0x34, 0x10, True),
+ sa.Callsite(0x36, None, True),
+ sa.Callsite(0x38, None, True),
+ ]
+ self.assertEqual(callsites, expect_callsites)
class StackAnalyzerTest(unittest.TestCase):
- """Tests for class StackAnalyzer."""
-
- def setUp(self):
- symbols = [sa.Symbol(0x1000, 'F', 0x15C, 'hook_task'),
- sa.Symbol(0x2000, 'F', 0x51C, 'console_task'),
- sa.Symbol(0x3200, 'O', 0x124, '__just_data'),
- sa.Symbol(0x4000, 'F', 0x11C, 'touchpad_calc'),
- sa.Symbol(0x5000, 'F', 0x12C, 'touchpad_calc.constprop.42'),
- sa.Symbol(0x12000, 'F', 0x13C, 'trackpad_range'),
- sa.Symbol(0x13000, 'F', 0x200, 'inlined_mul'),
- sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul'),
- sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul_alias'),
- sa.Symbol(0x20000, 'O', 0x0, '__array'),
- sa.Symbol(0x20010, 'O', 0x0, '__array_end'),
- ]
- tasklist = [sa.Task('HOOKS', 'hook_task', 2048, 0x1000),
- sa.Task('CONSOLE', 'console_task', 460, 0x2000)]
- # Array at 0x20000 that contains pointers to hook_task and console_task,
- # with stride=8, offset=4
- rodata = (0x20000, [ 0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000 ])
- options = mock.MagicMock(elf_path='./ec.RW.elf',
- export_taskinfo='fake',
- section='RW',
- objdump='objdump',
- addr2line='addr2line',
- annotation=None)
- self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {})
-
- def testParseSymbolText(self):
- symbol_text = (
- '0 g F .text e8 Foo\n'
- '0000dead w F .text 000000e8 .hidden Bar\n'
- 'deadbeef l O .bss 00000004 .hidden Woooo\n'
- 'deadbee g O .rodata 00000008 __Hooo_ooo\n'
- 'deadbee g .rodata 00000000 __foo_doo_coo_end\n'
- )
- symbols = sa.ParseSymbolText(symbol_text)
- expect_symbols = [sa.Symbol(0x0, 'F', 0xe8, 'Foo'),
- sa.Symbol(0xdead, 'F', 0xe8, 'Bar'),
- sa.Symbol(0xdeadbeef, 'O', 0x4, 'Woooo'),
- sa.Symbol(0xdeadbee, 'O', 0x8, '__Hooo_ooo'),
- sa.Symbol(0xdeadbee, 'O', 0x0, '__foo_doo_coo_end')]
- self.assertEqual(symbols, expect_symbols)
-
- def testParseRoData(self):
- rodata_text = (
- '\n'
- 'Contents of section .rodata:\n'
- ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n'
- )
- rodata = sa.ParseRoDataText(rodata_text)
- expect_rodata = (0x20000,
- [ 0x0010adde, 0x00001000, 0x0020adde, 0x00002000 ])
- self.assertEqual(rodata, expect_rodata)
-
- def testLoadTasklist(self):
- def tasklist_to_taskinfos(pointer, tasklist):
- taskinfos = []
- for task in tasklist:
- taskinfos.append(sa.TaskInfo(name=task.name.encode('utf-8'),
- routine=task.routine_name.encode('utf-8'),
- stack_size=task.stack_max_size))
-
- TaskInfoArray = sa.TaskInfo * len(taskinfos)
- pointer.contents.contents = TaskInfoArray(*taskinfos)
- return len(taskinfos)
-
- def ro_taskinfos(pointer):
- return tasklist_to_taskinfos(pointer, expect_ro_tasklist)
-
- def rw_taskinfos(pointer):
- return tasklist_to_taskinfos(pointer, expect_rw_tasklist)
-
- expect_ro_tasklist = [
- sa.Task('HOOKS', 'hook_task', 2048, 0x1000),
- ]
-
- expect_rw_tasklist = [
- sa.Task('HOOKS', 'hook_task', 2048, 0x1000),
- sa.Task('WOOKS', 'hook_task', 4096, 0x1000),
- sa.Task('CONSOLE', 'console_task', 460, 0x2000),
- ]
-
- export_taskinfo = mock.MagicMock(
- get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos),
- get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos))
-
- tasklist = sa.LoadTasklist('RO', export_taskinfo, self.analyzer.symbols)
- self.assertEqual(tasklist, expect_ro_tasklist)
- tasklist = sa.LoadTasklist('RW', export_taskinfo, self.analyzer.symbols)
- self.assertEqual(tasklist, expect_rw_tasklist)
-
- def testResolveAnnotation(self):
- self.analyzer.annotation = {}
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {})
- self.assertEqual(remove_rules, [])
- self.assertEqual(invalid_sigtxts, set())
-
- self.analyzer.annotation = {'add': None, 'remove': None}
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {})
- self.assertEqual(remove_rules, [])
- self.assertEqual(invalid_sigtxts, set())
-
- self.analyzer.annotation = {
- 'add': None,
- 'remove': [
- [['a', 'b'], ['0', '[', '2'], 'x'],
- [['a', 'b[x:3]'], ['0', '1', '2'], 'x'],
- ],
- }
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {})
- self.assertEqual(list.sort(remove_rules), list.sort([
- [('a', None, None), ('1', None, None), ('x', None, None)],
- [('a', None, None), ('0', None, None), ('x', None, None)],
- [('a', None, None), ('2', None, None), ('x', None, None)],
- [('b', os.path.abspath('x'), 3), ('1', None, None), ('x', None, None)],
- [('b', os.path.abspath('x'), 3), ('0', None, None), ('x', None, None)],
- [('b', os.path.abspath('x'), 3), ('2', None, None), ('x', None, None)],
- ]))
- self.assertEqual(invalid_sigtxts, {'['})
-
- self.analyzer.annotation = {
- 'add': {
- 'touchpad_calc': [ dict(name='__array', stride=8, offset=4) ],
+ """Tests for class StackAnalyzer."""
+
+ def setUp(self):
+ symbols = [
+ sa.Symbol(0x1000, "F", 0x15C, "hook_task"),
+ sa.Symbol(0x2000, "F", 0x51C, "console_task"),
+ sa.Symbol(0x3200, "O", 0x124, "__just_data"),
+ sa.Symbol(0x4000, "F", 0x11C, "touchpad_calc"),
+ sa.Symbol(0x5000, "F", 0x12C, "touchpad_calc.constprop.42"),
+ sa.Symbol(0x12000, "F", 0x13C, "trackpad_range"),
+ sa.Symbol(0x13000, "F", 0x200, "inlined_mul"),
+ sa.Symbol(0x13100, "F", 0x200, "inlined_mul"),
+ sa.Symbol(0x13100, "F", 0x200, "inlined_mul_alias"),
+ sa.Symbol(0x20000, "O", 0x0, "__array"),
+ sa.Symbol(0x20010, "O", 0x0, "__array_end"),
+ ]
+ tasklist = [
+ sa.Task("HOOKS", "hook_task", 2048, 0x1000),
+ sa.Task("CONSOLE", "console_task", 460, 0x2000),
+ ]
+ # Array at 0x20000 that contains pointers to hook_task and console_task,
+ # with stride=8, offset=4
+ rodata = (0x20000, [0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000])
+ options = mock.MagicMock(
+ elf_path="./ec.RW.elf",
+ export_taskinfo="fake",
+ section="RW",
+ objdump="objdump",
+ addr2line="addr2line",
+ annotation=None,
+ )
+ self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {})
+
+ def testParseSymbolText(self):
+ symbol_text = (
+ "0 g F .text e8 Foo\n"
+ "0000dead w F .text 000000e8 .hidden Bar\n"
+ "deadbeef l O .bss 00000004 .hidden Woooo\n"
+ "deadbee g O .rodata 00000008 __Hooo_ooo\n"
+ "deadbee g .rodata 00000000 __foo_doo_coo_end\n"
+ )
+ symbols = sa.ParseSymbolText(symbol_text)
+ expect_symbols = [
+ sa.Symbol(0x0, "F", 0xE8, "Foo"),
+ sa.Symbol(0xDEAD, "F", 0xE8, "Bar"),
+ sa.Symbol(0xDEADBEEF, "O", 0x4, "Woooo"),
+ sa.Symbol(0xDEADBEE, "O", 0x8, "__Hooo_ooo"),
+ sa.Symbol(0xDEADBEE, "O", 0x0, "__foo_doo_coo_end"),
+ ]
+ self.assertEqual(symbols, expect_symbols)
+
+ def testParseRoData(self):
+ rodata_text = (
+ "\n"
+ "Contents of section .rodata:\n"
+ " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n"
+ )
+ rodata = sa.ParseRoDataText(rodata_text)
+ expect_rodata = (0x20000, [0x0010ADDE, 0x00001000, 0x0020ADDE, 0x00002000])
+ self.assertEqual(rodata, expect_rodata)
+
+ def testLoadTasklist(self):
+ def tasklist_to_taskinfos(pointer, tasklist):
+ taskinfos = []
+ for task in tasklist:
+ taskinfos.append(
+ sa.TaskInfo(
+ name=task.name.encode("utf-8"),
+ routine=task.routine_name.encode("utf-8"),
+ stack_size=task.stack_max_size,
+ )
+ )
+
+ TaskInfoArray = sa.TaskInfo * len(taskinfos)
+ pointer.contents.contents = TaskInfoArray(*taskinfos)
+ return len(taskinfos)
+
+ def ro_taskinfos(pointer):
+ return tasklist_to_taskinfos(pointer, expect_ro_tasklist)
+
+ def rw_taskinfos(pointer):
+ return tasklist_to_taskinfos(pointer, expect_rw_tasklist)
+
+ expect_ro_tasklist = [
+ sa.Task("HOOKS", "hook_task", 2048, 0x1000),
+ ]
+
+ expect_rw_tasklist = [
+ sa.Task("HOOKS", "hook_task", 2048, 0x1000),
+ sa.Task("WOOKS", "hook_task", 4096, 0x1000),
+ sa.Task("CONSOLE", "console_task", 460, 0x2000),
+ ]
+
+ export_taskinfo = mock.MagicMock(
+ get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos),
+ get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos),
+ )
+
+ tasklist = sa.LoadTasklist("RO", export_taskinfo, self.analyzer.symbols)
+ self.assertEqual(tasklist, expect_ro_tasklist)
+ tasklist = sa.LoadTasklist("RW", export_taskinfo, self.analyzer.symbols)
+ self.assertEqual(tasklist, expect_rw_tasklist)
+
+ def testResolveAnnotation(self):
+ self.analyzer.annotation = {}
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(add_rules, {})
+ self.assertEqual(remove_rules, [])
+ self.assertEqual(invalid_sigtxts, set())
+
+ self.analyzer.annotation = {"add": None, "remove": None}
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(add_rules, {})
+ self.assertEqual(remove_rules, [])
+ self.assertEqual(invalid_sigtxts, set())
+
+ self.analyzer.annotation = {
+ "add": None,
+ "remove": [
+ [["a", "b"], ["0", "[", "2"], "x"],
+ [["a", "b[x:3]"], ["0", "1", "2"], "x"],
+ ],
+ }
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(add_rules, {})
+ self.assertEqual(
+ list.sort(remove_rules),
+ list.sort(
+ [
+ [("a", None, None), ("1", None, None), ("x", None, None)],
+ [("a", None, None), ("0", None, None), ("x", None, None)],
+ [("a", None, None), ("2", None, None), ("x", None, None)],
+ [
+ ("b", os.path.abspath("x"), 3),
+ ("1", None, None),
+ ("x", None, None),
+ ],
+ [
+ ("b", os.path.abspath("x"), 3),
+ ("0", None, None),
+ ("x", None, None),
+ ],
+ [
+ ("b", os.path.abspath("x"), 3),
+ ("2", None, None),
+ ("x", None, None),
+ ],
+ ]
+ ),
+ )
+ self.assertEqual(invalid_sigtxts, {"["})
+
+ self.analyzer.annotation = {
+ "add": {
+ "touchpad_calc": [dict(name="__array", stride=8, offset=4)],
+ }
+ }
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(
+ add_rules,
+ {
+ ("touchpad_calc", None, None): set(
+ [("console_task", None, None), ("hook_task", None, None)]
+ )
+ },
+ )
+
+ funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 0, []),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ 0x5000: sa.Function(0x5000, "touchpad_calc.constprop.42", 0, []),
+ 0x13000: sa.Function(0x13000, "inlined_mul", 0, []),
+ 0x13100: sa.Function(0x13100, "inlined_mul", 0, []),
+ }
+ funcs[0x1000].callsites = [sa.Callsite(0x1002, None, False, None)]
+ # Set address_to_line_cache to fake the results of addr2line.
+ self.analyzer.address_to_line_cache = {
+ (0x1000, False): [("hook_task", os.path.abspath("a.c"), 10)],
+ (0x1002, False): [("toot_calc", os.path.abspath("t.c"), 1234)],
+ (0x2000, False): [("console_task", os.path.abspath("b.c"), 20)],
+ (0x4000, False): [("toudhpad_calc", os.path.abspath("a.c"), 20)],
+ (0x5000, False): [
+ ("touchpad_calc.constprop.42", os.path.abspath("b.c"), 40)
+ ],
+ (0x12000, False): [("trackpad_range", os.path.abspath("t.c"), 10)],
+ (0x13000, False): [("inlined_mul", os.path.abspath("x.c"), 12)],
+ (0x13100, False): [("inlined_mul", os.path.abspath("x.c"), 12)],
+ }
+ self.analyzer.annotation = {
+ "add": {
+ "hook_task.lto.573": ["touchpad_calc.lto.2501[a.c]"],
+ "console_task": ["touchpad_calc[b.c]", "inlined_mul_alias"],
+ "hook_task[q.c]": ["hook_task"],
+ "inlined_mul[x.c]": ["inlined_mul"],
+ "toot_calc[t.c:1234]": ["hook_task"],
+ },
+ "remove": [
+ ["touchpad?calc["],
+ "touchpad_calc",
+ ["touchpad_calc[a.c]"],
+ ["task_unk[a.c]"],
+ ["touchpad_calc[x/a.c]"],
+ ["trackpad_range"],
+ ["inlined_mul"],
+ ["inlined_mul", "console_task", "touchpad_calc[a.c]"],
+ ["inlined_mul", "inlined_mul_alias", "console_task"],
+ ["inlined_mul", "inlined_mul_alias", "console_task"],
+ ],
+ }
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(invalid_sigtxts, {"touchpad?calc["})
+
+ signature_set = set()
+ for src_sig, dst_sigs in add_rules.items():
+ signature_set.add(src_sig)
+ signature_set.update(dst_sigs)
+
+ for remove_sigs in remove_rules:
+ signature_set.update(remove_sigs)
+
+ (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, signature_set)
+ result = self.analyzer.ResolveAnnotation(funcs)
+ (add_set, remove_list, eliminated_addrs, failed_sigs) = result
+
+ expect_signature_map = {
+ ("hook_task", None, None): {funcs[0x1000]},
+ ("touchpad_calc", os.path.abspath("a.c"), None): {funcs[0x4000]},
+ ("touchpad_calc", os.path.abspath("b.c"), None): {funcs[0x5000]},
+ ("console_task", None, None): {funcs[0x2000]},
+ ("inlined_mul_alias", None, None): {funcs[0x13100]},
+ ("inlined_mul", os.path.abspath("x.c"), None): {
+ funcs[0x13000],
+ funcs[0x13100],
+ },
+ ("inlined_mul", None, None): {funcs[0x13000], funcs[0x13100]},
+ }
+ self.assertEqual(len(signature_map), len(expect_signature_map))
+ for sig, funclist in signature_map.items():
+ self.assertEqual(set(funclist), expect_signature_map[sig])
+
+ self.assertEqual(
+ add_set,
+ {
+ (funcs[0x1000], funcs[0x4000]),
+ (funcs[0x1000], funcs[0x1000]),
+ (funcs[0x2000], funcs[0x5000]),
+ (funcs[0x2000], funcs[0x13100]),
+ (funcs[0x13000], funcs[0x13000]),
+ (funcs[0x13000], funcs[0x13100]),
+ (funcs[0x13100], funcs[0x13000]),
+ (funcs[0x13100], funcs[0x13100]),
+ },
+ )
+ expect_remove_list = [
+ [funcs[0x4000]],
+ [funcs[0x13000]],
+ [funcs[0x13100]],
+ [funcs[0x13000], funcs[0x2000], funcs[0x4000]],
+ [funcs[0x13100], funcs[0x2000], funcs[0x4000]],
+ [funcs[0x13000], funcs[0x13100], funcs[0x2000]],
+ [funcs[0x13100], funcs[0x13100], funcs[0x2000]],
+ ]
+ self.assertEqual(len(remove_list), len(expect_remove_list))
+ for remove_path in remove_list:
+ self.assertTrue(remove_path in expect_remove_list)
+
+ self.assertEqual(eliminated_addrs, {0x1002})
+ self.assertEqual(
+ failed_sigs,
+ {
+ ("touchpad?calc[", sa.StackAnalyzer.ANNOTATION_ERROR_INVALID),
+ ("touchpad_calc", sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS),
+ ("hook_task[q.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ ("task_unk[a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ ("touchpad_calc[x/a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ ("trackpad_range", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ },
+ )
+
+ def testPreprocessAnnotation(self):
+ funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 0, []),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ funcs[0x1000].callsites = [sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])]
+ funcs[0x2000].callsites = [
+ sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]),
+ sa.Callsite(0x2006, None, True, None),
+ ]
+ add_set = {
+ (funcs[0x2000], funcs[0x2000]),
+ (funcs[0x2000], funcs[0x4000]),
+ (funcs[0x4000], funcs[0x1000]),
+ (funcs[0x4000], funcs[0x2000]),
}
- }
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {
- ('touchpad_calc', None, None):
- set([('console_task', None, None), ('hook_task', None, None)])})
-
- funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 0, []),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- 0x5000: sa.Function(0x5000, 'touchpad_calc.constprop.42', 0, []),
- 0x13000: sa.Function(0x13000, 'inlined_mul', 0, []),
- 0x13100: sa.Function(0x13100, 'inlined_mul', 0, []),
- }
- funcs[0x1000].callsites = [
- sa.Callsite(0x1002, None, False, None)]
- # Set address_to_line_cache to fake the results of addr2line.
- self.analyzer.address_to_line_cache = {
- (0x1000, False): [('hook_task', os.path.abspath('a.c'), 10)],
- (0x1002, False): [('toot_calc', os.path.abspath('t.c'), 1234)],
- (0x2000, False): [('console_task', os.path.abspath('b.c'), 20)],
- (0x4000, False): [('toudhpad_calc', os.path.abspath('a.c'), 20)],
- (0x5000, False): [
- ('touchpad_calc.constprop.42', os.path.abspath('b.c'), 40)],
- (0x12000, False): [('trackpad_range', os.path.abspath('t.c'), 10)],
- (0x13000, False): [('inlined_mul', os.path.abspath('x.c'), 12)],
- (0x13100, False): [('inlined_mul', os.path.abspath('x.c'), 12)],
- }
- self.analyzer.annotation = {
- 'add': {
- 'hook_task.lto.573': ['touchpad_calc.lto.2501[a.c]'],
- 'console_task': ['touchpad_calc[b.c]', 'inlined_mul_alias'],
- 'hook_task[q.c]': ['hook_task'],
- 'inlined_mul[x.c]': ['inlined_mul'],
- 'toot_calc[t.c:1234]': ['hook_task'],
- },
- 'remove': [
- ['touchpad?calc['],
- 'touchpad_calc',
- ['touchpad_calc[a.c]'],
- ['task_unk[a.c]'],
- ['touchpad_calc[x/a.c]'],
- ['trackpad_range'],
- ['inlined_mul'],
- ['inlined_mul', 'console_task', 'touchpad_calc[a.c]'],
- ['inlined_mul', 'inlined_mul_alias', 'console_task'],
- ['inlined_mul', 'inlined_mul_alias', 'console_task'],
- ],
- }
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(invalid_sigtxts, {'touchpad?calc['})
-
- signature_set = set()
- for src_sig, dst_sigs in add_rules.items():
- signature_set.add(src_sig)
- signature_set.update(dst_sigs)
-
- for remove_sigs in remove_rules:
- signature_set.update(remove_sigs)
-
- (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs,
- signature_set)
- result = self.analyzer.ResolveAnnotation(funcs)
- (add_set, remove_list, eliminated_addrs, failed_sigs) = result
-
- expect_signature_map = {
- ('hook_task', None, None): {funcs[0x1000]},
- ('touchpad_calc', os.path.abspath('a.c'), None): {funcs[0x4000]},
- ('touchpad_calc', os.path.abspath('b.c'), None): {funcs[0x5000]},
- ('console_task', None, None): {funcs[0x2000]},
- ('inlined_mul_alias', None, None): {funcs[0x13100]},
- ('inlined_mul', os.path.abspath('x.c'), None): {funcs[0x13000],
- funcs[0x13100]},
- ('inlined_mul', None, None): {funcs[0x13000], funcs[0x13100]},
- }
- self.assertEqual(len(signature_map), len(expect_signature_map))
- for sig, funclist in signature_map.items():
- self.assertEqual(set(funclist), expect_signature_map[sig])
-
- self.assertEqual(add_set, {
- (funcs[0x1000], funcs[0x4000]),
- (funcs[0x1000], funcs[0x1000]),
- (funcs[0x2000], funcs[0x5000]),
- (funcs[0x2000], funcs[0x13100]),
- (funcs[0x13000], funcs[0x13000]),
- (funcs[0x13000], funcs[0x13100]),
- (funcs[0x13100], funcs[0x13000]),
- (funcs[0x13100], funcs[0x13100]),
- })
- expect_remove_list = [
- [funcs[0x4000]],
- [funcs[0x13000]],
- [funcs[0x13100]],
- [funcs[0x13000], funcs[0x2000], funcs[0x4000]],
- [funcs[0x13100], funcs[0x2000], funcs[0x4000]],
- [funcs[0x13000], funcs[0x13100], funcs[0x2000]],
- [funcs[0x13100], funcs[0x13100], funcs[0x2000]],
- ]
- self.assertEqual(len(remove_list), len(expect_remove_list))
- for remove_path in remove_list:
- self.assertTrue(remove_path in expect_remove_list)
-
- self.assertEqual(eliminated_addrs, {0x1002})
- self.assertEqual(failed_sigs, {
- ('touchpad?calc[', sa.StackAnalyzer.ANNOTATION_ERROR_INVALID),
- ('touchpad_calc', sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS),
- ('hook_task[q.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- ('task_unk[a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- ('touchpad_calc[x/a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- ('trackpad_range', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- })
-
- def testPreprocessAnnotation(self):
- funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 0, []),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- funcs[0x1000].callsites = [
- sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])]
- funcs[0x2000].callsites = [
- sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]),
- sa.Callsite(0x2006, None, True, None),
- ]
- add_set = {
- (funcs[0x2000], funcs[0x2000]),
- (funcs[0x2000], funcs[0x4000]),
- (funcs[0x4000], funcs[0x1000]),
- (funcs[0x4000], funcs[0x2000]),
- }
- remove_list = [
- [funcs[0x1000]],
- [funcs[0x2000], funcs[0x2000]],
- [funcs[0x4000], funcs[0x1000]],
- [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
- [funcs[0x4000], funcs[0x1000], funcs[0x4000]],
- ]
- eliminated_addrs = {0x2006}
-
- remaining_remove_list = self.analyzer.PreprocessAnnotation(funcs,
- add_set,
- remove_list,
- eliminated_addrs)
-
- expect_funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 0, []),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- expect_funcs[0x2000].callsites = [
- sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])]
- expect_funcs[0x4000].callsites = [
- sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])]
- self.assertEqual(funcs, expect_funcs)
- self.assertEqual(remaining_remove_list, [
- [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
- ])
-
- def testAndesAnalyzeDisassembly(self):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n'
- ' 1004: 47 70\t\tmovi55 $r0, #1\n'
- ' 1006: b1 13\tbnezs8 100929de <flash_command_write>\n'
- ' 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n'
- '00002000 <console_task>:\n'
- ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n'
- ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n'
- ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n'
- ' 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n'
- '00004000 <touchpad_calc>:\n'
- ' 4000: 47 70\t\tmovi55 $r0, #1\n'
- '00010000 <look_task>:'
- )
- function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
- func_hook_task = sa.Function(0x1000, 'hook_task', 48, [
- sa.Callsite(0x1006, 0x100929de, True, None)])
- expect_funcmap = {
- 0x1000: func_hook_task,
- 0x2000: sa.Function(0x2000, 'console_task', 16,
- [sa.Callsite(0x2002, 0x1000, False, func_hook_task),
- sa.Callsite(0x2006, 0x53968, True, None)]),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- self.assertEqual(function_map, expect_funcmap)
-
- def testArmAnalyzeDisassembly(self):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: dead beef\tfake\n'
- ' 1004: 4770\t\tbx lr\n'
- ' 1006: b113\tcbz r3, 100929de <flash_command_write>\n'
- ' 1008: 00015cfc\t.word 0x00015cfc\n'
- '00002000 <console_task>:\n'
- ' 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n'
- ' 2002: f00e fcc5\tbl 1000 <hook_task>\n'
- ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n'
- ' 200a: dead beef\tfake\n'
- '00004000 <touchpad_calc>:\n'
- ' 4000: 4770\t\tbx lr\n'
- '00010000 <look_task>:'
- )
- function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
- func_hook_task = sa.Function(0x1000, 'hook_task', 0, [
- sa.Callsite(0x1006, 0x100929de, True, None)])
- expect_funcmap = {
- 0x1000: func_hook_task,
- 0x2000: sa.Function(0x2000, 'console_task', 8,
- [sa.Callsite(0x2002, 0x1000, False, func_hook_task),
- sa.Callsite(0x2006, 0x53968, True, None)]),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- self.assertEqual(function_map, expect_funcmap)
-
- def testAnalyzeCallGraph(self):
- funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 8, []),
- 0x3000: sa.Function(0x3000, 'task_a', 12, []),
- 0x4000: sa.Function(0x4000, 'task_b', 96, []),
- 0x5000: sa.Function(0x5000, 'task_c', 32, []),
- 0x6000: sa.Function(0x6000, 'task_d', 100, []),
- 0x7000: sa.Function(0x7000, 'task_e', 24, []),
- 0x8000: sa.Function(0x8000, 'task_f', 20, []),
- 0x9000: sa.Function(0x9000, 'task_g', 20, []),
- 0x10000: sa.Function(0x10000, 'task_x', 16, []),
- }
- funcs[0x1000].callsites = [
- sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]),
- sa.Callsite(0x1006, 0x4000, False, funcs[0x4000])]
- funcs[0x2000].callsites = [
- sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]),
- sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]),
- sa.Callsite(0x200a, 0x10000, False, funcs[0x10000])]
- funcs[0x3000].callsites = [
- sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]),
- sa.Callsite(0x3006, 0x1000, False, funcs[0x1000])]
- funcs[0x4000].callsites = [
- sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]),
- sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]),
- sa.Callsite(0x400a, 0x8000, False, funcs[0x8000])]
- funcs[0x5000].callsites = [
- sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])]
- funcs[0x7000].callsites = [
- sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])]
- funcs[0x8000].callsites = [
- sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])]
- funcs[0x9000].callsites = [
- sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])]
- funcs[0x10000].callsites = [
- sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])]
-
- cycles = self.analyzer.AnalyzeCallGraph(funcs, [
- [funcs[0x2000]] * 2,
- [funcs[0x10000], funcs[0x2000]] * 3,
- [funcs[0x1000], funcs[0x3000], funcs[0x1000]]
- ])
-
- expect_func_stack = {
- 0x1000: (268, [funcs[0x1000],
- funcs[0x3000],
- funcs[0x4000],
- funcs[0x8000],
- funcs[0x9000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x2000: (208, [funcs[0x2000],
- funcs[0x10000],
- funcs[0x2000],
- funcs[0x10000],
- funcs[0x2000],
- funcs[0x5000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x3000: (280, [funcs[0x3000],
- funcs[0x1000],
- funcs[0x3000],
- funcs[0x4000],
- funcs[0x8000],
- funcs[0x9000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x4000: (120, [funcs[0x4000], funcs[0x7000]]),
- 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]),
- 0x6000: (100, [funcs[0x6000]]),
- 0x7000: (24, [funcs[0x7000]]),
- 0x8000: (160, [funcs[0x8000],
- funcs[0x9000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]),
- 0x10000: (200, [funcs[0x10000],
- funcs[0x2000],
- funcs[0x10000],
- funcs[0x2000],
- funcs[0x5000],
- funcs[0x4000],
- funcs[0x7000]]),
- }
- expect_cycles = [
- {funcs[0x4000], funcs[0x8000], funcs[0x9000]},
- {funcs[0x7000]},
- ]
- for func in funcs.values():
- (stack_max_usage, stack_max_path) = expect_func_stack[func.address]
- self.assertEqual(func.stack_max_usage, stack_max_usage)
- self.assertEqual(func.stack_max_path, stack_max_path)
-
- self.assertEqual(len(cycles), len(expect_cycles))
- for cycle in cycles:
- self.assertTrue(cycle in expect_cycles)
-
- @mock.patch('subprocess.check_output')
- def testAddressToLine(self, checkoutput_mock):
- checkoutput_mock.return_value = 'fake_func\n/test.c:1'
- self.assertEqual(self.analyzer.AddressToLine(0x1234),
- [('fake_func', '/test.c', 1)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '1234'], encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- checkoutput_mock.return_value = 'fake_func\n/a.c:1\nbake_func\n/b.c:2\n'
- self.assertEqual(self.analyzer.AddressToLine(0x1234, True),
- [('fake_func', '/a.c', 1), ('bake_func', '/b.c', 2)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '1234', '-i'],
- encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- checkoutput_mock.return_value = 'fake_func\n/test.c:1 (discriminator 128)'
- self.assertEqual(self.analyzer.AddressToLine(0x12345),
- [('fake_func', '/test.c', 1)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '12345'], encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- checkoutput_mock.return_value = '??\n:?\nbake_func\n/b.c:2\n'
- self.assertEqual(self.analyzer.AddressToLine(0x123456),
- [None, ('bake_func', '/b.c', 2)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '123456'], encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'addr2line failed to resolve lines.'):
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- self.analyzer.AddressToLine(0x5678)
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'Failed to run addr2line.'):
- checkoutput_mock.side_effect = OSError()
- self.analyzer.AddressToLine(0x9012)
-
- @mock.patch('subprocess.check_output')
- @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine')
- def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n'
- ' 1002: 47 70\t\tmovi55 $r0, #1\n'
- ' 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n'
- '00002000 <console_task>:\n'
- ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n'
- ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n'
- ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n'
- ' 200a: 12 34 56 78\tjral5 $r0\n'
- )
-
- addrtoline_mock.return_value = [('??', '??', 0)]
- self.analyzer.annotation = {
- 'exception_frame_size': 64,
- 'remove': [['fake_func']],
- }
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.return_value = disasm_text
- self.analyzer.Analyze()
- print_mock.assert_has_calls([
- mock.call(
- 'Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048'),
- mock.call('Call Trace:'),
- mock.call(' hook_task (32) [??:0] 1000'),
- mock.call(
- 'Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460'),
- mock.call('Call Trace:'),
- mock.call(' console_task (16) [??:0] 2000'),
- mock.call(' -> ??[??:0] 2002'),
- mock.call(' hook_task (32) [??:0] 1000'),
- mock.call('Unresolved indirect callsites:'),
- mock.call(' In function console_task:'),
- mock.call(' -> ??[??:0] 200a'),
- mock.call('Unresolved annotation signatures:'),
- mock.call(' fake_func: function is not found'),
- ])
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'Failed to run objdump.'):
- checkoutput_mock.side_effect = OSError()
- self.analyzer.Analyze()
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'objdump failed to disassemble.'):
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- self.analyzer.Analyze()
-
- @mock.patch('subprocess.check_output')
- @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine')
- def testArmAnalyze(self, addrtoline_mock, checkoutput_mock):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: b508\t\tpush {r3, lr}\n'
- ' 1002: 4770\t\tbx lr\n'
- ' 1006: 00015cfc\t.word 0x00015cfc\n'
- '00002000 <console_task>:\n'
- ' 2000: b508\t\tpush {r3, lr}\n'
- ' 2002: f00e fcc5\tbl 1000 <hook_task>\n'
- ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n'
- ' 200a: 1234 5678\tb.w sl\n'
- )
-
- addrtoline_mock.return_value = [('??', '??', 0)]
- self.analyzer.annotation = {
- 'exception_frame_size': 64,
- 'remove': [['fake_func']],
- }
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.return_value = disasm_text
- self.analyzer.Analyze()
- print_mock.assert_has_calls([
- mock.call(
- 'Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048'),
- mock.call('Call Trace:'),
- mock.call(' hook_task (8) [??:0] 1000'),
- mock.call(
- 'Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460'),
- mock.call('Call Trace:'),
- mock.call(' console_task (8) [??:0] 2000'),
- mock.call(' -> ??[??:0] 2002'),
- mock.call(' hook_task (8) [??:0] 1000'),
- mock.call('Unresolved indirect callsites:'),
- mock.call(' In function console_task:'),
- mock.call(' -> ??[??:0] 200a'),
- mock.call('Unresolved annotation signatures:'),
- mock.call(' fake_func: function is not found'),
- ])
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'Failed to run objdump.'):
- checkoutput_mock.side_effect = OSError()
- self.analyzer.Analyze()
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'objdump failed to disassemble.'):
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- self.analyzer.Analyze()
-
- @mock.patch('subprocess.check_output')
- @mock.patch('stack_analyzer.ParseArgs')
- def testMain(self, parseargs_mock, checkoutput_mock):
- symbol_text = ('1000 g F .text 0000015c .hidden hook_task\n'
- '2000 g F .text 0000051c .hidden console_task\n')
- rodata_text = (
- '\n'
- 'Contents of section .rodata:\n'
- ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n'
- )
-
- args = mock.MagicMock(elf_path='./ec.RW.elf',
- export_taskinfo='fake',
- section='RW',
- objdump='objdump',
- addr2line='addr2line',
- annotation='fake')
- parseargs_mock.return_value = args
-
- with mock.patch('os.path.exists') as path_mock:
- path_mock.return_value = False
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open', mock.mock_open()) as open_mock:
- sa.main()
- print_mock.assert_any_call(
- 'Warning: Annotation file fake does not exist.')
-
- with mock.patch('os.path.exists') as path_mock:
- path_mock.return_value = True
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open', mock.mock_open()) as open_mock:
- open_mock.side_effect = IOError()
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: Failed to open annotation file fake.')
-
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open', mock.mock_open()) as open_mock:
- open_mock.return_value.read.side_effect = ['{', '']
- sa.main()
- open_mock.assert_called_once_with('fake', 'r')
- print_mock.assert_called_once_with(
- 'Error: Failed to parse annotation file fake.')
-
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open',
- mock.mock_open(read_data='')) as open_mock:
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: Invalid annotation file fake.')
-
- args.annotation = None
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.side_effect = [symbol_text, rodata_text]
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: Failed to load export_taskinfo.')
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: objdump failed to dump symbol table or rodata.')
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.side_effect = OSError()
- sa.main()
- print_mock.assert_called_once_with('Error: Failed to run objdump.')
-
-
-if __name__ == '__main__':
- unittest.main()
+ remove_list = [
+ [funcs[0x1000]],
+ [funcs[0x2000], funcs[0x2000]],
+ [funcs[0x4000], funcs[0x1000]],
+ [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
+ [funcs[0x4000], funcs[0x1000], funcs[0x4000]],
+ ]
+ eliminated_addrs = {0x2006}
+
+ remaining_remove_list = self.analyzer.PreprocessAnnotation(
+ funcs, add_set, remove_list, eliminated_addrs
+ )
+
+ expect_funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 0, []),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ expect_funcs[0x2000].callsites = [
+ sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])
+ ]
+ expect_funcs[0x4000].callsites = [
+ sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])
+ ]
+ self.assertEqual(funcs, expect_funcs)
+ self.assertEqual(
+ remaining_remove_list,
+ [
+ [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
+ ],
+ )
+
+ def testAndesAnalyzeDisassembly(self):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n"
+ " 1004: 47 70\t\tmovi55 $r0, #1\n"
+ " 1006: b1 13\tbnezs8 100929de <flash_command_write>\n"
+ " 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n"
+ "00002000 <console_task>:\n"
+ " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n"
+ " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n"
+ " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n"
+ " 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n"
+ "00004000 <touchpad_calc>:\n"
+ " 4000: 47 70\t\tmovi55 $r0, #1\n"
+ "00010000 <look_task>:"
+ )
+ function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
+ func_hook_task = sa.Function(
+ 0x1000, "hook_task", 48, [sa.Callsite(0x1006, 0x100929DE, True, None)]
+ )
+ expect_funcmap = {
+ 0x1000: func_hook_task,
+ 0x2000: sa.Function(
+ 0x2000,
+ "console_task",
+ 16,
+ [
+ sa.Callsite(0x2002, 0x1000, False, func_hook_task),
+ sa.Callsite(0x2006, 0x53968, True, None),
+ ],
+ ),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ self.assertEqual(function_map, expect_funcmap)
+
+ def testArmAnalyzeDisassembly(self):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: dead beef\tfake\n"
+ " 1004: 4770\t\tbx lr\n"
+ " 1006: b113\tcbz r3, 100929de <flash_command_write>\n"
+ " 1008: 00015cfc\t.word 0x00015cfc\n"
+ "00002000 <console_task>:\n"
+ " 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n"
+ " 2002: f00e fcc5\tbl 1000 <hook_task>\n"
+ " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n"
+ " 200a: dead beef\tfake\n"
+ "00004000 <touchpad_calc>:\n"
+ " 4000: 4770\t\tbx lr\n"
+ "00010000 <look_task>:"
+ )
+ function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
+ func_hook_task = sa.Function(
+ 0x1000, "hook_task", 0, [sa.Callsite(0x1006, 0x100929DE, True, None)]
+ )
+ expect_funcmap = {
+ 0x1000: func_hook_task,
+ 0x2000: sa.Function(
+ 0x2000,
+ "console_task",
+ 8,
+ [
+ sa.Callsite(0x2002, 0x1000, False, func_hook_task),
+ sa.Callsite(0x2006, 0x53968, True, None),
+ ],
+ ),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ self.assertEqual(function_map, expect_funcmap)
+
+ def testAnalyzeCallGraph(self):
+ funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 8, []),
+ 0x3000: sa.Function(0x3000, "task_a", 12, []),
+ 0x4000: sa.Function(0x4000, "task_b", 96, []),
+ 0x5000: sa.Function(0x5000, "task_c", 32, []),
+ 0x6000: sa.Function(0x6000, "task_d", 100, []),
+ 0x7000: sa.Function(0x7000, "task_e", 24, []),
+ 0x8000: sa.Function(0x8000, "task_f", 20, []),
+ 0x9000: sa.Function(0x9000, "task_g", 20, []),
+ 0x10000: sa.Function(0x10000, "task_x", 16, []),
+ }
+ funcs[0x1000].callsites = [
+ sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]),
+ sa.Callsite(0x1006, 0x4000, False, funcs[0x4000]),
+ ]
+ funcs[0x2000].callsites = [
+ sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]),
+ sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]),
+ sa.Callsite(0x200A, 0x10000, False, funcs[0x10000]),
+ ]
+ funcs[0x3000].callsites = [
+ sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]),
+ sa.Callsite(0x3006, 0x1000, False, funcs[0x1000]),
+ ]
+ funcs[0x4000].callsites = [
+ sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]),
+ sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]),
+ sa.Callsite(0x400A, 0x8000, False, funcs[0x8000]),
+ ]
+ funcs[0x5000].callsites = [sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])]
+ funcs[0x7000].callsites = [sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])]
+ funcs[0x8000].callsites = [sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])]
+ funcs[0x9000].callsites = [sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])]
+ funcs[0x10000].callsites = [sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])]
+
+ cycles = self.analyzer.AnalyzeCallGraph(
+ funcs,
+ [
+ [funcs[0x2000]] * 2,
+ [funcs[0x10000], funcs[0x2000]] * 3,
+ [funcs[0x1000], funcs[0x3000], funcs[0x1000]],
+ ],
+ )
+
+ expect_func_stack = {
+ 0x1000: (
+ 268,
+ [
+ funcs[0x1000],
+ funcs[0x3000],
+ funcs[0x4000],
+ funcs[0x8000],
+ funcs[0x9000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ 0x2000: (
+ 208,
+ [
+ funcs[0x2000],
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x5000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ 0x3000: (
+ 280,
+ [
+ funcs[0x3000],
+ funcs[0x1000],
+ funcs[0x3000],
+ funcs[0x4000],
+ funcs[0x8000],
+ funcs[0x9000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ 0x4000: (120, [funcs[0x4000], funcs[0x7000]]),
+ 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]),
+ 0x6000: (100, [funcs[0x6000]]),
+ 0x7000: (24, [funcs[0x7000]]),
+ 0x8000: (160, [funcs[0x8000], funcs[0x9000], funcs[0x4000], funcs[0x7000]]),
+ 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]),
+ 0x10000: (
+ 200,
+ [
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x5000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ }
+ expect_cycles = [
+ {funcs[0x4000], funcs[0x8000], funcs[0x9000]},
+ {funcs[0x7000]},
+ ]
+ for func in funcs.values():
+ (stack_max_usage, stack_max_path) = expect_func_stack[func.address]
+ self.assertEqual(func.stack_max_usage, stack_max_usage)
+ self.assertEqual(func.stack_max_path, stack_max_path)
+
+ self.assertEqual(len(cycles), len(expect_cycles))
+ for cycle in cycles:
+ self.assertTrue(cycle in expect_cycles)
+
+ @mock.patch("subprocess.check_output")
+ def testAddressToLine(self, checkoutput_mock):
+ checkoutput_mock.return_value = "fake_func\n/test.c:1"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x1234), [("fake_func", "/test.c", 1)]
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "1234"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ checkoutput_mock.return_value = "fake_func\n/a.c:1\nbake_func\n/b.c:2\n"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x1234, True),
+ [("fake_func", "/a.c", 1), ("bake_func", "/b.c", 2)],
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "1234", "-i"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ checkoutput_mock.return_value = "fake_func\n/test.c:1 (discriminator 128)"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x12345), [("fake_func", "/test.c", 1)]
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "12345"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ checkoutput_mock.return_value = "??\n:?\nbake_func\n/b.c:2\n"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x123456), [None, ("bake_func", "/b.c", 2)]
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "123456"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ with self.assertRaisesRegexp(
+ sa.StackAnalyzerError, "addr2line failed to resolve lines."
+ ):
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ self.analyzer.AddressToLine(0x5678)
+
+ with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run addr2line."):
+ checkoutput_mock.side_effect = OSError()
+ self.analyzer.AddressToLine(0x9012)
+
+ @mock.patch("subprocess.check_output")
+ @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine")
+ def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n"
+ " 1002: 47 70\t\tmovi55 $r0, #1\n"
+ " 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n"
+ "00002000 <console_task>:\n"
+ " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n"
+ " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n"
+ " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n"
+ " 200a: 12 34 56 78\tjral5 $r0\n"
+ )
+
+ addrtoline_mock.return_value = [("??", "??", 0)]
+ self.analyzer.annotation = {
+ "exception_frame_size": 64,
+ "remove": [["fake_func"]],
+ }
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.return_value = disasm_text
+ self.analyzer.Analyze()
+ print_mock.assert_has_calls(
+ [
+ mock.call(
+ "Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" hook_task (32) [??:0] 1000"),
+ mock.call(
+ "Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" console_task (16) [??:0] 2000"),
+ mock.call(" -> ??[??:0] 2002"),
+ mock.call(" hook_task (32) [??:0] 1000"),
+ mock.call("Unresolved indirect callsites:"),
+ mock.call(" In function console_task:"),
+ mock.call(" -> ??[??:0] 200a"),
+ mock.call("Unresolved annotation signatures:"),
+ mock.call(" fake_func: function is not found"),
+ ]
+ )
+
+ with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."):
+ checkoutput_mock.side_effect = OSError()
+ self.analyzer.Analyze()
+
+ with self.assertRaisesRegexp(
+ sa.StackAnalyzerError, "objdump failed to disassemble."
+ ):
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ self.analyzer.Analyze()
+
+ @mock.patch("subprocess.check_output")
+ @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine")
+ def testArmAnalyze(self, addrtoline_mock, checkoutput_mock):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: b508\t\tpush {r3, lr}\n"
+ " 1002: 4770\t\tbx lr\n"
+ " 1006: 00015cfc\t.word 0x00015cfc\n"
+ "00002000 <console_task>:\n"
+ " 2000: b508\t\tpush {r3, lr}\n"
+ " 2002: f00e fcc5\tbl 1000 <hook_task>\n"
+ " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n"
+ " 200a: 1234 5678\tb.w sl\n"
+ )
+
+ addrtoline_mock.return_value = [("??", "??", 0)]
+ self.analyzer.annotation = {
+ "exception_frame_size": 64,
+ "remove": [["fake_func"]],
+ }
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.return_value = disasm_text
+ self.analyzer.Analyze()
+ print_mock.assert_has_calls(
+ [
+ mock.call(
+ "Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" hook_task (8) [??:0] 1000"),
+ mock.call(
+ "Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" console_task (8) [??:0] 2000"),
+ mock.call(" -> ??[??:0] 2002"),
+ mock.call(" hook_task (8) [??:0] 1000"),
+ mock.call("Unresolved indirect callsites:"),
+ mock.call(" In function console_task:"),
+ mock.call(" -> ??[??:0] 200a"),
+ mock.call("Unresolved annotation signatures:"),
+ mock.call(" fake_func: function is not found"),
+ ]
+ )
+
+ with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."):
+ checkoutput_mock.side_effect = OSError()
+ self.analyzer.Analyze()
+
+ with self.assertRaisesRegexp(
+ sa.StackAnalyzerError, "objdump failed to disassemble."
+ ):
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ self.analyzer.Analyze()
+
+ @mock.patch("subprocess.check_output")
+ @mock.patch("stack_analyzer.ParseArgs")
+ def testMain(self, parseargs_mock, checkoutput_mock):
+ symbol_text = (
+ "1000 g F .text 0000015c .hidden hook_task\n"
+ "2000 g F .text 0000051c .hidden console_task\n"
+ )
+ rodata_text = (
+ "\n"
+ "Contents of section .rodata:\n"
+ " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n"
+ )
+
+ args = mock.MagicMock(
+ elf_path="./ec.RW.elf",
+ export_taskinfo="fake",
+ section="RW",
+ objdump="objdump",
+ addr2line="addr2line",
+ annotation="fake",
+ )
+ parseargs_mock.return_value = args
+
+ with mock.patch("os.path.exists") as path_mock:
+ path_mock.return_value = False
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch("builtins.open", mock.mock_open()) as open_mock:
+ sa.main()
+ print_mock.assert_any_call(
+ "Warning: Annotation file fake does not exist."
+ )
+
+ with mock.patch("os.path.exists") as path_mock:
+ path_mock.return_value = True
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch("builtins.open", mock.mock_open()) as open_mock:
+ open_mock.side_effect = IOError()
+ sa.main()
+ print_mock.assert_called_once_with(
+ "Error: Failed to open annotation file fake."
+ )
+
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch("builtins.open", mock.mock_open()) as open_mock:
+ open_mock.return_value.read.side_effect = ["{", ""]
+ sa.main()
+ open_mock.assert_called_once_with("fake", "r")
+ print_mock.assert_called_once_with(
+ "Error: Failed to parse annotation file fake."
+ )
+
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch(
+ "builtins.open", mock.mock_open(read_data="")
+ ) as open_mock:
+ sa.main()
+ print_mock.assert_called_once_with(
+ "Error: Invalid annotation file fake."
+ )
+
+ args.annotation = None
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.side_effect = [symbol_text, rodata_text]
+ sa.main()
+ print_mock.assert_called_once_with("Error: Failed to load export_taskinfo.")
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ sa.main()
+ print_mock.assert_called_once_with(
+ "Error: objdump failed to dump symbol table or rodata."
+ )
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.side_effect = OSError()
+ sa.main()
+ print_mock.assert_called_once_with("Error: Failed to run objdump.")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extra/tigertool/ecusb/__init__.py b/extra/tigertool/ecusb/__init__.py
index fe4dbc6749..7a48fdb360 100644
--- a/extra/tigertool/ecusb/__init__.py
+++ b/extra/tigertool/ecusb/__init__.py
@@ -6,4 +6,4 @@
# pylint: disable=bad-indentation,docstring-section-indent
# pylint: disable=docstring-trailing-quotes
-__all__ = ['tiny_servo_common', 'stm32usb', 'stm32uart', 'pty_driver']
+__all__ = ["tiny_servo_common", "stm32usb", "stm32uart", "pty_driver"]
diff --git a/extra/tigertool/ecusb/pty_driver.py b/extra/tigertool/ecusb/pty_driver.py
index 09ef8c42e4..037ff8d529 100644
--- a/extra/tigertool/ecusb/pty_driver.py
+++ b/extra/tigertool/ecusb/pty_driver.py
@@ -17,8 +17,9 @@ import ast
import errno
import fcntl
import os
-import pexpect
import time
+
+import pexpect
from pexpect import fdpexpect
# Expecting a result in 3 seconds is plenty even for slow platforms.
@@ -27,281 +28,285 @@ FLUSH_UART_TIMEOUT = 1
class ptyError(Exception):
- """Exception class for pty errors."""
+ """Exception class for pty errors."""
UART_PARAMS = {
- 'uart_cmd': None,
- 'uart_multicmd': None,
- 'uart_regexp': None,
- 'uart_timeout': DEFAULT_UART_TIMEOUT,
+ "uart_cmd": None,
+ "uart_multicmd": None,
+ "uart_regexp": None,
+ "uart_timeout": DEFAULT_UART_TIMEOUT,
}
class ptyDriver(object):
- """Automate interactive commands on a pty interface."""
- def __init__(self, interface, params, fast=False):
- """Init class variables."""
- self._child = None
- self._fd = None
- self._interface = interface
- self._pty_path = self._interface.get_pty()
- self._dict = UART_PARAMS.copy()
- self._fast = fast
-
- def __del__(self):
- self.close()
-
- def close(self):
- """Close any open files and interfaces."""
- if self._fd:
- self._close()
- self._interface.close()
-
- def _open(self):
- """Connect to serial device and create pexpect interface."""
- assert self._fd is None
- self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK)
- # Don't allow forked processes to access.
- fcntl.fcntl(self._fd, fcntl.F_SETFD,
- fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC)
- self._child = fdpexpect.fdspawn(self._fd)
- # pexpect defaults to a 100ms delay before sending characters, to
- # work around race conditions in ssh. We don't need this feature
- # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up.
- if self._fast:
- self._child.delaybeforesend = 0.001
-
- def _close(self):
- """Close serial device connection."""
- os.close(self._fd)
- self._fd = None
- self._child = None
-
- def _flush(self):
- """Flush device output to prevent previous messages interfering."""
- if self._child.sendline('') != 1:
- raise ptyError('Failed to send newline.')
- # Have a maximum timeout for the flush operation. We should have cleared
- # all data from the buffer, but if data is regularly being generated, we
- # can't guarantee it will ever stop.
- flush_end_time = time.time() + FLUSH_UART_TIMEOUT
- while time.time() <= flush_end_time:
- try:
- self._child.expect('.', timeout=0.01)
- except (pexpect.TIMEOUT, pexpect.EOF):
- break
- except OSError as e:
- # EAGAIN indicates no data available, maybe we didn't wait long enough.
- if e.errno != errno.EAGAIN:
- raise
- break
-
- def _send(self, cmds):
- """Send command to EC.
-
- This function always flushes serial device before sending, and is used as
- a wrapper function to make sure the channel is always flushed before
- sending commands.
-
- Args:
- cmds: The commands to send to the device, either a list or a string.
-
- Raises:
- ptyError: Raised when writing to the device fails.
- """
- self._flush()
- if not isinstance(cmds, list):
- cmds = [cmds]
- for cmd in cmds:
- if self._child.sendline(cmd) != len(cmd) + 1:
- raise ptyError('Failed to send command.')
-
- def _issue_cmd(self, cmds):
- """Send command to the device and do not wait for response.
-
- Args:
- cmds: The commands to send to the device, either a list or a string.
- """
- self._issue_cmd_get_results(cmds, [])
-
- def _issue_cmd_get_results(self, cmds,
- regex_list, timeout=DEFAULT_UART_TIMEOUT):
- """Send command to the device and wait for response.
-
- This function waits for response message matching a regular
- expressions.
-
- Args:
- cmds: The commands issued, either a list or a string.
- regex_list: List of Regular expressions used to match response message.
- Note1, list must be ordered.
- Note2, empty list sends and returns.
- timeout: time to wait for matching results before failing.
-
- Returns:
- List of tuples, each of which contains the entire matched string and
- all the subgroups of the match. None if not matched.
- For example:
- response of the given command:
- High temp: 37.2
- Low temp: 36.4
- regex_list:
- ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)']
- returns:
- [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')]
-
- Raises:
- ptyError: If timed out waiting for a response
- """
- result_list = []
- self._open()
- try:
- self._send(cmds)
- for regex in regex_list:
- self._child.expect(regex, timeout)
- match = self._child.match
- lastindex = match.lastindex if match and match.lastindex else 0
- # Create a tuple which contains the entire matched string and all
- # the subgroups of the match.
- result = match.group(*range(lastindex + 1)) if match else None
- if result:
- result = tuple(res.decode('utf-8') for res in result)
- result_list.append(result)
- except pexpect.TIMEOUT:
- raise ptyError('Timeout waiting for response.')
- finally:
- if not regex_list:
- # Must be longer than delaybeforesend
- time.sleep(0.1)
- self._close()
- return result_list
-
- def _issue_cmd_get_multi_results(self, cmd, regex):
- """Send command to the device and wait for multiple response.
-
- This function waits for arbitrary number of response message
- matching a regular expression.
-
- Args:
- cmd: The command issued.
- regex: Regular expression used to match response message.
-
- Returns:
- List of tuples, each of which contains the entire matched string and
- all the subgroups of the match. None if not matched.
- """
- result_list = []
- self._open()
- try:
- self._send(cmd)
- while True:
+ """Automate interactive commands on a pty interface."""
+
+ def __init__(self, interface, params, fast=False):
+ """Init class variables."""
+ self._child = None
+ self._fd = None
+ self._interface = interface
+ self._pty_path = self._interface.get_pty()
+ self._dict = UART_PARAMS.copy()
+ self._fast = fast
+
+ def __del__(self):
+ self.close()
+
+ def close(self):
+ """Close any open files and interfaces."""
+ if self._fd:
+ self._close()
+ self._interface.close()
+
+ def _open(self):
+ """Connect to serial device and create pexpect interface."""
+ assert self._fd is None
+ self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK)
+ # Don't allow forked processes to access.
+ fcntl.fcntl(
+ self._fd,
+ fcntl.F_SETFD,
+ fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC,
+ )
+ self._child = fdpexpect.fdspawn(self._fd)
+ # pexpect defaults to a 100ms delay before sending characters, to
+ # work around race conditions in ssh. We don't need this feature
+ # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up.
+ if self._fast:
+ self._child.delaybeforesend = 0.001
+
+ def _close(self):
+ """Close serial device connection."""
+ os.close(self._fd)
+ self._fd = None
+ self._child = None
+
+ def _flush(self):
+ """Flush device output to prevent previous messages interfering."""
+ if self._child.sendline("") != 1:
+ raise ptyError("Failed to send newline.")
+ # Have a maximum timeout for the flush operation. We should have cleared
+ # all data from the buffer, but if data is regularly being generated, we
+ # can't guarantee it will ever stop.
+ flush_end_time = time.time() + FLUSH_UART_TIMEOUT
+ while time.time() <= flush_end_time:
+ try:
+ self._child.expect(".", timeout=0.01)
+ except (pexpect.TIMEOUT, pexpect.EOF):
+ break
+ except OSError as e:
+ # EAGAIN indicates no data available, maybe we didn't wait long enough.
+ if e.errno != errno.EAGAIN:
+ raise
+ break
+
+ def _send(self, cmds):
+ """Send command to EC.
+
+ This function always flushes serial device before sending, and is used as
+ a wrapper function to make sure the channel is always flushed before
+ sending commands.
+
+ Args:
+ cmds: The commands to send to the device, either a list or a string.
+
+ Raises:
+ ptyError: Raised when writing to the device fails.
+ """
+ self._flush()
+ if not isinstance(cmds, list):
+ cmds = [cmds]
+ for cmd in cmds:
+ if self._child.sendline(cmd) != len(cmd) + 1:
+ raise ptyError("Failed to send command.")
+
+ def _issue_cmd(self, cmds):
+ """Send command to the device and do not wait for response.
+
+ Args:
+ cmds: The commands to send to the device, either a list or a string.
+ """
+ self._issue_cmd_get_results(cmds, [])
+
+ def _issue_cmd_get_results(self, cmds, regex_list, timeout=DEFAULT_UART_TIMEOUT):
+ """Send command to the device and wait for response.
+
+ This function waits for response message matching a regular
+ expressions.
+
+ Args:
+ cmds: The commands issued, either a list or a string.
+ regex_list: List of Regular expressions used to match response message.
+ Note1, list must be ordered.
+ Note2, empty list sends and returns.
+ timeout: time to wait for matching results before failing.
+
+ Returns:
+ List of tuples, each of which contains the entire matched string and
+ all the subgroups of the match. None if not matched.
+ For example:
+ response of the given command:
+ High temp: 37.2
+ Low temp: 36.4
+ regex_list:
+ ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)']
+ returns:
+ [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')]
+
+ Raises:
+ ptyError: If timed out waiting for a response
+ """
+ result_list = []
+ self._open()
try:
- self._child.expect(regex, timeout=0.1)
- match = self._child.match
- lastindex = match.lastindex if match and match.lastindex else 0
- # Create a tuple which contains the entire matched string and all
- # the subgroups of the match.
- result = match.group(*range(lastindex + 1)) if match else None
- if result:
- result = tuple(res.decode('utf-8') for res in result)
- result_list.append(result)
+ self._send(cmds)
+ for regex in regex_list:
+ self._child.expect(regex, timeout)
+ match = self._child.match
+ lastindex = match.lastindex if match and match.lastindex else 0
+ # Create a tuple which contains the entire matched string and all
+ # the subgroups of the match.
+ result = match.group(*range(lastindex + 1)) if match else None
+ if result:
+ result = tuple(res.decode("utf-8") for res in result)
+ result_list.append(result)
except pexpect.TIMEOUT:
- break
- finally:
- self._close()
- return result_list
-
- def _Set_uart_timeout(self, timeout):
- """Set timeout value for waiting for the device response.
-
- Args:
- timeout: Timeout value in second.
- """
- self._dict['uart_timeout'] = timeout
-
- def _Get_uart_timeout(self):
- """Get timeout value for waiting for the device response.
-
- Returns:
- Timeout value in second.
- """
- return self._dict['uart_timeout']
-
- def _Set_uart_regexp(self, regexp):
- """Set the list of regular expressions which matches the command response.
-
- Args:
- regexp: A string which contains a list of regular expressions.
- """
- if not isinstance(regexp, str):
- raise ptyError('The argument regexp should be a string.')
- self._dict['uart_regexp'] = ast.literal_eval(regexp)
-
- def _Get_uart_regexp(self):
- """Get the list of regular expressions which matches the command response.
-
- Returns:
- A string which contains a list of regular expressions.
- """
- return str(self._dict['uart_regexp'])
-
- def _Set_uart_cmd(self, cmd):
- """Set the UART command and send it to the device.
-
- If ec_uart_regexp is 'None', the command is just sent and it doesn't care
- about its response.
-
- If ec_uart_regexp is not 'None', the command is send and its response,
- which matches the regular expression of ec_uart_regexp, will be kept.
- Use its getter to obtain this result. If no match after ec_uart_timeout
- seconds, a timeout error will be raised.
-
- Args:
- cmd: A string of UART command.
- """
- if self._dict['uart_regexp']:
- self._dict['uart_cmd'] = self._issue_cmd_get_results(
- cmd, self._dict['uart_regexp'], self._dict['uart_timeout'])
- else:
- self._dict['uart_cmd'] = None
- self._issue_cmd(cmd)
-
- def _Set_uart_multicmd(self, cmds):
- """Set multiple UART commands and send them to the device.
-
- Note that ec_uart_regexp is not supported to match the results.
-
- Args:
- cmds: A semicolon-separated string of UART commands.
- """
- self._issue_cmd(cmds.split(';'))
-
- def _Get_uart_cmd(self):
- """Get the result of the latest UART command.
-
- Returns:
- A string which contains a list of tuples, each of which contains the
- entire matched string and all the subgroups of the match. 'None' if
- the ec_uart_regexp is 'None'.
- """
- return str(self._dict['uart_cmd'])
-
- def _Set_uart_capture(self, cmd):
- """Set UART capture mode (on or off).
-
- Once capture is enabled, UART output could be collected periodically by
- invoking _Get_uart_stream() below.
-
- Args:
- cmd: True for on, False for off
- """
- self._interface.set_capture_active(cmd)
-
- def _Get_uart_capture(self):
- """Get the UART capture mode (on or off)."""
- return self._interface.get_capture_active()
-
- def _Get_uart_stream(self):
- """Get uart stream generated since last time."""
- return self._interface.get_stream()
+ raise ptyError("Timeout waiting for response.")
+ finally:
+ if not regex_list:
+ # Must be longer than delaybeforesend
+ time.sleep(0.1)
+ self._close()
+ return result_list
+
+ def _issue_cmd_get_multi_results(self, cmd, regex):
+ """Send command to the device and wait for multiple response.
+
+ This function waits for arbitrary number of response message
+ matching a regular expression.
+
+ Args:
+ cmd: The command issued.
+ regex: Regular expression used to match response message.
+
+ Returns:
+ List of tuples, each of which contains the entire matched string and
+ all the subgroups of the match. None if not matched.
+ """
+ result_list = []
+ self._open()
+ try:
+ self._send(cmd)
+ while True:
+ try:
+ self._child.expect(regex, timeout=0.1)
+ match = self._child.match
+ lastindex = match.lastindex if match and match.lastindex else 0
+ # Create a tuple which contains the entire matched string and all
+ # the subgroups of the match.
+ result = match.group(*range(lastindex + 1)) if match else None
+ if result:
+ result = tuple(res.decode("utf-8") for res in result)
+ result_list.append(result)
+ except pexpect.TIMEOUT:
+ break
+ finally:
+ self._close()
+ return result_list
+
+ def _Set_uart_timeout(self, timeout):
+ """Set timeout value for waiting for the device response.
+
+ Args:
+ timeout: Timeout value in second.
+ """
+ self._dict["uart_timeout"] = timeout
+
+ def _Get_uart_timeout(self):
+ """Get timeout value for waiting for the device response.
+
+ Returns:
+ Timeout value in second.
+ """
+ return self._dict["uart_timeout"]
+
+ def _Set_uart_regexp(self, regexp):
+ """Set the list of regular expressions which matches the command response.
+
+ Args:
+ regexp: A string which contains a list of regular expressions.
+ """
+ if not isinstance(regexp, str):
+ raise ptyError("The argument regexp should be a string.")
+ self._dict["uart_regexp"] = ast.literal_eval(regexp)
+
+ def _Get_uart_regexp(self):
+ """Get the list of regular expressions which matches the command response.
+
+ Returns:
+ A string which contains a list of regular expressions.
+ """
+ return str(self._dict["uart_regexp"])
+
+ def _Set_uart_cmd(self, cmd):
+ """Set the UART command and send it to the device.
+
+ If ec_uart_regexp is 'None', the command is just sent and it doesn't care
+ about its response.
+
+ If ec_uart_regexp is not 'None', the command is send and its response,
+ which matches the regular expression of ec_uart_regexp, will be kept.
+ Use its getter to obtain this result. If no match after ec_uart_timeout
+ seconds, a timeout error will be raised.
+
+ Args:
+ cmd: A string of UART command.
+ """
+ if self._dict["uart_regexp"]:
+ self._dict["uart_cmd"] = self._issue_cmd_get_results(
+ cmd, self._dict["uart_regexp"], self._dict["uart_timeout"]
+ )
+ else:
+ self._dict["uart_cmd"] = None
+ self._issue_cmd(cmd)
+
+ def _Set_uart_multicmd(self, cmds):
+ """Set multiple UART commands and send them to the device.
+
+ Note that ec_uart_regexp is not supported to match the results.
+
+ Args:
+ cmds: A semicolon-separated string of UART commands.
+ """
+ self._issue_cmd(cmds.split(";"))
+
+ def _Get_uart_cmd(self):
+ """Get the result of the latest UART command.
+
+ Returns:
+ A string which contains a list of tuples, each of which contains the
+ entire matched string and all the subgroups of the match. 'None' if
+ the ec_uart_regexp is 'None'.
+ """
+ return str(self._dict["uart_cmd"])
+
+ def _Set_uart_capture(self, cmd):
+ """Set UART capture mode (on or off).
+
+ Once capture is enabled, UART output could be collected periodically by
+ invoking _Get_uart_stream() below.
+
+ Args:
+ cmd: True for on, False for off
+ """
+ self._interface.set_capture_active(cmd)
+
+ def _Get_uart_capture(self):
+ """Get the UART capture mode (on or off)."""
+ return self._interface.get_capture_active()
+
+ def _Get_uart_stream(self):
+ """Get uart stream generated since last time."""
+ return self._interface.get_stream()
diff --git a/extra/tigertool/ecusb/stm32uart.py b/extra/tigertool/ecusb/stm32uart.py
index 95219455a9..5794bd091b 100644
--- a/extra/tigertool/ecusb/stm32uart.py
+++ b/extra/tigertool/ecusb/stm32uart.py
@@ -17,232 +17,244 @@ import termios
import threading
import time
import tty
+
import usb
from . import stm32usb
class SuartError(Exception):
- """Class for exceptions of Suart."""
- def __init__(self, msg, value=0):
- """SuartError constructor.
+ """Class for exceptions of Suart."""
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SuartError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
+ def __init__(self, msg, value=0):
+ """SuartError constructor.
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SuartError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
-class Suart(object):
- """Provide interface to stm32 serial usb endpoint."""
- def __init__(self, vendor=0x18d1, product=0x501a, interface=0,
- serialname=None, debuglog=False):
- """Suart contstructor.
-
- Initializes stm32 USB stream interface.
-
- Args:
- vendor: usb vendor id of stm32 device
- product: usb product id of stm32 device
- interface: interface number of stm32 device to use
- serialname: serial name to target. Defaults to None.
- debuglog: chatty output. Defaults to False.
-
- Raises:
- SuartError: If init fails
- """
- self._ptym = None
- self._ptys = None
- self._ptyname = None
- self._rx_thread = None
- self._tx_thread = None
- self._debuglog = debuglog
- self._susb = stm32usb.Susb(vendor=vendor, product=product,
- interface=interface, serialname=serialname)
- self._running = False
-
- def __del__(self):
- """Suart destructor."""
- self.close()
-
- def close(self):
- """Stop all running threads."""
- self._running = False
- if self._rx_thread:
- self._rx_thread.join(2)
- self._rx_thread = None
- if self._tx_thread:
- self._tx_thread.join(2)
- self._tx_thread = None
- self._susb.close()
-
- def run_rx_thread(self):
- """Background loop to pass data from USB to pty."""
- ep = select.epoll()
- ep.register(self._ptym, select.EPOLLHUP)
- try:
- while self._running:
- events = ep.poll(0)
- # Check if the pty is connected to anything, or hungup.
- if not events:
- try:
- r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
- if r:
- if self._debuglog:
- print(''.join([chr(x) for x in r]), end='')
- os.write(self._ptym, r)
-
- # If we miss some characters on pty disconnect, that's fine.
- # ep.read() also throws USBError on timeout, which we discard.
- except OSError:
- pass
- except usb.core.USBError:
- pass
- else:
- time.sleep(.1)
- except Exception as e:
- raise e
-
- def run_tx_thread(self):
- """Background loop to pass data from pty to USB."""
- ep = select.epoll()
- ep.register(self._ptym, select.EPOLLHUP)
- try:
- while self._running:
- events = ep.poll(0)
- # Check if the pty is connected to anything, or hungup.
- if not events:
- try:
- r = os.read(self._ptym, 64)
- # TODO(crosbug.com/936182): Remove when the servo v4/micro console
- # issues are fixed.
- time.sleep(0.001)
- if r:
- self._susb._write_ep.write(r, self._susb.TIMEOUT_MS)
-
- except OSError:
- pass
- except usb.core.USBError:
- pass
- else:
- time.sleep(.1)
- except Exception as e:
- raise e
-
- def run(self):
- """Creates pthreads to poll stm32 & PTY for data."""
- m, s = os.openpty()
- self._ptyname = os.ttyname(s)
-
- self._ptym = m
- self._ptys = s
-
- os.fchmod(s, 0o660)
-
- # Change the owner and group of the PTY to the user who started servod.
- try:
- uid = int(os.environ.get('SUDO_UID', -1))
- except TypeError:
- uid = -1
- try:
- gid = int(os.environ.get('SUDO_GID', -1))
- except TypeError:
- gid = -1
- os.fchown(s, uid, gid)
-
- tty.setraw(self._ptym, termios.TCSADRAIN)
-
- # Generate a HUP flag on pty slave fd.
- os.fdopen(s).close()
-
- self._running = True
-
- self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[])
- self._rx_thread.daemon = True
- self._rx_thread.start()
-
- self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[])
- self._tx_thread.daemon = True
- self._tx_thread.start()
-
- def get_uart_props(self):
- """Get the uart's properties.
-
- Returns:
- dict where:
- baudrate: integer of uarts baudrate
- bits: integer, number of bits of data Can be 5|6|7|8 inclusive
- parity: integer, parity of 0-2 inclusive where:
- 0: no parity
- 1: odd parity
- 2: even parity
- sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
- 0: 1 stop bit
- 1: 1.5 stop bits
- 2: 2 stop bits
- """
- return {
- 'baudrate': 115200,
- 'bits': 8,
- 'parity': 0,
- 'sbits': 1,
- }
-
- def set_uart_props(self, line_props):
- """Set the uart's properties.
-
- Note that Suart cannot set properties
- and will fail if the properties are not the default 115200,8n1.
-
- Args:
- line_props: dict where:
- baudrate: integer of uarts baudrate
- bits: integer, number of bits of data ( prior to stop bit)
- parity: integer, parity of 0-2 inclusive where
- 0: no parity
- 1: odd parity
- 2: even parity
- sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
- 0: 1 stop bit
- 1: 1.5 stop bits
- 2: 2 stop bits
-
- Raises:
- SuartError: If requested line properties are not the default.
- """
- curr_props = self.get_uart_props()
- for prop in line_props:
- if line_props[prop] != curr_props[prop]:
- raise SuartError('Line property %s cannot be set from %s to %s' % (
- prop, curr_props[prop], line_props[prop]))
- return True
-
- def get_pty(self):
- """Gets path to pty for communication to/from uart.
-
- Returns:
- String path to the pty connected to the uart
- """
- return self._ptyname
+class Suart(object):
+ """Provide interface to stm32 serial usb endpoint."""
+
+ def __init__(
+ self,
+ vendor=0x18D1,
+ product=0x501A,
+ interface=0,
+ serialname=None,
+ debuglog=False,
+ ):
+ """Suart contstructor.
+
+ Initializes stm32 USB stream interface.
+
+ Args:
+ vendor: usb vendor id of stm32 device
+ product: usb product id of stm32 device
+ interface: interface number of stm32 device to use
+ serialname: serial name to target. Defaults to None.
+ debuglog: chatty output. Defaults to False.
+
+ Raises:
+ SuartError: If init fails
+ """
+ self._ptym = None
+ self._ptys = None
+ self._ptyname = None
+ self._rx_thread = None
+ self._tx_thread = None
+ self._debuglog = debuglog
+ self._susb = stm32usb.Susb(
+ vendor=vendor, product=product, interface=interface, serialname=serialname
+ )
+ self._running = False
+
+ def __del__(self):
+ """Suart destructor."""
+ self.close()
+
+ def close(self):
+ """Stop all running threads."""
+ self._running = False
+ if self._rx_thread:
+ self._rx_thread.join(2)
+ self._rx_thread = None
+ if self._tx_thread:
+ self._tx_thread.join(2)
+ self._tx_thread = None
+ self._susb.close()
+
+ def run_rx_thread(self):
+ """Background loop to pass data from USB to pty."""
+ ep = select.epoll()
+ ep.register(self._ptym, select.EPOLLHUP)
+ try:
+ while self._running:
+ events = ep.poll(0)
+ # Check if the pty is connected to anything, or hungup.
+ if not events:
+ try:
+ r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
+ if r:
+ if self._debuglog:
+ print("".join([chr(x) for x in r]), end="")
+ os.write(self._ptym, r)
+
+ # If we miss some characters on pty disconnect, that's fine.
+ # ep.read() also throws USBError on timeout, which we discard.
+ except OSError:
+ pass
+ except usb.core.USBError:
+ pass
+ else:
+ time.sleep(0.1)
+ except Exception as e:
+ raise e
+
+ def run_tx_thread(self):
+ """Background loop to pass data from pty to USB."""
+ ep = select.epoll()
+ ep.register(self._ptym, select.EPOLLHUP)
+ try:
+ while self._running:
+ events = ep.poll(0)
+ # Check if the pty is connected to anything, or hungup.
+ if not events:
+ try:
+ r = os.read(self._ptym, 64)
+ # TODO(crosbug.com/936182): Remove when the servo v4/micro console
+ # issues are fixed.
+ time.sleep(0.001)
+ if r:
+ self._susb._write_ep.write(r, self._susb.TIMEOUT_MS)
+
+ except OSError:
+ pass
+ except usb.core.USBError:
+ pass
+ else:
+ time.sleep(0.1)
+ except Exception as e:
+ raise e
+
+ def run(self):
+ """Creates pthreads to poll stm32 & PTY for data."""
+ m, s = os.openpty()
+ self._ptyname = os.ttyname(s)
+
+ self._ptym = m
+ self._ptys = s
+
+ os.fchmod(s, 0o660)
+
+ # Change the owner and group of the PTY to the user who started servod.
+ try:
+ uid = int(os.environ.get("SUDO_UID", -1))
+ except TypeError:
+ uid = -1
+
+ try:
+ gid = int(os.environ.get("SUDO_GID", -1))
+ except TypeError:
+ gid = -1
+ os.fchown(s, uid, gid)
+
+ tty.setraw(self._ptym, termios.TCSADRAIN)
+
+ # Generate a HUP flag on pty slave fd.
+ os.fdopen(s).close()
+
+ self._running = True
+
+ self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[])
+ self._rx_thread.daemon = True
+ self._rx_thread.start()
+
+ self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[])
+ self._tx_thread.daemon = True
+ self._tx_thread.start()
+
+ def get_uart_props(self):
+ """Get the uart's properties.
+
+ Returns:
+ dict where:
+ baudrate: integer of uarts baudrate
+ bits: integer, number of bits of data Can be 5|6|7|8 inclusive
+ parity: integer, parity of 0-2 inclusive where:
+ 0: no parity
+ 1: odd parity
+ 2: even parity
+ sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
+ 0: 1 stop bit
+ 1: 1.5 stop bits
+ 2: 2 stop bits
+ """
+ return {
+ "baudrate": 115200,
+ "bits": 8,
+ "parity": 0,
+ "sbits": 1,
+ }
+
+ def set_uart_props(self, line_props):
+ """Set the uart's properties.
+
+ Note that Suart cannot set properties
+ and will fail if the properties are not the default 115200,8n1.
+
+ Args:
+ line_props: dict where:
+ baudrate: integer of uarts baudrate
+ bits: integer, number of bits of data ( prior to stop bit)
+ parity: integer, parity of 0-2 inclusive where
+ 0: no parity
+ 1: odd parity
+ 2: even parity
+ sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
+ 0: 1 stop bit
+ 1: 1.5 stop bits
+ 2: 2 stop bits
+
+ Raises:
+ SuartError: If requested line properties are not the default.
+ """
+ curr_props = self.get_uart_props()
+ for prop in line_props:
+ if line_props[prop] != curr_props[prop]:
+ raise SuartError(
+ "Line property %s cannot be set from %s to %s"
+ % (prop, curr_props[prop], line_props[prop])
+ )
+ return True
+
+ def get_pty(self):
+ """Gets path to pty for communication to/from uart.
+
+ Returns:
+ String path to the pty connected to the uart
+ """
+ return self._ptyname
def main():
- """Run a suart test with the default parameters."""
- try:
- sobj = Suart()
- sobj.run()
+ """Run a suart test with the default parameters."""
+ try:
+ sobj = Suart()
+ sobj.run()
- # run() is a thread so just busy wait to mimic server.
- while True:
- # Ours sleeps to eleven!
- time.sleep(11)
- except KeyboardInterrupt:
- sys.exit(0)
+ # run() is a thread so just busy wait to mimic server.
+ while True:
+ # Ours sleeps to eleven!
+ time.sleep(11)
+ except KeyboardInterrupt:
+ sys.exit(0)
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/extra/tigertool/ecusb/stm32usb.py b/extra/tigertool/ecusb/stm32usb.py
index bfd5fbb1fb..4b2b23fbac 100644
--- a/extra/tigertool/ecusb/stm32usb.py
+++ b/extra/tigertool/ecusb/stm32usb.py
@@ -12,108 +12,115 @@ import usb
class SusbError(Exception):
- """Class for exceptions of Susb."""
- def __init__(self, msg, value=0):
- """SusbError constructor.
+ """Class for exceptions of Susb."""
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SusbError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
+ def __init__(self, msg, value=0):
+ """SusbError constructor.
+
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SusbError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
class Susb(object):
- """Provide stm32 USB functionality.
-
- Instance Variables:
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
- READ_ENDPOINT = 0x81
- WRITE_ENDPOINT = 0x1
- TIMEOUT_MS = 100
-
- def __init__(self, vendor=0x18d1,
- product=0x5027, interface=1, serialname=None, logger=None):
- """Susb constructor.
-
- Discovers and connects to stm32 USB endpoints.
-
- Args:
- vendor: usb vendor id of stm32 device.
- product: usb product id of stm32 device.
- interface: interface number ( 1 - 4 ) of stm32 device to use.
- serialname: string of device serialname.
- logger: none
-
- Raises:
- SusbError: An error accessing Susb object
+ """Provide stm32 USB functionality.
+
+ Instance Variables:
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- self._vendor = vendor
- self._product = product
- self._interface = interface
- self._serialname = serialname
- self._find_device()
-
- def _find_device(self):
- """Set up the usb endpoint"""
- # Find the stm32.
- dev_g = usb.core.find(idVendor=self._vendor, idProduct=self._product,
- find_all=True)
- dev_list = list(dev_g)
-
- if not dev_list:
- raise SusbError('USB device not found')
-
- # Check if we have multiple stm32s and we've specified the serial.
- dev = None
- if self._serialname:
- for d in dev_list:
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == self._serialname:
- dev = d
- break
- if dev is None:
- raise SusbError('USB device(%s) not found' % self._serialname)
- else:
- try:
- dev = dev_list[0]
- except StopIteration:
- raise SusbError('USB device %04x:%04x not found' % (
- self._vendor, self._product))
-
- # If we can't set configuration, it's already been set.
- try:
- dev.set_configuration()
- except usb.core.USBError:
- pass
-
- self._dev = dev
-
- # Get an endpoint instance.
- cfg = dev.get_active_configuration()
- intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface)
- self._intf = intf
- if not intf:
- raise SusbError('Interface %04x:%04x - 0x%x not found' % (
- self._vendor, self._product, self._interface))
-
- # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel
- # module driver that produces a ttyUSB, or direct endpoint access, but
- # can't do both at the same time.
- if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
- dev.detach_kernel_driver(intf.bInterfaceNumber)
-
- read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
- read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
- self._read_ep = read_ep
-
- write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
- write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
- self._write_ep = write_ep
-
- def close(self):
- usb.util.dispose_resources(self._dev)
+
+ READ_ENDPOINT = 0x81
+ WRITE_ENDPOINT = 0x1
+ TIMEOUT_MS = 100
+
+ def __init__(
+ self, vendor=0x18D1, product=0x5027, interface=1, serialname=None, logger=None
+ ):
+ """Susb constructor.
+
+ Discovers and connects to stm32 USB endpoints.
+
+ Args:
+ vendor: usb vendor id of stm32 device.
+ product: usb product id of stm32 device.
+ interface: interface number ( 1 - 4 ) of stm32 device to use.
+ serialname: string of device serialname.
+ logger: none
+
+ Raises:
+ SusbError: An error accessing Susb object
+ """
+ self._vendor = vendor
+ self._product = product
+ self._interface = interface
+ self._serialname = serialname
+ self._find_device()
+
+ def _find_device(self):
+ """Set up the usb endpoint"""
+ # Find the stm32.
+ dev_g = usb.core.find(
+ idVendor=self._vendor, idProduct=self._product, find_all=True
+ )
+ dev_list = list(dev_g)
+
+ if not dev_list:
+ raise SusbError("USB device not found")
+
+ # Check if we have multiple stm32s and we've specified the serial.
+ dev = None
+ if self._serialname:
+ for d in dev_list:
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == self._serialname:
+ dev = d
+ break
+ if dev is None:
+ raise SusbError("USB device(%s) not found" % self._serialname)
+ else:
+ try:
+ dev = dev_list[0]
+ except StopIteration:
+ raise SusbError(
+ "USB device %04x:%04x not found" % (self._vendor, self._product)
+ )
+
+ # If we can't set configuration, it's already been set.
+ try:
+ dev.set_configuration()
+ except usb.core.USBError:
+ pass
+
+ self._dev = dev
+
+ # Get an endpoint instance.
+ cfg = dev.get_active_configuration()
+ intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface)
+ self._intf = intf
+ if not intf:
+ raise SusbError(
+ "Interface %04x:%04x - 0x%x not found"
+ % (self._vendor, self._product, self._interface)
+ )
+
+ # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel
+ # module driver that produces a ttyUSB, or direct endpoint access, but
+ # can't do both at the same time.
+ if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
+ dev.detach_kernel_driver(intf.bInterfaceNumber)
+
+ read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
+ read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
+ self._read_ep = read_ep
+
+ write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
+ write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
+ self._write_ep = write_ep
+
+ def close(self):
+ usb.util.dispose_resources(self._dev)
diff --git a/extra/tigertool/ecusb/tiny_servo_common.py b/extra/tigertool/ecusb/tiny_servo_common.py
index 726e2e64b7..599bae80dc 100644
--- a/extra/tigertool/ecusb/tiny_servo_common.py
+++ b/extra/tigertool/ecusb/tiny_servo_common.py
@@ -17,215 +17,224 @@ import time
import six
import usb
-from . import pty_driver
-from . import stm32uart
+from . import pty_driver, stm32uart
def get_subprocess_args():
- if six.PY3:
- return {'encoding': 'utf-8'}
- return {}
+ if six.PY3:
+ return {"encoding": "utf-8"}
+ return {}
class TinyServoError(Exception):
- """Exceptions."""
+ """Exceptions."""
def log(output):
- """Print output to console, logfiles can be added here.
+ """Print output to console, logfiles can be added here.
+
+ Args:
+ output: string to output.
+ """
+ sys.stdout.write(output)
+ sys.stdout.write("\n")
+ sys.stdout.flush()
- Args:
- output: string to output.
- """
- sys.stdout.write(output)
- sys.stdout.write('\n')
- sys.stdout.flush()
def check_usb(vidpid, serialname=None):
- """Check if |vidpid| is present on the system's USB.
+ """Check if |vidpid| is present on the system's USB.
+
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specified.
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specified.
+ Returns:
+ True if found, False, otherwise.
+ """
+ if get_usb_dev(vidpid, serialname):
+ return True
- Returns:
- True if found, False, otherwise.
- """
- if get_usb_dev(vidpid, serialname):
- return True
+ return False
- return False
def check_usb_sn(vidpid):
- """Return the serial number
+ """Return the serial number
- Return the serial number of the first USB device with VID:PID vidpid,
- or None if no device is found. This will not work well with two of
- the same device attached.
+ Return the serial number of the first USB device with VID:PID vidpid,
+ or None if no device is found. This will not work well with two of
+ the same device attached.
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- Returns:
- string serial number if found, None otherwise.
- """
- dev = get_usb_dev(vidpid)
+ Returns:
+ string serial number if found, None otherwise.
+ """
+ dev = get_usb_dev(vidpid)
- if dev:
- dev_serial = usb.util.get_string(dev, dev.iSerialNumber)
+ if dev:
+ dev_serial = usb.util.get_string(dev, dev.iSerialNumber)
- return dev_serial
+ return dev_serial
+
+ return None
- return None
def get_usb_dev(vidpid, serialname=None):
- """Return the USB pyusb devie struct
+ """Return the USB pyusb devie struct
+
+ Return the dev struct of the first USB device with VID:PID vidpid,
+ or None if no device is found. If more than one device check serial
+ if supplied.
+
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specified.
+
+ Returns:
+ pyusb device if found, None otherwise.
+ """
+ vidpidst = vidpid.split(":")
+ vid = int(vidpidst[0], 16)
+ pid = int(vidpidst[1], 16)
+
+ dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True)
+ dev_list = list(dev_g)
+
+ if not dev_list:
+ return None
+
+ # Check if we have multiple devices and we've specified the serial.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == serialname:
+ dev = d
+ break
+ if dev is None:
+ return None
+ else:
+ try:
+ dev = dev_list[0]
+ except StopIteration:
+ return None
+
+ return dev
+
- Return the dev struct of the first USB device with VID:PID vidpid,
- or None if no device is found. If more than one device check serial
- if supplied.
+def check_usb_dev(vidpid, serialname=None):
+ """Return the USB dev number
+
+ Return the dev number of the first USB device with VID:PID vidpid,
+ or None if no device is found. If more than one device check serial
+ if supplied.
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specified.
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specified.
- Returns:
- pyusb device if found, None otherwise.
- """
- vidpidst = vidpid.split(':')
- vid = int(vidpidst[0], 16)
- pid = int(vidpidst[1], 16)
+ Returns:
+ usb device number if found, None otherwise.
+ """
+ dev = get_usb_dev(vidpid, serialname=serialname)
- dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True)
- dev_list = list(dev_g)
+ if dev:
+ return dev.address
- if not dev_list:
return None
- # Check if we have multiple devices and we've specified the serial.
- dev = None
- if serialname:
- for d in dev_list:
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == serialname:
- dev = d
- break
- if dev is None:
- return None
- else:
- try:
- dev = dev_list[0]
- except StopIteration:
- return None
-
- return dev
-def check_usb_dev(vidpid, serialname=None):
- """Return the USB dev number
+def wait_for_usb_remove(vidpid, serialname=None, timeout=None):
+ """Wait for USB device with vidpid to be removed.
- Return the dev number of the first USB device with VID:PID vidpid,
- or None if no device is found. If more than one device check serial
- if supplied.
+ Wrapper for wait_for_usb below
+ """
+ wait_for_usb(vidpid, serialname=serialname, timeout=timeout, desiredpresence=False)
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specified.
- Returns:
- usb device number if found, None otherwise.
- """
- dev = get_usb_dev(vidpid, serialname=serialname)
+def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True):
+ """Wait for usb device with vidpid to be present/absent.
- if dev:
- return dev.address
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specificed.
+ timeout: timeout in seconds, None for no timeout.
+ desiredpresence: True for present, False for not present.
- return None
+ Raises:
+ TinyServoError: on timeout.
+ """
+ if timeout:
+ finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout)
+ while check_usb(vidpid, serialname) != desiredpresence:
+ time.sleep(0.1)
+ if timeout:
+ if datetime.datetime.now() > finish:
+ raise TinyServoError("Timeout", "Timeout waiting for USB %s" % vidpid)
-def wait_for_usb_remove(vidpid, serialname=None, timeout=None):
- """Wait for USB device with vidpid to be removed.
+def do_serialno(serialno, pty):
+ """Set serialnumber 'serialno' via ec console 'pty'.
- Wrapper for wait_for_usb below
- """
- wait_for_usb(vidpid, serialname=serialname,
- timeout=timeout, desiredpresence=False)
+ Commands are:
+ # > serialno set 1234
+ # Saving serial number
+ # Serial number: 1234
-def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True):
- """Wait for usb device with vidpid to be present/absent.
-
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specificed.
- timeout: timeout in seconds, None for no timeout.
- desiredpresence: True for present, False for not present.
-
- Raises:
- TinyServoError: on timeout.
- """
- if timeout:
- finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout)
- while check_usb(vidpid, serialname) != desiredpresence:
- time.sleep(.1)
- if timeout:
- if datetime.datetime.now() > finish:
- raise TinyServoError('Timeout', 'Timeout waiting for USB %s' % vidpid)
+ Args:
+ serialno: string serial number to set.
+ pty: tinyservo console to send commands.
+
+ Raises:
+ TinyServoError: on failure to set.
+ ptyError: on command interface error.
+ """
+ cmd = r"serialno set %s" % serialno
+ regex = r"Serial number:\s+(\S+)"
+
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ sn = results[1].strip().strip("\n\r")
+
+ if sn == serialno:
+ log("Success !")
+ log("Serial set to %s" % sn)
+ else:
+ log("Serial number set to %s but saved as %s." % (serialno, sn))
+ raise TinyServoError(
+ "Serial Number", "Serial number set to %s but saved as %s." % (serialno, sn)
+ )
-def do_serialno(serialno, pty):
- """Set serialnumber 'serialno' via ec console 'pty'.
-
- Commands are:
- # > serialno set 1234
- # Saving serial number
- # Serial number: 1234
-
- Args:
- serialno: string serial number to set.
- pty: tinyservo console to send commands.
-
- Raises:
- TinyServoError: on failure to set.
- ptyError: on command interface error.
- """
- cmd = r'serialno set %s' % serialno
- regex = r'Serial number:\s+(\S+)'
-
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- sn = results[1].strip().strip('\n\r')
-
- if sn == serialno:
- log('Success !')
- log('Serial set to %s' % sn)
- else:
- log('Serial number set to %s but saved as %s.' % (serialno, sn))
- raise TinyServoError(
- 'Serial Number',
- 'Serial number set to %s but saved as %s.' % (serialno, sn))
def setup_tinyservod(vidpid, interface, serialname=None, debuglog=False):
- """Set up a pty
-
- Set up a pty to the ec console in order
- to send commands. Returns a pty_driver object.
-
- Args:
- vidpid: string vidpid of device to access.
- interface: not used.
- serialname: string serial name of device requested, optional.
- debuglog: chatty printout (boolean)
-
- Returns:
- pty object
-
- Raises:
- UsbError, SusbError: on device not found
- """
- vidstr, pidstr = vidpid.split(':')
- vid = int(vidstr, 16)
- pid = int(pidstr, 16)
- suart = stm32uart.Suart(vendor=vid, product=pid,
- interface=interface, serialname=serialname,
- debuglog=debuglog)
- suart.run()
- pty = pty_driver.ptyDriver(suart, [])
-
- return pty
+ """Set up a pty
+
+ Set up a pty to the ec console in order
+ to send commands. Returns a pty_driver object.
+
+ Args:
+ vidpid: string vidpid of device to access.
+ interface: not used.
+ serialname: string serial name of device requested, optional.
+ debuglog: chatty printout (boolean)
+
+ Returns:
+ pty object
+
+ Raises:
+ UsbError, SusbError: on device not found
+ """
+ vidstr, pidstr = vidpid.split(":")
+ vid = int(vidstr, 16)
+ pid = int(pidstr, 16)
+ suart = stm32uart.Suart(
+ vendor=vid,
+ product=pid,
+ interface=interface,
+ serialname=serialname,
+ debuglog=debuglog,
+ )
+ suart.run()
+ pty = pty_driver.ptyDriver(suart, [])
+
+ return pty
diff --git a/extra/tigertool/ecusb/tiny_servod.py b/extra/tigertool/ecusb/tiny_servod.py
index 632d9c3a20..ca5ca63f31 100644
--- a/extra/tigertool/ecusb/tiny_servod.py
+++ b/extra/tigertool/ecusb/tiny_servod.py
@@ -8,47 +8,48 @@
"""Helper class to facilitate communication to servo ec console."""
-from ecusb import pty_driver
-from ecusb import stm32uart
+from ecusb import pty_driver, stm32uart
class TinyServod(object):
- """Helper class to wrap a pty_driver with interface."""
-
- def __init__(self, vid, pid, interface, serialname=None, debug=False):
- """Build the driver and interface.
-
- Args:
- vid: servo device vid
- pid: servo device pid
- interface: which usb interface the servo console is on
- serialname: the servo device serial (if available)
- """
- self._vid = vid
- self._pid = pid
- self._interface = interface
- self._serial = serialname
- self._debug = debug
- self._init()
-
- def _init(self):
- self.suart = stm32uart.Suart(vendor=self._vid,
- product=self._pid,
- interface=self._interface,
- serialname=self._serial,
- debuglog=self._debug)
- self.suart.run()
- self.pty = pty_driver.ptyDriver(self.suart, [])
-
- def reinitialize(self):
- """Reinitialize the connect after a reset/disconnect/etc."""
- self.close()
- self._init()
-
- def close(self):
- """Close out the connection and release resources.
-
- Note: if another TinyServod process or servod itself needs the same device
- it's necessary to call this to ensure the usb device is available.
- """
- self.suart.close()
+ """Helper class to wrap a pty_driver with interface."""
+
+ def __init__(self, vid, pid, interface, serialname=None, debug=False):
+ """Build the driver and interface.
+
+ Args:
+ vid: servo device vid
+ pid: servo device pid
+ interface: which usb interface the servo console is on
+ serialname: the servo device serial (if available)
+ """
+ self._vid = vid
+ self._pid = pid
+ self._interface = interface
+ self._serial = serialname
+ self._debug = debug
+ self._init()
+
+ def _init(self):
+ self.suart = stm32uart.Suart(
+ vendor=self._vid,
+ product=self._pid,
+ interface=self._interface,
+ serialname=self._serial,
+ debuglog=self._debug,
+ )
+ self.suart.run()
+ self.pty = pty_driver.ptyDriver(self.suart, [])
+
+ def reinitialize(self):
+ """Reinitialize the connect after a reset/disconnect/etc."""
+ self.close()
+ self._init()
+
+ def close(self):
+ """Close out the connection and release resources.
+
+ Note: if another TinyServod process or servod itself needs the same device
+ it's necessary to call this to ensure the usb device is available.
+ """
+ self.suart.close()
diff --git a/extra/tigertool/tigertest.py b/extra/tigertool/tigertest.py
index 0cd31c8cce..8f8b2c7f03 100755
--- a/extra/tigertool/tigertest.py
+++ b/extra/tigertool/tigertest.py
@@ -13,7 +13,6 @@ import argparse
import subprocess
import sys
-
# Script to control tigertail USB-C Mux board.
#
# optional arguments:
@@ -35,58 +34,60 @@ import sys
def testCmd(cmd, expected_results):
- """Run command on console, check for success.
-
- Args:
- cmd: shell command to run.
- expected_results: a list object of strings expected in the result.
-
- Raises:
- Exception on fail.
- """
- print('run: ' + cmd)
- try:
- p = subprocess.run(cmd, shell=True, check=False, capture_output=True)
- output = p.stdout.decode('utf-8')
- error = p.stderr.decode('utf-8')
- assert p.returncode == 0
- for result in expected_results:
- output.index(result)
- except Exception as e:
- print('FAIL')
- print('cmd: ' + cmd)
- print('error: ' + str(e))
- print('stdout:\n' + output)
- print('stderr:\n' + error)
- print('expected: ' + str(expected_results))
- print('RC: ' + str(p.returncode))
- raise e
+ """Run command on console, check for success.
+
+ Args:
+ cmd: shell command to run.
+ expected_results: a list object of strings expected in the result.
+
+ Raises:
+ Exception on fail.
+ """
+ print("run: " + cmd)
+ try:
+ p = subprocess.run(cmd, shell=True, check=False, capture_output=True)
+ output = p.stdout.decode("utf-8")
+ error = p.stderr.decode("utf-8")
+ assert p.returncode == 0
+ for result in expected_results:
+ output.index(result)
+ except Exception as e:
+ print("FAIL")
+ print("cmd: " + cmd)
+ print("error: " + str(e))
+ print("stdout:\n" + output)
+ print("stderr:\n" + error)
+ print("expected: " + str(expected_results))
+ print("RC: " + str(p.returncode))
+ raise e
+
def test_sequence():
- testCmd('./tigertool.py --reboot', ['PASS'])
- testCmd('./tigertool.py --setserialno test', ['PASS'])
- testCmd('./tigertool.py --check_serial', ['test', 'PASS'])
- testCmd('./tigertool.py -s test --check_serial', ['test', 'PASS'])
- testCmd('./tigertool.py -m A', ['Mux set to A', 'PASS'])
- testCmd('./tigertool.py -m B', ['Mux set to B', 'PASS'])
- testCmd('./tigertool.py -m off', ['Mux set to off', 'PASS'])
- testCmd('./tigertool.py -p', ['PASS'])
- testCmd('./tigertool.py -r rw', ['PASS'])
- testCmd('./tigertool.py -r ro', ['PASS'])
- testCmd('./tigertool.py --check_version', ['RW', 'RO', 'PASS'])
-
- print('PASS')
+ testCmd("./tigertool.py --reboot", ["PASS"])
+ testCmd("./tigertool.py --setserialno test", ["PASS"])
+ testCmd("./tigertool.py --check_serial", ["test", "PASS"])
+ testCmd("./tigertool.py -s test --check_serial", ["test", "PASS"])
+ testCmd("./tigertool.py -m A", ["Mux set to A", "PASS"])
+ testCmd("./tigertool.py -m B", ["Mux set to B", "PASS"])
+ testCmd("./tigertool.py -m off", ["Mux set to off", "PASS"])
+ testCmd("./tigertool.py -p", ["PASS"])
+ testCmd("./tigertool.py -r rw", ["PASS"])
+ testCmd("./tigertool.py -r ro", ["PASS"])
+ testCmd("./tigertool.py --check_version", ["RW", "RO", "PASS"])
+
+ print("PASS")
+
def main(argv):
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument('-c', '--count', type=int, default=1,
- help='loops to run')
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("-c", "--count", type=int, default=1, help="loops to run")
+
+ opts = parser.parse_args(argv)
- opts = parser.parse_args(argv)
+ for i in range(1, opts.count + 1):
+ print("Iteration: %d" % i)
+ test_sequence()
- for i in range(1, opts.count + 1):
- print('Iteration: %d' % i)
- test_sequence()
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv[1:])
diff --git a/extra/tigertool/tigertool.py b/extra/tigertool/tigertool.py
index 6baae8abdf..37b7b01495 100755
--- a/extra/tigertool/tigertool.py
+++ b/extra/tigertool/tigertool.py
@@ -17,287 +17,308 @@ import time
import ecusb.tiny_servo_common as c
-STM_VIDPID = '18d1:5027'
-serialno = 'Uninitialized'
+STM_VIDPID = "18d1:5027"
+serialno = "Uninitialized"
+
def do_mux(mux, pty):
- """Set mux via ec console 'pty'.
+ """Set mux via ec console 'pty'.
+
+ Args:
+ mux: mux to connect to DUT, 'A', 'B', or 'off'
+ pty: a pty object connected to tigertail
- Args:
- mux: mux to connect to DUT, 'A', 'B', or 'off'
- pty: a pty object connected to tigertail
+ Commands are:
+ # > mux A
+ # TYPE-C mux is A
+ """
+ validmux = ["A", "B", "off"]
+ if mux not in validmux:
+ c.log("Mux setting %s invalid, try one of %s" % (mux, validmux))
+ return False
- Commands are:
- # > mux A
- # TYPE-C mux is A
- """
- validmux = ['A', 'B', 'off']
- if mux not in validmux:
- c.log('Mux setting %s invalid, try one of %s' % (mux, validmux))
- return False
+ cmd = "mux %s" % mux
+ regex = "TYPE\-C mux is ([^\s\r\n]*)\r"
- cmd = 'mux %s' % mux
- regex = 'TYPE\-C mux is ([^\s\r\n]*)\r'
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ result = results[1].strip().strip("\n\r")
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- result = results[1].strip().strip('\n\r')
+ if result != mux:
+ c.log("Mux set to %s but saved as %s." % (mux, result))
+ return False
+ c.log("Mux set to %s" % result)
+ return True
- if result != mux:
- c.log('Mux set to %s but saved as %s.' % (mux, result))
- return False
- c.log('Mux set to %s' % result)
- return True
def do_version(pty):
- """Check version via ec console 'pty'.
-
- Args:
- pty: a pty object connected to tigertail
-
- Commands are:
- # > version
- # Chip: stm stm32f07x
- # Board: 0
- # RO: tigertail_v1.1.6749-74d1a312e
- # RW: tigertail_v1.1.6749-74d1a312e
- # Build: tigertail_v1.1.6749-74d1a312e
- # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com
- """
- cmd = 'version'
- regex = r'RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+' \
- r'(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)'
-
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- c.log('Version is %s' % results[3])
- c.log('RO: %s' % results[1])
- c.log('RW: %s' % results[2])
- c.log('Date: %s' % results[4])
- c.log('Src: %s' % results[5])
-
- return True
+ """Check version via ec console 'pty'.
+
+ Args:
+ pty: a pty object connected to tigertail
+
+ Commands are:
+ # > version
+ # Chip: stm stm32f07x
+ # Board: 0
+ # RO: tigertail_v1.1.6749-74d1a312e
+ # RW: tigertail_v1.1.6749-74d1a312e
+ # Build: tigertail_v1.1.6749-74d1a312e
+ # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com
+ """
+ cmd = "version"
+ regex = (
+ r"RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+"
+ r"(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)"
+ )
+
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ c.log("Version is %s" % results[3])
+ c.log("RO: %s" % results[1])
+ c.log("RW: %s" % results[2])
+ c.log("Date: %s" % results[4])
+ c.log("Src: %s" % results[5])
+
+ return True
+
def do_check_serial(pty):
- """Check serial via ec console 'pty'.
+ """Check serial via ec console 'pty'.
- Args:
- pty: a pty object connected to tigertail
+ Args:
+ pty: a pty object connected to tigertail
- Commands are:
- # > serialno
- # Serial number: number
- """
- cmd = 'serialno'
- regex = r'Serial number: ([^\n\r]+)'
+ Commands are:
+ # > serialno
+ # Serial number: number
+ """
+ cmd = "serialno"
+ regex = r"Serial number: ([^\n\r]+)"
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- c.log('Serial is %s' % results[1])
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ c.log("Serial is %s" % results[1])
- return True
+ return True
def do_power(count, bus, pty):
- """Check power usage via ec console 'pty'.
-
- Args:
- count: number of samples to capture
- bus: rail to monitor, 'vbus', 'cc1', or 'cc2'
- pty: a pty object connected to tigertail
-
- Commands are:
- # > ina 0
- # Configuration: 4127
- # Shunt voltage: 02c4 => 1770 uV
- # Bus voltage : 1008 => 5130 mV
- # Power : 0019 => 625 mW
- # Current : 0082 => 130 mA
- # Calibration : 0155
- # Mask/Enable : 0008
- # Alert limit : 0000
- """
- if bus == 'vbus':
- ina = 0
- if bus == 'cc1':
- ina = 4
- if bus == 'cc2':
- ina = 1
-
- start = time.time()
-
- c.log('time,\tmV,\tmW,\tmA')
-
- cmd = 'ina %s' % ina
- regex = r'Bus voltage : \S+ \S+ (\d+) mV\s+' \
- r'Power : \S+ \S+ (\d+) mW\s+' \
- r'Current : \S+ \S+ (\d+) mA'
-
- for i in range(0, count):
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- c.log('%.2f,\t%s,\t%s\t%s' % (
- time.time() - start,
- results[1], results[2], results[3]))
+ """Check power usage via ec console 'pty'.
+
+ Args:
+ count: number of samples to capture
+ bus: rail to monitor, 'vbus', 'cc1', or 'cc2'
+ pty: a pty object connected to tigertail
+
+ Commands are:
+ # > ina 0
+ # Configuration: 4127
+ # Shunt voltage: 02c4 => 1770 uV
+ # Bus voltage : 1008 => 5130 mV
+ # Power : 0019 => 625 mW
+ # Current : 0082 => 130 mA
+ # Calibration : 0155
+ # Mask/Enable : 0008
+ # Alert limit : 0000
+ """
+ if bus == "vbus":
+ ina = 0
+ if bus == "cc1":
+ ina = 4
+ if bus == "cc2":
+ ina = 1
+
+ start = time.time()
+
+ c.log("time,\tmV,\tmW,\tmA")
+
+ cmd = "ina %s" % ina
+ regex = (
+ r"Bus voltage : \S+ \S+ (\d+) mV\s+"
+ r"Power : \S+ \S+ (\d+) mW\s+"
+ r"Current : \S+ \S+ (\d+) mA"
+ )
+
+ for i in range(0, count):
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ c.log(
+ "%.2f,\t%s,\t%s\t%s"
+ % (time.time() - start, results[1], results[2], results[3])
+ )
+
+ return True
- return True
def do_reboot(pty, serialname):
- """Reboot via ec console pty
-
- Args:
- pty: a pty object connected to tigertail
- serialname: serial name, can be None.
-
- Command is: reboot.
- """
- cmd = 'reboot'
-
- # Check usb dev number on current instance.
- devno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
- if not devno:
- c.log('Device not found')
- return False
-
- try:
- pty._issue_cmd(cmd)
- except Exception as e:
- c.log('Failed to send command: ' + str(e))
- return False
-
- try:
- c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
- # it's not going to. This step just goes faster if it's detected.
- pass
-
- try:
- c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- c.log('Failed to return from reboot: ' + str(e))
- return False
-
- # Check that the device had a new device number, i.e. it's
- # disconnected and reconnected.
- newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
- if newdevno == devno:
- c.log("Device didn't reboot")
- return False
-
- return True
+ """Reboot via ec console pty
+
+ Args:
+ pty: a pty object connected to tigertail
+ serialname: serial name, can be None.
+
+ Command is: reboot.
+ """
+ cmd = "reboot"
+
+ # Check usb dev number on current instance.
+ devno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
+ if not devno:
+ c.log("Device not found")
+ return False
+
+ try:
+ pty._issue_cmd(cmd)
+ except Exception as e:
+ c.log("Failed to send command: " + str(e))
+ return False
+
+ try:
+ c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
+ # it's not going to. This step just goes faster if it's detected.
+ pass
+
+ try:
+ c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ c.log("Failed to return from reboot: " + str(e))
+ return False
+
+ # Check that the device had a new device number, i.e. it's
+ # disconnected and reconnected.
+ newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
+ if newdevno == devno:
+ c.log("Device didn't reboot")
+ return False
+
+ return True
+
def do_sysjump(region, pty, serialname):
- """Set region via ec console 'pty'.
-
- Args:
- region: ec code region to execute, 'ro' or 'rw'
- pty: a pty object connected to tigertail
- serialname: serial name, can be None.
-
- Commands are:
- # > sysjump rw
- """
- validregion = ['ro', 'rw']
- if region not in validregion:
- c.log('Region setting %s invalid, try one of %s' % (
- region, validregion))
- return False
-
- cmd = 'sysjump %s' % region
- try:
- pty._issue_cmd(cmd)
- except Exception as e:
- c.log('Exception: ' + str(e))
- return False
-
- try:
- c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
- # it's not going to. This step just goes faster if it's detected.
- pass
-
- try:
- c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- c.log('Failed to return from restart: ' + str(e))
- return False
-
- c.log('Region requested %s' % region)
- return True
+ """Set region via ec console 'pty'.
+
+ Args:
+ region: ec code region to execute, 'ro' or 'rw'
+ pty: a pty object connected to tigertail
+ serialname: serial name, can be None.
+
+ Commands are:
+ # > sysjump rw
+ """
+ validregion = ["ro", "rw"]
+ if region not in validregion:
+ c.log("Region setting %s invalid, try one of %s" % (region, validregion))
+ return False
+
+ cmd = "sysjump %s" % region
+ try:
+ pty._issue_cmd(cmd)
+ except Exception as e:
+ c.log("Exception: " + str(e))
+ return False
+
+ try:
+ c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
+ # it's not going to. This step just goes faster if it's detected.
+ pass
+
+ try:
+ c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ c.log("Failed to return from restart: " + str(e))
+ return False
+
+ c.log("Region requested %s" % region)
+ return True
+
def get_parser():
- parser = argparse.ArgumentParser(
- description=__doc__)
- parser.add_argument('-s', '--serialno', type=str, default=None,
- help='serial number of board to use')
- parser.add_argument('-b', '--bus', type=str, default='vbus',
- help='Which rail to log: [vbus|cc1|cc2]')
- group = parser.add_mutually_exclusive_group()
- group.add_argument('--setserialno', type=str, default=None,
- help='serial number to set on the board.')
- group.add_argument('--check_serial', action='store_true',
- help='check serial number set on the board.')
- group.add_argument('-m', '--mux', type=str, default=None,
- help='mux selection')
- group.add_argument('-p', '--power', action='store_true',
- help='check VBUS')
- group.add_argument('-l', '--powerlog', type=int, default=None,
- help='log VBUS')
- group.add_argument('-r', '--sysjump', type=str, default=None,
- help='region selection')
- group.add_argument('--reboot', action='store_true',
- help='reboot tigertail')
- group.add_argument('--check_version', action='store_true',
- help='check tigertail version')
- return parser
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "-s", "--serialno", type=str, default=None, help="serial number of board to use"
+ )
+ parser.add_argument(
+ "-b",
+ "--bus",
+ type=str,
+ default="vbus",
+ help="Which rail to log: [vbus|cc1|cc2]",
+ )
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument(
+ "--setserialno",
+ type=str,
+ default=None,
+ help="serial number to set on the board.",
+ )
+ group.add_argument(
+ "--check_serial",
+ action="store_true",
+ help="check serial number set on the board.",
+ )
+ group.add_argument("-m", "--mux", type=str, default=None, help="mux selection")
+ group.add_argument("-p", "--power", action="store_true", help="check VBUS")
+ group.add_argument("-l", "--powerlog", type=int, default=None, help="log VBUS")
+ group.add_argument(
+ "-r", "--sysjump", type=str, default=None, help="region selection"
+ )
+ group.add_argument("--reboot", action="store_true", help="reboot tigertail")
+ group.add_argument(
+ "--check_version", action="store_true", help="check tigertail version"
+ )
+ return parser
+
def main(argv):
- parser = get_parser()
- opts = parser.parse_args(argv)
+ parser = get_parser()
+ opts = parser.parse_args(argv)
- result = True
+ result = True
- # Let's make sure there's a tigertail
- # If nothing found in 5 seconds, fail.
- c.wait_for_usb(STM_VIDPID, timeout=5., serialname=opts.serialno)
+ # Let's make sure there's a tigertail
+ # If nothing found in 5 seconds, fail.
+ c.wait_for_usb(STM_VIDPID, timeout=5.0, serialname=opts.serialno)
- pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno)
+ pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno)
- if opts.bus not in ('vbus', 'cc1', 'cc2'):
- c.log('Try --bus [vbus|cc1|cc2]')
- result = False
+ if opts.bus not in ("vbus", "cc1", "cc2"):
+ c.log("Try --bus [vbus|cc1|cc2]")
+ result = False
- elif opts.setserialno:
- try:
- c.do_serialno(opts.setserialno, pty)
- except Exception:
- result = False
+ elif opts.setserialno:
+ try:
+ c.do_serialno(opts.setserialno, pty)
+ except Exception:
+ result = False
- elif opts.mux:
- result &= do_mux(opts.mux, pty)
+ elif opts.mux:
+ result &= do_mux(opts.mux, pty)
- elif opts.sysjump:
- result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno)
+ elif opts.sysjump:
+ result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno)
- elif opts.reboot:
- result &= do_reboot(pty, serialname=opts.serialno)
+ elif opts.reboot:
+ result &= do_reboot(pty, serialname=opts.serialno)
- elif opts.check_version:
- result &= do_version(pty)
+ elif opts.check_version:
+ result &= do_version(pty)
- elif opts.check_serial:
- result &= do_check_serial(pty)
+ elif opts.check_serial:
+ result &= do_check_serial(pty)
- elif opts.power:
- result &= do_power(1, opts.bus, pty)
+ elif opts.power:
+ result &= do_power(1, opts.bus, pty)
- elif opts.powerlog:
- result &= do_power(opts.powerlog, opts.bus, pty)
+ elif opts.powerlog:
+ result &= do_power(opts.powerlog, opts.bus, pty)
- if result:
- c.log('PASS')
- else:
- c.log('FAIL')
- sys.exit(-1)
+ if result:
+ c.log("PASS")
+ else:
+ c.log("FAIL")
+ sys.exit(-1)
-if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+if __name__ == "__main__":
+ sys.exit(main(sys.argv[1:]))
diff --git a/extra/usb_power/convert_power_log_board.py b/extra/usb_power/convert_power_log_board.py
index 8aab77ee4c..c1b25f57db 100644
--- a/extra/usb_power/convert_power_log_board.py
+++ b/extra/usb_power/convert_power_log_board.py
@@ -14,6 +14,7 @@ Program to convert sweetberry config to servod config template.
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import json
import os
import sys
@@ -48,21 +49,25 @@ def write_to_file(file, sweetberry, inas):
inas: list of inas read from board file.
"""
- with open(file, 'w') as pyfile:
+ with open(file, "w") as pyfile:
- pyfile.write('inas = [\n')
+ pyfile.write("inas = [\n")
for rec in inas:
- if rec['sweetberry'] != sweetberry:
+ if rec["sweetberry"] != sweetberry:
continue
# EX : ('sweetberry', 0x40, 'SB_FW_CAM_2P8', 5.0, 1.000, 3, False),
- channel, i2c_addr = Spower.CHMAP[rec['channel']]
- record = (" ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')"
- ",\n" % (i2c_addr, rec['name'], rec['rs'], channel))
+ channel, i2c_addr = Spower.CHMAP[rec["channel"]]
+ record = " ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" ",\n" % (
+ i2c_addr,
+ rec["name"],
+ rec["rs"],
+ channel,
+ )
pyfile.write(record)
- pyfile.write(']\n')
+ pyfile.write("]\n")
def main(argv):
@@ -76,16 +81,18 @@ def main(argv):
inas = fetch_records(inputf)
- sweetberry = set(rec['sweetberry'] for rec in inas)
+ sweetberry = set(rec["sweetberry"] for rec in inas)
if len(sweetberry) == 2:
- print("Converting %s to %s and %s" % (inputf, basename + '_a.py',
- basename + '_b.py'))
- write_to_file(basename + '_a.py', 'A', inas)
- write_to_file(basename + '_b.py', 'B', inas)
+ print(
+ "Converting %s to %s and %s"
+ % (inputf, basename + "_a.py", basename + "_b.py")
+ )
+ write_to_file(basename + "_a.py", "A", inas)
+ write_to_file(basename + "_b.py", "B", inas)
else:
- print("Converting %s to %s" % (inputf, basename + '.py'))
- write_to_file(basename + '.py', sweetberry.pop(), inas)
+ print("Converting %s to %s" % (inputf, basename + ".py"))
+ write_to_file(basename + ".py", sweetberry.pop(), inas)
if __name__ == "__main__":
diff --git a/extra/usb_power/convert_servo_ina.py b/extra/usb_power/convert_servo_ina.py
index 1c70f31aeb..6ccd474e6c 100755
--- a/extra/usb_power/convert_servo_ina.py
+++ b/extra/usb_power/convert_servo_ina.py
@@ -14,67 +14,71 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import os
import sys
def fetch_records(basename):
- """Import records from servo_ina file.
+ """Import records from servo_ina file.
- servo_ina files are python imports, and have a list of tuples with
- the INA data.
- (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True)
+ servo_ina files are python imports, and have a list of tuples with
+ the INA data.
+ (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True)
- Args:
- basename: python import name (filename -.py)
+ Args:
+ basename: python import name (filename -.py)
- Returns:
- list of tuples as described above.
- """
- ina_desc = __import__(basename)
- return ina_desc.inas
+ Returns:
+ list of tuples as described above.
+ """
+ ina_desc = __import__(basename)
+ return ina_desc.inas
def main(argv):
- if len(argv) != 2:
- print("usage:")
- print(" %s input.py" % argv[0])
- return
+ if len(argv) != 2:
+ print("usage:")
+ print(" %s input.py" % argv[0])
+ return
- inputf = argv[1]
- basename = os.path.splitext(inputf)[0]
- outputf = basename + '.board'
- outputs = basename + '.scenario'
+ inputf = argv[1]
+ basename = os.path.splitext(inputf)[0]
+ outputf = basename + ".board"
+ outputs = basename + ".scenario"
- print("Converting %s to %s, %s" % (inputf, outputf, outputs))
+ print("Converting %s to %s, %s" % (inputf, outputf, outputs))
- inas = fetch_records(basename)
+ inas = fetch_records(basename)
+ boardfile = open(outputf, "w")
+ scenario = open(outputs, "w")
- boardfile = open(outputf, 'w')
- scenario = open(outputs, 'w')
+ boardfile.write("[\n")
+ scenario.write("[\n")
+ start = True
- boardfile.write('[\n')
- scenario.write('[\n')
- start = True
+ for rec in inas:
+ if start:
+ start = False
+ else:
+ boardfile.write(",\n")
+ scenario.write(",\n")
- for rec in inas:
- if start:
- start = False
- else:
- boardfile.write(',\n')
- scenario.write(',\n')
+ record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % (
+ rec[2],
+ rec[4],
+ rec[1] - 64,
+ )
+ boardfile.write(record)
+ scenario.write('"%s"' % rec[2])
- record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % (
- rec[2], rec[4], rec[1] - 64)
- boardfile.write(record)
- scenario.write('"%s"' % rec[2])
+ boardfile.write("\n")
+ boardfile.write("]")
- boardfile.write('\n')
- boardfile.write(']')
+ scenario.write("\n")
+ scenario.write("]")
- scenario.write('\n')
- scenario.write(']')
if __name__ == "__main__":
- main(sys.argv)
+ main(sys.argv)
diff --git a/extra/usb_power/powerlog.py b/extra/usb_power/powerlog.py
index 82cce3daed..d893e5c6b9 100755
--- a/extra/usb_power/powerlog.py
+++ b/extra/usb_power/powerlog.py
@@ -14,9 +14,9 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import argparse
import array
-from distutils import sysconfig
import json
import logging
import os
@@ -25,884 +25,988 @@ import struct
import sys
import time
import traceback
+from distutils import sysconfig
import usb
-
from stats_manager import StatsManager
# Directory where hdctools installs configuration files into.
-LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), 'servo',
- 'data')
+LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), "servo", "data")
# Potential config file locations: current working directory, the same directory
# as powerlog.py file or LIB_DIR.
-CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)),
- LIB_DIR]
-
-def logoutput(msg):
- print(msg)
- sys.stdout.flush()
-
-def process_filename(filename):
- """Find the file path from the filename.
-
- If filename is already the complete path, return that directly. If filename is
- just the short name, look for the file in the current working directory, in
- the directory of the current .py file, and then in the directory installed by
- hdctools. If the file is found, return the complete path of the file.
-
- Args:
- filename: complete file path or short file name.
-
- Returns:
- a complete file path.
-
- Raises:
- IOError if filename does not exist.
- """
- # Check if filename is absolute path.
- if os.path.isabs(filename) and os.path.isfile(filename):
- return filename
- # Check if filename is relative to a known config location.
- for dirname in CONFIG_LOCATIONS:
- file_at_dir = os.path.join(dirname, filename)
- if os.path.isfile(file_at_dir):
- return file_at_dir
- raise IOError('No such file or directory: \'%s\'' % filename)
-
-
-class Spower(object):
- """Power class to access devices on the bus.
-
- Usage:
- bus = Spower()
-
- Instance Variables:
- _dev: pyUSB device object
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
-
- # INA interface type.
- INA_POWER = 1
- INA_BUSV = 2
- INA_CURRENT = 3
- INA_SHUNTV = 4
- # INA_SUFFIX is used to differentiate multiple ina types for the same power
- # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1
- # (power, no suffix for backward compatibility).
- INA_SUFFIX = ['', '', '_busv', '_cur', '_shuntv']
-
- # usb power commands
- CMD_RESET = 0x0000
- CMD_STOP = 0x0001
- CMD_ADDINA = 0x0002
- CMD_START = 0x0003
- CMD_NEXT = 0x0004
- CMD_SETTIME = 0x0005
-
- # Map between header channel number (0-47)
- # and INA I2C bus/addr on sweetberry.
- CHMAP = {
- 0: (3, 0x40),
- 1: (1, 0x40),
- 2: (2, 0x40),
- 3: (0, 0x40),
- 4: (3, 0x41),
- 5: (1, 0x41),
- 6: (2, 0x41),
- 7: (0, 0x41),
- 8: (3, 0x42),
- 9: (1, 0x42),
- 10: (2, 0x42),
- 11: (0, 0x42),
- 12: (3, 0x43),
- 13: (1, 0x43),
- 14: (2, 0x43),
- 15: (0, 0x43),
- 16: (3, 0x44),
- 17: (1, 0x44),
- 18: (2, 0x44),
- 19: (0, 0x44),
- 20: (3, 0x45),
- 21: (1, 0x45),
- 22: (2, 0x45),
- 23: (0, 0x45),
- 24: (3, 0x46),
- 25: (1, 0x46),
- 26: (2, 0x46),
- 27: (0, 0x46),
- 28: (3, 0x47),
- 29: (1, 0x47),
- 30: (2, 0x47),
- 31: (0, 0x47),
- 32: (3, 0x48),
- 33: (1, 0x48),
- 34: (2, 0x48),
- 35: (0, 0x48),
- 36: (3, 0x49),
- 37: (1, 0x49),
- 38: (2, 0x49),
- 39: (0, 0x49),
- 40: (3, 0x4a),
- 41: (1, 0x4a),
- 42: (2, 0x4a),
- 43: (0, 0x4a),
- 44: (3, 0x4b),
- 45: (1, 0x4b),
- 46: (2, 0x4b),
- 47: (0, 0x4b),
- }
-
- def __init__(self, board, vendor=0x18d1,
- product=0x5020, interface=1, serialname=None):
- self._logger = logging.getLogger(__name__)
- self._board = board
-
- # Find the stm32.
- dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
- dev_list = list(dev_g)
- if dev_list is None:
- raise Exception("Power", "USB device not found")
-
- # Check if we have multiple stm32s and we've specified the serial.
- dev = None
- if serialname:
- for d in dev_list:
- dev_serial = "PyUSB dioesn't have a stable interface"
- try:
- dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
- except ValueError:
- # Incompatible pyUsb version.
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == serialname:
- dev = d
- break
- if dev is None:
- raise Exception("Power", "USB device(%s) not found" % serialname)
- else:
- try:
- dev = dev_list[0]
- except TypeError:
- # Incompatible pyUsb version.
- dev = dev_list.next()
-
- self._logger.debug("Found USB device: %04x:%04x", vendor, product)
- self._dev = dev
-
- # Get an endpoint instance.
- try:
- dev.set_configuration()
- except usb.USBError:
- pass
- cfg = dev.get_active_configuration()
-
- intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \
- i.bInterfaceClass==255 and i.bInterfaceSubClass==0x54)
-
- self._intf = intf
- self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber)
-
- read_ep = usb.util.find_descriptor(
- intf,
- # match the first IN endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_IN
- )
-
- self._read_ep = read_ep
- self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress)
-
- write_ep = usb.util.find_descriptor(
- intf,
- # match the first OUT endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_OUT
- )
-
- self._write_ep = write_ep
- self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress)
-
- self.clear_ina_struct()
-
- self._logger.debug("Found power logging USB endpoint.")
-
- def clear_ina_struct(self):
- """ Clear INA description struct."""
- self._inas = []
+CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), LIB_DIR]
- def append_ina_struct(self, name, rs, port, addr,
- data=None, ina_type=INA_POWER):
- """Add an INA descriptor into the list of active INAs.
- Args:
- name: Readable name of this channel.
- rs: Sense resistor value in ohms, floating point.
- port: I2C channel this INA is connected to.
- addr: I2C addr of this INA.
- data: Misc data for special handling, board specific.
- ina_type: INA function to use, power, voltage, etc.
- """
- ina = {}
- ina['name'] = name
- ina['rs'] = rs
- ina['port'] = port
- ina['addr'] = addr
- ina['type'] = ina_type
- # Calculate INA231 Calibration register
- # (see INA231 spec p.15)
- # CurrentLSB = uA per div = 80mV / (Rsh * 2^15)
- # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000)
- ina['uAscale'] = 80000000. / (rs * 0x8000);
- ina['uWscale'] = 25. * ina['uAscale'];
- ina['mVscale'] = 1.25
- ina['uVscale'] = 2.5
- ina['data'] = data
- self._inas.append(ina)
-
- def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000):
- """Write command to logger logic.
-
- This function writes byte command values list to stm, then reads
- byte status.
-
- Args:
- write_list: list of command byte values [0~255].
- read_count: number of status byte values to read.
-
- Interface:
- write: [command, data ... ]
- read: [status ]
-
- Returns:
- bytes read, or None on failure.
- """
- self._logger.debug("Spower.wr_command(write_list=[%s] (%d), read_count=%s)",
- list(bytearray(write_list)), len(write_list), read_count)
-
- # Clean up args from python style to correct types.
- write_length = 0
- if write_list:
- write_length = len(write_list)
- if not read_count:
- read_count = 0
-
- # Send command to stm32.
- if write_list:
- cmd = write_list
- ret = self._write_ep.write(cmd, wtimeout)
-
- self._logger.debug("RET: %s ", ret)
-
- # Read back response if necessary.
- if read_count:
- bytesread = self._read_ep.read(512, rtimeout)
- self._logger.debug("BYTES: [%s]", bytesread)
-
- if len(bytesread) != read_count:
- pass
-
- self._logger.debug("STATUS: 0x%02x", int(bytesread[0]))
- if read_count == 1:
- return bytesread[0]
- else:
- return bytesread
-
- return None
-
- def clear(self):
- """Clear pending reads on the stm32"""
- try:
- while True:
- ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50)
- self._logger.debug("Try Clear: read %s",
- "success" if ret == 0 else "failure")
- except:
- pass
-
- def send_reset(self):
- """Reset the power interface on the stm32"""
- cmd = struct.pack("<H", self.CMD_RESET)
- ret = self.wr_command(cmd, rtimeout=50, wtimeout=50)
- self._logger.debug("Command RESET: %s",
- "success" if ret == 0 else "failure")
-
- def reset(self):
- """Try resetting the USB interface until success.
-
- Use linear back off strategy when encounter the error with 10ms increment.
-
- Raises:
- Exception on failure.
- """
- max_reset_retry = 100
- for count in range(1, max_reset_retry + 1):
- self.clear()
- try:
- self.send_reset()
- return
- except Exception as e:
- self.clear()
- self.clear()
- self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e)
- time.sleep(count * 0.01)
- raise Exception("Power", "Failed to reset")
-
- def stop(self):
- """Stop any active data acquisition."""
- cmd = struct.pack("<H", self.CMD_STOP)
- ret = self.wr_command(cmd)
- self._logger.debug("Command STOP: %s",
- "success" if ret == 0 else "failure")
-
- def start(self, integration_us):
- """Start data acquisition.
-
- Args:
- integration_us: int, how many us between samples, and
- how often the data block must be read.
+def logoutput(msg):
+ print(msg)
+ sys.stdout.flush()
- Returns:
- actual sampling interval in ms.
- """
- cmd = struct.pack("<HI", self.CMD_START, integration_us)
- read = self.wr_command(cmd, read_count=5)
- actual_us = 0
- if len(read) == 5:
- ret, actual_us = struct.unpack("<BI", read)
- self._logger.debug("Command START: %s %dus",
- "success" if ret == 0 else "failure", actual_us)
- else:
- self._logger.debug("Command START: FAIL")
- return actual_us
+def process_filename(filename):
+ """Find the file path from the filename.
- def add_ina_name(self, name_tuple):
- """Add INA from board config.
+ If filename is already the complete path, return that directly. If filename is
+ just the short name, look for the file in the current working directory, in
+ the directory of the current .py file, and then in the directory installed by
+ hdctools. If the file is found, return the complete path of the file.
Args:
- name_tuple: name and type of power rail in board config.
+ filename: complete file path or short file name.
Returns:
- True if INA added, False if the INA is not on this board.
+ a complete file path.
Raises:
- Exception on unexpected failure.
+ IOError if filename does not exist.
"""
- name, ina_type = name_tuple
-
- for datum in self._brdcfg:
- if datum["name"] == name:
- rs = int(float(datum["rs"]) * 1000.)
- board = datum["sweetberry"]
-
- if board == self._board:
- if 'port' in datum and 'addr' in datum:
- port = datum['port']
- addr = datum['addr']
- else:
- channel = int(datum["channel"])
- port, addr = self.CHMAP[channel]
- self.add_ina(port, ina_type, addr, 0, rs, data=datum)
- return True
- else:
- return False
- raise Exception("Power", "Failed to find INA %s" % name)
-
- def set_time(self, timestamp_us):
- """Set sweetberry time to match host time.
-
- Args:
- timestamp_us: host timestmap in us.
- """
- # 0x0005 , 8 byte timestamp
- cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us)
- ret = self.wr_command(cmd)
+ # Check if filename is absolute path.
+ if os.path.isabs(filename) and os.path.isfile(filename):
+ return filename
+ # Check if filename is relative to a known config location.
+ for dirname in CONFIG_LOCATIONS:
+ file_at_dir = os.path.join(dirname, filename)
+ if os.path.isfile(file_at_dir):
+ return file_at_dir
+ raise IOError("No such file or directory: '%s'" % filename)
- self._logger.debug("Command SETTIME: %s",
- "success" if ret == 0 else "failure")
- def add_ina(self, bus, ina_type, addr, extra, resistance, data=None):
- """Add an INA to the data acquisition list.
+class Spower(object):
+ """Power class to access devices on the bus.
- Args:
- bus: which i2c bus the INA is on. Same ordering as Si2c.
- ina_type: Ina interface: INA_POWER/BUSV/etc.
- addr: 7 bit i2c addr of this INA
- extra: extra data for nonstandard configs.
- resistance: int, shunt resistance in mOhm
- """
- # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs
- cmd = struct.pack("<HBBBBI", self.CMD_ADDINA,
- bus, ina_type, addr, extra, resistance)
- ret = self.wr_command(cmd)
- if ret == 0:
- if data:
- name = data['name']
- else:
- name = "ina%d_%02x" % (bus, addr)
- self.append_ina_struct(name, resistance, bus, addr,
- data=data, ina_type=ina_type)
- self._logger.debug("Command ADD_INA: %s",
- "success" if ret == 0 else "failure")
-
- def report_header_size(self):
- """Helper function to calculate power record header size."""
- result = 2
- timestamp = 8
- return result + timestamp
-
- def report_size(self, ina_count):
- """Helper function to calculate full power record size."""
- record = 2
-
- datasize = self.report_header_size() + ina_count * record
- # Round to multiple of 4 bytes.
- datasize = int(((datasize + 3) // 4) * 4)
-
- return datasize
-
- def read_line(self):
- """Read a line of data from the setup INAs
+ Usage:
+ bus = Spower()
- Returns:
- list of dicts of the values read by ina/type tuple, otherwise None.
- [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}]
+ Instance Variables:
+ _dev: pyUSB device object
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- try:
- expected_bytes = self.report_size(len(self._inas))
- cmd = struct.pack("<H", self.CMD_NEXT)
- bytesread = self.wr_command(cmd, read_count=expected_bytes)
- except usb.core.USBError as e:
- self._logger.error("READ LINE FAILED %s", e)
- return None
-
- if len(bytesread) == 1:
- if bytesread[0] != 0x6:
- self._logger.debug("READ LINE FAILED bytes: %d ret: %02x",
- len(bytesread), bytesread[0])
- return None
-
- if len(bytesread) % expected_bytes != 0:
- self._logger.debug("READ LINE WARNING: expected %d, got %d",
- expected_bytes, len(bytesread))
-
- packet_count = len(bytesread) // expected_bytes
-
- values = []
- for i in range(0, packet_count):
- start = i * expected_bytes
- end = (i + 1) * expected_bytes
- record = self.interpret_line(bytesread[start:end])
- values.append(record)
-
- return values
-
- def interpret_line(self, data):
- """Interpret a power record from INAs
- Args:
- data: one single record of bytes.
+ # INA interface type.
+ INA_POWER = 1
+ INA_BUSV = 2
+ INA_CURRENT = 3
+ INA_SHUNTV = 4
+ # INA_SUFFIX is used to differentiate multiple ina types for the same power
+ # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1
+ # (power, no suffix for backward compatibility).
+ INA_SUFFIX = ["", "", "_busv", "_cur", "_shuntv"]
+
+ # usb power commands
+ CMD_RESET = 0x0000
+ CMD_STOP = 0x0001
+ CMD_ADDINA = 0x0002
+ CMD_START = 0x0003
+ CMD_NEXT = 0x0004
+ CMD_SETTIME = 0x0005
+
+ # Map between header channel number (0-47)
+ # and INA I2C bus/addr on sweetberry.
+ CHMAP = {
+ 0: (3, 0x40),
+ 1: (1, 0x40),
+ 2: (2, 0x40),
+ 3: (0, 0x40),
+ 4: (3, 0x41),
+ 5: (1, 0x41),
+ 6: (2, 0x41),
+ 7: (0, 0x41),
+ 8: (3, 0x42),
+ 9: (1, 0x42),
+ 10: (2, 0x42),
+ 11: (0, 0x42),
+ 12: (3, 0x43),
+ 13: (1, 0x43),
+ 14: (2, 0x43),
+ 15: (0, 0x43),
+ 16: (3, 0x44),
+ 17: (1, 0x44),
+ 18: (2, 0x44),
+ 19: (0, 0x44),
+ 20: (3, 0x45),
+ 21: (1, 0x45),
+ 22: (2, 0x45),
+ 23: (0, 0x45),
+ 24: (3, 0x46),
+ 25: (1, 0x46),
+ 26: (2, 0x46),
+ 27: (0, 0x46),
+ 28: (3, 0x47),
+ 29: (1, 0x47),
+ 30: (2, 0x47),
+ 31: (0, 0x47),
+ 32: (3, 0x48),
+ 33: (1, 0x48),
+ 34: (2, 0x48),
+ 35: (0, 0x48),
+ 36: (3, 0x49),
+ 37: (1, 0x49),
+ 38: (2, 0x49),
+ 39: (0, 0x49),
+ 40: (3, 0x4A),
+ 41: (1, 0x4A),
+ 42: (2, 0x4A),
+ 43: (0, 0x4A),
+ 44: (3, 0x4B),
+ 45: (1, 0x4B),
+ 46: (2, 0x4B),
+ 47: (0, 0x4B),
+ }
+
+ def __init__(
+ self, board, vendor=0x18D1, product=0x5020, interface=1, serialname=None
+ ):
+ self._logger = logging.getLogger(__name__)
+ self._board = board
+
+ # Find the stm32.
+ dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
+ dev_list = list(dev_g)
+ if dev_list is None:
+ raise Exception("Power", "USB device not found")
+
+ # Check if we have multiple stm32s and we've specified the serial.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ dev_serial = "PyUSB dioesn't have a stable interface"
+ try:
+ dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
+ except ValueError:
+ # Incompatible pyUsb version.
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == serialname:
+ dev = d
+ break
+ if dev is None:
+ raise Exception("Power", "USB device(%s) not found" % serialname)
+ else:
+ try:
+ dev = dev_list[0]
+ except TypeError:
+ # Incompatible pyUsb version.
+ dev = dev_list.next()
- Output:
- stdout of the record in csv format.
+ self._logger.debug("Found USB device: %04x:%04x", vendor, product)
+ self._dev = dev
- Returns:
- dict containing name, value of recorded data.
- """
- status, size = struct.unpack("<BB", data[0:2])
- if len(data) != self.report_size(size):
- self._logger.error("READ LINE FAILED st:%d size:%d expected:%d len:%d",
- status, size, self.report_size(size), len(data))
- else:
- pass
+ # Get an endpoint instance.
+ try:
+ dev.set_configuration()
+ except usb.USBError:
+ pass
+ cfg = dev.get_active_configuration()
+
+ intf = usb.util.find_descriptor(
+ cfg,
+ custom_match=lambda i: i.bInterfaceClass == 255
+ and i.bInterfaceSubClass == 0x54,
+ )
+
+ self._intf = intf
+ self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber)
+
+ read_ep = usb.util.find_descriptor(
+ intf,
+ # match the first IN endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_IN,
+ )
+
+ self._read_ep = read_ep
+ self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress)
+
+ write_ep = usb.util.find_descriptor(
+ intf,
+ # match the first OUT endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_OUT,
+ )
+
+ self._write_ep = write_ep
+ self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress)
+
+ self.clear_ina_struct()
+
+ self._logger.debug("Found power logging USB endpoint.")
+
+ def clear_ina_struct(self):
+ """Clear INA description struct."""
+ self._inas = []
+
+ def append_ina_struct(self, name, rs, port, addr, data=None, ina_type=INA_POWER):
+ """Add an INA descriptor into the list of active INAs.
+
+ Args:
+ name: Readable name of this channel.
+ rs: Sense resistor value in ohms, floating point.
+ port: I2C channel this INA is connected to.
+ addr: I2C addr of this INA.
+ data: Misc data for special handling, board specific.
+ ina_type: INA function to use, power, voltage, etc.
+ """
+ ina = {}
+ ina["name"] = name
+ ina["rs"] = rs
+ ina["port"] = port
+ ina["addr"] = addr
+ ina["type"] = ina_type
+ # Calculate INA231 Calibration register
+ # (see INA231 spec p.15)
+ # CurrentLSB = uA per div = 80mV / (Rsh * 2^15)
+ # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000)
+ ina["uAscale"] = 80000000.0 / (rs * 0x8000)
+ ina["uWscale"] = 25.0 * ina["uAscale"]
+ ina["mVscale"] = 1.25
+ ina["uVscale"] = 2.5
+ ina["data"] = data
+ self._inas.append(ina)
+
+ def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000):
+ """Write command to logger logic.
+
+ This function writes byte command values list to stm, then reads
+ byte status.
+
+ Args:
+ write_list: list of command byte values [0~255].
+ read_count: number of status byte values to read.
+
+ Interface:
+ write: [command, data ... ]
+ read: [status ]
+
+ Returns:
+ bytes read, or None on failure.
+ """
+ self._logger.debug(
+ "Spower.wr_command(write_list=[%s] (%d), read_count=%s)",
+ list(bytearray(write_list)),
+ len(write_list),
+ read_count,
+ )
+
+ # Clean up args from python style to correct types.
+ write_length = 0
+ if write_list:
+ write_length = len(write_list)
+ if not read_count:
+ read_count = 0
+
+ # Send command to stm32.
+ if write_list:
+ cmd = write_list
+ ret = self._write_ep.write(cmd, wtimeout)
+
+ self._logger.debug("RET: %s ", ret)
+
+ # Read back response if necessary.
+ if read_count:
+ bytesread = self._read_ep.read(512, rtimeout)
+ self._logger.debug("BYTES: [%s]", bytesread)
+
+ if len(bytesread) != read_count:
+ pass
+
+ self._logger.debug("STATUS: 0x%02x", int(bytesread[0]))
+ if read_count == 1:
+ return bytesread[0]
+ else:
+ return bytesread
+
+ return None
+
+ def clear(self):
+ """Clear pending reads on the stm32"""
+ try:
+ while True:
+ ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50)
+ self._logger.debug(
+ "Try Clear: read %s", "success" if ret == 0 else "failure"
+ )
+ except:
+ pass
+
+ def send_reset(self):
+ """Reset the power interface on the stm32"""
+ cmd = struct.pack("<H", self.CMD_RESET)
+ ret = self.wr_command(cmd, rtimeout=50, wtimeout=50)
+ self._logger.debug("Command RESET: %s", "success" if ret == 0 else "failure")
+
+ def reset(self):
+ """Try resetting the USB interface until success.
+
+ Use linear back off strategy when encounter the error with 10ms increment.
+
+ Raises:
+ Exception on failure.
+ """
+ max_reset_retry = 100
+ for count in range(1, max_reset_retry + 1):
+ self.clear()
+ try:
+ self.send_reset()
+ return
+ except Exception as e:
+ self.clear()
+ self.clear()
+ self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e)
+ time.sleep(count * 0.01)
+ raise Exception("Power", "Failed to reset")
+
+ def stop(self):
+ """Stop any active data acquisition."""
+ cmd = struct.pack("<H", self.CMD_STOP)
+ ret = self.wr_command(cmd)
+ self._logger.debug("Command STOP: %s", "success" if ret == 0 else "failure")
+
+ def start(self, integration_us):
+ """Start data acquisition.
+
+ Args:
+ integration_us: int, how many us between samples, and
+ how often the data block must be read.
+
+ Returns:
+ actual sampling interval in ms.
+ """
+ cmd = struct.pack("<HI", self.CMD_START, integration_us)
+ read = self.wr_command(cmd, read_count=5)
+ actual_us = 0
+ if len(read) == 5:
+ ret, actual_us = struct.unpack("<BI", read)
+ self._logger.debug(
+ "Command START: %s %dus",
+ "success" if ret == 0 else "failure",
+ actual_us,
+ )
+ else:
+ self._logger.debug("Command START: FAIL")
+
+ return actual_us
+
+ def add_ina_name(self, name_tuple):
+ """Add INA from board config.
+
+ Args:
+ name_tuple: name and type of power rail in board config.
+
+ Returns:
+ True if INA added, False if the INA is not on this board.
+
+ Raises:
+ Exception on unexpected failure.
+ """
+ name, ina_type = name_tuple
+
+ for datum in self._brdcfg:
+ if datum["name"] == name:
+ rs = int(float(datum["rs"]) * 1000.0)
+ board = datum["sweetberry"]
+
+ if board == self._board:
+ if "port" in datum and "addr" in datum:
+ port = datum["port"]
+ addr = datum["addr"]
+ else:
+ channel = int(datum["channel"])
+ port, addr = self.CHMAP[channel]
+ self.add_ina(port, ina_type, addr, 0, rs, data=datum)
+ return True
+ else:
+ return False
+ raise Exception("Power", "Failed to find INA %s" % name)
+
+ def set_time(self, timestamp_us):
+ """Set sweetberry time to match host time.
+
+ Args:
+ timestamp_us: host timestmap in us.
+ """
+ # 0x0005 , 8 byte timestamp
+ cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us)
+ ret = self.wr_command(cmd)
+
+ self._logger.debug("Command SETTIME: %s", "success" if ret == 0 else "failure")
+
+ def add_ina(self, bus, ina_type, addr, extra, resistance, data=None):
+ """Add an INA to the data acquisition list.
+
+ Args:
+ bus: which i2c bus the INA is on. Same ordering as Si2c.
+ ina_type: Ina interface: INA_POWER/BUSV/etc.
+ addr: 7 bit i2c addr of this INA
+ extra: extra data for nonstandard configs.
+ resistance: int, shunt resistance in mOhm
+ """
+ # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs
+ cmd = struct.pack(
+ "<HBBBBI", self.CMD_ADDINA, bus, ina_type, addr, extra, resistance
+ )
+ ret = self.wr_command(cmd)
+ if ret == 0:
+ if data:
+ name = data["name"]
+ else:
+ name = "ina%d_%02x" % (bus, addr)
+ self.append_ina_struct(
+ name, resistance, bus, addr, data=data, ina_type=ina_type
+ )
+ self._logger.debug("Command ADD_INA: %s", "success" if ret == 0 else "failure")
+
+ def report_header_size(self):
+ """Helper function to calculate power record header size."""
+ result = 2
+ timestamp = 8
+ return result + timestamp
+
+ def report_size(self, ina_count):
+ """Helper function to calculate full power record size."""
+ record = 2
+
+ datasize = self.report_header_size() + ina_count * record
+ # Round to multiple of 4 bytes.
+ datasize = int(((datasize + 3) // 4) * 4)
+
+ return datasize
+
+ def read_line(self):
+ """Read a line of data from the setup INAs
+
+ Returns:
+ list of dicts of the values read by ina/type tuple, otherwise None.
+ [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}]
+ """
+ try:
+ expected_bytes = self.report_size(len(self._inas))
+ cmd = struct.pack("<H", self.CMD_NEXT)
+ bytesread = self.wr_command(cmd, read_count=expected_bytes)
+ except usb.core.USBError as e:
+ self._logger.error("READ LINE FAILED %s", e)
+ return None
+
+ if len(bytesread) == 1:
+ if bytesread[0] != 0x6:
+ self._logger.debug(
+ "READ LINE FAILED bytes: %d ret: %02x", len(bytesread), bytesread[0]
+ )
+ return None
+
+ if len(bytesread) % expected_bytes != 0:
+ self._logger.debug(
+ "READ LINE WARNING: expected %d, got %d", expected_bytes, len(bytesread)
+ )
+
+ packet_count = len(bytesread) // expected_bytes
+
+ values = []
+ for i in range(0, packet_count):
+ start = i * expected_bytes
+ end = (i + 1) * expected_bytes
+ record = self.interpret_line(bytesread[start:end])
+ values.append(record)
+
+ return values
+
+ def interpret_line(self, data):
+ """Interpret a power record from INAs
+
+ Args:
+ data: one single record of bytes.
+
+ Output:
+ stdout of the record in csv format.
+
+ Returns:
+ dict containing name, value of recorded data.
+ """
+ status, size = struct.unpack("<BB", data[0:2])
+ if len(data) != self.report_size(size):
+ self._logger.error(
+ "READ LINE FAILED st:%d size:%d expected:%d len:%d",
+ status,
+ size,
+ self.report_size(size),
+ len(data),
+ )
+ else:
+ pass
- timestamp = struct.unpack("<Q", data[2:10])[0]
- self._logger.debug("READ LINE: st:%d size:%d time:%dus", status, size,
- timestamp)
- ftimestamp = float(timestamp) / 1000000.
+ timestamp = struct.unpack("<Q", data[2:10])[0]
+ self._logger.debug(
+ "READ LINE: st:%d size:%d time:%dus", status, size, timestamp
+ )
+ ftimestamp = float(timestamp) / 1000000.0
- record = {"ts": ftimestamp, "status": status, "berry":self._board}
+ record = {"ts": ftimestamp, "status": status, "berry": self._board}
- for i in range(0, size):
- idx = self.report_header_size() + 2*i
- name = self._inas[i]['name']
- name_tuple = (self._inas[i]['name'], self._inas[i]['type'])
+ for i in range(0, size):
+ idx = self.report_header_size() + 2 * i
+ name = self._inas[i]["name"]
+ name_tuple = (self._inas[i]["name"], self._inas[i]["type"])
- raw_val = struct.unpack("<h", data[idx:idx+2])[0]
+ raw_val = struct.unpack("<h", data[idx : idx + 2])[0]
- if self._inas[i]['type'] == Spower.INA_POWER:
- val = raw_val * self._inas[i]['uWscale']
- elif self._inas[i]['type'] == Spower.INA_BUSV:
- val = raw_val * self._inas[i]['mVscale']
- elif self._inas[i]['type'] == Spower.INA_CURRENT:
- val = raw_val * self._inas[i]['uAscale']
- elif self._inas[i]['type'] == Spower.INA_SHUNTV:
- val = raw_val * self._inas[i]['uVscale']
+ if self._inas[i]["type"] == Spower.INA_POWER:
+ val = raw_val * self._inas[i]["uWscale"]
+ elif self._inas[i]["type"] == Spower.INA_BUSV:
+ val = raw_val * self._inas[i]["mVscale"]
+ elif self._inas[i]["type"] == Spower.INA_CURRENT:
+ val = raw_val * self._inas[i]["uAscale"]
+ elif self._inas[i]["type"] == Spower.INA_SHUNTV:
+ val = raw_val * self._inas[i]["uVscale"]
- self._logger.debug("READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp,
- raw_val, val)
- record[name_tuple] = val
+ self._logger.debug(
+ "READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, raw_val, val
+ )
+ record[name_tuple] = val
- return record
+ return record
- def load_board(self, brdfile):
- """Load a board config.
+ def load_board(self, brdfile):
+ """Load a board config.
- Args:
- brdfile: Filename of a json file decribing the INA wiring of this board.
- """
- with open(process_filename(brdfile)) as data_file:
- data = json.load(data_file)
+ Args:
+ brdfile: Filename of a json file decribing the INA wiring of this board.
+ """
+ with open(process_filename(brdfile)) as data_file:
+ data = json.load(data_file)
- #TODO: validate this.
- self._brdcfg = data;
- self._logger.debug(pprint.pformat(data))
+ # TODO: validate this.
+ self._brdcfg = data
+ self._logger.debug(pprint.pformat(data))
class powerlog(object):
- """Power class to log aggregated power.
-
- Usage:
- obj = powerlog()
+ """Power class to log aggregated power.
- Instance Variables:
- _data: a StatsManager object that records sweetberry readings and calculates
- statistics.
- _pwr[]: Spower objects for individual sweetberries.
- """
-
- def __init__(self, brdfile, cfgfile, serial_a=None, serial_b=None,
- sync_date=False, use_ms=False, use_mW=False, print_stats=False,
- stats_dir=None, stats_json_dir=None, print_raw_data=True,
- raw_data_dir=None):
- """Init the powerlog class and set the variables.
-
- Args:
- brdfile: string name of json file containing board layout.
- cfgfile: string name of json containing list of rails to read.
- serial_a: serial number of sweetberry A.
- serial_b: serial number of sweetberry B.
- sync_date: report timestamps synced with host datetime.
- use_ms: report timestamps in ms rather than us.
- use_mW: report power as milliwatts, otherwise default to microwatts.
- print_stats: print statistics for sweetberry readings at the end.
- stats_dir: directory to save sweetberry readings statistics; if None then
- do not save the statistics.
- stats_json_dir: directory to save means of sweetberry readings in json
- format; if None then do not save the statistics.
- print_raw_data: print sweetberry readings raw data in real time, default
- is to print.
- raw_data_dir: directory to save sweetberry readings raw data; if None then
- do not save the raw data.
- """
- self._logger = logging.getLogger(__name__)
- self._data = StatsManager()
- self._pwr = {}
- self._use_ms = use_ms
- self._use_mW = use_mW
- self._print_stats = print_stats
- self._stats_dir = stats_dir
- self._stats_json_dir = stats_json_dir
- self._print_raw_data = print_raw_data
- self._raw_data_dir = raw_data_dir
-
- if not serial_a and not serial_b:
- self._pwr['A'] = Spower('A')
- if serial_a:
- self._pwr['A'] = Spower('A', serialname=serial_a)
- if serial_b:
- self._pwr['B'] = Spower('B', serialname=serial_b)
-
- with open(process_filename(cfgfile)) as data_file:
- names = json.load(data_file)
- self._names = self.process_scenario(names)
-
- for key in self._pwr:
- self._pwr[key].load_board(brdfile)
- self._pwr[key].reset()
-
- # Allocate the rails to the appropriate boards.
- used_boards = []
- for name in self._names:
- success = False
- for key in self._pwr.keys():
- if self._pwr[key].add_ina_name(name):
- success = True
- if key not in used_boards:
- used_boards.append(key)
- if not success:
- raise Exception("Failed to add %s (maybe missing "
- "sweetberry, or bad board file?)" % name)
-
- # Evict unused boards.
- for key in list(self._pwr.keys()):
- if key not in used_boards:
- self._pwr.pop(key)
-
- for key in self._pwr.keys():
- if sync_date:
- self._pwr[key].set_time(time.time() * 1000000)
- else:
- self._pwr[key].set_time(0)
-
- def process_scenario(self, name_list):
- """Return list of tuples indicating name and type.
+ Usage:
+ obj = powerlog()
- Args:
- json originated list of names, or [name, type]
- Returns:
- list of tuples of (name, type) defaulting to type "POWER"
- Raises: exception, invalid INA type.
+ Instance Variables:
+ _data: a StatsManager object that records sweetberry readings and calculates
+ statistics.
+ _pwr[]: Spower objects for individual sweetberries.
"""
- names = []
- for entry in name_list:
- if isinstance(entry, list):
- name = entry[0]
- if entry[1] == "POWER":
- type = Spower.INA_POWER
- elif entry[1] == "BUSV":
- type = Spower.INA_BUSV
- elif entry[1] == "CURRENT":
- type = Spower.INA_CURRENT
- elif entry[1] == "SHUNTV":
- type = Spower.INA_SHUNTV
- else:
- raise Exception("Invalid INA type", "Type of %s [%s] not recognized,"
- " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1]))
- else:
- name = entry
- type = Spower.INA_POWER
-
- names.append((name, type))
- return names
- def start(self, integration_us_request, seconds, sync_speed=.8):
- """Starts sampling.
+ def __init__(
+ self,
+ brdfile,
+ cfgfile,
+ serial_a=None,
+ serial_b=None,
+ sync_date=False,
+ use_ms=False,
+ use_mW=False,
+ print_stats=False,
+ stats_dir=None,
+ stats_json_dir=None,
+ print_raw_data=True,
+ raw_data_dir=None,
+ ):
+ """Init the powerlog class and set the variables.
+
+ Args:
+ brdfile: string name of json file containing board layout.
+ cfgfile: string name of json containing list of rails to read.
+ serial_a: serial number of sweetberry A.
+ serial_b: serial number of sweetberry B.
+ sync_date: report timestamps synced with host datetime.
+ use_ms: report timestamps in ms rather than us.
+ use_mW: report power as milliwatts, otherwise default to microwatts.
+ print_stats: print statistics for sweetberry readings at the end.
+ stats_dir: directory to save sweetberry readings statistics; if None then
+ do not save the statistics.
+ stats_json_dir: directory to save means of sweetberry readings in json
+ format; if None then do not save the statistics.
+ print_raw_data: print sweetberry readings raw data in real time, default
+ is to print.
+ raw_data_dir: directory to save sweetberry readings raw data; if None then
+ do not save the raw data.
+ """
+ self._logger = logging.getLogger(__name__)
+ self._data = StatsManager()
+ self._pwr = {}
+ self._use_ms = use_ms
+ self._use_mW = use_mW
+ self._print_stats = print_stats
+ self._stats_dir = stats_dir
+ self._stats_json_dir = stats_json_dir
+ self._print_raw_data = print_raw_data
+ self._raw_data_dir = raw_data_dir
+
+ if not serial_a and not serial_b:
+ self._pwr["A"] = Spower("A")
+ if serial_a:
+ self._pwr["A"] = Spower("A", serialname=serial_a)
+ if serial_b:
+ self._pwr["B"] = Spower("B", serialname=serial_b)
+
+ with open(process_filename(cfgfile)) as data_file:
+ names = json.load(data_file)
+ self._names = self.process_scenario(names)
- Args:
- integration_us_request: requested interval between sample values.
- seconds: time until exit, or None to run until cancel.
- sync_speed: A usb request is sent every [.8] * integration_us.
- """
- # We will get back the actual integration us.
- # It should be the same for all devices.
- integration_us = None
- for key in self._pwr:
- integration_us_new = self._pwr[key].start(integration_us_request)
- if integration_us:
- if integration_us != integration_us_new:
- raise Exception("FAIL",
- "Integration on A: %dus != integration on B %dus" % (
- integration_us, integration_us_new))
- integration_us = integration_us_new
-
- # CSV header
- title = "ts:%dus" % integration_us
- for name_tuple in self._names:
- name, ina_type = name_tuple
-
- if ina_type == Spower.INA_POWER:
- unit = "mW" if self._use_mW else "uW"
- elif ina_type == Spower.INA_BUSV:
- unit = "mV"
- elif ina_type == Spower.INA_CURRENT:
- unit = "uA"
- elif ina_type == Spower.INA_SHUNTV:
- unit = "uV"
-
- title += ", %s %s" % (name, unit)
- name_type = name + Spower.INA_SUFFIX[ina_type]
- self._data.SetUnit(name_type, unit)
- title += ", status"
- if self._print_raw_data:
- logoutput(title)
-
- forever = False
- if not seconds:
- forever = True
- end_time = time.time() + seconds
- try:
- pending_records = []
- while forever or end_time > time.time():
- if (integration_us > 5000):
- time.sleep((integration_us / 1000000.) * sync_speed)
for key in self._pwr:
- records = self._pwr[key].read_line()
- if not records:
- continue
-
- for record in records:
- pending_records.append(record)
-
- pending_records.sort(key=lambda r: r['ts'])
-
- aggregate_record = {"boards": set()}
- for record in pending_records:
- if record["berry"] not in aggregate_record["boards"]:
- for rkey in record.keys():
- aggregate_record[rkey] = record[rkey]
- aggregate_record["boards"].add(record["berry"])
- else:
- self._logger.info("break %s, %s", record["berry"],
- aggregate_record["boards"])
- break
-
- if aggregate_record["boards"] == set(self._pwr.keys()):
- csv = "%f" % aggregate_record["ts"]
- for name in self._names:
- if name in aggregate_record:
- multiplier = 0.001 if (self._use_mW and
- name[1]==Spower.INA_POWER) else 1
- value = aggregate_record[name] * multiplier
- csv += ", %.2f" % value
- name_type = name[0] + Spower.INA_SUFFIX[name[1]]
- self._data.AddSample(name_type, value)
- else:
- csv += ", "
- csv += ", %d" % aggregate_record["status"]
- if self._print_raw_data:
- logoutput(csv)
-
- aggregate_record = {"boards": set()}
- for r in range(0, len(self._pwr)):
- pending_records.pop(0)
-
- except KeyboardInterrupt:
- self._logger.info('\nCTRL+C caught.')
-
- finally:
- for key in self._pwr:
- self._pwr[key].stop()
- self._data.CalculateStats()
- if self._print_stats:
- print(self._data.SummaryToString())
- save_dir = 'sweetberry%s' % time.time()
- if self._stats_dir:
- stats_dir = os.path.join(self._stats_dir, save_dir)
- self._data.SaveSummary(stats_dir)
- if self._stats_json_dir:
- stats_json_dir = os.path.join(self._stats_json_dir, save_dir)
- self._data.SaveSummaryJSON(stats_json_dir)
- if self._raw_data_dir:
- raw_data_dir = os.path.join(self._raw_data_dir, save_dir)
- self._data.SaveRawData(raw_data_dir)
+ self._pwr[key].load_board(brdfile)
+ self._pwr[key].reset()
+
+ # Allocate the rails to the appropriate boards.
+ used_boards = []
+ for name in self._names:
+ success = False
+ for key in self._pwr.keys():
+ if self._pwr[key].add_ina_name(name):
+ success = True
+ if key not in used_boards:
+ used_boards.append(key)
+ if not success:
+ raise Exception(
+ "Failed to add %s (maybe missing "
+ "sweetberry, or bad board file?)" % name
+ )
+
+ # Evict unused boards.
+ for key in list(self._pwr.keys()):
+ if key not in used_boards:
+ self._pwr.pop(key)
+
+ for key in self._pwr.keys():
+ if sync_date:
+ self._pwr[key].set_time(time.time() * 1000000)
+ else:
+ self._pwr[key].set_time(0)
+
+ def process_scenario(self, name_list):
+ """Return list of tuples indicating name and type.
+
+ Args:
+ json originated list of names, or [name, type]
+ Returns:
+ list of tuples of (name, type) defaulting to type "POWER"
+ Raises: exception, invalid INA type.
+ """
+ names = []
+ for entry in name_list:
+ if isinstance(entry, list):
+ name = entry[0]
+ if entry[1] == "POWER":
+ type = Spower.INA_POWER
+ elif entry[1] == "BUSV":
+ type = Spower.INA_BUSV
+ elif entry[1] == "CURRENT":
+ type = Spower.INA_CURRENT
+ elif entry[1] == "SHUNTV":
+ type = Spower.INA_SHUNTV
+ else:
+ raise Exception(
+ "Invalid INA type",
+ "Type of %s [%s] not recognized,"
+ " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1]),
+ )
+ else:
+ name = entry
+ type = Spower.INA_POWER
+
+ names.append((name, type))
+ return names
+
+ def start(self, integration_us_request, seconds, sync_speed=0.8):
+ """Starts sampling.
+
+ Args:
+ integration_us_request: requested interval between sample values.
+ seconds: time until exit, or None to run until cancel.
+ sync_speed: A usb request is sent every [.8] * integration_us.
+ """
+ # We will get back the actual integration us.
+ # It should be the same for all devices.
+ integration_us = None
+ for key in self._pwr:
+ integration_us_new = self._pwr[key].start(integration_us_request)
+ if integration_us:
+ if integration_us != integration_us_new:
+ raise Exception(
+ "FAIL",
+ "Integration on A: %dus != integration on B %dus"
+ % (integration_us, integration_us_new),
+ )
+ integration_us = integration_us_new
+
+ # CSV header
+ title = "ts:%dus" % integration_us
+ for name_tuple in self._names:
+ name, ina_type = name_tuple
+
+ if ina_type == Spower.INA_POWER:
+ unit = "mW" if self._use_mW else "uW"
+ elif ina_type == Spower.INA_BUSV:
+ unit = "mV"
+ elif ina_type == Spower.INA_CURRENT:
+ unit = "uA"
+ elif ina_type == Spower.INA_SHUNTV:
+ unit = "uV"
+
+ title += ", %s %s" % (name, unit)
+ name_type = name + Spower.INA_SUFFIX[ina_type]
+ self._data.SetUnit(name_type, unit)
+ title += ", status"
+ if self._print_raw_data:
+ logoutput(title)
+
+ forever = False
+ if not seconds:
+ forever = True
+ end_time = time.time() + seconds
+ try:
+ pending_records = []
+ while forever or end_time > time.time():
+ if integration_us > 5000:
+ time.sleep((integration_us / 1000000.0) * sync_speed)
+ for key in self._pwr:
+ records = self._pwr[key].read_line()
+ if not records:
+ continue
+
+ for record in records:
+ pending_records.append(record)
+
+ pending_records.sort(key=lambda r: r["ts"])
+
+ aggregate_record = {"boards": set()}
+ for record in pending_records:
+ if record["berry"] not in aggregate_record["boards"]:
+ for rkey in record.keys():
+ aggregate_record[rkey] = record[rkey]
+ aggregate_record["boards"].add(record["berry"])
+ else:
+ self._logger.info(
+ "break %s, %s", record["berry"], aggregate_record["boards"]
+ )
+ break
+
+ if aggregate_record["boards"] == set(self._pwr.keys()):
+ csv = "%f" % aggregate_record["ts"]
+ for name in self._names:
+ if name in aggregate_record:
+ multiplier = (
+ 0.001
+ if (self._use_mW and name[1] == Spower.INA_POWER)
+ else 1
+ )
+ value = aggregate_record[name] * multiplier
+ csv += ", %.2f" % value
+ name_type = name[0] + Spower.INA_SUFFIX[name[1]]
+ self._data.AddSample(name_type, value)
+ else:
+ csv += ", "
+ csv += ", %d" % aggregate_record["status"]
+ if self._print_raw_data:
+ logoutput(csv)
+
+ aggregate_record = {"boards": set()}
+ for r in range(0, len(self._pwr)):
+ pending_records.pop(0)
+
+ except KeyboardInterrupt:
+ self._logger.info("\nCTRL+C caught.")
+
+ finally:
+ for key in self._pwr:
+ self._pwr[key].stop()
+ self._data.CalculateStats()
+ if self._print_stats:
+ print(self._data.SummaryToString())
+ save_dir = "sweetberry%s" % time.time()
+ if self._stats_dir:
+ stats_dir = os.path.join(self._stats_dir, save_dir)
+ self._data.SaveSummary(stats_dir)
+ if self._stats_json_dir:
+ stats_json_dir = os.path.join(self._stats_json_dir, save_dir)
+ self._data.SaveSummaryJSON(stats_json_dir)
+ if self._raw_data_dir:
+ raw_data_dir = os.path.join(self._raw_data_dir, save_dir)
+ self._data.SaveRawData(raw_data_dir)
def main(argv=None):
- if argv is None:
- argv = sys.argv[1:]
- # Command line argument description.
- parser = argparse.ArgumentParser(
- description="Gather CSV data from sweetberry")
- parser.add_argument('-b', '--board', type=str,
- help="Board configuration file, eg. my.board", default="")
- parser.add_argument('-c', '--config', type=str,
- help="Rail config to monitor, eg my.scenario", default="")
- parser.add_argument('-A', '--serial', type=str,
- help="Serial number of sweetberry A", default="")
- parser.add_argument('-B', '--serial_b', type=str,
- help="Serial number of sweetberry B", default="")
- parser.add_argument('-t', '--integration_us', type=int,
- help="Target integration time for samples", default=100000)
- parser.add_argument('-s', '--seconds', type=float,
- help="Seconds to run capture", default=0.)
- parser.add_argument('--date', default=False,
- help="Sync logged timestamp to host date", action="store_true")
- parser.add_argument('--ms', default=False,
- help="Print timestamp as milliseconds", action="store_true")
- parser.add_argument('--mW', default=False,
- help="Print power as milliwatts, otherwise default to microwatts",
- action="store_true")
- parser.add_argument('--slow', default=False,
- help="Intentionally overflow", action="store_true")
- parser.add_argument('--print_stats', default=False, action="store_true",
- help="Print statistics for sweetberry readings at the end")
- parser.add_argument('--save_stats', type=str, nargs='?',
- dest='stats_dir', metavar='STATS_DIR',
- const=os.path.dirname(os.path.abspath(__file__)), default=None,
- help="Save statistics for sweetberry readings to %(metavar)s if "
- "%(metavar)s is specified, %(metavar)s will be created if it does "
- "not exist; if %(metavar)s is not specified but the flag is set, "
- "stats will be saved to where %(prog)s is located; if this flag is "
- "not set, then do not save stats")
- parser.add_argument('--save_stats_json', type=str, nargs='?',
- dest='stats_json_dir', metavar='STATS_JSON_DIR',
- const=os.path.dirname(os.path.abspath(__file__)), default=None,
- help="Save means for sweetberry readings in json to %(metavar)s if "
- "%(metavar)s is specified, %(metavar)s will be created if it does "
- "not exist; if %(metavar)s is not specified but the flag is set, "
- "stats will be saved to where %(prog)s is located; if this flag is "
- "not set, then do not save stats")
- parser.add_argument('--no_print_raw_data',
- dest='print_raw_data', default=True, action="store_false",
- help="Not print raw sweetberry readings at real time, default is to "
- "print")
- parser.add_argument('--save_raw_data', type=str, nargs='?',
- dest='raw_data_dir', metavar='RAW_DATA_DIR',
- const=os.path.dirname(os.path.abspath(__file__)), default=None,
- help="Save raw data for sweetberry readings to %(metavar)s if "
- "%(metavar)s is specified, %(metavar)s will be created if it does "
- "not exist; if %(metavar)s is not specified but the flag is set, "
- "raw data will be saved to where %(prog)s is located; if this flag "
- "is not set, then do not save raw data")
- parser.add_argument('-v', '--verbose', default=False,
- help="Very chatty printout", action="store_true")
-
- args = parser.parse_args(argv)
-
- root_logger = logging.getLogger(__name__)
- if args.verbose:
- root_logger.setLevel(logging.DEBUG)
- else:
- root_logger.setLevel(logging.INFO)
-
- # if powerlog is used through main, log to sys.stdout
- if __name__ == "__main__":
- stdout_handler = logging.StreamHandler(sys.stdout)
- stdout_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
- root_logger.addHandler(stdout_handler)
-
- integration_us_request = args.integration_us
- if not args.board:
- raise Exception("Power", "No board file selected, see board.README")
- if not args.config:
- raise Exception("Power", "No config file selected, see board.README")
-
- brdfile = args.board
- cfgfile = args.config
- seconds = args.seconds
- serial_a = args.serial
- serial_b = args.serial_b
- sync_date = args.date
- use_ms = args.ms
- use_mW = args.mW
- print_stats = args.print_stats
- stats_dir = args.stats_dir
- stats_json_dir = args.stats_json_dir
- print_raw_data = args.print_raw_data
- raw_data_dir = args.raw_data_dir
-
- boards = []
-
- sync_speed = .8
- if args.slow:
- sync_speed = 1.2
-
- # Set up logging interface.
- powerlogger = powerlog(brdfile, cfgfile, serial_a=serial_a, serial_b=serial_b,
- sync_date=sync_date, use_ms=use_ms, use_mW=use_mW,
- print_stats=print_stats, stats_dir=stats_dir,
- stats_json_dir=stats_json_dir,
- print_raw_data=print_raw_data,raw_data_dir=raw_data_dir)
-
- # Start logging.
- powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed)
+ if argv is None:
+ argv = sys.argv[1:]
+ # Command line argument description.
+ parser = argparse.ArgumentParser(description="Gather CSV data from sweetberry")
+ parser.add_argument(
+ "-b",
+ "--board",
+ type=str,
+ help="Board configuration file, eg. my.board",
+ default="",
+ )
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=str,
+ help="Rail config to monitor, eg my.scenario",
+ default="",
+ )
+ parser.add_argument(
+ "-A", "--serial", type=str, help="Serial number of sweetberry A", default=""
+ )
+ parser.add_argument(
+ "-B", "--serial_b", type=str, help="Serial number of sweetberry B", default=""
+ )
+ parser.add_argument(
+ "-t",
+ "--integration_us",
+ type=int,
+ help="Target integration time for samples",
+ default=100000,
+ )
+ parser.add_argument(
+ "-s", "--seconds", type=float, help="Seconds to run capture", default=0.0
+ )
+ parser.add_argument(
+ "--date",
+ default=False,
+ help="Sync logged timestamp to host date",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--ms",
+ default=False,
+ help="Print timestamp as milliseconds",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--mW",
+ default=False,
+ help="Print power as milliwatts, otherwise default to microwatts",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--slow", default=False, help="Intentionally overflow", action="store_true"
+ )
+ parser.add_argument(
+ "--print_stats",
+ default=False,
+ action="store_true",
+ help="Print statistics for sweetberry readings at the end",
+ )
+ parser.add_argument(
+ "--save_stats",
+ type=str,
+ nargs="?",
+ dest="stats_dir",
+ metavar="STATS_DIR",
+ const=os.path.dirname(os.path.abspath(__file__)),
+ default=None,
+ help="Save statistics for sweetberry readings to %(metavar)s if "
+ "%(metavar)s is specified, %(metavar)s will be created if it does "
+ "not exist; if %(metavar)s is not specified but the flag is set, "
+ "stats will be saved to where %(prog)s is located; if this flag is "
+ "not set, then do not save stats",
+ )
+ parser.add_argument(
+ "--save_stats_json",
+ type=str,
+ nargs="?",
+ dest="stats_json_dir",
+ metavar="STATS_JSON_DIR",
+ const=os.path.dirname(os.path.abspath(__file__)),
+ default=None,
+ help="Save means for sweetberry readings in json to %(metavar)s if "
+ "%(metavar)s is specified, %(metavar)s will be created if it does "
+ "not exist; if %(metavar)s is not specified but the flag is set, "
+ "stats will be saved to where %(prog)s is located; if this flag is "
+ "not set, then do not save stats",
+ )
+ parser.add_argument(
+ "--no_print_raw_data",
+ dest="print_raw_data",
+ default=True,
+ action="store_false",
+ help="Not print raw sweetberry readings at real time, default is to " "print",
+ )
+ parser.add_argument(
+ "--save_raw_data",
+ type=str,
+ nargs="?",
+ dest="raw_data_dir",
+ metavar="RAW_DATA_DIR",
+ const=os.path.dirname(os.path.abspath(__file__)),
+ default=None,
+ help="Save raw data for sweetberry readings to %(metavar)s if "
+ "%(metavar)s is specified, %(metavar)s will be created if it does "
+ "not exist; if %(metavar)s is not specified but the flag is set, "
+ "raw data will be saved to where %(prog)s is located; if this flag "
+ "is not set, then do not save raw data",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ default=False,
+ help="Very chatty printout",
+ action="store_true",
+ )
+
+ args = parser.parse_args(argv)
+
+ root_logger = logging.getLogger(__name__)
+ if args.verbose:
+ root_logger.setLevel(logging.DEBUG)
+ else:
+ root_logger.setLevel(logging.INFO)
+
+ # if powerlog is used through main, log to sys.stdout
+ if __name__ == "__main__":
+ stdout_handler = logging.StreamHandler(sys.stdout)
+ stdout_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
+ root_logger.addHandler(stdout_handler)
+
+ integration_us_request = args.integration_us
+ if not args.board:
+ raise Exception("Power", "No board file selected, see board.README")
+ if not args.config:
+ raise Exception("Power", "No config file selected, see board.README")
+
+ brdfile = args.board
+ cfgfile = args.config
+ seconds = args.seconds
+ serial_a = args.serial
+ serial_b = args.serial_b
+ sync_date = args.date
+ use_ms = args.ms
+ use_mW = args.mW
+ print_stats = args.print_stats
+ stats_dir = args.stats_dir
+ stats_json_dir = args.stats_json_dir
+ print_raw_data = args.print_raw_data
+ raw_data_dir = args.raw_data_dir
+
+ boards = []
+
+ sync_speed = 0.8
+ if args.slow:
+ sync_speed = 1.2
+
+ # Set up logging interface.
+ powerlogger = powerlog(
+ brdfile,
+ cfgfile,
+ serial_a=serial_a,
+ serial_b=serial_b,
+ sync_date=sync_date,
+ use_ms=use_ms,
+ use_mW=use_mW,
+ print_stats=print_stats,
+ stats_dir=stats_dir,
+ stats_json_dir=stats_json_dir,
+ print_raw_data=print_raw_data,
+ raw_data_dir=raw_data_dir,
+ )
+
+ # Start logging.
+ powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed)
if __name__ == "__main__":
- main()
+ main()
diff --git a/extra/usb_power/powerlog_unittest.py b/extra/usb_power/powerlog_unittest.py
index 1d0718530e..693826c16d 100644
--- a/extra/usb_power/powerlog_unittest.py
+++ b/extra/usb_power/powerlog_unittest.py
@@ -15,40 +15,42 @@ import unittest
import powerlog
+
class TestPowerlog(unittest.TestCase):
- """Test to verify powerlog util methods work as expected."""
-
- def setUp(self):
- """Set up data and create a temporary directory to save data and stats."""
- self.tempdir = tempfile.mkdtemp()
- self.filename = 'testfile'
- self.filepath = os.path.join(self.tempdir, self.filename)
- with open(self.filepath, 'w') as f:
- f.write('')
-
- def tearDown(self):
- """Delete the temporary directory and its content."""
- shutil.rmtree(self.tempdir)
-
- def test_ProcessFilenameAbsoluteFilePath(self):
- """Absolute file path is returned unchanged."""
- processed_fname = powerlog.process_filename(self.filepath)
- self.assertEqual(self.filepath, processed_fname)
-
- def test_ProcessFilenameRelativeFilePath(self):
- """Finds relative file path inside a known config location."""
- original = powerlog.CONFIG_LOCATIONS
- powerlog.CONFIG_LOCATIONS = [self.tempdir]
- processed_fname = powerlog.process_filename(self.filename)
- try:
- self.assertEqual(self.filepath, processed_fname)
- finally:
- powerlog.CONFIG_LOCATIONS = original
-
- def test_ProcessFilenameInvalid(self):
- """IOError is raised when file cannot be found by any of the four ways."""
- with self.assertRaises(IOError):
- powerlog.process_filename(self.filename)
-
-if __name__ == '__main__':
- unittest.main()
+ """Test to verify powerlog util methods work as expected."""
+
+ def setUp(self):
+ """Set up data and create a temporary directory to save data and stats."""
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = "testfile"
+ self.filepath = os.path.join(self.tempdir, self.filename)
+ with open(self.filepath, "w") as f:
+ f.write("")
+
+ def tearDown(self):
+ """Delete the temporary directory and its content."""
+ shutil.rmtree(self.tempdir)
+
+ def test_ProcessFilenameAbsoluteFilePath(self):
+ """Absolute file path is returned unchanged."""
+ processed_fname = powerlog.process_filename(self.filepath)
+ self.assertEqual(self.filepath, processed_fname)
+
+ def test_ProcessFilenameRelativeFilePath(self):
+ """Finds relative file path inside a known config location."""
+ original = powerlog.CONFIG_LOCATIONS
+ powerlog.CONFIG_LOCATIONS = [self.tempdir]
+ processed_fname = powerlog.process_filename(self.filename)
+ try:
+ self.assertEqual(self.filepath, processed_fname)
+ finally:
+ powerlog.CONFIG_LOCATIONS = original
+
+ def test_ProcessFilenameInvalid(self):
+ """IOError is raised when file cannot be found by any of the four ways."""
+ with self.assertRaises(IOError):
+ powerlog.process_filename(self.filename)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extra/usb_power/stats_manager.py b/extra/usb_power/stats_manager.py
index 0f8c3fcb15..633312311d 100644
--- a/extra/usb_power/stats_manager.py
+++ b/extra/usb_power/stats_manager.py
@@ -20,382 +20,391 @@ import os
import numpy
-STATS_PREFIX = '@@'
-NAN_TAG = '*'
-NAN_DESCRIPTION = '%s domains contain NaN samples' % NAN_TAG
+STATS_PREFIX = "@@"
+NAN_TAG = "*"
+NAN_DESCRIPTION = "%s domains contain NaN samples" % NAN_TAG
LONG_UNIT = {
- '': 'N/A',
- 'mW': 'milliwatt',
- 'uW': 'microwatt',
- 'mV': 'millivolt',
- 'uA': 'microamp',
- 'uV': 'microvolt'
+ "": "N/A",
+ "mW": "milliwatt",
+ "uW": "microwatt",
+ "mV": "millivolt",
+ "uA": "microamp",
+ "uV": "microvolt",
}
class StatsManagerError(Exception):
- """Errors in StatsManager class."""
- pass
+ """Errors in StatsManager class."""
+ pass
-class StatsManager(object):
- """Calculates statistics for several lists of data(float).
-
- Example usage:
-
- >>> stats = StatsManager(title='Title Banner')
- >>> stats.AddSample(TIME_KEY, 50.0)
- >>> stats.AddSample(TIME_KEY, 25.0)
- >>> stats.AddSample(TIME_KEY, 40.0)
- >>> stats.AddSample(TIME_KEY, 10.0)
- >>> stats.AddSample(TIME_KEY, 10.0)
- >>> stats.AddSample('frobnicate', 11.5)
- >>> stats.AddSample('frobnicate', 9.0)
- >>> stats.AddSample('foobar', 11111.0)
- >>> stats.AddSample('foobar', 22222.0)
- >>> stats.CalculateStats()
- >>> print(stats.SummaryToString())
- ` @@--------------------------------------------------------------
- ` @@ Title Banner
- @@--------------------------------------------------------------
- @@ NAME COUNT MEAN STDDEV MAX MIN
- @@ sample_msecs 4 31.25 15.16 50.00 10.00
- @@ foobar 2 16666.50 5555.50 22222.00 11111.00
- @@ frobnicate 2 10.25 1.25 11.50 9.00
- ` @@--------------------------------------------------------------
-
- Attributes:
- _data: dict of list of readings for each domain(key)
- _unit: dict of unit for each domain(key)
- _smid: id supplied to differentiate data output to other StatsManager
- instances that potentially save to the same directory
- if smid all output files will be named |smid|_|fname|
- _title: title to add as banner to formatted summary. If no title,
- no banner gets added
- _order: list of formatting order for domains. Domains not listed are
- displayed in sorted order
- _hide_domains: collection of domains to hide when formatting summary string
- _accept_nan: flag to indicate if NaN samples are acceptable
- _nan_domains: set to keep track of which domains contain NaN samples
- _summary: dict of stats per domain (key): min, max, count, mean, stddev
- _logger = StatsManager logger
-
- Note:
- _summary is empty until CalculateStats() is called, and is updated when
- CalculateStats() is called.
- """
-
- # pylint: disable=W0102
- def __init__(self, smid='', title='', order=[], hide_domains=[],
- accept_nan=True):
- """Initialize infrastructure for data and their statistics."""
- self._title = title
- self._data = collections.defaultdict(list)
- self._unit = collections.defaultdict(str)
- self._smid = smid
- self._order = order
- self._hide_domains = hide_domains
- self._accept_nan = accept_nan
- self._nan_domains = set()
- self._summary = {}
- self._logger = logging.getLogger(type(self).__name__)
-
- def AddSample(self, domain, sample):
- """Add one sample for a domain.
-
- Args:
- domain: the domain name for the sample.
- sample: one time sample for domain, expect type float.
-
- Raises:
- StatsManagerError: if trying to add NaN and |_accept_nan| is false
- """
- try:
- sample = float(sample)
- except ValueError:
- # if we don't accept nan this will be caught below
- self._logger.debug('sample %s for domain %s is not a number. Making NaN',
- sample, domain)
- sample = float('NaN')
- if not self._accept_nan and math.isnan(sample):
- raise StatsManagerError('accept_nan is false. Cannot add NaN sample.')
- self._data[domain].append(sample)
- if math.isnan(sample):
- self._nan_domains.add(domain)
-
- def SetUnit(self, domain, unit):
- """Set the unit for a domain.
-
- There can be only one unit for each domain. Setting unit twice will
- overwrite the original unit.
-
- Args:
- domain: the domain name.
- unit: unit of the domain.
- """
- if domain in self._unit:
- self._logger.warning('overwriting the unit of %s, old unit is %s, new '
- 'unit is %s.', domain, self._unit[domain], unit)
- self._unit[domain] = unit
- def CalculateStats(self):
- """Calculate stats for all domain-data pairs.
-
- First erases all previous stats, then calculate stats for all data.
- """
- self._summary = {}
- for domain, data in self._data.items():
- data_np = numpy.array(data)
- self._summary[domain] = {
- 'mean': numpy.nanmean(data_np),
- 'min': numpy.nanmin(data_np),
- 'max': numpy.nanmax(data_np),
- 'stddev': numpy.nanstd(data_np),
- 'count': data_np.size,
- }
-
- @property
- def DomainsToDisplay(self):
- """List of domains that the manager will output in summaries."""
- return set(self._summary.keys()) - set(self._hide_domains)
-
- @property
- def NanInOutput(self):
- """Return whether any of the domains to display have NaN values."""
- return bool(len(set(self._nan_domains) & self.DomainsToDisplay))
-
- def _SummaryTable(self):
- """Generate the matrix to output as a summary.
-
- Returns:
- A 2d matrix of headers and their data for each domain
- e.g.
- [[NAME, COUNT, MEAN, STDDEV, MAX, MIN],
- [pp5000_mw, 10, 50, 0, 50, 50]]
- """
- headers = ('NAME', 'COUNT', 'MEAN', 'STDDEV', 'MAX', 'MIN')
- table = [headers]
- # determine what domains to display & and the order
- domains_to_display = self.DomainsToDisplay
- display_order = [key for key in self._order if key in domains_to_display]
- domains_to_display -= set(display_order)
- display_order.extend(sorted(domains_to_display))
- for domain in display_order:
- stats = self._summary[domain]
- if not domain.endswith(self._unit[domain]):
- domain = '%s_%s' % (domain, self._unit[domain])
- if domain in self._nan_domains:
- domain = '%s%s' % (domain, NAN_TAG)
- row = [domain]
- row.append(str(stats['count']))
- for entry in headers[2:]:
- row.append('%.2f' % stats[entry.lower()])
- table.append(row)
- return table
-
- def SummaryToMarkdownString(self):
- """Format the summary into a b/ compatible markdown table string.
-
- This requires this sort of output format
-
- | header1 | header2 | header3 | ...
- | --------- | --------- | --------- | ...
- | sample1h1 | sample1h2 | sample1h3 | ...
- .
- .
- .
-
- Returns:
- formatted summary string.
- """
- # All we need to do before processing is insert a row of '-' between
- # the headers, and the data
- table = self._SummaryTable()
- columns = len(table[0])
- # Using '-:' to allow the numbers to be right aligned
- sep_row = ['-'] + ['-:'] * (columns - 1)
- table.insert(1, sep_row)
- text_rows = ['|'.join(r) for r in table]
- body = '\n'.join(['|%s|' % r for r in text_rows])
- if self._title:
- title_section = '**%s** \n\n' % self._title
- body = title_section + body
- # Make sure that the body is terminated with a newline.
- return body + '\n'
-
- def SummaryToString(self, prefix=STATS_PREFIX):
- """Format summary into a string, ready for pretty print.
-
- See class description for format example.
-
- Args:
- prefix: start every row in summary string with prefix, for easier reading.
-
- Returns:
- formatted summary string.
- """
- table = self._SummaryTable()
- max_col_width = []
- for col_idx in range(len(table[0])):
- col_item_widths = [len(row[col_idx]) for row in table]
- max_col_width.append(max(col_item_widths))
-
- formatted_lines = []
- for row in table:
- formatted_row = prefix + ' '
- for i in range(len(row)):
- formatted_row += row[i].rjust(max_col_width[i] + 2)
- formatted_lines.append(formatted_row)
- if self.NanInOutput:
- formatted_lines.append('%s %s' % (prefix, NAN_DESCRIPTION))
-
- if self._title:
- line_length = len(formatted_lines[0])
- dec_length = len(prefix)
- # trim title to be at most as long as the longest line without the prefix
- title = self._title[:(line_length - dec_length)]
- # line is a seperator line consisting of -----
- line = '%s%s' % (prefix, '-' * (line_length - dec_length))
- # prepend the prefix to the centered title
- padded_title = '%s%s' % (prefix, title.center(line_length)[dec_length:])
- formatted_lines = [line, padded_title, line] + formatted_lines + [line]
- formatted_output = '\n'.join(formatted_lines)
- return formatted_output
-
- def GetSummary(self):
- """Getter for summary."""
- return self._summary
-
- def _MakeUniqueFName(self, fname):
- """prepend |_smid| to fname & rotate fname to ensure uniqueness.
-
- Before saving a file through the StatsManager, make sure that the filename
- is unique, first by prepending the smid if any and otherwise by appending
- increasing integer suffixes until the filename is unique.
-
- If |smid| is defined /path/to/example/file.txt becomes
- /path/to/example/{smid}_file.txt.
-
- The rotation works by changing /path/to/example/somename.txt to
- /path/to/example/somename1.txt if the first one already exists on the
- system.
-
- Note: this is not thread-safe. While it makes sense to use StatsManager
- in a threaded data-collection, the data retrieval should happen in a
- single threaded environment to ensure files don't get potentially clobbered.
-
- Args:
- fname: filename to ensure uniqueness.
-
- Returns:
- {smid_}fname{tag}.[b].ext
- the smid portion gets prepended if |smid| is defined
- the tag portion gets appended if necessary to ensure unique fname
- """
- fdir = os.path.dirname(fname)
- base, ext = os.path.splitext(os.path.basename(fname))
- if self._smid:
- base = '%s_%s' % (self._smid, base)
- unique_fname = os.path.join(fdir, '%s%s' % (base, ext))
- tag = 0
- while os.path.exists(unique_fname):
- old_fname = unique_fname
- unique_fname = os.path.join(fdir, '%s%d%s' % (base, tag, ext))
- self._logger.warning('Attempted to store stats information at %s, but '
- 'file already exists. Attempting to store at %s '
- 'now.', old_fname, unique_fname)
- tag += 1
- return unique_fname
-
- def SaveSummary(self, directory, fname='summary.txt', prefix=STATS_PREFIX):
- """Save summary to file.
-
- Args:
- directory: directory to save the summary in.
- fname: filename to save summary under.
- prefix: start every row in summary string with prefix, for easier reading.
-
- Returns:
- full path of summary save location
- """
- summary_str = self.SummaryToString(prefix=prefix) + '\n'
- return self._SaveSummary(summary_str, directory, fname)
-
- def SaveSummaryJSON(self, directory, fname='summary.json'):
- """Save summary (only MEAN) into a JSON file.
-
- Args:
- directory: directory to save the JSON summary in.
- fname: filename to save summary under.
-
- Returns:
- full path of summary save location
- """
- data = {}
- for domain in self._summary:
- unit = LONG_UNIT.get(self._unit[domain], self._unit[domain])
- data_entry = {'mean': self._summary[domain]['mean'], 'unit': unit}
- data[domain] = data_entry
- summary_str = json.dumps(data, indent=2)
- return self._SaveSummary(summary_str, directory, fname)
-
- def SaveSummaryMD(self, directory, fname='summary.md'):
- """Save summary into a MD file to paste into b/.
-
- Args:
- directory: directory to save the MD summary in.
- fname: filename to save summary under.
-
- Returns:
- full path of summary save location
+class StatsManager(object):
+ """Calculates statistics for several lists of data(float).
+
+ Example usage:
+
+ >>> stats = StatsManager(title='Title Banner')
+ >>> stats.AddSample(TIME_KEY, 50.0)
+ >>> stats.AddSample(TIME_KEY, 25.0)
+ >>> stats.AddSample(TIME_KEY, 40.0)
+ >>> stats.AddSample(TIME_KEY, 10.0)
+ >>> stats.AddSample(TIME_KEY, 10.0)
+ >>> stats.AddSample('frobnicate', 11.5)
+ >>> stats.AddSample('frobnicate', 9.0)
+ >>> stats.AddSample('foobar', 11111.0)
+ >>> stats.AddSample('foobar', 22222.0)
+ >>> stats.CalculateStats()
+ >>> print(stats.SummaryToString())
+ ` @@--------------------------------------------------------------
+ ` @@ Title Banner
+ @@--------------------------------------------------------------
+ @@ NAME COUNT MEAN STDDEV MAX MIN
+ @@ sample_msecs 4 31.25 15.16 50.00 10.00
+ @@ foobar 2 16666.50 5555.50 22222.00 11111.00
+ @@ frobnicate 2 10.25 1.25 11.50 9.00
+ ` @@--------------------------------------------------------------
+
+ Attributes:
+ _data: dict of list of readings for each domain(key)
+ _unit: dict of unit for each domain(key)
+ _smid: id supplied to differentiate data output to other StatsManager
+ instances that potentially save to the same directory
+ if smid all output files will be named |smid|_|fname|
+ _title: title to add as banner to formatted summary. If no title,
+ no banner gets added
+ _order: list of formatting order for domains. Domains not listed are
+ displayed in sorted order
+ _hide_domains: collection of domains to hide when formatting summary string
+ _accept_nan: flag to indicate if NaN samples are acceptable
+ _nan_domains: set to keep track of which domains contain NaN samples
+ _summary: dict of stats per domain (key): min, max, count, mean, stddev
+ _logger = StatsManager logger
+
+ Note:
+ _summary is empty until CalculateStats() is called, and is updated when
+ CalculateStats() is called.
"""
- summary_str = self.SummaryToMarkdownString()
- return self._SaveSummary(summary_str, directory, fname)
- def _SaveSummary(self, output_str, directory, fname):
- """Wrote |output_str| to |fname|.
-
- Args:
- output_str: formatted output string
- directory: directory to save the summary in.
- fname: filename to save summary under.
-
- Returns:
- full path of summary save location
- """
- if not os.path.exists(directory):
- os.makedirs(directory)
- fname = self._MakeUniqueFName(os.path.join(directory, fname))
- with open(fname, 'w') as f:
- f.write(output_str)
- return fname
-
- def GetRawData(self):
- """Getter for all raw_data."""
- return self._data
-
- def SaveRawData(self, directory, dirname='raw_data'):
- """Save raw data to file.
-
- Args:
- directory: directory to create the raw data folder in.
- dirname: folder in which raw data live.
-
- Returns:
- list of full path of each domain's raw data save location
- """
- if not os.path.exists(directory):
- os.makedirs(directory)
- dirname = os.path.join(directory, dirname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- fnames = []
- for domain, data in self._data.items():
- if not domain.endswith(self._unit[domain]):
- domain = '%s_%s' % (domain, self._unit[domain])
- fname = self._MakeUniqueFName(os.path.join(dirname, '%s.txt' % domain))
- with open(fname, 'w') as f:
- f.write('\n'.join('%.2f' % sample for sample in data) + '\n')
- fnames.append(fname)
- return fnames
+ # pylint: disable=W0102
+ def __init__(self, smid="", title="", order=[], hide_domains=[], accept_nan=True):
+ """Initialize infrastructure for data and their statistics."""
+ self._title = title
+ self._data = collections.defaultdict(list)
+ self._unit = collections.defaultdict(str)
+ self._smid = smid
+ self._order = order
+ self._hide_domains = hide_domains
+ self._accept_nan = accept_nan
+ self._nan_domains = set()
+ self._summary = {}
+ self._logger = logging.getLogger(type(self).__name__)
+
+ def AddSample(self, domain, sample):
+ """Add one sample for a domain.
+
+ Args:
+ domain: the domain name for the sample.
+ sample: one time sample for domain, expect type float.
+
+ Raises:
+ StatsManagerError: if trying to add NaN and |_accept_nan| is false
+ """
+ try:
+ sample = float(sample)
+ except ValueError:
+ # if we don't accept nan this will be caught below
+ self._logger.debug(
+ "sample %s for domain %s is not a number. Making NaN", sample, domain
+ )
+ sample = float("NaN")
+ if not self._accept_nan and math.isnan(sample):
+ raise StatsManagerError("accept_nan is false. Cannot add NaN sample.")
+ self._data[domain].append(sample)
+ if math.isnan(sample):
+ self._nan_domains.add(domain)
+
+ def SetUnit(self, domain, unit):
+ """Set the unit for a domain.
+
+ There can be only one unit for each domain. Setting unit twice will
+ overwrite the original unit.
+
+ Args:
+ domain: the domain name.
+ unit: unit of the domain.
+ """
+ if domain in self._unit:
+ self._logger.warning(
+ "overwriting the unit of %s, old unit is %s, new " "unit is %s.",
+ domain,
+ self._unit[domain],
+ unit,
+ )
+ self._unit[domain] = unit
+
+ def CalculateStats(self):
+ """Calculate stats for all domain-data pairs.
+
+ First erases all previous stats, then calculate stats for all data.
+ """
+ self._summary = {}
+ for domain, data in self._data.items():
+ data_np = numpy.array(data)
+ self._summary[domain] = {
+ "mean": numpy.nanmean(data_np),
+ "min": numpy.nanmin(data_np),
+ "max": numpy.nanmax(data_np),
+ "stddev": numpy.nanstd(data_np),
+ "count": data_np.size,
+ }
+
+ @property
+ def DomainsToDisplay(self):
+ """List of domains that the manager will output in summaries."""
+ return set(self._summary.keys()) - set(self._hide_domains)
+
+ @property
+ def NanInOutput(self):
+ """Return whether any of the domains to display have NaN values."""
+ return bool(len(set(self._nan_domains) & self.DomainsToDisplay))
+
+ def _SummaryTable(self):
+ """Generate the matrix to output as a summary.
+
+ Returns:
+ A 2d matrix of headers and their data for each domain
+ e.g.
+ [[NAME, COUNT, MEAN, STDDEV, MAX, MIN],
+ [pp5000_mw, 10, 50, 0, 50, 50]]
+ """
+ headers = ("NAME", "COUNT", "MEAN", "STDDEV", "MAX", "MIN")
+ table = [headers]
+ # determine what domains to display & and the order
+ domains_to_display = self.DomainsToDisplay
+ display_order = [key for key in self._order if key in domains_to_display]
+ domains_to_display -= set(display_order)
+ display_order.extend(sorted(domains_to_display))
+ for domain in display_order:
+ stats = self._summary[domain]
+ if not domain.endswith(self._unit[domain]):
+ domain = "%s_%s" % (domain, self._unit[domain])
+ if domain in self._nan_domains:
+ domain = "%s%s" % (domain, NAN_TAG)
+ row = [domain]
+ row.append(str(stats["count"]))
+ for entry in headers[2:]:
+ row.append("%.2f" % stats[entry.lower()])
+ table.append(row)
+ return table
+
+ def SummaryToMarkdownString(self):
+ """Format the summary into a b/ compatible markdown table string.
+
+ This requires this sort of output format
+
+ | header1 | header2 | header3 | ...
+ | --------- | --------- | --------- | ...
+ | sample1h1 | sample1h2 | sample1h3 | ...
+ .
+ .
+ .
+
+ Returns:
+ formatted summary string.
+ """
+ # All we need to do before processing is insert a row of '-' between
+ # the headers, and the data
+ table = self._SummaryTable()
+ columns = len(table[0])
+ # Using '-:' to allow the numbers to be right aligned
+ sep_row = ["-"] + ["-:"] * (columns - 1)
+ table.insert(1, sep_row)
+ text_rows = ["|".join(r) for r in table]
+ body = "\n".join(["|%s|" % r for r in text_rows])
+ if self._title:
+ title_section = "**%s** \n\n" % self._title
+ body = title_section + body
+ # Make sure that the body is terminated with a newline.
+ return body + "\n"
+
+ def SummaryToString(self, prefix=STATS_PREFIX):
+ """Format summary into a string, ready for pretty print.
+
+ See class description for format example.
+
+ Args:
+ prefix: start every row in summary string with prefix, for easier reading.
+
+ Returns:
+ formatted summary string.
+ """
+ table = self._SummaryTable()
+ max_col_width = []
+ for col_idx in range(len(table[0])):
+ col_item_widths = [len(row[col_idx]) for row in table]
+ max_col_width.append(max(col_item_widths))
+
+ formatted_lines = []
+ for row in table:
+ formatted_row = prefix + " "
+ for i in range(len(row)):
+ formatted_row += row[i].rjust(max_col_width[i] + 2)
+ formatted_lines.append(formatted_row)
+ if self.NanInOutput:
+ formatted_lines.append("%s %s" % (prefix, NAN_DESCRIPTION))
+
+ if self._title:
+ line_length = len(formatted_lines[0])
+ dec_length = len(prefix)
+ # trim title to be at most as long as the longest line without the prefix
+ title = self._title[: (line_length - dec_length)]
+ # line is a seperator line consisting of -----
+ line = "%s%s" % (prefix, "-" * (line_length - dec_length))
+ # prepend the prefix to the centered title
+ padded_title = "%s%s" % (prefix, title.center(line_length)[dec_length:])
+ formatted_lines = [line, padded_title, line] + formatted_lines + [line]
+ formatted_output = "\n".join(formatted_lines)
+ return formatted_output
+
+ def GetSummary(self):
+ """Getter for summary."""
+ return self._summary
+
+ def _MakeUniqueFName(self, fname):
+ """prepend |_smid| to fname & rotate fname to ensure uniqueness.
+
+ Before saving a file through the StatsManager, make sure that the filename
+ is unique, first by prepending the smid if any and otherwise by appending
+ increasing integer suffixes until the filename is unique.
+
+ If |smid| is defined /path/to/example/file.txt becomes
+ /path/to/example/{smid}_file.txt.
+
+ The rotation works by changing /path/to/example/somename.txt to
+ /path/to/example/somename1.txt if the first one already exists on the
+ system.
+
+ Note: this is not thread-safe. While it makes sense to use StatsManager
+ in a threaded data-collection, the data retrieval should happen in a
+ single threaded environment to ensure files don't get potentially clobbered.
+
+ Args:
+ fname: filename to ensure uniqueness.
+
+ Returns:
+ {smid_}fname{tag}.[b].ext
+ the smid portion gets prepended if |smid| is defined
+ the tag portion gets appended if necessary to ensure unique fname
+ """
+ fdir = os.path.dirname(fname)
+ base, ext = os.path.splitext(os.path.basename(fname))
+ if self._smid:
+ base = "%s_%s" % (self._smid, base)
+ unique_fname = os.path.join(fdir, "%s%s" % (base, ext))
+ tag = 0
+ while os.path.exists(unique_fname):
+ old_fname = unique_fname
+ unique_fname = os.path.join(fdir, "%s%d%s" % (base, tag, ext))
+ self._logger.warning(
+ "Attempted to store stats information at %s, but "
+ "file already exists. Attempting to store at %s "
+ "now.",
+ old_fname,
+ unique_fname,
+ )
+ tag += 1
+ return unique_fname
+
+ def SaveSummary(self, directory, fname="summary.txt", prefix=STATS_PREFIX):
+ """Save summary to file.
+
+ Args:
+ directory: directory to save the summary in.
+ fname: filename to save summary under.
+ prefix: start every row in summary string with prefix, for easier reading.
+
+ Returns:
+ full path of summary save location
+ """
+ summary_str = self.SummaryToString(prefix=prefix) + "\n"
+ return self._SaveSummary(summary_str, directory, fname)
+
+ def SaveSummaryJSON(self, directory, fname="summary.json"):
+ """Save summary (only MEAN) into a JSON file.
+
+ Args:
+ directory: directory to save the JSON summary in.
+ fname: filename to save summary under.
+
+ Returns:
+ full path of summary save location
+ """
+ data = {}
+ for domain in self._summary:
+ unit = LONG_UNIT.get(self._unit[domain], self._unit[domain])
+ data_entry = {"mean": self._summary[domain]["mean"], "unit": unit}
+ data[domain] = data_entry
+ summary_str = json.dumps(data, indent=2)
+ return self._SaveSummary(summary_str, directory, fname)
+
+ def SaveSummaryMD(self, directory, fname="summary.md"):
+ """Save summary into a MD file to paste into b/.
+
+ Args:
+ directory: directory to save the MD summary in.
+ fname: filename to save summary under.
+
+ Returns:
+ full path of summary save location
+ """
+ summary_str = self.SummaryToMarkdownString()
+ return self._SaveSummary(summary_str, directory, fname)
+
+ def _SaveSummary(self, output_str, directory, fname):
+ """Wrote |output_str| to |fname|.
+
+ Args:
+ output_str: formatted output string
+ directory: directory to save the summary in.
+ fname: filename to save summary under.
+
+ Returns:
+ full path of summary save location
+ """
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ fname = self._MakeUniqueFName(os.path.join(directory, fname))
+ with open(fname, "w") as f:
+ f.write(output_str)
+ return fname
+
+ def GetRawData(self):
+ """Getter for all raw_data."""
+ return self._data
+
+ def SaveRawData(self, directory, dirname="raw_data"):
+ """Save raw data to file.
+
+ Args:
+ directory: directory to create the raw data folder in.
+ dirname: folder in which raw data live.
+
+ Returns:
+ list of full path of each domain's raw data save location
+ """
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ dirname = os.path.join(directory, dirname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ fnames = []
+ for domain, data in self._data.items():
+ if not domain.endswith(self._unit[domain]):
+ domain = "%s_%s" % (domain, self._unit[domain])
+ fname = self._MakeUniqueFName(os.path.join(dirname, "%s.txt" % domain))
+ with open(fname, "w") as f:
+ f.write("\n".join("%.2f" % sample for sample in data) + "\n")
+ fnames.append(fname)
+ return fnames
diff --git a/extra/usb_power/stats_manager_unittest.py b/extra/usb_power/stats_manager_unittest.py
index beb9984b93..37a472af98 100644
--- a/extra/usb_power/stats_manager_unittest.py
+++ b/extra/usb_power/stats_manager_unittest.py
@@ -9,6 +9,7 @@
"""Unit tests for StatsManager."""
from __future__ import print_function
+
import json
import os
import re
@@ -20,296 +21,304 @@ import stats_manager
class TestStatsManager(unittest.TestCase):
- """Test to verify StatsManager methods work as expected.
-
- StatsManager should collect raw data, calculate their statistics, and save
- them in expected format.
- """
-
- def _populate_mock_stats(self):
- """Create a populated & processed StatsManager to test data retrieval."""
- self.data.AddSample('A', 99999.5)
- self.data.AddSample('A', 100000.5)
- self.data.SetUnit('A', 'uW')
- self.data.SetUnit('A', 'mW')
- self.data.AddSample('B', 1.5)
- self.data.AddSample('B', 2.5)
- self.data.AddSample('B', 3.5)
- self.data.SetUnit('B', 'mV')
- self.data.CalculateStats()
-
- def _populate_mock_stats_no_unit(self):
- self.data.AddSample('B', 1000)
- self.data.AddSample('A', 200)
- self.data.SetUnit('A', 'blue')
-
- def setUp(self):
- """Set up StatsManager and create a temporary directory for test."""
- self.tempdir = tempfile.mkdtemp()
- self.data = stats_manager.StatsManager()
-
- def tearDown(self):
- """Delete the temporary directory and its content."""
- shutil.rmtree(self.tempdir)
-
- def test_AddSample(self):
- """Adding a sample successfully adds a sample."""
- self.data.AddSample('Test', 1000)
- self.data.SetUnit('Test', 'test')
- self.data.CalculateStats()
- summary = self.data.GetSummary()
- self.assertEqual(1, summary['Test']['count'])
-
- def test_AddSampleNoFloatAcceptNaN(self):
- """Adding a non-number adds 'NaN' and doesn't raise an exception."""
- self.data.AddSample('Test', 10)
- self.data.AddSample('Test', 20)
- # adding a fake NaN: one that gets converted into NaN internally
- self.data.AddSample('Test', 'fiesta')
- # adding a real NaN
- self.data.AddSample('Test', float('NaN'))
- self.data.SetUnit('Test', 'test')
- self.data.CalculateStats()
- summary = self.data.GetSummary()
- # assert that 'NaN' as added.
- self.assertEqual(4, summary['Test']['count'])
- # assert that mean, min, and max calculatings ignore the 'NaN'
- self.assertEqual(10, summary['Test']['min'])
- self.assertEqual(20, summary['Test']['max'])
- self.assertEqual(15, summary['Test']['mean'])
-
- def test_AddSampleNoFloatNotAcceptNaN(self):
- """Adding a non-number raises a StatsManagerError if accept_nan is False."""
- self.data = stats_manager.StatsManager(accept_nan=False)
- with self.assertRaisesRegexp(stats_manager.StatsManagerError,
- 'accept_nan is false. Cannot add NaN sample.'):
- # adding a fake NaN: one that gets converted into NaN internally
- self.data.AddSample('Test', 'fiesta')
- with self.assertRaisesRegexp(stats_manager.StatsManagerError,
- 'accept_nan is false. Cannot add NaN sample.'):
- # adding a real NaN
- self.data.AddSample('Test', float('NaN'))
-
- def test_AddSampleNoUnit(self):
- """Not adding a unit does not cause an exception on CalculateStats()."""
- self.data.AddSample('Test', 17)
- self.data.CalculateStats()
- summary = self.data.GetSummary()
- self.assertEqual(1, summary['Test']['count'])
-
- def test_UnitSuffix(self):
- """Unit gets appended as a suffix in the displayed summary."""
- self.data.AddSample('test', 250)
- self.data.SetUnit('test', 'mw')
- self.data.CalculateStats()
- summary_str = self.data.SummaryToString()
- self.assertIn('test_mw', summary_str)
-
- def test_DoubleUnitSuffix(self):
- """If domain already ends in unit, verify that unit doesn't get appended."""
- self.data.AddSample('test_mw', 250)
- self.data.SetUnit('test_mw', 'mw')
- self.data.CalculateStats()
- summary_str = self.data.SummaryToString()
- self.assertIn('test_mw', summary_str)
- self.assertNotIn('test_mw_mw', summary_str)
-
- def test_GetRawData(self):
- """GetRawData returns exact same data as fed in."""
- self._populate_mock_stats()
- raw_data = self.data.GetRawData()
- self.assertListEqual([99999.5, 100000.5], raw_data['A'])
- self.assertListEqual([1.5, 2.5, 3.5], raw_data['B'])
-
- def test_GetSummary(self):
- """GetSummary returns expected stats about the data fed in."""
- self._populate_mock_stats()
- summary = self.data.GetSummary()
- self.assertEqual(2, summary['A']['count'])
- self.assertAlmostEqual(100000.5, summary['A']['max'])
- self.assertAlmostEqual(99999.5, summary['A']['min'])
- self.assertAlmostEqual(0.5, summary['A']['stddev'])
- self.assertAlmostEqual(100000.0, summary['A']['mean'])
- self.assertEqual(3, summary['B']['count'])
- self.assertAlmostEqual(3.5, summary['B']['max'])
- self.assertAlmostEqual(1.5, summary['B']['min'])
- self.assertAlmostEqual(0.81649658092773, summary['B']['stddev'])
- self.assertAlmostEqual(2.5, summary['B']['mean'])
-
- def test_SaveRawData(self):
- """SaveRawData stores same data as fed in."""
- self._populate_mock_stats()
- dirname = 'unittest_raw_data'
- expected_files = set(['A_mW.txt', 'B_mV.txt'])
- fnames = self.data.SaveRawData(self.tempdir, dirname)
- files_returned = set([os.path.basename(f) for f in fnames])
- # Assert that only the expected files got returned.
- self.assertEqual(expected_files, files_returned)
- # Assert that only the returned files are in the outdir.
- self.assertEqual(set(os.listdir(os.path.join(self.tempdir, dirname))),
- files_returned)
- for fname in fnames:
- with open(fname, 'r') as f:
- if 'A_mW' in fname:
- self.assertEqual('99999.50', f.readline().strip())
- self.assertEqual('100000.50', f.readline().strip())
- if 'B_mV' in fname:
- self.assertEqual('1.50', f.readline().strip())
- self.assertEqual('2.50', f.readline().strip())
- self.assertEqual('3.50', f.readline().strip())
-
- def test_SaveRawDataNoUnit(self):
- """SaveRawData appends no unit suffix if the unit is not specified."""
- self._populate_mock_stats_no_unit()
- self.data.CalculateStats()
- outdir = 'unittest_raw_data'
- files = self.data.SaveRawData(self.tempdir, outdir)
- files = [os.path.basename(f) for f in files]
- # Verify nothing gets appended to domain for filename if no unit exists.
- self.assertIn('B.txt', files)
-
- def test_SaveRawDataSMID(self):
- """SaveRawData uses the smid when creating output filename."""
- identifier = 'ec'
- self.data = stats_manager.StatsManager(smid=identifier)
- self._populate_mock_stats()
- files = self.data.SaveRawData(self.tempdir)
- for fname in files:
- self.assertTrue(os.path.basename(fname).startswith(identifier))
-
- def test_SummaryToStringNaNHelp(self):
- """NaN containing row gets tagged with *, help banner gets added."""
- help_banner_exp = '%s %s' % (stats_manager.STATS_PREFIX,
- stats_manager.NAN_DESCRIPTION)
- nan_domain = 'A-domain'
- nan_domain_exp = '%s%s' % (nan_domain, stats_manager.NAN_TAG)
- # NaN helper banner is added when a NaN domain is found & domain gets tagged
- data = stats_manager.StatsManager()
- data.AddSample(nan_domain, float('NaN'))
- data.AddSample(nan_domain, 17)
- data.AddSample('B-domain', 17)
- data.CalculateStats()
- summarystr = data.SummaryToString()
- self.assertIn(help_banner_exp, summarystr)
- self.assertIn(nan_domain_exp, summarystr)
- # NaN helper banner is not added when no NaN domain output, no tagging
- data = stats_manager.StatsManager()
- # nan_domain in this scenario does not contain any NaN
- data.AddSample(nan_domain, 19)
- data.AddSample('B-domain', 17)
- data.CalculateStats()
- summarystr = data.SummaryToString()
- self.assertNotIn(help_banner_exp, summarystr)
- self.assertNotIn(nan_domain_exp, summarystr)
-
- def test_SummaryToStringTitle(self):
- """Title shows up in SummaryToString if title specified."""
- title = 'titulo'
- data = stats_manager.StatsManager(title=title)
- self._populate_mock_stats()
- summary_str = data.SummaryToString()
- self.assertIn(title, summary_str)
-
- def test_SummaryToStringHideDomains(self):
- """Keys indicated in hide_domains are not printed in the summary."""
- data = stats_manager.StatsManager(hide_domains=['A-domain'])
- data.AddSample('A-domain', 17)
- data.AddSample('B-domain', 17)
- data.CalculateStats()
- summary_str = data.SummaryToString()
- self.assertIn('B-domain', summary_str)
- self.assertNotIn('A-domain', summary_str)
-
- def test_SummaryToStringOrder(self):
- """Order passed into StatsManager is honoured when formatting summary."""
- # StatsManager that should print D & B first, and the subsequent elements
- # are sorted.
- d_b_a_c_regexp = re.compile('D-domain.*B-domain.*A-domain.*C-domain',
- re.DOTALL)
- data = stats_manager.StatsManager(order=['D-domain', 'B-domain'])
- data.AddSample('A-domain', 17)
- data.AddSample('B-domain', 17)
- data.AddSample('C-domain', 17)
- data.AddSample('D-domain', 17)
- data.CalculateStats()
- summary_str = data.SummaryToString()
- self.assertRegexpMatches(summary_str, d_b_a_c_regexp)
-
- def test_MakeUniqueFName(self):
- data = stats_manager.StatsManager()
- testfile = os.path.join(self.tempdir, 'testfile.txt')
- with open(testfile, 'w') as f:
- f.write('')
- expected_fname = os.path.join(self.tempdir, 'testfile0.txt')
- self.assertEqual(expected_fname, data._MakeUniqueFName(testfile))
-
- def test_SaveSummary(self):
- """SaveSummary properly dumps the summary into a file."""
- self._populate_mock_stats()
- fname = 'unittest_summary.txt'
- expected_fname = os.path.join(self.tempdir, fname)
- fname = self.data.SaveSummary(self.tempdir, fname)
- # Assert the reported fname is the same as the expected fname
- self.assertEqual(expected_fname, fname)
- # Assert only the reported fname is output (in the tempdir)
- self.assertEqual(set([os.path.basename(fname)]),
- set(os.listdir(self.tempdir)))
- with open(fname, 'r') as f:
- self.assertEqual(
- '@@ NAME COUNT MEAN STDDEV MAX MIN\n',
- f.readline())
- self.assertEqual(
- '@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n',
- f.readline())
- self.assertEqual(
- '@@ B_mV 3 2.50 0.82 3.50 1.50\n',
- f.readline())
-
- def test_SaveSummarySMID(self):
- """SaveSummary uses the smid when creating output filename."""
- identifier = 'ec'
- self.data = stats_manager.StatsManager(smid=identifier)
- self._populate_mock_stats()
- fname = os.path.basename(self.data.SaveSummary(self.tempdir))
- self.assertTrue(fname.startswith(identifier))
-
- def test_SaveSummaryJSON(self):
- """SaveSummaryJSON saves the added data properly in JSON format."""
- self._populate_mock_stats()
- fname = 'unittest_summary.json'
- expected_fname = os.path.join(self.tempdir, fname)
- fname = self.data.SaveSummaryJSON(self.tempdir, fname)
- # Assert the reported fname is the same as the expected fname
- self.assertEqual(expected_fname, fname)
- # Assert only the reported fname is output (in the tempdir)
- self.assertEqual(set([os.path.basename(fname)]),
- set(os.listdir(self.tempdir)))
- with open(fname, 'r') as f:
- summary = json.load(f)
- self.assertAlmostEqual(100000.0, summary['A']['mean'])
- self.assertEqual('milliwatt', summary['A']['unit'])
- self.assertAlmostEqual(2.5, summary['B']['mean'])
- self.assertEqual('millivolt', summary['B']['unit'])
-
- def test_SaveSummaryJSONSMID(self):
- """SaveSummaryJSON uses the smid when creating output filename."""
- identifier = 'ec'
- self.data = stats_manager.StatsManager(smid=identifier)
- self._populate_mock_stats()
- fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir))
- self.assertTrue(fname.startswith(identifier))
-
- def test_SaveSummaryJSONNoUnit(self):
- """SaveSummaryJSON marks unknown units properly as N/A."""
- self._populate_mock_stats_no_unit()
- self.data.CalculateStats()
- fname = 'unittest_summary.json'
- fname = self.data.SaveSummaryJSON(self.tempdir, fname)
- with open(fname, 'r') as f:
- summary = json.load(f)
- self.assertEqual('blue', summary['A']['unit'])
- # if no unit is specified, JSON should save 'N/A' as the unit.
- self.assertEqual('N/A', summary['B']['unit'])
-
-if __name__ == '__main__':
- unittest.main()
+ """Test to verify StatsManager methods work as expected.
+
+ StatsManager should collect raw data, calculate their statistics, and save
+ them in expected format.
+ """
+
+ def _populate_mock_stats(self):
+ """Create a populated & processed StatsManager to test data retrieval."""
+ self.data.AddSample("A", 99999.5)
+ self.data.AddSample("A", 100000.5)
+ self.data.SetUnit("A", "uW")
+ self.data.SetUnit("A", "mW")
+ self.data.AddSample("B", 1.5)
+ self.data.AddSample("B", 2.5)
+ self.data.AddSample("B", 3.5)
+ self.data.SetUnit("B", "mV")
+ self.data.CalculateStats()
+
+ def _populate_mock_stats_no_unit(self):
+ self.data.AddSample("B", 1000)
+ self.data.AddSample("A", 200)
+ self.data.SetUnit("A", "blue")
+
+ def setUp(self):
+ """Set up StatsManager and create a temporary directory for test."""
+ self.tempdir = tempfile.mkdtemp()
+ self.data = stats_manager.StatsManager()
+
+ def tearDown(self):
+ """Delete the temporary directory and its content."""
+ shutil.rmtree(self.tempdir)
+
+ def test_AddSample(self):
+ """Adding a sample successfully adds a sample."""
+ self.data.AddSample("Test", 1000)
+ self.data.SetUnit("Test", "test")
+ self.data.CalculateStats()
+ summary = self.data.GetSummary()
+ self.assertEqual(1, summary["Test"]["count"])
+
+ def test_AddSampleNoFloatAcceptNaN(self):
+ """Adding a non-number adds 'NaN' and doesn't raise an exception."""
+ self.data.AddSample("Test", 10)
+ self.data.AddSample("Test", 20)
+ # adding a fake NaN: one that gets converted into NaN internally
+ self.data.AddSample("Test", "fiesta")
+ # adding a real NaN
+ self.data.AddSample("Test", float("NaN"))
+ self.data.SetUnit("Test", "test")
+ self.data.CalculateStats()
+ summary = self.data.GetSummary()
+ # assert that 'NaN' as added.
+ self.assertEqual(4, summary["Test"]["count"])
+ # assert that mean, min, and max calculatings ignore the 'NaN'
+ self.assertEqual(10, summary["Test"]["min"])
+ self.assertEqual(20, summary["Test"]["max"])
+ self.assertEqual(15, summary["Test"]["mean"])
+
+ def test_AddSampleNoFloatNotAcceptNaN(self):
+ """Adding a non-number raises a StatsManagerError if accept_nan is False."""
+ self.data = stats_manager.StatsManager(accept_nan=False)
+ with self.assertRaisesRegexp(
+ stats_manager.StatsManagerError,
+ "accept_nan is false. Cannot add NaN sample.",
+ ):
+ # adding a fake NaN: one that gets converted into NaN internally
+ self.data.AddSample("Test", "fiesta")
+ with self.assertRaisesRegexp(
+ stats_manager.StatsManagerError,
+ "accept_nan is false. Cannot add NaN sample.",
+ ):
+ # adding a real NaN
+ self.data.AddSample("Test", float("NaN"))
+
+ def test_AddSampleNoUnit(self):
+ """Not adding a unit does not cause an exception on CalculateStats()."""
+ self.data.AddSample("Test", 17)
+ self.data.CalculateStats()
+ summary = self.data.GetSummary()
+ self.assertEqual(1, summary["Test"]["count"])
+
+ def test_UnitSuffix(self):
+ """Unit gets appended as a suffix in the displayed summary."""
+ self.data.AddSample("test", 250)
+ self.data.SetUnit("test", "mw")
+ self.data.CalculateStats()
+ summary_str = self.data.SummaryToString()
+ self.assertIn("test_mw", summary_str)
+
+ def test_DoubleUnitSuffix(self):
+ """If domain already ends in unit, verify that unit doesn't get appended."""
+ self.data.AddSample("test_mw", 250)
+ self.data.SetUnit("test_mw", "mw")
+ self.data.CalculateStats()
+ summary_str = self.data.SummaryToString()
+ self.assertIn("test_mw", summary_str)
+ self.assertNotIn("test_mw_mw", summary_str)
+
+ def test_GetRawData(self):
+ """GetRawData returns exact same data as fed in."""
+ self._populate_mock_stats()
+ raw_data = self.data.GetRawData()
+ self.assertListEqual([99999.5, 100000.5], raw_data["A"])
+ self.assertListEqual([1.5, 2.5, 3.5], raw_data["B"])
+
+ def test_GetSummary(self):
+ """GetSummary returns expected stats about the data fed in."""
+ self._populate_mock_stats()
+ summary = self.data.GetSummary()
+ self.assertEqual(2, summary["A"]["count"])
+ self.assertAlmostEqual(100000.5, summary["A"]["max"])
+ self.assertAlmostEqual(99999.5, summary["A"]["min"])
+ self.assertAlmostEqual(0.5, summary["A"]["stddev"])
+ self.assertAlmostEqual(100000.0, summary["A"]["mean"])
+ self.assertEqual(3, summary["B"]["count"])
+ self.assertAlmostEqual(3.5, summary["B"]["max"])
+ self.assertAlmostEqual(1.5, summary["B"]["min"])
+ self.assertAlmostEqual(0.81649658092773, summary["B"]["stddev"])
+ self.assertAlmostEqual(2.5, summary["B"]["mean"])
+
+ def test_SaveRawData(self):
+ """SaveRawData stores same data as fed in."""
+ self._populate_mock_stats()
+ dirname = "unittest_raw_data"
+ expected_files = set(["A_mW.txt", "B_mV.txt"])
+ fnames = self.data.SaveRawData(self.tempdir, dirname)
+ files_returned = set([os.path.basename(f) for f in fnames])
+ # Assert that only the expected files got returned.
+ self.assertEqual(expected_files, files_returned)
+ # Assert that only the returned files are in the outdir.
+ self.assertEqual(
+ set(os.listdir(os.path.join(self.tempdir, dirname))), files_returned
+ )
+ for fname in fnames:
+ with open(fname, "r") as f:
+ if "A_mW" in fname:
+ self.assertEqual("99999.50", f.readline().strip())
+ self.assertEqual("100000.50", f.readline().strip())
+ if "B_mV" in fname:
+ self.assertEqual("1.50", f.readline().strip())
+ self.assertEqual("2.50", f.readline().strip())
+ self.assertEqual("3.50", f.readline().strip())
+
+ def test_SaveRawDataNoUnit(self):
+ """SaveRawData appends no unit suffix if the unit is not specified."""
+ self._populate_mock_stats_no_unit()
+ self.data.CalculateStats()
+ outdir = "unittest_raw_data"
+ files = self.data.SaveRawData(self.tempdir, outdir)
+ files = [os.path.basename(f) for f in files]
+ # Verify nothing gets appended to domain for filename if no unit exists.
+ self.assertIn("B.txt", files)
+
+ def test_SaveRawDataSMID(self):
+ """SaveRawData uses the smid when creating output filename."""
+ identifier = "ec"
+ self.data = stats_manager.StatsManager(smid=identifier)
+ self._populate_mock_stats()
+ files = self.data.SaveRawData(self.tempdir)
+ for fname in files:
+ self.assertTrue(os.path.basename(fname).startswith(identifier))
+
+ def test_SummaryToStringNaNHelp(self):
+ """NaN containing row gets tagged with *, help banner gets added."""
+ help_banner_exp = "%s %s" % (
+ stats_manager.STATS_PREFIX,
+ stats_manager.NAN_DESCRIPTION,
+ )
+ nan_domain = "A-domain"
+ nan_domain_exp = "%s%s" % (nan_domain, stats_manager.NAN_TAG)
+ # NaN helper banner is added when a NaN domain is found & domain gets tagged
+ data = stats_manager.StatsManager()
+ data.AddSample(nan_domain, float("NaN"))
+ data.AddSample(nan_domain, 17)
+ data.AddSample("B-domain", 17)
+ data.CalculateStats()
+ summarystr = data.SummaryToString()
+ self.assertIn(help_banner_exp, summarystr)
+ self.assertIn(nan_domain_exp, summarystr)
+ # NaN helper banner is not added when no NaN domain output, no tagging
+ data = stats_manager.StatsManager()
+ # nan_domain in this scenario does not contain any NaN
+ data.AddSample(nan_domain, 19)
+ data.AddSample("B-domain", 17)
+ data.CalculateStats()
+ summarystr = data.SummaryToString()
+ self.assertNotIn(help_banner_exp, summarystr)
+ self.assertNotIn(nan_domain_exp, summarystr)
+
+ def test_SummaryToStringTitle(self):
+ """Title shows up in SummaryToString if title specified."""
+ title = "titulo"
+ data = stats_manager.StatsManager(title=title)
+ self._populate_mock_stats()
+ summary_str = data.SummaryToString()
+ self.assertIn(title, summary_str)
+
+ def test_SummaryToStringHideDomains(self):
+ """Keys indicated in hide_domains are not printed in the summary."""
+ data = stats_manager.StatsManager(hide_domains=["A-domain"])
+ data.AddSample("A-domain", 17)
+ data.AddSample("B-domain", 17)
+ data.CalculateStats()
+ summary_str = data.SummaryToString()
+ self.assertIn("B-domain", summary_str)
+ self.assertNotIn("A-domain", summary_str)
+
+ def test_SummaryToStringOrder(self):
+ """Order passed into StatsManager is honoured when formatting summary."""
+ # StatsManager that should print D & B first, and the subsequent elements
+ # are sorted.
+ d_b_a_c_regexp = re.compile("D-domain.*B-domain.*A-domain.*C-domain", re.DOTALL)
+ data = stats_manager.StatsManager(order=["D-domain", "B-domain"])
+ data.AddSample("A-domain", 17)
+ data.AddSample("B-domain", 17)
+ data.AddSample("C-domain", 17)
+ data.AddSample("D-domain", 17)
+ data.CalculateStats()
+ summary_str = data.SummaryToString()
+ self.assertRegexpMatches(summary_str, d_b_a_c_regexp)
+
+ def test_MakeUniqueFName(self):
+ data = stats_manager.StatsManager()
+ testfile = os.path.join(self.tempdir, "testfile.txt")
+ with open(testfile, "w") as f:
+ f.write("")
+ expected_fname = os.path.join(self.tempdir, "testfile0.txt")
+ self.assertEqual(expected_fname, data._MakeUniqueFName(testfile))
+
+ def test_SaveSummary(self):
+ """SaveSummary properly dumps the summary into a file."""
+ self._populate_mock_stats()
+ fname = "unittest_summary.txt"
+ expected_fname = os.path.join(self.tempdir, fname)
+ fname = self.data.SaveSummary(self.tempdir, fname)
+ # Assert the reported fname is the same as the expected fname
+ self.assertEqual(expected_fname, fname)
+ # Assert only the reported fname is output (in the tempdir)
+ self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir)))
+ with open(fname, "r") as f:
+ self.assertEqual(
+ "@@ NAME COUNT MEAN STDDEV MAX MIN\n",
+ f.readline(),
+ )
+ self.assertEqual(
+ "@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n",
+ f.readline(),
+ )
+ self.assertEqual(
+ "@@ B_mV 3 2.50 0.82 3.50 1.50\n",
+ f.readline(),
+ )
+
+ def test_SaveSummarySMID(self):
+ """SaveSummary uses the smid when creating output filename."""
+ identifier = "ec"
+ self.data = stats_manager.StatsManager(smid=identifier)
+ self._populate_mock_stats()
+ fname = os.path.basename(self.data.SaveSummary(self.tempdir))
+ self.assertTrue(fname.startswith(identifier))
+
+ def test_SaveSummaryJSON(self):
+ """SaveSummaryJSON saves the added data properly in JSON format."""
+ self._populate_mock_stats()
+ fname = "unittest_summary.json"
+ expected_fname = os.path.join(self.tempdir, fname)
+ fname = self.data.SaveSummaryJSON(self.tempdir, fname)
+ # Assert the reported fname is the same as the expected fname
+ self.assertEqual(expected_fname, fname)
+ # Assert only the reported fname is output (in the tempdir)
+ self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir)))
+ with open(fname, "r") as f:
+ summary = json.load(f)
+ self.assertAlmostEqual(100000.0, summary["A"]["mean"])
+ self.assertEqual("milliwatt", summary["A"]["unit"])
+ self.assertAlmostEqual(2.5, summary["B"]["mean"])
+ self.assertEqual("millivolt", summary["B"]["unit"])
+
+ def test_SaveSummaryJSONSMID(self):
+ """SaveSummaryJSON uses the smid when creating output filename."""
+ identifier = "ec"
+ self.data = stats_manager.StatsManager(smid=identifier)
+ self._populate_mock_stats()
+ fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir))
+ self.assertTrue(fname.startswith(identifier))
+
+ def test_SaveSummaryJSONNoUnit(self):
+ """SaveSummaryJSON marks unknown units properly as N/A."""
+ self._populate_mock_stats_no_unit()
+ self.data.CalculateStats()
+ fname = "unittest_summary.json"
+ fname = self.data.SaveSummaryJSON(self.tempdir, fname)
+ with open(fname, "r") as f:
+ summary = json.load(f)
+ self.assertEqual("blue", summary["A"]["unit"])
+ # if no unit is specified, JSON should save 'N/A' as the unit.
+ self.assertEqual("N/A", summary["B"]["unit"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extra/usb_serial/console.py b/extra/usb_serial/console.py
index 7211dceff6..c08cb72092 100755
--- a/extra/usb_serial/console.py
+++ b/extra/usb_serial/console.py
@@ -12,6 +12,7 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import argparse
import array
import os
@@ -20,23 +21,24 @@ import termios
import threading
import time
import tty
+
try:
- import usb
+ import usb
except ModuleNotFoundError:
- print('import usb failed')
- print('try running these commands:')
- print(' sudo apt-get install python-pip')
- print(' sudo pip install --pre pyusb')
- print()
- sys.exit(-1)
+ print("import usb failed")
+ print("try running these commands:")
+ print(" sudo apt-get install python-pip")
+ print(" sudo pip install --pre pyusb")
+ print()
+ sys.exit(-1)
import six
def GetBuffer(stream):
- if six.PY3:
- return stream.buffer
- return stream
+ if six.PY3:
+ return stream.buffer
+ return stream
"""Class Susb covers USB device discovery and initialization.
@@ -45,99 +47,104 @@ def GetBuffer(stream):
and interface number.
"""
+
class SusbError(Exception):
- """Class for exceptions of Susb."""
- def __init__(self, msg, value=0):
- """SusbError constructor.
+ """Class for exceptions of Susb."""
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SusbError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
-
-class Susb():
- """Provide USB functionality.
-
- Instance Variables:
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
- READ_ENDPOINT = 0x81
- WRITE_ENDPOINT = 0x1
- TIMEOUT_MS = 100
-
- def __init__(self, vendor=0x18d1,
- product=0x500f, interface=1, serialname=None):
- """Susb constructor.
-
- Discovers and connects to USB endpoints.
-
- Args:
- vendor: usb vendor id of device
- product: usb product id of device
- interface: interface number ( 1 - 8 ) of device to use
- serialname: string of device serialnumber.
-
- Raises:
- SusbError: An error accessing Susb object
+ def __init__(self, msg, value=0):
+ """SusbError constructor.
+
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SusbError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
+
+
+class Susb:
+ """Provide USB functionality.
+
+ Instance Variables:
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- # Find the device.
- dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
- dev_list = list(dev_g)
- if dev_list is None:
- raise SusbError('USB device not found')
-
- # Check if we have multiple devices.
- dev = None
- if serialname:
- for d in dev_list:
- dev_serial = "PyUSB doesn't have a stable interface"
- try:
- dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
- except:
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == serialname:
- dev = d
- break
- if dev is None:
- raise SusbError('USB device(%s) not found' % (serialname,))
- else:
- try:
- dev = dev_list[0]
- except:
- try:
- dev = dev_list.next()
- except:
- raise SusbError('USB device %04x:%04x not found' % (vendor, product))
- # If we can't set configuration, it's already been set.
- try:
- dev.set_configuration()
- except usb.core.USBError:
- pass
+ READ_ENDPOINT = 0x81
+ WRITE_ENDPOINT = 0x1
+ TIMEOUT_MS = 100
+
+ def __init__(self, vendor=0x18D1, product=0x500F, interface=1, serialname=None):
+ """Susb constructor.
+
+ Discovers and connects to USB endpoints.
+
+ Args:
+ vendor: usb vendor id of device
+ product: usb product id of device
+ interface: interface number ( 1 - 8 ) of device to use
+ serialname: string of device serialnumber.
+
+ Raises:
+ SusbError: An error accessing Susb object
+ """
+ # Find the device.
+ dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
+ dev_list = list(dev_g)
+ if dev_list is None:
+ raise SusbError("USB device not found")
+
+ # Check if we have multiple devices.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ dev_serial = "PyUSB doesn't have a stable interface"
+ try:
+ dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
+ except:
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == serialname:
+ dev = d
+ break
+ if dev is None:
+ raise SusbError("USB device(%s) not found" % (serialname,))
+ else:
+ try:
+ dev = dev_list[0]
+ except:
+ try:
+ dev = dev_list.next()
+ except:
+ raise SusbError(
+ "USB device %04x:%04x not found" % (vendor, product)
+ )
+
+ # If we can't set configuration, it's already been set.
+ try:
+ dev.set_configuration()
+ except usb.core.USBError:
+ pass
- # Get an endpoint instance.
- cfg = dev.get_active_configuration()
- intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface)
- self._intf = intf
+ # Get an endpoint instance.
+ cfg = dev.get_active_configuration()
+ intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface)
+ self._intf = intf
- if not intf:
- raise SusbError('Interface not found')
+ if not intf:
+ raise SusbError("Interface not found")
- # Detach raiden.ko if it is loaded.
- if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
+ # Detach raiden.ko if it is loaded.
+ if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
dev.detach_kernel_driver(intf.bInterfaceNumber)
- read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
- read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
- self._read_ep = read_ep
+ read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
+ read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
+ self._read_ep = read_ep
- write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
- write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
- self._write_ep = write_ep
+ write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
+ write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
+ self._write_ep = write_ep
"""Suart class implements a stream interface, to access Google's USB class.
@@ -146,89 +153,91 @@ class Susb():
and forwards them across. This particular class is hardcoded to stdin/out.
"""
-class SuartError(Exception):
- """Class for exceptions of Suart."""
- def __init__(self, msg, value=0):
- """SuartError constructor.
-
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SuartError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
-
-
-class Suart():
- """Provide interface to serial usb endpoint."""
- def __init__(self, vendor=0x18d1, product=0x501c, interface=0,
- serialname=None):
- """Suart contstructor.
+class SuartError(Exception):
+ """Class for exceptions of Suart."""
- Initializes USB stream interface.
+ def __init__(self, msg, value=0):
+ """SuartError constructor.
- Args:
- vendor: usb vendor id of device
- product: usb product id of device
- interface: interface number of device to use
- serialname: Defaults to None.
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SuartError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
- Raises:
- SuartError: If init fails
- """
- self._done = threading.Event()
- self._susb = Susb(vendor=vendor, product=product,
- interface=interface, serialname=serialname)
- def wait_until_done(self, timeout=None):
- return self._done.wait(timeout=timeout)
+class Suart:
+ """Provide interface to serial usb endpoint."""
- def run_rx_thread(self):
- try:
- while True:
- try:
- r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
- if r:
- GetBuffer(sys.stdout).write(r.tobytes())
- GetBuffer(sys.stdout).flush()
-
- except Exception as e:
- # If we miss some characters on pty disconnect, that's fine.
- # ep.read() also throws USBError on timeout, which we discard.
- if not isinstance(e, (OSError, usb.core.USBError)):
- print('rx %s' % e)
- finally:
- self._done.set()
+ def __init__(self, vendor=0x18D1, product=0x501C, interface=0, serialname=None):
+ """Suart contstructor.
- def run_tx_thread(self):
- try:
- while True:
- try:
- r = GetBuffer(sys.stdin).read(1)
- if not r or r == b'\x03':
- break
- if r:
- self._susb._write_ep.write(array.array('B', r),
- self._susb.TIMEOUT_MS)
- except Exception as e:
- print('tx %s' % e)
- finally:
- self._done.set()
+ Initializes USB stream interface.
- def run(self):
- """Creates pthreads to poll USB & PTY for data."""
- self._exit = False
+ Args:
+ vendor: usb vendor id of device
+ product: usb product id of device
+ interface: interface number of device to use
+ serialname: Defaults to None.
- self._rx_thread = threading.Thread(target=self.run_rx_thread)
- self._rx_thread.daemon = True
- self._rx_thread.start()
+ Raises:
+ SuartError: If init fails
+ """
+ self._done = threading.Event()
+ self._susb = Susb(
+ vendor=vendor, product=product, interface=interface, serialname=serialname
+ )
- self._tx_thread = threading.Thread(target=self.run_tx_thread)
- self._tx_thread.daemon = True
- self._tx_thread.start()
+ def wait_until_done(self, timeout=None):
+ return self._done.wait(timeout=timeout)
+ def run_rx_thread(self):
+ try:
+ while True:
+ try:
+ r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
+ if r:
+ GetBuffer(sys.stdout).write(r.tobytes())
+ GetBuffer(sys.stdout).flush()
+
+ except Exception as e:
+ # If we miss some characters on pty disconnect, that's fine.
+ # ep.read() also throws USBError on timeout, which we discard.
+ if not isinstance(e, (OSError, usb.core.USBError)):
+ print("rx %s" % e)
+ finally:
+ self._done.set()
+
+ def run_tx_thread(self):
+ try:
+ while True:
+ try:
+ r = GetBuffer(sys.stdin).read(1)
+ if not r or r == b"\x03":
+ break
+ if r:
+ self._susb._write_ep.write(
+ array.array("B", r), self._susb.TIMEOUT_MS
+ )
+ except Exception as e:
+ print("tx %s" % e)
+ finally:
+ self._done.set()
+
+ def run(self):
+ """Creates pthreads to poll USB & PTY for data."""
+ self._exit = False
+
+ self._rx_thread = threading.Thread(target=self.run_rx_thread)
+ self._rx_thread.daemon = True
+ self._rx_thread.start()
+
+ self._tx_thread = threading.Thread(target=self.run_tx_thread)
+ self._tx_thread.daemon = True
+ self._tx_thread.start()
"""Command line functionality
@@ -237,59 +246,67 @@ class Suart():
Ctrl-C exits.
"""
-parser = argparse.ArgumentParser(description='Open a console to a USB device')
-parser.add_argument('-d', '--device', type=str,
- help='vid:pid of target device', default='18d1:501c')
-parser.add_argument('-i', '--interface', type=int,
- help='interface number of console', default=0)
-parser.add_argument('-s', '--serialno', type=str,
- help='serial number of device', default='')
-parser.add_argument('-S', '--notty-exit-sleep', type=float, default=0.2,
- help='When stdin is *not* a TTY, wait this many seconds '
- 'after EOF from stdin before exiting, to give time for '
- 'receiving a reply from the USB device.')
+parser = argparse.ArgumentParser(description="Open a console to a USB device")
+parser.add_argument(
+ "-d", "--device", type=str, help="vid:pid of target device", default="18d1:501c"
+)
+parser.add_argument(
+ "-i", "--interface", type=int, help="interface number of console", default=0
+)
+parser.add_argument(
+ "-s", "--serialno", type=str, help="serial number of device", default=""
+)
+parser.add_argument(
+ "-S",
+ "--notty-exit-sleep",
+ type=float,
+ default=0.2,
+ help="When stdin is *not* a TTY, wait this many seconds "
+ "after EOF from stdin before exiting, to give time for "
+ "receiving a reply from the USB device.",
+)
+
def runconsole():
- """Run the usb console code
+ """Run the usb console code
- Starts the pty thread, and idles until a ^C is caught.
- """
- args = parser.parse_args()
+ Starts the pty thread, and idles until a ^C is caught.
+ """
+ args = parser.parse_args()
- vidstr, pidstr = args.device.split(':')
- vid = int(vidstr, 16)
- pid = int(pidstr, 16)
+ vidstr, pidstr = args.device.split(":")
+ vid = int(vidstr, 16)
+ pid = int(pidstr, 16)
- serialno = args.serialno
- interface = args.interface
+ serialno = args.serialno
+ interface = args.interface
- sobj = Suart(vendor=vid, product=pid, interface=interface,
- serialname=serialno)
- if sys.stdin.isatty():
- tty.setraw(sys.stdin.fileno())
- sobj.run()
- sobj.wait_until_done()
- if not sys.stdin.isatty() and args.notty_exit_sleep > 0:
- time.sleep(args.notty_exit_sleep)
+ sobj = Suart(vendor=vid, product=pid, interface=interface, serialname=serialno)
+ if sys.stdin.isatty():
+ tty.setraw(sys.stdin.fileno())
+ sobj.run()
+ sobj.wait_until_done()
+ if not sys.stdin.isatty() and args.notty_exit_sleep > 0:
+ time.sleep(args.notty_exit_sleep)
def main():
- stdin_isatty = sys.stdin.isatty()
- if stdin_isatty:
- fd = sys.stdin.fileno()
- os.system('stty -echo')
- old_settings = termios.tcgetattr(fd)
-
- try:
- runconsole()
- finally:
+ stdin_isatty = sys.stdin.isatty()
if stdin_isatty:
- termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
- os.system('stty echo')
- # Avoid having the user's shell prompt start mid-line after the final output
- # from this program.
- print()
+ fd = sys.stdin.fileno()
+ os.system("stty -echo")
+ old_settings = termios.tcgetattr(fd)
+
+ try:
+ runconsole()
+ finally:
+ if stdin_isatty:
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
+ os.system("stty echo")
+ # Avoid having the user's shell prompt start mid-line after the final output
+ # from this program.
+ print()
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/extra/usb_updater/fw_update.py b/extra/usb_updater/fw_update.py
index 0d7a570fc3..f05797bfb6 100755
--- a/extra/usb_updater/fw_update.py
+++ b/extra/usb_updater/fw_update.py
@@ -20,407 +20,425 @@ import struct
import sys
import time
from pprint import pprint
-import usb
+import usb
debug = False
-def debuglog(msg):
- if debug:
- print(msg)
-
-def log(msg):
- print(msg)
- sys.stdout.flush()
-
-
-"""Sends firmware update to CROS EC usb endpoint."""
-
-class Supdate(object):
- """Class to access firmware update endpoints.
-
- Usage:
- d = Supdate()
- Instance Variables:
- _dev: pyUSB device object
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
- USB_SUBCLASS_GOOGLE_UPDATE = 0x53
- USB_CLASS_VENDOR = 0xFF
- def __init__(self):
- pass
-
-
- def connect_usb(self, serialname=None):
- """Initial discovery and connection to USB endpoint.
-
- This searches for a USB device matching the VID:PID specified
- in the config file, optionally matching a specified serialname.
-
- Args:
- serialname: Find the device with this serial, in case multiple
- devices are attached.
-
- Returns:
- True on success.
- Raises:
- Exception on error.
- """
- # Find the stm32.
- vendor = self._brdcfg['vid']
- product = self._brdcfg['pid']
-
- dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
- dev_list = list(dev_g)
- if dev_list is None:
- raise Exception("Update", "USB device not found")
-
- # Check if we have multiple stm32s and we've specified the serial.
- dev = None
- if serialname:
- for d in dev_list:
- if usb.util.get_string(d, d.iSerialNumber) == serialname:
- dev = d
- break
- if dev is None:
- raise SusbError("USB device(%s) not found" % serialname)
- else:
- try:
- dev = dev_list[0]
- except:
- dev = dev_list.next()
-
- debuglog("Found stm32: %04x:%04x" % (vendor, product))
- self._dev = dev
-
- # Get an endpoint instance.
- try:
- dev.set_configuration()
- except:
- pass
- cfg = dev.get_active_configuration()
-
- intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \
- i.bInterfaceClass==self.USB_CLASS_VENDOR and \
- i.bInterfaceSubClass==self.USB_SUBCLASS_GOOGLE_UPDATE)
-
- self._intf = intf
- debuglog("Interface: %s" % intf)
- debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber)
-
- read_ep = usb.util.find_descriptor(
- intf,
- # match the first IN endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_IN
- )
-
- self._read_ep = read_ep
- debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress)
-
- write_ep = usb.util.find_descriptor(
- intf,
- # match the first OUT endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_OUT
- )
-
- self._write_ep = write_ep
- debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress)
-
- return True
-
-
- def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000):
- """Write command to logger logic..
-
- This function writes byte command values list to stm, then reads
- byte status.
-
- Args:
- write_list: list of command byte values [0~255].
- read_count: number of status byte values to read.
- wtimeout: mS to wait for write success
- rtimeout: mS to wait for read success
-
- Returns:
- status byte, if one byte is read,
- byte list, if multiple bytes are read,
- None, if no bytes are read.
-
- Interface:
- write: [command, data ... ]
- read: [status ]
- """
- debuglog("wr_command(write_list=[%s] (%d), read_count=%s)" % (
- list(bytearray(write_list)), len(write_list), read_count))
-
- # Clean up args from python style to correct types.
- write_length = 0
- if write_list:
- write_length = len(write_list)
- if not read_count:
- read_count = 0
-
- # Send command to stm32.
- if write_list:
- cmd = write_list
- ret = self._write_ep.write(cmd, wtimeout)
- debuglog("RET: %s " % ret)
-
- # Read back response if necessary.
- if read_count:
- bytesread = self._read_ep.read(512, rtimeout)
- debuglog("BYTES: [%s]" % bytesread)
-
- if len(bytesread) != read_count:
- debuglog("Unexpected bytes read: %d, expected: %d" % (len(bytesread), read_count))
- pass
-
- debuglog("STATUS: 0x%02x" % int(bytesread[0]))
- if read_count == 1:
- return bytesread[0]
- else:
- return bytesread
-
- return None
+def debuglog(msg):
+ if debug:
+ print(msg)
- def stop(self):
- """Finalize system flash and exit."""
- cmd = struct.pack(">I", 0xB007AB1E)
- read = self.wr_command(cmd, read_count=4)
- if len(read) == 4:
- log("Finished flashing")
- return
+def log(msg):
+ print(msg)
+ sys.stdout.flush()
- raise Exception("Update", "Stop failed [%s]" % read)
+"""Sends firmware update to CROS EC usb endpoint."""
- def write_file(self):
- """Write the update region packet by packet to USB
- This sends write packets of size 128B out, in 32B chunks.
- Overall, this will write all data in the inactive code region.
+class Supdate(object):
+ """Class to access firmware update endpoints.
- Raises:
- Exception if write failed or address out of bounds.
- """
- region = self._region
- flash_base = self._brdcfg["flash"]
- offset = self._base - flash_base
- if offset != self._brdcfg['regions'][region][0]:
- raise Exception("Update", "Region %s offset 0x%x != available offset 0x%x" % (
- region, self._brdcfg['regions'][region][0], offset))
-
- length = self._brdcfg['regions'][region][1]
- log("Sending")
-
- # Go to the correct region in the ec.bin file.
- self._binfile.seek(offset)
-
- # Send 32 bytes at a time. Must be less than the endpoint's max packet size.
- maxpacket = 32
-
- # While data is left, create update packets.
- while length > 0:
- # Update packets are 128B. We can use any number
- # but the micro must malloc this memory.
- pagesize = min(length, 128)
-
- # Packet is:
- # packet size: page bytes transferred plus 3 x 32b values header.
- # cmd: n/a
- # base: flash address to write this packet.
- # data: 128B of data to write into flash_base
- cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base)
- read = self.wr_command(cmd, read_count=0)
-
- # Push 'todo' bytes out the pipe.
- todo = pagesize
- while todo > 0:
- packetsize = min(maxpacket, todo)
- data = self._binfile.read(packetsize)
- if len(data) != packetsize:
- raise Exception("Update", "No more data from file")
- for i in range(0, 10):
- try:
- self.wr_command(data, read_count=0)
- break
- except:
- log("Timeout fail")
- todo -= packetsize
- # Done with this packet, move to the next one.
- length -= pagesize
- offset += pagesize
-
- # Validate that the micro thinks it successfully wrote the data.
- read = self.wr_command(''.encode(), read_count=4)
- result = struct.unpack("<I", read)
- result = result[0]
- if result != 0:
- raise Exception("Update", "Upload failed with rc: 0x%x" % result)
-
-
- def start(self):
- """Start a transaction and erase currently inactive region.
-
- This function sends a start command, and receives the base of the
- preferred inactive region. This could be RW, RW_B,
- or RO (if there's no RW_B)
-
- Note that the region is erased here, so you'd better program the RO if
- you just erased it. TODO(nsanders): Modify the protocol to allow active
- region select or query before erase.
- """
+ Usage:
+ d = Supdate()
- # Size is 3 uint32 fields
- # packet: [packetsize, cmd, base]
- size = 4 + 4 + 4
- # Return value is [status, base_addr]
- expected = 4 + 4
-
- cmd = struct.pack("<III", size, 0, 0)
- read = self.wr_command(cmd, read_count=expected)
-
- if len(read) == 4:
- raise Exception("Update", "Protocol version 0 not supported")
- elif len(read) == expected:
- base, version = struct.unpack(">II", read)
- log("Update protocol v. %d" % version)
- log("Available flash region base: %x" % base)
- else:
- raise Exception("Update", "Start command returned %d bytes" % len(read))
-
- if base < 256:
- raise Exception("Update", "Start returned error code 0x%x" % base)
-
- self._base = base
- flash_base = self._brdcfg["flash"]
- self._offset = self._base - flash_base
-
- # Find our active region.
- for region in self._brdcfg['regions']:
- if (self._offset >= self._brdcfg['regions'][region][0]) and \
- (self._offset < (self._brdcfg['regions'][region][0] + \
- self._brdcfg['regions'][region][1])):
- log("Active region: %s" % region)
- self._region = region
-
-
- def load_board(self, brdfile):
- """Load firmware layout file.
-
- example as follows:
- {
- "board": "servo micro",
- "vid": 6353,
- "pid": 20506,
- "flash": 134217728,
- "regions": {
- "RW": [65536, 65536],
- "PSTATE": [61440, 4096],
- "RO": [0, 61440]
- }
- }
-
- Args:
- brdfile: path to board description file.
+ Instance Variables:
+ _dev: pyUSB device object
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- with open(brdfile) as data_file:
- data = json.load(data_file)
-
- # TODO(nsanders): validate this data before moving on.
- self._brdcfg = data;
- if debug:
- pprint(data)
-
- log("Board is %s" % self._brdcfg['board'])
- # Cast hex strings to int.
- self._brdcfg['flash'] = int(self._brdcfg['flash'], 0)
- self._brdcfg['vid'] = int(self._brdcfg['vid'], 0)
- self._brdcfg['pid'] = int(self._brdcfg['pid'], 0)
- log("Flash Base is %x" % self._brdcfg['flash'])
- self._flashsize = 0
- for region in self._brdcfg['regions']:
- base = int(self._brdcfg['regions'][region][0], 0)
- length = int(self._brdcfg['regions'][region][1], 0)
- log("region %s\tbase:0x%08x size:0x%08x" % (
- region, base, length))
- self._flashsize += length
+ USB_SUBCLASS_GOOGLE_UPDATE = 0x53
+ USB_CLASS_VENDOR = 0xFF
- # Convert these to int because json doesn't support hex.
- self._brdcfg['regions'][region][0] = base
- self._brdcfg['regions'][region][1] = length
-
- log("Flash Size: 0x%x" % self._flashsize)
-
- def load_file(self, binfile):
- """Open and verify size of the target ec.bin file.
-
- Args:
- binfile: path to ec.bin
-
- Raises:
- Exception on file not found or filesize not matching.
- """
- self._filesize = os.path.getsize(binfile)
- self._binfile = open(binfile, 'rb')
-
- if self._filesize != self._flashsize:
- raise Exception("Update", "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize))
+ def __init__(self):
+ pass
+ def connect_usb(self, serialname=None):
+ """Initial discovery and connection to USB endpoint.
+
+ This searches for a USB device matching the VID:PID specified
+ in the config file, optionally matching a specified serialname.
+
+ Args:
+ serialname: Find the device with this serial, in case multiple
+ devices are attached.
+
+ Returns:
+ True on success.
+ Raises:
+ Exception on error.
+ """
+ # Find the stm32.
+ vendor = self._brdcfg["vid"]
+ product = self._brdcfg["pid"]
+
+ dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
+ dev_list = list(dev_g)
+ if dev_list is None:
+ raise Exception("Update", "USB device not found")
+
+ # Check if we have multiple stm32s and we've specified the serial.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ if usb.util.get_string(d, d.iSerialNumber) == serialname:
+ dev = d
+ break
+ if dev is None:
+ raise SusbError("USB device(%s) not found" % serialname)
+ else:
+ try:
+ dev = dev_list[0]
+ except:
+ dev = dev_list.next()
+
+ debuglog("Found stm32: %04x:%04x" % (vendor, product))
+ self._dev = dev
+
+ # Get an endpoint instance.
+ try:
+ dev.set_configuration()
+ except:
+ pass
+ cfg = dev.get_active_configuration()
+
+ intf = usb.util.find_descriptor(
+ cfg,
+ custom_match=lambda i: i.bInterfaceClass == self.USB_CLASS_VENDOR
+ and i.bInterfaceSubClass == self.USB_SUBCLASS_GOOGLE_UPDATE,
+ )
+
+ self._intf = intf
+ debuglog("Interface: %s" % intf)
+ debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber)
+
+ read_ep = usb.util.find_descriptor(
+ intf,
+ # match the first IN endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_IN,
+ )
+
+ self._read_ep = read_ep
+ debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress)
+
+ write_ep = usb.util.find_descriptor(
+ intf,
+ # match the first OUT endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_OUT,
+ )
+
+ self._write_ep = write_ep
+ debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress)
+
+ return True
+
+ def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000):
+ """Write command to logger logic..
+
+ This function writes byte command values list to stm, then reads
+ byte status.
+
+ Args:
+ write_list: list of command byte values [0~255].
+ read_count: number of status byte values to read.
+ wtimeout: mS to wait for write success
+ rtimeout: mS to wait for read success
+
+ Returns:
+ status byte, if one byte is read,
+ byte list, if multiple bytes are read,
+ None, if no bytes are read.
+
+ Interface:
+ write: [command, data ... ]
+ read: [status ]
+ """
+ debuglog(
+ "wr_command(write_list=[%s] (%d), read_count=%s)"
+ % (list(bytearray(write_list)), len(write_list), read_count)
+ )
+
+ # Clean up args from python style to correct types.
+ write_length = 0
+ if write_list:
+ write_length = len(write_list)
+ if not read_count:
+ read_count = 0
+
+ # Send command to stm32.
+ if write_list:
+ cmd = write_list
+ ret = self._write_ep.write(cmd, wtimeout)
+ debuglog("RET: %s " % ret)
+
+ # Read back response if necessary.
+ if read_count:
+ bytesread = self._read_ep.read(512, rtimeout)
+ debuglog("BYTES: [%s]" % bytesread)
+
+ if len(bytesread) != read_count:
+ debuglog(
+ "Unexpected bytes read: %d, expected: %d"
+ % (len(bytesread), read_count)
+ )
+ pass
+
+ debuglog("STATUS: 0x%02x" % int(bytesread[0]))
+ if read_count == 1:
+ return bytesread[0]
+ else:
+ return bytesread
+
+ return None
+
+ def stop(self):
+ """Finalize system flash and exit."""
+ cmd = struct.pack(">I", 0xB007AB1E)
+ read = self.wr_command(cmd, read_count=4)
+
+ if len(read) == 4:
+ log("Finished flashing")
+ return
+
+ raise Exception("Update", "Stop failed [%s]" % read)
+
+ def write_file(self):
+ """Write the update region packet by packet to USB
+
+ This sends write packets of size 128B out, in 32B chunks.
+ Overall, this will write all data in the inactive code region.
+
+ Raises:
+ Exception if write failed or address out of bounds.
+ """
+ region = self._region
+ flash_base = self._brdcfg["flash"]
+ offset = self._base - flash_base
+ if offset != self._brdcfg["regions"][region][0]:
+ raise Exception(
+ "Update",
+ "Region %s offset 0x%x != available offset 0x%x"
+ % (region, self._brdcfg["regions"][region][0], offset),
+ )
+
+ length = self._brdcfg["regions"][region][1]
+ log("Sending")
+
+ # Go to the correct region in the ec.bin file.
+ self._binfile.seek(offset)
+
+ # Send 32 bytes at a time. Must be less than the endpoint's max packet size.
+ maxpacket = 32
+
+ # While data is left, create update packets.
+ while length > 0:
+ # Update packets are 128B. We can use any number
+ # but the micro must malloc this memory.
+ pagesize = min(length, 128)
+
+ # Packet is:
+ # packet size: page bytes transferred plus 3 x 32b values header.
+ # cmd: n/a
+ # base: flash address to write this packet.
+ # data: 128B of data to write into flash_base
+ cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base)
+ read = self.wr_command(cmd, read_count=0)
+
+ # Push 'todo' bytes out the pipe.
+ todo = pagesize
+ while todo > 0:
+ packetsize = min(maxpacket, todo)
+ data = self._binfile.read(packetsize)
+ if len(data) != packetsize:
+ raise Exception("Update", "No more data from file")
+ for i in range(0, 10):
+ try:
+ self.wr_command(data, read_count=0)
+ break
+ except:
+ log("Timeout fail")
+ todo -= packetsize
+ # Done with this packet, move to the next one.
+ length -= pagesize
+ offset += pagesize
+
+ # Validate that the micro thinks it successfully wrote the data.
+ read = self.wr_command("".encode(), read_count=4)
+ result = struct.unpack("<I", read)
+ result = result[0]
+ if result != 0:
+ raise Exception("Update", "Upload failed with rc: 0x%x" % result)
+
+ def start(self):
+ """Start a transaction and erase currently inactive region.
+
+ This function sends a start command, and receives the base of the
+ preferred inactive region. This could be RW, RW_B,
+ or RO (if there's no RW_B)
+
+ Note that the region is erased here, so you'd better program the RO if
+ you just erased it. TODO(nsanders): Modify the protocol to allow active
+ region select or query before erase.
+ """
+
+ # Size is 3 uint32 fields
+ # packet: [packetsize, cmd, base]
+ size = 4 + 4 + 4
+ # Return value is [status, base_addr]
+ expected = 4 + 4
+
+ cmd = struct.pack("<III", size, 0, 0)
+ read = self.wr_command(cmd, read_count=expected)
+
+ if len(read) == 4:
+ raise Exception("Update", "Protocol version 0 not supported")
+ elif len(read) == expected:
+ base, version = struct.unpack(">II", read)
+ log("Update protocol v. %d" % version)
+ log("Available flash region base: %x" % base)
+ else:
+ raise Exception("Update", "Start command returned %d bytes" % len(read))
+
+ if base < 256:
+ raise Exception("Update", "Start returned error code 0x%x" % base)
+
+ self._base = base
+ flash_base = self._brdcfg["flash"]
+ self._offset = self._base - flash_base
+
+ # Find our active region.
+ for region in self._brdcfg["regions"]:
+ if (self._offset >= self._brdcfg["regions"][region][0]) and (
+ self._offset
+ < (
+ self._brdcfg["regions"][region][0]
+ + self._brdcfg["regions"][region][1]
+ )
+ ):
+ log("Active region: %s" % region)
+ self._region = region
+
+ def load_board(self, brdfile):
+ """Load firmware layout file.
+
+ example as follows:
+ {
+ "board": "servo micro",
+ "vid": 6353,
+ "pid": 20506,
+ "flash": 134217728,
+ "regions": {
+ "RW": [65536, 65536],
+ "PSTATE": [61440, 4096],
+ "RO": [0, 61440]
+ }
+ }
+
+ Args:
+ brdfile: path to board description file.
+ """
+ with open(brdfile) as data_file:
+ data = json.load(data_file)
+
+ # TODO(nsanders): validate this data before moving on.
+ self._brdcfg = data
+ if debug:
+ pprint(data)
+
+ log("Board is %s" % self._brdcfg["board"])
+ # Cast hex strings to int.
+ self._brdcfg["flash"] = int(self._brdcfg["flash"], 0)
+ self._brdcfg["vid"] = int(self._brdcfg["vid"], 0)
+ self._brdcfg["pid"] = int(self._brdcfg["pid"], 0)
+
+ log("Flash Base is %x" % self._brdcfg["flash"])
+ self._flashsize = 0
+ for region in self._brdcfg["regions"]:
+ base = int(self._brdcfg["regions"][region][0], 0)
+ length = int(self._brdcfg["regions"][region][1], 0)
+ log("region %s\tbase:0x%08x size:0x%08x" % (region, base, length))
+ self._flashsize += length
+
+ # Convert these to int because json doesn't support hex.
+ self._brdcfg["regions"][region][0] = base
+ self._brdcfg["regions"][region][1] = length
+
+ log("Flash Size: 0x%x" % self._flashsize)
+
+ def load_file(self, binfile):
+ """Open and verify size of the target ec.bin file.
+
+ Args:
+ binfile: path to ec.bin
+
+ Raises:
+ Exception on file not found or filesize not matching.
+ """
+ self._filesize = os.path.getsize(binfile)
+ self._binfile = open(binfile, "rb")
+
+ if self._filesize != self._flashsize:
+ raise Exception(
+ "Update",
+ "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize),
+ )
# Generate command line arguments
parser = argparse.ArgumentParser(description="Update firmware over usb")
-parser.add_argument('-b', '--board', type=str, help="Board configuration json file", default="board.json")
-parser.add_argument('-f', '--file', type=str, help="Complete ec.bin file", default="ec.bin")
-parser.add_argument('-s', '--serial', type=str, help="Serial number", default="")
-parser.add_argument('-l', '--list', action="store_true", help="List regions")
-parser.add_argument('-v', '--verbose', action="store_true", help="Chatty output")
+parser.add_argument(
+ "-b",
+ "--board",
+ type=str,
+ help="Board configuration json file",
+ default="board.json",
+)
+parser.add_argument(
+ "-f", "--file", type=str, help="Complete ec.bin file", default="ec.bin"
+)
+parser.add_argument("-s", "--serial", type=str, help="Serial number", default="")
+parser.add_argument("-l", "--list", action="store_true", help="List regions")
+parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output")
+
def main():
- global debug
- args = parser.parse_args()
+ global debug
+ args = parser.parse_args()
+ brdfile = args.board
+ serial = args.serial
+ binfile = args.file
+ if args.verbose:
+ debug = True
- brdfile = args.board
- serial = args.serial
- binfile = args.file
- if args.verbose:
- debug = True
+ with open(brdfile) as data_file:
+ names = json.load(data_file)
- with open(brdfile) as data_file:
- names = json.load(data_file)
+ p = Supdate()
+ p.load_board(brdfile)
+ p.connect_usb(serialname=serial)
+ p.load_file(binfile)
- p = Supdate()
- p.load_board(brdfile)
- p.connect_usb(serialname=serial)
- p.load_file(binfile)
+ # List solely prints the config.
+ if args.list:
+ return
- # List solely prints the config.
- if (args.list):
- return
+ # Start transfer and erase.
+ p.start()
+ # Upload the bin file
+ log("Uploading %s" % binfile)
+ p.write_file()
- # Start transfer and erase.
- p.start()
- # Upload the bin file
- log("Uploading %s" % binfile)
- p.write_file()
+ # Finalize
+ log("Done. Finalizing.")
+ p.stop()
- # Finalize
- log("Done. Finalizing.")
- p.stop()
if __name__ == "__main__":
- main()
-
-
+ main()
diff --git a/extra/usb_updater/servo_updater.py b/extra/usb_updater/servo_updater.py
index 432ee120de..5402af70aa 100755
--- a/extra/usb_updater/servo_updater.py
+++ b/extra/usb_updater/servo_updater.py
@@ -14,40 +14,46 @@
from __future__ import print_function
import argparse
+import json
import os
import re
import subprocess
import time
-import json
-
-import fw_update
import ecusb.tiny_servo_common as c
+import fw_update
from ecusb import tiny_servod
+
class ServoUpdaterException(Exception):
- """Raised on exceptions generated by servo_updater."""
+ """Raised on exceptions generated by servo_updater."""
-BOARD_C2D2 = 'c2d2'
-BOARD_SERVO_MICRO = 'servo_micro'
-BOARD_SERVO_V4 = 'servo_v4'
-BOARD_SERVO_V4P1 = 'servo_v4p1'
-BOARD_SWEETBERRY = 'sweetberry'
+
+BOARD_C2D2 = "c2d2"
+BOARD_SERVO_MICRO = "servo_micro"
+BOARD_SERVO_V4 = "servo_v4"
+BOARD_SERVO_V4P1 = "servo_v4p1"
+BOARD_SWEETBERRY = "sweetberry"
DEFAULT_BOARD = BOARD_SERVO_V4
# These lists are to facilitate exposing choices in the command-line tool
# below.
-BOARDS = [BOARD_C2D2, BOARD_SERVO_MICRO, BOARD_SERVO_V4, BOARD_SERVO_V4P1,
- BOARD_SWEETBERRY]
+BOARDS = [
+ BOARD_C2D2,
+ BOARD_SERVO_MICRO,
+ BOARD_SERVO_V4,
+ BOARD_SERVO_V4P1,
+ BOARD_SWEETBERRY,
+]
# Servo firmware bundles four channels of firmware. We need to make sure the
# user does not request a non-existing channel, so keep the lists around to
# guard on command-line usage.
-DEFAULT_CHANNEL = STABLE_CHANNEL = 'stable'
+DEFAULT_CHANNEL = STABLE_CHANNEL = "stable"
-PREV_CHANNEL = 'prev'
+PREV_CHANNEL = "prev"
# The ordering here matters. From left to right it's the channel that the user
# is most likely to be running. This is used to inform and warn the user if
@@ -55,403 +61,436 @@ PREV_CHANNEL = 'prev'
# user know they are running the 'stable' version before letting them know they
# are running 'dev' or even 'alpah' which (while true) might cause confusion.
-CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, 'dev', 'alpha']
+CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, "dev", "alpha"]
-DEFAULT_BASE_PATH = '/usr/'
-TEST_IMAGE_BASE_PATH = '/usr/local/'
+DEFAULT_BASE_PATH = "/usr/"
+TEST_IMAGE_BASE_PATH = "/usr/local/"
-COMMON_PATH = 'share/servo_updater'
+COMMON_PATH = "share/servo_updater"
-FIRMWARE_DIR = 'firmware/'
-CONFIGS_DIR = 'configs/'
+FIRMWARE_DIR = "firmware/"
+CONFIGS_DIR = "configs/"
RETRIES_COUNT = 10
RETRIES_DELAY = 1
+
def do_with_retries(func, *args):
- """Try a function several times
-
- Call function passed as argument and check if no error happened.
- If exception was raised by function,
- it will be retried up to RETRIES_COUNT times.
-
- Args:
- func: function that will be called
- args: arguments passed to 'func'
-
- Returns:
- If call to function was successful, its result will be returned.
- If retries count was exceeded, exception will be raised.
- """
-
- retry = 0
- while retry < RETRIES_COUNT:
- try:
- return func(*args)
- except Exception as e:
- print('Retrying function %s: %s' % (func.__name__, e))
- retry = retry + 1
- time.sleep(RETRIES_DELAY)
- continue
-
- raise Exception("'{}' failed after {} retries".format(
- func.__name__, RETRIES_COUNT))
+ """Try a function several times
+
+ Call function passed as argument and check if no error happened.
+ If exception was raised by function,
+ it will be retried up to RETRIES_COUNT times.
+
+ Args:
+ func: function that will be called
+ args: arguments passed to 'func'
+
+ Returns:
+ If call to function was successful, its result will be returned.
+ If retries count was exceeded, exception will be raised.
+ """
+
+ retry = 0
+ while retry < RETRIES_COUNT:
+ try:
+ return func(*args)
+ except Exception as e:
+ print("Retrying function %s: %s" % (func.__name__, e))
+ retry = retry + 1
+ time.sleep(RETRIES_DELAY)
+ continue
+
+ raise Exception("'{}' failed after {} retries".format(func.__name__, RETRIES_COUNT))
+
def flash(brdfile, serialno, binfile):
- """Call fw_update to upload to updater USB endpoint.
+ """Call fw_update to upload to updater USB endpoint.
+
+ Args:
+ brdfile: path to board configuration file
+ serialno: device serial number
+ binfile: firmware file
+ """
- Args:
- brdfile: path to board configuration file
- serialno: device serial number
- binfile: firmware file
- """
+ p = fw_update.Supdate()
+ p.load_board(brdfile)
+ p.connect_usb(serialname=serialno)
+ p.load_file(binfile)
- p = fw_update.Supdate()
- p.load_board(brdfile)
- p.connect_usb(serialname=serialno)
- p.load_file(binfile)
+ # Start transfer and erase.
+ p.start()
+ # Upload the bin file
+ print("Uploading %s" % binfile)
+ p.write_file()
- # Start transfer and erase.
- p.start()
- # Upload the bin file
- print('Uploading %s' % binfile)
- p.write_file()
+ # Finalize
+ print("Done. Finalizing.")
+ p.stop()
- # Finalize
- print('Done. Finalizing.')
- p.stop()
def flash2(vidpid, serialno, binfile):
- """Call fw update via usb_updater2 commandline.
-
- Args:
- vidpid: vendor id and product id of device
- serialno: device serial number (optional)
- binfile: firmware file
- """
-
- tool = 'usb_updater2'
- cmd = '%s -d %s' % (tool, vidpid)
- if serialno:
- cmd += ' -S %s' % serialno
- cmd += ' -n'
- cmd += ' %s' % binfile
-
- print(cmd)
- help_cmd = '%s --help' % tool
- with open('/dev/null') as devnull:
- valid_check = subprocess.call(help_cmd.split(), stdout=devnull,
- stderr=devnull)
- if valid_check:
- raise ServoUpdaterException('%s exit with res = %d. Make sure the tool '
- 'is available on the device.' % (help_cmd,
- valid_check))
- res = subprocess.call(cmd.split())
-
- if res in (0, 1, 2):
- return res
- else:
- raise ServoUpdaterException('%s exit with res = %d' % (cmd, res))
+ """Call fw update via usb_updater2 commandline.
+
+ Args:
+ vidpid: vendor id and product id of device
+ serialno: device serial number (optional)
+ binfile: firmware file
+ """
+
+ tool = "usb_updater2"
+ cmd = "%s -d %s" % (tool, vidpid)
+ if serialno:
+ cmd += " -S %s" % serialno
+ cmd += " -n"
+ cmd += " %s" % binfile
+
+ print(cmd)
+ help_cmd = "%s --help" % tool
+ with open("/dev/null") as devnull:
+ valid_check = subprocess.call(help_cmd.split(), stdout=devnull, stderr=devnull)
+ if valid_check:
+ raise ServoUpdaterException(
+ "%s exit with res = %d. Make sure the tool "
+ "is available on the device." % (help_cmd, valid_check)
+ )
+ res = subprocess.call(cmd.split())
+
+ if res in (0, 1, 2):
+ return res
+ else:
+ raise ServoUpdaterException("%s exit with res = %d" % (cmd, res))
+
def select(tinys, region):
- """Jump to specified boot region
+ """Jump to specified boot region
- Ensure the servo is in the expected ro/rw region.
- This function jumps to the required region and verify if jump was
- successful by executing 'sysinfo' command and reading current region.
- If response was not received or region is invalid, exception is raised.
+ Ensure the servo is in the expected ro/rw region.
+ This function jumps to the required region and verify if jump was
+ successful by executing 'sysinfo' command and reading current region.
+ If response was not received or region is invalid, exception is raised.
- Args:
- tinys: TinyServod object
- region: region to jump to, only "rw" and "ro" is allowed
- """
+ Args:
+ tinys: TinyServod object
+ region: region to jump to, only "rw" and "ro" is allowed
+ """
- if region not in ['rw', 'ro']:
- raise Exception('Region must be ro or rw')
+ if region not in ["rw", "ro"]:
+ raise Exception("Region must be ro or rw")
- if region == 'ro':
- cmd = 'reboot'
- else:
- cmd = 'sysjump %s' % region
+ if region == "ro":
+ cmd = "reboot"
+ else:
+ cmd = "sysjump %s" % region
- tinys.pty._issue_cmd(cmd)
+ tinys.pty._issue_cmd(cmd)
- tinys.close()
- time.sleep(2)
- tinys.reinitialize()
+ tinys.close()
+ time.sleep(2)
+ tinys.reinitialize()
+
+ res = tinys.pty._issue_cmd_get_results("sysinfo", [r"Copy:[\s]+(RO|RW)"])
+ current_region = res[0][1].lower()
+ if current_region != region:
+ raise Exception("Invalid region: %s/%s" % (current_region, region))
- res = tinys.pty._issue_cmd_get_results('sysinfo', [r'Copy:[\s]+(RO|RW)'])
- current_region = res[0][1].lower()
- if current_region != region:
- raise Exception('Invalid region: %s/%s' % (current_region, region))
def do_version(tinys):
- """Check version via ec console 'pty'.
+ """Check version via ec console 'pty'.
+
+ Args:
+ tinys: TinyServod object
- Args:
- tinys: TinyServod object
+ Returns:
+ detected version number
- Returns:
- detected version number
+ Commands are:
+ # > version
+ # ...
+ # Build: tigertail_v1.1.6749-74d1a312e
+ """
+ cmd = "version"
+ regex = r"Build:\s+(\S+)[\r\n]+"
- Commands are:
- # > version
- # ...
- # Build: tigertail_v1.1.6749-74d1a312e
- """
- cmd = 'version'
- regex = r'Build:\s+(\S+)[\r\n]+'
+ results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0]
- results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0]
+ return results[1].strip(" \t\r\n\0")
- return results[1].strip(' \t\r\n\0')
def do_updater_version(tinys):
- """Check whether this uses python updater or c++ updater
-
- Args:
- tinys: TinyServod object
-
- Returns:
- updater version number. 2 or 6.
- """
- vers = do_version(tinys)
-
- # Servo versions below 58 are from servo-9040.B. Versions starting with _v2
- # are newer than anything _v1, no need to check the exact number. Updater
- # version is not directly queryable.
- if re.search(r'_v[2-9]\.\d', vers):
- return 6
- m = re.search(r'_v1\.1\.(\d\d\d\d)', vers)
- if m:
- version_number = int(m.group(1))
- if version_number < 5800:
- return 2
- else:
- return 6
- raise ServoUpdaterException(
- "Can't determine updater target from vers: [%s]" % vers)
+ """Check whether this uses python updater or c++ updater
-def _extract_version(boardname, binfile):
- """Find the version string from |binfile|.
+ Args:
+ tinys: TinyServod object
- Args:
- boardname: the name of the board, eg. "servo_micro"
- binfile: path to the binary to search
+ Returns:
+ updater version number. 2 or 6.
+ """
+ vers = do_version(tinys)
- Returns:
- the version string.
- """
- if boardname is None:
- # cannot extract the version if the name is None
- return None
- rawstrings = subprocess.check_output(
- ['cbfstool', binfile, 'read', '-r', 'RO_FRID', '-f', '/dev/stdout'],
- **c.get_subprocess_args())
- m = re.match(r'%s_v\S+' % boardname, rawstrings)
- if m:
- newvers = m.group(0).strip(' \t\r\n\0')
- else:
- raise ServoUpdaterException("Can't find version from file: %s." % binfile)
+ # Servo versions below 58 are from servo-9040.B. Versions starting with _v2
+ # are newer than anything _v1, no need to check the exact number. Updater
+ # version is not directly queryable.
+ if re.search(r"_v[2-9]\.\d", vers):
+ return 6
+ m = re.search(r"_v1\.1\.(\d\d\d\d)", vers)
+ if m:
+ version_number = int(m.group(1))
+ if version_number < 5800:
+ return 2
+ else:
+ return 6
+ raise ServoUpdaterException("Can't determine updater target from vers: [%s]" % vers)
+
+
+def _extract_version(boardname, binfile):
+ """Find the version string from |binfile|.
+
+ Args:
+ boardname: the name of the board, eg. "servo_micro"
+ binfile: path to the binary to search
+
+ Returns:
+ the version string.
+ """
+ if boardname is None:
+ # cannot extract the version if the name is None
+ return None
+ rawstrings = subprocess.check_output(
+ ["cbfstool", binfile, "read", "-r", "RO_FRID", "-f", "/dev/stdout"],
+ **c.get_subprocess_args()
+ )
+ m = re.match(r"%s_v\S+" % boardname, rawstrings)
+ if m:
+ newvers = m.group(0).strip(" \t\r\n\0")
+ else:
+ raise ServoUpdaterException("Can't find version from file: %s." % binfile)
+
+ return newvers
- return newvers
def get_firmware_channel(bname, version):
- """Find out which channel |version| for |bname| came from.
-
- Args:
- bname: board name
- version: current version string
-
- Returns:
- one of the channel names if |version| came from one of those, or None
- """
- for channel in CHANNELS:
- # Pass |bname| as cname to find the board specific file, and pass None as
- # fname to ensure the default directory is searched
- _, _, vers = get_files_and_version(bname, None, channel=channel)
- if version == vers:
- return channel
- # None of the channels matched. This firmware is currently unknown.
- return None
+ """Find out which channel |version| for |bname| came from.
+
+ Args:
+ bname: board name
+ version: current version string
+
+ Returns:
+ one of the channel names if |version| came from one of those, or None
+ """
+ for channel in CHANNELS:
+ # Pass |bname| as cname to find the board specific file, and pass None as
+ # fname to ensure the default directory is searched
+ _, _, vers = get_files_and_version(bname, None, channel=channel)
+ if version == vers:
+ return channel
+ # None of the channels matched. This firmware is currently unknown.
+ return None
+
def get_files_and_version(cname, fname=None, channel=DEFAULT_CHANNEL):
- """Select config and firmware binary files.
-
- This checks default file names and paths.
- In: /usr/share/servo_updater/[firmware|configs]
- check for board.json, board.bin
-
- Args:
- cname: board name, or config name. eg. "servo_v4" or "servo_v4.json"
- fname: firmware binary name. Can be None to try default.
- channel: the channel requested for servo firmware. See |CHANNELS| above.
-
- Returns:
- cname, fname, version: validated filenames selected from the path.
- """
- for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH):
- updater_path = os.path.join(p, COMMON_PATH)
- if os.path.exists(updater_path):
- break
- else:
- raise ServoUpdaterException('servo_updater/ dir not found in known spots.')
-
- firmware_path = os.path.join(updater_path, FIRMWARE_DIR)
- configs_path = os.path.join(updater_path, CONFIGS_DIR)
-
- for p in (firmware_path, configs_path):
- if not os.path.exists(p):
- raise ServoUpdaterException('Could not find required path %r' % p)
-
- if not os.path.isfile(cname):
- # If not an existing file, try checking on the default path.
- newname = os.path.join(configs_path, cname)
- if os.path.isfile(newname):
- cname = newname
+ """Select config and firmware binary files.
+
+ This checks default file names and paths.
+ In: /usr/share/servo_updater/[firmware|configs]
+ check for board.json, board.bin
+
+ Args:
+ cname: board name, or config name. eg. "servo_v4" or "servo_v4.json"
+ fname: firmware binary name. Can be None to try default.
+ channel: the channel requested for servo firmware. See |CHANNELS| above.
+
+ Returns:
+ cname, fname, version: validated filenames selected from the path.
+ """
+ for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH):
+ updater_path = os.path.join(p, COMMON_PATH)
+ if os.path.exists(updater_path):
+ break
else:
- # Try appending ".json" to convert board name to config file.
- cname = newname + '.json'
+ raise ServoUpdaterException("servo_updater/ dir not found in known spots.")
+
+ firmware_path = os.path.join(updater_path, FIRMWARE_DIR)
+ configs_path = os.path.join(updater_path, CONFIGS_DIR)
+
+ for p in (firmware_path, configs_path):
+ if not os.path.exists(p):
+ raise ServoUpdaterException("Could not find required path %r" % p)
+
if not os.path.isfile(cname):
- raise ServoUpdaterException("Can't find config file: %s." % cname)
-
- # Always retrieve the boardname
- with open(cname) as data_file:
- data = json.load(data_file)
- boardname = data['board']
-
- if not fname:
- # If no |fname| supplied, look for the default locations with the board
- # and channel requested.
- binary_file = '%s.%s.bin' % (boardname, channel)
- newname = os.path.join(firmware_path, binary_file)
- if os.path.isfile(newname):
- fname = newname
- else:
- raise ServoUpdaterException("Can't find firmware binary: %s." %
- binary_file)
- elif not os.path.isfile(fname):
- # If a name is specified but not found, try the default path.
- newname = os.path.join(firmware_path, fname)
- if os.path.isfile(newname):
- fname = newname
+ # If not an existing file, try checking on the default path.
+ newname = os.path.join(configs_path, cname)
+ if os.path.isfile(newname):
+ cname = newname
+ else:
+ # Try appending ".json" to convert board name to config file.
+ cname = newname + ".json"
+ if not os.path.isfile(cname):
+ raise ServoUpdaterException("Can't find config file: %s." % cname)
+
+ # Always retrieve the boardname
+ with open(cname) as data_file:
+ data = json.load(data_file)
+ boardname = data["board"]
+
+ if not fname:
+ # If no |fname| supplied, look for the default locations with the board
+ # and channel requested.
+ binary_file = "%s.%s.bin" % (boardname, channel)
+ newname = os.path.join(firmware_path, binary_file)
+ if os.path.isfile(newname):
+ fname = newname
+ else:
+ raise ServoUpdaterException("Can't find firmware binary: %s." % binary_file)
+ elif not os.path.isfile(fname):
+ # If a name is specified but not found, try the default path.
+ newname = os.path.join(firmware_path, fname)
+ if os.path.isfile(newname):
+ fname = newname
+ else:
+ raise ServoUpdaterException("Can't find file: %s." % fname)
+
+ # Lastly, retrieve the version as well for decision making, debug, and
+ # informational purposes.
+ binvers = _extract_version(boardname, fname)
+
+ return cname, fname, binvers
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Image a servo device")
+ parser.add_argument(
+ "-p",
+ "--print",
+ dest="print_only",
+ action="store_true",
+ default=False,
+ help="only print available firmware for board/channel",
+ )
+ parser.add_argument(
+ "-s", "--serialno", type=str, help="serial number to program", default=None
+ )
+ parser.add_argument(
+ "-b",
+ "--board",
+ type=str,
+ help="Board configuration json file",
+ default=DEFAULT_BOARD,
+ choices=BOARDS,
+ )
+ parser.add_argument(
+ "-c",
+ "--channel",
+ type=str,
+ help="Firmware channel to use",
+ default=DEFAULT_CHANNEL,
+ choices=CHANNELS,
+ )
+ parser.add_argument(
+ "-f", "--file", type=str, help="Complete ec.bin file", default=None
+ )
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Update even if version match",
+ default=False,
+ )
+ parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output")
+ parser.add_argument(
+ "-r", "--reboot", action="store_true", help="Always reboot, even after probe."
+ )
+
+ args = parser.parse_args()
+
+ brdfile, binfile, newvers = get_files_and_version(
+ args.board, args.file, args.channel
+ )
+
+ # If the user only cares about the information then just print it here,
+ # and exit.
+ if args.print_only:
+ output = ("board: %s\n" "channel: %s\n" "firmware: %s") % (
+ args.board,
+ args.channel,
+ newvers,
+ )
+ print(output)
+ return
+
+ serialno = args.serialno
+
+ with open(brdfile) as data_file:
+ data = json.load(data_file)
+ vid, pid = int(data["vid"], 0), int(data["pid"], 0)
+ vidpid = "%04x:%04x" % (vid, pid)
+ iface = int(data["console"], 0)
+ boardname = data["board"]
+
+ # Make sure device is up.
+ print("===== Waiting for USB device =====")
+ c.wait_for_usb(vidpid, serialname=serialno)
+ # We need a tiny_servod to query some information. Set it up first.
+ tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose)
+
+ if not args.force:
+ vers = do_version(tinys)
+ print("Current %s version is %s" % (boardname, vers))
+ print("Available %s version is %s" % (boardname, newvers))
+
+ if newvers == vers:
+ print("No version update needed")
+ if args.reboot:
+ select(tinys, "ro")
+ return
+ else:
+ print("Updating to recommended version.")
+
+ # Make sure the servo MCU is in RO
+ print("===== Jumping to RO =====")
+ do_with_retries(select, tinys, "ro")
+
+ print("===== Flashing RW =====")
+ vers = do_with_retries(do_updater_version, tinys)
+ # To make sure that the tiny_servod here does not interfere with other
+ # processes, close it out.
+ tinys.close()
+
+ if vers == 2:
+ flash(brdfile, serialno, binfile)
+ elif vers == 6:
+ do_with_retries(flash2, vidpid, serialno, binfile)
else:
- raise ServoUpdaterException("Can't find file: %s." % fname)
+ raise ServoUpdaterException("Can't detect updater version")
- # Lastly, retrieve the version as well for decision making, debug, and
- # informational purposes.
- binvers = _extract_version(boardname, fname)
+ # Make sure device is up.
+ c.wait_for_usb(vidpid, serialname=serialno)
+ # After we have made sure that it's back/available, reconnect the tiny servod.
+ tinys.reinitialize()
- return cname, fname, binvers
+ # Make sure the servo MCU is in RW
+ print("===== Jumping to RW =====")
+ do_with_retries(select, tinys, "rw")
-def main():
- parser = argparse.ArgumentParser(description='Image a servo device')
- parser.add_argument('-p', '--print', dest='print_only', action='store_true',
- default=False,
- help='only print available firmware for board/channel')
- parser.add_argument('-s', '--serialno', type=str,
- help='serial number to program', default=None)
- parser.add_argument('-b', '--board', type=str,
- help='Board configuration json file',
- default=DEFAULT_BOARD, choices=BOARDS)
- parser.add_argument('-c', '--channel', type=str,
- help='Firmware channel to use',
- default=DEFAULT_CHANNEL, choices=CHANNELS)
- parser.add_argument('-f', '--file', type=str,
- help='Complete ec.bin file', default=None)
- parser.add_argument('--force', action='store_true',
- help='Update even if version match', default=False)
- parser.add_argument('-v', '--verbose', action='store_true',
- help='Chatty output')
- parser.add_argument('-r', '--reboot', action='store_true',
- help='Always reboot, even after probe.')
-
- args = parser.parse_args()
-
- brdfile, binfile, newvers = get_files_and_version(args.board, args.file,
- args.channel)
-
- # If the user only cares about the information then just print it here,
- # and exit.
- if args.print_only:
- output = ('board: %s\n'
- 'channel: %s\n'
- 'firmware: %s') % (args.board, args.channel, newvers)
- print(output)
- return
-
- serialno = args.serialno
-
- with open(brdfile) as data_file:
- data = json.load(data_file)
- vid, pid = int(data['vid'], 0), int(data['pid'], 0)
- vidpid = '%04x:%04x' % (vid, pid)
- iface = int(data['console'], 0)
- boardname = data['board']
-
- # Make sure device is up.
- print('===== Waiting for USB device =====')
- c.wait_for_usb(vidpid, serialname=serialno)
- # We need a tiny_servod to query some information. Set it up first.
- tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose)
-
- if not args.force:
- vers = do_version(tinys)
- print('Current %s version is %s' % (boardname, vers))
- print('Available %s version is %s' % (boardname, newvers))
-
- if newvers == vers:
- print('No version update needed')
- if args.reboot:
- select(tinys, 'ro')
- return
+ print("===== Flashing RO =====")
+ vers = do_with_retries(do_updater_version, tinys)
+
+ if vers == 2:
+ flash(brdfile, serialno, binfile)
+ elif vers == 6:
+ do_with_retries(flash2, vidpid, serialno, binfile)
else:
- print('Updating to recommended version.')
-
- # Make sure the servo MCU is in RO
- print('===== Jumping to RO =====')
- do_with_retries(select, tinys, 'ro')
-
- print('===== Flashing RW =====')
- vers = do_with_retries(do_updater_version, tinys)
- # To make sure that the tiny_servod here does not interfere with other
- # processes, close it out.
- tinys.close()
-
- if vers == 2:
- flash(brdfile, serialno, binfile)
- elif vers == 6:
- do_with_retries(flash2, vidpid, serialno, binfile)
- else:
- raise ServoUpdaterException("Can't detect updater version")
-
- # Make sure device is up.
- c.wait_for_usb(vidpid, serialname=serialno)
- # After we have made sure that it's back/available, reconnect the tiny servod.
- tinys.reinitialize()
-
- # Make sure the servo MCU is in RW
- print('===== Jumping to RW =====')
- do_with_retries(select, tinys, 'rw')
-
- print('===== Flashing RO =====')
- vers = do_with_retries(do_updater_version, tinys)
-
- if vers == 2:
- flash(brdfile, serialno, binfile)
- elif vers == 6:
- do_with_retries(flash2, vidpid, serialno, binfile)
- else:
- raise ServoUpdaterException("Can't detect updater version")
-
- # Make sure the servo MCU is in RO
- print('===== Rebooting =====')
- do_with_retries(select, tinys, 'ro')
- # Perform additional reboot to free USB/UART resources, taken by tiny servod.
- # See https://issuetracker.google.com/196021317 for background.
- tinys.pty._issue_cmd('reboot')
-
- print('===== Finished =====')
-
-if __name__ == '__main__':
- main()
+ raise ServoUpdaterException("Can't detect updater version")
+
+ # Make sure the servo MCU is in RO
+ print("===== Rebooting =====")
+ do_with_retries(select, tinys, "ro")
+ # Perform additional reboot to free USB/UART resources, taken by tiny servod.
+ # See https://issuetracker.google.com/196021317 for background.
+ tinys.pty._issue_cmd("reboot")
+
+ print("===== Finished =====")
+
+
+if __name__ == "__main__":
+ main()