summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Bettis <jbettis@google.com>2022-07-08 10:58:19 -0600
committerChromeos LUCI <chromeos-scoped@luci-project-accounts.iam.gserviceaccount.com>2022-07-12 19:13:33 +0000
commit7540e7b47b55447475bb8191fb3520dd67cf7998 (patch)
tree13309dbcf1db48e60fa2c2e5aed79f63bce00b5e
parent7c114b8e1a3bb29991da70b9de394ac5d4f6c909 (diff)
downloadchrome-ec-7540e7b47b55447475bb8191fb3520dd67cf7998.tar.gz
ec: Format all python files with black and isort
find . \( -path ./private -prune \) -o -name '*.py' -print | xargs black find . \( -path ./private -prune \) -o -name '*.py' -print | xargs ~/chromiumos/chromite/scripts/isort --settings-file=.isort.cfg BRANCH=None BUG=b:238434058 TEST=None Signed-off-by: Jeremy Bettis <jbettis@google.com> Change-Id: I63462d6f15d1eaf3db84eb20d1404ee976be8382 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/ec/+/3749242 Commit-Queue: Jeremy Bettis <jbettis@chromium.org> Reviewed-by: Tom Hughes <tomhughes@chromium.org> Tested-by: Jeremy Bettis <jbettis@chromium.org> Commit-Queue: Jack Rosenthal <jrosenth@chromium.org> Auto-Submit: Jeremy Bettis <jbettis@chromium.org> Reviewed-by: Jack Rosenthal <jrosenth@chromium.org>
-rwxr-xr-xchip/ish/util/pack_ec.py155
-rwxr-xr-xchip/mchp/util/pack_ec.py885
-rwxr-xr-xchip/mchp/util/pack_ec_mec152x.py1132
-rwxr-xr-xchip/mchp/util/pack_ec_mec172x.py1178
-rwxr-xr-xchip/mec1322/util/pack_ec.py479
-rw-r--r--cts/common/board.py698
-rwxr-xr-xcts/cts.py804
-rwxr-xr-xextra/cr50_rma_open/cr50_rma_open.py418
-rwxr-xr-xextra/stack_analyzer/stack_analyzer.py3562
-rwxr-xr-xextra/stack_analyzer/stack_analyzer_unittest.py1708
-rw-r--r--extra/tigertool/ecusb/__init__.py2
-rw-r--r--extra/tigertool/ecusb/pty_driver.py545
-rw-r--r--extra/tigertool/ecusb/stm32uart.py438
-rw-r--r--extra/tigertool/ecusb/stm32usb.py207
-rw-r--r--extra/tigertool/ecusb/tiny_servo_common.py343
-rw-r--r--extra/tigertool/ecusb/tiny_servod.py83
-rwxr-xr-xextra/tigertool/tigertest.py97
-rwxr-xr-xextra/tigertool/tigertool.py505
-rw-r--r--extra/usb_power/convert_power_log_board.py35
-rwxr-xr-xextra/usb_power/convert_servo_ina.py86
-rwxr-xr-xextra/usb_power/powerlog.py1762
-rw-r--r--extra/usb_power/powerlog_unittest.py74
-rw-r--r--extra/usb_power/stats_manager.py745
-rw-r--r--extra/usb_power/stats_manager_unittest.py595
-rwxr-xr-xextra/usb_serial/console.py435
-rwxr-xr-xextra/usb_updater/fw_update.py762
-rwxr-xr-xextra/usb_updater/servo_updater.py767
-rwxr-xr-xfirmware_builder.py129
-rw-r--r--setup.py39
-rwxr-xr-xtest/run_device_tests.py543
-rw-r--r--test/timer_calib.py78
-rw-r--r--test/timer_jump.py37
-rwxr-xr-xutil/build_with_clang.py42
-rw-r--r--util/chargen25
-rwxr-xr-xutil/config_option_check.py693
-rwxr-xr-xutil/ec3po/console.py2231
-rwxr-xr-xutil/ec3po/console_unittest.py2954
-rw-r--r--util/ec3po/interpreter.py818
-rwxr-xr-xutil/ec3po/interpreter_unittest.py728
-rw-r--r--util/ec3po/threadproc_shim.py31
-rwxr-xr-xutil/ec_openocd.py24
-rwxr-xr-xutil/flash_jlink.py144
-rwxr-xr-xutil/fptool.py17
-rwxr-xr-xutil/inject-keys.py149
-rwxr-xr-xutil/kconfig_check.py230
-rw-r--r--util/kconfiglib.py1581
-rw-r--r--util/run_ects.py144
-rw-r--r--util/test_kconfig_check.py181
-rwxr-xr-xutil/uart_stress_tester.py1015
-rwxr-xr-xutil/unpack_ftb.py124
-rwxr-xr-xutil/update_release_branch.py251
-rwxr-xr-xzephyr/firmware_builder.py209
-rw-r--r--zephyr/zmake/tests/conftest.py1
-rw-r--r--zephyr/zmake/tests/test_build_config.py1
-rw-r--r--zephyr/zmake/tests/test_generate_readme.py1
-rw-r--r--zephyr/zmake/tests/test_modules.py1
-rw-r--r--zephyr/zmake/tests/test_packers.py1
-rw-r--r--zephyr/zmake/tests/test_project.py1
-rw-r--r--zephyr/zmake/tests/test_reexec.py1
-rw-r--r--zephyr/zmake/tests/test_toolchains.py1
-rw-r--r--zephyr/zmake/tests/test_util.py1
-rw-r--r--zephyr/zmake/tests/test_zmake.py3
62 files changed, 16388 insertions, 14541 deletions
diff --git a/chip/ish/util/pack_ec.py b/chip/ish/util/pack_ec.py
index bd9b823cab..8dde6ab6a9 100755
--- a/chip/ish/util/pack_ec.py
+++ b/chip/ish/util/pack_ec.py
@@ -28,85 +28,100 @@ MANIFEST_ENTRY_SIZE = 0x80
HEADER_SIZE = 0x1000
PAGE_SIZE = 0x1000
+
def parseargs():
- parser = argparse.ArgumentParser()
- parser.add_argument("-k", "--kernel",
- help="EC kernel binary to pack, \
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-k",
+ "--kernel",
+ help="EC kernel binary to pack, \
usually ec.RW.bin or ec.RW.flat.",
- required=True)
- parser.add_argument("--kernel-size", type=int,
- help="Size of EC kernel image",
- required=True)
- parser.add_argument("-a", "--aon",
- help="EC aontask binary to pack, \
+ required=True,
+ )
+ parser.add_argument(
+ "--kernel-size", type=int, help="Size of EC kernel image", required=True
+ )
+ parser.add_argument(
+ "-a",
+ "--aon",
+ help="EC aontask binary to pack, \
usually ish_aontask.bin.",
- required=False)
- parser.add_argument("--aon-size", type=int,
- help="Size of EC aontask image",
- required=False)
- parser.add_argument("-o", "--output",
- help="Output flash binary file")
+ required=False,
+ )
+ parser.add_argument(
+ "--aon-size", type=int, help="Size of EC aontask image", required=False
+ )
+ parser.add_argument("-o", "--output", help="Output flash binary file")
+
+ return parser.parse_args()
- return parser.parse_args()
def gen_manifest(ext_id, comp_app_name, code_offset, module_size):
- """Returns a binary blob that represents a manifest entry"""
- m = bytearray(MANIFEST_ENTRY_SIZE)
+ """Returns a binary blob that represents a manifest entry"""
+ m = bytearray(MANIFEST_ENTRY_SIZE)
- # 4 bytes of ASCII encode ID (little endian)
- struct.pack_into('<4s', m, 0, ext_id)
- # 8 bytes of ASCII encode ID (little endian)
- struct.pack_into('<8s', m, 32, comp_app_name)
- # 4 bytes of code offset (little endian)
- struct.pack_into('<I', m, 96, code_offset)
- # 2 bytes of module in page size increments (little endian)
- struct.pack_into('<H', m, 100, module_size)
+ # 4 bytes of ASCII encode ID (little endian)
+ struct.pack_into("<4s", m, 0, ext_id)
+ # 8 bytes of ASCII encode ID (little endian)
+ struct.pack_into("<8s", m, 32, comp_app_name)
+ # 4 bytes of code offset (little endian)
+ struct.pack_into("<I", m, 96, code_offset)
+ # 2 bytes of module in page size increments (little endian)
+ struct.pack_into("<H", m, 100, module_size)
+
+ return m
- return m
def roundup_page(size):
- """Returns roundup-ed page size from size of bytes"""
- return int(size / PAGE_SIZE) + (size % PAGE_SIZE > 0)
+ """Returns roundup-ed page size from size of bytes"""
+ return int(size / PAGE_SIZE) + (size % PAGE_SIZE > 0)
+
def main():
- args = parseargs()
- print(" Packing EC image file for ISH")
-
- with open(args.output, 'wb') as f:
- print(" kernel binary size:", args.kernel_size)
- kern_rdup_pg_size = roundup_page(args.kernel_size)
- # Add manifest for main ISH binary
- f.write(gen_manifest(b'ISHM', b'ISH_KERN', HEADER_SIZE, kern_rdup_pg_size))
-
- if args.aon is not None:
- print(" AON binary size: ", args.aon_size)
- aon_rdup_pg_size = roundup_page(args.aon_size)
- # Add manifest for aontask binary
- f.write(gen_manifest(b'ISHM', b'AON_TASK',
- (HEADER_SIZE + kern_rdup_pg_size * PAGE_SIZE -
- MANIFEST_ENTRY_SIZE), aon_rdup_pg_size))
-
- # Add manifest that signals end of manifests
- f.write(gen_manifest(b'ISHE', b'', 0, 0))
-
- # Pad the remaining HEADER with 0s
- if args.aon is not None:
- f.write(b'\x00' * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 3)))
- else:
- f.write(b'\x00' * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 2)))
-
- # Append original kernel image
- with open(args.kernel, 'rb') as in_file:
- f.write(in_file.read())
- # Filling padings due to size round up as pages
- f.write(b'\x00' * (kern_rdup_pg_size * PAGE_SIZE - args.kernel_size))
-
- if args.aon is not None:
- # Append original aon image
- with open(args.aon, 'rb') as in_file:
- f.write(in_file.read())
- # Filling padings due to size round up as pages
- f.write(b'\x00' * (aon_rdup_pg_size * PAGE_SIZE - args.aon_size))
-
-if __name__ == '__main__':
- main()
+ args = parseargs()
+ print(" Packing EC image file for ISH")
+
+ with open(args.output, "wb") as f:
+ print(" kernel binary size:", args.kernel_size)
+ kern_rdup_pg_size = roundup_page(args.kernel_size)
+ # Add manifest for main ISH binary
+ f.write(gen_manifest(b"ISHM", b"ISH_KERN", HEADER_SIZE, kern_rdup_pg_size))
+
+ if args.aon is not None:
+ print(" AON binary size: ", args.aon_size)
+ aon_rdup_pg_size = roundup_page(args.aon_size)
+ # Add manifest for aontask binary
+ f.write(
+ gen_manifest(
+ b"ISHM",
+ b"AON_TASK",
+ (HEADER_SIZE + kern_rdup_pg_size * PAGE_SIZE - MANIFEST_ENTRY_SIZE),
+ aon_rdup_pg_size,
+ )
+ )
+
+ # Add manifest that signals end of manifests
+ f.write(gen_manifest(b"ISHE", b"", 0, 0))
+
+ # Pad the remaining HEADER with 0s
+ if args.aon is not None:
+ f.write(b"\x00" * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 3)))
+ else:
+ f.write(b"\x00" * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 2)))
+
+ # Append original kernel image
+ with open(args.kernel, "rb") as in_file:
+ f.write(in_file.read())
+ # Filling padings due to size round up as pages
+ f.write(b"\x00" * (kern_rdup_pg_size * PAGE_SIZE - args.kernel_size))
+
+ if args.aon is not None:
+ # Append original aon image
+ with open(args.aon, "rb") as in_file:
+ f.write(in_file.read())
+ # Filling padings due to size round up as pages
+ f.write(b"\x00" * (aon_rdup_pg_size * PAGE_SIZE - args.aon_size))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/chip/mchp/util/pack_ec.py b/chip/mchp/util/pack_ec.py
index 7908b0bf37..15be16c0d4 100755
--- a/chip/mchp/util/pack_ec.py
+++ b/chip/mchp/util/pack_ec.py
@@ -16,7 +16,7 @@ import os
import struct
import subprocess
import tempfile
-import zlib # CRC32
+import zlib # CRC32
# MEC1701 has 256KB SRAM from 0xE0000 - 0x120000
# SRAM is divided into contiguous CODE & DATA
@@ -30,165 +30,199 @@ LOAD_ADDR = 0x0E0000
LOAD_ADDR_RW = 0xE1000
HEADER_SIZE = 0x40
SPI_CLOCK_LIST = [48, 24, 16, 12]
-SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b, 0x6b]
+SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B, 0x6B]
+
+CRC_TABLE = [
+ 0x00,
+ 0x07,
+ 0x0E,
+ 0x09,
+ 0x1C,
+ 0x1B,
+ 0x12,
+ 0x15,
+ 0x38,
+ 0x3F,
+ 0x36,
+ 0x31,
+ 0x24,
+ 0x23,
+ 0x2A,
+ 0x2D,
+]
-CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15,
- 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d]
def mock_print(*args, **kwargs):
- pass
+ pass
+
debug_print = mock_print
+
def Crc8(crc, data):
- """Update CRC8 value."""
- for v in data:
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]);
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]);
- return crc ^ 0x55
+ """Update CRC8 value."""
+ for v in data:
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)])
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)])
+ return crc ^ 0x55
+
def GetEntryPoint(payload_file):
- """Read entry point from payload EC image."""
- with open(payload_file, 'rb') as f:
- f.seek(4)
- s = f.read(4)
- return struct.unpack('<I', s)[0]
+ """Read entry point from payload EC image."""
+ with open(payload_file, "rb") as f:
+ f.seek(4)
+ s = f.read(4)
+ return struct.unpack("<I", s)[0]
+
def GetPayloadFromOffset(payload_file, offset):
- """Read payload and pad it to 64-byte aligned."""
- with open(payload_file, 'rb') as f:
- f.seek(offset)
- payload = bytearray(f.read())
- rem_len = len(payload) % 64
- if rem_len:
- payload += b'\0' * (64 - rem_len)
- return payload
+ """Read payload and pad it to 64-byte aligned."""
+ with open(payload_file, "rb") as f:
+ f.seek(offset)
+ payload = bytearray(f.read())
+ rem_len = len(payload) % 64
+ if rem_len:
+ payload += b"\0" * (64 - rem_len)
+ return payload
+
def GetPayload(payload_file):
- """Read payload and pad it to 64-byte aligned."""
- return GetPayloadFromOffset(payload_file, 0)
+ """Read payload and pad it to 64-byte aligned."""
+ return GetPayloadFromOffset(payload_file, 0)
+
def GetPublicKey(pem_file):
- """Extract public exponent and modulus from PEM file."""
- result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text',
- '-noout'], stdout=subprocess.PIPE, encoding='utf-8')
- modulus_raw = []
- in_modulus = False
- for line in result.stdout.splitlines():
- if line.startswith('modulus'):
- in_modulus = True
- elif not line.startswith(' '):
- in_modulus = False
- elif in_modulus:
- modulus_raw.extend(line.strip().strip(':').split(':'))
- if line.startswith('publicExponent'):
- exp = int(line.split(' ')[1], 10)
- modulus_raw.reverse()
- modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
- return struct.pack('<Q', exp), modulus
+ """Extract public exponent and modulus from PEM file."""
+ result = subprocess.run(
+ ["openssl", "rsa", "-in", pem_file, "-text", "-noout"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ )
+ modulus_raw = []
+ in_modulus = False
+ for line in result.stdout.splitlines():
+ if line.startswith("modulus"):
+ in_modulus = True
+ elif not line.startswith(" "):
+ in_modulus = False
+ elif in_modulus:
+ modulus_raw.extend(line.strip().strip(":").split(":"))
+ if line.startswith("publicExponent"):
+ exp = int(line.split(" ")[1], 10)
+ modulus_raw.reverse()
+ modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
+ return struct.pack("<Q", exp), modulus
+
def GetSpiClockParameter(args):
- assert args.spi_clock in SPI_CLOCK_LIST, \
- "Unsupported SPI clock speed %d MHz" % args.spi_clock
- return SPI_CLOCK_LIST.index(args.spi_clock)
+ assert args.spi_clock in SPI_CLOCK_LIST, (
+ "Unsupported SPI clock speed %d MHz" % args.spi_clock
+ )
+ return SPI_CLOCK_LIST.index(args.spi_clock)
+
def GetSpiReadCmdParameter(args):
- assert args.spi_read_cmd in SPI_READ_CMD_LIST, \
- "Unsupported SPI read command 0x%x" % args.spi_read_cmd
- return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+ assert args.spi_read_cmd in SPI_READ_CMD_LIST, (
+ "Unsupported SPI read command 0x%x" % args.spi_read_cmd
+ )
+ return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+
def PadZeroTo(data, size):
- data.extend(b'\0' * (size - len(data)))
+ data.extend(b"\0" * (size - len(data)))
+
def BuildHeader(args, payload_len, load_addr, rorofile):
- # Identifier and header version
- header = bytearray(b'PHCM\0')
+ # Identifier and header version
+ header = bytearray(b"PHCM\0")
- # byte[5]
- b = GetSpiClockParameter(args)
- b |= (1 << 2)
- header.append(b)
+ # byte[5]
+ b = GetSpiClockParameter(args)
+ b |= 1 << 2
+ header.append(b)
- # byte[6]
- b = 0
- header.append(b)
+ # byte[6]
+ b = 0
+ header.append(b)
- # byte[7]
- header.append(GetSpiReadCmdParameter(args))
+ # byte[7]
+ header.append(GetSpiReadCmdParameter(args))
- # bytes 0x08 - 0x0b
- header.extend(struct.pack('<I', load_addr))
- # bytes 0x0c - 0x0f
- header.extend(struct.pack('<I', GetEntryPoint(rorofile)))
- # bytes 0x10 - 0x13
- header.append((payload_len >> 6) & 0xff)
- header.append((payload_len >> 14) & 0xff)
- PadZeroTo(header, 0x14)
- # bytes 0x14 - 0x17
- header.extend(struct.pack('<I', args.payload_offset))
+ # bytes 0x08 - 0x0b
+ header.extend(struct.pack("<I", load_addr))
+ # bytes 0x0c - 0x0f
+ header.extend(struct.pack("<I", GetEntryPoint(rorofile)))
+ # bytes 0x10 - 0x13
+ header.append((payload_len >> 6) & 0xFF)
+ header.append((payload_len >> 14) & 0xFF)
+ PadZeroTo(header, 0x14)
+ # bytes 0x14 - 0x17
+ header.extend(struct.pack("<I", args.payload_offset))
- # bytes 0x14 - 0x3F all 0
- PadZeroTo(header, 0x40)
+ # bytes 0x14 - 0x3F all 0
+ PadZeroTo(header, 0x40)
- # header signature is appended by the caller
+ # header signature is appended by the caller
- return header
+ return header
def BuildHeader2(args, payload_len, load_addr, payload_entry):
- # Identifier and header version
- header = bytearray(b'PHCM\0')
+ # Identifier and header version
+ header = bytearray(b"PHCM\0")
- # byte[5]
- b = GetSpiClockParameter(args)
- b |= (1 << 2)
- header.append(b)
+ # byte[5]
+ b = GetSpiClockParameter(args)
+ b |= 1 << 2
+ header.append(b)
- # byte[6]
- b = 0
- header.append(b)
+ # byte[6]
+ b = 0
+ header.append(b)
- # byte[7]
- header.append(GetSpiReadCmdParameter(args))
+ # byte[7]
+ header.append(GetSpiReadCmdParameter(args))
- # bytes 0x08 - 0x0b
- header.extend(struct.pack('<I', load_addr))
- # bytes 0x0c - 0x0f
- header.extend(struct.pack('<I', payload_entry))
- # bytes 0x10 - 0x13
- header.append((payload_len >> 6) & 0xff)
- header.append((payload_len >> 14) & 0xff)
- PadZeroTo(header, 0x14)
- # bytes 0x14 - 0x17
- header.extend(struct.pack('<I', args.payload_offset))
+ # bytes 0x08 - 0x0b
+ header.extend(struct.pack("<I", load_addr))
+ # bytes 0x0c - 0x0f
+ header.extend(struct.pack("<I", payload_entry))
+ # bytes 0x10 - 0x13
+ header.append((payload_len >> 6) & 0xFF)
+ header.append((payload_len >> 14) & 0xFF)
+ PadZeroTo(header, 0x14)
+ # bytes 0x14 - 0x17
+ header.extend(struct.pack("<I", args.payload_offset))
- # bytes 0x14 - 0x3F all 0
- PadZeroTo(header, 0x40)
+ # bytes 0x14 - 0x3F all 0
+ PadZeroTo(header, 0x40)
- # header signature is appended by the caller
+ # header signature is appended by the caller
+
+ return header
- return header
#
# Compute SHA-256 of data and return digest
# as a bytearray
#
def HashByteArray(data):
- hasher = hashlib.sha256()
- hasher.update(data)
- h = hasher.digest()
- bah = bytearray(h)
- return bah
+ hasher = hashlib.sha256()
+ hasher.update(data)
+ h = hasher.digest()
+ bah = bytearray(h)
+ return bah
+
#
# Return 64-byte signature of byte array data.
# Signature is SHA256 of data with 32 0 bytes appended
#
def SignByteArray(data):
- debug_print("Signature is SHA-256 of data")
- sigb = HashByteArray(data)
- sigb.extend(b'\0' * 32)
- return sigb
+ debug_print("Signature is SHA-256 of data")
+ sigb = HashByteArray(data)
+ sigb.extend(b"\0" * 32)
+ return sigb
# MEC1701H supports two 32-bit Tags located at offsets 0x0 and 0x4
@@ -201,16 +235,21 @@ def SignByteArray(data):
# to the same flash part.
#
def BuildTag(args):
- tag = bytearray([(args.header_loc >> 8) & 0xff,
- (args.header_loc >> 16) & 0xff,
- (args.header_loc >> 24) & 0xff])
- tag.append(Crc8(0, tag))
- return tag
+ tag = bytearray(
+ [
+ (args.header_loc >> 8) & 0xFF,
+ (args.header_loc >> 16) & 0xFF,
+ (args.header_loc >> 24) & 0xFF,
+ ]
+ )
+ tag.append(Crc8(0, tag))
+ return tag
+
def BuildTagFromHdrAddr(header_loc):
- tag = bytearray([(header_loc >> 8) & 0xff,
- (header_loc >> 16) & 0xff,
- (header_loc >> 24) & 0xff])
+ tag = bytearray(
+ [(header_loc >> 8) & 0xFF, (header_loc >> 16) & 0xFF, (header_loc >> 24) & 0xFF]
+ )
tag.append(Crc8(0, tag))
return tag
@@ -224,20 +263,21 @@ def BuildTagFromHdrAddr(header_loc):
# Returns temporary file name
#
def PacklfwRoImage(rorw_file, loader_file, image_size):
- """Create a temp file with the
- first image_size bytes from the loader file and append bytes
- from the rorw file.
- return the filename"""
- fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
- with open(loader_file,'rb') as fin1: # read 4KB loader file
- pro = fin1.read()
- fo.write(pro) # write 4KB loader data to temp file
- with open(rorw_file, 'rb') as fin:
- ro = fin.read(image_size)
-
- fo.write(ro)
- fo.close()
- return fo.name
+ """Create a temp file with the
+ first image_size bytes from the loader file and append bytes
+ from the rorw file.
+ return the filename"""
+ fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
+ with open(loader_file, "rb") as fin1: # read 4KB loader file
+ pro = fin1.read()
+ fo.write(pro) # write 4KB loader data to temp file
+ with open(rorw_file, "rb") as fin:
+ ro = fin.read(image_size)
+
+ fo.write(ro)
+ fo.close()
+ return fo.name
+
#
# Generate a test EC_RW image of same size
@@ -248,105 +288,145 @@ def PacklfwRoImage(rorw_file, loader_file, image_size):
# process hash generation.
#
def gen_test_ecrw(pldrw):
- debug_print("gen_test_ecrw: pldrw type =", type(pldrw))
- debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw)))
- cookie1_pos = pldrw.find(b'\x99\x88\x77\xce')
- cookie2_pos = pldrw.find(b'\xdd\xbb\xaa\xce', cookie1_pos+4)
- t = struct.unpack("<L", pldrw[cookie1_pos+0x24:cookie1_pos+0x28])
- size = t[0]
- debug_print("EC_RW size =", size, " = ", hex(size))
-
- debug_print("Found cookie1 at ", hex(cookie1_pos))
- debug_print("Found cookie2 at ", hex(cookie2_pos))
-
- if cookie1_pos > 0 and cookie2_pos > cookie1_pos:
- for i in range(0, cookie1_pos):
- pldrw[i] = 0xA5
- for i in range(cookie2_pos+4, len(pldrw)):
- pldrw[i] = 0xA5
-
- with open("ec_RW_test.bin", "wb") as fecrw:
- fecrw.write(pldrw[:size])
+ debug_print("gen_test_ecrw: pldrw type =", type(pldrw))
+ debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw)))
+ cookie1_pos = pldrw.find(b"\x99\x88\x77\xce")
+ cookie2_pos = pldrw.find(b"\xdd\xbb\xaa\xce", cookie1_pos + 4)
+ t = struct.unpack("<L", pldrw[cookie1_pos + 0x24 : cookie1_pos + 0x28])
+ size = t[0]
+ debug_print("EC_RW size =", size, " = ", hex(size))
+
+ debug_print("Found cookie1 at ", hex(cookie1_pos))
+ debug_print("Found cookie2 at ", hex(cookie2_pos))
+
+ if cookie1_pos > 0 and cookie2_pos > cookie1_pos:
+ for i in range(0, cookie1_pos):
+ pldrw[i] = 0xA5
+ for i in range(cookie2_pos + 4, len(pldrw)):
+ pldrw[i] = 0xA5
+
+ with open("ec_RW_test.bin", "wb") as fecrw:
+ fecrw.write(pldrw[:size])
+
def parseargs():
- rpath = os.path.dirname(os.path.relpath(__file__))
-
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input",
- help="EC binary to pack, usually ec.bin or ec.RO.flat.",
- metavar="EC_BIN", default="ec.bin")
- parser.add_argument("-o", "--output",
- help="Output flash binary file",
- metavar="EC_SPI_FLASH", default="ec.packed.bin")
- parser.add_argument("--loader_file",
- help="EC loader binary",
- default="ecloader.bin")
- parser.add_argument("-s", "--spi_size", type=int,
- help="Size of the SPI flash in KB",
- default=512)
- parser.add_argument("-l", "--header_loc", type=int,
- help="Location of header in SPI flash",
- default=0x1000)
- parser.add_argument("-p", "--payload_offset", type=int,
- help="The offset of payload from the start of header",
- default=0x80)
- parser.add_argument("-r", "--rw_loc", type=int,
- help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size",
- default=-1)
- parser.add_argument("--spi_clock", type=int,
- help="SPI clock speed. 8, 12, 24, or 48 MHz.",
- default=24)
- parser.add_argument("--spi_read_cmd", type=int,
- help="SPI read command. 0x3, 0xB, or 0x3B.",
- default=0xb)
- parser.add_argument("--image_size", type=int,
- help="Size of a single image. Default 220KB",
- default=(220 * 1024))
- parser.add_argument("--test_spi", action='store_true',
- help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries",
- default=False)
- parser.add_argument("--test_ecrw", action='store_true',
- help="Use fixed pattern for EC_RW but preserve image_data",
- default=False)
- parser.add_argument("--verbose", action='store_true',
- help="Enable verbose output",
- default=False)
-
- return parser.parse_args()
+ rpath = os.path.dirname(os.path.relpath(__file__))
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i",
+ "--input",
+ help="EC binary to pack, usually ec.bin or ec.RO.flat.",
+ metavar="EC_BIN",
+ default="ec.bin",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ help="Output flash binary file",
+ metavar="EC_SPI_FLASH",
+ default="ec.packed.bin",
+ )
+ parser.add_argument(
+ "--loader_file", help="EC loader binary", default="ecloader.bin"
+ )
+ parser.add_argument(
+ "-s", "--spi_size", type=int, help="Size of the SPI flash in KB", default=512
+ )
+ parser.add_argument(
+ "-l",
+ "--header_loc",
+ type=int,
+ help="Location of header in SPI flash",
+ default=0x1000,
+ )
+ parser.add_argument(
+ "-p",
+ "--payload_offset",
+ type=int,
+ help="The offset of payload from the start of header",
+ default=0x80,
+ )
+ parser.add_argument(
+ "-r",
+ "--rw_loc",
+ type=int,
+ help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size",
+ default=-1,
+ )
+ parser.add_argument(
+ "--spi_clock",
+ type=int,
+ help="SPI clock speed. 8, 12, 24, or 48 MHz.",
+ default=24,
+ )
+ parser.add_argument(
+ "--spi_read_cmd",
+ type=int,
+ help="SPI read command. 0x3, 0xB, or 0x3B.",
+ default=0xB,
+ )
+ parser.add_argument(
+ "--image_size",
+ type=int,
+ help="Size of a single image. Default 220KB",
+ default=(220 * 1024),
+ )
+ parser.add_argument(
+ "--test_spi",
+ action="store_true",
+ help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries",
+ default=False,
+ )
+ parser.add_argument(
+ "--test_ecrw",
+ action="store_true",
+ help="Use fixed pattern for EC_RW but preserve image_data",
+ default=False,
+ )
+ parser.add_argument(
+ "--verbose", action="store_true", help="Enable verbose output", default=False
+ )
+
+ return parser.parse_args()
+
# Debug helper routine
def dumpsects(spi_list):
- debug_print("spi_list has {0} entries".format(len(spi_list)))
- for s in spi_list:
- debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0],len(s[1]),s[2]))
+ debug_print("spi_list has {0} entries".format(len(spi_list)))
+ for s in spi_list:
+ debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0], len(s[1]), s[2]))
+
def printByteArrayAsHex(ba, title):
- debug_print(title,"= ")
- count = 0
- for b in ba:
- count = count + 1
- debug_print("0x{0:02x}, ".format(b),end="")
- if (count % 8) == 0:
- debug_print("")
- debug_print("\n")
+ debug_print(title, "= ")
+ count = 0
+ for b in ba:
+ count = count + 1
+ debug_print("0x{0:02x}, ".format(b), end="")
+ if (count % 8) == 0:
+ debug_print("")
+ debug_print("\n")
+
def print_args(args):
- debug_print("parsed arguments:")
- debug_print(".input = ", args.input)
- debug_print(".output = ", args.output)
- debug_print(".loader_file = ", args.loader_file)
- debug_print(".spi_size (KB) = ", hex(args.spi_size))
- debug_print(".image_size = ", hex(args.image_size))
- debug_print(".header_loc = ", hex(args.header_loc))
- debug_print(".payload_offset = ", hex(args.payload_offset))
- if args.rw_loc < 0:
- debug_print(".rw_loc = ", args.rw_loc)
- else:
- debug_print(".rw_loc = ", hex(args.rw_loc))
- debug_print(".spi_clock = ", args.spi_clock)
- debug_print(".spi_read_cmd = ", args.spi_read_cmd)
- debug_print(".test_spi = ", args.test_spi)
- debug_print(".verbose = ", args.verbose)
+ debug_print("parsed arguments:")
+ debug_print(".input = ", args.input)
+ debug_print(".output = ", args.output)
+ debug_print(".loader_file = ", args.loader_file)
+ debug_print(".spi_size (KB) = ", hex(args.spi_size))
+ debug_print(".image_size = ", hex(args.image_size))
+ debug_print(".header_loc = ", hex(args.header_loc))
+ debug_print(".payload_offset = ", hex(args.payload_offset))
+ if args.rw_loc < 0:
+ debug_print(".rw_loc = ", args.rw_loc)
+ else:
+ debug_print(".rw_loc = ", hex(args.rw_loc))
+ debug_print(".spi_clock = ", args.spi_clock)
+ debug_print(".spi_read_cmd = ", args.spi_read_cmd)
+ debug_print(".test_spi = ", args.test_spi)
+ debug_print(".verbose = ", args.verbose)
+
#
# Handle quiet mode build from Makefile
@@ -354,183 +434,188 @@ def print_args(args):
# Verbose mode when V=1
#
def main():
- global debug_print
-
- args = parseargs()
-
- if args.verbose:
- debug_print = print
-
- debug_print("Begin MEC17xx pack_ec.py script")
-
-
- # MEC17xx maximum 192KB each for RO & RW
- # mec1701 chip Makefile sets args.spi_size = 512
- # Tags at offset 0
- #
- print_args(args)
-
- spi_size = args.spi_size * 1024
- debug_print("SPI Flash image size in bytes =", hex(spi_size))
-
- # !!! IMPORTANT !!!
- # These values MUST match chip/mec1701/config_flash_layout.h
- # defines.
- # MEC17xx Boot-ROM TAGs are at offset 0 and 4.
- # lfw + EC_RO starts at beginning of second 4KB sector
- # EC_RW starts at offset 0x40000 (256KB)
-
- spi_list = []
-
- debug_print("args.input = ",args.input)
- debug_print("args.loader_file = ",args.loader_file)
- debug_print("args.image_size = ",hex(args.image_size))
-
- rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size)
-
- payload = GetPayload(rorofile)
- payload_len = len(payload)
- # debug
- debug_print("EC_LFW + EC_RO length = ",hex(payload_len))
-
- # SPI image integrity test
- # compute CRC32 of EC_RO except for last 4 bytes
- # skip over 4KB LFW
- # Store CRC32 in last 4 bytes
- if args.test_spi == True:
- crc = zlib.crc32(bytes(payload[LFW_SIZE:(payload_len - 4)]))
- crc_ofs = payload_len - 4
- debug_print("EC_RO CRC32 = 0x{0:08x} @ 0x{1:08x}".format(crc, crc_ofs))
- for i in range(4):
- payload[crc_ofs + i] = crc & 0xff
- crc = crc >> 8
-
- # Chromebooks are not using MEC BootROM ECDSA.
- # We implemented the ECDSA disabled case where
- # the 64-byte signature contains a SHA-256 of the binary plus
- # 32 zeros bytes.
- payload_signature = SignByteArray(payload)
- # debug
- printByteArrayAsHex(payload_signature, "LFW + EC_RO payload_signature")
-
- # MEC17xx Header is 0x80 bytes with an 64 byte signature
- # (32 byte SHA256 + 32 zero bytes)
- header = BuildHeader(args, payload_len, LOAD_ADDR, rorofile)
- # debug
- printByteArrayAsHex(header, "Header LFW + EC_RO")
-
- # MEC17xx payload ECDSA not used, 64 byte signature is
- # SHA256 + 32 zero bytes
- header_signature = SignByteArray(header)
- # debug
- printByteArrayAsHex(header_signature, "header_signature")
-
- tag = BuildTag(args)
- # MEC17xx truncate RW length to args.image_size to not overwrite LFW
- # offset may be different due to Header size and other changes
- # MCHP we want to append a SHA-256 to the end of the actual payload
- # to test SPI read routines.
- debug_print("Call to GetPayloadFromOffset")
- debug_print("args.input = ", args.input)
- debug_print("args.image_size = ", hex(args.image_size))
-
- payload_rw = GetPayloadFromOffset(args.input, args.image_size)
- debug_print("type(payload_rw) is ", type(payload_rw))
- debug_print("len(payload_rw) is ", hex(len(payload_rw)))
-
- # truncate to args.image_size
- rw_len = args.image_size
- payload_rw = payload_rw[:rw_len]
- payload_rw_len = len(payload_rw)
- debug_print("Truncated size of EC_RW = ", hex(payload_rw_len))
-
- payload_entry_tuple = struct.unpack_from('<I', payload_rw, 4)
- debug_print("payload_entry_tuple = ", payload_entry_tuple)
-
- payload_entry = payload_entry_tuple[0]
- debug_print("payload_entry = ", hex(payload_entry))
-
- # Note: payload_rw is a bytearray therefore is mutable
- if args.test_ecrw:
- gen_test_ecrw(payload_rw)
-
- # SPI image integrity test
- # compute CRC32 of EC_RW except for last 4 bytes
- # Store CRC32 in last 4 bytes
- if args.test_spi == True:
- crc = zlib.crc32(bytes(payload_rw[:(payload_rw_len - 32)]))
- crc_ofs = payload_rw_len - 4
- debug_print("EC_RW CRC32 = 0x{0:08x} at offset 0x{1:08x}".format(crc, crc_ofs))
- for i in range(4):
- payload_rw[crc_ofs + i] = crc & 0xff
- crc = crc >> 8
-
- payload_rw_sig = SignByteArray(payload_rw)
- # debug
- printByteArrayAsHex(payload_rw_sig, "payload_rw_sig")
-
- os.remove(rorofile) # clean up the temp file
-
- # MEC170x Boot-ROM Tags are located at SPI offset 0
- spi_list.append((0, tag, "tag"))
-
- spi_list.append((args.header_loc, header, "header(lwf + ro)"))
- spi_list.append((args.header_loc + HEADER_SIZE, header_signature,
- "header(lwf + ro) signature"))
- spi_list.append((args.header_loc + args.payload_offset, payload,
- "payload(lfw + ro)"))
-
- offset = args.header_loc + args.payload_offset + payload_len
-
- # No SPI Header for EC_RW as its not loaded by BootROM
- spi_list.append((offset, payload_signature,
- "payload(lfw_ro) signature"))
-
- # EC_RW location
- rw_offset = int(spi_size // 2)
- if args.rw_loc >= 0:
- rw_offset = args.rw_loc
-
- debug_print("rw_offset = 0x{0:08x}".format(rw_offset))
-
- if rw_offset < offset + len(payload_signature):
- print("ERROR: EC_RW overlaps EC_RO")
-
- spi_list.append((rw_offset, payload_rw, "payload(rw)"))
-
- # don't add to EC_RW. We don't know if Google will process
- # EC SPI flash binary with other tools during build of
- # coreboot and OS.
- #offset = rw_offset + payload_rw_len
- #spi_list.append((offset, payload_rw_sig, "payload(rw) signature"))
-
- spi_list = sorted(spi_list)
-
- dumpsects(spi_list)
-
- #
- # MEC17xx Boot-ROM locates TAG at SPI offset 0 instead of end of SPI.
- #
- with open(args.output, 'wb') as f:
- debug_print("Write spi list to file", args.output)
- addr = 0
- for s in spi_list:
- if addr < s[0]:
- debug_print("Offset ",hex(addr)," Length", hex(s[0]-addr),
- "fill with 0xff")
- f.write(b'\xff' * (s[0] - addr))
- addr = s[0]
- debug_print("Offset ",hex(addr), " Length", hex(len(s[1])), "write data")
-
- f.write(s[1])
- addr += len(s[1])
-
- if addr < spi_size:
- debug_print("Offset ",hex(addr), " Length", hex(spi_size - addr),
- "fill with 0xff")
- f.write(b'\xff' * (spi_size - addr))
-
- f.flush()
-
-if __name__ == '__main__':
- main()
+ global debug_print
+
+ args = parseargs()
+
+ if args.verbose:
+ debug_print = print
+
+ debug_print("Begin MEC17xx pack_ec.py script")
+
+ # MEC17xx maximum 192KB each for RO & RW
+ # mec1701 chip Makefile sets args.spi_size = 512
+ # Tags at offset 0
+ #
+ print_args(args)
+
+ spi_size = args.spi_size * 1024
+ debug_print("SPI Flash image size in bytes =", hex(spi_size))
+
+ # !!! IMPORTANT !!!
+ # These values MUST match chip/mec1701/config_flash_layout.h
+ # defines.
+ # MEC17xx Boot-ROM TAGs are at offset 0 and 4.
+ # lfw + EC_RO starts at beginning of second 4KB sector
+ # EC_RW starts at offset 0x40000 (256KB)
+
+ spi_list = []
+
+ debug_print("args.input = ", args.input)
+ debug_print("args.loader_file = ", args.loader_file)
+ debug_print("args.image_size = ", hex(args.image_size))
+
+ rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size)
+
+ payload = GetPayload(rorofile)
+ payload_len = len(payload)
+ # debug
+ debug_print("EC_LFW + EC_RO length = ", hex(payload_len))
+
+ # SPI image integrity test
+ # compute CRC32 of EC_RO except for last 4 bytes
+ # skip over 4KB LFW
+ # Store CRC32 in last 4 bytes
+ if args.test_spi == True:
+ crc = zlib.crc32(bytes(payload[LFW_SIZE : (payload_len - 4)]))
+ crc_ofs = payload_len - 4
+ debug_print("EC_RO CRC32 = 0x{0:08x} @ 0x{1:08x}".format(crc, crc_ofs))
+ for i in range(4):
+ payload[crc_ofs + i] = crc & 0xFF
+ crc = crc >> 8
+
+ # Chromebooks are not using MEC BootROM ECDSA.
+ # We implemented the ECDSA disabled case where
+ # the 64-byte signature contains a SHA-256 of the binary plus
+ # 32 zeros bytes.
+ payload_signature = SignByteArray(payload)
+ # debug
+ printByteArrayAsHex(payload_signature, "LFW + EC_RO payload_signature")
+
+ # MEC17xx Header is 0x80 bytes with an 64 byte signature
+ # (32 byte SHA256 + 32 zero bytes)
+ header = BuildHeader(args, payload_len, LOAD_ADDR, rorofile)
+ # debug
+ printByteArrayAsHex(header, "Header LFW + EC_RO")
+
+ # MEC17xx payload ECDSA not used, 64 byte signature is
+ # SHA256 + 32 zero bytes
+ header_signature = SignByteArray(header)
+ # debug
+ printByteArrayAsHex(header_signature, "header_signature")
+
+ tag = BuildTag(args)
+ # MEC17xx truncate RW length to args.image_size to not overwrite LFW
+ # offset may be different due to Header size and other changes
+ # MCHP we want to append a SHA-256 to the end of the actual payload
+ # to test SPI read routines.
+ debug_print("Call to GetPayloadFromOffset")
+ debug_print("args.input = ", args.input)
+ debug_print("args.image_size = ", hex(args.image_size))
+
+ payload_rw = GetPayloadFromOffset(args.input, args.image_size)
+ debug_print("type(payload_rw) is ", type(payload_rw))
+ debug_print("len(payload_rw) is ", hex(len(payload_rw)))
+
+ # truncate to args.image_size
+ rw_len = args.image_size
+ payload_rw = payload_rw[:rw_len]
+ payload_rw_len = len(payload_rw)
+ debug_print("Truncated size of EC_RW = ", hex(payload_rw_len))
+
+ payload_entry_tuple = struct.unpack_from("<I", payload_rw, 4)
+ debug_print("payload_entry_tuple = ", payload_entry_tuple)
+
+ payload_entry = payload_entry_tuple[0]
+ debug_print("payload_entry = ", hex(payload_entry))
+
+ # Note: payload_rw is a bytearray therefore is mutable
+ if args.test_ecrw:
+ gen_test_ecrw(payload_rw)
+
+ # SPI image integrity test
+ # compute CRC32 of EC_RW except for last 4 bytes
+ # Store CRC32 in last 4 bytes
+ if args.test_spi == True:
+ crc = zlib.crc32(bytes(payload_rw[: (payload_rw_len - 32)]))
+ crc_ofs = payload_rw_len - 4
+ debug_print("EC_RW CRC32 = 0x{0:08x} at offset 0x{1:08x}".format(crc, crc_ofs))
+ for i in range(4):
+ payload_rw[crc_ofs + i] = crc & 0xFF
+ crc = crc >> 8
+
+ payload_rw_sig = SignByteArray(payload_rw)
+ # debug
+ printByteArrayAsHex(payload_rw_sig, "payload_rw_sig")
+
+ os.remove(rorofile) # clean up the temp file
+
+ # MEC170x Boot-ROM Tags are located at SPI offset 0
+ spi_list.append((0, tag, "tag"))
+
+ spi_list.append((args.header_loc, header, "header(lwf + ro)"))
+ spi_list.append(
+ (args.header_loc + HEADER_SIZE, header_signature, "header(lwf + ro) signature")
+ )
+ spi_list.append(
+ (args.header_loc + args.payload_offset, payload, "payload(lfw + ro)")
+ )
+
+ offset = args.header_loc + args.payload_offset + payload_len
+
+ # No SPI Header for EC_RW as its not loaded by BootROM
+ spi_list.append((offset, payload_signature, "payload(lfw_ro) signature"))
+
+ # EC_RW location
+ rw_offset = int(spi_size // 2)
+ if args.rw_loc >= 0:
+ rw_offset = args.rw_loc
+
+ debug_print("rw_offset = 0x{0:08x}".format(rw_offset))
+
+ if rw_offset < offset + len(payload_signature):
+ print("ERROR: EC_RW overlaps EC_RO")
+
+ spi_list.append((rw_offset, payload_rw, "payload(rw)"))
+
+ # don't add to EC_RW. We don't know if Google will process
+ # EC SPI flash binary with other tools during build of
+ # coreboot and OS.
+ # offset = rw_offset + payload_rw_len
+ # spi_list.append((offset, payload_rw_sig, "payload(rw) signature"))
+
+ spi_list = sorted(spi_list)
+
+ dumpsects(spi_list)
+
+ #
+ # MEC17xx Boot-ROM locates TAG at SPI offset 0 instead of end of SPI.
+ #
+ with open(args.output, "wb") as f:
+ debug_print("Write spi list to file", args.output)
+ addr = 0
+ for s in spi_list:
+ if addr < s[0]:
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(s[0] - addr), "fill with 0xff"
+ )
+ f.write(b"\xff" * (s[0] - addr))
+ addr = s[0]
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(len(s[1])), "write data"
+ )
+
+ f.write(s[1])
+ addr += len(s[1])
+
+ if addr < spi_size:
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(spi_size - addr), "fill with 0xff"
+ )
+ f.write(b"\xff" * (spi_size - addr))
+
+ f.flush()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/chip/mchp/util/pack_ec_mec152x.py b/chip/mchp/util/pack_ec_mec152x.py
index 34846cd6ba..89f90f5394 100755
--- a/chip/mchp/util/pack_ec_mec152x.py
+++ b/chip/mchp/util/pack_ec_mec152x.py
@@ -16,7 +16,7 @@ import os
import struct
import subprocess
import tempfile
-import zlib # CRC32
+import zlib # CRC32
# MEC152xH has 256KB SRAM from 0xE0000 - 0x120000
# SRAM is divided into contiguous CODE & DATA
@@ -29,119 +29,157 @@ LOAD_ADDR = 0x0E0000
LOAD_ADDR_RW = 0xE1000
MEC152X_HEADER_SIZE = 0x140
MEC152X_HEADER_VERSION = 0x02
-PAYLOAD_PAD_BYTE = b'\xff'
+PAYLOAD_PAD_BYTE = b"\xff"
SPI_ERASE_BLOCK_SIZE = 0x1000
SPI_CLOCK_LIST = [48, 24, 16, 12]
-SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b, 0x6b]
-SPI_DRIVE_STR_DICT = {2:0, 4:1, 8:2, 12:3}
+SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B, 0x6B]
+SPI_DRIVE_STR_DICT = {2: 0, 4: 1, 8: 2, 12: 3}
CHIP_MAX_CODE_SRAM_KB = 224
MEC152X_DICT = {
- "HEADER_SIZE":0x140,
- "HEADER_VER":0x02,
- "PAYLOAD_OFFSET":0x140,
- "PAYLOAD_GRANULARITY":128,
- "EC_INFO_BLK_SZ":128,
- "ENCR_KEY_HDR_SZ":128,
- "COSIG_SZ":96,
- "TRAILER_SZ":160,
- "TAILER_PAD_BYTE":b'\xff',
- "PAD_SIZE":128
- }
-
-CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15,
- 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d]
+ "HEADER_SIZE": 0x140,
+ "HEADER_VER": 0x02,
+ "PAYLOAD_OFFSET": 0x140,
+ "PAYLOAD_GRANULARITY": 128,
+ "EC_INFO_BLK_SZ": 128,
+ "ENCR_KEY_HDR_SZ": 128,
+ "COSIG_SZ": 96,
+ "TRAILER_SZ": 160,
+ "TAILER_PAD_BYTE": b"\xff",
+ "PAD_SIZE": 128,
+}
+
+CRC_TABLE = [
+ 0x00,
+ 0x07,
+ 0x0E,
+ 0x09,
+ 0x1C,
+ 0x1B,
+ 0x12,
+ 0x15,
+ 0x38,
+ 0x3F,
+ 0x36,
+ 0x31,
+ 0x24,
+ 0x23,
+ 0x2A,
+ 0x2D,
+]
+
def mock_print(*args, **kwargs):
- pass
+ pass
+
debug_print = mock_print
# Debug helper routine
def dumpsects(spi_list):
- debug_print("spi_list has {0} entries".format(len(spi_list)))
- for s in spi_list:
- debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0],len(s[1]),s[2]))
+ debug_print("spi_list has {0} entries".format(len(spi_list)))
+ for s in spi_list:
+ debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0], len(s[1]), s[2]))
+
def printByteArrayAsHex(ba, title):
- debug_print(title,"= ")
- if ba == None:
- debug_print("None")
- return
-
- count = 0
- for b in ba:
- count = count + 1
- debug_print("0x{0:02x}, ".format(b),end="")
- if (count % 8) == 0:
- debug_print("")
- debug_print("")
+ debug_print(title, "= ")
+ if ba == None:
+ debug_print("None")
+ return
+
+ count = 0
+ for b in ba:
+ count = count + 1
+ debug_print("0x{0:02x}, ".format(b), end="")
+ if (count % 8) == 0:
+ debug_print("")
+ debug_print("")
+
def Crc8(crc, data):
- """Update CRC8 value."""
- for v in data:
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]);
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]);
- return crc ^ 0x55
+ """Update CRC8 value."""
+ for v in data:
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)])
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)])
+ return crc ^ 0x55
+
def GetEntryPoint(payload_file):
- """Read entry point from payload EC image."""
- with open(payload_file, 'rb') as f:
- f.seek(4)
- s = f.read(4)
- return int.from_bytes(s, byteorder='little')
+ """Read entry point from payload EC image."""
+ with open(payload_file, "rb") as f:
+ f.seek(4)
+ s = f.read(4)
+ return int.from_bytes(s, byteorder="little")
+
def GetPayloadFromOffset(payload_file, offset, padsize):
- """Read payload and pad it to padsize."""
- with open(payload_file, 'rb') as f:
- f.seek(offset)
- payload = bytearray(f.read())
- rem_len = len(payload) % padsize
- debug_print("GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format(padsize,len(payload),rem_len))
+ """Read payload and pad it to padsize."""
+ with open(payload_file, "rb") as f:
+ f.seek(offset)
+ payload = bytearray(f.read())
+ rem_len = len(payload) % padsize
+ debug_print(
+ "GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format(
+ padsize, len(payload), rem_len
+ )
+ )
+
+ if rem_len:
+ payload += PAYLOAD_PAD_BYTE * (padsize - rem_len)
+ debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len))
- if rem_len:
- payload += PAYLOAD_PAD_BYTE * (padsize - rem_len)
- debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len))
+ return payload
- return payload
def GetPayload(payload_file, padsize):
- """Read payload and pad it to padsize"""
- return GetPayloadFromOffset(payload_file, 0, padsize)
+ """Read payload and pad it to padsize"""
+ return GetPayloadFromOffset(payload_file, 0, padsize)
+
def GetPublicKey(pem_file):
- """Extract public exponent and modulus from PEM file."""
- result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text',
- '-noout'], stdout=subprocess.PIPE, encoding='utf-8')
- modulus_raw = []
- in_modulus = False
- for line in result.stdout.splitlines():
- if line.startswith('modulus'):
- in_modulus = True
- elif not line.startswith(' '):
- in_modulus = False
- elif in_modulus:
- modulus_raw.extend(line.strip().strip(':').split(':'))
- if line.startswith('publicExponent'):
- exp = int(line.split(' ')[1], 10)
- modulus_raw.reverse()
- modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
- return struct.pack('<Q', exp), modulus
+ """Extract public exponent and modulus from PEM file."""
+ result = subprocess.run(
+ ["openssl", "rsa", "-in", pem_file, "-text", "-noout"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ )
+ modulus_raw = []
+ in_modulus = False
+ for line in result.stdout.splitlines():
+ if line.startswith("modulus"):
+ in_modulus = True
+ elif not line.startswith(" "):
+ in_modulus = False
+ elif in_modulus:
+ modulus_raw.extend(line.strip().strip(":").split(":"))
+ if line.startswith("publicExponent"):
+ exp = int(line.split(" ")[1], 10)
+ modulus_raw.reverse()
+ modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
+ return struct.pack("<Q", exp), modulus
+
def GetSpiClockParameter(args):
- assert args.spi_clock in SPI_CLOCK_LIST, \
- "Unsupported SPI clock speed %d MHz" % args.spi_clock
- return SPI_CLOCK_LIST.index(args.spi_clock)
+ assert args.spi_clock in SPI_CLOCK_LIST, (
+ "Unsupported SPI clock speed %d MHz" % args.spi_clock
+ )
+ return SPI_CLOCK_LIST.index(args.spi_clock)
+
def GetSpiReadCmdParameter(args):
- assert args.spi_read_cmd in SPI_READ_CMD_LIST, \
- "Unsupported SPI read command 0x%x" % args.spi_read_cmd
- return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+ assert args.spi_read_cmd in SPI_READ_CMD_LIST, (
+ "Unsupported SPI read command 0x%x" % args.spi_read_cmd
+ )
+ return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+
def GetEncodedSpiDriveStrength(args):
- assert args.spi_drive_str in SPI_DRIVE_STR_DICT, \
- "Unsupported SPI drive strength %d mA" % args.spi_drive_str
- return SPI_DRIVE_STR_DICT.get(args.spi_drive_str)
+ assert args.spi_drive_str in SPI_DRIVE_STR_DICT, (
+ "Unsupported SPI drive strength %d mA" % args.spi_drive_str
+ )
+ return SPI_DRIVE_STR_DICT.get(args.spi_drive_str)
+
# Return 0=Slow slew rate or 1=Fast slew rate
def GetSpiSlewRate(args):
@@ -149,12 +187,14 @@ def GetSpiSlewRate(args):
return 1
return 0
+
# Return SPI CPOL = 0 or 1
def GetSpiCpol(args):
if args.spi_cpol == 0:
return 0
return 1
+
# Return SPI CPHA_MOSI
# 0 = SPI Master drives data is stable on inactive to clock edge
# 1 = SPI Master drives data is stable on active to inactive clock edge
@@ -163,6 +203,7 @@ def GetSpiCphaMosi(args):
return 0
return 1
+
# Return SPI CPHA_MISO 0 or 1
# 0 = SPI Master samples data on inactive to active clock edge
# 1 = SPI Master samples data on active to inactive clock edge
@@ -171,8 +212,10 @@ def GetSpiCphaMiso(args):
return 0
return 1
+
def PadZeroTo(data, size):
- data.extend(b'\0' * (size - len(data)))
+ data.extend(b"\0" * (size - len(data)))
+
#
# Boot-ROM SPI image encryption not used with Chromebooks
@@ -180,6 +223,7 @@ def PadZeroTo(data, size):
def EncryptPayload(args, chip_dict, payload):
return None
+
#
# Build SPI image header for MEC152x
# MEC152x image header size = 320(0x140) bytes
@@ -237,67 +281,69 @@ def EncryptPayload(args, chip_dict, payload):
# header[0x110:0x140] = Header ECDSA-384 signature y-coor. = 0 Auth. disabled
#
def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry):
- header_size = MEC152X_HEADER_SIZE
-
- # allocate zero filled header
- header = bytearray(b'\x00' * header_size)
- debug_print("len(header) = ", len(header))
-
- # Identifier and header version
- header[0:4] = b'PHCM'
- header[4] = MEC152X_HEADER_VERSION
-
- # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips
- spiFreqMHz = GetSpiClockParameter(args)
- header[5] = (int(spiFreqMHz // 48) - 1) & 0x03
- header[5] |= ((GetEncodedSpiDriveStrength(args) & 0x03) << 2)
- header[5] |= ((GetSpiSlewRate(args) & 0x01) << 4)
- header[5] |= ((GetSpiCpol(args) & 0x01) << 5)
- header[5] |= ((GetSpiCphaMosi(args) & 0x01) << 6)
- header[5] |= ((GetSpiCphaMiso(args) & 0x01) << 7)
-
- # b[0]=0 VTR1 must be 3.3V
- # b[1]=0(VTR2 3.3V), 1(VTR2 1.8V)
- # b[2]=0(VTR3 3.3V), 1(VTR3 1.8V)
- # b[5:3]=111b
- # b[6]=0 No ECDSA
- # b[7]=0 No encrypted FW image
- header[6] = 0x7 << 3
- if args.vtr2_V18 == True:
- header[6] |= 0x02
- if args.vtr3_V18 == True:
- header[6] |= 0x04
-
- # SPI read command set same for both chips
- header[7] = GetSpiReadCmdParameter(args) & 0xFF
-
- # bytes 0x08 - 0x0b
- header[0x08:0x0C] = load_addr.to_bytes(4, byteorder='little')
- # bytes 0x0c - 0x0f
- header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder='little')
- # bytes 0x10 - 0x11 payload length in units of 128 bytes
-
- payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"])
- assert payload_units < 0x10000, \
- print("Payload too large: len={0} units={1}".format(payload_len, payload_units))
-
- header[0x10:0x12] = payload_units.to_bytes(2, 'little')
-
- # bytes 0x14 - 0x17
- header[0x14:0x18] = chip_dict["PAYLOAD_OFFSET"].to_bytes(4, 'little')
-
- # MEC152x: Disable ECDSA and encryption
- header[0x18] = 0
-
- # header[0xB0:0xE0] = SHA384(header[0:0xB0])
- header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest()
- # When ECDSA authentication is disabled MCHP SPI image generator
- # is filling the last 48 bytes of the Header with 0xff
- header[-48:] = b'\xff' * 48
-
- debug_print("After hash: len(header) = ", len(header))
-
- return header
+ header_size = MEC152X_HEADER_SIZE
+
+ # allocate zero filled header
+ header = bytearray(b"\x00" * header_size)
+ debug_print("len(header) = ", len(header))
+
+ # Identifier and header version
+ header[0:4] = b"PHCM"
+ header[4] = MEC152X_HEADER_VERSION
+
+ # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips
+ spiFreqMHz = GetSpiClockParameter(args)
+ header[5] = (int(spiFreqMHz // 48) - 1) & 0x03
+ header[5] |= (GetEncodedSpiDriveStrength(args) & 0x03) << 2
+ header[5] |= (GetSpiSlewRate(args) & 0x01) << 4
+ header[5] |= (GetSpiCpol(args) & 0x01) << 5
+ header[5] |= (GetSpiCphaMosi(args) & 0x01) << 6
+ header[5] |= (GetSpiCphaMiso(args) & 0x01) << 7
+
+ # b[0]=0 VTR1 must be 3.3V
+ # b[1]=0(VTR2 3.3V), 1(VTR2 1.8V)
+ # b[2]=0(VTR3 3.3V), 1(VTR3 1.8V)
+ # b[5:3]=111b
+ # b[6]=0 No ECDSA
+ # b[7]=0 No encrypted FW image
+ header[6] = 0x7 << 3
+ if args.vtr2_V18 == True:
+ header[6] |= 0x02
+ if args.vtr3_V18 == True:
+ header[6] |= 0x04
+
+ # SPI read command set same for both chips
+ header[7] = GetSpiReadCmdParameter(args) & 0xFF
+
+ # bytes 0x08 - 0x0b
+ header[0x08:0x0C] = load_addr.to_bytes(4, byteorder="little")
+ # bytes 0x0c - 0x0f
+ header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder="little")
+ # bytes 0x10 - 0x11 payload length in units of 128 bytes
+
+ payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"])
+ assert payload_units < 0x10000, print(
+ "Payload too large: len={0} units={1}".format(payload_len, payload_units)
+ )
+
+ header[0x10:0x12] = payload_units.to_bytes(2, "little")
+
+ # bytes 0x14 - 0x17
+ header[0x14:0x18] = chip_dict["PAYLOAD_OFFSET"].to_bytes(4, "little")
+
+ # MEC152x: Disable ECDSA and encryption
+ header[0x18] = 0
+
+ # header[0xB0:0xE0] = SHA384(header[0:0xB0])
+ header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest()
+ # When ECDSA authentication is disabled MCHP SPI image generator
+ # is filling the last 48 bytes of the Header with 0xff
+ header[-48:] = b"\xff" * 48
+
+ debug_print("After hash: len(header) = ", len(header))
+
+ return header
+
#
# MEC152x 128-byte EC Info Block appended to
@@ -311,8 +357,9 @@ def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry):
# byte 127 = customer current image revision
#
def GenEcInfoBlock(args, chip_dict):
- ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"])
- return ecinfo
+ ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"])
+ return ecinfo
+
#
# Generate SPI FW image co-signature.
@@ -325,7 +372,8 @@ def GenEcInfoBlock(args, chip_dict):
# signature.
#
def GenCoSignature(args, chip_dict, payload):
- return bytearray(b'\xff' * chip_dict["COSIG_SZ"])
+ return bytearray(b"\xff" * chip_dict["COSIG_SZ"])
+
#
# Generate SPI FW Image trailer.
@@ -336,22 +384,24 @@ def GenCoSignature(args, chip_dict, payload):
# trailer[144:160] = 0xFF. Boot-ROM spec. says these bytes should be random.
# Authentication & encryption are not used therefore random data
# is not necessary.
-def GenTrailer(args, chip_dict, payload, encryption_key_header,
- ec_info_block, cosignature):
+def GenTrailer(
+ args, chip_dict, payload, encryption_key_header, ec_info_block, cosignature
+):
trailer = bytearray(chip_dict["TAILER_PAD_BYTE"] * chip_dict["TRAILER_SZ"])
hasher = hashlib.sha384()
hasher.update(payload)
if ec_info_block != None:
- hasher.update(ec_info_block)
+ hasher.update(ec_info_block)
if encryption_key_header != None:
- hasher.update(encryption_key_header)
+ hasher.update(encryption_key_header)
if cosignature != None:
- hasher.update(cosignature)
+ hasher.update(cosignature)
trailer[0:48] = hasher.digest()
- trailer[-16:] = 16 * b'\xff'
+ trailer[-16:] = 16 * b"\xff"
return trailer
+
# MEC152xH supports two 32-bit Tags located at offsets 0x0 and 0x4
# in the SPI flash.
# Tag format:
@@ -362,16 +412,21 @@ def GenTrailer(args, chip_dict, payload, encryption_key_header,
# to the same flash part.
#
def BuildTag(args):
- tag = bytearray([(args.header_loc >> 8) & 0xff,
- (args.header_loc >> 16) & 0xff,
- (args.header_loc >> 24) & 0xff])
- tag.append(Crc8(0, tag))
- return tag
+ tag = bytearray(
+ [
+ (args.header_loc >> 8) & 0xFF,
+ (args.header_loc >> 16) & 0xFF,
+ (args.header_loc >> 24) & 0xFF,
+ ]
+ )
+ tag.append(Crc8(0, tag))
+ return tag
+
def BuildTagFromHdrAddr(header_loc):
- tag = bytearray([(header_loc >> 8) & 0xff,
- (header_loc >> 16) & 0xff,
- (header_loc >> 24) & 0xff])
+ tag = bytearray(
+ [(header_loc >> 8) & 0xFF, (header_loc >> 16) & 0xFF, (header_loc >> 24) & 0xFF]
+ )
tag.append(Crc8(0, tag))
return tag
@@ -388,12 +443,13 @@ def BuildTagFromHdrAddr(header_loc):
# Output:
# bytearray of length 4
def BuildFlashMap(secondSpiFlashBaseAddr):
- flashmap = bytearray(4)
- flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xff
- flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xff
- flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xff
- flashmap[3] = Crc8(0, flashmap)
- return flashmap
+ flashmap = bytearray(4)
+ flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xFF
+ flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xFF
+ flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xFF
+ flashmap[3] = Crc8(0, flashmap)
+ return flashmap
+
#
# Creates temporary file for read/write
@@ -404,21 +460,22 @@ def BuildFlashMap(secondSpiFlashBaseAddr):
# Returns temporary file name
#
def PacklfwRoImage(rorw_file, loader_file, image_size):
- """Create a temp file with the
- first image_size bytes from the loader file and append bytes
- from the rorw file.
- return the filename"""
- fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
- with open(loader_file,'rb') as fin1: # read 4KB loader file
- pro = fin1.read()
- fo.write(pro) # write 4KB loader data to temp file
- with open(rorw_file, 'rb') as fin:
- ro = fin.read(image_size)
-
- fo.write(ro)
- fo.close()
-
- return fo.name
+ """Create a temp file with the
+ first image_size bytes from the loader file and append bytes
+ from the rorw file.
+ return the filename"""
+ fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
+ with open(loader_file, "rb") as fin1: # read 4KB loader file
+ pro = fin1.read()
+ fo.write(pro) # write 4KB loader data to temp file
+ with open(rorw_file, "rb") as fin:
+ ro = fin.read(image_size)
+
+ fo.write(ro)
+ fo.close()
+
+ return fo.name
+
#
# Generate a test EC_RW image of same size
@@ -429,129 +486,184 @@ def PacklfwRoImage(rorw_file, loader_file, image_size):
# process hash generation.
#
def gen_test_ecrw(pldrw):
- debug_print("gen_test_ecrw: pldrw type =", type(pldrw))
- debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw)))
- cookie1_pos = pldrw.find(b'\x99\x88\x77\xce')
- cookie2_pos = pldrw.find(b'\xdd\xbb\xaa\xce', cookie1_pos+4)
- t = struct.unpack("<L", pldrw[cookie1_pos+0x24:cookie1_pos+0x28])
- size = t[0]
- debug_print("EC_RW size =", size, " = ", hex(size))
-
- debug_print("Found cookie1 at ", hex(cookie1_pos))
- debug_print("Found cookie2 at ", hex(cookie2_pos))
-
- if cookie1_pos > 0 and cookie2_pos > cookie1_pos:
- for i in range(0, cookie1_pos):
- pldrw[i] = 0xA5
- for i in range(cookie2_pos+4, len(pldrw)):
- pldrw[i] = 0xA5
-
- with open("ec_RW_test.bin", "wb") as fecrw:
- fecrw.write(pldrw[:size])
+ debug_print("gen_test_ecrw: pldrw type =", type(pldrw))
+ debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw)))
+ cookie1_pos = pldrw.find(b"\x99\x88\x77\xce")
+ cookie2_pos = pldrw.find(b"\xdd\xbb\xaa\xce", cookie1_pos + 4)
+ t = struct.unpack("<L", pldrw[cookie1_pos + 0x24 : cookie1_pos + 0x28])
+ size = t[0]
+ debug_print("EC_RW size =", size, " = ", hex(size))
+
+ debug_print("Found cookie1 at ", hex(cookie1_pos))
+ debug_print("Found cookie2 at ", hex(cookie2_pos))
+
+ if cookie1_pos > 0 and cookie2_pos > cookie1_pos:
+ for i in range(0, cookie1_pos):
+ pldrw[i] = 0xA5
+ for i in range(cookie2_pos + 4, len(pldrw)):
+ pldrw[i] = 0xA5
+
+ with open("ec_RW_test.bin", "wb") as fecrw:
+ fecrw.write(pldrw[:size])
+
def parseargs():
- #TODO I commented this out. Why?
- rpath = os.path.dirname(os.path.relpath(__file__))
-
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input",
- help="EC binary to pack, usually ec.bin or ec.RO.flat.",
- metavar="EC_BIN", default="ec.bin")
- parser.add_argument("-o", "--output",
- help="Output flash binary file",
- metavar="EC_SPI_FLASH", default="ec.packed.bin")
- parser.add_argument("--loader_file",
- help="EC loader binary",
- default="ecloader.bin")
- parser.add_argument("-s", "--spi_size", type=int,
- help="Size of the SPI flash in KB",
- default=512)
- parser.add_argument("-l", "--header_loc", type=int,
- help="Location of header in SPI flash",
- default=0x1000)
- parser.add_argument("-r", "--rw_loc", type=int,
- help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size",
- default=-1)
- parser.add_argument("--spi_clock", type=int,
- help="SPI clock speed. 8, 12, 24, or 48 MHz.",
- default=24)
- parser.add_argument("--spi_read_cmd", type=int,
- help="SPI read command. 0x3, 0xB, or 0x3B.",
- default=0xb)
- parser.add_argument("--image_size", type=int,
- help="Size of a single image. Default 220KB",
- default=(220 * 1024))
- parser.add_argument("--test_spi", action='store_true',
- help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries",
- default=False)
- parser.add_argument("--test_ecrw", action='store_true',
- help="Use fixed pattern for EC_RW but preserve image_data",
- default=False)
- parser.add_argument("--verbose", action='store_true',
- help="Enable verbose output",
- default=False)
- parser.add_argument("--tag0_loc", type=int,
- help="MEC152X TAG0 SPI offset",
- default=0)
- parser.add_argument("--tag1_loc", type=int,
- help="MEC152X TAG1 SPI offset",
- default=4)
- parser.add_argument("--spi_drive_str", type=int,
- help="Chip SPI drive strength in mA: 2, 4, 8, or 12",
- default=4)
- parser.add_argument("--spi_slew_fast", action='store_true',
- help="SPI use fast slew rate. Default is False",
- default=False)
- parser.add_argument("--spi_cpol", type=int,
- help="SPI clock polarity when idle. Defealt is 0(low)",
- default=0)
- parser.add_argument("--spi_cpha_mosi", type=int,
- help="""SPI clock phase master drives data.
+ # TODO I commented this out. Why?
+ rpath = os.path.dirname(os.path.relpath(__file__))
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i",
+ "--input",
+ help="EC binary to pack, usually ec.bin or ec.RO.flat.",
+ metavar="EC_BIN",
+ default="ec.bin",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ help="Output flash binary file",
+ metavar="EC_SPI_FLASH",
+ default="ec.packed.bin",
+ )
+ parser.add_argument(
+ "--loader_file", help="EC loader binary", default="ecloader.bin"
+ )
+ parser.add_argument(
+ "-s", "--spi_size", type=int, help="Size of the SPI flash in KB", default=512
+ )
+ parser.add_argument(
+ "-l",
+ "--header_loc",
+ type=int,
+ help="Location of header in SPI flash",
+ default=0x1000,
+ )
+ parser.add_argument(
+ "-r",
+ "--rw_loc",
+ type=int,
+ help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size",
+ default=-1,
+ )
+ parser.add_argument(
+ "--spi_clock",
+ type=int,
+ help="SPI clock speed. 8, 12, 24, or 48 MHz.",
+ default=24,
+ )
+ parser.add_argument(
+ "--spi_read_cmd",
+ type=int,
+ help="SPI read command. 0x3, 0xB, or 0x3B.",
+ default=0xB,
+ )
+ parser.add_argument(
+ "--image_size",
+ type=int,
+ help="Size of a single image. Default 220KB",
+ default=(220 * 1024),
+ )
+ parser.add_argument(
+ "--test_spi",
+ action="store_true",
+ help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries",
+ default=False,
+ )
+ parser.add_argument(
+ "--test_ecrw",
+ action="store_true",
+ help="Use fixed pattern for EC_RW but preserve image_data",
+ default=False,
+ )
+ parser.add_argument(
+ "--verbose", action="store_true", help="Enable verbose output", default=False
+ )
+ parser.add_argument(
+ "--tag0_loc", type=int, help="MEC152X TAG0 SPI offset", default=0
+ )
+ parser.add_argument(
+ "--tag1_loc", type=int, help="MEC152X TAG1 SPI offset", default=4
+ )
+ parser.add_argument(
+ "--spi_drive_str",
+ type=int,
+ help="Chip SPI drive strength in mA: 2, 4, 8, or 12",
+ default=4,
+ )
+ parser.add_argument(
+ "--spi_slew_fast",
+ action="store_true",
+ help="SPI use fast slew rate. Default is False",
+ default=False,
+ )
+ parser.add_argument(
+ "--spi_cpol",
+ type=int,
+ help="SPI clock polarity when idle. Defealt is 0(low)",
+ default=0,
+ )
+ parser.add_argument(
+ "--spi_cpha_mosi",
+ type=int,
+ help="""SPI clock phase master drives data.
0=Data driven on active to inactive clock edge,
1=Data driven on inactive to active clock edge""",
- default=0)
- parser.add_argument("--spi_cpha_miso", type=int,
- help="""SPI clock phase master samples data.
+ default=0,
+ )
+ parser.add_argument(
+ "--spi_cpha_miso",
+ type=int,
+ help="""SPI clock phase master samples data.
0=Data sampled on inactive to active clock edge,
1=Data sampled on active to inactive clock edge""",
- default=0)
+ default=0,
+ )
+
+ parser.add_argument(
+ "--vtr2_V18",
+ action="store_true",
+ help="Chip VTR2 rail is 1.8V. Default is False(3.3V)",
+ default=False,
+ )
- parser.add_argument("--vtr2_V18", action='store_true',
- help="Chip VTR2 rail is 1.8V. Default is False(3.3V)",
- default=False)
+ parser.add_argument(
+ "--vtr3_V18",
+ action="store_true",
+ help="Chip VTR3 rail is 1.8V. Default is False(3.3V)",
+ default=False,
+ )
- parser.add_argument("--vtr3_V18", action='store_true',
- help="Chip VTR3 rail is 1.8V. Default is False(3.3V)",
- default=False)
+ return parser.parse_args()
- return parser.parse_args()
def print_args(args):
- debug_print("parsed arguments:")
- debug_print(".input = ", args.input)
- debug_print(".output = ", args.output)
- debug_print(".loader_file = ", args.loader_file)
- debug_print(".spi_size (KB) = ", hex(args.spi_size))
- debug_print(".image_size = ", hex(args.image_size))
- debug_print(".tag0_loc = ", hex(args.tag0_loc))
- debug_print(".tag1_loc = ", hex(args.tag1_loc))
- debug_print(".header_loc = ", hex(args.header_loc))
- if args.rw_loc < 0:
- debug_print(".rw_loc = ", args.rw_loc)
- else:
- debug_print(".rw_loc = ", hex(args.rw_loc))
- debug_print(".spi_clock (MHz) = ", args.spi_clock)
- debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd))
- debug_print(".test_spi = ", args.test_spi)
- debug_print(".test_ecrw = ", args.test_ecrw)
- debug_print(".verbose = ", args.verbose)
- debug_print(".spi_drive_str = ", args.spi_drive_str)
- debug_print(".spi_slew_fast = ", args.spi_slew_fast)
- debug_print(".spi_cpol = ", args.spi_cpol)
- debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi)
- debug_print(".spi_cpha_miso = ", args.spi_cpha_miso)
- debug_print(".vtr2_V18 = ", args.vtr2_V18)
- debug_print(".vtr3_V18 = ", args.vtr3_V18)
+ debug_print("parsed arguments:")
+ debug_print(".input = ", args.input)
+ debug_print(".output = ", args.output)
+ debug_print(".loader_file = ", args.loader_file)
+ debug_print(".spi_size (KB) = ", hex(args.spi_size))
+ debug_print(".image_size = ", hex(args.image_size))
+ debug_print(".tag0_loc = ", hex(args.tag0_loc))
+ debug_print(".tag1_loc = ", hex(args.tag1_loc))
+ debug_print(".header_loc = ", hex(args.header_loc))
+ if args.rw_loc < 0:
+ debug_print(".rw_loc = ", args.rw_loc)
+ else:
+ debug_print(".rw_loc = ", hex(args.rw_loc))
+ debug_print(".spi_clock (MHz) = ", args.spi_clock)
+ debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd))
+ debug_print(".test_spi = ", args.test_spi)
+ debug_print(".test_ecrw = ", args.test_ecrw)
+ debug_print(".verbose = ", args.verbose)
+ debug_print(".spi_drive_str = ", args.spi_drive_str)
+ debug_print(".spi_slew_fast = ", args.spi_slew_fast)
+ debug_print(".spi_cpol = ", args.spi_cpol)
+ debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi)
+ debug_print(".spi_cpha_miso = ", args.spi_cpha_miso)
+ debug_print(".vtr2_V18 = ", args.vtr2_V18)
+ debug_print(".vtr3_V18 = ", args.vtr3_V18)
+
#
# Handle quiet mode build from Makefile
@@ -589,215 +701,229 @@ def print_args(args):
# || 48 * [0]
#
def main():
- global debug_print
+ global debug_print
+
+ args = parseargs()
+
+ if args.verbose:
+ debug_print = print
- args = parseargs()
-
- if args.verbose:
- debug_print = print
+ debug_print("Begin pack_ec_mec152x.py script")
+
+ print_args(args)
+
+ chip_dict = MEC152X_DICT
+
+ # Boot-ROM requires header location aligned >= 256 bytes.
+ # CrOS EC flash image update code requires EC_RO/RW location to be aligned
+ # on a flash erase size boundary and EC_RO/RW size to be a multiple of
+ # the smallest flash erase block size.
+ #
+ assert (args.header_loc % SPI_ERASE_BLOCK_SIZE) == 0, (
+ "Header location %d is not on a flash erase block boundary boundary"
+ % args.header_loc
+ )
+
+ max_image_size = CHIP_MAX_CODE_SRAM_KB - LFW_SIZE
+ if args.test_spi:
+ max_image_size -= 32 # SHA256 digest
+
+ assert args.image_size > max_image_size, (
+ "Image size exceeds maximum" % args.image_size
+ )
+
+ spi_size = args.spi_size * 1024
+ debug_print("SPI Flash image size in bytes =", hex(spi_size))
+
+ # !!! IMPORTANT !!!
+ # These values MUST match chip/mchp/config_flash_layout.h
+ # defines.
+ # MEC152x Boot-ROM TAGs are at offset 0 and 4.
+ # lfw + EC_RO starts at beginning of second 4KB sector
+ # EC_RW starts at (flash size / 2) i.e. 0x40000 for a 512KB flash.
+
+ spi_list = []
+
+ debug_print("args.input = ", args.input)
+ debug_print("args.loader_file = ", args.loader_file)
+ debug_print("args.image_size = ", hex(args.image_size))
+
+ rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size)
+ debug_print("Temporary file containing LFW + EC_RO is ", rorofile)
+
+ lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"])
+ lfw_ecro_len = len(lfw_ecro)
+ debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len))
+
+ # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes
+ if args.test_spi:
+ crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4]))
+ crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder="little")
+ lfw_ecro[-4:] = crc32_ecro_bytes
+ debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE))
+ debug_print("CRC32(ecro-4) = ", hex(crc32_ecro))
+
+ # Reads entry point from offset 4 of file.
+ # This assumes binary has Cortex-M4 vector table at offset 0.
+ # 32-bit word at offset 0x0 initial stack pointer value
+ # 32-bit word at offset 0x4 address of reset handler
+ # NOTE: reset address will have bit[0]=1 to ensure thumb mode.
+ lfw_ecro_entry = GetEntryPoint(rorofile)
+
+ # Chromebooks are not using MEC BootROM SPI header/payload authentication
+ # or payload encryption. In this case the header authentication signature
+ # is filled with the hash digest of the respective entity.
+ # BuildHeader2 computes the hash digest and stores it in the correct
+ # header location.
+ header = BuildHeader2(args, chip_dict, lfw_ecro_len, LOAD_ADDR, lfw_ecro_entry)
+ printByteArrayAsHex(header, "Header(lfw_ecro)")
+
+ # If payload encryption used then encrypt payload and
+ # generate Payload Key Header. If encryption not used
+ # payload is not modified and the method returns None
+ encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro)
+ printByteArrayAsHex(encryption_key_header, "LFW + EC_RO encryption_key_header")
+
+ ec_info_block = GenEcInfoBlock(args, chip_dict)
+ printByteArrayAsHex(ec_info_block, "EC Info Block")
+
+ cosignature = GenCoSignature(args, chip_dict, lfw_ecro)
+ printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature")
+
+ trailer = GenTrailer(
+ args, chip_dict, lfw_ecro, encryption_key_header, ec_info_block, cosignature
+ )
+
+ printByteArrayAsHex(trailer, "LFW + EC_RO trailer")
+
+ # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only.
+ tag0 = BuildTag(args)
+ tag1 = tag0
+
+ debug_print("Call to GetPayloadFromOffset")
+ debug_print("args.input = ", args.input)
+ debug_print("args.image_size = ", hex(args.image_size))
+
+ ecrw = GetPayloadFromOffset(args.input, args.image_size, chip_dict["PAD_SIZE"])
+ debug_print("type(ecrw) is ", type(ecrw))
+ debug_print("len(ecrw) is ", hex(len(ecrw)))
+
+ # truncate to args.image_size
+ ecrw_len = len(ecrw)
+ if ecrw_len > args.image_size:
+ debug_print(
+ "Truncate EC_RW len={0:0x} to image_size={1:0x}".format(
+ ecrw_len, args.image_size
+ )
+ )
+ ecrw = ecrw[: args.image_size]
+ ecrw_len = len(ecrw)
+
+ debug_print("len(EC_RW) = ", hex(ecrw_len))
+
+ # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes
+ if args.test_spi:
+ crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4]))
+ crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder="little")
+ ecrw[-4:] = crc32_ecrw_bytes
+ debug_print("ecrw len = ", hex(len(ecrw)))
+ debug_print("CRC32(ecrw) = ", hex(crc32_ecrw))
+
+ # Assume FW layout is standard Cortex-M style with vector
+ # table at start of binary.
+ # 32-bit word at offset 0x0 = Initial stack pointer
+ # 32-bit word at offset 0x4 = Address of reset handler
+ ecrw_entry_tuple = struct.unpack_from("<I", ecrw, 4)
+ debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0]))
+
+ ecrw_entry = ecrw_entry_tuple[0]
+ debug_print("ecrw_entry = ", hex(ecrw_entry))
+
+ # Note: payload_rw is a bytearray therefore is mutable
+ if args.test_ecrw:
+ gen_test_ecrw(ecrw)
+
+ os.remove(rorofile) # clean up the temp file
+
+ # MEC152X Add TAG's
+ spi_list.append((args.tag0_loc, tag0, "tag0"))
+ spi_list.append((args.tag1_loc, tag1, "tag1"))
+
+ # flashmap is non-zero only for systems with two external
+ # SPI flash chips.
+ flashmap = BuildFlashMap(0)
+ spi_list.append((8, flashmap, "flashmap"))
+
+ # Boot-ROM SPI image header for LFW+EC-RO
+ spi_list.append((args.header_loc, header, "header(lfw + ro)"))
+ spi_list.append(
+ (args.header_loc + chip_dict["PAYLOAD_OFFSET"], lfw_ecro, "lfw_ecro")
+ )
+
+ offset = args.header_loc + chip_dict["PAYLOAD_OFFSET"] + lfw_ecro_len
- debug_print("Begin pack_ec_mec152x.py script")
-
- print_args(args)
-
- chip_dict = MEC152X_DICT
-
- # Boot-ROM requires header location aligned >= 256 bytes.
- # CrOS EC flash image update code requires EC_RO/RW location to be aligned
- # on a flash erase size boundary and EC_RO/RW size to be a multiple of
- # the smallest flash erase block size.
- #
- assert (args.header_loc % SPI_ERASE_BLOCK_SIZE) == 0, \
- "Header location %d is not on a flash erase block boundary boundary" % args.header_loc
-
- max_image_size = CHIP_MAX_CODE_SRAM_KB - LFW_SIZE
- if args.test_spi:
- max_image_size -= 32 # SHA256 digest
-
- assert args.image_size > max_image_size, \
- "Image size exceeds maximum" % args.image_size
-
- spi_size = args.spi_size * 1024
- debug_print("SPI Flash image size in bytes =", hex(spi_size))
-
- # !!! IMPORTANT !!!
- # These values MUST match chip/mchp/config_flash_layout.h
- # defines.
- # MEC152x Boot-ROM TAGs are at offset 0 and 4.
- # lfw + EC_RO starts at beginning of second 4KB sector
- # EC_RW starts at (flash size / 2) i.e. 0x40000 for a 512KB flash.
-
- spi_list = []
-
- debug_print("args.input = ",args.input)
- debug_print("args.loader_file = ",args.loader_file)
- debug_print("args.image_size = ",hex(args.image_size))
-
- rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size)
- debug_print("Temporary file containing LFW + EC_RO is ", rorofile)
-
- lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"])
- lfw_ecro_len = len(lfw_ecro)
- debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len))
-
- # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes
- if args.test_spi:
- crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4]))
- crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder='little')
- lfw_ecro[-4:] = crc32_ecro_bytes
- debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE))
- debug_print("CRC32(ecro-4) = ", hex(crc32_ecro))
-
- # Reads entry point from offset 4 of file.
- # This assumes binary has Cortex-M4 vector table at offset 0.
- # 32-bit word at offset 0x0 initial stack pointer value
- # 32-bit word at offset 0x4 address of reset handler
- # NOTE: reset address will have bit[0]=1 to ensure thumb mode.
- lfw_ecro_entry = GetEntryPoint(rorofile)
-
- # Chromebooks are not using MEC BootROM SPI header/payload authentication
- # or payload encryption. In this case the header authentication signature
- # is filled with the hash digest of the respective entity.
- # BuildHeader2 computes the hash digest and stores it in the correct
- # header location.
- header = BuildHeader2(args, chip_dict, lfw_ecro_len,
- LOAD_ADDR, lfw_ecro_entry)
- printByteArrayAsHex(header, "Header(lfw_ecro)")
-
- # If payload encryption used then encrypt payload and
- # generate Payload Key Header. If encryption not used
- # payload is not modified and the method returns None
- encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro)
- printByteArrayAsHex(encryption_key_header,
- "LFW + EC_RO encryption_key_header")
-
- ec_info_block = GenEcInfoBlock(args, chip_dict)
- printByteArrayAsHex(ec_info_block, "EC Info Block")
-
- cosignature = GenCoSignature(args, chip_dict, lfw_ecro)
- printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature")
-
- trailer = GenTrailer(args, chip_dict, lfw_ecro, encryption_key_header,
- ec_info_block, cosignature)
-
- printByteArrayAsHex(trailer, "LFW + EC_RO trailer")
-
- # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only.
- tag0 = BuildTag(args)
- tag1 = tag0
-
- debug_print("Call to GetPayloadFromOffset")
- debug_print("args.input = ", args.input)
- debug_print("args.image_size = ", hex(args.image_size))
-
- ecrw = GetPayloadFromOffset(args.input, args.image_size,
- chip_dict["PAD_SIZE"])
- debug_print("type(ecrw) is ", type(ecrw))
- debug_print("len(ecrw) is ", hex(len(ecrw)))
-
- # truncate to args.image_size
- ecrw_len = len(ecrw)
- if ecrw_len > args.image_size:
- debug_print("Truncate EC_RW len={0:0x} to image_size={1:0x}".format(ecrw_len,args.image_size))
- ecrw = ecrw[:args.image_size]
- ecrw_len = len(ecrw)
-
- debug_print("len(EC_RW) = ", hex(ecrw_len))
-
- # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes
- if args.test_spi:
- crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4]))
- crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder='little')
- ecrw[-4:] = crc32_ecrw_bytes
- debug_print("ecrw len = ", hex(len(ecrw)))
- debug_print("CRC32(ecrw) = ", hex(crc32_ecrw))
-
- # Assume FW layout is standard Cortex-M style with vector
- # table at start of binary.
- # 32-bit word at offset 0x0 = Initial stack pointer
- # 32-bit word at offset 0x4 = Address of reset handler
- ecrw_entry_tuple = struct.unpack_from('<I', ecrw, 4)
- debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0]))
-
- ecrw_entry = ecrw_entry_tuple[0]
- debug_print("ecrw_entry = ", hex(ecrw_entry))
-
- # Note: payload_rw is a bytearray therefore is mutable
- if args.test_ecrw:
- gen_test_ecrw(ecrw)
-
- os.remove(rorofile) # clean up the temp file
-
- # MEC152X Add TAG's
- spi_list.append((args.tag0_loc, tag0, "tag0"))
- spi_list.append((args.tag1_loc, tag1, "tag1"))
-
- # flashmap is non-zero only for systems with two external
- # SPI flash chips.
- flashmap = BuildFlashMap(0)
- spi_list.append((8, flashmap, "flashmap"))
-
- # Boot-ROM SPI image header for LFW+EC-RO
- spi_list.append((args.header_loc, header, "header(lfw + ro)"))
- spi_list.append((args.header_loc + chip_dict["PAYLOAD_OFFSET"], lfw_ecro,
- "lfw_ecro"))
-
- offset = args.header_loc + chip_dict["PAYLOAD_OFFSET"] + lfw_ecro_len
-
- if ec_info_block != None:
- spi_list.append((offset, ec_info_block, "EC Info Block"))
- offset += len(ec_info_block)
-
- if cosignature != None:
- spi_list.append((offset, cosignature, "ECRO Cosignature"))
- offset += len(cosignature)
-
- if trailer != None:
- spi_list.append((offset, trailer, "ECRO Trailer"))
- offset += len(trailer)
-
- # EC_RW location
- rw_offset = int(spi_size // 2)
- if args.rw_loc >= 0:
- rw_offset = args.rw_loc
-
- debug_print("rw_offset = 0x{0:08x}".format(rw_offset))
-
- assert rw_offset >= offset, \
- print("""Offset of EC_RW at {0:08x} overlaps end
- of EC_RO at {0:08x}""".format(rw_offset, offset))
-
- spi_list.append((rw_offset, ecrw, "ecrw"))
- offset = rw_offset + len(ecrw)
-
- spi_list = sorted(spi_list)
-
- dumpsects(spi_list)
-
- #
- # MEC152X Boot-ROM locates TAG0/1 at SPI offset 0
- # instead of end of SPI.
- #
- with open(args.output, 'wb') as f:
- debug_print("Write spi list to file", args.output)
- addr = 0
- for s in spi_list:
- if addr < s[0]:
- debug_print("Offset ",hex(addr)," Length", hex(s[0]-addr),
- "fill with 0xff")
- f.write(b'\xff' * (s[0] - addr))
- addr = s[0]
- debug_print("Offset ",hex(addr), " Length", hex(len(s[1])), "write data")
-
- f.write(s[1])
- addr += len(s[1])
-
- if addr < spi_size:
- debug_print("Offset ",hex(addr), " Length", hex(spi_size - addr),
- "fill with 0xff")
- f.write(b'\xff' * (spi_size - addr))
-
- f.flush()
+ if ec_info_block != None:
+ spi_list.append((offset, ec_info_block, "EC Info Block"))
+ offset += len(ec_info_block)
-if __name__ == '__main__':
- main()
+ if cosignature != None:
+ spi_list.append((offset, cosignature, "ECRO Cosignature"))
+ offset += len(cosignature)
+
+ if trailer != None:
+ spi_list.append((offset, trailer, "ECRO Trailer"))
+ offset += len(trailer)
+
+ # EC_RW location
+ rw_offset = int(spi_size // 2)
+ if args.rw_loc >= 0:
+ rw_offset = args.rw_loc
+
+ debug_print("rw_offset = 0x{0:08x}".format(rw_offset))
+
+ assert rw_offset >= offset, print(
+ """Offset of EC_RW at {0:08x} overlaps end
+ of EC_RO at {0:08x}""".format(
+ rw_offset, offset
+ )
+ )
+
+ spi_list.append((rw_offset, ecrw, "ecrw"))
+ offset = rw_offset + len(ecrw)
+
+ spi_list = sorted(spi_list)
+
+ dumpsects(spi_list)
+
+ #
+ # MEC152X Boot-ROM locates TAG0/1 at SPI offset 0
+ # instead of end of SPI.
+ #
+ with open(args.output, "wb") as f:
+ debug_print("Write spi list to file", args.output)
+ addr = 0
+ for s in spi_list:
+ if addr < s[0]:
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(s[0] - addr), "fill with 0xff"
+ )
+ f.write(b"\xff" * (s[0] - addr))
+ addr = s[0]
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(len(s[1])), "write data"
+ )
+
+ f.write(s[1])
+ addr += len(s[1])
+
+ if addr < spi_size:
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(spi_size - addr), "fill with 0xff"
+ )
+ f.write(b"\xff" * (spi_size - addr))
+
+ f.flush()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/chip/mchp/util/pack_ec_mec172x.py b/chip/mchp/util/pack_ec_mec172x.py
index 32747d3d9a..bd5ff6edba 100755
--- a/chip/mchp/util/pack_ec_mec172x.py
+++ b/chip/mchp/util/pack_ec_mec172x.py
@@ -16,7 +16,7 @@ import os
import struct
import subprocess
import tempfile
-import zlib # CRC32
+import zlib # CRC32
# MEC172x has 416KB SRAM from 0xC0000 - 0x127FFF
# SRAM is divided into contiguous CODE & DATA
@@ -28,8 +28,8 @@ import zlib # CRC32
#
SPI_ERASE_BLOCK_SIZE = 0x1000
SPI_CLOCK_LIST = [48, 24, 16, 12, 96]
-SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b, 0x6b]
-SPI_DRIVE_STR_DICT = {2:0, 4:1, 8:2, 12:3}
+SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B, 0x6B]
+SPI_DRIVE_STR_DICT = {2: 0, 4: 1, 8: 2, 12: 3}
# Maximum EC_RO/EC_RW code size is based upon SPI flash erase
# sector size, MEC172x Boot-ROM TAG, Header, Footer.
# SPI Offset Description
@@ -38,118 +38,156 @@ SPI_DRIVE_STR_DICT = {2:0, 4:1, 8:2, 12:3}
# 0x1140 - 0x213F 4KB LFW
# 0x2040 - 0x3EFFF
# 0x3F000 - 0x3FFFF BootROM EC_INFO_BLK || COSIG || ENCR_KEY_HDR(optional) || TRAILER
-CHIP_MAX_CODE_SRAM_KB = (256 - 12)
+CHIP_MAX_CODE_SRAM_KB = 256 - 12
MEC172X_DICT = {
- "LFW_SIZE": 0x1000,
- "LOAD_ADDR": 0xC0000,
- "TAG_SIZE": 4,
- "KEY_BLOB_SIZE": 1584,
- "HEADER_SIZE":0x140,
- "HEADER_VER":0x03,
- "PAYLOAD_GRANULARITY":128,
- "PAYLOAD_PAD_BYTE":b'\xff',
- "EC_INFO_BLK_SZ":128,
- "ENCR_KEY_HDR_SZ":128,
- "COSIG_SZ":96,
- "TRAILER_SZ":160,
- "TAILER_PAD_BYTE":b'\xff',
- "PAD_SIZE":128
- }
-
-CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15,
- 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d]
+ "LFW_SIZE": 0x1000,
+ "LOAD_ADDR": 0xC0000,
+ "TAG_SIZE": 4,
+ "KEY_BLOB_SIZE": 1584,
+ "HEADER_SIZE": 0x140,
+ "HEADER_VER": 0x03,
+ "PAYLOAD_GRANULARITY": 128,
+ "PAYLOAD_PAD_BYTE": b"\xff",
+ "EC_INFO_BLK_SZ": 128,
+ "ENCR_KEY_HDR_SZ": 128,
+ "COSIG_SZ": 96,
+ "TRAILER_SZ": 160,
+ "TAILER_PAD_BYTE": b"\xff",
+ "PAD_SIZE": 128,
+}
+
+CRC_TABLE = [
+ 0x00,
+ 0x07,
+ 0x0E,
+ 0x09,
+ 0x1C,
+ 0x1B,
+ 0x12,
+ 0x15,
+ 0x38,
+ 0x3F,
+ 0x36,
+ 0x31,
+ 0x24,
+ 0x23,
+ 0x2A,
+ 0x2D,
+]
+
def mock_print(*args, **kwargs):
- pass
+ pass
+
debug_print = mock_print
# Debug helper routine
def dumpsects(spi_list):
- debug_print("spi_list has {0} entries".format(len(spi_list)))
- for s in spi_list:
- debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0],len(s[1]),s[2]))
+ debug_print("spi_list has {0} entries".format(len(spi_list)))
+ for s in spi_list:
+ debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0], len(s[1]), s[2]))
+
def printByteArrayAsHex(ba, title):
- debug_print(title,"= ")
- if ba == None:
- debug_print("None")
- return
-
- count = 0
- for b in ba:
- count = count + 1
- debug_print("0x{0:02x}, ".format(b),end="")
- if (count % 8) == 0:
- debug_print("")
- debug_print("")
+ debug_print(title, "= ")
+ if ba == None:
+ debug_print("None")
+ return
+
+ count = 0
+ for b in ba:
+ count = count + 1
+ debug_print("0x{0:02x}, ".format(b), end="")
+ if (count % 8) == 0:
+ debug_print("")
+ debug_print("")
+
def Crc8(crc, data):
- """Update CRC8 value."""
- for v in data:
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]);
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]);
- return crc ^ 0x55
+ """Update CRC8 value."""
+ for v in data:
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)])
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)])
+ return crc ^ 0x55
+
def GetEntryPoint(payload_file):
- """Read entry point from payload EC image."""
- with open(payload_file, 'rb') as f:
- f.seek(4)
- s = f.read(4)
- return int.from_bytes(s, byteorder='little')
+ """Read entry point from payload EC image."""
+ with open(payload_file, "rb") as f:
+ f.seek(4)
+ s = f.read(4)
+ return int.from_bytes(s, byteorder="little")
+
def GetPayloadFromOffset(payload_file, offset, padsize):
- """Read payload and pad it to padsize."""
- with open(payload_file, 'rb') as f:
- f.seek(offset)
- payload = bytearray(f.read())
- rem_len = len(payload) % padsize
- debug_print("GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format(padsize,len(payload),rem_len))
+ """Read payload and pad it to padsize."""
+ with open(payload_file, "rb") as f:
+ f.seek(offset)
+ payload = bytearray(f.read())
+ rem_len = len(payload) % padsize
+ debug_print(
+ "GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format(
+ padsize, len(payload), rem_len
+ )
+ )
- if rem_len:
- payload += PAYLOAD_PAD_BYTE * (padsize - rem_len)
- debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len))
+ if rem_len:
+ payload += PAYLOAD_PAD_BYTE * (padsize - rem_len)
+ debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len))
+
+ return payload
- return payload
def GetPayload(payload_file, padsize):
- """Read payload and pad it to padsize"""
- return GetPayloadFromOffset(payload_file, 0, padsize)
+ """Read payload and pad it to padsize"""
+ return GetPayloadFromOffset(payload_file, 0, padsize)
+
def GetPublicKey(pem_file):
- """Extract public exponent and modulus from PEM file."""
- result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text',
- '-noout'], stdout=subprocess.PIPE, encoding='utf-8')
- modulus_raw = []
- in_modulus = False
- for line in result.stdout.splitlines():
- if line.startswith('modulus'):
- in_modulus = True
- elif not line.startswith(' '):
- in_modulus = False
- elif in_modulus:
- modulus_raw.extend(line.strip().strip(':').split(':'))
- if line.startswith('publicExponent'):
- exp = int(line.split(' ')[1], 10)
- modulus_raw.reverse()
- modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
- return struct.pack('<Q', exp), modulus
+ """Extract public exponent and modulus from PEM file."""
+ result = subprocess.run(
+ ["openssl", "rsa", "-in", pem_file, "-text", "-noout"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ )
+ modulus_raw = []
+ in_modulus = False
+ for line in result.stdout.splitlines():
+ if line.startswith("modulus"):
+ in_modulus = True
+ elif not line.startswith(" "):
+ in_modulus = False
+ elif in_modulus:
+ modulus_raw.extend(line.strip().strip(":").split(":"))
+ if line.startswith("publicExponent"):
+ exp = int(line.split(" ")[1], 10)
+ modulus_raw.reverse()
+ modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
+ return struct.pack("<Q", exp), modulus
+
def GetSpiClockParameter(args):
- assert args.spi_clock in SPI_CLOCK_LIST, \
- "Unsupported SPI clock speed %d MHz" % args.spi_clock
- return SPI_CLOCK_LIST.index(args.spi_clock)
+ assert args.spi_clock in SPI_CLOCK_LIST, (
+ "Unsupported SPI clock speed %d MHz" % args.spi_clock
+ )
+ return SPI_CLOCK_LIST.index(args.spi_clock)
+
def GetSpiReadCmdParameter(args):
- assert args.spi_read_cmd in SPI_READ_CMD_LIST, \
- "Unsupported SPI read command 0x%x" % args.spi_read_cmd
- return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+ assert args.spi_read_cmd in SPI_READ_CMD_LIST, (
+ "Unsupported SPI read command 0x%x" % args.spi_read_cmd
+ )
+ return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+
def GetEncodedSpiDriveStrength(args):
- assert args.spi_drive_str in SPI_DRIVE_STR_DICT, \
- "Unsupported SPI drive strength %d mA" % args.spi_drive_str
- return SPI_DRIVE_STR_DICT.get(args.spi_drive_str)
+ assert args.spi_drive_str in SPI_DRIVE_STR_DICT, (
+ "Unsupported SPI drive strength %d mA" % args.spi_drive_str
+ )
+ return SPI_DRIVE_STR_DICT.get(args.spi_drive_str)
+
# Return 0=Slow slew rate or 1=Fast slew rate
def GetSpiSlewRate(args):
@@ -157,12 +195,14 @@ def GetSpiSlewRate(args):
return 1
return 0
+
# Return SPI CPOL = 0 or 1
def GetSpiCpol(args):
if args.spi_cpol == 0:
return 0
return 1
+
# Return SPI CPHA_MOSI
# 0 = SPI Master drives data is stable on inactive to clock edge
# 1 = SPI Master drives data is stable on active to inactive clock edge
@@ -171,6 +211,7 @@ def GetSpiCphaMosi(args):
return 0
return 1
+
# Return SPI CPHA_MISO 0 or 1
# 0 = SPI Master samples data on inactive to active clock edge
# 1 = SPI Master samples data on active to inactive clock edge
@@ -179,8 +220,10 @@ def GetSpiCphaMiso(args):
return 0
return 1
+
def PadZeroTo(data, size):
- data.extend(b'\0' * (size - len(data)))
+ data.extend(b"\0" * (size - len(data)))
+
#
# Boot-ROM SPI image encryption not used with Chromebooks
@@ -188,6 +231,7 @@ def PadZeroTo(data, size):
def EncryptPayload(args, chip_dict, payload):
return None
+
#
# Build SPI image header for MEC172x
# MEC172x image header size = 320(0x140) bytes
@@ -262,89 +306,94 @@ def EncryptPayload(args, chip_dict, payload):
# header[0x110:0x140] = Header ECDSA-384 signature y-coor. = 0 Auth. disabled
#
def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry):
- header_size = chip_dict["HEADER_SIZE"]
-
- # allocate zero filled header
- header = bytearray(b'\x00' * header_size)
- debug_print("len(header) = ", len(header))
-
- # Identifier and header version
- header[0:4] = b'PHCM'
- header[4] = chip_dict["HEADER_VER"]
-
- # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips
- spiFreqIndex = GetSpiClockParameter(args)
- if spiFreqIndex > 3:
- header[6] |= 0x01
- else:
- header[5] = spiFreqIndex
-
- header[5] |= ((GetEncodedSpiDriveStrength(args) & 0x03) << 2)
- header[5] |= ((GetSpiSlewRate(args) & 0x01) << 4)
- header[5] |= ((GetSpiCpol(args) & 0x01) << 5)
- header[5] |= ((GetSpiCphaMosi(args) & 0x01) << 6)
- header[5] |= ((GetSpiCphaMiso(args) & 0x01) << 7)
-
- # header[6]
- # b[0] value set above
- # b[2:1] = 00b, b[5:3]=111b
- # b[7]=0 No encryption of FW payload
- header[6] |= 0x7 << 3
-
- # SPI read command set same for both chips
- header[7] = GetSpiReadCmdParameter(args) & 0xFF
-
- # bytes 0x08 - 0x0b
- header[0x08:0x0C] = load_addr.to_bytes(4, byteorder='little')
- # bytes 0x0c - 0x0f
- header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder='little')
-
- # bytes 0x10 - 0x11 payload length in units of 128 bytes
- assert payload_len % chip_dict["PAYLOAD_GRANULARITY"] == 0, \
- print("Payload size not a multiple of {0}".format(chip_dict["PAYLOAD_GRANULARITY"]))
-
- payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"])
- assert payload_units < 0x10000, \
- print("Payload too large: len={0} units={1}".format(payload_len, payload_units))
-
- header[0x10:0x12] = payload_units.to_bytes(2, 'little')
-
- # bytes 0x14 - 0x17 TODO offset from start of payload to FW payload to be
- # loaded by Boot-ROM. We ask Boot-ROM to load (LFW || EC_RO).
- # LFW location provided on the command line.
- assert (args.lfw_loc % 4096 == 0), \
- print("LFW location not on a 4KB boundary! 0x{0:0x}".format(args.lfw_loc))
-
- assert args.lfw_loc >= (args.header_loc + chip_dict["HEADER_SIZE"]), \
- print("LFW location not greater than header location + header size")
-
- lfw_ofs = args.lfw_loc - args.header_loc
- header[0x14:0x18] = lfw_ofs.to_bytes(4, 'little')
-
- # MEC172x: authentication key select. Authentication not used, set to 0.
- header[0x18] = 0
-
- # header[0x19], header[0x20:0x28]
- # header[0x1A:0x20] reserved 0
- # MEC172x: supports SPI flash devices with drive strength settings
- # TODO leave these fields at 0 for now. We must add 6 command line
- # arguments.
-
- # header[0x28:0x48] reserve can be any value
- # header[0x48:0x50] Customer use. TODO
- # authentication disabled, leave these 0.
- # header[0x50:0x80] ECDSA P384 Authentication Public key Rx
- # header[0x80:0xB0] ECDSA P384 Authentication Public key Ry
-
- # header[0xB0:0xE0] = SHA384(header[0:0xB0])
- header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest()
- # When ECDSA authentication is disabled MCHP SPI image generator
- # is filling the last 48 bytes of the Header with 0xff
- header[-48:] = b'\xff' * 48
-
- debug_print("After hash: len(header) = ", len(header))
-
- return header
+ header_size = chip_dict["HEADER_SIZE"]
+
+ # allocate zero filled header
+ header = bytearray(b"\x00" * header_size)
+ debug_print("len(header) = ", len(header))
+
+ # Identifier and header version
+ header[0:4] = b"PHCM"
+ header[4] = chip_dict["HEADER_VER"]
+
+ # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips
+ spiFreqIndex = GetSpiClockParameter(args)
+ if spiFreqIndex > 3:
+ header[6] |= 0x01
+ else:
+ header[5] = spiFreqIndex
+
+ header[5] |= (GetEncodedSpiDriveStrength(args) & 0x03) << 2
+ header[5] |= (GetSpiSlewRate(args) & 0x01) << 4
+ header[5] |= (GetSpiCpol(args) & 0x01) << 5
+ header[5] |= (GetSpiCphaMosi(args) & 0x01) << 6
+ header[5] |= (GetSpiCphaMiso(args) & 0x01) << 7
+
+ # header[6]
+ # b[0] value set above
+ # b[2:1] = 00b, b[5:3]=111b
+ # b[7]=0 No encryption of FW payload
+ header[6] |= 0x7 << 3
+
+ # SPI read command set same for both chips
+ header[7] = GetSpiReadCmdParameter(args) & 0xFF
+
+ # bytes 0x08 - 0x0b
+ header[0x08:0x0C] = load_addr.to_bytes(4, byteorder="little")
+ # bytes 0x0c - 0x0f
+ header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder="little")
+
+ # bytes 0x10 - 0x11 payload length in units of 128 bytes
+ assert payload_len % chip_dict["PAYLOAD_GRANULARITY"] == 0, print(
+ "Payload size not a multiple of {0}".format(chip_dict["PAYLOAD_GRANULARITY"])
+ )
+
+ payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"])
+ assert payload_units < 0x10000, print(
+ "Payload too large: len={0} units={1}".format(payload_len, payload_units)
+ )
+
+ header[0x10:0x12] = payload_units.to_bytes(2, "little")
+
+ # bytes 0x14 - 0x17 TODO offset from start of payload to FW payload to be
+ # loaded by Boot-ROM. We ask Boot-ROM to load (LFW || EC_RO).
+ # LFW location provided on the command line.
+ assert args.lfw_loc % 4096 == 0, print(
+ "LFW location not on a 4KB boundary! 0x{0:0x}".format(args.lfw_loc)
+ )
+
+ assert args.lfw_loc >= (args.header_loc + chip_dict["HEADER_SIZE"]), print(
+ "LFW location not greater than header location + header size"
+ )
+
+ lfw_ofs = args.lfw_loc - args.header_loc
+ header[0x14:0x18] = lfw_ofs.to_bytes(4, "little")
+
+ # MEC172x: authentication key select. Authentication not used, set to 0.
+ header[0x18] = 0
+
+ # header[0x19], header[0x20:0x28]
+ # header[0x1A:0x20] reserved 0
+ # MEC172x: supports SPI flash devices with drive strength settings
+ # TODO leave these fields at 0 for now. We must add 6 command line
+ # arguments.
+
+ # header[0x28:0x48] reserve can be any value
+ # header[0x48:0x50] Customer use. TODO
+ # authentication disabled, leave these 0.
+ # header[0x50:0x80] ECDSA P384 Authentication Public key Rx
+ # header[0x80:0xB0] ECDSA P384 Authentication Public key Ry
+
+ # header[0xB0:0xE0] = SHA384(header[0:0xB0])
+ header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest()
+ # When ECDSA authentication is disabled MCHP SPI image generator
+ # is filling the last 48 bytes of the Header with 0xff
+ header[-48:] = b"\xff" * 48
+
+ debug_print("After hash: len(header) = ", len(header))
+
+ return header
+
#
# MEC172x 128-byte EC Info Block appended to end of padded FW binary.
@@ -361,9 +410,10 @@ def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry):
# byte[0x7f] = current imeage revision
#
def GenEcInfoBlock(args, chip_dict):
- # ecinfo = bytearray([0xff] * chip_dict["EC_INFO_BLK_SZ"])
- ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"])
- return ecinfo
+ # ecinfo = bytearray([0xff] * chip_dict["EC_INFO_BLK_SZ"])
+ ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"])
+ return ecinfo
+
#
# Generate SPI FW image co-signature.
@@ -376,7 +426,8 @@ def GenEcInfoBlock(args, chip_dict):
# signature.
#
def GenCoSignature(args, chip_dict, payload):
- return bytearray(b'\xff' * chip_dict["COSIG_SZ"])
+ return bytearray(b"\xff" * chip_dict["COSIG_SZ"])
+
#
# Generate SPI FW Image trailer.
@@ -387,27 +438,33 @@ def GenCoSignature(args, chip_dict, payload):
# trailer[144:160] = 0xFF. Boot-ROM spec. says these bytes should be random.
# Authentication & encryption are not used therefore random data
# is not necessary.
-def GenTrailer(args, chip_dict, payload, encryption_key_header,
- ec_info_block, cosignature):
+def GenTrailer(
+ args, chip_dict, payload, encryption_key_header, ec_info_block, cosignature
+):
debug_print("GenTrailer SHA384 computation")
trailer = bytearray(chip_dict["TAILER_PAD_BYTE"] * chip_dict["TRAILER_SZ"])
hasher = hashlib.sha384()
hasher.update(payload)
debug_print(" Update: payload len=0x{0:0x}".format(len(payload)))
if ec_info_block != None:
- hasher.update(ec_info_block)
- debug_print(" Update: ec_info_block len=0x{0:0x}".format(len(ec_info_block)))
+ hasher.update(ec_info_block)
+ debug_print(" Update: ec_info_block len=0x{0:0x}".format(len(ec_info_block)))
if encryption_key_header != None:
- hasher.update(encryption_key_header)
- debug_print(" Update: encryption_key_header len=0x{0:0x}".format(len(encryption_key_header)))
+ hasher.update(encryption_key_header)
+ debug_print(
+ " Update: encryption_key_header len=0x{0:0x}".format(
+ len(encryption_key_header)
+ )
+ )
if cosignature != None:
- hasher.update(cosignature)
- debug_print(" Update: cosignature len=0x{0:0x}".format(len(cosignature)))
+ hasher.update(cosignature)
+ debug_print(" Update: cosignature len=0x{0:0x}".format(len(cosignature)))
trailer[0:48] = hasher.digest()
- trailer[-16:] = 16 * b'\xff'
+ trailer[-16:] = 16 * b"\xff"
return trailer
+
# MEC172x supports two 32-bit Tags located at offsets 0x0 and 0x4
# in the SPI flash.
# Tag format:
@@ -418,16 +475,21 @@ def GenTrailer(args, chip_dict, payload, encryption_key_header,
# to the same flash part.
#
def BuildTag(args):
- tag = bytearray([(args.header_loc >> 8) & 0xff,
- (args.header_loc >> 16) & 0xff,
- (args.header_loc >> 24) & 0xff])
- tag.append(Crc8(0, tag))
- return tag
+ tag = bytearray(
+ [
+ (args.header_loc >> 8) & 0xFF,
+ (args.header_loc >> 16) & 0xFF,
+ (args.header_loc >> 24) & 0xFF,
+ ]
+ )
+ tag.append(Crc8(0, tag))
+ return tag
+
def BuildTagFromHdrAddr(header_loc):
- tag = bytearray([(header_loc >> 8) & 0xff,
- (header_loc >> 16) & 0xff,
- (header_loc >> 24) & 0xff])
+ tag = bytearray(
+ [(header_loc >> 8) & 0xFF, (header_loc >> 16) & 0xFF, (header_loc >> 24) & 0xFF]
+ )
tag.append(Crc8(0, tag))
return tag
@@ -444,12 +506,13 @@ def BuildTagFromHdrAddr(header_loc):
# Output:
# bytearray of length 4
def BuildFlashMap(secondSpiFlashBaseAddr):
- flashmap = bytearray(4)
- flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xff
- flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xff
- flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xff
- flashmap[3] = Crc8(0, flashmap)
- return flashmap
+ flashmap = bytearray(4)
+ flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xFF
+ flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xFF
+ flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xFF
+ flashmap[3] = Crc8(0, flashmap)
+ return flashmap
+
#
# Creates temporary file for read/write
@@ -460,21 +523,22 @@ def BuildFlashMap(secondSpiFlashBaseAddr):
# Returns temporary file name
#
def PacklfwRoImage(rorw_file, loader_file, image_size):
- """Create a temp file with the
- first image_size bytes from the loader file and append bytes
- from the rorw file.
- return the filename"""
- fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
- with open(loader_file,'rb') as fin1: # read 4KB loader file
- pro = fin1.read()
- fo.write(pro) # write 4KB loader data to temp file
- with open(rorw_file, 'rb') as fin:
- ro = fin.read(image_size)
-
- fo.write(ro)
- fo.close()
-
- return fo.name
+ """Create a temp file with the
+ first image_size bytes from the loader file and append bytes
+ from the rorw file.
+ return the filename"""
+ fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
+ with open(loader_file, "rb") as fin1: # read 4KB loader file
+ pro = fin1.read()
+ fo.write(pro) # write 4KB loader data to temp file
+ with open(rorw_file, "rb") as fin:
+ ro = fin.read(image_size)
+
+ fo.write(ro)
+ fo.close()
+
+ return fo.name
+
#
# Generate a test EC_RW image of same size
@@ -485,136 +549,193 @@ def PacklfwRoImage(rorw_file, loader_file, image_size):
# process hash generation.
#
def gen_test_ecrw(pldrw):
- debug_print("gen_test_ecrw: pldrw type =", type(pldrw))
- debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw)))
- cookie1_pos = pldrw.find(b'\x99\x88\x77\xce')
- cookie2_pos = pldrw.find(b'\xdd\xbb\xaa\xce', cookie1_pos+4)
- t = struct.unpack("<L", pldrw[cookie1_pos+0x24:cookie1_pos+0x28])
- size = t[0]
- debug_print("EC_RW size =", size, " = ", hex(size))
-
- debug_print("Found cookie1 at ", hex(cookie1_pos))
- debug_print("Found cookie2 at ", hex(cookie2_pos))
-
- if cookie1_pos > 0 and cookie2_pos > cookie1_pos:
- for i in range(0, cookie1_pos):
- pldrw[i] = 0xA5
- for i in range(cookie2_pos+4, len(pldrw)):
- pldrw[i] = 0xA5
-
- with open("ec_RW_test.bin", "wb") as fecrw:
- fecrw.write(pldrw[:size])
+ debug_print("gen_test_ecrw: pldrw type =", type(pldrw))
+ debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw)))
+ cookie1_pos = pldrw.find(b"\x99\x88\x77\xce")
+ cookie2_pos = pldrw.find(b"\xdd\xbb\xaa\xce", cookie1_pos + 4)
+ t = struct.unpack("<L", pldrw[cookie1_pos + 0x24 : cookie1_pos + 0x28])
+ size = t[0]
+ debug_print("EC_RW size =", size, " = ", hex(size))
+
+ debug_print("Found cookie1 at ", hex(cookie1_pos))
+ debug_print("Found cookie2 at ", hex(cookie2_pos))
+
+ if cookie1_pos > 0 and cookie2_pos > cookie1_pos:
+ for i in range(0, cookie1_pos):
+ pldrw[i] = 0xA5
+ for i in range(cookie2_pos + 4, len(pldrw)):
+ pldrw[i] = 0xA5
+
+ with open("ec_RW_test.bin", "wb") as fecrw:
+ fecrw.write(pldrw[:size])
+
def parseargs():
- rpath = os.path.dirname(os.path.relpath(__file__))
-
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input",
- help="EC binary to pack, usually ec.bin or ec.RO.flat.",
- metavar="EC_BIN", default="ec.bin")
- parser.add_argument("-o", "--output",
- help="Output flash binary file",
- metavar="EC_SPI_FLASH", default="ec.packed.bin")
- parser.add_argument("--loader_file",
- help="EC loader binary",
- default="ecloader.bin")
- parser.add_argument("--load_addr", type=int,
- help="EC SRAM load address",
- default=0xC0000)
- parser.add_argument("-s", "--spi_size", type=int,
- help="Size of the SPI flash in KB",
- default=512)
- parser.add_argument("-l", "--header_loc", type=int,
- help="Location of header in SPI flash. Must be on a 256 byte boundary",
- default=0x0100)
- parser.add_argument("--lfw_loc", type=int,
- help="Location of LFW in SPI flash. Must be on a 4KB boundary",
- default=0x1000)
- parser.add_argument("--lfw_size", type=int,
- help="LFW size in bytes",
- default=0x1000)
- parser.add_argument("-r", "--rw_loc", type=int,
- help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size",
- default=-1)
- parser.add_argument("--spi_clock", type=int,
- help="SPI clock speed. 8, 12, 24, or 48 MHz.",
- default=24)
- parser.add_argument("--spi_read_cmd", type=int,
- help="SPI read command. 0x3, 0xB, 0x3B, or 0x6B.",
- default=0xb)
- parser.add_argument("--image_size", type=int,
- help="Size of a single image. Default 244KB",
- default=(244 * 1024))
- parser.add_argument("--test_spi", action='store_true',
- help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries",
- default=False)
- parser.add_argument("--test_ecrw", action='store_true',
- help="Use fixed pattern for EC_RW but preserve image_data",
- default=False)
- parser.add_argument("--verbose", action='store_true',
- help="Enable verbose output",
- default=False)
- parser.add_argument("--tag0_loc", type=int,
- help="MEC172x TAG0 SPI offset",
- default=0)
- parser.add_argument("--tag1_loc", type=int,
- help="MEC172x TAG1 SPI offset",
- default=4)
- parser.add_argument("--spi_drive_str", type=int,
- help="Chip SPI drive strength in mA: 2, 4, 8, or 12",
- default=4)
- parser.add_argument("--spi_slew_fast", action='store_true',
- help="SPI use fast slew rate. Default is False",
- default=False)
- parser.add_argument("--spi_cpol", type=int,
- help="SPI clock polarity when idle. Defealt is 0(low)",
- default=0)
- parser.add_argument("--spi_cpha_mosi", type=int,
- help="""SPI clock phase controller drives data.
+ rpath = os.path.dirname(os.path.relpath(__file__))
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i",
+ "--input",
+ help="EC binary to pack, usually ec.bin or ec.RO.flat.",
+ metavar="EC_BIN",
+ default="ec.bin",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ help="Output flash binary file",
+ metavar="EC_SPI_FLASH",
+ default="ec.packed.bin",
+ )
+ parser.add_argument(
+ "--loader_file", help="EC loader binary", default="ecloader.bin"
+ )
+ parser.add_argument(
+ "--load_addr", type=int, help="EC SRAM load address", default=0xC0000
+ )
+ parser.add_argument(
+ "-s", "--spi_size", type=int, help="Size of the SPI flash in KB", default=512
+ )
+ parser.add_argument(
+ "-l",
+ "--header_loc",
+ type=int,
+ help="Location of header in SPI flash. Must be on a 256 byte boundary",
+ default=0x0100,
+ )
+ parser.add_argument(
+ "--lfw_loc",
+ type=int,
+ help="Location of LFW in SPI flash. Must be on a 4KB boundary",
+ default=0x1000,
+ )
+ parser.add_argument(
+ "--lfw_size", type=int, help="LFW size in bytes", default=0x1000
+ )
+ parser.add_argument(
+ "-r",
+ "--rw_loc",
+ type=int,
+ help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size",
+ default=-1,
+ )
+ parser.add_argument(
+ "--spi_clock",
+ type=int,
+ help="SPI clock speed. 8, 12, 24, or 48 MHz.",
+ default=24,
+ )
+ parser.add_argument(
+ "--spi_read_cmd",
+ type=int,
+ help="SPI read command. 0x3, 0xB, 0x3B, or 0x6B.",
+ default=0xB,
+ )
+ parser.add_argument(
+ "--image_size",
+ type=int,
+ help="Size of a single image. Default 244KB",
+ default=(244 * 1024),
+ )
+ parser.add_argument(
+ "--test_spi",
+ action="store_true",
+ help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries",
+ default=False,
+ )
+ parser.add_argument(
+ "--test_ecrw",
+ action="store_true",
+ help="Use fixed pattern for EC_RW but preserve image_data",
+ default=False,
+ )
+ parser.add_argument(
+ "--verbose", action="store_true", help="Enable verbose output", default=False
+ )
+ parser.add_argument(
+ "--tag0_loc", type=int, help="MEC172x TAG0 SPI offset", default=0
+ )
+ parser.add_argument(
+ "--tag1_loc", type=int, help="MEC172x TAG1 SPI offset", default=4
+ )
+ parser.add_argument(
+ "--spi_drive_str",
+ type=int,
+ help="Chip SPI drive strength in mA: 2, 4, 8, or 12",
+ default=4,
+ )
+ parser.add_argument(
+ "--spi_slew_fast",
+ action="store_true",
+ help="SPI use fast slew rate. Default is False",
+ default=False,
+ )
+ parser.add_argument(
+ "--spi_cpol",
+ type=int,
+ help="SPI clock polarity when idle. Defealt is 0(low)",
+ default=0,
+ )
+ parser.add_argument(
+ "--spi_cpha_mosi",
+ type=int,
+ help="""SPI clock phase controller drives data.
0=Data driven on active to inactive clock edge,
1=Data driven on inactive to active clock edge""",
- default=0)
- parser.add_argument("--spi_cpha_miso", type=int,
- help="""SPI clock phase controller samples data.
+ default=0,
+ )
+ parser.add_argument(
+ "--spi_cpha_miso",
+ type=int,
+ help="""SPI clock phase controller samples data.
0=Data sampled on inactive to active clock edge,
1=Data sampled on active to inactive clock edge""",
- default=0)
+ default=0,
+ )
+
+ return parser.parse_args()
- return parser.parse_args()
def print_args(args):
- debug_print("parsed arguments:")
- debug_print(".input = ", args.input)
- debug_print(".output = ", args.output)
- debug_print(".loader_file = ", args.loader_file)
- debug_print(".spi_size (KB) = ", hex(args.spi_size))
- debug_print(".image_size = ", hex(args.image_size))
- debug_print(".load_addr", hex(args.load_addr))
- debug_print(".tag0_loc = ", hex(args.tag0_loc))
- debug_print(".tag1_loc = ", hex(args.tag1_loc))
- debug_print(".header_loc = ", hex(args.header_loc))
- debug_print(".lfw_loc = ", hex(args.lfw_loc))
- debug_print(".lfw_size = ", hex(args.lfw_size))
- if args.rw_loc < 0:
- debug_print(".rw_loc = ", args.rw_loc)
- else:
- debug_print(".rw_loc = ", hex(args.rw_loc))
- debug_print(".spi_clock (MHz) = ", args.spi_clock)
- debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd))
- debug_print(".test_spi = ", args.test_spi)
- debug_print(".test_ecrw = ", args.test_ecrw)
- debug_print(".verbose = ", args.verbose)
- debug_print(".spi_drive_str = ", args.spi_drive_str)
- debug_print(".spi_slew_fast = ", args.spi_slew_fast)
- debug_print(".spi_cpol = ", args.spi_cpol)
- debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi)
- debug_print(".spi_cpha_miso = ", args.spi_cpha_miso)
+ debug_print("parsed arguments:")
+ debug_print(".input = ", args.input)
+ debug_print(".output = ", args.output)
+ debug_print(".loader_file = ", args.loader_file)
+ debug_print(".spi_size (KB) = ", hex(args.spi_size))
+ debug_print(".image_size = ", hex(args.image_size))
+ debug_print(".load_addr", hex(args.load_addr))
+ debug_print(".tag0_loc = ", hex(args.tag0_loc))
+ debug_print(".tag1_loc = ", hex(args.tag1_loc))
+ debug_print(".header_loc = ", hex(args.header_loc))
+ debug_print(".lfw_loc = ", hex(args.lfw_loc))
+ debug_print(".lfw_size = ", hex(args.lfw_size))
+ if args.rw_loc < 0:
+ debug_print(".rw_loc = ", args.rw_loc)
+ else:
+ debug_print(".rw_loc = ", hex(args.rw_loc))
+ debug_print(".spi_clock (MHz) = ", args.spi_clock)
+ debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd))
+ debug_print(".test_spi = ", args.test_spi)
+ debug_print(".test_ecrw = ", args.test_ecrw)
+ debug_print(".verbose = ", args.verbose)
+ debug_print(".spi_drive_str = ", args.spi_drive_str)
+ debug_print(".spi_slew_fast = ", args.spi_slew_fast)
+ debug_print(".spi_cpol = ", args.spi_cpol)
+ debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi)
+ debug_print(".spi_cpha_miso = ", args.spi_cpha_miso)
+
def spi_list_append(mylist, loc, data, description):
- """Append SPI data block tuple to list"""
- t = (loc, data, description)
- mylist.append(t)
- debug_print("Add SPI entry: offset=0x{0:08x} len=0x{1:0x} descr={2}".format(loc, len(data), description))
+ """Append SPI data block tuple to list"""
+ t = (loc, data, description)
+ mylist.append(t)
+ debug_print(
+ "Add SPI entry: offset=0x{0:08x} len=0x{1:0x} descr={2}".format(
+ loc, len(data), description
+ )
+ )
+
#
# Handle quiet mode build from Makefile
@@ -652,200 +773,207 @@ def spi_list_append(mylist, loc, data, description):
# || 48 * [0]
#
def main():
- global debug_print
-
- args = parseargs()
-
- if args.verbose:
- debug_print = print
-
- debug_print("Begin pack_ec_mec172x.py script")
-
- print_args(args)
-
- chip_dict = MEC172X_DICT
-
- # Boot-ROM requires header location aligned >= 256 bytes.
- # CrOS EC flash image update code requires EC_RO/RW location to be aligned
- # on a flash erase size boundary and EC_RO/RW size to be a multiple of
- # the smallest flash erase block size.
-
- spi_size = args.spi_size * 1024
- spi_image_size = spi_size // 2
-
- rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size)
- debug_print("Temporary file containing LFW + EC_RO is ", rorofile)
-
- lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"])
- lfw_ecro_len = len(lfw_ecro)
- debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len))
-
- # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes
- if args.test_spi:
- crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4]))
- crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder='little')
- lfw_ecro[-4:] = crc32_ecro_bytes
- debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE))
- debug_print("CRC32(ecro-4) = ", hex(crc32_ecro))
-
- # Reads entry point from offset 4 of file.
- # This assumes binary has Cortex-M4 vector table at offset 0.
- # 32-bit word at offset 0x0 initial stack pointer value
- # 32-bit word at offset 0x4 address of reset handler
- # NOTE: reset address will have bit[0]=1 to ensure thumb mode.
- lfw_ecro_entry = GetEntryPoint(rorofile)
- debug_print("LFW Entry point from GetEntryPoint = 0x{0:08x}".format(lfw_ecro_entry))
-
- # Chromebooks are not using MEC BootROM SPI header/payload authentication
- # or payload encryption. In this case the header authentication signature
- # is filled with the hash digest of the respective entity.
- # BuildHeader2 computes the hash digest and stores it in the correct
- # header location.
- header = BuildHeader2(args, chip_dict, lfw_ecro_len,
- args.load_addr, lfw_ecro_entry)
- printByteArrayAsHex(header, "Header(lfw_ecro)")
-
- # If payload encryption used then encrypt payload and
- # generate Payload Key Header. If encryption not used
- # payload is not modified and the method returns None
- encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro)
- printByteArrayAsHex(encryption_key_header,
- "LFW + EC_RO encryption_key_header")
-
- ec_info_block = GenEcInfoBlock(args, chip_dict)
- printByteArrayAsHex(ec_info_block, "EC Info Block")
-
- cosignature = GenCoSignature(args, chip_dict, lfw_ecro)
- printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature")
-
- trailer = GenTrailer(args, chip_dict, lfw_ecro, encryption_key_header,
- ec_info_block, cosignature)
-
- printByteArrayAsHex(trailer, "LFW + EC_RO trailer")
-
- # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only.
- tag0 = BuildTag(args)
- tag1 = tag0
-
- debug_print("Call to GetPayloadFromOffset")
- debug_print("args.input = ", args.input)
- debug_print("args.image_size = ", hex(args.image_size))
-
- ecrw = GetPayloadFromOffset(args.input, args.image_size,
- chip_dict["PAD_SIZE"])
- debug_print("type(ecrw) is ", type(ecrw))
- debug_print("len(ecrw) is ", hex(len(ecrw)))
-
- # truncate to args.image_size
- ecrw_len = len(ecrw)
- if ecrw_len > args.image_size:
- debug_print("Truncate EC_RW len={0:0x} to image_size={1:0x}".format(ecrw_len,args.image_size))
- ecrw = ecrw[:args.image_size]
- ecrw_len = len(ecrw)
+ global debug_print
+
+ args = parseargs()
+
+ if args.verbose:
+ debug_print = print
+
+ debug_print("Begin pack_ec_mec172x.py script")
+
+ print_args(args)
+
+ chip_dict = MEC172X_DICT
+
+ # Boot-ROM requires header location aligned >= 256 bytes.
+ # CrOS EC flash image update code requires EC_RO/RW location to be aligned
+ # on a flash erase size boundary and EC_RO/RW size to be a multiple of
+ # the smallest flash erase block size.
+
+ spi_size = args.spi_size * 1024
+ spi_image_size = spi_size // 2
+
+ rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size)
+ debug_print("Temporary file containing LFW + EC_RO is ", rorofile)
+
+ lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"])
+ lfw_ecro_len = len(lfw_ecro)
+ debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len))
+
+ # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes
+ if args.test_spi:
+ crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4]))
+ crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder="little")
+ lfw_ecro[-4:] = crc32_ecro_bytes
+ debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE))
+ debug_print("CRC32(ecro-4) = ", hex(crc32_ecro))
+
+ # Reads entry point from offset 4 of file.
+ # This assumes binary has Cortex-M4 vector table at offset 0.
+ # 32-bit word at offset 0x0 initial stack pointer value
+ # 32-bit word at offset 0x4 address of reset handler
+ # NOTE: reset address will have bit[0]=1 to ensure thumb mode.
+ lfw_ecro_entry = GetEntryPoint(rorofile)
+ debug_print("LFW Entry point from GetEntryPoint = 0x{0:08x}".format(lfw_ecro_entry))
+
+ # Chromebooks are not using MEC BootROM SPI header/payload authentication
+ # or payload encryption. In this case the header authentication signature
+ # is filled with the hash digest of the respective entity.
+ # BuildHeader2 computes the hash digest and stores it in the correct
+ # header location.
+ header = BuildHeader2(args, chip_dict, lfw_ecro_len, args.load_addr, lfw_ecro_entry)
+ printByteArrayAsHex(header, "Header(lfw_ecro)")
+
+ # If payload encryption used then encrypt payload and
+ # generate Payload Key Header. If encryption not used
+ # payload is not modified and the method returns None
+ encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro)
+ printByteArrayAsHex(encryption_key_header, "LFW + EC_RO encryption_key_header")
+
+ ec_info_block = GenEcInfoBlock(args, chip_dict)
+ printByteArrayAsHex(ec_info_block, "EC Info Block")
+
+ cosignature = GenCoSignature(args, chip_dict, lfw_ecro)
+ printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature")
+
+ trailer = GenTrailer(
+ args, chip_dict, lfw_ecro, encryption_key_header, ec_info_block, cosignature
+ )
+
+ printByteArrayAsHex(trailer, "LFW + EC_RO trailer")
+
+ # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only.
+ tag0 = BuildTag(args)
+ tag1 = tag0
+
+ debug_print("Call to GetPayloadFromOffset")
+ debug_print("args.input = ", args.input)
+ debug_print("args.image_size = ", hex(args.image_size))
+
+ ecrw = GetPayloadFromOffset(args.input, args.image_size, chip_dict["PAD_SIZE"])
+ debug_print("type(ecrw) is ", type(ecrw))
+ debug_print("len(ecrw) is ", hex(len(ecrw)))
+
+ # truncate to args.image_size
+ ecrw_len = len(ecrw)
+ if ecrw_len > args.image_size:
+ debug_print(
+ "Truncate EC_RW len={0:0x} to image_size={1:0x}".format(
+ ecrw_len, args.image_size
+ )
+ )
+ ecrw = ecrw[: args.image_size]
+ ecrw_len = len(ecrw)
+
+ debug_print("len(EC_RW) = ", hex(ecrw_len))
+
+ # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes
+ if args.test_spi:
+ crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4]))
+ crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder="little")
+ ecrw[-4:] = crc32_ecrw_bytes
+ debug_print("ecrw len = ", hex(len(ecrw)))
+ debug_print("CRC32(ecrw) = ", hex(crc32_ecrw))
+
+ # Assume FW layout is standard Cortex-M style with vector
+ # table at start of binary.
+ # 32-bit word at offset 0x0 = Initial stack pointer
+ # 32-bit word at offset 0x4 = Address of reset handler
+ ecrw_entry_tuple = struct.unpack_from("<I", ecrw, 4)
+ debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0]))
+
+ ecrw_entry = ecrw_entry_tuple[0]
+ debug_print("ecrw_entry = ", hex(ecrw_entry))
+
+ # Note: payload_rw is a bytearray therefore is mutable
+ if args.test_ecrw:
+ gen_test_ecrw(ecrw)
+
+ os.remove(rorofile) # clean up the temp file
+
+ spi_list = []
+
+ # MEC172x Add TAG's
+ # spi_list.append((args.tag0_loc, tag0, "tag0"))
+ # spi_list.append((args.tag1_loc, tag1, "tag1"))
+ spi_list_append(spi_list, args.tag0_loc, tag0, "TAG0")
+ spi_list_append(spi_list, args.tag1_loc, tag1, "TAG1")
+
+ # Boot-ROM SPI image header for LFW+EC-RO
+ # spi_list.append((args.header_loc, header, "header(lfw + ro)"))
+ spi_list_append(spi_list, args.header_loc, header, "LFW-EC_RO Header")
+
+ spi_list_append(spi_list, args.lfw_loc, lfw_ecro, "LFW-EC_RO FW")
+
+ offset = args.lfw_loc + len(lfw_ecro)
+ debug_print("SPI offset after LFW_ECRO = 0x{0:08x}".format(offset))
- debug_print("len(EC_RW) = ", hex(ecrw_len))
-
- # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes
- if args.test_spi:
- crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4]))
- crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder='little')
- ecrw[-4:] = crc32_ecrw_bytes
- debug_print("ecrw len = ", hex(len(ecrw)))
- debug_print("CRC32(ecrw) = ", hex(crc32_ecrw))
-
- # Assume FW layout is standard Cortex-M style with vector
- # table at start of binary.
- # 32-bit word at offset 0x0 = Initial stack pointer
- # 32-bit word at offset 0x4 = Address of reset handler
- ecrw_entry_tuple = struct.unpack_from('<I', ecrw, 4)
- debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0]))
-
- ecrw_entry = ecrw_entry_tuple[0]
- debug_print("ecrw_entry = ", hex(ecrw_entry))
-
- # Note: payload_rw is a bytearray therefore is mutable
- if args.test_ecrw:
- gen_test_ecrw(ecrw)
-
- os.remove(rorofile) # clean up the temp file
-
- spi_list = []
-
- # MEC172x Add TAG's
- #spi_list.append((args.tag0_loc, tag0, "tag0"))
- #spi_list.append((args.tag1_loc, tag1, "tag1"))
- spi_list_append(spi_list, args.tag0_loc, tag0, "TAG0")
- spi_list_append(spi_list, args.tag1_loc, tag1, "TAG1")
-
- # Boot-ROM SPI image header for LFW+EC-RO
- #spi_list.append((args.header_loc, header, "header(lfw + ro)"))
- spi_list_append(spi_list, args.header_loc, header, "LFW-EC_RO Header")
-
- spi_list_append(spi_list, args.lfw_loc, lfw_ecro, "LFW-EC_RO FW")
-
- offset = args.lfw_loc + len(lfw_ecro)
- debug_print("SPI offset after LFW_ECRO = 0x{0:08x}".format(offset))
-
- if ec_info_block != None:
- spi_list_append(spi_list, offset, ec_info_block, "LFW-EC_RO Info Block")
- offset += len(ec_info_block)
-
- debug_print("SPI offset after ec_info_block = 0x{0:08x}".format(offset))
-
- if cosignature != None:
- #spi_list.append((offset, co-signature, "ECRO Co-signature"))
- spi_list_append(spi_list, offset, cosignature, "LFW-EC_RO Co-signature")
- offset += len(cosignature)
-
- debug_print("SPI offset after co-signature = 0x{0:08x}".format(offset))
-
- if trailer != None:
- #spi_list.append((offset, trailer, "ECRO Trailer"))
- spi_list_append(spi_list, offset, trailer, "LFW-EC_RO trailer")
- offset += len(trailer)
-
- debug_print("SPI offset after trailer = 0x{0:08x}".format(offset))
-
- # EC_RW location
- rw_offset = int(spi_size // 2)
- if args.rw_loc >= 0:
- rw_offset = args.rw_loc
-
- debug_print("rw_offset = 0x{0:08x}".format(rw_offset))
-
- #spi_list.append((rw_offset, ecrw, "ecrw"))
- spi_list_append(spi_list, rw_offset, ecrw, "EC_RW")
- offset = rw_offset + len(ecrw)
-
- spi_list = sorted(spi_list)
-
- debug_print("Display spi_list:")
- dumpsects(spi_list)
-
- #
- # MEC172x Boot-ROM locates TAG0/1 at SPI offset 0
- # instead of end of SPI.
- #
- with open(args.output, 'wb') as f:
- debug_print("Write spi list to file", args.output)
- addr = 0
- for s in spi_list:
- if addr < s[0]:
- debug_print("Offset ",hex(addr)," Length", hex(s[0]-addr),
- "fill with 0xff")
- f.write(b'\xff' * (s[0] - addr))
- addr = s[0]
- debug_print("Offset ",hex(addr), " Length", hex(len(s[1])), "write data")
-
- f.write(s[1])
- addr += len(s[1])
-
- if addr < spi_size:
- debug_print("Offset ",hex(addr), " Length", hex(spi_size - addr),
- "fill with 0xff")
- f.write(b'\xff' * (spi_size - addr))
+ if ec_info_block != None:
+ spi_list_append(spi_list, offset, ec_info_block, "LFW-EC_RO Info Block")
+ offset += len(ec_info_block)
- f.flush()
+ debug_print("SPI offset after ec_info_block = 0x{0:08x}".format(offset))
-if __name__ == '__main__':
- main()
+ if cosignature != None:
+ # spi_list.append((offset, co-signature, "ECRO Co-signature"))
+ spi_list_append(spi_list, offset, cosignature, "LFW-EC_RO Co-signature")
+ offset += len(cosignature)
+
+ debug_print("SPI offset after co-signature = 0x{0:08x}".format(offset))
+
+ if trailer != None:
+ # spi_list.append((offset, trailer, "ECRO Trailer"))
+ spi_list_append(spi_list, offset, trailer, "LFW-EC_RO trailer")
+ offset += len(trailer)
+
+ debug_print("SPI offset after trailer = 0x{0:08x}".format(offset))
+
+ # EC_RW location
+ rw_offset = int(spi_size // 2)
+ if args.rw_loc >= 0:
+ rw_offset = args.rw_loc
+
+ debug_print("rw_offset = 0x{0:08x}".format(rw_offset))
+
+ # spi_list.append((rw_offset, ecrw, "ecrw"))
+ spi_list_append(spi_list, rw_offset, ecrw, "EC_RW")
+ offset = rw_offset + len(ecrw)
+
+ spi_list = sorted(spi_list)
+
+ debug_print("Display spi_list:")
+ dumpsects(spi_list)
+
+ #
+ # MEC172x Boot-ROM locates TAG0/1 at SPI offset 0
+ # instead of end of SPI.
+ #
+ with open(args.output, "wb") as f:
+ debug_print("Write spi list to file", args.output)
+ addr = 0
+ for s in spi_list:
+ if addr < s[0]:
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(s[0] - addr), "fill with 0xff"
+ )
+ f.write(b"\xff" * (s[0] - addr))
+ addr = s[0]
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(len(s[1])), "write data"
+ )
+
+ f.write(s[1])
+ addr += len(s[1])
+
+ if addr < spi_size:
+ debug_print(
+ "Offset ", hex(addr), " Length", hex(spi_size - addr), "fill with 0xff"
+ )
+ f.write(b"\xff" * (spi_size - addr))
+
+ f.flush()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/chip/mec1322/util/pack_ec.py b/chip/mec1322/util/pack_ec.py
index 9783ffb2d5..736f9efcac 100755
--- a/chip/mec1322/util/pack_ec.py
+++ b/chip/mec1322/util/pack_ec.py
@@ -21,228 +21,325 @@ import tempfile
LOAD_ADDR = 0x100000
HEADER_SIZE = 0x140
SPI_CLOCK_LIST = [48, 24, 12, 8]
-SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b]
+SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B]
+
+CRC_TABLE = [
+ 0x00,
+ 0x07,
+ 0x0E,
+ 0x09,
+ 0x1C,
+ 0x1B,
+ 0x12,
+ 0x15,
+ 0x38,
+ 0x3F,
+ 0x36,
+ 0x31,
+ 0x24,
+ 0x23,
+ 0x2A,
+ 0x2D,
+]
-CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15,
- 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d]
def Crc8(crc, data):
- """Update CRC8 value."""
- for v in data:
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]);
- crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]);
- return crc ^ 0x55
+ """Update CRC8 value."""
+ for v in data:
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)])
+ crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)])
+ return crc ^ 0x55
+
def GetEntryPoint(payload_file):
- """Read entry point from payload EC image."""
- with open(payload_file, 'rb') as f:
- f.seek(4)
- s = f.read(4)
- return struct.unpack('<I', s)[0]
-
-def GetPayloadFromOffset(payload_file,offset):
- """Read payload and pad it to 64-byte aligned."""
- with open(payload_file, 'rb') as f:
- f.seek(offset)
- payload = bytearray(f.read())
- rem_len = len(payload) % 64
- if rem_len:
- payload += b'\0' * (64 - rem_len)
- return payload
+ """Read entry point from payload EC image."""
+ with open(payload_file, "rb") as f:
+ f.seek(4)
+ s = f.read(4)
+ return struct.unpack("<I", s)[0]
+
+
+def GetPayloadFromOffset(payload_file, offset):
+ """Read payload and pad it to 64-byte aligned."""
+ with open(payload_file, "rb") as f:
+ f.seek(offset)
+ payload = bytearray(f.read())
+ rem_len = len(payload) % 64
+ if rem_len:
+ payload += b"\0" * (64 - rem_len)
+ return payload
+
def GetPayload(payload_file):
- """Read payload and pad it to 64-byte aligned."""
- return GetPayloadFromOffset(payload_file, 0)
+ """Read payload and pad it to 64-byte aligned."""
+ return GetPayloadFromOffset(payload_file, 0)
+
def GetPublicKey(pem_file):
- """Extract public exponent and modulus from PEM file."""
- result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text',
- '-noout'], stdout=subprocess.PIPE, encoding='utf-8')
- modulus_raw = []
- in_modulus = False
- for line in result.stdout.splitlines():
- if line.startswith('modulus'):
- in_modulus = True
- elif not line.startswith(' '):
- in_modulus = False
- elif in_modulus:
- modulus_raw.extend(line.strip().strip(':').split(':'))
- if line.startswith('publicExponent'):
- exp = int(line.split(' ')[1], 10)
- modulus_raw.reverse()
- modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
- return struct.pack('<Q', exp), modulus
+ """Extract public exponent and modulus from PEM file."""
+ result = subprocess.run(
+ ["openssl", "rsa", "-in", pem_file, "-text", "-noout"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ )
+ modulus_raw = []
+ in_modulus = False
+ for line in result.stdout.splitlines():
+ if line.startswith("modulus"):
+ in_modulus = True
+ elif not line.startswith(" "):
+ in_modulus = False
+ elif in_modulus:
+ modulus_raw.extend(line.strip().strip(":").split(":"))
+ if line.startswith("publicExponent"):
+ exp = int(line.split(" ")[1], 10)
+ modulus_raw.reverse()
+ modulus = bytearray((int(x, 16) for x in modulus_raw[:256]))
+ return struct.pack("<Q", exp), modulus
+
def GetSpiClockParameter(args):
- assert args.spi_clock in SPI_CLOCK_LIST, \
- "Unsupported SPI clock speed %d MHz" % args.spi_clock
- return SPI_CLOCK_LIST.index(args.spi_clock)
+ assert args.spi_clock in SPI_CLOCK_LIST, (
+ "Unsupported SPI clock speed %d MHz" % args.spi_clock
+ )
+ return SPI_CLOCK_LIST.index(args.spi_clock)
+
def GetSpiReadCmdParameter(args):
- assert args.spi_read_cmd in SPI_READ_CMD_LIST, \
- "Unsupported SPI read command 0x%x" % args.spi_read_cmd
- return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+ assert args.spi_read_cmd in SPI_READ_CMD_LIST, (
+ "Unsupported SPI read command 0x%x" % args.spi_read_cmd
+ )
+ return SPI_READ_CMD_LIST.index(args.spi_read_cmd)
+
def PadZeroTo(data, size):
- data.extend(b'\0' * (size - len(data)))
+ data.extend(b"\0" * (size - len(data)))
+
def BuildHeader(args, payload_len, rorofile):
- # Identifier and header version
- header = bytearray(b'CSMS\0')
+ # Identifier and header version
+ header = bytearray(b"CSMS\0")
+
+ PadZeroTo(header, 0x6)
+ header.append(GetSpiClockParameter(args))
+ header.append(GetSpiReadCmdParameter(args))
- PadZeroTo(header, 0x6)
- header.append(GetSpiClockParameter(args))
- header.append(GetSpiReadCmdParameter(args))
+ header.extend(struct.pack("<I", LOAD_ADDR))
+ header.extend(struct.pack("<I", GetEntryPoint(rorofile)))
+ header.append((payload_len >> 6) & 0xFF)
+ header.append((payload_len >> 14) & 0xFF)
+ PadZeroTo(header, 0x14)
+ header.extend(struct.pack("<I", args.payload_offset))
- header.extend(struct.pack('<I', LOAD_ADDR))
- header.extend(struct.pack('<I', GetEntryPoint(rorofile)))
- header.append((payload_len >> 6) & 0xff)
- header.append((payload_len >> 14) & 0xff)
- PadZeroTo(header, 0x14)
- header.extend(struct.pack('<I', args.payload_offset))
+ exp, modulus = GetPublicKey(args.payload_key)
+ PadZeroTo(header, 0x20)
+ header.extend(exp)
+ PadZeroTo(header, 0x30)
+ header.extend(modulus)
+ PadZeroTo(header, HEADER_SIZE)
- exp, modulus = GetPublicKey(args.payload_key)
- PadZeroTo(header, 0x20)
- header.extend(exp)
- PadZeroTo(header, 0x30)
- header.extend(modulus)
- PadZeroTo(header, HEADER_SIZE)
+ return header
- return header
def SignByteArray(data, pem_file):
- hash_file = tempfile.mkstemp(prefix='pack_ec.')[1]
- sign_file = tempfile.mkstemp(prefix='pack_ec.')[1]
- try:
- with open(hash_file, 'wb') as f:
- hasher = hashlib.sha256()
- hasher.update(data)
- f.write(hasher.digest())
- subprocess.run(['openssl', 'rsautl', '-sign', '-inkey', pem_file,
- '-keyform', 'PEM', '-in', hash_file, '-out', sign_file],
- check=True)
- with open(sign_file, 'rb') as f:
- signed = f.read()
- return bytearray(reversed(signed))
- finally:
- os.remove(hash_file)
- os.remove(sign_file)
+ hash_file = tempfile.mkstemp(prefix="pack_ec.")[1]
+ sign_file = tempfile.mkstemp(prefix="pack_ec.")[1]
+ try:
+ with open(hash_file, "wb") as f:
+ hasher = hashlib.sha256()
+ hasher.update(data)
+ f.write(hasher.digest())
+ subprocess.run(
+ [
+ "openssl",
+ "rsautl",
+ "-sign",
+ "-inkey",
+ pem_file,
+ "-keyform",
+ "PEM",
+ "-in",
+ hash_file,
+ "-out",
+ sign_file,
+ ],
+ check=True,
+ )
+ with open(sign_file, "rb") as f:
+ signed = f.read()
+ return bytearray(reversed(signed))
+ finally:
+ os.remove(hash_file)
+ os.remove(sign_file)
+
def BuildTag(args):
- tag = bytearray([(args.header_loc >> 8) & 0xff,
- (args.header_loc >> 16) & 0xff,
- (args.header_loc >> 24) & 0xff])
- if args.chip_select != 0:
- tag[2] |= 0x80
- tag.append(Crc8(0, tag))
- return tag
+ tag = bytearray(
+ [
+ (args.header_loc >> 8) & 0xFF,
+ (args.header_loc >> 16) & 0xFF,
+ (args.header_loc >> 24) & 0xFF,
+ ]
+ )
+ if args.chip_select != 0:
+ tag[2] |= 0x80
+ tag.append(Crc8(0, tag))
+ return tag
+
def PacklfwRoImage(rorw_file, loader_file, image_size):
- """TODO:Clean up to get rid of Temp file and just use memory
- to save data"""
- """Create a temp file with the
+ """TODO:Clean up to get rid of Temp file and just use memory
+ to save data"""
+ """Create a temp file with the
first image_size bytes from the rorw file and the
bytes from the loader_file appended
return the filename"""
- fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
- with open(loader_file,'rb') as fin1:
- pro = fin1.read()
- fo.write(pro)
- with open(rorw_file, 'rb') as fin:
- ro = fin.read(image_size)
- fo.write(ro)
- fo.close()
- return fo.name
+ fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around
+ with open(loader_file, "rb") as fin1:
+ pro = fin1.read()
+ fo.write(pro)
+ with open(rorw_file, "rb") as fin:
+ ro = fin.read(image_size)
+ fo.write(ro)
+ fo.close()
+ return fo.name
+
def parseargs():
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input",
- help="EC binary to pack, usually ec.bin or ec.RO.flat.",
- metavar="EC_BIN", default="ec.bin")
- parser.add_argument("-o", "--output",
- help="Output flash binary file",
- metavar="EC_SPI_FLASH", default="ec.packed.bin")
- parser.add_argument("--header_key",
- help="PEM key file for signing header",
- default="rsakey_sign_header.pem")
- parser.add_argument("--payload_key",
- help="PEM key file for signing payload",
- default="rsakey_sign_payload.pem")
- parser.add_argument("--loader_file",
- help="EC loader binary",
- default="ecloader.bin")
- parser.add_argument("-s", "--spi_size", type=int,
- help="Size of the SPI flash in MB",
- default=4)
- parser.add_argument("-l", "--header_loc", type=int,
- help="Location of header in SPI flash",
- default=0x170000)
- parser.add_argument("-p", "--payload_offset", type=int,
- help="The offset of payload from the header",
- default=0x240)
- parser.add_argument("-r", "--rwpayload_loc", type=int,
- help="The offset of payload from the header",
- default=0x190000)
- parser.add_argument("-z", "--romstart", type=int,
- help="The first location to output of the rom",
- default=0)
- parser.add_argument("-c", "--chip_select", type=int,
- help="Chip select signal to use, either 0 or 1.",
- default=0)
- parser.add_argument("--spi_clock", type=int,
- help="SPI clock speed. 8, 12, 24, or 48 MHz.",
- default=24)
- parser.add_argument("--spi_read_cmd", type=int,
- help="SPI read command. 0x3, 0xB, or 0x3B.",
- default=0xb)
- parser.add_argument("--image_size", type=int,
- help="Size of a single image.",
- default=(96 * 1024))
- return parser.parse_args()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i",
+ "--input",
+ help="EC binary to pack, usually ec.bin or ec.RO.flat.",
+ metavar="EC_BIN",
+ default="ec.bin",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ help="Output flash binary file",
+ metavar="EC_SPI_FLASH",
+ default="ec.packed.bin",
+ )
+ parser.add_argument(
+ "--header_key",
+ help="PEM key file for signing header",
+ default="rsakey_sign_header.pem",
+ )
+ parser.add_argument(
+ "--payload_key",
+ help="PEM key file for signing payload",
+ default="rsakey_sign_payload.pem",
+ )
+ parser.add_argument(
+ "--loader_file", help="EC loader binary", default="ecloader.bin"
+ )
+ parser.add_argument(
+ "-s", "--spi_size", type=int, help="Size of the SPI flash in MB", default=4
+ )
+ parser.add_argument(
+ "-l",
+ "--header_loc",
+ type=int,
+ help="Location of header in SPI flash",
+ default=0x170000,
+ )
+ parser.add_argument(
+ "-p",
+ "--payload_offset",
+ type=int,
+ help="The offset of payload from the header",
+ default=0x240,
+ )
+ parser.add_argument(
+ "-r",
+ "--rwpayload_loc",
+ type=int,
+ help="The offset of payload from the header",
+ default=0x190000,
+ )
+ parser.add_argument(
+ "-z",
+ "--romstart",
+ type=int,
+ help="The first location to output of the rom",
+ default=0,
+ )
+ parser.add_argument(
+ "-c",
+ "--chip_select",
+ type=int,
+ help="Chip select signal to use, either 0 or 1.",
+ default=0,
+ )
+ parser.add_argument(
+ "--spi_clock",
+ type=int,
+ help="SPI clock speed. 8, 12, 24, or 48 MHz.",
+ default=24,
+ )
+ parser.add_argument(
+ "--spi_read_cmd",
+ type=int,
+ help="SPI read command. 0x3, 0xB, or 0x3B.",
+ default=0xB,
+ )
+ parser.add_argument(
+ "--image_size", type=int, help="Size of a single image.", default=(96 * 1024)
+ )
+ return parser.parse_args()
+
def main():
- args = parseargs()
-
- spi_size = args.spi_size * 1024
- args.header_loc = spi_size - (128 * 1024)
- args.rwpayload_loc = spi_size - (256 * 1024)
- args.romstart = spi_size - (256 * 1024)
-
- spi_list = []
-
- rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size)
- payload = GetPayload(rorofile)
- payload_len = len(payload)
- payload_signature = SignByteArray(payload, args.payload_key)
- header = BuildHeader(args, payload_len, rorofile)
- header_signature = SignByteArray(header, args.header_key)
- tag = BuildTag(args)
- # truncate the RW to 128k
- payloadrw = GetPayloadFromOffset(args.input,args.image_size)[:128*1024]
- os.remove(rorofile) # clean up the temp file
-
- spi_list.append((args.header_loc, header, "header"))
- spi_list.append((args.header_loc + HEADER_SIZE, header_signature, "header_signature"))
- spi_list.append((args.header_loc + args.payload_offset, payload, "payload"))
- spi_list.append((args.header_loc + args.payload_offset + payload_len,
- payload_signature, "payload_signature"))
- spi_list.append((spi_size - 256, tag, "tag"))
- spi_list.append((args.rwpayload_loc, payloadrw, "payloadrw"))
-
-
- spi_list = sorted(spi_list)
-
- with open(args.output, 'wb') as f:
- addr = args.romstart
- for s in spi_list:
- assert addr <= s[0]
- if addr < s[0]:
- f.write(b'\xff' * (s[0] - addr))
- addr = s[0]
- f.write(s[1])
- addr += len(s[1])
- if addr < spi_size:
- f.write(b'\xff' * (spi_size - addr))
-
-if __name__ == '__main__':
- main()
+ args = parseargs()
+
+ spi_size = args.spi_size * 1024
+ args.header_loc = spi_size - (128 * 1024)
+ args.rwpayload_loc = spi_size - (256 * 1024)
+ args.romstart = spi_size - (256 * 1024)
+
+ spi_list = []
+
+ rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size)
+ payload = GetPayload(rorofile)
+ payload_len = len(payload)
+ payload_signature = SignByteArray(payload, args.payload_key)
+ header = BuildHeader(args, payload_len, rorofile)
+ header_signature = SignByteArray(header, args.header_key)
+ tag = BuildTag(args)
+ # truncate the RW to 128k
+ payloadrw = GetPayloadFromOffset(args.input, args.image_size)[: 128 * 1024]
+ os.remove(rorofile) # clean up the temp file
+
+ spi_list.append((args.header_loc, header, "header"))
+ spi_list.append(
+ (args.header_loc + HEADER_SIZE, header_signature, "header_signature")
+ )
+ spi_list.append((args.header_loc + args.payload_offset, payload, "payload"))
+ spi_list.append(
+ (
+ args.header_loc + args.payload_offset + payload_len,
+ payload_signature,
+ "payload_signature",
+ )
+ )
+ spi_list.append((spi_size - 256, tag, "tag"))
+ spi_list.append((args.rwpayload_loc, payloadrw, "payloadrw"))
+
+ spi_list = sorted(spi_list)
+
+ with open(args.output, "wb") as f:
+ addr = args.romstart
+ for s in spi_list:
+ assert addr <= s[0]
+ if addr < s[0]:
+ f.write(b"\xff" * (s[0] - addr))
+ addr = s[0]
+ f.write(s[1])
+ addr += len(s[1])
+ if addr < spi_size:
+ f.write(b"\xff" * (spi_size - addr))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cts/common/board.py b/cts/common/board.py
index d2c8e02b04..f62d7bdfc5 100644
--- a/cts/common/board.py
+++ b/cts/common/board.py
@@ -10,379 +10,401 @@
from __future__ import print_function
-from abc import ABCMeta
-from abc import abstractmethod
import os
import shutil
import subprocess as sp
-import serial
+from abc import ABCMeta, abstractmethod
+import serial
import six
-
-OCD_SCRIPT_DIR = '/usr/share/openocd/scripts'
+OCD_SCRIPT_DIR = "/usr/share/openocd/scripts"
OPENOCD_CONFIGS = {
- 'stm32l476g-eval': 'board/stm32l4discovery.cfg',
- 'nucleo-f072rb': 'board/st_nucleo_f0.cfg',
- 'nucleo-f411re': 'board/st_nucleo_f4.cfg',
+ "stm32l476g-eval": "board/stm32l4discovery.cfg",
+ "nucleo-f072rb": "board/st_nucleo_f0.cfg",
+ "nucleo-f411re": "board/st_nucleo_f4.cfg",
}
FLASH_OFFSETS = {
- 'stm32l476g-eval': '0x08000000',
- 'nucleo-f072rb': '0x08000000',
- 'nucleo-f411re': '0x08000000',
+ "stm32l476g-eval": "0x08000000",
+ "nucleo-f072rb": "0x08000000",
+ "nucleo-f411re": "0x08000000",
}
-REBOOT_MARKER = 'UART initialized after reboot'
+REBOOT_MARKER = "UART initialized after reboot"
def get_subprocess_args():
- if six.PY3:
- return {'encoding': 'utf-8'}
- return {}
+ if six.PY3:
+ return {"encoding": "utf-8"}
+ return {}
class Board(six.with_metaclass(ABCMeta, object)):
- """Class representing a single board connected to a host machine.
-
- Attributes:
- board: String containing actual type of board, i.e. nucleo-f072rb
- config: Directory of board config file relative to openocd's
- scripts directory
- hla_serial: String containing board's hla_serial number (if board
- is an stm32 board)
- tty_port: String that is the path to the tty port which board's
- UART outputs to
- tty: String of file descriptor for tty_port
- """
-
- def __init__(self, board, module, hla_serial=None):
- """Initializes a board object with given attributes.
-
- Args:
- board: String containing board name
- module: String of the test module you are building,
- i.e. gpio, timer, etc.
- hla_serial: Serial number if board's adaptor is an HLA
-
- Raises:
- RuntimeError: Board is not supported
- """
- if board not in OPENOCD_CONFIGS:
- msg = 'OpenOcd configuration not found for ' + board
- raise RuntimeError(msg)
- if board not in FLASH_OFFSETS:
- msg = 'Flash offset not found for ' + board
- raise RuntimeError(msg)
- self.board = board
- self.flash_offset = FLASH_OFFSETS[self.board]
- self.openocd_config = OPENOCD_CONFIGS[self.board]
- self.module = module
- self.hla_serial = hla_serial
- self.tty_port = None
- self.tty = None
-
- def reset_log_dir(self):
- """Reset log directory."""
- if os.path.isdir(self.log_dir):
- shutil.rmtree(self.log_dir)
- os.makedirs(self.log_dir)
-
- @staticmethod
- def get_stlink_serials():
- """Gets serial numbers of all st-link v2.1 board attached to host.
-
- Returns:
- List of serials
+ """Class representing a single board connected to a host machine.
+
+ Attributes:
+ board: String containing actual type of board, i.e. nucleo-f072rb
+ config: Directory of board config file relative to openocd's
+ scripts directory
+ hla_serial: String containing board's hla_serial number (if board
+ is an stm32 board)
+ tty_port: String that is the path to the tty port which board's
+ UART outputs to
+ tty: String of file descriptor for tty_port
"""
- usb_args = ['sudo', 'lsusb', '-v', '-d', '0x0483:0x374b']
- st_link_info = sp.check_output(usb_args, **get_subprocess_args())
- st_serials = []
- for line in st_link_info.split('\n'):
- if 'iSerial' not in line:
- continue
- words = line.split()
- if len(words) <= 2:
- continue
- st_serials.append(words[2].strip())
- return st_serials
-
- @abstractmethod
- def get_serial(self):
- """Subclass should implement this."""
- pass
-
- def send_openocd_commands(self, commands):
- """Send a command to the board via openocd.
-
- Args:
- commands: A list of commands to send
-
- Returns:
- True if execution is successful or False otherwise.
- """
- args = ['sudo', 'openocd', '-s', OCD_SCRIPT_DIR,
- '-f', self.openocd_config, '-c', 'hla_serial ' + self.hla_serial]
-
- for cmd in commands:
- args += ['-c', cmd]
- args += ['-c', 'shutdown']
-
- rv = 1
- with open(self.openocd_log, 'a') as output:
- rv = sp.call(args, stdout=output, stderr=sp.STDOUT)
- if rv != 0:
- self.dump_openocd_log()
-
- return rv == 0
-
- def dump_openocd_log(self):
- with open(self.openocd_log) as log:
- print(log.read())
+ def __init__(self, board, module, hla_serial=None):
+ """Initializes a board object with given attributes.
+
+ Args:
+ board: String containing board name
+ module: String of the test module you are building,
+ i.e. gpio, timer, etc.
+ hla_serial: Serial number if board's adaptor is an HLA
+
+ Raises:
+ RuntimeError: Board is not supported
+ """
+ if board not in OPENOCD_CONFIGS:
+ msg = "OpenOcd configuration not found for " + board
+ raise RuntimeError(msg)
+ if board not in FLASH_OFFSETS:
+ msg = "Flash offset not found for " + board
+ raise RuntimeError(msg)
+ self.board = board
+ self.flash_offset = FLASH_OFFSETS[self.board]
+ self.openocd_config = OPENOCD_CONFIGS[self.board]
+ self.module = module
+ self.hla_serial = hla_serial
+ self.tty_port = None
+ self.tty = None
+
+ def reset_log_dir(self):
+ """Reset log directory."""
+ if os.path.isdir(self.log_dir):
+ shutil.rmtree(self.log_dir)
+ os.makedirs(self.log_dir)
+
+ @staticmethod
+ def get_stlink_serials():
+ """Gets serial numbers of all st-link v2.1 board attached to host.
+
+ Returns:
+ List of serials
+ """
+ usb_args = ["sudo", "lsusb", "-v", "-d", "0x0483:0x374b"]
+ st_link_info = sp.check_output(usb_args, **get_subprocess_args())
+ st_serials = []
+ for line in st_link_info.split("\n"):
+ if "iSerial" not in line:
+ continue
+ words = line.split()
+ if len(words) <= 2:
+ continue
+ st_serials.append(words[2].strip())
+ return st_serials
+
+ @abstractmethod
+ def get_serial(self):
+ """Subclass should implement this."""
+ pass
+
+ def send_openocd_commands(self, commands):
+ """Send a command to the board via openocd.
+
+ Args:
+ commands: A list of commands to send
+
+ Returns:
+ True if execution is successful or False otherwise.
+ """
+ args = [
+ "sudo",
+ "openocd",
+ "-s",
+ OCD_SCRIPT_DIR,
+ "-f",
+ self.openocd_config,
+ "-c",
+ "hla_serial " + self.hla_serial,
+ ]
+
+ for cmd in commands:
+ args += ["-c", cmd]
+ args += ["-c", "shutdown"]
+
+ rv = 1
+ with open(self.openocd_log, "a") as output:
+ rv = sp.call(args, stdout=output, stderr=sp.STDOUT)
+
+ if rv != 0:
+ self.dump_openocd_log()
+
+ return rv == 0
+
+ def dump_openocd_log(self):
+ with open(self.openocd_log) as log:
+ print(log.read())
+
+ def build(self, ec_dir):
+ """Builds test suite module for board.
+
+ Args:
+ ec_dir: String of the ec directory path
+
+ Returns:
+ True if build is successful or False otherwise.
+ """
+ cmds = [
+ "make",
+ "--directory=" + ec_dir,
+ "BOARD=" + self.board,
+ "CTS_MODULE=" + self.module,
+ "-j",
+ ]
+
+ rv = 1
+ with open(self.build_log, "a") as output:
+ rv = sp.call(cmds, stdout=output, stderr=sp.STDOUT)
+
+ if rv != 0:
+ self.dump_build_log()
+
+ return rv == 0
+
+ def dump_build_log(self):
+ with open(self.build_log) as log:
+ print(log.read())
+
+ def flash(self, image_path):
+ """Flashes board with most recent build ec.bin."""
+ cmd = [
+ "reset_config connect_assert_srst",
+ "init",
+ "reset init",
+ "flash write_image erase %s %s" % (image_path, self.flash_offset),
+ ]
+ return self.send_openocd_commands(cmd)
+
+ def to_string(self):
+ s = (
+ "Type: Board\n"
+ "board: " + self.board + "\n"
+ "hla_serial: " + self.hla_serial + "\n"
+ "openocd_config: " + self.openocd_config + "\n"
+ "tty_port: " + self.tty_port + "\n"
+ "tty: " + str(self.tty) + "\n"
+ )
+ return s
+
+ def reset_halt(self):
+ """Reset then halt board."""
+ return self.send_openocd_commands(["init", "reset halt"])
+
+ def resume(self):
+ """Resume halting board."""
+ return self.send_openocd_commands(["init", "resume"])
+
+ def setup_tty(self):
+ """Call this before calling read_tty for the first time.
+
+ This is not in the initialization because caller only should call
+ this function after serial numbers are setup
+ """
+ self.get_serial()
+ self.reset_halt()
+ self.identify_tty_port()
+
+ tty = None
+ try:
+ tty = serial.Serial(self.tty_port, 115200, timeout=1)
+ except serial.SerialException:
+ raise ValueError(
+ "Failed to open "
+ + self.tty_port
+ + " of "
+ + self.board
+ + ". Please make sure the port is available and you have"
+ + " permission to read it. Create dialout group and run:"
+ + " sudo usermod -a -G dialout <username>."
+ )
+ self.tty = tty
+
+ def read_tty(self, max_boot_count=1):
+ """Read info from a serial port described by a file descriptor.
+
+ Args:
+ max_boot_count: Stop reading if boot count exceeds this number
+
+ Returns:
+ result: characters read from tty
+ boot: boot counts
+ """
+ buf = []
+ line = []
+ boot = 0
+ while True:
+ c = self.tty.read().decode("utf-8")
+ if not c:
+ break
+ line.append(c)
+ if c == "\n":
+ l = "".join(line)
+ buf.append(l)
+ if REBOOT_MARKER in l:
+ boot += 1
+ line = []
+ if boot > max_boot_count:
+ break
+
+ l = "".join(line)
+ buf.append(l)
+ result = "".join(buf)
- def build(self, ec_dir):
- """Builds test suite module for board.
+ return result, boot
- Args:
- ec_dir: String of the ec directory path
+ def identify_tty_port(self):
+ """Saves this board's serial port."""
+ dev_dir = "/dev"
+ id_prefix = "ID_SERIAL_SHORT="
+ com_devices = [f for f in os.listdir(dev_dir) if f.startswith("ttyACM")]
- Returns:
- True if build is successful or False otherwise.
- """
- cmds = ['make',
- '--directory=' + ec_dir,
- 'BOARD=' + self.board,
- 'CTS_MODULE=' + self.module,
- '-j']
-
- rv = 1
- with open(self.build_log, 'a') as output:
- rv = sp.call(cmds, stdout=output, stderr=sp.STDOUT)
-
- if rv != 0:
- self.dump_build_log()
-
- return rv == 0
-
- def dump_build_log(self):
- with open(self.build_log) as log:
- print(log.read())
-
- def flash(self, image_path):
- """Flashes board with most recent build ec.bin."""
- cmd = ['reset_config connect_assert_srst',
- 'init',
- 'reset init',
- 'flash write_image erase %s %s' % (image_path, self.flash_offset)]
- return self.send_openocd_commands(cmd)
-
- def to_string(self):
- s = ('Type: Board\n'
- 'board: ' + self.board + '\n'
- 'hla_serial: ' + self.hla_serial + '\n'
- 'openocd_config: ' + self.openocd_config + '\n'
- 'tty_port: ' + self.tty_port + '\n'
- 'tty: ' + str(self.tty) + '\n')
- return s
-
- def reset_halt(self):
- """Reset then halt board."""
- return self.send_openocd_commands(['init', 'reset halt'])
-
- def resume(self):
- """Resume halting board."""
- return self.send_openocd_commands(['init', 'resume'])
-
- def setup_tty(self):
- """Call this before calling read_tty for the first time.
-
- This is not in the initialization because caller only should call
- this function after serial numbers are setup
- """
- self.get_serial()
- self.reset_halt()
- self.identify_tty_port()
-
- tty = None
- try:
- tty = serial.Serial(self.tty_port, 115200, timeout=1)
- except serial.SerialException:
- raise ValueError('Failed to open ' + self.tty_port + ' of ' + self.board +
- '. Please make sure the port is available and you have' +
- ' permission to read it. Create dialout group and run:' +
- ' sudo usermod -a -G dialout <username>.')
- self.tty = tty
-
- def read_tty(self, max_boot_count=1):
- """Read info from a serial port described by a file descriptor.
-
- Args:
- max_boot_count: Stop reading if boot count exceeds this number
-
- Returns:
- result: characters read from tty
- boot: boot counts
- """
- buf = []
- line = []
- boot = 0
- while True:
- c = self.tty.read().decode('utf-8')
- if not c:
- break
- line.append(c)
- if c == '\n':
- l = ''.join(line)
- buf.append(l)
- if REBOOT_MARKER in l:
- boot += 1
- line = []
- if boot > max_boot_count:
- break
-
- l = ''.join(line)
- buf.append(l)
- result = ''.join(buf)
-
- return result, boot
-
- def identify_tty_port(self):
- """Saves this board's serial port."""
- dev_dir = '/dev'
- id_prefix = 'ID_SERIAL_SHORT='
- com_devices = [f for f in os.listdir(dev_dir) if f.startswith('ttyACM')]
-
- for device in com_devices:
- self.tty_port = os.path.join(dev_dir, device)
- properties = sp.check_output(
- ['udevadm', 'info', '-a', '-n', self.tty_port, '--query=property'],
- **get_subprocess_args())
- for line in [l.strip() for l in properties.split('\n')]:
- if line.startswith(id_prefix):
- if self.hla_serial == line[len(id_prefix):]:
- return
+ for device in com_devices:
+ self.tty_port = os.path.join(dev_dir, device)
+ properties = sp.check_output(
+ ["udevadm", "info", "-a", "-n", self.tty_port, "--query=property"],
+ **get_subprocess_args()
+ )
+ for line in [l.strip() for l in properties.split("\n")]:
+ if line.startswith(id_prefix):
+ if self.hla_serial == line[len(id_prefix) :]:
+ return
- # If we get here without returning, something is wrong
- raise RuntimeError('The device dev path could not be found')
+ # If we get here without returning, something is wrong
+ raise RuntimeError("The device dev path could not be found")
- def close_tty(self):
- """Close tty."""
- self.tty.close()
+ def close_tty(self):
+ """Close tty."""
+ self.tty.close()
class TestHarness(Board):
- """Subclass of Board representing a Test Harness.
+ """Subclass of Board representing a Test Harness.
- Attributes:
- serial_path: Path to file containing serial number
- """
-
- def __init__(self, board, module, log_dir, serial_path):
- """Initializes a board object with given attributes.
-
- Args:
- board: board name
- module: module name
- log_dir: Directory where log file is stored
+ Attributes:
serial_path: Path to file containing serial number
"""
- Board.__init__(self, board, module)
- self.log_dir = log_dir
- self.openocd_log = os.path.join(log_dir, 'openocd_th.log')
- self.build_log = os.path.join(log_dir, 'build_th.log')
- self.serial_path = serial_path
- self.reset_log_dir()
-
- def get_serial(self):
- """Loads serial number from saved location."""
- if self.hla_serial:
- return # serial was already loaded
- try:
- with open(self.serial_path, mode='r') as f:
- s = f.read()
- self.hla_serial = s.strip()
+
+ def __init__(self, board, module, log_dir, serial_path):
+ """Initializes a board object with given attributes.
+
+ Args:
+ board: board name
+ module: module name
+ log_dir: Directory where log file is stored
+ serial_path: Path to file containing serial number
+ """
+ Board.__init__(self, board, module)
+ self.log_dir = log_dir
+ self.openocd_log = os.path.join(log_dir, "openocd_th.log")
+ self.build_log = os.path.join(log_dir, "build_th.log")
+ self.serial_path = serial_path
+ self.reset_log_dir()
+
+ def get_serial(self):
+ """Loads serial number from saved location."""
+ if self.hla_serial:
+ return # serial was already loaded
+ try:
+ with open(self.serial_path, mode="r") as f:
+ s = f.read()
+ self.hla_serial = s.strip()
+ return
+ except IOError:
+ msg = (
+ "Your TH board has not been identified.\n"
+ "Connect only TH and run the script --setup, then try again."
+ )
+ raise RuntimeError(msg)
+
+ def save_serial(self):
+ """Saves the TH serial number to a file."""
+ serials = Board.get_stlink_serials()
+ if len(serials) > 1:
+ msg = (
+ "There are more than one test board connected to the host."
+ "\nConnect only the test harness and remove other boards."
+ )
+ raise RuntimeError(msg)
+ if len(serials) < 1:
+ msg = "No test boards were found.\n" "Check boards are connected."
+ raise RuntimeError(msg)
+
+ s = serials[0]
+ serial_dir = os.path.dirname(self.serial_path)
+ if not os.path.exists(serial_dir):
+ os.makedirs(serial_dir)
+ with open(self.serial_path, mode="w") as f:
+ f.write(s)
+ self.hla_serial = s
+
+ print("Your TH serial", s, "has been saved as", self.serial_path)
return
- except IOError:
- msg = ('Your TH board has not been identified.\n'
- 'Connect only TH and run the script --setup, then try again.')
- raise RuntimeError(msg)
-
- def save_serial(self):
- """Saves the TH serial number to a file."""
- serials = Board.get_stlink_serials()
- if len(serials) > 1:
- msg = ('There are more than one test board connected to the host.'
- '\nConnect only the test harness and remove other boards.')
- raise RuntimeError(msg)
- if len(serials) < 1:
- msg = ('No test boards were found.\n'
- 'Check boards are connected.')
- raise RuntimeError(msg)
-
- s = serials[0]
- serial_dir = os.path.dirname(self.serial_path)
- if not os.path.exists(serial_dir):
- os.makedirs(serial_dir)
- with open(self.serial_path, mode='w') as f:
- f.write(s)
- self.hla_serial = s
-
- print('Your TH serial', s, 'has been saved as', self.serial_path)
- return
class DeviceUnderTest(Board):
- """Subclass of Board representing a DUT board.
+ """Subclass of Board representing a DUT board.
- Attributes:
- th: Reference to test harness board to which this DUT is attached
- """
-
- def __init__(self, board, th, module, log_dir, hla_ser=None):
- """Initializes a DUT object.
-
- Args:
- board: String containing board name
+ Attributes:
th: Reference to test harness board to which this DUT is attached
- module: module name
- log_dir: Directory where log file is stored
- hla_ser: Serial number if board uses an HLA adaptor
"""
- Board.__init__(self, board, module, hla_serial=hla_ser)
- self.th = th
- self.log_dir = log_dir
- self.openocd_log = os.path.join(log_dir, 'openocd_dut.log')
- self.build_log = os.path.join(log_dir, 'build_dut.log')
- self.reset_log_dir()
-
- def get_serial(self):
- """Get serial number.
- Precondition: The DUT and TH must both be connected, and th.hla_serial
- must hold the correct value (the th's serial #)
+ def __init__(self, board, th, module, log_dir, hla_ser=None):
+ """Initializes a DUT object.
+
+ Args:
+ board: String containing board name
+ th: Reference to test harness board to which this DUT is attached
+ module: module name
+ log_dir: Directory where log file is stored
+ hla_ser: Serial number if board uses an HLA adaptor
+ """
+ Board.__init__(self, board, module, hla_serial=hla_ser)
+ self.th = th
+ self.log_dir = log_dir
+ self.openocd_log = os.path.join(log_dir, "openocd_dut.log")
+ self.build_log = os.path.join(log_dir, "build_dut.log")
+ self.reset_log_dir()
+
+ def get_serial(self):
+ """Get serial number.
+
+ Precondition: The DUT and TH must both be connected, and th.hla_serial
+ must hold the correct value (the th's serial #)
+
+ Raises:
+ RuntimeError: DUT isn't found or multiple DUTs are found.
+ """
+ if self.hla_serial is not None:
+ # serial was already set ('' is a valid serial)
+ return
- Raises:
- RuntimeError: DUT isn't found or multiple DUTs are found.
- """
- if self.hla_serial is not None:
- # serial was already set ('' is a valid serial)
- return
-
- serials = Board.get_stlink_serials()
- dut = [s for s in serials if self.th.hla_serial != s]
-
- # If len(dut) is 0 then your dut doesn't use an st-link device, so we
- # don't have to worry about its serial number
- if not dut:
- msg = ('Failed to find serial for DUT.\n'
- 'Is ' + self.board + ' connected?')
- raise RuntimeError(msg)
- if len(dut) > 1:
- msg = ('Found multiple DUTs.\n'
- 'You can connect only one DUT at a time. This may be caused by\n'
- 'an incorrect TH serial. Check if ' + self.th.serial_path + '\n'
- 'contains a correct serial.')
- raise RuntimeError(msg)
-
- # Found your other st-link device serial!
- self.hla_serial = dut[0]
- return
+ serials = Board.get_stlink_serials()
+ dut = [s for s in serials if self.th.hla_serial != s]
+
+ # If len(dut) is 0 then your dut doesn't use an st-link device, so we
+ # don't have to worry about its serial number
+ if not dut:
+ msg = "Failed to find serial for DUT.\n" "Is " + self.board + " connected?"
+ raise RuntimeError(msg)
+ if len(dut) > 1:
+ msg = (
+ "Found multiple DUTs.\n"
+ "You can connect only one DUT at a time. This may be caused by\n"
+ "an incorrect TH serial. Check if " + self.th.serial_path + "\n"
+ "contains a correct serial."
+ )
+ raise RuntimeError(msg)
+
+ # Found your other st-link device serial!
+ self.hla_serial = dut[0]
+ return
diff --git a/cts/cts.py b/cts/cts.py
index c3e0335cab..ebc526c701 100755
--- a/cts/cts.py
+++ b/cts/cts.py
@@ -28,416 +28,424 @@ import argparse
import os
import shutil
import time
-import common.board as board
+import common.board as board
-CTS_RC_PREFIX = 'CTS_RC_'
-DEFAULT_TH = 'stm32l476g-eval'
-DEFAULT_DUT = 'nucleo-f072rb'
+CTS_RC_PREFIX = "CTS_RC_"
+DEFAULT_TH = "stm32l476g-eval"
+DEFAULT_DUT = "nucleo-f072rb"
MAX_SUITE_TIME_SEC = 5
-CTS_TEST_RESULT_DIR = '/tmp/ects'
+CTS_TEST_RESULT_DIR = "/tmp/ects"
# Host only return codes. Make sure they match values in cts.rc
-CTS_RC_DID_NOT_START = -1 # test did not run.
-CTS_RC_DID_NOT_END = -2 # test did not run.
-CTS_RC_DUPLICATE_RUN = -3 # test was run multiple times.
-CTS_RC_INVALID_RETURN_CODE = -4 # failed to parse return code
+CTS_RC_DID_NOT_START = -1 # test did not run.
+CTS_RC_DID_NOT_END = -2 # test did not run.
+CTS_RC_DUPLICATE_RUN = -3 # test was run multiple times.
+CTS_RC_INVALID_RETURN_CODE = -4 # failed to parse return code
class Cts(object):
- """Class that represents a eCTS run.
-
- Attributes:
- dut: DeviceUnderTest object representing DUT
- th: TestHarness object representing a test harness
- module: Name of module to build/run tests for
- testlist: List of strings of test names contained in given module
- return_codes: Dict of strings of return codes, with a code's integer
- value being the index for the corresponding string representation
- """
-
- def __init__(self, ec_dir, th, dut, module):
- """Initializes cts class object with given arguments.
-
- Args:
- ec_dir: Path to ec directory
- th: Name of the test harness board
- dut: Name of the device under test board
- module: Name of module to build/run tests for (e.g. gpio, interrupt)
- """
- self.results_dir = os.path.join(CTS_TEST_RESULT_DIR, dut, module)
- if os.path.isdir(self.results_dir):
- shutil.rmtree(self.results_dir)
- else:
- os.makedirs(self.results_dir)
- self.ec_dir = ec_dir
- self.module = module
- serial_path = os.path.join(CTS_TEST_RESULT_DIR, 'th_serial')
- self.th = board.TestHarness(th, module, self.results_dir, serial_path)
- self.dut = board.DeviceUnderTest(dut, self.th, module, self.results_dir)
- cts_dir = os.path.join(self.ec_dir, 'cts')
- testlist_path = os.path.join(cts_dir, self.module, 'cts.testlist')
- return_codes_path = os.path.join(cts_dir, 'common', 'cts.rc')
- self.get_return_codes(return_codes_path)
- self.testlist = self.get_macro_args(testlist_path, 'CTS_TEST')
-
- def build(self):
- """Build images for DUT and TH."""
- print('Building DUT image...')
- if not self.dut.build(self.ec_dir):
- raise RuntimeError('Building module %s for DUT failed' % (self.module))
- print('Building TH image...')
- if not self.th.build(self.ec_dir):
- raise RuntimeError('Building module %s for TH failed' % (self.module))
-
- def flash_boards(self):
- """Flashes TH and DUT with their most recently built ec.bin."""
- cts_module = 'cts_' + self.module
- image_path = os.path.join('build', self.th.board, cts_module, 'ec.bin')
- self.identify_boards()
- print('Flashing TH with', image_path)
- if not self.th.flash(image_path):
- raise RuntimeError('Flashing TH failed')
- image_path = os.path.join('build', self.dut.board, cts_module, 'ec.bin')
- print('Flashing DUT with', image_path)
- if not self.dut.flash(image_path):
- raise RuntimeError('Flashing DUT failed')
-
- def setup(self):
- """Setup boards."""
- self.th.save_serial()
-
- def identify_boards(self):
- """Updates serials of TH and DUT in that order (order matters)."""
- self.th.get_serial()
- self.dut.get_serial()
-
- def get_macro_args(self, filepath, macro):
- """Get list of args of a macro in a file when macro.
-
- Args:
- filepath: String containing absolute path to the file
- macro: String containing text of macro to get args of
-
- Returns:
- List of dictionaries where each entry is:
- 'name': Test name,
- 'th_string': Expected string from TH,
- 'dut_string': Expected string from DUT,
- """
- tests = []
- with open(filepath, 'r') as f:
- lines = f.readlines()
- joined = ''.join(lines).replace('\\\n', '').splitlines()
- for l in joined:
- if not l.strip().startswith(macro):
- continue
- d = {}
- l = l.strip()[len(macro):]
- l = l.strip('()').split(',')
- d['name'] = l[0].strip()
- d['th_rc'] = self.get_return_code_value(l[1].strip().strip('"'))
- d['th_string'] = l[2].strip().strip('"')
- d['dut_rc'] = self.get_return_code_value(l[3].strip().strip('"'))
- d['dut_string'] = l[4].strip().strip('"')
- tests.append(d)
- return tests
-
- def get_return_codes(self, filepath):
- """Read return code names from the return code definition file."""
- self.return_codes = {}
- val = 0
- with open(filepath, 'r') as f:
- for line in f:
- line = line.strip()
- if not line.startswith(CTS_RC_PREFIX):
- continue
- line = line.split(',')[0]
- if '=' in line:
- tokens = line.split('=')
- line = tokens[0].strip()
- val = int(tokens[1].strip())
- self.return_codes[line] = val
- val += 1
-
- def parse_output(self, output):
- """Parse console output from DUT or TH.
-
- Args:
- output: String containing consoule output
-
- Returns:
- List of dictionaries where each key and value are:
- name = 'ects_test_x',
- started = True/False,
- ended = True/False,
- rc = CTS_RC_*,
- output = All text between 'ects_test_x start' and 'ects_test_x end'
- """
- results = []
- i = 0
- for test in self.testlist:
- results.append({})
- results[i]['name'] = test['name']
- results[i]['started'] = False
- results[i]['rc'] = CTS_RC_DID_NOT_START
- results[i]['string'] = False
- results[i]['output'] = []
- i += 1
-
- i = 0
- for ln in [ln.strip() for ln in output.split('\n')]:
- if i + 1 > len(results):
- break
- tokens = ln.split()
- if len(tokens) >= 2:
- if tokens[0].strip() == results[i]['name']:
- if tokens[1].strip() == 'start':
- # start line found
- if results[i]['started']: # Already started
- results[i]['rc'] = CTS_RC_DUPLICATE_RUN
- else:
- results[i]['rc'] = CTS_RC_DID_NOT_END
- results[i]['started'] = True
- continue
- elif results[i]['started'] and tokens[1].strip() == 'end':
- # end line found
- results[i]['rc'] = CTS_RC_INVALID_RETURN_CODE
- if len(tokens) == 3:
- try:
- results[i]['rc'] = int(tokens[2].strip())
- except ValueError:
- pass
- # Since index is incremented when 'end' is encountered, we don't
- # need to check duplicate 'end'.
- i += 1
- continue
- if results[i]['started']:
- results[i]['output'].append(ln)
-
- return results
-
- def get_return_code_name(self, code, strip_prefix=False):
- name = ''
- for k, v in self.return_codes.items():
- if v == code:
- if strip_prefix:
- name = k[len(CTS_RC_PREFIX):]
- else:
- name = k
- return name
-
- def get_return_code_value(self, name):
- if name:
- return self.return_codes[name]
- return 0
-
- def evaluate_run(self, dut_output, th_output):
- """Parse outputs to derive test results.
-
- Args:
- dut_output: String output of DUT
- th_output: String output of TH
-
- Returns:
- th_results: list of test results for TH
- dut_results: list of test results for DUT
+ """Class that represents a eCTS run.
+
+ Attributes:
+ dut: DeviceUnderTest object representing DUT
+ th: TestHarness object representing a test harness
+ module: Name of module to build/run tests for
+ testlist: List of strings of test names contained in given module
+ return_codes: Dict of strings of return codes, with a code's integer
+ value being the index for the corresponding string representation
"""
- th_results = self.parse_output(th_output)
- dut_results = self.parse_output(dut_output)
- # Search for expected string in each output
- for i, v in enumerate(self.testlist):
- if v['th_string'] in th_results[i]['output'] or not v['th_string']:
- th_results[i]['string'] = True
- if v['dut_string'] in dut_results[i]['output'] or not v['dut_string']:
- dut_results[i]['string'] = True
+ def __init__(self, ec_dir, th, dut, module):
+ """Initializes cts class object with given arguments.
+
+ Args:
+ ec_dir: Path to ec directory
+ th: Name of the test harness board
+ dut: Name of the device under test board
+ module: Name of module to build/run tests for (e.g. gpio, interrupt)
+ """
+ self.results_dir = os.path.join(CTS_TEST_RESULT_DIR, dut, module)
+ if os.path.isdir(self.results_dir):
+ shutil.rmtree(self.results_dir)
+ else:
+ os.makedirs(self.results_dir)
+ self.ec_dir = ec_dir
+ self.module = module
+ serial_path = os.path.join(CTS_TEST_RESULT_DIR, "th_serial")
+ self.th = board.TestHarness(th, module, self.results_dir, serial_path)
+ self.dut = board.DeviceUnderTest(dut, self.th, module, self.results_dir)
+ cts_dir = os.path.join(self.ec_dir, "cts")
+ testlist_path = os.path.join(cts_dir, self.module, "cts.testlist")
+ return_codes_path = os.path.join(cts_dir, "common", "cts.rc")
+ self.get_return_codes(return_codes_path)
+ self.testlist = self.get_macro_args(testlist_path, "CTS_TEST")
+
+ def build(self):
+ """Build images for DUT and TH."""
+ print("Building DUT image...")
+ if not self.dut.build(self.ec_dir):
+ raise RuntimeError("Building module %s for DUT failed" % (self.module))
+ print("Building TH image...")
+ if not self.th.build(self.ec_dir):
+ raise RuntimeError("Building module %s for TH failed" % (self.module))
+
+ def flash_boards(self):
+ """Flashes TH and DUT with their most recently built ec.bin."""
+ cts_module = "cts_" + self.module
+ image_path = os.path.join("build", self.th.board, cts_module, "ec.bin")
+ self.identify_boards()
+ print("Flashing TH with", image_path)
+ if not self.th.flash(image_path):
+ raise RuntimeError("Flashing TH failed")
+ image_path = os.path.join("build", self.dut.board, cts_module, "ec.bin")
+ print("Flashing DUT with", image_path)
+ if not self.dut.flash(image_path):
+ raise RuntimeError("Flashing DUT failed")
+
+ def setup(self):
+ """Setup boards."""
+ self.th.save_serial()
+
+ def identify_boards(self):
+ """Updates serials of TH and DUT in that order (order matters)."""
+ self.th.get_serial()
+ self.dut.get_serial()
+
+ def get_macro_args(self, filepath, macro):
+ """Get list of args of a macro in a file when macro.
+
+ Args:
+ filepath: String containing absolute path to the file
+ macro: String containing text of macro to get args of
+
+ Returns:
+ List of dictionaries where each entry is:
+ 'name': Test name,
+ 'th_string': Expected string from TH,
+ 'dut_string': Expected string from DUT,
+ """
+ tests = []
+ with open(filepath, "r") as f:
+ lines = f.readlines()
+ joined = "".join(lines).replace("\\\n", "").splitlines()
+ for l in joined:
+ if not l.strip().startswith(macro):
+ continue
+ d = {}
+ l = l.strip()[len(macro) :]
+ l = l.strip("()").split(",")
+ d["name"] = l[0].strip()
+ d["th_rc"] = self.get_return_code_value(l[1].strip().strip('"'))
+ d["th_string"] = l[2].strip().strip('"')
+ d["dut_rc"] = self.get_return_code_value(l[3].strip().strip('"'))
+ d["dut_string"] = l[4].strip().strip('"')
+ tests.append(d)
+ return tests
+
+ def get_return_codes(self, filepath):
+ """Read return code names from the return code definition file."""
+ self.return_codes = {}
+ val = 0
+ with open(filepath, "r") as f:
+ for line in f:
+ line = line.strip()
+ if not line.startswith(CTS_RC_PREFIX):
+ continue
+ line = line.split(",")[0]
+ if "=" in line:
+ tokens = line.split("=")
+ line = tokens[0].strip()
+ val = int(tokens[1].strip())
+ self.return_codes[line] = val
+ val += 1
+
+ def parse_output(self, output):
+ """Parse console output from DUT or TH.
+
+ Args:
+ output: String containing consoule output
+
+ Returns:
+ List of dictionaries where each key and value are:
+ name = 'ects_test_x',
+ started = True/False,
+ ended = True/False,
+ rc = CTS_RC_*,
+ output = All text between 'ects_test_x start' and 'ects_test_x end'
+ """
+ results = []
+ i = 0
+ for test in self.testlist:
+ results.append({})
+ results[i]["name"] = test["name"]
+ results[i]["started"] = False
+ results[i]["rc"] = CTS_RC_DID_NOT_START
+ results[i]["string"] = False
+ results[i]["output"] = []
+ i += 1
- return th_results, dut_results
+ i = 0
+ for ln in [ln.strip() for ln in output.split("\n")]:
+ if i + 1 > len(results):
+ break
+ tokens = ln.split()
+ if len(tokens) >= 2:
+ if tokens[0].strip() == results[i]["name"]:
+ if tokens[1].strip() == "start":
+ # start line found
+ if results[i]["started"]: # Already started
+ results[i]["rc"] = CTS_RC_DUPLICATE_RUN
+ else:
+ results[i]["rc"] = CTS_RC_DID_NOT_END
+ results[i]["started"] = True
+ continue
+ elif results[i]["started"] and tokens[1].strip() == "end":
+ # end line found
+ results[i]["rc"] = CTS_RC_INVALID_RETURN_CODE
+ if len(tokens) == 3:
+ try:
+ results[i]["rc"] = int(tokens[2].strip())
+ except ValueError:
+ pass
+ # Since index is incremented when 'end' is encountered, we don't
+ # need to check duplicate 'end'.
+ i += 1
+ continue
+ if results[i]["started"]:
+ results[i]["output"].append(ln)
+
+ return results
+
+ def get_return_code_name(self, code, strip_prefix=False):
+ name = ""
+ for k, v in self.return_codes.items():
+ if v == code:
+ if strip_prefix:
+ name = k[len(CTS_RC_PREFIX) :]
+ else:
+ name = k
+ return name
+
+ def get_return_code_value(self, name):
+ if name:
+ return self.return_codes[name]
+ return 0
+
+ def evaluate_run(self, dut_output, th_output):
+ """Parse outputs to derive test results.
+
+ Args:
+ dut_output: String output of DUT
+ th_output: String output of TH
+
+ Returns:
+ th_results: list of test results for TH
+ dut_results: list of test results for DUT
+ """
+ th_results = self.parse_output(th_output)
+ dut_results = self.parse_output(dut_output)
+
+ # Search for expected string in each output
+ for i, v in enumerate(self.testlist):
+ if v["th_string"] in th_results[i]["output"] or not v["th_string"]:
+ th_results[i]["string"] = True
+ if v["dut_string"] in dut_results[i]["output"] or not v["dut_string"]:
+ dut_results[i]["string"] = True
+
+ return th_results, dut_results
+
+ def print_result(self, th_results, dut_results):
+ """Print results to the screen.
+
+ Args:
+ th_results: list of test results for TH
+ dut_results: list of test results for DUT
+ """
+ len_test_name = max(len(s["name"]) for s in self.testlist)
+ len_code_name = max(
+ len(self.get_return_code_name(v, True)) for v in self.return_codes.values()
+ )
+
+ head = "{:^" + str(len_test_name) + "} "
+ head += "{:^" + str(len_code_name) + "} "
+ head += "{:^" + str(len_code_name) + "}"
+ head += "{:^" + str(len(" TH_STR")) + "}"
+ head += "{:^" + str(len(" DUT_STR")) + "}"
+ head += "{:^" + str(len(" RESULT")) + "}\n"
+ fmt = "{:" + str(len_test_name) + "} "
+ fmt += "{:>" + str(len_code_name) + "} "
+ fmt += "{:>" + str(len_code_name) + "}"
+ fmt += "{:>" + str(len(" TH_STR")) + "}"
+ fmt += "{:>" + str(len(" DUT_STR")) + "}"
+ fmt += "{:>" + str(len(" RESULT")) + "}\n"
+
+ self.formatted_results = head.format(
+ "TEST NAME", "TH_RC", "DUT_RC", " TH_STR", " DUT_STR", " RESULT"
+ )
+ for i, d in enumerate(dut_results):
+ th_cn = self.get_return_code_name(th_results[i]["rc"], True)
+ dut_cn = self.get_return_code_name(dut_results[i]["rc"], True)
+ th_res = self.evaluate_result(
+ th_results[i], self.testlist[i]["th_rc"], self.testlist[i]["th_string"]
+ )
+ dut_res = self.evaluate_result(
+ dut_results[i],
+ self.testlist[i]["dut_rc"],
+ self.testlist[i]["dut_string"],
+ )
+ self.formatted_results += fmt.format(
+ d["name"],
+ th_cn,
+ dut_cn,
+ "YES" if th_results[i]["string"] else "NO",
+ "YES" if dut_results[i]["string"] else "NO",
+ "PASS" if th_res and dut_res else "FAIL",
+ )
+
+ def evaluate_result(self, result, expected_rc, expected_string):
+ if result["rc"] != expected_rc:
+ return False
+ if expected_string and expected_string not in result["output"]:
+ return False
+ return True
+
+ def run(self):
+ """Resets boards, records test results in results dir."""
+ print("Reading serials...")
+ self.identify_boards()
+ print("Opening DUT tty...")
+ self.dut.setup_tty()
+ print("Opening TH tty...")
+ self.th.setup_tty()
+
+ # Boards might be still writing to tty. Wait a few seconds before flashing.
+ time.sleep(3)
+
+ # clear buffers
+ print("Clearing DUT tty...")
+ self.dut.read_tty()
+ print("Clearing TH tty...")
+ self.th.read_tty()
+
+ # Resets the boards and allows them to run tests
+ # Due to current (7/27/16) version of sync function,
+ # both boards must be rest and halted, with the th
+ # resuming first, in order for the test suite to run in sync
+ print("Halting TH...")
+ if not self.th.reset_halt():
+ raise RuntimeError("Failed to halt TH")
+ print("Halting DUT...")
+ if not self.dut.reset_halt():
+ raise RuntimeError("Failed to halt DUT")
+ print("Resuming TH...")
+ if not self.th.resume():
+ raise RuntimeError("Failed to resume TH")
+ print("Resuming DUT...")
+ if not self.dut.resume():
+ raise RuntimeError("Failed to resume DUT")
+
+ time.sleep(MAX_SUITE_TIME_SEC)
+
+ print("Reading DUT tty...")
+ dut_output, _ = self.dut.read_tty()
+ self.dut.close_tty()
+ print("Reading TH tty...")
+ th_output, _ = self.th.read_tty()
+ self.th.close_tty()
+
+ print("Halting TH...")
+ if not self.th.reset_halt():
+ raise RuntimeError("Failed to halt TH")
+ print("Halting DUT...")
+ if not self.dut.reset_halt():
+ raise RuntimeError("Failed to halt DUT")
+
+ if not dut_output or not th_output:
+ raise ValueError(
+ "Output missing from boards. If you have a process "
+ "reading ttyACMx, please kill that process and try "
+ "again."
+ )
+
+ print("Pursing results...")
+ th_results, dut_results = self.evaluate_run(dut_output, th_output)
+
+ # Print out results
+ self.print_result(th_results, dut_results)
+
+ # Write results
+ dest = os.path.join(self.results_dir, "results.log")
+ with open(dest, "w") as fl:
+ fl.write(self.formatted_results)
+
+ # Write UART outputs
+ dest = os.path.join(self.results_dir, "uart_th.log")
+ with open(dest, "w") as fl:
+ fl.write(th_output)
+ dest = os.path.join(self.results_dir, "uart_dut.log")
+ with open(dest, "w") as fl:
+ fl.write(dut_output)
+
+ print(self.formatted_results)
+
+ # TODO(chromium:735652): Should set exit code for the shell
- def print_result(self, th_results, dut_results):
- """Print results to the screen.
- Args:
- th_results: list of test results for TH
- dut_results: list of test results for DUT
- """
- len_test_name = max(len(s['name']) for s in self.testlist)
- len_code_name = max(len(self.get_return_code_name(v, True))
- for v in self.return_codes.values())
-
- head = '{:^' + str(len_test_name) + '} '
- head += '{:^' + str(len_code_name) + '} '
- head += '{:^' + str(len_code_name) + '}'
- head += '{:^' + str(len(' TH_STR')) + '}'
- head += '{:^' + str(len(' DUT_STR')) + '}'
- head += '{:^' + str(len(' RESULT')) + '}\n'
- fmt = '{:' + str(len_test_name) + '} '
- fmt += '{:>' + str(len_code_name) + '} '
- fmt += '{:>' + str(len_code_name) + '}'
- fmt += '{:>' + str(len(' TH_STR')) + '}'
- fmt += '{:>' + str(len(' DUT_STR')) + '}'
- fmt += '{:>' + str(len(' RESULT')) + '}\n'
-
- self.formatted_results = head.format(
- 'TEST NAME', 'TH_RC', 'DUT_RC',
- ' TH_STR', ' DUT_STR', ' RESULT')
- for i, d in enumerate(dut_results):
- th_cn = self.get_return_code_name(th_results[i]['rc'], True)
- dut_cn = self.get_return_code_name(dut_results[i]['rc'], True)
- th_res = self.evaluate_result(th_results[i],
- self.testlist[i]['th_rc'],
- self.testlist[i]['th_string'])
- dut_res = self.evaluate_result(dut_results[i],
- self.testlist[i]['dut_rc'],
- self.testlist[i]['dut_string'])
- self.formatted_results += fmt.format(
- d['name'], th_cn, dut_cn,
- 'YES' if th_results[i]['string'] else 'NO',
- 'YES' if dut_results[i]['string'] else 'NO',
- 'PASS' if th_res and dut_res else 'FAIL')
-
- def evaluate_result(self, result, expected_rc, expected_string):
- if result['rc'] != expected_rc:
- return False
- if expected_string and expected_string not in result['output']:
- return False
- return True
-
- def run(self):
- """Resets boards, records test results in results dir."""
- print('Reading serials...')
- self.identify_boards()
- print('Opening DUT tty...')
- self.dut.setup_tty()
- print('Opening TH tty...')
- self.th.setup_tty()
-
- # Boards might be still writing to tty. Wait a few seconds before flashing.
- time.sleep(3)
-
- # clear buffers
- print('Clearing DUT tty...')
- self.dut.read_tty()
- print('Clearing TH tty...')
- self.th.read_tty()
-
- # Resets the boards and allows them to run tests
- # Due to current (7/27/16) version of sync function,
- # both boards must be rest and halted, with the th
- # resuming first, in order for the test suite to run in sync
- print('Halting TH...')
- if not self.th.reset_halt():
- raise RuntimeError('Failed to halt TH')
- print('Halting DUT...')
- if not self.dut.reset_halt():
- raise RuntimeError('Failed to halt DUT')
- print('Resuming TH...')
- if not self.th.resume():
- raise RuntimeError('Failed to resume TH')
- print('Resuming DUT...')
- if not self.dut.resume():
- raise RuntimeError('Failed to resume DUT')
-
- time.sleep(MAX_SUITE_TIME_SEC)
-
- print('Reading DUT tty...')
- dut_output, _ = self.dut.read_tty()
- self.dut.close_tty()
- print('Reading TH tty...')
- th_output, _ = self.th.read_tty()
- self.th.close_tty()
-
- print('Halting TH...')
- if not self.th.reset_halt():
- raise RuntimeError('Failed to halt TH')
- print('Halting DUT...')
- if not self.dut.reset_halt():
- raise RuntimeError('Failed to halt DUT')
-
- if not dut_output or not th_output:
- raise ValueError('Output missing from boards. If you have a process '
- 'reading ttyACMx, please kill that process and try '
- 'again.')
-
- print('Pursing results...')
- th_results, dut_results = self.evaluate_run(dut_output, th_output)
-
- # Print out results
- self.print_result(th_results, dut_results)
-
- # Write results
- dest = os.path.join(self.results_dir, 'results.log')
- with open(dest, 'w') as fl:
- fl.write(self.formatted_results)
-
- # Write UART outputs
- dest = os.path.join(self.results_dir, 'uart_th.log')
- with open(dest, 'w') as fl:
- fl.write(th_output)
- dest = os.path.join(self.results_dir, 'uart_dut.log')
- with open(dest, 'w') as fl:
- fl.write(dut_output)
-
- print(self.formatted_results)
-
- # TODO(chromium:735652): Should set exit code for the shell
+def main():
+ ec_dir = os.path.realpath(
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
+ )
+ os.chdir(ec_dir)
+
+ dut = DEFAULT_DUT
+ module = "meta"
+
+ parser = argparse.ArgumentParser(description="Used to build/flash boards")
+ parser.add_argument("-d", "--dut", help="Specify DUT you want to build/flash")
+ parser.add_argument("-m", "--module", help="Specify module you want to build/flash")
+ parser.add_argument(
+ "-s",
+ "--setup",
+ action="store_true",
+ help="Connect only the TH to save its serial",
+ )
+ parser.add_argument(
+ "-b", "--build", action="store_true", help="Build test suite (no flashing)"
+ )
+ parser.add_argument(
+ "-f",
+ "--flash",
+ action="store_true",
+ help="Flash boards with most recent images",
+ )
+ parser.add_argument(
+ "-r", "--run", action="store_true", help="Run tests without flashing"
+ )
+
+ args = parser.parse_args()
+
+ if args.module:
+ module = args.module
+
+ if args.dut:
+ dut = args.dut
+
+ cts = Cts(ec_dir, DEFAULT_TH, dut=dut, module=module)
+
+ if args.setup:
+ cts.setup()
+ elif args.build:
+ cts.build()
+ elif args.flash:
+ cts.flash_boards()
+ elif args.run:
+ cts.run()
+ else:
+ cts.build()
+ cts.flash_boards()
+ cts.run()
-def main():
- ec_dir = os.path.realpath(os.path.join(
- os.path.dirname(os.path.abspath(__file__)), '..'))
- os.chdir(ec_dir)
-
- dut = DEFAULT_DUT
- module = 'meta'
-
- parser = argparse.ArgumentParser(description='Used to build/flash boards')
- parser.add_argument('-d',
- '--dut',
- help='Specify DUT you want to build/flash')
- parser.add_argument('-m',
- '--module',
- help='Specify module you want to build/flash')
- parser.add_argument('-s',
- '--setup',
- action='store_true',
- help='Connect only the TH to save its serial')
- parser.add_argument('-b',
- '--build',
- action='store_true',
- help='Build test suite (no flashing)')
- parser.add_argument('-f',
- '--flash',
- action='store_true',
- help='Flash boards with most recent images')
- parser.add_argument('-r',
- '--run',
- action='store_true',
- help='Run tests without flashing')
-
- args = parser.parse_args()
-
- if args.module:
- module = args.module
-
- if args.dut:
- dut = args.dut
-
- cts = Cts(ec_dir, DEFAULT_TH, dut=dut, module=module)
-
- if args.setup:
- cts.setup()
- elif args.build:
- cts.build()
- elif args.flash:
- cts.flash_boards()
- elif args.run:
- cts.run()
- else:
- cts.build()
- cts.flash_boards()
- cts.run()
-
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/extra/cr50_rma_open/cr50_rma_open.py b/extra/cr50_rma_open/cr50_rma_open.py
index 42ddbbac2d..b77b8f3dbb 100755
--- a/extra/cr50_rma_open/cr50_rma_open.py
+++ b/extra/cr50_rma_open/cr50_rma_open.py
@@ -57,16 +57,17 @@ CCD_IS_UNRESTRICTED = 1 << 0
WP_IS_DISABLED = 1 << 1
TESTLAB_IS_ENABLED = 1 << 2
RMA_OPENED = CCD_IS_UNRESTRICTED | WP_IS_DISABLED
-URL = ('https://www.google.com/chromeos/partner/console/cr50reset?'
- 'challenge=%s&hwid=%s')
-RMA_SUPPORT_PROD = '0.3.3'
-RMA_SUPPORT_PREPVT = '0.4.5'
-DEV_MODE_OPEN_PROD = '0.3.9'
-DEV_MODE_OPEN_PREPVT = '0.4.7'
-TESTLAB_PROD = '0.3.10'
-CR50_USB = '18d1:5014'
-CR50_LSUSB_CMD = ['lsusb', '-vd', CR50_USB]
-ERASED_BID = 'ffffffff'
+URL = (
+ "https://www.google.com/chromeos/partner/console/cr50reset?" "challenge=%s&hwid=%s"
+)
+RMA_SUPPORT_PROD = "0.3.3"
+RMA_SUPPORT_PREPVT = "0.4.5"
+DEV_MODE_OPEN_PROD = "0.3.9"
+DEV_MODE_OPEN_PREPVT = "0.4.7"
+TESTLAB_PROD = "0.3.10"
+CR50_USB = "18d1:5014"
+CR50_LSUSB_CMD = ["lsusb", "-vd", CR50_USB]
+ERASED_BID = "ffffffff"
DEBUG_MISSING_USB = """
Unable to find Cr50 Device 18d1:5014
@@ -128,13 +129,14 @@ DEBUG_DUT_CONTROL_OSERROR = """
Run from chroot if you are trying to use a /dev/pts ccd servo console
"""
+
class RMAOpen(object):
"""Used to find the cr50 console and run RMA open"""
- ENABLE_TESTLAB_CMD = 'ccd testlab enabled\n'
+ ENABLE_TESTLAB_CMD = "ccd testlab enabled\n"
def __init__(self, device=None, usb_serial=None, servo_port=None, ip=None):
- self.servo_port = servo_port if servo_port else '9999'
+ self.servo_port = servo_port if servo_port else "9999"
self.ip = ip
if device:
self.set_cr50_device(device)
@@ -142,18 +144,18 @@ class RMAOpen(object):
self.find_cr50_servo_uart()
else:
self.find_cr50_device(usb_serial)
- logging.info('DEVICE: %s', self.device)
+ logging.info("DEVICE: %s", self.device)
self.check_version()
self.print_platform_info()
- logging.info('Cr50 setup ok')
+ logging.info("Cr50 setup ok")
self.update_ccd_state()
self.using_ccd = self.device_is_running_with_servo_ccd()
def _dut_control(self, control):
"""Run dut-control and return the response"""
try:
- cmd = ['dut-control', '-p', self.servo_port, control]
- return subprocess.check_output(cmd, encoding='utf-8').strip()
+ cmd = ["dut-control", "-p", self.servo_port, control]
+ return subprocess.check_output(cmd, encoding="utf-8").strip()
except OSError:
logging.warning(DEBUG_DUT_CONTROL_OSERROR)
raise
@@ -163,8 +165,8 @@ class RMAOpen(object):
Find the console and configure it, so it can be used with this script.
"""
- self._dut_control('cr50_uart_timestamp:off')
- self.device = self._dut_control('cr50_uart_pty').split(':')[-1]
+ self._dut_control("cr50_uart_timestamp:off")
+ self.device = self._dut_control("cr50_uart_pty").split(":")[-1]
def set_cr50_device(self, device):
"""Save the device used for the console"""
@@ -183,38 +185,38 @@ class RMAOpen(object):
try:
ser = serial.Serial(self.device, timeout=1)
except OSError:
- logging.warning('Permission denied %s', self.device)
- logging.warning('Try running cr50_rma_open with sudo')
+ logging.warning("Permission denied %s", self.device)
+ logging.warning("Try running cr50_rma_open with sudo")
raise
- write_cmd = cmd + '\n\n'
- ser.write(write_cmd.encode('utf-8'))
+ write_cmd = cmd + "\n\n"
+ ser.write(write_cmd.encode("utf-8"))
if nbytes:
output = ser.read(nbytes)
else:
output = ser.readall()
ser.close()
- output = output.decode('utf-8').strip() if output else ''
+ output = output.decode("utf-8").strip() if output else ""
# Return only the command output
- split_cmd = cmd + '\r'
+ split_cmd = cmd + "\r"
if cmd and split_cmd in output:
- return ''.join(output.rpartition(split_cmd)[1::]).split('>')[0]
+ return "".join(output.rpartition(split_cmd)[1::]).split(">")[0]
return output
def device_is_running_with_servo_ccd(self):
"""Return True if the device is a servod ccd console"""
# servod uses /dev/pts consoles. Non-servod uses /dev/ttyUSBX
- if '/dev/pts' not in self.device:
+ if "/dev/pts" not in self.device:
return False
# If cr50 doesn't show rdd is connected, cr50 the device must not be
# a ccd device
- if 'Rdd: connected' not in self.send_cmd_get_output('ccdstate'):
+ if "Rdd: connected" not in self.send_cmd_get_output("ccdstate"):
return False
# Check if the servod is running with ccd. This requires the script
# is run in the chroot, so run it last.
- if 'ccd_cr50' not in self._dut_control('servo_type'):
+ if "ccd_cr50" not in self._dut_control("servo_type"):
return False
- logging.info('running through servod ccd')
+ logging.info("running through servod ccd")
return True
def get_rma_challenge(self):
@@ -239,14 +241,14 @@ class RMAOpen(object):
Returns:
The RMA challenge with all whitespace removed.
"""
- output = self.send_cmd_get_output('rma_auth').strip()
- logging.info('rma_auth output:\n%s', output)
+ output = self.send_cmd_get_output("rma_auth").strip()
+ logging.info("rma_auth output:\n%s", output)
# Extract the challenge from the console output
- if 'generated challenge:' in output:
- return output.split('generated challenge:')[-1].strip()
- challenge = ''.join(re.findall(r' \S{5}' * 4, output))
+ if "generated challenge:" in output:
+ return output.split("generated challenge:")[-1].strip()
+ challenge = "".join(re.findall(r" \S{5}" * 4, output))
# Remove all whitespace
- return re.sub(r'\s', '', challenge)
+ return re.sub(r"\s", "", challenge)
def generate_challenge_url(self, hwid):
"""Get the rma_auth challenge
@@ -257,12 +259,14 @@ class RMAOpen(object):
challenge = self.get_rma_challenge()
self.print_platform_info()
- logging.info('CHALLENGE: %s', challenge)
- logging.info('HWID: %s', hwid)
+ logging.info("CHALLENGE: %s", challenge)
+ logging.info("HWID: %s", hwid)
url = URL % (challenge, hwid)
- logging.info('GOTO:\n %s', url)
- logging.info('If the server fails to debug the challenge make sure the '
- 'RLZ is allowlisted')
+ logging.info("GOTO:\n %s", url)
+ logging.info(
+ "If the server fails to debug the challenge make sure the "
+ "RLZ is allowlisted"
+ )
def try_authcode(self, authcode):
"""Try opening cr50 with the authcode
@@ -272,48 +276,48 @@ class RMAOpen(object):
"""
# rma_auth may cause the system to reboot. Don't wait to read all that
# output. Read the first 300 bytes and call it a day.
- output = self.send_cmd_get_output('rma_auth ' + authcode, nbytes=300)
- logging.info('CR50 RESPONSE: %s', output)
- logging.info('waiting for cr50 reboot')
+ output = self.send_cmd_get_output("rma_auth " + authcode, nbytes=300)
+ logging.info("CR50 RESPONSE: %s", output)
+ logging.info("waiting for cr50 reboot")
# Cr50 may be rebooting. Wait a bit
time.sleep(5)
if self.using_ccd:
# After reboot, reset the ccd endpoints
- self._dut_control('power_state:ccd_reset')
+ self._dut_control("power_state:ccd_reset")
# Update the ccd state after the authcode attempt
self.update_ccd_state()
- authcode_match = 'process_response: success!' in output
+ authcode_match = "process_response: success!" in output
if not self.check(CCD_IS_UNRESTRICTED):
if not authcode_match:
logging.warning(DEBUG_AUTHCODE_MISMATCH)
- message = 'Authcode mismatch. Check args and url'
+ message = "Authcode mismatch. Check args and url"
else:
- message = 'Could not set all capability privileges to Always'
+ message = "Could not set all capability privileges to Always"
raise ValueError(message)
def wp_is_force_disabled(self):
"""Returns True if write protect is forced disabled"""
- output = self.send_cmd_get_output('wp')
- wp_state = output.split('Flash WP:', 1)[-1].split('\n', 1)[0].strip()
- logging.info('wp: %s', wp_state)
- return wp_state == 'forced disabled'
+ output = self.send_cmd_get_output("wp")
+ wp_state = output.split("Flash WP:", 1)[-1].split("\n", 1)[0].strip()
+ logging.info("wp: %s", wp_state)
+ return wp_state == "forced disabled"
def testlab_is_enabled(self):
"""Returns True if testlab mode is enabled"""
- output = self.send_cmd_get_output('ccd testlab')
- testlab_state = output.split('mode')[-1].strip().lower()
- logging.info('testlab: %s', testlab_state)
- return testlab_state == 'enabled'
+ output = self.send_cmd_get_output("ccd testlab")
+ testlab_state = output.split("mode")[-1].strip().lower()
+ logging.info("testlab: %s", testlab_state)
+ return testlab_state == "enabled"
def ccd_is_restricted(self):
"""Returns True if any of the capabilities are still restricted"""
- output = self.send_cmd_get_output('ccd')
- if 'Capabilities' not in output:
- raise ValueError('Could not get ccd output')
- logging.debug('CURRENT CCD SETTINGS:\n%s', output)
- restricted = 'IfOpened' in output or 'IfUnlocked' in output
- logging.info('ccd: %srestricted', '' if restricted else 'Un')
+ output = self.send_cmd_get_output("ccd")
+ if "Capabilities" not in output:
+ raise ValueError("Could not get ccd output")
+ logging.debug("CURRENT CCD SETTINGS:\n%s", output)
+ restricted = "IfOpened" in output or "IfUnlocked" in output
+ logging.info("ccd: %srestricted", "" if restricted else "Un")
return restricted
def update_ccd_state(self):
@@ -339,9 +343,10 @@ class RMAOpen(object):
def _capabilities_allow_open_from_console(self):
"""Return True if ccd open is Always allowed from usb"""
- output = self.send_cmd_get_output('ccd')
- return (re.search('OpenNoDevMode.*Always', output) and
- re.search('OpenFromUSB.*Always', output))
+ output = self.send_cmd_get_output("ccd")
+ return re.search("OpenNoDevMode.*Always", output) and re.search(
+ "OpenFromUSB.*Always", output
+ )
def _requires_dev_mode_open(self):
"""Return True if the image requires dev mode to open"""
@@ -354,78 +359,81 @@ class RMAOpen(object):
def _run_on_dut(self, command):
"""Run the command on the DUT."""
- return subprocess.check_output(['ssh', self.ip, command],
- encoding='utf-8')
+ return subprocess.check_output(["ssh", self.ip, command], encoding="utf-8")
def _open_in_dev_mode(self):
"""Open Cr50 when it's in dev mode"""
- output = self.send_cmd_get_output('ccd')
+ output = self.send_cmd_get_output("ccd")
# If the device is already open, nothing needs to be done.
- if 'State: Open' not in output:
+ if "State: Open" not in output:
# Verify the device is in devmode before trying to run open.
- if 'dev_mode' not in output:
- logging.warning('Enter dev mode to open ccd or update to %s',
- TESTLAB_PROD)
- raise ValueError('DUT not in dev mode')
+ if "dev_mode" not in output:
+ logging.warning(
+ "Enter dev mode to open ccd or update to %s", TESTLAB_PROD
+ )
+ raise ValueError("DUT not in dev mode")
if not self.ip:
- logging.warning("If your DUT doesn't have ssh support, run "
- "'gsctool -a -o' from the AP")
- raise ValueError('Cannot run ccd open without dut ip')
- self._run_on_dut('gsctool -a -o')
+ logging.warning(
+ "If your DUT doesn't have ssh support, run "
+ "'gsctool -a -o' from the AP"
+ )
+ raise ValueError("Cannot run ccd open without dut ip")
+ self._run_on_dut("gsctool -a -o")
# Wait >1 second for cr50 to update ccd state
time.sleep(3)
- output = self.send_cmd_get_output('ccd')
- if 'State: Open' not in output:
- raise ValueError('Could not open cr50')
- logging.info('ccd is open')
+ output = self.send_cmd_get_output("ccd")
+ if "State: Open" not in output:
+ raise ValueError("Could not open cr50")
+ logging.info("ccd is open")
def enable_testlab(self):
"""Disable write protect"""
if not self._has_testlab_support():
- logging.warning('Testlab mode is not supported in prod iamges')
+ logging.warning("Testlab mode is not supported in prod iamges")
return
# Some cr50 images need to be in dev mode before they can be opened.
if self._requires_dev_mode_open():
self._open_in_dev_mode()
else:
- self.send_cmd_get_output('ccd open')
- logging.info('Enabling testlab mode reqires pressing the power button.')
- logging.info('Once the process starts keep tapping the power button '
- 'for 10 seconds.')
+ self.send_cmd_get_output("ccd open")
+ logging.info("Enabling testlab mode reqires pressing the power button.")
+ logging.info(
+ "Once the process starts keep tapping the power button " "for 10 seconds."
+ )
input("Press Enter when you're ready to start...")
end_time = time.time() + 15
ser = serial.Serial(self.device, timeout=1)
- printed_lines = ''
- output = ''
+ printed_lines = ""
+ output = ""
# start ccd testlab enable
- ser.write(self.ENABLE_TESTLAB_CMD.encode('utf-8'))
- logging.info('start pressing the power button\n\n')
+ ser.write(self.ENABLE_TESTLAB_CMD.encode("utf-8"))
+ logging.info("start pressing the power button\n\n")
# Print all of the cr50 output as we get it, so the user will have more
# information about pressing the power button. Tapping the power button
# a couple of times should do it, but this will give us more confidence
# the process is still running/worked.
try:
while time.time() < end_time:
- output += ser.read(100).decode('utf-8')
- full_lines = output.rsplit('\n', 1)[0]
+ output += ser.read(100).decode("utf-8")
+ full_lines = output.rsplit("\n", 1)[0]
new_lines = full_lines
if printed_lines:
new_lines = full_lines.split(printed_lines, 1)[-1].strip()
- logging.info('\n%s', new_lines)
+ logging.info("\n%s", new_lines)
printed_lines = full_lines
# Make sure the process hasn't ended. If it has, print the last
# of the output and exit.
new_lines = output.split(printed_lines, 1)[-1]
- if 'CCD test lab mode enabled' in output:
+ if "CCD test lab mode enabled" in output:
# print the last of the ou
logging.info(new_lines)
break
- elif 'Physical presence check timeout' in output:
+ elif "Physical presence check timeout" in output:
logging.info(new_lines)
- logging.warning('Did not detect power button press in time')
- raise ValueError('Could not enable testlab mode try again')
+ logging.warning("Did not detect power button press in time")
+ raise ValueError("Could not enable testlab mode try again")
finally:
ser.close()
# Wait for the ccd hook to update things
@@ -433,44 +441,48 @@ class RMAOpen(object):
# Update the state after attempting to disable write protect
self.update_ccd_state()
if not self.check(TESTLAB_IS_ENABLED):
- raise ValueError('Could not enable testlab mode try again')
+ raise ValueError("Could not enable testlab mode try again")
def wp_disable(self):
"""Disable write protect"""
- logging.info('Disabling write protect')
- self.send_cmd_get_output('wp disable')
+ logging.info("Disabling write protect")
+ self.send_cmd_get_output("wp disable")
# Update the state after attempting to disable write protect
self.update_ccd_state()
if not self.check(WP_IS_DISABLED):
- raise ValueError('Could not disable write protect')
+ raise ValueError("Could not disable write protect")
def check_version(self):
"""Make sure cr50 is running a version that supports RMA Open"""
- output = self.send_cmd_get_output('version')
+ output = self.send_cmd_get_output("version")
if not output.strip():
logging.warning(DEBUG_DEVICE, self.device)
- raise ValueError('Could not communicate with %s' % self.device)
+ raise ValueError("Could not communicate with %s" % self.device)
- version = re.search(r'RW.*\* ([\d\.]+)/', output).group(1)
- logging.info('Running Cr50 Version: %s', version)
- self.running_ver_fields = [int(field) for field in version.split('.')]
+ version = re.search(r"RW.*\* ([\d\.]+)/", output).group(1)
+ logging.info("Running Cr50 Version: %s", version)
+ self.running_ver_fields = [int(field) for field in version.split(".")]
# prePVT images have even major versions. Prod have odd
self.is_prepvt = self.running_ver_fields[1] % 2 == 0
rma_support = RMA_SUPPORT_PREPVT if self.is_prepvt else RMA_SUPPORT_PROD
- logging.info('%s RMA support added in: %s',
- 'prePVT' if self.is_prepvt else 'prod', rma_support)
+ logging.info(
+ "%s RMA support added in: %s",
+ "prePVT" if self.is_prepvt else "prod",
+ rma_support,
+ )
if not self.is_prepvt and self._running_version_is_older(TESTLAB_PROD):
- raise ValueError('Update cr50. No testlab support in old prod '
- 'images.')
+ raise ValueError("Update cr50. No testlab support in old prod " "images.")
if self._running_version_is_older(rma_support):
- raise ValueError('%s does not have RMA support. Update to at '
- 'least %s' % (version, rma_support))
+ raise ValueError(
+ "%s does not have RMA support. Update to at "
+ "least %s" % (version, rma_support)
+ )
def _running_version_is_older(self, target_ver):
"""Returns True if running version is older than target_ver."""
- target_ver_fields = [int(field) for field in target_ver.split('.')]
+ target_ver_fields = [int(field) for field in target_ver.split(".")]
for i, field in enumerate(self.running_ver_fields):
if field > int(target_ver_fields[i]):
return False
@@ -486,11 +498,11 @@ class RMAOpen(object):
is no output or sysinfo doesn't contain the devid.
"""
self.set_cr50_device(device)
- sysinfo = self.send_cmd_get_output('sysinfo')
+ sysinfo = self.send_cmd_get_output("sysinfo")
# Make sure there is some output, and it shows it's from Cr50
- if not sysinfo or 'cr50' not in sysinfo:
+ if not sysinfo or "cr50" not in sysinfo:
return False
- logging.debug('Sysinfo output: %s', sysinfo)
+ logging.debug("Sysinfo output: %s", sysinfo)
# The cr50 device id should be in the sysinfo output, if we found
# the right console. Make sure it is
return devid in sysinfo
@@ -508,104 +520,137 @@ class RMAOpen(object):
ValueError if the console can't be found with the given serialname
"""
usb_serial = self.find_cr50_usb(usb_serial)
- logging.info('SERIALNAME: %s', usb_serial)
- devid = '0x' + ' 0x'.join(usb_serial.lower().split('-'))
- logging.info('DEVID: %s', devid)
+ logging.info("SERIALNAME: %s", usb_serial)
+ devid = "0x" + " 0x".join(usb_serial.lower().split("-"))
+ logging.info("DEVID: %s", devid)
# Get all the usb devices
- devices = glob.glob('/dev/ttyUSB*')
+ devices = glob.glob("/dev/ttyUSB*")
# Typically Cr50 has the lowest number. Sort the devices, so we're more
# likely to try the cr50 console first.
devices.sort()
# Find the one that is the cr50 console
for device in devices:
- logging.info('testing %s', device)
+ logging.info("testing %s", device)
if self.device_matches_devid(devid, device):
- logging.info('found device: %s', device)
+ logging.info("found device: %s", device)
return
logging.warning(DEBUG_CONNECTION)
- raise ValueError('Found USB device, but could not communicate with '
- 'cr50 console')
+ raise ValueError(
+ "Found USB device, but could not communicate with " "cr50 console"
+ )
def print_platform_info(self):
"""Print the cr50 BID RLZ code"""
- bid_output = self.send_cmd_get_output('bid')
- bid = re.search(r'Board ID: (\S+?)[:,]', bid_output).group(1)
+ bid_output = self.send_cmd_get_output("bid")
+ bid = re.search(r"Board ID: (\S+?)[:,]", bid_output).group(1)
if bid == ERASED_BID:
logging.warning(DEBUG_ERASED_BOARD_ID)
- raise ValueError('Cannot run RMA Open when board id is erased')
+ raise ValueError("Cannot run RMA Open when board id is erased")
bid = int(bid, 16)
- chrs = [chr((bid >> (8 * i)) & 0xff) for i in range(4)]
- logging.info('RLZ: %s', ''.join(chrs[::-1]))
+ chrs = [chr((bid >> (8 * i)) & 0xFF) for i in range(4)]
+ logging.info("RLZ: %s", "".join(chrs[::-1]))
@staticmethod
def find_cr50_usb(usb_serial):
"""Make sure the Cr50 USB device exists"""
try:
- output = subprocess.check_output(CR50_LSUSB_CMD, encoding='utf-8')
+ output = subprocess.check_output(CR50_LSUSB_CMD, encoding="utf-8")
except:
logging.warning(DEBUG_MISSING_USB)
- raise ValueError('Could not find Cr50 USB device')
- serialnames = re.findall(r'iSerial +\d+ (\S+)\s', output)
+ raise ValueError("Could not find Cr50 USB device")
+ serialnames = re.findall(r"iSerial +\d+ (\S+)\s", output)
if usb_serial:
if usb_serial not in serialnames:
logging.warning(DEBUG_SERIALNAME)
raise ValueError('Could not find usb device "%s"' % usb_serial)
return usb_serial
if len(serialnames) > 1:
- logging.info('Found Cr50 device serialnames %s',
- ', '.join(serialnames))
+ logging.info("Found Cr50 device serialnames %s", ", ".join(serialnames))
logging.warning(DEBUG_TOO_MANY_USB_DEVICES)
- raise ValueError('Too many cr50 usb devices')
+ raise ValueError("Too many cr50 usb devices")
return serialnames[0]
def print_dut_state(self):
"""Print CCD RMA and testlab mode state."""
if not self.check(CCD_IS_UNRESTRICTED):
- logging.info('CCD is still restricted.')
- logging.info('Run cr50_rma_open.py -g -i $HWID to generate a url')
- logging.info('Run cr50_rma_open.py -a $AUTHCODE to open cr50 with '
- 'an authcode')
+ logging.info("CCD is still restricted.")
+ logging.info("Run cr50_rma_open.py -g -i $HWID to generate a url")
+ logging.info(
+ "Run cr50_rma_open.py -a $AUTHCODE to open cr50 with " "an authcode"
+ )
elif not self.check(WP_IS_DISABLED):
- logging.info('WP is still enabled.')
- logging.info('Run cr50_rma_open.py -w to disable write protect')
+ logging.info("WP is still enabled.")
+ logging.info("Run cr50_rma_open.py -w to disable write protect")
if self.check(RMA_OPENED):
- logging.info('RMA Open complete')
+ logging.info("RMA Open complete")
if not self.check(TESTLAB_IS_ENABLED) and self.is_prepvt:
- logging.info('testlab mode is disabled.')
- logging.info('If you are prepping a device for the testlab, you '
- 'should enable testlab mode.')
- logging.info('Run cr50_rma_open.py -t to enable testlab mode')
+ logging.info("testlab mode is disabled.")
+ logging.info(
+ "If you are prepping a device for the testlab, you "
+ "should enable testlab mode."
+ )
+ logging.info("Run cr50_rma_open.py -t to enable testlab mode")
def parse_args(argv):
"""Get cr50_rma_open args."""
parser = argparse.ArgumentParser(
- description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
- parser.add_argument('-g', '--generate_challenge', action='store_true',
- help='Generate Cr50 challenge. Must be used with -i')
- parser.add_argument('-t', '--enable_testlab', action='store_true',
- help='enable testlab mode')
- parser.add_argument('-w', '--wp_disable', action='store_true',
- help='Disable write protect')
- parser.add_argument('-c', '--check_connection', action='store_true',
- help='Check cr50 console connection works')
- parser.add_argument('-s', '--serialname', type=str, default='',
- help='The cr50 usb serialname')
- parser.add_argument('-D', '--debug', action='store_true',
- help='print debug messages')
- parser.add_argument('-d', '--device', type=str, default='',
- help='cr50 console device ex /dev/ttyUSB0')
- parser.add_argument('-i', '--hwid', type=str, default='',
- help='The board hwid. Needed to generate a challenge')
- parser.add_argument('-a', '--authcode', type=str, default='',
- help='The authcode string from the challenge url')
- parser.add_argument('-P', '--servo_port', type=str, default='',
- help='the servo port')
- parser.add_argument('-I', '--ip', type=str, default='',
- help='The DUT IP. Necessary to do ccd open')
+ description=__doc__, formatter_class=argparse.RawTextHelpFormatter
+ )
+ parser.add_argument(
+ "-g",
+ "--generate_challenge",
+ action="store_true",
+ help="Generate Cr50 challenge. Must be used with -i",
+ )
+ parser.add_argument(
+ "-t", "--enable_testlab", action="store_true", help="enable testlab mode"
+ )
+ parser.add_argument(
+ "-w", "--wp_disable", action="store_true", help="Disable write protect"
+ )
+ parser.add_argument(
+ "-c",
+ "--check_connection",
+ action="store_true",
+ help="Check cr50 console connection works",
+ )
+ parser.add_argument(
+ "-s", "--serialname", type=str, default="", help="The cr50 usb serialname"
+ )
+ parser.add_argument(
+ "-D", "--debug", action="store_true", help="print debug messages"
+ )
+ parser.add_argument(
+ "-d",
+ "--device",
+ type=str,
+ default="",
+ help="cr50 console device ex /dev/ttyUSB0",
+ )
+ parser.add_argument(
+ "-i",
+ "--hwid",
+ type=str,
+ default="",
+ help="The board hwid. Needed to generate a challenge",
+ )
+ parser.add_argument(
+ "-a",
+ "--authcode",
+ type=str,
+ default="",
+ help="The authcode string from the challenge url",
+ )
+ parser.add_argument(
+ "-P", "--servo_port", type=str, default="", help="the servo port"
+ )
+ parser.add_argument(
+ "-I", "--ip", type=str, default="", help="The DUT IP. Necessary to do ccd open"
+ )
return parser.parse_args(argv)
@@ -614,52 +659,53 @@ def main(argv):
opts = parse_args(argv)
loglevel = logging.INFO
- log_format = '%(levelname)7s'
+ log_format = "%(levelname)7s"
if opts.debug:
loglevel = logging.DEBUG
- log_format += ' - %(lineno)3d:%(funcName)-15s'
- log_format += ' - %(message)s'
+ log_format += " - %(lineno)3d:%(funcName)-15s"
+ log_format += " - %(message)s"
logging.basicConfig(level=loglevel, format=log_format)
tried_authcode = False
- logging.info('Running cr50_rma_open version %s', SCRIPT_VERSION)
+ logging.info("Running cr50_rma_open version %s", SCRIPT_VERSION)
- cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port,
- opts.ip)
+ cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, opts.ip)
if opts.check_connection:
sys.exit(0)
if not cr50_rma_open.check(CCD_IS_UNRESTRICTED):
if opts.generate_challenge:
if not opts.hwid:
- logging.warning('--hwid necessary to generate challenge url')
+ logging.warning("--hwid necessary to generate challenge url")
sys.exit(0)
cr50_rma_open.generate_challenge_url(opts.hwid)
sys.exit(0)
elif opts.authcode:
- logging.info('Using authcode: %s', opts.authcode)
+ logging.info("Using authcode: %s", opts.authcode)
cr50_rma_open.try_authcode(opts.authcode)
tried_authcode = True
- if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or
- opts.wp_disable):
+ if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or opts.wp_disable):
if not cr50_rma_open.check(CCD_IS_UNRESTRICTED):
- raise ValueError("Can't disable write protect unless ccd is "
- "open. Run through the rma open process first")
+ raise ValueError(
+ "Can't disable write protect unless ccd is "
+ "open. Run through the rma open process first"
+ )
if tried_authcode:
- logging.warning('RMA Open did not disable write protect. File a '
- 'bug')
- logging.warning('Trying to disable it manually')
+ logging.warning("RMA Open did not disable write protect. File a " "bug")
+ logging.warning("Trying to disable it manually")
cr50_rma_open.wp_disable()
if not cr50_rma_open.check(TESTLAB_IS_ENABLED) and opts.enable_testlab:
if not cr50_rma_open.check(CCD_IS_UNRESTRICTED):
- raise ValueError("Can't enable testlab mode unless ccd is open."
- "Run through the rma open process first")
+ raise ValueError(
+ "Can't enable testlab mode unless ccd is open."
+ "Run through the rma open process first"
+ )
cr50_rma_open.enable_testlab()
cr50_rma_open.print_dut_state()
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/extra/stack_analyzer/stack_analyzer.py b/extra/stack_analyzer/stack_analyzer.py
index 77d16d5450..17b2651972 100755
--- a/extra/stack_analyzer/stack_analyzer.py
+++ b/extra/stack_analyzer/stack_analyzer.py
@@ -25,1848 +25,1992 @@ import ctypes
import os
import re
import subprocess
-import yaml
+import yaml
-SECTION_RO = 'RO'
-SECTION_RW = 'RW'
+SECTION_RO = "RO"
+SECTION_RW = "RW"
# Default size of extra stack frame needed by exception context switch.
# This value is for cortex-m with FPU enabled.
DEFAULT_EXCEPTION_FRAME_SIZE = 224
class StackAnalyzerError(Exception):
- """Exception class for stack analyzer utility."""
+ """Exception class for stack analyzer utility."""
class TaskInfo(ctypes.Structure):
- """Taskinfo ctypes structure.
+ """Taskinfo ctypes structure.
- The structure definition is corresponding to the "struct taskinfo"
- in "util/export_taskinfo.so.c".
- """
- _fields_ = [('name', ctypes.c_char_p),
- ('routine', ctypes.c_char_p),
- ('stack_size', ctypes.c_uint32)]
+ The structure definition is corresponding to the "struct taskinfo"
+ in "util/export_taskinfo.so.c".
+ """
+ _fields_ = [
+ ("name", ctypes.c_char_p),
+ ("routine", ctypes.c_char_p),
+ ("stack_size", ctypes.c_uint32),
+ ]
-class Task(object):
- """Task information.
- Attributes:
- name: Task name.
- routine_name: Routine function name.
- stack_max_size: Max stack size.
- routine_address: Resolved routine address. None if it hasn't been resolved.
- """
-
- def __init__(self, name, routine_name, stack_max_size, routine_address=None):
- """Constructor.
+class Task(object):
+ """Task information.
- Args:
+ Attributes:
name: Task name.
routine_name: Routine function name.
stack_max_size: Max stack size.
- routine_address: Resolved routine address.
+ routine_address: Resolved routine address. None if it hasn't been resolved.
"""
- self.name = name
- self.routine_name = routine_name
- self.stack_max_size = stack_max_size
- self.routine_address = routine_address
- def __eq__(self, other):
- """Task equality.
+ def __init__(self, name, routine_name, stack_max_size, routine_address=None):
+ """Constructor.
- Args:
- other: The compared object.
+ Args:
+ name: Task name.
+ routine_name: Routine function name.
+ stack_max_size: Max stack size.
+ routine_address: Resolved routine address.
+ """
+ self.name = name
+ self.routine_name = routine_name
+ self.stack_max_size = stack_max_size
+ self.routine_address = routine_address
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Task):
- return False
+ def __eq__(self, other):
+ """Task equality.
- return (self.name == other.name and
- self.routine_name == other.routine_name and
- self.stack_max_size == other.stack_max_size and
- self.routine_address == other.routine_address)
+ Args:
+ other: The compared object.
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Task):
+ return False
-class Symbol(object):
- """Symbol information.
+ return (
+ self.name == other.name
+ and self.routine_name == other.routine_name
+ and self.stack_max_size == other.stack_max_size
+ and self.routine_address == other.routine_address
+ )
- Attributes:
- address: Symbol address.
- symtype: Symbol type, 'O' (data, object) or 'F' (function).
- size: Symbol size.
- name: Symbol name.
- """
- def __init__(self, address, symtype, size, name):
- """Constructor.
+class Symbol(object):
+ """Symbol information.
- Args:
+ Attributes:
address: Symbol address.
- symtype: Symbol type.
+ symtype: Symbol type, 'O' (data, object) or 'F' (function).
size: Symbol size.
name: Symbol name.
"""
- assert symtype in ['O', 'F']
- self.address = address
- self.symtype = symtype
- self.size = size
- self.name = name
- def __eq__(self, other):
- """Symbol equality.
-
- Args:
- other: The compared object.
-
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Symbol):
- return False
-
- return (self.address == other.address and
- self.symtype == other.symtype and
- self.size == other.size and
- self.name == other.name)
+ def __init__(self, address, symtype, size, name):
+ """Constructor.
+
+ Args:
+ address: Symbol address.
+ symtype: Symbol type.
+ size: Symbol size.
+ name: Symbol name.
+ """
+ assert symtype in ["O", "F"]
+ self.address = address
+ self.symtype = symtype
+ self.size = size
+ self.name = name
+
+ def __eq__(self, other):
+ """Symbol equality.
+
+ Args:
+ other: The compared object.
+
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Symbol):
+ return False
+
+ return (
+ self.address == other.address
+ and self.symtype == other.symtype
+ and self.size == other.size
+ and self.name == other.name
+ )
class Callsite(object):
- """Function callsite.
-
- Attributes:
- address: Address of callsite location. None if it is unknown.
- target: Callee address. None if it is unknown.
- is_tail: A bool indicates that it is a tailing call.
- callee: Resolved callee function. None if it hasn't been resolved.
- """
-
- def __init__(self, address, target, is_tail, callee=None):
- """Constructor.
+ """Function callsite.
- Args:
+ Attributes:
address: Address of callsite location. None if it is unknown.
target: Callee address. None if it is unknown.
- is_tail: A bool indicates that it is a tailing call. (function jump to
- another function without restoring the stack frame)
- callee: Resolved callee function.
+ is_tail: A bool indicates that it is a tailing call.
+ callee: Resolved callee function. None if it hasn't been resolved.
"""
- # It makes no sense that both address and target are unknown.
- assert not (address is None and target is None)
- self.address = address
- self.target = target
- self.is_tail = is_tail
- self.callee = callee
-
- def __eq__(self, other):
- """Callsite equality.
-
- Args:
- other: The compared object.
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Callsite):
- return False
-
- if not (self.address == other.address and
- self.target == other.target and
- self.is_tail == other.is_tail):
- return False
-
- if self.callee is None:
- return other.callee is None
- elif other.callee is None:
- return False
-
- # Assume the addresses of functions are unique.
- return self.callee.address == other.callee.address
+ def __init__(self, address, target, is_tail, callee=None):
+ """Constructor.
+
+ Args:
+ address: Address of callsite location. None if it is unknown.
+ target: Callee address. None if it is unknown.
+ is_tail: A bool indicates that it is a tailing call. (function jump to
+ another function without restoring the stack frame)
+ callee: Resolved callee function.
+ """
+ # It makes no sense that both address and target are unknown.
+ assert not (address is None and target is None)
+ self.address = address
+ self.target = target
+ self.is_tail = is_tail
+ self.callee = callee
+
+ def __eq__(self, other):
+ """Callsite equality.
+
+ Args:
+ other: The compared object.
+
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Callsite):
+ return False
+
+ if not (
+ self.address == other.address
+ and self.target == other.target
+ and self.is_tail == other.is_tail
+ ):
+ return False
+
+ if self.callee is None:
+ return other.callee is None
+ elif other.callee is None:
+ return False
+
+ # Assume the addresses of functions are unique.
+ return self.callee.address == other.callee.address
class Function(object):
- """Function.
-
- Attributes:
- address: Address of function.
- name: Name of function from its symbol.
- stack_frame: Size of stack frame.
- callsites: Callsite list.
- stack_max_usage: Max stack usage. None if it hasn't been analyzed.
- stack_max_path: Max stack usage path. None if it hasn't been analyzed.
- """
+ """Function.
- def __init__(self, address, name, stack_frame, callsites):
- """Constructor.
-
- Args:
+ Attributes:
address: Address of function.
name: Name of function from its symbol.
stack_frame: Size of stack frame.
callsites: Callsite list.
+ stack_max_usage: Max stack usage. None if it hasn't been analyzed.
+ stack_max_path: Max stack usage path. None if it hasn't been analyzed.
"""
- self.address = address
- self.name = name
- self.stack_frame = stack_frame
- self.callsites = callsites
- self.stack_max_usage = None
- self.stack_max_path = None
-
- def __eq__(self, other):
- """Function equality.
-
- Args:
- other: The compared object.
-
- Returns:
- True if equal, False if not.
- """
- if not isinstance(other, Function):
- return False
-
- if not (self.address == other.address and
- self.name == other.name and
- self.stack_frame == other.stack_frame and
- self.callsites == other.callsites and
- self.stack_max_usage == other.stack_max_usage):
- return False
- if self.stack_max_path is None:
- return other.stack_max_path is None
- elif other.stack_max_path is None:
- return False
+ def __init__(self, address, name, stack_frame, callsites):
+ """Constructor.
+
+ Args:
+ address: Address of function.
+ name: Name of function from its symbol.
+ stack_frame: Size of stack frame.
+ callsites: Callsite list.
+ """
+ self.address = address
+ self.name = name
+ self.stack_frame = stack_frame
+ self.callsites = callsites
+ self.stack_max_usage = None
+ self.stack_max_path = None
+
+ def __eq__(self, other):
+ """Function equality.
+
+ Args:
+ other: The compared object.
+
+ Returns:
+ True if equal, False if not.
+ """
+ if not isinstance(other, Function):
+ return False
+
+ if not (
+ self.address == other.address
+ and self.name == other.name
+ and self.stack_frame == other.stack_frame
+ and self.callsites == other.callsites
+ and self.stack_max_usage == other.stack_max_usage
+ ):
+ return False
+
+ if self.stack_max_path is None:
+ return other.stack_max_path is None
+ elif other.stack_max_path is None:
+ return False
+
+ if len(self.stack_max_path) != len(other.stack_max_path):
+ return False
+
+ for self_func, other_func in zip(self.stack_max_path, other.stack_max_path):
+ # Assume the addresses of functions are unique.
+ if self_func.address != other_func.address:
+ return False
+
+ return True
+
+ def __hash__(self):
+ return id(self)
- if len(self.stack_max_path) != len(other.stack_max_path):
- return False
-
- for self_func, other_func in zip(self.stack_max_path, other.stack_max_path):
- # Assume the addresses of functions are unique.
- if self_func.address != other_func.address:
- return False
-
- return True
-
- def __hash__(self):
- return id(self)
class AndesAnalyzer(object):
- """Disassembly analyzer for Andes architecture.
-
- Public Methods:
- AnalyzeFunction: Analyze stack frame and callsites of the function.
- """
-
- GENERAL_PURPOSE_REGISTER_SIZE = 4
-
- # Possible condition code suffixes.
- CONDITION_CODES = [ 'eq', 'eqz', 'gez', 'gtz', 'lez', 'ltz', 'ne', 'nez',
- 'eqc', 'nec', 'nezs', 'nes', 'eqs']
- CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES))
-
- IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>'
- # Branch instructions.
- JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$' \
- .format(CONDITION_CODES_RE))
- # Call instructions.
- CALL_OPCODE_RE = re.compile \
- (r'^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$')
- CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE))
- # Ignore lp register because it's for return.
- INDIRECT_CALL_OPERAND_RE = re.compile \
- (r'^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$')
- # TODO: Handle other kinds of store instructions.
- PUSH_OPCODE_RE = re.compile(r'^push(\d{1,})$')
- PUSH_OPERAND_RE = re.compile(r'^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}')
- SMW_OPCODE_RE = re.compile(r'^smw(\.\w\w|\.\w\w\w)$')
- SMW_OPERAND_RE = re.compile(r'^(\$r\d{1,}|\$\wp), \[\$\wp\], '
- r'(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}')
- OPERANDGROUP_RE = re.compile(r'^\$r\d{1,}\~\$r\d{1,}')
-
- LWI_OPCODE_RE = re.compile(r'^lwi(\.\w\w)$')
- LWI_PC_OPERAND_RE = re.compile(r'^\$pc, \[([^\]]+)\]')
- # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec"
- # Assume there is always a "\t" after the hex data.
- DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+'
- r'(?P<words>[0-9A-Fa-f ]+)'
- r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?')
-
- def ParseInstruction(self, line, function_end):
- """Parse the line of instruction.
+ """Disassembly analyzer for Andes architecture.
- Args:
- line: Text of disassembly.
- function_end: End address of the current function. None if unknown.
-
- Returns:
- (address, words, opcode, operand_text): The instruction address, words,
- opcode, and the text of operands.
- None if it isn't an instruction line.
+ Public Methods:
+ AnalyzeFunction: Analyze stack frame and callsites of the function.
"""
- result = self.DISASM_REGEX_RE.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- # Check if it's out of bound.
- if function_end is not None and address >= function_end:
- return None
-
- opcode = result.group('opcode').strip()
- operand_text = result.group('operand')
- words = result.group('words')
- if operand_text is None:
- operand_text = ''
- else:
- operand_text = operand_text.strip()
-
- return (address, words, opcode, operand_text)
-
- def AnalyzeFunction(self, function_symbol, instructions):
-
- stack_frame = 0
- callsites = []
- for address, words, opcode, operand_text in instructions:
- is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
- is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
-
- if is_jump_opcode or is_call_opcode:
- is_tail = is_jump_opcode
-
- result = self.CALL_OPERAND_RE.match(operand_text)
+ GENERAL_PURPOSE_REGISTER_SIZE = 4
+
+ # Possible condition code suffixes.
+ CONDITION_CODES = [
+ "eq",
+ "eqz",
+ "gez",
+ "gtz",
+ "lez",
+ "ltz",
+ "ne",
+ "nez",
+ "eqc",
+ "nec",
+ "nezs",
+ "nes",
+ "eqs",
+ ]
+ CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES))
+
+ IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>"
+ # Branch instructions.
+ JUMP_OPCODE_RE = re.compile(
+ r"^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$".format(CONDITION_CODES_RE)
+ )
+ # Call instructions.
+ CALL_OPCODE_RE = re.compile(r"^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$")
+ CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE))
+ # Ignore lp register because it's for return.
+ INDIRECT_CALL_OPERAND_RE = re.compile(r"^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$")
+ # TODO: Handle other kinds of store instructions.
+ PUSH_OPCODE_RE = re.compile(r"^push(\d{1,})$")
+ PUSH_OPERAND_RE = re.compile(r"^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}")
+ SMW_OPCODE_RE = re.compile(r"^smw(\.\w\w|\.\w\w\w)$")
+ SMW_OPERAND_RE = re.compile(
+ r"^(\$r\d{1,}|\$\wp), \[\$\wp\], "
+ r"(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}"
+ )
+ OPERANDGROUP_RE = re.compile(r"^\$r\d{1,}\~\$r\d{1,}")
+
+ LWI_OPCODE_RE = re.compile(r"^lwi(\.\w\w)$")
+ LWI_PC_OPERAND_RE = re.compile(r"^\$pc, \[([^\]]+)\]")
+ # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec"
+ # Assume there is always a "\t" after the hex data.
+ DISASM_REGEX_RE = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+):\s+"
+ r"(?P<words>[0-9A-Fa-f ]+)"
+ r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?"
+ )
+
+ def ParseInstruction(self, line, function_end):
+ """Parse the line of instruction.
+
+ Args:
+ line: Text of disassembly.
+ function_end: End address of the current function. None if unknown.
+
+ Returns:
+ (address, words, opcode, operand_text): The instruction address, words,
+ opcode, and the text of operands.
+ None if it isn't an instruction line.
+ """
+ result = self.DISASM_REGEX_RE.match(line)
if result is None:
- if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None):
- # Found an indirect call.
- callsites.append(Callsite(address, None, is_tail))
-
+ return None
+
+ address = int(result.group("address"), 16)
+ # Check if it's out of bound.
+ if function_end is not None and address >= function_end:
+ return None
+
+ opcode = result.group("opcode").strip()
+ operand_text = result.group("operand")
+ words = result.group("words")
+ if operand_text is None:
+ operand_text = ""
else:
- target_address = int(result.group(1), 16)
- # Filter out the in-function target (branches and in-function calls,
- # which are actually branches).
- if not (function_symbol.size > 0 and
- function_symbol.address < target_address <
- (function_symbol.address + function_symbol.size)):
- # Maybe it is a callsite.
- callsites.append(Callsite(address, target_address, is_tail))
-
- elif self.LWI_OPCODE_RE.match(opcode) is not None:
- result = self.LWI_PC_OPERAND_RE.match(operand_text)
- if result is not None:
- # Ignore "lwi $pc, [$sp], xx" because it's usually a return.
- if result.group(1) != '$sp':
- # Found an indirect call.
- callsites.append(Callsite(address, None, True))
-
- elif self.PUSH_OPCODE_RE.match(opcode) is not None:
- # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp}
- if self.PUSH_OPERAND_RE.match(operand_text) is not None:
- # capture fc 20
- imm5u = int(words.split(' ')[1], 16)
- # sp = sp - (imm5u << 3)
- imm8u = (imm5u<<3) & 0xff
- stack_frame += imm8u
-
- result = self.PUSH_OPERAND_RE.match(operand_text)
- operandgroup_text = result.group(1)
- # capture $rx~$ry
- if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
- # capture number & transfer string to integer
- oprandgrouphead = operandgroup_text.split(',')[0]
- rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0])))
- ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1])))
-
- stack_frame += ((len(operandgroup_text.split(','))+ry-rx) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
- else:
- stack_frame += (len(operandgroup_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
-
- elif self.SMW_OPCODE_RE.match(opcode) is not None:
- # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp}
- if self.SMW_OPERAND_RE.match(operand_text) is not None:
- result = self.SMW_OPERAND_RE.match(operand_text)
- operandgroup_text = result.group(3)
- # capture $rx~$ry
- if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
- # capture number & transfer string to integer
- oprandgrouphead = operandgroup_text.split(',')[0]
- rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0])))
- ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1])))
-
- stack_frame += ((len(operandgroup_text.split(','))+ry-rx) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
- else:
- stack_frame += (len(operandgroup_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
-
- return (stack_frame, callsites)
-
-class ArmAnalyzer(object):
- """Disassembly analyzer for ARM architecture.
-
- Public Methods:
- AnalyzeFunction: Analyze stack frame and callsites of the function.
- """
-
- GENERAL_PURPOSE_REGISTER_SIZE = 4
-
- # Possible condition code suffixes.
- CONDITION_CODES = ['', 'eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl', 'vs',
- 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le']
- CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES))
- # Assume there is no function name containing ">".
- IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>'
-
- # Fuzzy regular expressions for instruction and operand parsing.
- # Branch instructions.
- JUMP_OPCODE_RE = re.compile(
- r'^(b{0}|bx{0})(\.\w)?$'.format(CONDITION_CODES_RE))
- # Call instructions.
- CALL_OPCODE_RE = re.compile(
- r'^(bl{0}|blx{0})(\.\w)?$'.format(CONDITION_CODES_RE))
- CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE))
- CBZ_CBNZ_OPCODE_RE = re.compile(r'^(cbz|cbnz)(\.\w)?$')
- # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>"
- CBZ_CBNZ_OPERAND_RE = re.compile(r'^[^,]+,\s+{}$'.format(IMM_ADDRESS_RE))
- # Ignore lr register because it's for return.
- INDIRECT_CALL_OPERAND_RE = re.compile(r'^r\d+|sb|sl|fp|ip|sp|pc$')
- # TODO(cheyuw): Handle conditional versions of following
- # instructions.
- # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc).
- LDR_OPCODE_RE = re.compile(r'^ldr(\.\w)?$')
- # Example: "pc, [sp], #4"
- LDR_PC_OPERAND_RE = re.compile(r'^pc, \[([^\]]+)\]')
- # TODO(cheyuw): Handle other kinds of stm instructions.
- PUSH_OPCODE_RE = re.compile(r'^push$')
- STM_OPCODE_RE = re.compile(r'^stmdb$')
- # Stack subtraction instructions.
- SUB_OPCODE_RE = re.compile(r'^sub(s|w)?(\.\w)?$')
- SUB_OPERAND_RE = re.compile(r'^sp[^#]+#(\d+)')
- # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68"
- # Assume there is always a "\t" after the hex data.
- DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+'
- r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?')
-
- def ParseInstruction(self, line, function_end):
- """Parse the line of instruction.
+ operand_text = operand_text.strip()
+
+ return (address, words, opcode, operand_text)
+
+ def AnalyzeFunction(self, function_symbol, instructions):
+
+ stack_frame = 0
+ callsites = []
+ for address, words, opcode, operand_text in instructions:
+ is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
+ is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
+
+ if is_jump_opcode or is_call_opcode:
+ is_tail = is_jump_opcode
+
+ result = self.CALL_OPERAND_RE.match(operand_text)
+
+ if result is None:
+ if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None:
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, is_tail))
+
+ else:
+ target_address = int(result.group(1), 16)
+ # Filter out the in-function target (branches and in-function calls,
+ # which are actually branches).
+ if not (
+ function_symbol.size > 0
+ and function_symbol.address
+ < target_address
+ < (function_symbol.address + function_symbol.size)
+ ):
+ # Maybe it is a callsite.
+ callsites.append(Callsite(address, target_address, is_tail))
+
+ elif self.LWI_OPCODE_RE.match(opcode) is not None:
+ result = self.LWI_PC_OPERAND_RE.match(operand_text)
+ if result is not None:
+ # Ignore "lwi $pc, [$sp], xx" because it's usually a return.
+ if result.group(1) != "$sp":
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, True))
+
+ elif self.PUSH_OPCODE_RE.match(opcode) is not None:
+ # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp}
+ if self.PUSH_OPERAND_RE.match(operand_text) is not None:
+ # capture fc 20
+ imm5u = int(words.split(" ")[1], 16)
+ # sp = sp - (imm5u << 3)
+ imm8u = (imm5u << 3) & 0xFF
+ stack_frame += imm8u
+
+ result = self.PUSH_OPERAND_RE.match(operand_text)
+ operandgroup_text = result.group(1)
+ # capture $rx~$ry
+ if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
+ # capture number & transfer string to integer
+ oprandgrouphead = operandgroup_text.split(",")[0]
+ rx = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[0]))
+ )
+ ry = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[1]))
+ )
+
+ stack_frame += (
+ len(operandgroup_text.split(",")) + ry - rx
+ ) * self.GENERAL_PURPOSE_REGISTER_SIZE
+ else:
+ stack_frame += (
+ len(operandgroup_text.split(","))
+ * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+
+ elif self.SMW_OPCODE_RE.match(opcode) is not None:
+ # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp}
+ if self.SMW_OPERAND_RE.match(operand_text) is not None:
+ result = self.SMW_OPERAND_RE.match(operand_text)
+ operandgroup_text = result.group(3)
+ # capture $rx~$ry
+ if self.OPERANDGROUP_RE.match(operandgroup_text) is not None:
+ # capture number & transfer string to integer
+ oprandgrouphead = operandgroup_text.split(",")[0]
+ rx = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[0]))
+ )
+ ry = int(
+ "".join(filter(str.isdigit, oprandgrouphead.split("~")[1]))
+ )
+
+ stack_frame += (
+ len(operandgroup_text.split(",")) + ry - rx
+ ) * self.GENERAL_PURPOSE_REGISTER_SIZE
+ else:
+ stack_frame += (
+ len(operandgroup_text.split(","))
+ * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+
+ return (stack_frame, callsites)
- Args:
- line: Text of disassembly.
- function_end: End address of the current function. None if unknown.
- Returns:
- (address, opcode, operand_text): The instruction address, opcode,
- and the text of operands. None if it
- isn't an instruction line.
- """
- result = self.DISASM_REGEX_RE.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- # Check if it's out of bound.
- if function_end is not None and address >= function_end:
- return None
-
- opcode = result.group('opcode').strip()
- operand_text = result.group('operand')
- if operand_text is None:
- operand_text = ''
- else:
- operand_text = operand_text.strip()
-
- return (address, opcode, operand_text)
-
- def AnalyzeFunction(self, function_symbol, instructions):
- """Analyze function, resolve the size of stack frame and callsites.
-
- Args:
- function_symbol: Function symbol.
- instructions: Instruction list.
+class ArmAnalyzer(object):
+ """Disassembly analyzer for ARM architecture.
- Returns:
- (stack_frame, callsites): Size of stack frame, callsite list.
+ Public Methods:
+ AnalyzeFunction: Analyze stack frame and callsites of the function.
"""
- stack_frame = 0
- callsites = []
- for address, opcode, operand_text in instructions:
- is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
- is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
- is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None
- if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode:
- is_tail = is_jump_opcode or is_cbz_cbnz_opcode
-
- if is_cbz_cbnz_opcode:
- result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text)
- else:
- result = self.CALL_OPERAND_RE.match(operand_text)
+ GENERAL_PURPOSE_REGISTER_SIZE = 4
+
+ # Possible condition code suffixes.
+ CONDITION_CODES = [
+ "",
+ "eq",
+ "ne",
+ "cs",
+ "hs",
+ "cc",
+ "lo",
+ "mi",
+ "pl",
+ "vs",
+ "vc",
+ "hi",
+ "ls",
+ "ge",
+ "lt",
+ "gt",
+ "le",
+ ]
+ CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES))
+ # Assume there is no function name containing ">".
+ IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>"
+
+ # Fuzzy regular expressions for instruction and operand parsing.
+ # Branch instructions.
+ JUMP_OPCODE_RE = re.compile(r"^(b{0}|bx{0})(\.\w)?$".format(CONDITION_CODES_RE))
+ # Call instructions.
+ CALL_OPCODE_RE = re.compile(r"^(bl{0}|blx{0})(\.\w)?$".format(CONDITION_CODES_RE))
+ CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE))
+ CBZ_CBNZ_OPCODE_RE = re.compile(r"^(cbz|cbnz)(\.\w)?$")
+ # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>"
+ CBZ_CBNZ_OPERAND_RE = re.compile(r"^[^,]+,\s+{}$".format(IMM_ADDRESS_RE))
+ # Ignore lr register because it's for return.
+ INDIRECT_CALL_OPERAND_RE = re.compile(r"^r\d+|sb|sl|fp|ip|sp|pc$")
+ # TODO(cheyuw): Handle conditional versions of following
+ # instructions.
+ # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc).
+ LDR_OPCODE_RE = re.compile(r"^ldr(\.\w)?$")
+ # Example: "pc, [sp], #4"
+ LDR_PC_OPERAND_RE = re.compile(r"^pc, \[([^\]]+)\]")
+ # TODO(cheyuw): Handle other kinds of stm instructions.
+ PUSH_OPCODE_RE = re.compile(r"^push$")
+ STM_OPCODE_RE = re.compile(r"^stmdb$")
+ # Stack subtraction instructions.
+ SUB_OPCODE_RE = re.compile(r"^sub(s|w)?(\.\w)?$")
+ SUB_OPERAND_RE = re.compile(r"^sp[^#]+#(\d+)")
+ # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68"
+ # Assume there is always a "\t" after the hex data.
+ DISASM_REGEX_RE = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+"
+ r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?"
+ )
+
+ def ParseInstruction(self, line, function_end):
+ """Parse the line of instruction.
+
+ Args:
+ line: Text of disassembly.
+ function_end: End address of the current function. None if unknown.
+
+ Returns:
+ (address, opcode, operand_text): The instruction address, opcode,
+ and the text of operands. None if it
+ isn't an instruction line.
+ """
+ result = self.DISASM_REGEX_RE.match(line)
if result is None:
- # Failed to match immediate address, maybe it is an indirect call.
- # CBZ and CBNZ can't be indirect calls.
- if (not is_cbz_cbnz_opcode and
- self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None):
- # Found an indirect call.
- callsites.append(Callsite(address, None, is_tail))
+ return None
- else:
- target_address = int(result.group(1), 16)
- # Filter out the in-function target (branches and in-function calls,
- # which are actually branches).
- if not (function_symbol.size > 0 and
- function_symbol.address < target_address <
- (function_symbol.address + function_symbol.size)):
- # Maybe it is a callsite.
- callsites.append(Callsite(address, target_address, is_tail))
-
- elif self.LDR_OPCODE_RE.match(opcode) is not None:
- result = self.LDR_PC_OPERAND_RE.match(operand_text)
- if result is not None:
- # Ignore "ldr pc, [sp], xx" because it's usually a return.
- if result.group(1) != 'sp':
- # Found an indirect call.
- callsites.append(Callsite(address, None, True))
-
- elif self.PUSH_OPCODE_RE.match(opcode) is not None:
- # Example: "{r4, r5, r6, r7, lr}"
- stack_frame += (len(operand_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
- elif self.SUB_OPCODE_RE.match(opcode) is not None:
- result = self.SUB_OPERAND_RE.match(operand_text)
- if result is not None:
- stack_frame += int(result.group(1))
- else:
- # Unhandled stack register subtraction.
- assert not operand_text.startswith('sp')
+ address = int(result.group("address"), 16)
+ # Check if it's out of bound.
+ if function_end is not None and address >= function_end:
+ return None
- elif self.STM_OPCODE_RE.match(opcode) is not None:
- if operand_text.startswith('sp!'):
- # Subtract and writeback to stack register.
- # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}"
- # Get the text of pushed register list.
- unused_sp, unused_sep, parameter_text = operand_text.partition(',')
- stack_frame += (len(parameter_text.split(',')) *
- self.GENERAL_PURPOSE_REGISTER_SIZE)
+ opcode = result.group("opcode").strip()
+ operand_text = result.group("operand")
+ if operand_text is None:
+ operand_text = ""
+ else:
+ operand_text = operand_text.strip()
+
+ return (address, opcode, operand_text)
+
+ def AnalyzeFunction(self, function_symbol, instructions):
+ """Analyze function, resolve the size of stack frame and callsites.
+
+ Args:
+ function_symbol: Function symbol.
+ instructions: Instruction list.
+
+ Returns:
+ (stack_frame, callsites): Size of stack frame, callsite list.
+ """
+ stack_frame = 0
+ callsites = []
+ for address, opcode, operand_text in instructions:
+ is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
+ is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
+ is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None
+ if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode:
+ is_tail = is_jump_opcode or is_cbz_cbnz_opcode
+
+ if is_cbz_cbnz_opcode:
+ result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text)
+ else:
+ result = self.CALL_OPERAND_RE.match(operand_text)
+
+ if result is None:
+ # Failed to match immediate address, maybe it is an indirect call.
+ # CBZ and CBNZ can't be indirect calls.
+ if (
+ not is_cbz_cbnz_opcode
+ and self.INDIRECT_CALL_OPERAND_RE.match(operand_text)
+ is not None
+ ):
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, is_tail))
+
+ else:
+ target_address = int(result.group(1), 16)
+ # Filter out the in-function target (branches and in-function calls,
+ # which are actually branches).
+ if not (
+ function_symbol.size > 0
+ and function_symbol.address
+ < target_address
+ < (function_symbol.address + function_symbol.size)
+ ):
+ # Maybe it is a callsite.
+ callsites.append(Callsite(address, target_address, is_tail))
+
+ elif self.LDR_OPCODE_RE.match(opcode) is not None:
+ result = self.LDR_PC_OPERAND_RE.match(operand_text)
+ if result is not None:
+ # Ignore "ldr pc, [sp], xx" because it's usually a return.
+ if result.group(1) != "sp":
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, True))
+
+ elif self.PUSH_OPCODE_RE.match(opcode) is not None:
+ # Example: "{r4, r5, r6, r7, lr}"
+ stack_frame += (
+ len(operand_text.split(",")) * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+ elif self.SUB_OPCODE_RE.match(opcode) is not None:
+ result = self.SUB_OPERAND_RE.match(operand_text)
+ if result is not None:
+ stack_frame += int(result.group(1))
+ else:
+ # Unhandled stack register subtraction.
+ assert not operand_text.startswith("sp")
+
+ elif self.STM_OPCODE_RE.match(opcode) is not None:
+ if operand_text.startswith("sp!"):
+ # Subtract and writeback to stack register.
+ # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}"
+ # Get the text of pushed register list.
+ unused_sp, unused_sep, parameter_text = operand_text.partition(",")
+ stack_frame += (
+ len(parameter_text.split(","))
+ * self.GENERAL_PURPOSE_REGISTER_SIZE
+ )
+
+ return (stack_frame, callsites)
- return (stack_frame, callsites)
class RiscvAnalyzer(object):
- """Disassembly analyzer for RISC-V architecture.
-
- Public Methods:
- AnalyzeFunction: Analyze stack frame and callsites of the function.
- """
-
- # Possible condition code suffixes.
- CONDITION_CODES = [ 'eqz', 'nez', 'lez', 'gez', 'ltz', 'gtz', 'gt', 'le',
- 'gtu', 'leu', 'eq', 'ne', 'ge', 'lt', 'ltu', 'geu']
- CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES))
- # Branch instructions.
- JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr)$'.format(CONDITION_CODES_RE))
- # Call instructions.
- CALL_OPCODE_RE = re.compile(r'^(jal|jalr)$')
- # Example: "j 8009b318 <set_state_prl_hr>" or
- # "jal ra,800a4394 <power_get_signals>" or
- # "bltu t0,t1,80080300 <data_loop>"
- JUMP_ADDRESS_RE = r'((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>'
- CALL_OPERAND_RE = re.compile(r'^{}$'.format(JUMP_ADDRESS_RE))
- # Capture address, Example: 800a4394
- CAPTURE_ADDRESS = re.compile(r'[0-9A-Fa-f]{8}')
- # Indirect jump, Example: jalr a5
- INDIRECT_CALL_OPERAND_RE = re.compile(r'^t\d+|s\d+|a\d+$')
- # Example: addi
- ADDI_OPCODE_RE = re.compile(r'^addi$')
- # Allocate stack instructions.
- ADDI_OPERAND_RE = re.compile(r'^(sp,sp,-\d+)$')
- # Example: "800804b6: 1101 addi sp,sp,-32"
- DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+'
- r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?')
-
- def ParseInstruction(self, line, function_end):
- """Parse the line of instruction.
+ """Disassembly analyzer for RISC-V architecture.
- Args:
- line: Text of disassembly.
- function_end: End address of the current function. None if unknown.
-
- Returns:
- (address, opcode, operand_text): The instruction address, opcode,
- and the text of operands. None if it
- isn't an instruction line.
+ Public Methods:
+ AnalyzeFunction: Analyze stack frame and callsites of the function.
"""
- result = self.DISASM_REGEX_RE.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- # Check if it's out of bound.
- if function_end is not None and address >= function_end:
- return None
-
- opcode = result.group('opcode').strip()
- operand_text = result.group('operand')
- if operand_text is None:
- operand_text = ''
- else:
- operand_text = operand_text.strip()
-
- return (address, opcode, operand_text)
-
- def AnalyzeFunction(self, function_symbol, instructions):
-
- stack_frame = 0
- callsites = []
- for address, opcode, operand_text in instructions:
- is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
- is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
- if is_jump_opcode or is_call_opcode:
- is_tail = is_jump_opcode
-
- result = self.CALL_OPERAND_RE.match(operand_text)
+ # Possible condition code suffixes.
+ CONDITION_CODES = [
+ "eqz",
+ "nez",
+ "lez",
+ "gez",
+ "ltz",
+ "gtz",
+ "gt",
+ "le",
+ "gtu",
+ "leu",
+ "eq",
+ "ne",
+ "ge",
+ "lt",
+ "ltu",
+ "geu",
+ ]
+ CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES))
+ # Branch instructions.
+ JUMP_OPCODE_RE = re.compile(r"^(b{0}|j|jr)$".format(CONDITION_CODES_RE))
+ # Call instructions.
+ CALL_OPCODE_RE = re.compile(r"^(jal|jalr)$")
+ # Example: "j 8009b318 <set_state_prl_hr>" or
+ # "jal ra,800a4394 <power_get_signals>" or
+ # "bltu t0,t1,80080300 <data_loop>"
+ JUMP_ADDRESS_RE = r"((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>"
+ CALL_OPERAND_RE = re.compile(r"^{}$".format(JUMP_ADDRESS_RE))
+ # Capture address, Example: 800a4394
+ CAPTURE_ADDRESS = re.compile(r"[0-9A-Fa-f]{8}")
+ # Indirect jump, Example: jalr a5
+ INDIRECT_CALL_OPERAND_RE = re.compile(r"^t\d+|s\d+|a\d+$")
+ # Example: addi
+ ADDI_OPCODE_RE = re.compile(r"^addi$")
+ # Allocate stack instructions.
+ ADDI_OPERAND_RE = re.compile(r"^(sp,sp,-\d+)$")
+ # Example: "800804b6: 1101 addi sp,sp,-32"
+ DISASM_REGEX_RE = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+"
+ r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?"
+ )
+
+ def ParseInstruction(self, line, function_end):
+ """Parse the line of instruction.
+
+ Args:
+ line: Text of disassembly.
+ function_end: End address of the current function. None if unknown.
+
+ Returns:
+ (address, opcode, operand_text): The instruction address, opcode,
+ and the text of operands. None if it
+ isn't an instruction line.
+ """
+ result = self.DISASM_REGEX_RE.match(line)
if result is None:
- if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None):
- # Found an indirect call.
- callsites.append(Callsite(address, None, is_tail))
-
- else:
- # Capture address form operand_text and then convert to string
- address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text))
- # String to integer
- target_address = int(address_str, 16)
- # Filter out the in-function target (branches and in-function calls,
- # which are actually branches).
- if not (function_symbol.size > 0 and
- function_symbol.address < target_address <
- (function_symbol.address + function_symbol.size)):
- # Maybe it is a callsite.
- callsites.append(Callsite(address, target_address, is_tail))
-
- elif self.ADDI_OPCODE_RE.match(opcode) is not None:
- # Example: sp,sp,-32
- if self.ADDI_OPERAND_RE.match(operand_text) is not None:
- stack_frame += abs(int(operand_text.split(",")[2]))
-
- return (stack_frame, callsites)
-
-class StackAnalyzer(object):
- """Class to analyze stack usage.
-
- Public Methods:
- Analyze: Run the stack analysis.
- """
-
- C_FUNCTION_NAME = r'_A-Za-z0-9'
-
- # Assume there is no ":" in the path.
- # Example: "driver/accel_kionix.c:321 (discriminator 3)"
- ADDRTOLINE_RE = re.compile(
- r'^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$')
- # To eliminate the suffix appended by compilers, try to extract the
- # C function name from the prefix of symbol name.
- # Example: "SHA256_transform.constprop.28"
- FUNCTION_PREFIX_NAME_RE = re.compile(
- r'^(?P<name>[{0}]+)([^{0}].*)?$'.format(C_FUNCTION_NAME))
-
- # Errors of annotation resolving.
- ANNOTATION_ERROR_INVALID = 'invalid signature'
- ANNOTATION_ERROR_NOTFOUND = 'function is not found'
- ANNOTATION_ERROR_AMBIGUOUS = 'signature is ambiguous'
-
- def __init__(self, options, symbols, rodata, tasklist, annotation):
- """Constructor.
-
- Args:
- options: Namespace from argparse.parse_args().
- symbols: Symbol list.
- rodata: Content of .rodata section (offset, data)
- tasklist: Task list.
- annotation: Annotation config.
- """
- self.options = options
- self.symbols = symbols
- self.rodata_offset = rodata[0]
- self.rodata = rodata[1]
- self.tasklist = tasklist
- self.annotation = annotation
- self.address_to_line_cache = {}
-
- def AddressToLine(self, address, resolve_inline=False):
- """Convert address to line.
-
- Args:
- address: Target address.
- resolve_inline: Output the stack of inlining.
-
- Returns:
- lines: List of the corresponding lines.
-
- Raises:
- StackAnalyzerError: If addr2line is failed.
- """
- cache_key = (address, resolve_inline)
- if cache_key in self.address_to_line_cache:
- return self.address_to_line_cache[cache_key]
-
- try:
- args = [self.options.addr2line,
- '-f',
- '-e',
- self.options.elf_path,
- '{:x}'.format(address)]
- if resolve_inline:
- args.append('-i')
-
- line_text = subprocess.check_output(args, encoding='utf-8')
- except subprocess.CalledProcessError:
- raise StackAnalyzerError('addr2line failed to resolve lines.')
- except OSError:
- raise StackAnalyzerError('Failed to run addr2line.')
-
- lines = [line.strip() for line in line_text.splitlines()]
- # Assume the output has at least one pair like "function\nlocation\n", and
- # they always show up in pairs.
- # Example: "handle_request\n
- # common/usb_pd_protocol.c:1191\n"
- assert len(lines) >= 2 and len(lines) % 2 == 0
-
- line_infos = []
- for index in range(0, len(lines), 2):
- (function_name, line_text) = lines[index:index + 2]
- if line_text in ['??:0', ':?']:
- line_infos.append(None)
- else:
- result = self.ADDRTOLINE_RE.match(line_text)
- # Assume the output is always well-formed.
- assert result is not None
- line_infos.append((function_name.strip(),
- os.path.realpath(result.group('path').strip()),
- int(result.group('linenum'))))
-
- self.address_to_line_cache[cache_key] = line_infos
- return line_infos
-
- def AnalyzeDisassembly(self, disasm_text):
- """Parse the disassembly text, analyze, and build a map of all functions.
-
- Args:
- disasm_text: Disassembly text.
-
- Returns:
- function_map: Dict of functions.
- """
- disasm_lines = [line.strip() for line in disasm_text.splitlines()]
-
- if 'nds' in disasm_lines[1]:
- analyzer = AndesAnalyzer()
- elif 'arm' in disasm_lines[1]:
- analyzer = ArmAnalyzer()
- elif 'riscv' in disasm_lines[1]:
- analyzer = RiscvAnalyzer()
- else:
- raise StackAnalyzerError('Unsupported architecture.')
-
- # Example: "08028c8c <motion_lid_calc>:"
- function_signature_regex = re.compile(
- r'^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$')
-
- def DetectFunctionHead(line):
- """Check if the line is a function head.
-
- Args:
- line: Text of disassembly.
-
- Returns:
- symbol: Function symbol. None if it isn't a function head.
- """
- result = function_signature_regex.match(line)
- if result is None:
- return None
-
- address = int(result.group('address'), 16)
- symbol = symbol_map.get(address)
-
- # Check if the function exists and matches.
- if symbol is None or symbol.symtype != 'F':
- return None
-
- return symbol
-
- # Build symbol map, indexed by symbol address.
- symbol_map = {}
- for symbol in self.symbols:
- # If there are multiple symbols with same address, keeping any of them is
- # good enough.
- symbol_map[symbol.address] = symbol
-
- # Parse the disassembly text. We update the variable "line" to next line
- # when needed. There are two steps of parser:
- #
- # Step 1: Searching for the function head. Once reach the function head,
- # move to the next line, which is the first line of function body.
- #
- # Step 2: Parsing each instruction line of function body. Once reach a
- # non-instruction line, stop parsing and analyze the parsed instructions.
- #
- # Finally turn back to the step 1 without updating the line, because the
- # current non-instruction line can be another function head.
- function_map = {}
- # The following three variables are the states of the parsing processing.
- # They will be initialized properly during the state changes.
- function_symbol = None
- function_end = None
- instructions = []
-
- # Remove heading and tailing spaces for each line.
- line_index = 0
- while line_index < len(disasm_lines):
- # Get the current line.
- line = disasm_lines[line_index]
-
- if function_symbol is None:
- # Step 1: Search for the function head.
-
- function_symbol = DetectFunctionHead(line)
- if function_symbol is not None:
- # Assume there is no empty function. If the function head is followed
- # by EOF, it is an empty function.
- assert line_index + 1 < len(disasm_lines)
-
- # Found the function head, initialize and turn to the step 2.
- instructions = []
- # If symbol size exists, use it as a hint of function size.
- if function_symbol.size > 0:
- function_end = function_symbol.address + function_symbol.size
- else:
- function_end = None
-
- else:
- # Step 2: Parse the function body.
-
- instruction = analyzer.ParseInstruction(line, function_end)
- if instruction is not None:
- instructions.append(instruction)
-
- if instruction is None or line_index + 1 == len(disasm_lines):
- # Either the invalid instruction or EOF indicates the end of the
- # function, finalize the function analysis.
-
- # Assume there is no empty function.
- assert len(instructions) > 0
-
- (stack_frame, callsites) = analyzer.AnalyzeFunction(function_symbol,
- instructions)
- # Assume the function addresses are unique in the disassembly.
- assert function_symbol.address not in function_map
- function_map[function_symbol.address] = Function(
- function_symbol.address,
- function_symbol.name,
- stack_frame,
- callsites)
-
- # Initialize and turn back to the step 1.
- function_symbol = None
-
- # If the current line isn't an instruction, it can be another function
- # head, skip moving to the next line.
- if instruction is None:
- continue
-
- # Move to the next line.
- line_index += 1
+ return None
- # Resolve callees of functions.
- for function in function_map.values():
- for callsite in function.callsites:
- if callsite.target is not None:
- # Remain the callee as None if we can't resolve it.
- callsite.callee = function_map.get(callsite.target)
+ address = int(result.group("address"), 16)
+ # Check if it's out of bound.
+ if function_end is not None and address >= function_end:
+ return None
- return function_map
+ opcode = result.group("opcode").strip()
+ operand_text = result.group("operand")
+ if operand_text is None:
+ operand_text = ""
+ else:
+ operand_text = operand_text.strip()
+
+ return (address, opcode, operand_text)
+
+ def AnalyzeFunction(self, function_symbol, instructions):
+
+ stack_frame = 0
+ callsites = []
+ for address, opcode, operand_text in instructions:
+ is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None
+ is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None
+
+ if is_jump_opcode or is_call_opcode:
+ is_tail = is_jump_opcode
+
+ result = self.CALL_OPERAND_RE.match(operand_text)
+ if result is None:
+ if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None:
+ # Found an indirect call.
+ callsites.append(Callsite(address, None, is_tail))
+
+ else:
+ # Capture address form operand_text and then convert to string
+ address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text))
+ # String to integer
+ target_address = int(address_str, 16)
+ # Filter out the in-function target (branches and in-function calls,
+ # which are actually branches).
+ if not (
+ function_symbol.size > 0
+ and function_symbol.address
+ < target_address
+ < (function_symbol.address + function_symbol.size)
+ ):
+ # Maybe it is a callsite.
+ callsites.append(Callsite(address, target_address, is_tail))
+
+ elif self.ADDI_OPCODE_RE.match(opcode) is not None:
+ # Example: sp,sp,-32
+ if self.ADDI_OPERAND_RE.match(operand_text) is not None:
+ stack_frame += abs(int(operand_text.split(",")[2]))
+
+ return (stack_frame, callsites)
- def MapAnnotation(self, function_map, signature_set):
- """Map annotation signatures to functions.
- Args:
- function_map: Function map.
- signature_set: Set of annotation signatures.
+class StackAnalyzer(object):
+ """Class to analyze stack usage.
- Returns:
- Map of signatures to functions, map of signatures which can't be resolved.
+ Public Methods:
+ Analyze: Run the stack analysis.
"""
- # Build the symbol map indexed by symbol name. If there are multiple symbols
- # with the same name, add them into a set. (e.g. symbols of static function
- # with the same name)
- symbol_map = collections.defaultdict(set)
- for symbol in self.symbols:
- if symbol.symtype == 'F':
- # Function symbol.
- result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
- if result is not None:
- function = function_map.get(symbol.address)
- # Ignore the symbol not in disassembly.
- if function is not None:
- # If there are multiple symbol with the same name and point to the
- # same function, the set will deduplicate them.
- symbol_map[result.group('name').strip()].add(function)
-
- # Build the signature map indexed by annotation signature.
- signature_map = {}
- sig_error_map = {}
- symbol_path_map = {}
- for sig in signature_set:
- (name, path, _) = sig
-
- functions = symbol_map.get(name)
- if functions is None:
- sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
- continue
-
- if name not in symbol_path_map:
- # Lazy symbol path resolving. Since the addr2line isn't fast, only
- # resolve needed symbol paths.
- group_map = collections.defaultdict(list)
- for function in functions:
- line_info = self.AddressToLine(function.address)[0]
- if line_info is None:
- continue
- (_, symbol_path, _) = line_info
+ C_FUNCTION_NAME = r"_A-Za-z0-9"
- # Group the functions with the same symbol signature (symbol name +
- # symbol path). Assume they are the same copies and do the same
- # annotation operations of them because we don't know which copy is
- # indicated by the users.
- group_map[symbol_path].append(function)
+ # Assume there is no ":" in the path.
+ # Example: "driver/accel_kionix.c:321 (discriminator 3)"
+ ADDRTOLINE_RE = re.compile(
+ r"^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$"
+ )
+ # To eliminate the suffix appended by compilers, try to extract the
+ # C function name from the prefix of symbol name.
+ # Example: "SHA256_transform.constprop.28"
+ FUNCTION_PREFIX_NAME_RE = re.compile(
+ r"^(?P<name>[{0}]+)([^{0}].*)?$".format(C_FUNCTION_NAME)
+ )
+
+ # Errors of annotation resolving.
+ ANNOTATION_ERROR_INVALID = "invalid signature"
+ ANNOTATION_ERROR_NOTFOUND = "function is not found"
+ ANNOTATION_ERROR_AMBIGUOUS = "signature is ambiguous"
+
+ def __init__(self, options, symbols, rodata, tasklist, annotation):
+ """Constructor.
+
+ Args:
+ options: Namespace from argparse.parse_args().
+ symbols: Symbol list.
+ rodata: Content of .rodata section (offset, data)
+ tasklist: Task list.
+ annotation: Annotation config.
+ """
+ self.options = options
+ self.symbols = symbols
+ self.rodata_offset = rodata[0]
+ self.rodata = rodata[1]
+ self.tasklist = tasklist
+ self.annotation = annotation
+ self.address_to_line_cache = {}
+
+ def AddressToLine(self, address, resolve_inline=False):
+ """Convert address to line.
+
+ Args:
+ address: Target address.
+ resolve_inline: Output the stack of inlining.
+
+ Returns:
+ lines: List of the corresponding lines.
+
+ Raises:
+ StackAnalyzerError: If addr2line is failed.
+ """
+ cache_key = (address, resolve_inline)
+ if cache_key in self.address_to_line_cache:
+ return self.address_to_line_cache[cache_key]
+
+ try:
+ args = [
+ self.options.addr2line,
+ "-f",
+ "-e",
+ self.options.elf_path,
+ "{:x}".format(address),
+ ]
+ if resolve_inline:
+ args.append("-i")
+
+ line_text = subprocess.check_output(args, encoding="utf-8")
+ except subprocess.CalledProcessError:
+ raise StackAnalyzerError("addr2line failed to resolve lines.")
+ except OSError:
+ raise StackAnalyzerError("Failed to run addr2line.")
+
+ lines = [line.strip() for line in line_text.splitlines()]
+ # Assume the output has at least one pair like "function\nlocation\n", and
+ # they always show up in pairs.
+ # Example: "handle_request\n
+ # common/usb_pd_protocol.c:1191\n"
+ assert len(lines) >= 2 and len(lines) % 2 == 0
+
+ line_infos = []
+ for index in range(0, len(lines), 2):
+ (function_name, line_text) = lines[index : index + 2]
+ if line_text in ["??:0", ":?"]:
+ line_infos.append(None)
+ else:
+ result = self.ADDRTOLINE_RE.match(line_text)
+ # Assume the output is always well-formed.
+ assert result is not None
+ line_infos.append(
+ (
+ function_name.strip(),
+ os.path.realpath(result.group("path").strip()),
+ int(result.group("linenum")),
+ )
+ )
+
+ self.address_to_line_cache[cache_key] = line_infos
+ return line_infos
+
+ def AnalyzeDisassembly(self, disasm_text):
+ """Parse the disassembly text, analyze, and build a map of all functions.
+
+ Args:
+ disasm_text: Disassembly text.
+
+ Returns:
+ function_map: Dict of functions.
+ """
+ disasm_lines = [line.strip() for line in disasm_text.splitlines()]
+
+ if "nds" in disasm_lines[1]:
+ analyzer = AndesAnalyzer()
+ elif "arm" in disasm_lines[1]:
+ analyzer = ArmAnalyzer()
+ elif "riscv" in disasm_lines[1]:
+ analyzer = RiscvAnalyzer()
+ else:
+ raise StackAnalyzerError("Unsupported architecture.")
- symbol_path_map[name] = group_map
+ # Example: "08028c8c <motion_lid_calc>:"
+ function_signature_regex = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$"
+ )
- # Symbol matching.
- function_group = None
- group_map = symbol_path_map[name]
- if len(group_map) > 0:
- if path is None:
- if len(group_map) > 1:
- # There is ambiguity but the path isn't specified.
- sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS
- continue
+ def DetectFunctionHead(line):
+ """Check if the line is a function head.
- # No path signature but all symbol signatures of functions are same.
- # Assume they are the same functions, so there is no ambiguity.
- (function_group,) = group_map.values()
- else:
- function_group = group_map.get(path)
+ Args:
+ line: Text of disassembly.
- if function_group is None:
- sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
- continue
+ Returns:
+ symbol: Function symbol. None if it isn't a function head.
+ """
+ result = function_signature_regex.match(line)
+ if result is None:
+ return None
- # The function_group is a list of all the same functions (according to
- # our assumption) which should be annotated together.
- signature_map[sig] = function_group
+ address = int(result.group("address"), 16)
+ symbol = symbol_map.get(address)
- return (signature_map, sig_error_map)
+ # Check if the function exists and matches.
+ if symbol is None or symbol.symtype != "F":
+ return None
- def LoadAnnotation(self):
- """Load annotation rules.
+ return symbol
- Returns:
- Map of add rules, set of remove rules, set of text signatures which can't
- be parsed.
- """
- # Assume there is no ":" in the path.
- # Example: "get_range.lto.2501[driver/accel_kionix.c:327]"
- annotation_signature_regex = re.compile(
- r'^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$')
-
- def NormalizeSignature(signature_text):
- """Parse and normalize the annotation signature.
-
- Args:
- signature_text: Text of the annotation signature.
-
- Returns:
- (function name, path, line number) of the signature. The path and line
- number can be None if not exist. None if failed to parse.
- """
- result = annotation_signature_regex.match(signature_text.strip())
- if result is None:
- return None
-
- name_result = self.FUNCTION_PREFIX_NAME_RE.match(
- result.group('name').strip())
- if name_result is None:
- return None
-
- path = result.group('path')
- if path is not None:
- path = os.path.realpath(path.strip())
-
- linenum = result.group('linenum')
- if linenum is not None:
- linenum = int(linenum.strip())
-
- return (name_result.group('name').strip(), path, linenum)
-
- def ExpandArray(dic):
- """Parse and expand a symbol array
-
- Args:
- dic: Dictionary for the array annotation
-
- Returns:
- array of (symbol name, None, None).
- """
- # TODO(drinkcat): This function is quite inefficient, as it goes through
- # the symbol table multiple times.
-
- begin_name = dic['name']
- end_name = dic['name'] + "_end"
- offset = dic['offset'] if 'offset' in dic else 0
- stride = dic['stride']
-
- begin_address = None
- end_address = None
-
- for symbol in self.symbols:
- if (symbol.name == begin_name):
- begin_address = symbol.address
- if (symbol.name == end_name):
- end_address = symbol.address
-
- if (not begin_address or not end_address):
- return None
-
- output = []
- # TODO(drinkcat): This is inefficient as we go from address to symbol
- # object then to symbol name, and later on we'll go back from symbol name
- # to symbol object.
- for addr in range(begin_address+offset, end_address, stride):
- # TODO(drinkcat): Not all architectures need to drop the first bit.
- val = self.rodata[(addr-self.rodata_offset) // 4] & 0xfffffffe
- name = None
+ # Build symbol map, indexed by symbol address.
+ symbol_map = {}
for symbol in self.symbols:
- if (symbol.address == val):
- result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
- name = result.group('name')
- break
-
- if not name:
- raise StackAnalyzerError('Cannot find function for address %s.',
- hex(val))
-
- output.append((name, None, None))
-
- return output
-
- add_rules = collections.defaultdict(set)
- remove_rules = list()
- invalid_sigtxts = set()
-
- if 'add' in self.annotation and self.annotation['add'] is not None:
- for src_sigtxt, dst_sigtxts in self.annotation['add'].items():
- src_sig = NormalizeSignature(src_sigtxt)
- if src_sig is None:
- invalid_sigtxts.add(src_sigtxt)
- continue
-
- for dst_sigtxt in dst_sigtxts:
- if isinstance(dst_sigtxt, dict):
- dst_sig = ExpandArray(dst_sigtxt)
- if dst_sig is None:
- invalid_sigtxts.add(str(dst_sigtxt))
+ # If there are multiple symbols with same address, keeping any of them is
+ # good enough.
+ symbol_map[symbol.address] = symbol
+
+ # Parse the disassembly text. We update the variable "line" to next line
+ # when needed. There are two steps of parser:
+ #
+ # Step 1: Searching for the function head. Once reach the function head,
+ # move to the next line, which is the first line of function body.
+ #
+ # Step 2: Parsing each instruction line of function body. Once reach a
+ # non-instruction line, stop parsing and analyze the parsed instructions.
+ #
+ # Finally turn back to the step 1 without updating the line, because the
+ # current non-instruction line can be another function head.
+ function_map = {}
+ # The following three variables are the states of the parsing processing.
+ # They will be initialized properly during the state changes.
+ function_symbol = None
+ function_end = None
+ instructions = []
+
+ # Remove heading and tailing spaces for each line.
+ line_index = 0
+ while line_index < len(disasm_lines):
+ # Get the current line.
+ line = disasm_lines[line_index]
+
+ if function_symbol is None:
+ # Step 1: Search for the function head.
+
+ function_symbol = DetectFunctionHead(line)
+ if function_symbol is not None:
+ # Assume there is no empty function. If the function head is followed
+ # by EOF, it is an empty function.
+ assert line_index + 1 < len(disasm_lines)
+
+ # Found the function head, initialize and turn to the step 2.
+ instructions = []
+ # If symbol size exists, use it as a hint of function size.
+ if function_symbol.size > 0:
+ function_end = function_symbol.address + function_symbol.size
+ else:
+ function_end = None
+
else:
- add_rules[src_sig].update(dst_sig)
- else:
- dst_sig = NormalizeSignature(dst_sigtxt)
- if dst_sig is None:
- invalid_sigtxts.add(dst_sigtxt)
+ # Step 2: Parse the function body.
+
+ instruction = analyzer.ParseInstruction(line, function_end)
+ if instruction is not None:
+ instructions.append(instruction)
+
+ if instruction is None or line_index + 1 == len(disasm_lines):
+ # Either the invalid instruction or EOF indicates the end of the
+ # function, finalize the function analysis.
+
+ # Assume there is no empty function.
+ assert len(instructions) > 0
+
+ (stack_frame, callsites) = analyzer.AnalyzeFunction(
+ function_symbol, instructions
+ )
+ # Assume the function addresses are unique in the disassembly.
+ assert function_symbol.address not in function_map
+ function_map[function_symbol.address] = Function(
+ function_symbol.address,
+ function_symbol.name,
+ stack_frame,
+ callsites,
+ )
+
+ # Initialize and turn back to the step 1.
+ function_symbol = None
+
+ # If the current line isn't an instruction, it can be another function
+ # head, skip moving to the next line.
+ if instruction is None:
+ continue
+
+ # Move to the next line.
+ line_index += 1
+
+ # Resolve callees of functions.
+ for function in function_map.values():
+ for callsite in function.callsites:
+ if callsite.target is not None:
+ # Remain the callee as None if we can't resolve it.
+ callsite.callee = function_map.get(callsite.target)
+
+ return function_map
+
+ def MapAnnotation(self, function_map, signature_set):
+ """Map annotation signatures to functions.
+
+ Args:
+ function_map: Function map.
+ signature_set: Set of annotation signatures.
+
+ Returns:
+ Map of signatures to functions, map of signatures which can't be resolved.
+ """
+ # Build the symbol map indexed by symbol name. If there are multiple symbols
+ # with the same name, add them into a set. (e.g. symbols of static function
+ # with the same name)
+ symbol_map = collections.defaultdict(set)
+ for symbol in self.symbols:
+ if symbol.symtype == "F":
+ # Function symbol.
+ result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
+ if result is not None:
+ function = function_map.get(symbol.address)
+ # Ignore the symbol not in disassembly.
+ if function is not None:
+ # If there are multiple symbol with the same name and point to the
+ # same function, the set will deduplicate them.
+ symbol_map[result.group("name").strip()].add(function)
+
+ # Build the signature map indexed by annotation signature.
+ signature_map = {}
+ sig_error_map = {}
+ symbol_path_map = {}
+ for sig in signature_set:
+ (name, path, _) = sig
+
+ functions = symbol_map.get(name)
+ if functions is None:
+ sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
+ continue
+
+ if name not in symbol_path_map:
+ # Lazy symbol path resolving. Since the addr2line isn't fast, only
+ # resolve needed symbol paths.
+ group_map = collections.defaultdict(list)
+ for function in functions:
+ line_info = self.AddressToLine(function.address)[0]
+ if line_info is None:
+ continue
+
+ (_, symbol_path, _) = line_info
+
+ # Group the functions with the same symbol signature (symbol name +
+ # symbol path). Assume they are the same copies and do the same
+ # annotation operations of them because we don't know which copy is
+ # indicated by the users.
+ group_map[symbol_path].append(function)
+
+ symbol_path_map[name] = group_map
+
+ # Symbol matching.
+ function_group = None
+ group_map = symbol_path_map[name]
+ if len(group_map) > 0:
+ if path is None:
+ if len(group_map) > 1:
+ # There is ambiguity but the path isn't specified.
+ sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS
+ continue
+
+ # No path signature but all symbol signatures of functions are same.
+ # Assume they are the same functions, so there is no ambiguity.
+ (function_group,) = group_map.values()
+ else:
+ function_group = group_map.get(path)
+
+ if function_group is None:
+ sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND
+ continue
+
+ # The function_group is a list of all the same functions (according to
+ # our assumption) which should be annotated together.
+ signature_map[sig] = function_group
+
+ return (signature_map, sig_error_map)
+
+ def LoadAnnotation(self):
+ """Load annotation rules.
+
+ Returns:
+ Map of add rules, set of remove rules, set of text signatures which can't
+ be parsed.
+ """
+ # Assume there is no ":" in the path.
+ # Example: "get_range.lto.2501[driver/accel_kionix.c:327]"
+ annotation_signature_regex = re.compile(
+ r"^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$"
+ )
+
+ def NormalizeSignature(signature_text):
+ """Parse and normalize the annotation signature.
+
+ Args:
+ signature_text: Text of the annotation signature.
+
+ Returns:
+ (function name, path, line number) of the signature. The path and line
+ number can be None if not exist. None if failed to parse.
+ """
+ result = annotation_signature_regex.match(signature_text.strip())
+ if result is None:
+ return None
+
+ name_result = self.FUNCTION_PREFIX_NAME_RE.match(
+ result.group("name").strip()
+ )
+ if name_result is None:
+ return None
+
+ path = result.group("path")
+ if path is not None:
+ path = os.path.realpath(path.strip())
+
+ linenum = result.group("linenum")
+ if linenum is not None:
+ linenum = int(linenum.strip())
+
+ return (name_result.group("name").strip(), path, linenum)
+
+ def ExpandArray(dic):
+ """Parse and expand a symbol array
+
+ Args:
+ dic: Dictionary for the array annotation
+
+ Returns:
+ array of (symbol name, None, None).
+ """
+ # TODO(drinkcat): This function is quite inefficient, as it goes through
+ # the symbol table multiple times.
+
+ begin_name = dic["name"]
+ end_name = dic["name"] + "_end"
+ offset = dic["offset"] if "offset" in dic else 0
+ stride = dic["stride"]
+
+ begin_address = None
+ end_address = None
+
+ for symbol in self.symbols:
+ if symbol.name == begin_name:
+ begin_address = symbol.address
+ if symbol.name == end_name:
+ end_address = symbol.address
+
+ if not begin_address or not end_address:
+ return None
+
+ output = []
+ # TODO(drinkcat): This is inefficient as we go from address to symbol
+ # object then to symbol name, and later on we'll go back from symbol name
+ # to symbol object.
+ for addr in range(begin_address + offset, end_address, stride):
+ # TODO(drinkcat): Not all architectures need to drop the first bit.
+ val = self.rodata[(addr - self.rodata_offset) // 4] & 0xFFFFFFFE
+ name = None
+ for symbol in self.symbols:
+ if symbol.address == val:
+ result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name)
+ name = result.group("name")
+ break
+
+ if not name:
+ raise StackAnalyzerError(
+ "Cannot find function for address %s.", hex(val)
+ )
+
+ output.append((name, None, None))
+
+ return output
+
+ add_rules = collections.defaultdict(set)
+ remove_rules = list()
+ invalid_sigtxts = set()
+
+ if "add" in self.annotation and self.annotation["add"] is not None:
+ for src_sigtxt, dst_sigtxts in self.annotation["add"].items():
+ src_sig = NormalizeSignature(src_sigtxt)
+ if src_sig is None:
+ invalid_sigtxts.add(src_sigtxt)
+ continue
+
+ for dst_sigtxt in dst_sigtxts:
+ if isinstance(dst_sigtxt, dict):
+ dst_sig = ExpandArray(dst_sigtxt)
+ if dst_sig is None:
+ invalid_sigtxts.add(str(dst_sigtxt))
+ else:
+ add_rules[src_sig].update(dst_sig)
+ else:
+ dst_sig = NormalizeSignature(dst_sigtxt)
+ if dst_sig is None:
+ invalid_sigtxts.add(dst_sigtxt)
+ else:
+ add_rules[src_sig].add(dst_sig)
+
+ if "remove" in self.annotation and self.annotation["remove"] is not None:
+ for sigtxt_path in self.annotation["remove"]:
+ if isinstance(sigtxt_path, str):
+ # The path has only one vertex.
+ sigtxt_path = [sigtxt_path]
+
+ if len(sigtxt_path) == 0:
+ continue
+
+ # Generate multiple remove paths from all the combinations of the
+ # signatures of each vertex.
+ sig_paths = [[]]
+ broken_flag = False
+ for sigtxt_node in sigtxt_path:
+ if isinstance(sigtxt_node, str):
+ # The vertex has only one signature.
+ sigtxt_set = {sigtxt_node}
+ elif isinstance(sigtxt_node, list):
+ # The vertex has multiple signatures.
+ sigtxt_set = set(sigtxt_node)
+ else:
+ # Assume the format of annotation is verified. There should be no
+ # invalid case.
+ assert False
+
+ sig_set = set()
+ for sigtxt in sigtxt_set:
+ sig = NormalizeSignature(sigtxt)
+ if sig is None:
+ invalid_sigtxts.add(sigtxt)
+ broken_flag = True
+ elif not broken_flag:
+ sig_set.add(sig)
+
+ if broken_flag:
+ continue
+
+ # Append each signature of the current node to the all previous
+ # remove paths.
+ sig_paths = [path + [sig] for path in sig_paths for sig in sig_set]
+
+ if not broken_flag:
+ # All signatures are normalized. The remove path has no error.
+ remove_rules.extend(sig_paths)
+
+ return (add_rules, remove_rules, invalid_sigtxts)
+
+ def ResolveAnnotation(self, function_map):
+ """Resolve annotation.
+
+ Args:
+ function_map: Function map.
+
+ Returns:
+ Set of added call edges, list of remove paths, set of eliminated
+ callsite addresses, set of annotation signatures which can't be resolved.
+ """
+
+ def StringifySignature(signature):
+ """Stringify the tupled signature.
+
+ Args:
+ signature: Tupled signature.
+
+ Returns:
+ Signature string.
+ """
+ (name, path, linenum) = signature
+ bracket_text = ""
+ if path is not None:
+ path = os.path.relpath(path)
+ if linenum is None:
+ bracket_text = "[{}]".format(path)
+ else:
+ bracket_text = "[{}:{}]".format(path, linenum)
+
+ return name + bracket_text
+
+ (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation()
+
+ signature_set = set()
+ for src_sig, dst_sigs in add_rules.items():
+ signature_set.add(src_sig)
+ signature_set.update(dst_sigs)
+
+ for remove_sigs in remove_rules:
+ signature_set.update(remove_sigs)
+
+ # Map signatures to functions.
+ (signature_map, sig_error_map) = self.MapAnnotation(function_map, signature_set)
+
+ # Build the indirect callsite map indexed by callsite signature.
+ indirect_map = collections.defaultdict(set)
+ for function in function_map.values():
+ for callsite in function.callsites:
+ if callsite.target is not None:
+ continue
+
+ # Found an indirect callsite.
+ line_info = self.AddressToLine(callsite.address)[0]
+ if line_info is None:
+ continue
+
+ (name, path, linenum) = line_info
+ result = self.FUNCTION_PREFIX_NAME_RE.match(name)
+ if result is None:
+ continue
+
+ indirect_map[(result.group("name").strip(), path, linenum)].add(
+ (function, callsite.address)
+ )
+
+ # Generate the annotation sets.
+ add_set = set()
+ remove_list = list()
+ eliminated_addrs = set()
+
+ for src_sig, dst_sigs in add_rules.items():
+ src_funcs = set(signature_map.get(src_sig, []))
+ # Try to match the source signature to the indirect callsites. Even if it
+ # can't be found in disassembly.
+ indirect_calls = indirect_map.get(src_sig)
+ if indirect_calls is not None:
+ for function, callsite_address in indirect_calls:
+ # Add the caller of the indirect callsite to the source functions.
+ src_funcs.add(function)
+ # Assume each callsite can be represented by a unique address.
+ eliminated_addrs.add(callsite_address)
+
+ if src_sig in sig_error_map:
+ # Assume the error is always the not found error. Since the signature
+ # found in indirect callsite map must be a full signature, it can't
+ # happen the ambiguous error.
+ assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND
+ # Found in inline stack, remove the not found error.
+ del sig_error_map[src_sig]
+
+ for dst_sig in dst_sigs:
+ dst_funcs = signature_map.get(dst_sig)
+ if dst_funcs is None:
+ continue
+
+ # Duplicate the call edge for all the same source and destination
+ # functions.
+ for src_func in src_funcs:
+ for dst_func in dst_funcs:
+ add_set.add((src_func, dst_func))
+
+ for remove_sigs in remove_rules:
+ # Since each signature can be mapped to multiple functions, generate
+ # multiple remove paths from all the combinations of these functions.
+ remove_paths = [[]]
+ skip_flag = False
+ for remove_sig in remove_sigs:
+ # Transform each signature to the corresponding functions.
+ remove_funcs = signature_map.get(remove_sig)
+ if remove_funcs is None:
+ # There is an unresolved signature in the remove path. Ignore the
+ # whole broken remove path.
+ skip_flag = True
+ break
+ else:
+ # Append each function of the current signature to the all previous
+ # remove paths.
+ remove_paths = [p + [f] for p in remove_paths for f in remove_funcs]
+
+ if skip_flag:
+ # Ignore the broken remove path.
+ continue
+
+ for remove_path in remove_paths:
+ # Deduplicate the remove paths.
+ if remove_path not in remove_list:
+ remove_list.append(remove_path)
+
+ # Format the error messages.
+ failed_sigtxts = set()
+ for sigtxt in invalid_sigtxts:
+ failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID))
+
+ for sig, error in sig_error_map.items():
+ failed_sigtxts.add((StringifySignature(sig), error))
+
+ return (add_set, remove_list, eliminated_addrs, failed_sigtxts)
+
+ def PreprocessAnnotation(
+ self, function_map, add_set, remove_list, eliminated_addrs
+ ):
+ """Preprocess the annotation and callgraph.
+
+ Add the missing call edges, and delete simple remove paths (the paths have
+ one or two vertices) from the function_map.
+
+ Eliminate the annotated indirect callsites.
+
+ Return the remaining remove list.
+
+ Args:
+ function_map: Function map.
+ add_set: Set of missing call edges.
+ remove_list: List of remove paths.
+ eliminated_addrs: Set of eliminated callsite addresses.
+
+ Returns:
+ List of remaining remove paths.
+ """
+
+ def CheckEdge(path):
+ """Check if all edges of the path are on the callgraph.
+
+ Args:
+ path: Path.
+
+ Returns:
+ True or False.
+ """
+ for index in range(len(path) - 1):
+ if (path[index], path[index + 1]) not in edge_set:
+ return False
+
+ return True
+
+ for src_func, dst_func in add_set:
+ # TODO(cheyuw): Support tailing call annotation.
+ src_func.callsites.append(Callsite(None, dst_func.address, False, dst_func))
+
+ # Delete simple remove paths.
+ remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2)
+ edge_set = set()
+ for function in function_map.values():
+ cleaned_callsites = []
+ for callsite in function.callsites:
+ if (callsite.callee,) in remove_simple or (
+ function,
+ callsite.callee,
+ ) in remove_simple:
+ continue
+
+ if callsite.target is None and callsite.address in eliminated_addrs:
+ continue
+
+ cleaned_callsites.append(callsite)
+ if callsite.callee is not None:
+ edge_set.add((function, callsite.callee))
+
+ function.callsites = cleaned_callsites
+
+ return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)]
+
+ def AnalyzeCallGraph(self, function_map, remove_list):
+ """Analyze callgraph.
+
+ It will update the max stack size and path for each function.
+
+ Args:
+ function_map: Function map.
+ remove_list: List of remove paths.
+
+ Returns:
+ List of function cycles.
+ """
+
+ def Traverse(curr_state):
+ """Traverse the callgraph and calculate the max stack usages of functions.
+
+ Args:
+ curr_state: Current state.
+
+ Returns:
+ SCC lowest link.
+ """
+ scc_index = scc_index_counter[0]
+ scc_index_counter[0] += 1
+ scc_index_map[curr_state] = scc_index
+ scc_lowlink = scc_index
+ scc_stack.append(curr_state)
+ # Push the current state in the stack. We can use a set to maintain this
+ # because the stacked states are unique; otherwise we will find a cycle
+ # first.
+ stacked_states.add(curr_state)
+
+ (curr_address, curr_positions) = curr_state
+ curr_func = function_map[curr_address]
+
+ invalid_flag = False
+ new_positions = list(curr_positions)
+ for index, position in enumerate(curr_positions):
+ remove_path = remove_list[index]
+
+ # The position of each remove path in the state is the length of the
+ # longest matching path between the prefix of the remove path and the
+ # suffix of the current traversing path. We maintain this length when
+ # appending the next callee to the traversing path. And it can be used
+ # to check if the remove path appears in the traversing path.
+
+ # TODO(cheyuw): Implement KMP algorithm to match remove paths
+ # efficiently.
+ if remove_path[position] is curr_func:
+ # Matches the current function, extend the length.
+ new_positions[index] = position + 1
+ if new_positions[index] == len(remove_path):
+ # The length of the longest matching path is equal to the length of
+ # the remove path, which means the suffix of the current traversing
+ # path matches the remove path.
+ invalid_flag = True
+ break
+
+ else:
+ # We can't get the new longest matching path by extending the previous
+ # one directly. Fallback to search the new longest matching path.
+
+ # If we can't find any matching path in the following search, reset
+ # the matching length to 0.
+ new_positions[index] = 0
+
+ # We want to find the new longest matching prefix of remove path with
+ # the suffix of the current traversing path. Because the new longest
+ # matching path won't be longer than the prevous one now, and part of
+ # the suffix matches the prefix of remove path, we can get the needed
+ # suffix from the previous matching prefix of the invalid path.
+ suffix = remove_path[:position] + [curr_func]
+ for offset in range(1, len(suffix)):
+ length = position - offset
+ if remove_path[:length] == suffix[offset:]:
+ new_positions[index] = length
+ break
+
+ new_positions = tuple(new_positions)
+
+ # If the current suffix is invalid, set the max stack usage to 0.
+ max_stack_usage = 0
+ max_callee_state = None
+ self_loop = False
+
+ if not invalid_flag:
+ # Max stack usage is at least equal to the stack frame.
+ max_stack_usage = curr_func.stack_frame
+ for callsite in curr_func.callsites:
+ callee = callsite.callee
+ if callee is None:
+ continue
+
+ callee_state = (callee.address, new_positions)
+ if callee_state not in scc_index_map:
+ # Unvisited state.
+ scc_lowlink = min(scc_lowlink, Traverse(callee_state))
+ elif callee_state in stacked_states:
+ # The state is shown in the stack. There is a cycle.
+ sub_stack_usage = 0
+ scc_lowlink = min(scc_lowlink, scc_index_map[callee_state])
+ if callee_state == curr_state:
+ self_loop = True
+
+ done_result = done_states.get(callee_state)
+ if done_result is not None:
+ # Already done this state and use its result. If the state reaches a
+ # cycle, reusing the result will cause inaccuracy (the stack usage
+ # of cycle depends on where the entrance is). But it's fine since we
+ # can't get accurate stack usage under this situation, and we rely
+ # on user-provided annotations to break the cycle, after which the
+ # result will be accurate again.
+ (sub_stack_usage, _) = done_result
+
+ if callsite.is_tail:
+ # For tailing call, since the callee reuses the stack frame of the
+ # caller, choose the larger one directly.
+ stack_usage = max(curr_func.stack_frame, sub_stack_usage)
+ else:
+ stack_usage = curr_func.stack_frame + sub_stack_usage
+
+ if stack_usage > max_stack_usage:
+ max_stack_usage = stack_usage
+ max_callee_state = callee_state
+
+ if scc_lowlink == scc_index:
+ group = []
+ while scc_stack[-1] != curr_state:
+ scc_state = scc_stack.pop()
+ stacked_states.remove(scc_state)
+ group.append(scc_state)
+
+ scc_stack.pop()
+ stacked_states.remove(curr_state)
+
+ # If the cycle is not empty, record it.
+ if len(group) > 0 or self_loop:
+ group.append(curr_state)
+ cycle_groups.append(group)
+
+ # Store the done result.
+ done_states[curr_state] = (max_stack_usage, max_callee_state)
+
+ if curr_positions == initial_positions:
+ # If the current state is initial state, we traversed the callgraph by
+ # using the current function as start point. Update the stack usage of
+ # the function.
+ # If the function matches a single vertex remove path, this will set its
+ # max stack usage to 0, which is not expected (we still calculate its
+ # max stack usage, but prevent any function from calling it). However,
+ # all the single vertex remove paths have been preprocessed and removed.
+ curr_func.stack_max_usage = max_stack_usage
+
+ # Reconstruct the max stack path by traversing the state transitions.
+ max_stack_path = [curr_func]
+ callee_state = max_callee_state
+ while callee_state is not None:
+ # The first element of state tuple is function address.
+ max_stack_path.append(function_map[callee_state[0]])
+ done_result = done_states.get(callee_state)
+ # All of the descendants should be done.
+ assert done_result is not None
+ (_, callee_state) = done_result
+
+ curr_func.stack_max_path = max_stack_path
+
+ return scc_lowlink
+
+ # The state is the concatenation of the current function address and the
+ # state of matching position.
+ initial_positions = (0,) * len(remove_list)
+ done_states = {}
+ stacked_states = set()
+ scc_index_counter = [0]
+ scc_index_map = {}
+ scc_stack = []
+ cycle_groups = []
+ for function in function_map.values():
+ if function.stack_max_usage is None:
+ Traverse((function.address, initial_positions))
+
+ cycle_functions = []
+ for group in cycle_groups:
+ cycle = set(function_map[state[0]] for state in group)
+ if cycle not in cycle_functions:
+ cycle_functions.append(cycle)
+
+ return cycle_functions
+
+ def Analyze(self):
+ """Run the stack analysis.
+
+ Raises:
+ StackAnalyzerError: If disassembly fails.
+ """
+
+ def OutputInlineStack(address, prefix=""):
+ """Output beautiful inline stack.
+
+ Args:
+ address: Address.
+ prefix: Prefix of each line.
+
+ Returns:
+ Key for sorting, output text
+ """
+ line_infos = self.AddressToLine(address, True)
+
+ if line_infos[0] is None:
+ order_key = (None, None)
else:
- add_rules[src_sig].add(dst_sig)
-
- if 'remove' in self.annotation and self.annotation['remove'] is not None:
- for sigtxt_path in self.annotation['remove']:
- if isinstance(sigtxt_path, str):
- # The path has only one vertex.
- sigtxt_path = [sigtxt_path]
-
- if len(sigtxt_path) == 0:
- continue
-
- # Generate multiple remove paths from all the combinations of the
- # signatures of each vertex.
- sig_paths = [[]]
- broken_flag = False
- for sigtxt_node in sigtxt_path:
- if isinstance(sigtxt_node, str):
- # The vertex has only one signature.
- sigtxt_set = {sigtxt_node}
- elif isinstance(sigtxt_node, list):
- # The vertex has multiple signatures.
- sigtxt_set = set(sigtxt_node)
- else:
- # Assume the format of annotation is verified. There should be no
- # invalid case.
- assert False
-
- sig_set = set()
- for sigtxt in sigtxt_set:
- sig = NormalizeSignature(sigtxt)
- if sig is None:
- invalid_sigtxts.add(sigtxt)
- broken_flag = True
- elif not broken_flag:
- sig_set.add(sig)
-
- if broken_flag:
- continue
-
- # Append each signature of the current node to the all previous
- # remove paths.
- sig_paths = [path + [sig] for path in sig_paths for sig in sig_set]
-
- if not broken_flag:
- # All signatures are normalized. The remove path has no error.
- remove_rules.extend(sig_paths)
-
- return (add_rules, remove_rules, invalid_sigtxts)
+ (_, path, linenum) = line_infos[0]
+ order_key = (linenum, path)
+
+ line_texts = []
+ for line_info in reversed(line_infos):
+ if line_info is None:
+ (function_name, path, linenum) = ("??", "??", 0)
+ else:
+ (function_name, path, linenum) = line_info
+
+ line_texts.append(
+ "{}[{}:{}]".format(function_name, os.path.relpath(path), linenum)
+ )
+
+ output = "{}-> {} {:x}\n".format(prefix, line_texts[0], address)
+ for depth, line_text in enumerate(line_texts[1:]):
+ output += "{} {}- {}\n".format(prefix, " " * depth, line_text)
+
+ # Remove the last newline character.
+ return (order_key, output.rstrip("\n"))
+
+ # Analyze disassembly.
+ try:
+ disasm_text = subprocess.check_output(
+ [self.options.objdump, "-d", self.options.elf_path], encoding="utf-8"
+ )
+ except subprocess.CalledProcessError:
+ raise StackAnalyzerError("objdump failed to disassemble.")
+ except OSError:
+ raise StackAnalyzerError("Failed to run objdump.")
+
+ function_map = self.AnalyzeDisassembly(disasm_text)
+ result = self.ResolveAnnotation(function_map)
+ (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result
+ remove_list = self.PreprocessAnnotation(
+ function_map, add_set, remove_list, eliminated_addrs
+ )
+ cycle_functions = self.AnalyzeCallGraph(function_map, remove_list)
+
+ # Print the results of task-aware stack analysis.
+ extra_stack_frame = self.annotation.get(
+ "exception_frame_size", DEFAULT_EXCEPTION_FRAME_SIZE
+ )
+ for task in self.tasklist:
+ routine_func = function_map[task.routine_address]
+ print(
+ "Task: {}, Max size: {} ({} + {}), Allocated size: {}".format(
+ task.name,
+ routine_func.stack_max_usage + extra_stack_frame,
+ routine_func.stack_max_usage,
+ extra_stack_frame,
+ task.stack_max_size,
+ )
+ )
+
+ print("Call Trace:")
+ max_stack_path = routine_func.stack_max_path
+ # Assume the routine function is resolved.
+ assert max_stack_path is not None
+ for depth, curr_func in enumerate(max_stack_path):
+ line_info = self.AddressToLine(curr_func.address)[0]
+ if line_info is None:
+ (path, linenum) = ("??", 0)
+ else:
+ (_, path, linenum) = line_info
+
+ print(
+ " {} ({}) [{}:{}] {:x}".format(
+ curr_func.name,
+ curr_func.stack_frame,
+ os.path.relpath(path),
+ linenum,
+ curr_func.address,
+ )
+ )
+
+ if depth + 1 < len(max_stack_path):
+ succ_func = max_stack_path[depth + 1]
+ text_list = []
+ for callsite in curr_func.callsites:
+ if callsite.callee is succ_func:
+ indent_prefix = " "
+ if callsite.address is None:
+ order_text = (
+ None,
+ "{}-> [annotation]".format(indent_prefix),
+ )
+ else:
+ order_text = OutputInlineStack(
+ callsite.address, indent_prefix
+ )
+
+ text_list.append(order_text)
+
+ for _, text in sorted(text_list, key=lambda item: item[0]):
+ print(text)
+
+ print("Unresolved indirect callsites:")
+ for function in function_map.values():
+ indirect_callsites = []
+ for callsite in function.callsites:
+ if callsite.target is None:
+ indirect_callsites.append(callsite.address)
+
+ if len(indirect_callsites) > 0:
+ print(" In function {}:".format(function.name))
+ text_list = []
+ for address in indirect_callsites:
+ text_list.append(OutputInlineStack(address, " "))
+
+ for _, text in sorted(text_list, key=lambda item: item[0]):
+ print(text)
+
+ print("Unresolved annotation signatures:")
+ for sigtxt, error in failed_sigtxts:
+ print(" {}: {}".format(sigtxt, error))
+
+ if len(cycle_functions) > 0:
+ print("There are cycles in the following function sets:")
+ for functions in cycle_functions:
+ print("[{}]".format(", ".join(function.name for function in functions)))
- def ResolveAnnotation(self, function_map):
- """Resolve annotation.
- Args:
- function_map: Function map.
+def ParseArgs():
+ """Parse commandline arguments.
Returns:
- Set of added call edges, list of remove paths, set of eliminated
- callsite addresses, set of annotation signatures which can't be resolved.
+ options: Namespace from argparse.parse_args().
"""
- def StringifySignature(signature):
- """Stringify the tupled signature.
-
- Args:
- signature: Tupled signature.
-
- Returns:
- Signature string.
- """
- (name, path, linenum) = signature
- bracket_text = ''
- if path is not None:
- path = os.path.relpath(path)
- if linenum is None:
- bracket_text = '[{}]'.format(path)
- else:
- bracket_text = '[{}:{}]'.format(path, linenum)
-
- return name + bracket_text
-
- (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation()
-
- signature_set = set()
- for src_sig, dst_sigs in add_rules.items():
- signature_set.add(src_sig)
- signature_set.update(dst_sigs)
-
- for remove_sigs in remove_rules:
- signature_set.update(remove_sigs)
-
- # Map signatures to functions.
- (signature_map, sig_error_map) = self.MapAnnotation(function_map,
- signature_set)
-
- # Build the indirect callsite map indexed by callsite signature.
- indirect_map = collections.defaultdict(set)
- for function in function_map.values():
- for callsite in function.callsites:
- if callsite.target is not None:
- continue
-
- # Found an indirect callsite.
- line_info = self.AddressToLine(callsite.address)[0]
- if line_info is None:
- continue
-
- (name, path, linenum) = line_info
- result = self.FUNCTION_PREFIX_NAME_RE.match(name)
- if result is None:
- continue
-
- indirect_map[(result.group('name').strip(), path, linenum)].add(
- (function, callsite.address))
-
- # Generate the annotation sets.
- add_set = set()
- remove_list = list()
- eliminated_addrs = set()
-
- for src_sig, dst_sigs in add_rules.items():
- src_funcs = set(signature_map.get(src_sig, []))
- # Try to match the source signature to the indirect callsites. Even if it
- # can't be found in disassembly.
- indirect_calls = indirect_map.get(src_sig)
- if indirect_calls is not None:
- for function, callsite_address in indirect_calls:
- # Add the caller of the indirect callsite to the source functions.
- src_funcs.add(function)
- # Assume each callsite can be represented by a unique address.
- eliminated_addrs.add(callsite_address)
-
- if src_sig in sig_error_map:
- # Assume the error is always the not found error. Since the signature
- # found in indirect callsite map must be a full signature, it can't
- # happen the ambiguous error.
- assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND
- # Found in inline stack, remove the not found error.
- del sig_error_map[src_sig]
-
- for dst_sig in dst_sigs:
- dst_funcs = signature_map.get(dst_sig)
- if dst_funcs is None:
- continue
-
- # Duplicate the call edge for all the same source and destination
- # functions.
- for src_func in src_funcs:
- for dst_func in dst_funcs:
- add_set.add((src_func, dst_func))
-
- for remove_sigs in remove_rules:
- # Since each signature can be mapped to multiple functions, generate
- # multiple remove paths from all the combinations of these functions.
- remove_paths = [[]]
- skip_flag = False
- for remove_sig in remove_sigs:
- # Transform each signature to the corresponding functions.
- remove_funcs = signature_map.get(remove_sig)
- if remove_funcs is None:
- # There is an unresolved signature in the remove path. Ignore the
- # whole broken remove path.
- skip_flag = True
- break
- else:
- # Append each function of the current signature to the all previous
- # remove paths.
- remove_paths = [p + [f] for p in remove_paths for f in remove_funcs]
-
- if skip_flag:
- # Ignore the broken remove path.
- continue
-
- for remove_path in remove_paths:
- # Deduplicate the remove paths.
- if remove_path not in remove_list:
- remove_list.append(remove_path)
-
- # Format the error messages.
- failed_sigtxts = set()
- for sigtxt in invalid_sigtxts:
- failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID))
-
- for sig, error in sig_error_map.items():
- failed_sigtxts.add((StringifySignature(sig), error))
-
- return (add_set, remove_list, eliminated_addrs, failed_sigtxts)
-
- def PreprocessAnnotation(self, function_map, add_set, remove_list,
- eliminated_addrs):
- """Preprocess the annotation and callgraph.
+ parser = argparse.ArgumentParser(description="EC firmware stack analyzer.")
+ parser.add_argument("elf_path", help="the path of EC firmware ELF")
+ parser.add_argument(
+ "--export_taskinfo",
+ required=True,
+ help="the path of export_taskinfo.so utility",
+ )
+ parser.add_argument(
+ "--section",
+ required=True,
+ help="the section.",
+ choices=[SECTION_RO, SECTION_RW],
+ )
+ parser.add_argument("--objdump", default="objdump", help="the path of objdump")
+ parser.add_argument(
+ "--addr2line", default="addr2line", help="the path of addr2line"
+ )
+ parser.add_argument(
+ "--annotation", default=None, help="the path of annotation file"
+ )
+
+ # TODO(cheyuw): Add an option for dumping stack usage of all functions.
+
+ return parser.parse_args()
- Add the missing call edges, and delete simple remove paths (the paths have
- one or two vertices) from the function_map.
- Eliminate the annotated indirect callsites.
-
- Return the remaining remove list.
+def ParseSymbolText(symbol_text):
+ """Parse the content of the symbol text.
Args:
- function_map: Function map.
- add_set: Set of missing call edges.
- remove_list: List of remove paths.
- eliminated_addrs: Set of eliminated callsite addresses.
+ symbol_text: Text of the symbols.
Returns:
- List of remaining remove paths.
+ symbols: Symbol list.
"""
- def CheckEdge(path):
- """Check if all edges of the path are on the callgraph.
-
- Args:
- path: Path.
-
- Returns:
- True or False.
- """
- for index in range(len(path) - 1):
- if (path[index], path[index + 1]) not in edge_set:
- return False
-
- return True
-
- for src_func, dst_func in add_set:
- # TODO(cheyuw): Support tailing call annotation.
- src_func.callsites.append(
- Callsite(None, dst_func.address, False, dst_func))
-
- # Delete simple remove paths.
- remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2)
- edge_set = set()
- for function in function_map.values():
- cleaned_callsites = []
- for callsite in function.callsites:
- if ((callsite.callee,) in remove_simple or
- (function, callsite.callee) in remove_simple):
- continue
-
- if callsite.target is None and callsite.address in eliminated_addrs:
- continue
-
- cleaned_callsites.append(callsite)
- if callsite.callee is not None:
- edge_set.add((function, callsite.callee))
+ # Example: "10093064 g F .text 0000015c .hidden hook_task"
+ symbol_regex = re.compile(
+ r"^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+"
+ r"((?P<type>[OF])\s+)?\S+\s+"
+ r"(?P<size>[0-9A-Fa-f]+)\s+"
+ r"(\S+\s+)?(?P<name>\S+)$"
+ )
+
+ symbols = []
+ for line in symbol_text.splitlines():
+ line = line.strip()
+ result = symbol_regex.match(line)
+ if result is not None:
+ address = int(result.group("address"), 16)
+ symtype = result.group("type")
+ if symtype is None:
+ symtype = "O"
- function.callsites = cleaned_callsites
+ size = int(result.group("size"), 16)
+ name = result.group("name")
+ symbols.append(Symbol(address, symtype, size, name))
- return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)]
+ return symbols
- def AnalyzeCallGraph(self, function_map, remove_list):
- """Analyze callgraph.
- It will update the max stack size and path for each function.
+def ParseRoDataText(rodata_text):
+ """Parse the content of rodata
Args:
- function_map: Function map.
- remove_list: List of remove paths.
+ symbol_text: Text of the rodata dump.
Returns:
- List of function cycles.
+ symbols: Symbol list.
"""
- def Traverse(curr_state):
- """Traverse the callgraph and calculate the max stack usages of functions.
-
- Args:
- curr_state: Current state.
-
- Returns:
- SCC lowest link.
- """
- scc_index = scc_index_counter[0]
- scc_index_counter[0] += 1
- scc_index_map[curr_state] = scc_index
- scc_lowlink = scc_index
- scc_stack.append(curr_state)
- # Push the current state in the stack. We can use a set to maintain this
- # because the stacked states are unique; otherwise we will find a cycle
- # first.
- stacked_states.add(curr_state)
-
- (curr_address, curr_positions) = curr_state
- curr_func = function_map[curr_address]
-
- invalid_flag = False
- new_positions = list(curr_positions)
- for index, position in enumerate(curr_positions):
- remove_path = remove_list[index]
-
- # The position of each remove path in the state is the length of the
- # longest matching path between the prefix of the remove path and the
- # suffix of the current traversing path. We maintain this length when
- # appending the next callee to the traversing path. And it can be used
- # to check if the remove path appears in the traversing path.
-
- # TODO(cheyuw): Implement KMP algorithm to match remove paths
- # efficiently.
- if remove_path[position] is curr_func:
- # Matches the current function, extend the length.
- new_positions[index] = position + 1
- if new_positions[index] == len(remove_path):
- # The length of the longest matching path is equal to the length of
- # the remove path, which means the suffix of the current traversing
- # path matches the remove path.
- invalid_flag = True
- break
-
- else:
- # We can't get the new longest matching path by extending the previous
- # one directly. Fallback to search the new longest matching path.
-
- # If we can't find any matching path in the following search, reset
- # the matching length to 0.
- new_positions[index] = 0
-
- # We want to find the new longest matching prefix of remove path with
- # the suffix of the current traversing path. Because the new longest
- # matching path won't be longer than the prevous one now, and part of
- # the suffix matches the prefix of remove path, we can get the needed
- # suffix from the previous matching prefix of the invalid path.
- suffix = remove_path[:position] + [curr_func]
- for offset in range(1, len(suffix)):
- length = position - offset
- if remove_path[:length] == suffix[offset:]:
- new_positions[index] = length
- break
-
- new_positions = tuple(new_positions)
-
- # If the current suffix is invalid, set the max stack usage to 0.
- max_stack_usage = 0
- max_callee_state = None
- self_loop = False
-
- if not invalid_flag:
- # Max stack usage is at least equal to the stack frame.
- max_stack_usage = curr_func.stack_frame
- for callsite in curr_func.callsites:
- callee = callsite.callee
- if callee is None:
+ # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K...
+ # 100a7294 00000000 00000000 01000000 ............
+
+ base_offset = None
+ offset = None
+ rodata = []
+ for line in rodata_text.splitlines():
+ line = line.strip()
+ space = line.find(" ")
+ if space < 0:
+ continue
+ try:
+ address = int(line[0:space], 16)
+ except ValueError:
continue
- callee_state = (callee.address, new_positions)
- if callee_state not in scc_index_map:
- # Unvisited state.
- scc_lowlink = min(scc_lowlink, Traverse(callee_state))
- elif callee_state in stacked_states:
- # The state is shown in the stack. There is a cycle.
- sub_stack_usage = 0
- scc_lowlink = min(scc_lowlink, scc_index_map[callee_state])
- if callee_state == curr_state:
- self_loop = True
-
- done_result = done_states.get(callee_state)
- if done_result is not None:
- # Already done this state and use its result. If the state reaches a
- # cycle, reusing the result will cause inaccuracy (the stack usage
- # of cycle depends on where the entrance is). But it's fine since we
- # can't get accurate stack usage under this situation, and we rely
- # on user-provided annotations to break the cycle, after which the
- # result will be accurate again.
- (sub_stack_usage, _) = done_result
-
- if callsite.is_tail:
- # For tailing call, since the callee reuses the stack frame of the
- # caller, choose the larger one directly.
- stack_usage = max(curr_func.stack_frame, sub_stack_usage)
- else:
- stack_usage = curr_func.stack_frame + sub_stack_usage
-
- if stack_usage > max_stack_usage:
- max_stack_usage = stack_usage
- max_callee_state = callee_state
-
- if scc_lowlink == scc_index:
- group = []
- while scc_stack[-1] != curr_state:
- scc_state = scc_stack.pop()
- stacked_states.remove(scc_state)
- group.append(scc_state)
-
- scc_stack.pop()
- stacked_states.remove(curr_state)
-
- # If the cycle is not empty, record it.
- if len(group) > 0 or self_loop:
- group.append(curr_state)
- cycle_groups.append(group)
-
- # Store the done result.
- done_states[curr_state] = (max_stack_usage, max_callee_state)
-
- if curr_positions == initial_positions:
- # If the current state is initial state, we traversed the callgraph by
- # using the current function as start point. Update the stack usage of
- # the function.
- # If the function matches a single vertex remove path, this will set its
- # max stack usage to 0, which is not expected (we still calculate its
- # max stack usage, but prevent any function from calling it). However,
- # all the single vertex remove paths have been preprocessed and removed.
- curr_func.stack_max_usage = max_stack_usage
-
- # Reconstruct the max stack path by traversing the state transitions.
- max_stack_path = [curr_func]
- callee_state = max_callee_state
- while callee_state is not None:
- # The first element of state tuple is function address.
- max_stack_path.append(function_map[callee_state[0]])
- done_result = done_states.get(callee_state)
- # All of the descendants should be done.
- assert done_result is not None
- (_, callee_state) = done_result
-
- curr_func.stack_max_path = max_stack_path
-
- return scc_lowlink
-
- # The state is the concatenation of the current function address and the
- # state of matching position.
- initial_positions = (0,) * len(remove_list)
- done_states = {}
- stacked_states = set()
- scc_index_counter = [0]
- scc_index_map = {}
- scc_stack = []
- cycle_groups = []
- for function in function_map.values():
- if function.stack_max_usage is None:
- Traverse((function.address, initial_positions))
-
- cycle_functions = []
- for group in cycle_groups:
- cycle = set(function_map[state[0]] for state in group)
- if cycle not in cycle_functions:
- cycle_functions.append(cycle)
-
- return cycle_functions
-
- def Analyze(self):
- """Run the stack analysis.
-
- Raises:
- StackAnalyzerError: If disassembly fails.
- """
- def OutputInlineStack(address, prefix=''):
- """Output beautiful inline stack.
-
- Args:
- address: Address.
- prefix: Prefix of each line.
-
- Returns:
- Key for sorting, output text
- """
- line_infos = self.AddressToLine(address, True)
-
- if line_infos[0] is None:
- order_key = (None, None)
- else:
- (_, path, linenum) = line_infos[0]
- order_key = (linenum, path)
-
- line_texts = []
- for line_info in reversed(line_infos):
- if line_info is None:
- (function_name, path, linenum) = ('??', '??', 0)
- else:
- (function_name, path, linenum) = line_info
-
- line_texts.append('{}[{}:{}]'.format(function_name,
- os.path.relpath(path),
- linenum))
-
- output = '{}-> {} {:x}\n'.format(prefix, line_texts[0], address)
- for depth, line_text in enumerate(line_texts[1:]):
- output += '{} {}- {}\n'.format(prefix, ' ' * depth, line_text)
-
- # Remove the last newline character.
- return (order_key, output.rstrip('\n'))
-
- # Analyze disassembly.
- try:
- disasm_text = subprocess.check_output([self.options.objdump,
- '-d',
- self.options.elf_path],
- encoding='utf-8')
- except subprocess.CalledProcessError:
- raise StackAnalyzerError('objdump failed to disassemble.')
- except OSError:
- raise StackAnalyzerError('Failed to run objdump.')
-
- function_map = self.AnalyzeDisassembly(disasm_text)
- result = self.ResolveAnnotation(function_map)
- (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result
- remove_list = self.PreprocessAnnotation(function_map,
- add_set,
- remove_list,
- eliminated_addrs)
- cycle_functions = self.AnalyzeCallGraph(function_map, remove_list)
-
- # Print the results of task-aware stack analysis.
- extra_stack_frame = self.annotation.get('exception_frame_size',
- DEFAULT_EXCEPTION_FRAME_SIZE)
- for task in self.tasklist:
- routine_func = function_map[task.routine_address]
- print('Task: {}, Max size: {} ({} + {}), Allocated size: {}'.format(
- task.name,
- routine_func.stack_max_usage + extra_stack_frame,
- routine_func.stack_max_usage,
- extra_stack_frame,
- task.stack_max_size))
-
- print('Call Trace:')
- max_stack_path = routine_func.stack_max_path
- # Assume the routine function is resolved.
- assert max_stack_path is not None
- for depth, curr_func in enumerate(max_stack_path):
- line_info = self.AddressToLine(curr_func.address)[0]
- if line_info is None:
- (path, linenum) = ('??', 0)
- else:
- (_, path, linenum) = line_info
-
- print(' {} ({}) [{}:{}] {:x}'.format(curr_func.name,
- curr_func.stack_frame,
- os.path.relpath(path),
- linenum,
- curr_func.address))
-
- if depth + 1 < len(max_stack_path):
- succ_func = max_stack_path[depth + 1]
- text_list = []
- for callsite in curr_func.callsites:
- if callsite.callee is succ_func:
- indent_prefix = ' '
- if callsite.address is None:
- order_text = (None, '{}-> [annotation]'.format(indent_prefix))
- else:
- order_text = OutputInlineStack(callsite.address, indent_prefix)
-
- text_list.append(order_text)
-
- for _, text in sorted(text_list, key=lambda item: item[0]):
- print(text)
-
- print('Unresolved indirect callsites:')
- for function in function_map.values():
- indirect_callsites = []
- for callsite in function.callsites:
- if callsite.target is None:
- indirect_callsites.append(callsite.address)
-
- if len(indirect_callsites) > 0:
- print(' In function {}:'.format(function.name))
- text_list = []
- for address in indirect_callsites:
- text_list.append(OutputInlineStack(address, ' '))
-
- for _, text in sorted(text_list, key=lambda item: item[0]):
- print(text)
-
- print('Unresolved annotation signatures:')
- for sigtxt, error in failed_sigtxts:
- print(' {}: {}'.format(sigtxt, error))
-
- if len(cycle_functions) > 0:
- print('There are cycles in the following function sets:')
- for functions in cycle_functions:
- print('[{}]'.format(', '.join(function.name for function in functions)))
-
-
-def ParseArgs():
- """Parse commandline arguments.
-
- Returns:
- options: Namespace from argparse.parse_args().
- """
- parser = argparse.ArgumentParser(description="EC firmware stack analyzer.")
- parser.add_argument('elf_path', help="the path of EC firmware ELF")
- parser.add_argument('--export_taskinfo', required=True,
- help="the path of export_taskinfo.so utility")
- parser.add_argument('--section', required=True, help='the section.',
- choices=[SECTION_RO, SECTION_RW])
- parser.add_argument('--objdump', default='objdump',
- help='the path of objdump')
- parser.add_argument('--addr2line', default='addr2line',
- help='the path of addr2line')
- parser.add_argument('--annotation', default=None,
- help='the path of annotation file')
-
- # TODO(cheyuw): Add an option for dumping stack usage of all functions.
-
- return parser.parse_args()
-
-
-def ParseSymbolText(symbol_text):
- """Parse the content of the symbol text.
-
- Args:
- symbol_text: Text of the symbols.
-
- Returns:
- symbols: Symbol list.
- """
- # Example: "10093064 g F .text 0000015c .hidden hook_task"
- symbol_regex = re.compile(r'^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+'
- r'((?P<type>[OF])\s+)?\S+\s+'
- r'(?P<size>[0-9A-Fa-f]+)\s+'
- r'(\S+\s+)?(?P<name>\S+)$')
-
- symbols = []
- for line in symbol_text.splitlines():
- line = line.strip()
- result = symbol_regex.match(line)
- if result is not None:
- address = int(result.group('address'), 16)
- symtype = result.group('type')
- if symtype is None:
- symtype = 'O'
-
- size = int(result.group('size'), 16)
- name = result.group('name')
- symbols.append(Symbol(address, symtype, size, name))
-
- return symbols
-
-
-def ParseRoDataText(rodata_text):
- """Parse the content of rodata
-
- Args:
- symbol_text: Text of the rodata dump.
-
- Returns:
- symbols: Symbol list.
- """
- # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K...
- # 100a7294 00000000 00000000 01000000 ............
-
- base_offset = None
- offset = None
- rodata = []
- for line in rodata_text.splitlines():
- line = line.strip()
- space = line.find(' ')
- if space < 0:
- continue
- try:
- address = int(line[0:space], 16)
- except ValueError:
- continue
-
- if not base_offset:
- base_offset = address
- offset = address
- elif address != offset:
- raise StackAnalyzerError('objdump of rodata not contiguous.')
+ if not base_offset:
+ base_offset = address
+ offset = address
+ elif address != offset:
+ raise StackAnalyzerError("objdump of rodata not contiguous.")
- for i in range(0, 4):
- num = line[(space + 1 + i*9):(space + 9 + i*9)]
- if len(num.strip()) > 0:
- val = int(num, 16)
- else:
- val = 0
- # TODO(drinkcat): Not all platforms are necessarily big-endian
- rodata.append((val & 0x000000ff) << 24 |
- (val & 0x0000ff00) << 8 |
- (val & 0x00ff0000) >> 8 |
- (val & 0xff000000) >> 24)
+ for i in range(0, 4):
+ num = line[(space + 1 + i * 9) : (space + 9 + i * 9)]
+ if len(num.strip()) > 0:
+ val = int(num, 16)
+ else:
+ val = 0
+ # TODO(drinkcat): Not all platforms are necessarily big-endian
+ rodata.append(
+ (val & 0x000000FF) << 24
+ | (val & 0x0000FF00) << 8
+ | (val & 0x00FF0000) >> 8
+ | (val & 0xFF000000) >> 24
+ )
- offset = offset + 4*4
+ offset = offset + 4 * 4
- return (base_offset, rodata)
+ return (base_offset, rodata)
def LoadTasklist(section, export_taskinfo, symbols):
- """Load the task information.
+ """Load the task information.
- Args:
- section: Section (RO | RW).
- export_taskinfo: Handle of export_taskinfo.so.
- symbols: Symbol list.
+ Args:
+ section: Section (RO | RW).
+ export_taskinfo: Handle of export_taskinfo.so.
+ symbols: Symbol list.
- Returns:
- tasklist: Task list.
- """
+ Returns:
+ tasklist: Task list.
+ """
- TaskInfoPointer = ctypes.POINTER(TaskInfo)
- taskinfos = TaskInfoPointer()
- if section == SECTION_RO:
- get_taskinfos_func = export_taskinfo.get_ro_taskinfos
- else:
- get_taskinfos_func = export_taskinfo.get_rw_taskinfos
+ TaskInfoPointer = ctypes.POINTER(TaskInfo)
+ taskinfos = TaskInfoPointer()
+ if section == SECTION_RO:
+ get_taskinfos_func = export_taskinfo.get_ro_taskinfos
+ else:
+ get_taskinfos_func = export_taskinfo.get_rw_taskinfos
- taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos))
+ taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos))
- tasklist = []
- for index in range(taskinfo_num):
- taskinfo = taskinfos[index]
- tasklist.append(Task(taskinfo.name.decode('utf-8'),
- taskinfo.routine.decode('utf-8'),
- taskinfo.stack_size))
+ tasklist = []
+ for index in range(taskinfo_num):
+ taskinfo = taskinfos[index]
+ tasklist.append(
+ Task(
+ taskinfo.name.decode("utf-8"),
+ taskinfo.routine.decode("utf-8"),
+ taskinfo.stack_size,
+ )
+ )
- # Resolve routine address for each task. It's more efficient to resolve all
- # routine addresses of tasks together.
- routine_map = dict((task.routine_name, None) for task in tasklist)
+ # Resolve routine address for each task. It's more efficient to resolve all
+ # routine addresses of tasks together.
+ routine_map = dict((task.routine_name, None) for task in tasklist)
- for symbol in symbols:
- # Resolve task routine address.
- if symbol.name in routine_map:
- # Assume the symbol of routine is unique.
- assert routine_map[symbol.name] is None
- routine_map[symbol.name] = symbol.address
+ for symbol in symbols:
+ # Resolve task routine address.
+ if symbol.name in routine_map:
+ # Assume the symbol of routine is unique.
+ assert routine_map[symbol.name] is None
+ routine_map[symbol.name] = symbol.address
- for task in tasklist:
- address = routine_map[task.routine_name]
- # Assume we have resolved all routine addresses.
- assert address is not None
- task.routine_address = address
+ for task in tasklist:
+ address = routine_map[task.routine_name]
+ # Assume we have resolved all routine addresses.
+ assert address is not None
+ task.routine_address = address
- return tasklist
+ return tasklist
def main():
- """Main function."""
- try:
- options = ParseArgs()
-
- # Load annotation config.
- if options.annotation is None:
- annotation = {}
- elif not os.path.exists(options.annotation):
- print('Warning: Annotation file {} does not exist.'
- .format(options.annotation))
- annotation = {}
- else:
- try:
- with open(options.annotation, 'r') as annotation_file:
- annotation = yaml.safe_load(annotation_file)
-
- except yaml.YAMLError:
- raise StackAnalyzerError('Failed to parse annotation file {}.'
- .format(options.annotation))
- except IOError:
- raise StackAnalyzerError('Failed to open annotation file {}.'
- .format(options.annotation))
-
- # TODO(cheyuw): Do complete annotation format verification.
- if not isinstance(annotation, dict):
- raise StackAnalyzerError('Invalid annotation file {}.'
- .format(options.annotation))
-
- # Generate and parse the symbols.
+ """Main function."""
try:
- symbol_text = subprocess.check_output([options.objdump,
- '-t',
- options.elf_path],
- encoding='utf-8')
- rodata_text = subprocess.check_output([options.objdump,
- '-s',
- '-j', '.rodata',
- options.elf_path],
- encoding='utf-8')
- except subprocess.CalledProcessError:
- raise StackAnalyzerError('objdump failed to dump symbol table or rodata.')
- except OSError:
- raise StackAnalyzerError('Failed to run objdump.')
-
- symbols = ParseSymbolText(symbol_text)
- rodata = ParseRoDataText(rodata_text)
-
- # Load the tasklist.
- try:
- export_taskinfo = ctypes.CDLL(options.export_taskinfo)
- except OSError:
- raise StackAnalyzerError('Failed to load export_taskinfo.')
-
- tasklist = LoadTasklist(options.section, export_taskinfo, symbols)
-
- analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation)
- analyzer.Analyze()
- except StackAnalyzerError as e:
- print('Error: {}'.format(e))
-
-
-if __name__ == '__main__':
- main()
+ options = ParseArgs()
+
+ # Load annotation config.
+ if options.annotation is None:
+ annotation = {}
+ elif not os.path.exists(options.annotation):
+ print(
+ "Warning: Annotation file {} does not exist.".format(options.annotation)
+ )
+ annotation = {}
+ else:
+ try:
+ with open(options.annotation, "r") as annotation_file:
+ annotation = yaml.safe_load(annotation_file)
+
+ except yaml.YAMLError:
+ raise StackAnalyzerError(
+ "Failed to parse annotation file {}.".format(options.annotation)
+ )
+ except IOError:
+ raise StackAnalyzerError(
+ "Failed to open annotation file {}.".format(options.annotation)
+ )
+
+ # TODO(cheyuw): Do complete annotation format verification.
+ if not isinstance(annotation, dict):
+ raise StackAnalyzerError(
+ "Invalid annotation file {}.".format(options.annotation)
+ )
+
+ # Generate and parse the symbols.
+ try:
+ symbol_text = subprocess.check_output(
+ [options.objdump, "-t", options.elf_path], encoding="utf-8"
+ )
+ rodata_text = subprocess.check_output(
+ [options.objdump, "-s", "-j", ".rodata", options.elf_path],
+ encoding="utf-8",
+ )
+ except subprocess.CalledProcessError:
+ raise StackAnalyzerError("objdump failed to dump symbol table or rodata.")
+ except OSError:
+ raise StackAnalyzerError("Failed to run objdump.")
+
+ symbols = ParseSymbolText(symbol_text)
+ rodata = ParseRoDataText(rodata_text)
+
+ # Load the tasklist.
+ try:
+ export_taskinfo = ctypes.CDLL(options.export_taskinfo)
+ except OSError:
+ raise StackAnalyzerError("Failed to load export_taskinfo.")
+
+ tasklist = LoadTasklist(options.section, export_taskinfo, symbols)
+
+ analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation)
+ analyzer.Analyze()
+ except StackAnalyzerError as e:
+ print("Error: {}".format(e))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/extra/stack_analyzer/stack_analyzer_unittest.py b/extra/stack_analyzer/stack_analyzer_unittest.py
index c36fa9da45..ad2837a8a4 100755
--- a/extra/stack_analyzer/stack_analyzer_unittest.py
+++ b/extra/stack_analyzer/stack_analyzer_unittest.py
@@ -11,820 +11,924 @@
from __future__ import print_function
-import mock
import os
import subprocess
import unittest
+import mock
import stack_analyzer as sa
class ObjectTest(unittest.TestCase):
- """Tests for classes of basic objects."""
-
- def testTask(self):
- task_a = sa.Task('a', 'a_task', 1234)
- task_b = sa.Task('b', 'b_task', 5678, 0x1000)
- self.assertEqual(task_a, task_a)
- self.assertNotEqual(task_a, task_b)
- self.assertNotEqual(task_a, None)
-
- def testSymbol(self):
- symbol_a = sa.Symbol(0x1234, 'F', 32, 'a')
- symbol_b = sa.Symbol(0x234, 'O', 42, 'b')
- self.assertEqual(symbol_a, symbol_a)
- self.assertNotEqual(symbol_a, symbol_b)
- self.assertNotEqual(symbol_a, None)
-
- def testCallsite(self):
- callsite_a = sa.Callsite(0x1002, 0x3000, False)
- callsite_b = sa.Callsite(0x1002, 0x3000, True)
- self.assertEqual(callsite_a, callsite_a)
- self.assertNotEqual(callsite_a, callsite_b)
- self.assertNotEqual(callsite_a, None)
-
- def testFunction(self):
- func_a = sa.Function(0x100, 'a', 0, [])
- func_b = sa.Function(0x200, 'b', 0, [])
- self.assertEqual(func_a, func_a)
- self.assertNotEqual(func_a, func_b)
- self.assertNotEqual(func_a, None)
+ """Tests for classes of basic objects."""
+
+ def testTask(self):
+ task_a = sa.Task("a", "a_task", 1234)
+ task_b = sa.Task("b", "b_task", 5678, 0x1000)
+ self.assertEqual(task_a, task_a)
+ self.assertNotEqual(task_a, task_b)
+ self.assertNotEqual(task_a, None)
+
+ def testSymbol(self):
+ symbol_a = sa.Symbol(0x1234, "F", 32, "a")
+ symbol_b = sa.Symbol(0x234, "O", 42, "b")
+ self.assertEqual(symbol_a, symbol_a)
+ self.assertNotEqual(symbol_a, symbol_b)
+ self.assertNotEqual(symbol_a, None)
+
+ def testCallsite(self):
+ callsite_a = sa.Callsite(0x1002, 0x3000, False)
+ callsite_b = sa.Callsite(0x1002, 0x3000, True)
+ self.assertEqual(callsite_a, callsite_a)
+ self.assertNotEqual(callsite_a, callsite_b)
+ self.assertNotEqual(callsite_a, None)
+
+ def testFunction(self):
+ func_a = sa.Function(0x100, "a", 0, [])
+ func_b = sa.Function(0x200, "b", 0, [])
+ self.assertEqual(func_a, func_a)
+ self.assertNotEqual(func_a, func_b)
+ self.assertNotEqual(func_a, None)
class ArmAnalyzerTest(unittest.TestCase):
- """Tests for class ArmAnalyzer."""
-
- def AppendConditionCode(self, opcodes):
- rets = []
- for opcode in opcodes:
- rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES)
-
- return rets
-
- def testInstructionMatching(self):
- jump_list = self.AppendConditionCode(['b', 'bx'])
- jump_list += (list(opcode + '.n' for opcode in jump_list) +
- list(opcode + '.w' for opcode in jump_list))
- for opcode in jump_list:
- self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode))
-
- self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('bl'))
- self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('blx'))
-
- cbz_list = ['cbz', 'cbnz', 'cbz.n', 'cbnz.n', 'cbz.w', 'cbnz.w']
- for opcode in cbz_list:
- self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode))
-
- self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match('cbn'))
-
- call_list = self.AppendConditionCode(['bl', 'blx'])
- call_list += list(opcode + '.n' for opcode in call_list)
- for opcode in call_list:
- self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode))
-
- self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match('ble'))
-
- result = sa.ArmAnalyzer.CALL_OPERAND_RE.match('53f90 <get_time+0x18>')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '53f90')
- self.assertEqual(result.group(2), 'get_time+0x18')
-
- result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match('r6, 53f90 <get+0x0>')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '53f90')
- self.assertEqual(result.group(2), 'get+0x0')
-
- self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('push'))
- self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('pushal'))
- self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('stmdb'))
- self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('lstm'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subw'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub.w'))
- self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs.w'))
-
- result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, sp, #1668 ; 0x684')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '1668')
- result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, #1668')
- self.assertIsNotNone(result)
- self.assertEqual(result.group(1), '1668')
- self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match('sl, #1668'))
-
- def testAnalyzeFunction(self):
- analyzer = sa.ArmAnalyzer()
- symbol = sa.Symbol(0x10, 'F', 0x100, 'foo')
- instructions = [
- (0x10, 'push', '{r4, r5, r6, r7, lr}'),
- (0x12, 'subw', 'sp, sp, #16 ; 0x10'),
- (0x16, 'movs', 'lr, r1'),
- (0x18, 'beq.n', '26 <foo+0x26>'),
- (0x1a, 'bl', '30 <foo+0x30>'),
- (0x1e, 'bl', 'deadbeef <bar>'),
- (0x22, 'blx', '0 <woo>'),
- (0x26, 'push', '{r1}'),
- (0x28, 'stmdb', 'sp!, {r4, r5, r6, r7, r8, r9, lr}'),
- (0x2c, 'stmdb', 'sp!, {r4}'),
- (0x30, 'stmdb', 'sp, {r4}'),
- (0x34, 'bx.n', '10 <foo>'),
- (0x36, 'bx.n', 'r3'),
- (0x38, 'ldr', 'pc, [r10]'),
- ]
- (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions)
- self.assertEqual(size, 72)
- expect_callsites = [sa.Callsite(0x1e, 0xdeadbeef, False),
- sa.Callsite(0x22, 0x0, False),
- sa.Callsite(0x34, 0x10, True),
- sa.Callsite(0x36, None, True),
- sa.Callsite(0x38, None, True)]
- self.assertEqual(callsites, expect_callsites)
+ """Tests for class ArmAnalyzer."""
+
+ def AppendConditionCode(self, opcodes):
+ rets = []
+ for opcode in opcodes:
+ rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES)
+
+ return rets
+
+ def testInstructionMatching(self):
+ jump_list = self.AppendConditionCode(["b", "bx"])
+ jump_list += list(opcode + ".n" for opcode in jump_list) + list(
+ opcode + ".w" for opcode in jump_list
+ )
+ for opcode in jump_list:
+ self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode))
+
+ self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("bl"))
+ self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("blx"))
+
+ cbz_list = ["cbz", "cbnz", "cbz.n", "cbnz.n", "cbz.w", "cbnz.w"]
+ for opcode in cbz_list:
+ self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode))
+
+ self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match("cbn"))
+
+ call_list = self.AppendConditionCode(["bl", "blx"])
+ call_list += list(opcode + ".n" for opcode in call_list)
+ for opcode in call_list:
+ self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode))
+
+ self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match("ble"))
+
+ result = sa.ArmAnalyzer.CALL_OPERAND_RE.match("53f90 <get_time+0x18>")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "53f90")
+ self.assertEqual(result.group(2), "get_time+0x18")
+
+ result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match("r6, 53f90 <get+0x0>")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "53f90")
+ self.assertEqual(result.group(2), "get+0x0")
+
+ self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("push"))
+ self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("pushal"))
+ self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("stmdb"))
+ self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("lstm"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subw"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub.w"))
+ self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs.w"))
+
+ result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, sp, #1668 ; 0x684")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "1668")
+ result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, #1668")
+ self.assertIsNotNone(result)
+ self.assertEqual(result.group(1), "1668")
+ self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match("sl, #1668"))
+
+ def testAnalyzeFunction(self):
+ analyzer = sa.ArmAnalyzer()
+ symbol = sa.Symbol(0x10, "F", 0x100, "foo")
+ instructions = [
+ (0x10, "push", "{r4, r5, r6, r7, lr}"),
+ (0x12, "subw", "sp, sp, #16 ; 0x10"),
+ (0x16, "movs", "lr, r1"),
+ (0x18, "beq.n", "26 <foo+0x26>"),
+ (0x1A, "bl", "30 <foo+0x30>"),
+ (0x1E, "bl", "deadbeef <bar>"),
+ (0x22, "blx", "0 <woo>"),
+ (0x26, "push", "{r1}"),
+ (0x28, "stmdb", "sp!, {r4, r5, r6, r7, r8, r9, lr}"),
+ (0x2C, "stmdb", "sp!, {r4}"),
+ (0x30, "stmdb", "sp, {r4}"),
+ (0x34, "bx.n", "10 <foo>"),
+ (0x36, "bx.n", "r3"),
+ (0x38, "ldr", "pc, [r10]"),
+ ]
+ (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions)
+ self.assertEqual(size, 72)
+ expect_callsites = [
+ sa.Callsite(0x1E, 0xDEADBEEF, False),
+ sa.Callsite(0x22, 0x0, False),
+ sa.Callsite(0x34, 0x10, True),
+ sa.Callsite(0x36, None, True),
+ sa.Callsite(0x38, None, True),
+ ]
+ self.assertEqual(callsites, expect_callsites)
class StackAnalyzerTest(unittest.TestCase):
- """Tests for class StackAnalyzer."""
-
- def setUp(self):
- symbols = [sa.Symbol(0x1000, 'F', 0x15C, 'hook_task'),
- sa.Symbol(0x2000, 'F', 0x51C, 'console_task'),
- sa.Symbol(0x3200, 'O', 0x124, '__just_data'),
- sa.Symbol(0x4000, 'F', 0x11C, 'touchpad_calc'),
- sa.Symbol(0x5000, 'F', 0x12C, 'touchpad_calc.constprop.42'),
- sa.Symbol(0x12000, 'F', 0x13C, 'trackpad_range'),
- sa.Symbol(0x13000, 'F', 0x200, 'inlined_mul'),
- sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul'),
- sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul_alias'),
- sa.Symbol(0x20000, 'O', 0x0, '__array'),
- sa.Symbol(0x20010, 'O', 0x0, '__array_end'),
- ]
- tasklist = [sa.Task('HOOKS', 'hook_task', 2048, 0x1000),
- sa.Task('CONSOLE', 'console_task', 460, 0x2000)]
- # Array at 0x20000 that contains pointers to hook_task and console_task,
- # with stride=8, offset=4
- rodata = (0x20000, [ 0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000 ])
- options = mock.MagicMock(elf_path='./ec.RW.elf',
- export_taskinfo='fake',
- section='RW',
- objdump='objdump',
- addr2line='addr2line',
- annotation=None)
- self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {})
-
- def testParseSymbolText(self):
- symbol_text = (
- '0 g F .text e8 Foo\n'
- '0000dead w F .text 000000e8 .hidden Bar\n'
- 'deadbeef l O .bss 00000004 .hidden Woooo\n'
- 'deadbee g O .rodata 00000008 __Hooo_ooo\n'
- 'deadbee g .rodata 00000000 __foo_doo_coo_end\n'
- )
- symbols = sa.ParseSymbolText(symbol_text)
- expect_symbols = [sa.Symbol(0x0, 'F', 0xe8, 'Foo'),
- sa.Symbol(0xdead, 'F', 0xe8, 'Bar'),
- sa.Symbol(0xdeadbeef, 'O', 0x4, 'Woooo'),
- sa.Symbol(0xdeadbee, 'O', 0x8, '__Hooo_ooo'),
- sa.Symbol(0xdeadbee, 'O', 0x0, '__foo_doo_coo_end')]
- self.assertEqual(symbols, expect_symbols)
-
- def testParseRoData(self):
- rodata_text = (
- '\n'
- 'Contents of section .rodata:\n'
- ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n'
- )
- rodata = sa.ParseRoDataText(rodata_text)
- expect_rodata = (0x20000,
- [ 0x0010adde, 0x00001000, 0x0020adde, 0x00002000 ])
- self.assertEqual(rodata, expect_rodata)
-
- def testLoadTasklist(self):
- def tasklist_to_taskinfos(pointer, tasklist):
- taskinfos = []
- for task in tasklist:
- taskinfos.append(sa.TaskInfo(name=task.name.encode('utf-8'),
- routine=task.routine_name.encode('utf-8'),
- stack_size=task.stack_max_size))
-
- TaskInfoArray = sa.TaskInfo * len(taskinfos)
- pointer.contents.contents = TaskInfoArray(*taskinfos)
- return len(taskinfos)
-
- def ro_taskinfos(pointer):
- return tasklist_to_taskinfos(pointer, expect_ro_tasklist)
-
- def rw_taskinfos(pointer):
- return tasklist_to_taskinfos(pointer, expect_rw_tasklist)
-
- expect_ro_tasklist = [
- sa.Task('HOOKS', 'hook_task', 2048, 0x1000),
- ]
-
- expect_rw_tasklist = [
- sa.Task('HOOKS', 'hook_task', 2048, 0x1000),
- sa.Task('WOOKS', 'hook_task', 4096, 0x1000),
- sa.Task('CONSOLE', 'console_task', 460, 0x2000),
- ]
-
- export_taskinfo = mock.MagicMock(
- get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos),
- get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos))
-
- tasklist = sa.LoadTasklist('RO', export_taskinfo, self.analyzer.symbols)
- self.assertEqual(tasklist, expect_ro_tasklist)
- tasklist = sa.LoadTasklist('RW', export_taskinfo, self.analyzer.symbols)
- self.assertEqual(tasklist, expect_rw_tasklist)
-
- def testResolveAnnotation(self):
- self.analyzer.annotation = {}
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {})
- self.assertEqual(remove_rules, [])
- self.assertEqual(invalid_sigtxts, set())
-
- self.analyzer.annotation = {'add': None, 'remove': None}
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {})
- self.assertEqual(remove_rules, [])
- self.assertEqual(invalid_sigtxts, set())
-
- self.analyzer.annotation = {
- 'add': None,
- 'remove': [
- [['a', 'b'], ['0', '[', '2'], 'x'],
- [['a', 'b[x:3]'], ['0', '1', '2'], 'x'],
- ],
- }
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {})
- self.assertEqual(list.sort(remove_rules), list.sort([
- [('a', None, None), ('1', None, None), ('x', None, None)],
- [('a', None, None), ('0', None, None), ('x', None, None)],
- [('a', None, None), ('2', None, None), ('x', None, None)],
- [('b', os.path.abspath('x'), 3), ('1', None, None), ('x', None, None)],
- [('b', os.path.abspath('x'), 3), ('0', None, None), ('x', None, None)],
- [('b', os.path.abspath('x'), 3), ('2', None, None), ('x', None, None)],
- ]))
- self.assertEqual(invalid_sigtxts, {'['})
-
- self.analyzer.annotation = {
- 'add': {
- 'touchpad_calc': [ dict(name='__array', stride=8, offset=4) ],
+ """Tests for class StackAnalyzer."""
+
+ def setUp(self):
+ symbols = [
+ sa.Symbol(0x1000, "F", 0x15C, "hook_task"),
+ sa.Symbol(0x2000, "F", 0x51C, "console_task"),
+ sa.Symbol(0x3200, "O", 0x124, "__just_data"),
+ sa.Symbol(0x4000, "F", 0x11C, "touchpad_calc"),
+ sa.Symbol(0x5000, "F", 0x12C, "touchpad_calc.constprop.42"),
+ sa.Symbol(0x12000, "F", 0x13C, "trackpad_range"),
+ sa.Symbol(0x13000, "F", 0x200, "inlined_mul"),
+ sa.Symbol(0x13100, "F", 0x200, "inlined_mul"),
+ sa.Symbol(0x13100, "F", 0x200, "inlined_mul_alias"),
+ sa.Symbol(0x20000, "O", 0x0, "__array"),
+ sa.Symbol(0x20010, "O", 0x0, "__array_end"),
+ ]
+ tasklist = [
+ sa.Task("HOOKS", "hook_task", 2048, 0x1000),
+ sa.Task("CONSOLE", "console_task", 460, 0x2000),
+ ]
+ # Array at 0x20000 that contains pointers to hook_task and console_task,
+ # with stride=8, offset=4
+ rodata = (0x20000, [0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000])
+ options = mock.MagicMock(
+ elf_path="./ec.RW.elf",
+ export_taskinfo="fake",
+ section="RW",
+ objdump="objdump",
+ addr2line="addr2line",
+ annotation=None,
+ )
+ self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {})
+
+ def testParseSymbolText(self):
+ symbol_text = (
+ "0 g F .text e8 Foo\n"
+ "0000dead w F .text 000000e8 .hidden Bar\n"
+ "deadbeef l O .bss 00000004 .hidden Woooo\n"
+ "deadbee g O .rodata 00000008 __Hooo_ooo\n"
+ "deadbee g .rodata 00000000 __foo_doo_coo_end\n"
+ )
+ symbols = sa.ParseSymbolText(symbol_text)
+ expect_symbols = [
+ sa.Symbol(0x0, "F", 0xE8, "Foo"),
+ sa.Symbol(0xDEAD, "F", 0xE8, "Bar"),
+ sa.Symbol(0xDEADBEEF, "O", 0x4, "Woooo"),
+ sa.Symbol(0xDEADBEE, "O", 0x8, "__Hooo_ooo"),
+ sa.Symbol(0xDEADBEE, "O", 0x0, "__foo_doo_coo_end"),
+ ]
+ self.assertEqual(symbols, expect_symbols)
+
+ def testParseRoData(self):
+ rodata_text = (
+ "\n"
+ "Contents of section .rodata:\n"
+ " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n"
+ )
+ rodata = sa.ParseRoDataText(rodata_text)
+ expect_rodata = (0x20000, [0x0010ADDE, 0x00001000, 0x0020ADDE, 0x00002000])
+ self.assertEqual(rodata, expect_rodata)
+
+ def testLoadTasklist(self):
+ def tasklist_to_taskinfos(pointer, tasklist):
+ taskinfos = []
+ for task in tasklist:
+ taskinfos.append(
+ sa.TaskInfo(
+ name=task.name.encode("utf-8"),
+ routine=task.routine_name.encode("utf-8"),
+ stack_size=task.stack_max_size,
+ )
+ )
+
+ TaskInfoArray = sa.TaskInfo * len(taskinfos)
+ pointer.contents.contents = TaskInfoArray(*taskinfos)
+ return len(taskinfos)
+
+ def ro_taskinfos(pointer):
+ return tasklist_to_taskinfos(pointer, expect_ro_tasklist)
+
+ def rw_taskinfos(pointer):
+ return tasklist_to_taskinfos(pointer, expect_rw_tasklist)
+
+ expect_ro_tasklist = [
+ sa.Task("HOOKS", "hook_task", 2048, 0x1000),
+ ]
+
+ expect_rw_tasklist = [
+ sa.Task("HOOKS", "hook_task", 2048, 0x1000),
+ sa.Task("WOOKS", "hook_task", 4096, 0x1000),
+ sa.Task("CONSOLE", "console_task", 460, 0x2000),
+ ]
+
+ export_taskinfo = mock.MagicMock(
+ get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos),
+ get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos),
+ )
+
+ tasklist = sa.LoadTasklist("RO", export_taskinfo, self.analyzer.symbols)
+ self.assertEqual(tasklist, expect_ro_tasklist)
+ tasklist = sa.LoadTasklist("RW", export_taskinfo, self.analyzer.symbols)
+ self.assertEqual(tasklist, expect_rw_tasklist)
+
+ def testResolveAnnotation(self):
+ self.analyzer.annotation = {}
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(add_rules, {})
+ self.assertEqual(remove_rules, [])
+ self.assertEqual(invalid_sigtxts, set())
+
+ self.analyzer.annotation = {"add": None, "remove": None}
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(add_rules, {})
+ self.assertEqual(remove_rules, [])
+ self.assertEqual(invalid_sigtxts, set())
+
+ self.analyzer.annotation = {
+ "add": None,
+ "remove": [
+ [["a", "b"], ["0", "[", "2"], "x"],
+ [["a", "b[x:3]"], ["0", "1", "2"], "x"],
+ ],
+ }
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(add_rules, {})
+ self.assertEqual(
+ list.sort(remove_rules),
+ list.sort(
+ [
+ [("a", None, None), ("1", None, None), ("x", None, None)],
+ [("a", None, None), ("0", None, None), ("x", None, None)],
+ [("a", None, None), ("2", None, None), ("x", None, None)],
+ [
+ ("b", os.path.abspath("x"), 3),
+ ("1", None, None),
+ ("x", None, None),
+ ],
+ [
+ ("b", os.path.abspath("x"), 3),
+ ("0", None, None),
+ ("x", None, None),
+ ],
+ [
+ ("b", os.path.abspath("x"), 3),
+ ("2", None, None),
+ ("x", None, None),
+ ],
+ ]
+ ),
+ )
+ self.assertEqual(invalid_sigtxts, {"["})
+
+ self.analyzer.annotation = {
+ "add": {
+ "touchpad_calc": [dict(name="__array", stride=8, offset=4)],
+ }
+ }
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(
+ add_rules,
+ {
+ ("touchpad_calc", None, None): set(
+ [("console_task", None, None), ("hook_task", None, None)]
+ )
+ },
+ )
+
+ funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 0, []),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ 0x5000: sa.Function(0x5000, "touchpad_calc.constprop.42", 0, []),
+ 0x13000: sa.Function(0x13000, "inlined_mul", 0, []),
+ 0x13100: sa.Function(0x13100, "inlined_mul", 0, []),
+ }
+ funcs[0x1000].callsites = [sa.Callsite(0x1002, None, False, None)]
+ # Set address_to_line_cache to fake the results of addr2line.
+ self.analyzer.address_to_line_cache = {
+ (0x1000, False): [("hook_task", os.path.abspath("a.c"), 10)],
+ (0x1002, False): [("toot_calc", os.path.abspath("t.c"), 1234)],
+ (0x2000, False): [("console_task", os.path.abspath("b.c"), 20)],
+ (0x4000, False): [("toudhpad_calc", os.path.abspath("a.c"), 20)],
+ (0x5000, False): [
+ ("touchpad_calc.constprop.42", os.path.abspath("b.c"), 40)
+ ],
+ (0x12000, False): [("trackpad_range", os.path.abspath("t.c"), 10)],
+ (0x13000, False): [("inlined_mul", os.path.abspath("x.c"), 12)],
+ (0x13100, False): [("inlined_mul", os.path.abspath("x.c"), 12)],
+ }
+ self.analyzer.annotation = {
+ "add": {
+ "hook_task.lto.573": ["touchpad_calc.lto.2501[a.c]"],
+ "console_task": ["touchpad_calc[b.c]", "inlined_mul_alias"],
+ "hook_task[q.c]": ["hook_task"],
+ "inlined_mul[x.c]": ["inlined_mul"],
+ "toot_calc[t.c:1234]": ["hook_task"],
+ },
+ "remove": [
+ ["touchpad?calc["],
+ "touchpad_calc",
+ ["touchpad_calc[a.c]"],
+ ["task_unk[a.c]"],
+ ["touchpad_calc[x/a.c]"],
+ ["trackpad_range"],
+ ["inlined_mul"],
+ ["inlined_mul", "console_task", "touchpad_calc[a.c]"],
+ ["inlined_mul", "inlined_mul_alias", "console_task"],
+ ["inlined_mul", "inlined_mul_alias", "console_task"],
+ ],
+ }
+ (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
+ self.assertEqual(invalid_sigtxts, {"touchpad?calc["})
+
+ signature_set = set()
+ for src_sig, dst_sigs in add_rules.items():
+ signature_set.add(src_sig)
+ signature_set.update(dst_sigs)
+
+ for remove_sigs in remove_rules:
+ signature_set.update(remove_sigs)
+
+ (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, signature_set)
+ result = self.analyzer.ResolveAnnotation(funcs)
+ (add_set, remove_list, eliminated_addrs, failed_sigs) = result
+
+ expect_signature_map = {
+ ("hook_task", None, None): {funcs[0x1000]},
+ ("touchpad_calc", os.path.abspath("a.c"), None): {funcs[0x4000]},
+ ("touchpad_calc", os.path.abspath("b.c"), None): {funcs[0x5000]},
+ ("console_task", None, None): {funcs[0x2000]},
+ ("inlined_mul_alias", None, None): {funcs[0x13100]},
+ ("inlined_mul", os.path.abspath("x.c"), None): {
+ funcs[0x13000],
+ funcs[0x13100],
+ },
+ ("inlined_mul", None, None): {funcs[0x13000], funcs[0x13100]},
+ }
+ self.assertEqual(len(signature_map), len(expect_signature_map))
+ for sig, funclist in signature_map.items():
+ self.assertEqual(set(funclist), expect_signature_map[sig])
+
+ self.assertEqual(
+ add_set,
+ {
+ (funcs[0x1000], funcs[0x4000]),
+ (funcs[0x1000], funcs[0x1000]),
+ (funcs[0x2000], funcs[0x5000]),
+ (funcs[0x2000], funcs[0x13100]),
+ (funcs[0x13000], funcs[0x13000]),
+ (funcs[0x13000], funcs[0x13100]),
+ (funcs[0x13100], funcs[0x13000]),
+ (funcs[0x13100], funcs[0x13100]),
+ },
+ )
+ expect_remove_list = [
+ [funcs[0x4000]],
+ [funcs[0x13000]],
+ [funcs[0x13100]],
+ [funcs[0x13000], funcs[0x2000], funcs[0x4000]],
+ [funcs[0x13100], funcs[0x2000], funcs[0x4000]],
+ [funcs[0x13000], funcs[0x13100], funcs[0x2000]],
+ [funcs[0x13100], funcs[0x13100], funcs[0x2000]],
+ ]
+ self.assertEqual(len(remove_list), len(expect_remove_list))
+ for remove_path in remove_list:
+ self.assertTrue(remove_path in expect_remove_list)
+
+ self.assertEqual(eliminated_addrs, {0x1002})
+ self.assertEqual(
+ failed_sigs,
+ {
+ ("touchpad?calc[", sa.StackAnalyzer.ANNOTATION_ERROR_INVALID),
+ ("touchpad_calc", sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS),
+ ("hook_task[q.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ ("task_unk[a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ ("touchpad_calc[x/a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ ("trackpad_range", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
+ },
+ )
+
+ def testPreprocessAnnotation(self):
+ funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 0, []),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ funcs[0x1000].callsites = [sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])]
+ funcs[0x2000].callsites = [
+ sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]),
+ sa.Callsite(0x2006, None, True, None),
+ ]
+ add_set = {
+ (funcs[0x2000], funcs[0x2000]),
+ (funcs[0x2000], funcs[0x4000]),
+ (funcs[0x4000], funcs[0x1000]),
+ (funcs[0x4000], funcs[0x2000]),
}
- }
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(add_rules, {
- ('touchpad_calc', None, None):
- set([('console_task', None, None), ('hook_task', None, None)])})
-
- funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 0, []),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- 0x5000: sa.Function(0x5000, 'touchpad_calc.constprop.42', 0, []),
- 0x13000: sa.Function(0x13000, 'inlined_mul', 0, []),
- 0x13100: sa.Function(0x13100, 'inlined_mul', 0, []),
- }
- funcs[0x1000].callsites = [
- sa.Callsite(0x1002, None, False, None)]
- # Set address_to_line_cache to fake the results of addr2line.
- self.analyzer.address_to_line_cache = {
- (0x1000, False): [('hook_task', os.path.abspath('a.c'), 10)],
- (0x1002, False): [('toot_calc', os.path.abspath('t.c'), 1234)],
- (0x2000, False): [('console_task', os.path.abspath('b.c'), 20)],
- (0x4000, False): [('toudhpad_calc', os.path.abspath('a.c'), 20)],
- (0x5000, False): [
- ('touchpad_calc.constprop.42', os.path.abspath('b.c'), 40)],
- (0x12000, False): [('trackpad_range', os.path.abspath('t.c'), 10)],
- (0x13000, False): [('inlined_mul', os.path.abspath('x.c'), 12)],
- (0x13100, False): [('inlined_mul', os.path.abspath('x.c'), 12)],
- }
- self.analyzer.annotation = {
- 'add': {
- 'hook_task.lto.573': ['touchpad_calc.lto.2501[a.c]'],
- 'console_task': ['touchpad_calc[b.c]', 'inlined_mul_alias'],
- 'hook_task[q.c]': ['hook_task'],
- 'inlined_mul[x.c]': ['inlined_mul'],
- 'toot_calc[t.c:1234]': ['hook_task'],
- },
- 'remove': [
- ['touchpad?calc['],
- 'touchpad_calc',
- ['touchpad_calc[a.c]'],
- ['task_unk[a.c]'],
- ['touchpad_calc[x/a.c]'],
- ['trackpad_range'],
- ['inlined_mul'],
- ['inlined_mul', 'console_task', 'touchpad_calc[a.c]'],
- ['inlined_mul', 'inlined_mul_alias', 'console_task'],
- ['inlined_mul', 'inlined_mul_alias', 'console_task'],
- ],
- }
- (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation()
- self.assertEqual(invalid_sigtxts, {'touchpad?calc['})
-
- signature_set = set()
- for src_sig, dst_sigs in add_rules.items():
- signature_set.add(src_sig)
- signature_set.update(dst_sigs)
-
- for remove_sigs in remove_rules:
- signature_set.update(remove_sigs)
-
- (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs,
- signature_set)
- result = self.analyzer.ResolveAnnotation(funcs)
- (add_set, remove_list, eliminated_addrs, failed_sigs) = result
-
- expect_signature_map = {
- ('hook_task', None, None): {funcs[0x1000]},
- ('touchpad_calc', os.path.abspath('a.c'), None): {funcs[0x4000]},
- ('touchpad_calc', os.path.abspath('b.c'), None): {funcs[0x5000]},
- ('console_task', None, None): {funcs[0x2000]},
- ('inlined_mul_alias', None, None): {funcs[0x13100]},
- ('inlined_mul', os.path.abspath('x.c'), None): {funcs[0x13000],
- funcs[0x13100]},
- ('inlined_mul', None, None): {funcs[0x13000], funcs[0x13100]},
- }
- self.assertEqual(len(signature_map), len(expect_signature_map))
- for sig, funclist in signature_map.items():
- self.assertEqual(set(funclist), expect_signature_map[sig])
-
- self.assertEqual(add_set, {
- (funcs[0x1000], funcs[0x4000]),
- (funcs[0x1000], funcs[0x1000]),
- (funcs[0x2000], funcs[0x5000]),
- (funcs[0x2000], funcs[0x13100]),
- (funcs[0x13000], funcs[0x13000]),
- (funcs[0x13000], funcs[0x13100]),
- (funcs[0x13100], funcs[0x13000]),
- (funcs[0x13100], funcs[0x13100]),
- })
- expect_remove_list = [
- [funcs[0x4000]],
- [funcs[0x13000]],
- [funcs[0x13100]],
- [funcs[0x13000], funcs[0x2000], funcs[0x4000]],
- [funcs[0x13100], funcs[0x2000], funcs[0x4000]],
- [funcs[0x13000], funcs[0x13100], funcs[0x2000]],
- [funcs[0x13100], funcs[0x13100], funcs[0x2000]],
- ]
- self.assertEqual(len(remove_list), len(expect_remove_list))
- for remove_path in remove_list:
- self.assertTrue(remove_path in expect_remove_list)
-
- self.assertEqual(eliminated_addrs, {0x1002})
- self.assertEqual(failed_sigs, {
- ('touchpad?calc[', sa.StackAnalyzer.ANNOTATION_ERROR_INVALID),
- ('touchpad_calc', sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS),
- ('hook_task[q.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- ('task_unk[a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- ('touchpad_calc[x/a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- ('trackpad_range', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND),
- })
-
- def testPreprocessAnnotation(self):
- funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 0, []),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- funcs[0x1000].callsites = [
- sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])]
- funcs[0x2000].callsites = [
- sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]),
- sa.Callsite(0x2006, None, True, None),
- ]
- add_set = {
- (funcs[0x2000], funcs[0x2000]),
- (funcs[0x2000], funcs[0x4000]),
- (funcs[0x4000], funcs[0x1000]),
- (funcs[0x4000], funcs[0x2000]),
- }
- remove_list = [
- [funcs[0x1000]],
- [funcs[0x2000], funcs[0x2000]],
- [funcs[0x4000], funcs[0x1000]],
- [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
- [funcs[0x4000], funcs[0x1000], funcs[0x4000]],
- ]
- eliminated_addrs = {0x2006}
-
- remaining_remove_list = self.analyzer.PreprocessAnnotation(funcs,
- add_set,
- remove_list,
- eliminated_addrs)
-
- expect_funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 0, []),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- expect_funcs[0x2000].callsites = [
- sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])]
- expect_funcs[0x4000].callsites = [
- sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])]
- self.assertEqual(funcs, expect_funcs)
- self.assertEqual(remaining_remove_list, [
- [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
- ])
-
- def testAndesAnalyzeDisassembly(self):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n'
- ' 1004: 47 70\t\tmovi55 $r0, #1\n'
- ' 1006: b1 13\tbnezs8 100929de <flash_command_write>\n'
- ' 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n'
- '00002000 <console_task>:\n'
- ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n'
- ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n'
- ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n'
- ' 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n'
- '00004000 <touchpad_calc>:\n'
- ' 4000: 47 70\t\tmovi55 $r0, #1\n'
- '00010000 <look_task>:'
- )
- function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
- func_hook_task = sa.Function(0x1000, 'hook_task', 48, [
- sa.Callsite(0x1006, 0x100929de, True, None)])
- expect_funcmap = {
- 0x1000: func_hook_task,
- 0x2000: sa.Function(0x2000, 'console_task', 16,
- [sa.Callsite(0x2002, 0x1000, False, func_hook_task),
- sa.Callsite(0x2006, 0x53968, True, None)]),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- self.assertEqual(function_map, expect_funcmap)
-
- def testArmAnalyzeDisassembly(self):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: dead beef\tfake\n'
- ' 1004: 4770\t\tbx lr\n'
- ' 1006: b113\tcbz r3, 100929de <flash_command_write>\n'
- ' 1008: 00015cfc\t.word 0x00015cfc\n'
- '00002000 <console_task>:\n'
- ' 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n'
- ' 2002: f00e fcc5\tbl 1000 <hook_task>\n'
- ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n'
- ' 200a: dead beef\tfake\n'
- '00004000 <touchpad_calc>:\n'
- ' 4000: 4770\t\tbx lr\n'
- '00010000 <look_task>:'
- )
- function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
- func_hook_task = sa.Function(0x1000, 'hook_task', 0, [
- sa.Callsite(0x1006, 0x100929de, True, None)])
- expect_funcmap = {
- 0x1000: func_hook_task,
- 0x2000: sa.Function(0x2000, 'console_task', 8,
- [sa.Callsite(0x2002, 0x1000, False, func_hook_task),
- sa.Callsite(0x2006, 0x53968, True, None)]),
- 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []),
- }
- self.assertEqual(function_map, expect_funcmap)
-
- def testAnalyzeCallGraph(self):
- funcs = {
- 0x1000: sa.Function(0x1000, 'hook_task', 0, []),
- 0x2000: sa.Function(0x2000, 'console_task', 8, []),
- 0x3000: sa.Function(0x3000, 'task_a', 12, []),
- 0x4000: sa.Function(0x4000, 'task_b', 96, []),
- 0x5000: sa.Function(0x5000, 'task_c', 32, []),
- 0x6000: sa.Function(0x6000, 'task_d', 100, []),
- 0x7000: sa.Function(0x7000, 'task_e', 24, []),
- 0x8000: sa.Function(0x8000, 'task_f', 20, []),
- 0x9000: sa.Function(0x9000, 'task_g', 20, []),
- 0x10000: sa.Function(0x10000, 'task_x', 16, []),
- }
- funcs[0x1000].callsites = [
- sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]),
- sa.Callsite(0x1006, 0x4000, False, funcs[0x4000])]
- funcs[0x2000].callsites = [
- sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]),
- sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]),
- sa.Callsite(0x200a, 0x10000, False, funcs[0x10000])]
- funcs[0x3000].callsites = [
- sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]),
- sa.Callsite(0x3006, 0x1000, False, funcs[0x1000])]
- funcs[0x4000].callsites = [
- sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]),
- sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]),
- sa.Callsite(0x400a, 0x8000, False, funcs[0x8000])]
- funcs[0x5000].callsites = [
- sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])]
- funcs[0x7000].callsites = [
- sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])]
- funcs[0x8000].callsites = [
- sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])]
- funcs[0x9000].callsites = [
- sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])]
- funcs[0x10000].callsites = [
- sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])]
-
- cycles = self.analyzer.AnalyzeCallGraph(funcs, [
- [funcs[0x2000]] * 2,
- [funcs[0x10000], funcs[0x2000]] * 3,
- [funcs[0x1000], funcs[0x3000], funcs[0x1000]]
- ])
-
- expect_func_stack = {
- 0x1000: (268, [funcs[0x1000],
- funcs[0x3000],
- funcs[0x4000],
- funcs[0x8000],
- funcs[0x9000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x2000: (208, [funcs[0x2000],
- funcs[0x10000],
- funcs[0x2000],
- funcs[0x10000],
- funcs[0x2000],
- funcs[0x5000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x3000: (280, [funcs[0x3000],
- funcs[0x1000],
- funcs[0x3000],
- funcs[0x4000],
- funcs[0x8000],
- funcs[0x9000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x4000: (120, [funcs[0x4000], funcs[0x7000]]),
- 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]),
- 0x6000: (100, [funcs[0x6000]]),
- 0x7000: (24, [funcs[0x7000]]),
- 0x8000: (160, [funcs[0x8000],
- funcs[0x9000],
- funcs[0x4000],
- funcs[0x7000]]),
- 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]),
- 0x10000: (200, [funcs[0x10000],
- funcs[0x2000],
- funcs[0x10000],
- funcs[0x2000],
- funcs[0x5000],
- funcs[0x4000],
- funcs[0x7000]]),
- }
- expect_cycles = [
- {funcs[0x4000], funcs[0x8000], funcs[0x9000]},
- {funcs[0x7000]},
- ]
- for func in funcs.values():
- (stack_max_usage, stack_max_path) = expect_func_stack[func.address]
- self.assertEqual(func.stack_max_usage, stack_max_usage)
- self.assertEqual(func.stack_max_path, stack_max_path)
-
- self.assertEqual(len(cycles), len(expect_cycles))
- for cycle in cycles:
- self.assertTrue(cycle in expect_cycles)
-
- @mock.patch('subprocess.check_output')
- def testAddressToLine(self, checkoutput_mock):
- checkoutput_mock.return_value = 'fake_func\n/test.c:1'
- self.assertEqual(self.analyzer.AddressToLine(0x1234),
- [('fake_func', '/test.c', 1)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '1234'], encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- checkoutput_mock.return_value = 'fake_func\n/a.c:1\nbake_func\n/b.c:2\n'
- self.assertEqual(self.analyzer.AddressToLine(0x1234, True),
- [('fake_func', '/a.c', 1), ('bake_func', '/b.c', 2)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '1234', '-i'],
- encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- checkoutput_mock.return_value = 'fake_func\n/test.c:1 (discriminator 128)'
- self.assertEqual(self.analyzer.AddressToLine(0x12345),
- [('fake_func', '/test.c', 1)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '12345'], encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- checkoutput_mock.return_value = '??\n:?\nbake_func\n/b.c:2\n'
- self.assertEqual(self.analyzer.AddressToLine(0x123456),
- [None, ('bake_func', '/b.c', 2)])
- checkoutput_mock.assert_called_once_with(
- ['addr2line', '-f', '-e', './ec.RW.elf', '123456'], encoding='utf-8')
- checkoutput_mock.reset_mock()
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'addr2line failed to resolve lines.'):
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- self.analyzer.AddressToLine(0x5678)
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'Failed to run addr2line.'):
- checkoutput_mock.side_effect = OSError()
- self.analyzer.AddressToLine(0x9012)
-
- @mock.patch('subprocess.check_output')
- @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine')
- def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n'
- ' 1002: 47 70\t\tmovi55 $r0, #1\n'
- ' 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n'
- '00002000 <console_task>:\n'
- ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n'
- ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n'
- ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n'
- ' 200a: 12 34 56 78\tjral5 $r0\n'
- )
-
- addrtoline_mock.return_value = [('??', '??', 0)]
- self.analyzer.annotation = {
- 'exception_frame_size': 64,
- 'remove': [['fake_func']],
- }
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.return_value = disasm_text
- self.analyzer.Analyze()
- print_mock.assert_has_calls([
- mock.call(
- 'Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048'),
- mock.call('Call Trace:'),
- mock.call(' hook_task (32) [??:0] 1000'),
- mock.call(
- 'Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460'),
- mock.call('Call Trace:'),
- mock.call(' console_task (16) [??:0] 2000'),
- mock.call(' -> ??[??:0] 2002'),
- mock.call(' hook_task (32) [??:0] 1000'),
- mock.call('Unresolved indirect callsites:'),
- mock.call(' In function console_task:'),
- mock.call(' -> ??[??:0] 200a'),
- mock.call('Unresolved annotation signatures:'),
- mock.call(' fake_func: function is not found'),
- ])
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'Failed to run objdump.'):
- checkoutput_mock.side_effect = OSError()
- self.analyzer.Analyze()
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'objdump failed to disassemble.'):
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- self.analyzer.Analyze()
-
- @mock.patch('subprocess.check_output')
- @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine')
- def testArmAnalyze(self, addrtoline_mock, checkoutput_mock):
- disasm_text = (
- '\n'
- 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm'
- '\n'
- 'Disassembly of section .text:\n'
- '\n'
- '00000900 <wook_task>:\n'
- ' ...\n'
- '00001000 <hook_task>:\n'
- ' 1000: b508\t\tpush {r3, lr}\n'
- ' 1002: 4770\t\tbx lr\n'
- ' 1006: 00015cfc\t.word 0x00015cfc\n'
- '00002000 <console_task>:\n'
- ' 2000: b508\t\tpush {r3, lr}\n'
- ' 2002: f00e fcc5\tbl 1000 <hook_task>\n'
- ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n'
- ' 200a: 1234 5678\tb.w sl\n'
- )
-
- addrtoline_mock.return_value = [('??', '??', 0)]
- self.analyzer.annotation = {
- 'exception_frame_size': 64,
- 'remove': [['fake_func']],
- }
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.return_value = disasm_text
- self.analyzer.Analyze()
- print_mock.assert_has_calls([
- mock.call(
- 'Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048'),
- mock.call('Call Trace:'),
- mock.call(' hook_task (8) [??:0] 1000'),
- mock.call(
- 'Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460'),
- mock.call('Call Trace:'),
- mock.call(' console_task (8) [??:0] 2000'),
- mock.call(' -> ??[??:0] 2002'),
- mock.call(' hook_task (8) [??:0] 1000'),
- mock.call('Unresolved indirect callsites:'),
- mock.call(' In function console_task:'),
- mock.call(' -> ??[??:0] 200a'),
- mock.call('Unresolved annotation signatures:'),
- mock.call(' fake_func: function is not found'),
- ])
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'Failed to run objdump.'):
- checkoutput_mock.side_effect = OSError()
- self.analyzer.Analyze()
-
- with self.assertRaisesRegexp(sa.StackAnalyzerError,
- 'objdump failed to disassemble.'):
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- self.analyzer.Analyze()
-
- @mock.patch('subprocess.check_output')
- @mock.patch('stack_analyzer.ParseArgs')
- def testMain(self, parseargs_mock, checkoutput_mock):
- symbol_text = ('1000 g F .text 0000015c .hidden hook_task\n'
- '2000 g F .text 0000051c .hidden console_task\n')
- rodata_text = (
- '\n'
- 'Contents of section .rodata:\n'
- ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n'
- )
-
- args = mock.MagicMock(elf_path='./ec.RW.elf',
- export_taskinfo='fake',
- section='RW',
- objdump='objdump',
- addr2line='addr2line',
- annotation='fake')
- parseargs_mock.return_value = args
-
- with mock.patch('os.path.exists') as path_mock:
- path_mock.return_value = False
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open', mock.mock_open()) as open_mock:
- sa.main()
- print_mock.assert_any_call(
- 'Warning: Annotation file fake does not exist.')
-
- with mock.patch('os.path.exists') as path_mock:
- path_mock.return_value = True
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open', mock.mock_open()) as open_mock:
- open_mock.side_effect = IOError()
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: Failed to open annotation file fake.')
-
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open', mock.mock_open()) as open_mock:
- open_mock.return_value.read.side_effect = ['{', '']
- sa.main()
- open_mock.assert_called_once_with('fake', 'r')
- print_mock.assert_called_once_with(
- 'Error: Failed to parse annotation file fake.')
-
- with mock.patch('builtins.print') as print_mock:
- with mock.patch('builtins.open',
- mock.mock_open(read_data='')) as open_mock:
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: Invalid annotation file fake.')
-
- args.annotation = None
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.side_effect = [symbol_text, rodata_text]
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: Failed to load export_taskinfo.')
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '')
- sa.main()
- print_mock.assert_called_once_with(
- 'Error: objdump failed to dump symbol table or rodata.')
-
- with mock.patch('builtins.print') as print_mock:
- checkoutput_mock.side_effect = OSError()
- sa.main()
- print_mock.assert_called_once_with('Error: Failed to run objdump.')
-
-
-if __name__ == '__main__':
- unittest.main()
+ remove_list = [
+ [funcs[0x1000]],
+ [funcs[0x2000], funcs[0x2000]],
+ [funcs[0x4000], funcs[0x1000]],
+ [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
+ [funcs[0x4000], funcs[0x1000], funcs[0x4000]],
+ ]
+ eliminated_addrs = {0x2006}
+
+ remaining_remove_list = self.analyzer.PreprocessAnnotation(
+ funcs, add_set, remove_list, eliminated_addrs
+ )
+
+ expect_funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 0, []),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ expect_funcs[0x2000].callsites = [
+ sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])
+ ]
+ expect_funcs[0x4000].callsites = [
+ sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])
+ ]
+ self.assertEqual(funcs, expect_funcs)
+ self.assertEqual(
+ remaining_remove_list,
+ [
+ [funcs[0x2000], funcs[0x4000], funcs[0x2000]],
+ ],
+ )
+
+ def testAndesAnalyzeDisassembly(self):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n"
+ " 1004: 47 70\t\tmovi55 $r0, #1\n"
+ " 1006: b1 13\tbnezs8 100929de <flash_command_write>\n"
+ " 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n"
+ "00002000 <console_task>:\n"
+ " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n"
+ " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n"
+ " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n"
+ " 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n"
+ "00004000 <touchpad_calc>:\n"
+ " 4000: 47 70\t\tmovi55 $r0, #1\n"
+ "00010000 <look_task>:"
+ )
+ function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
+ func_hook_task = sa.Function(
+ 0x1000, "hook_task", 48, [sa.Callsite(0x1006, 0x100929DE, True, None)]
+ )
+ expect_funcmap = {
+ 0x1000: func_hook_task,
+ 0x2000: sa.Function(
+ 0x2000,
+ "console_task",
+ 16,
+ [
+ sa.Callsite(0x2002, 0x1000, False, func_hook_task),
+ sa.Callsite(0x2006, 0x53968, True, None),
+ ],
+ ),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ self.assertEqual(function_map, expect_funcmap)
+
+ def testArmAnalyzeDisassembly(self):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: dead beef\tfake\n"
+ " 1004: 4770\t\tbx lr\n"
+ " 1006: b113\tcbz r3, 100929de <flash_command_write>\n"
+ " 1008: 00015cfc\t.word 0x00015cfc\n"
+ "00002000 <console_task>:\n"
+ " 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n"
+ " 2002: f00e fcc5\tbl 1000 <hook_task>\n"
+ " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n"
+ " 200a: dead beef\tfake\n"
+ "00004000 <touchpad_calc>:\n"
+ " 4000: 4770\t\tbx lr\n"
+ "00010000 <look_task>:"
+ )
+ function_map = self.analyzer.AnalyzeDisassembly(disasm_text)
+ func_hook_task = sa.Function(
+ 0x1000, "hook_task", 0, [sa.Callsite(0x1006, 0x100929DE, True, None)]
+ )
+ expect_funcmap = {
+ 0x1000: func_hook_task,
+ 0x2000: sa.Function(
+ 0x2000,
+ "console_task",
+ 8,
+ [
+ sa.Callsite(0x2002, 0x1000, False, func_hook_task),
+ sa.Callsite(0x2006, 0x53968, True, None),
+ ],
+ ),
+ 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []),
+ }
+ self.assertEqual(function_map, expect_funcmap)
+
+ def testAnalyzeCallGraph(self):
+ funcs = {
+ 0x1000: sa.Function(0x1000, "hook_task", 0, []),
+ 0x2000: sa.Function(0x2000, "console_task", 8, []),
+ 0x3000: sa.Function(0x3000, "task_a", 12, []),
+ 0x4000: sa.Function(0x4000, "task_b", 96, []),
+ 0x5000: sa.Function(0x5000, "task_c", 32, []),
+ 0x6000: sa.Function(0x6000, "task_d", 100, []),
+ 0x7000: sa.Function(0x7000, "task_e", 24, []),
+ 0x8000: sa.Function(0x8000, "task_f", 20, []),
+ 0x9000: sa.Function(0x9000, "task_g", 20, []),
+ 0x10000: sa.Function(0x10000, "task_x", 16, []),
+ }
+ funcs[0x1000].callsites = [
+ sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]),
+ sa.Callsite(0x1006, 0x4000, False, funcs[0x4000]),
+ ]
+ funcs[0x2000].callsites = [
+ sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]),
+ sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]),
+ sa.Callsite(0x200A, 0x10000, False, funcs[0x10000]),
+ ]
+ funcs[0x3000].callsites = [
+ sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]),
+ sa.Callsite(0x3006, 0x1000, False, funcs[0x1000]),
+ ]
+ funcs[0x4000].callsites = [
+ sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]),
+ sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]),
+ sa.Callsite(0x400A, 0x8000, False, funcs[0x8000]),
+ ]
+ funcs[0x5000].callsites = [sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])]
+ funcs[0x7000].callsites = [sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])]
+ funcs[0x8000].callsites = [sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])]
+ funcs[0x9000].callsites = [sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])]
+ funcs[0x10000].callsites = [sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])]
+
+ cycles = self.analyzer.AnalyzeCallGraph(
+ funcs,
+ [
+ [funcs[0x2000]] * 2,
+ [funcs[0x10000], funcs[0x2000]] * 3,
+ [funcs[0x1000], funcs[0x3000], funcs[0x1000]],
+ ],
+ )
+
+ expect_func_stack = {
+ 0x1000: (
+ 268,
+ [
+ funcs[0x1000],
+ funcs[0x3000],
+ funcs[0x4000],
+ funcs[0x8000],
+ funcs[0x9000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ 0x2000: (
+ 208,
+ [
+ funcs[0x2000],
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x5000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ 0x3000: (
+ 280,
+ [
+ funcs[0x3000],
+ funcs[0x1000],
+ funcs[0x3000],
+ funcs[0x4000],
+ funcs[0x8000],
+ funcs[0x9000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ 0x4000: (120, [funcs[0x4000], funcs[0x7000]]),
+ 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]),
+ 0x6000: (100, [funcs[0x6000]]),
+ 0x7000: (24, [funcs[0x7000]]),
+ 0x8000: (160, [funcs[0x8000], funcs[0x9000], funcs[0x4000], funcs[0x7000]]),
+ 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]),
+ 0x10000: (
+ 200,
+ [
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x10000],
+ funcs[0x2000],
+ funcs[0x5000],
+ funcs[0x4000],
+ funcs[0x7000],
+ ],
+ ),
+ }
+ expect_cycles = [
+ {funcs[0x4000], funcs[0x8000], funcs[0x9000]},
+ {funcs[0x7000]},
+ ]
+ for func in funcs.values():
+ (stack_max_usage, stack_max_path) = expect_func_stack[func.address]
+ self.assertEqual(func.stack_max_usage, stack_max_usage)
+ self.assertEqual(func.stack_max_path, stack_max_path)
+
+ self.assertEqual(len(cycles), len(expect_cycles))
+ for cycle in cycles:
+ self.assertTrue(cycle in expect_cycles)
+
+ @mock.patch("subprocess.check_output")
+ def testAddressToLine(self, checkoutput_mock):
+ checkoutput_mock.return_value = "fake_func\n/test.c:1"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x1234), [("fake_func", "/test.c", 1)]
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "1234"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ checkoutput_mock.return_value = "fake_func\n/a.c:1\nbake_func\n/b.c:2\n"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x1234, True),
+ [("fake_func", "/a.c", 1), ("bake_func", "/b.c", 2)],
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "1234", "-i"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ checkoutput_mock.return_value = "fake_func\n/test.c:1 (discriminator 128)"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x12345), [("fake_func", "/test.c", 1)]
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "12345"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ checkoutput_mock.return_value = "??\n:?\nbake_func\n/b.c:2\n"
+ self.assertEqual(
+ self.analyzer.AddressToLine(0x123456), [None, ("bake_func", "/b.c", 2)]
+ )
+ checkoutput_mock.assert_called_once_with(
+ ["addr2line", "-f", "-e", "./ec.RW.elf", "123456"], encoding="utf-8"
+ )
+ checkoutput_mock.reset_mock()
+
+ with self.assertRaisesRegexp(
+ sa.StackAnalyzerError, "addr2line failed to resolve lines."
+ ):
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ self.analyzer.AddressToLine(0x5678)
+
+ with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run addr2line."):
+ checkoutput_mock.side_effect = OSError()
+ self.analyzer.AddressToLine(0x9012)
+
+ @mock.patch("subprocess.check_output")
+ @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine")
+ def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n"
+ " 1002: 47 70\t\tmovi55 $r0, #1\n"
+ " 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n"
+ "00002000 <console_task>:\n"
+ " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n"
+ " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n"
+ " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n"
+ " 200a: 12 34 56 78\tjral5 $r0\n"
+ )
+
+ addrtoline_mock.return_value = [("??", "??", 0)]
+ self.analyzer.annotation = {
+ "exception_frame_size": 64,
+ "remove": [["fake_func"]],
+ }
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.return_value = disasm_text
+ self.analyzer.Analyze()
+ print_mock.assert_has_calls(
+ [
+ mock.call(
+ "Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" hook_task (32) [??:0] 1000"),
+ mock.call(
+ "Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" console_task (16) [??:0] 2000"),
+ mock.call(" -> ??[??:0] 2002"),
+ mock.call(" hook_task (32) [??:0] 1000"),
+ mock.call("Unresolved indirect callsites:"),
+ mock.call(" In function console_task:"),
+ mock.call(" -> ??[??:0] 200a"),
+ mock.call("Unresolved annotation signatures:"),
+ mock.call(" fake_func: function is not found"),
+ ]
+ )
+
+ with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."):
+ checkoutput_mock.side_effect = OSError()
+ self.analyzer.Analyze()
+
+ with self.assertRaisesRegexp(
+ sa.StackAnalyzerError, "objdump failed to disassemble."
+ ):
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ self.analyzer.Analyze()
+
+ @mock.patch("subprocess.check_output")
+ @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine")
+ def testArmAnalyze(self, addrtoline_mock, checkoutput_mock):
+ disasm_text = (
+ "\n"
+ "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm"
+ "\n"
+ "Disassembly of section .text:\n"
+ "\n"
+ "00000900 <wook_task>:\n"
+ " ...\n"
+ "00001000 <hook_task>:\n"
+ " 1000: b508\t\tpush {r3, lr}\n"
+ " 1002: 4770\t\tbx lr\n"
+ " 1006: 00015cfc\t.word 0x00015cfc\n"
+ "00002000 <console_task>:\n"
+ " 2000: b508\t\tpush {r3, lr}\n"
+ " 2002: f00e fcc5\tbl 1000 <hook_task>\n"
+ " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n"
+ " 200a: 1234 5678\tb.w sl\n"
+ )
+
+ addrtoline_mock.return_value = [("??", "??", 0)]
+ self.analyzer.annotation = {
+ "exception_frame_size": 64,
+ "remove": [["fake_func"]],
+ }
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.return_value = disasm_text
+ self.analyzer.Analyze()
+ print_mock.assert_has_calls(
+ [
+ mock.call(
+ "Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" hook_task (8) [??:0] 1000"),
+ mock.call(
+ "Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460"
+ ),
+ mock.call("Call Trace:"),
+ mock.call(" console_task (8) [??:0] 2000"),
+ mock.call(" -> ??[??:0] 2002"),
+ mock.call(" hook_task (8) [??:0] 1000"),
+ mock.call("Unresolved indirect callsites:"),
+ mock.call(" In function console_task:"),
+ mock.call(" -> ??[??:0] 200a"),
+ mock.call("Unresolved annotation signatures:"),
+ mock.call(" fake_func: function is not found"),
+ ]
+ )
+
+ with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."):
+ checkoutput_mock.side_effect = OSError()
+ self.analyzer.Analyze()
+
+ with self.assertRaisesRegexp(
+ sa.StackAnalyzerError, "objdump failed to disassemble."
+ ):
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ self.analyzer.Analyze()
+
+ @mock.patch("subprocess.check_output")
+ @mock.patch("stack_analyzer.ParseArgs")
+ def testMain(self, parseargs_mock, checkoutput_mock):
+ symbol_text = (
+ "1000 g F .text 0000015c .hidden hook_task\n"
+ "2000 g F .text 0000051c .hidden console_task\n"
+ )
+ rodata_text = (
+ "\n"
+ "Contents of section .rodata:\n"
+ " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n"
+ )
+
+ args = mock.MagicMock(
+ elf_path="./ec.RW.elf",
+ export_taskinfo="fake",
+ section="RW",
+ objdump="objdump",
+ addr2line="addr2line",
+ annotation="fake",
+ )
+ parseargs_mock.return_value = args
+
+ with mock.patch("os.path.exists") as path_mock:
+ path_mock.return_value = False
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch("builtins.open", mock.mock_open()) as open_mock:
+ sa.main()
+ print_mock.assert_any_call(
+ "Warning: Annotation file fake does not exist."
+ )
+
+ with mock.patch("os.path.exists") as path_mock:
+ path_mock.return_value = True
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch("builtins.open", mock.mock_open()) as open_mock:
+ open_mock.side_effect = IOError()
+ sa.main()
+ print_mock.assert_called_once_with(
+ "Error: Failed to open annotation file fake."
+ )
+
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch("builtins.open", mock.mock_open()) as open_mock:
+ open_mock.return_value.read.side_effect = ["{", ""]
+ sa.main()
+ open_mock.assert_called_once_with("fake", "r")
+ print_mock.assert_called_once_with(
+ "Error: Failed to parse annotation file fake."
+ )
+
+ with mock.patch("builtins.print") as print_mock:
+ with mock.patch(
+ "builtins.open", mock.mock_open(read_data="")
+ ) as open_mock:
+ sa.main()
+ print_mock.assert_called_once_with(
+ "Error: Invalid annotation file fake."
+ )
+
+ args.annotation = None
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.side_effect = [symbol_text, rodata_text]
+ sa.main()
+ print_mock.assert_called_once_with("Error: Failed to load export_taskinfo.")
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "")
+ sa.main()
+ print_mock.assert_called_once_with(
+ "Error: objdump failed to dump symbol table or rodata."
+ )
+
+ with mock.patch("builtins.print") as print_mock:
+ checkoutput_mock.side_effect = OSError()
+ sa.main()
+ print_mock.assert_called_once_with("Error: Failed to run objdump.")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extra/tigertool/ecusb/__init__.py b/extra/tigertool/ecusb/__init__.py
index fe4dbc6749..7a48fdb360 100644
--- a/extra/tigertool/ecusb/__init__.py
+++ b/extra/tigertool/ecusb/__init__.py
@@ -6,4 +6,4 @@
# pylint: disable=bad-indentation,docstring-section-indent
# pylint: disable=docstring-trailing-quotes
-__all__ = ['tiny_servo_common', 'stm32usb', 'stm32uart', 'pty_driver']
+__all__ = ["tiny_servo_common", "stm32usb", "stm32uart", "pty_driver"]
diff --git a/extra/tigertool/ecusb/pty_driver.py b/extra/tigertool/ecusb/pty_driver.py
index 09ef8c42e4..037ff8d529 100644
--- a/extra/tigertool/ecusb/pty_driver.py
+++ b/extra/tigertool/ecusb/pty_driver.py
@@ -17,8 +17,9 @@ import ast
import errno
import fcntl
import os
-import pexpect
import time
+
+import pexpect
from pexpect import fdpexpect
# Expecting a result in 3 seconds is plenty even for slow platforms.
@@ -27,281 +28,285 @@ FLUSH_UART_TIMEOUT = 1
class ptyError(Exception):
- """Exception class for pty errors."""
+ """Exception class for pty errors."""
UART_PARAMS = {
- 'uart_cmd': None,
- 'uart_multicmd': None,
- 'uart_regexp': None,
- 'uart_timeout': DEFAULT_UART_TIMEOUT,
+ "uart_cmd": None,
+ "uart_multicmd": None,
+ "uart_regexp": None,
+ "uart_timeout": DEFAULT_UART_TIMEOUT,
}
class ptyDriver(object):
- """Automate interactive commands on a pty interface."""
- def __init__(self, interface, params, fast=False):
- """Init class variables."""
- self._child = None
- self._fd = None
- self._interface = interface
- self._pty_path = self._interface.get_pty()
- self._dict = UART_PARAMS.copy()
- self._fast = fast
-
- def __del__(self):
- self.close()
-
- def close(self):
- """Close any open files and interfaces."""
- if self._fd:
- self._close()
- self._interface.close()
-
- def _open(self):
- """Connect to serial device and create pexpect interface."""
- assert self._fd is None
- self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK)
- # Don't allow forked processes to access.
- fcntl.fcntl(self._fd, fcntl.F_SETFD,
- fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC)
- self._child = fdpexpect.fdspawn(self._fd)
- # pexpect defaults to a 100ms delay before sending characters, to
- # work around race conditions in ssh. We don't need this feature
- # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up.
- if self._fast:
- self._child.delaybeforesend = 0.001
-
- def _close(self):
- """Close serial device connection."""
- os.close(self._fd)
- self._fd = None
- self._child = None
-
- def _flush(self):
- """Flush device output to prevent previous messages interfering."""
- if self._child.sendline('') != 1:
- raise ptyError('Failed to send newline.')
- # Have a maximum timeout for the flush operation. We should have cleared
- # all data from the buffer, but if data is regularly being generated, we
- # can't guarantee it will ever stop.
- flush_end_time = time.time() + FLUSH_UART_TIMEOUT
- while time.time() <= flush_end_time:
- try:
- self._child.expect('.', timeout=0.01)
- except (pexpect.TIMEOUT, pexpect.EOF):
- break
- except OSError as e:
- # EAGAIN indicates no data available, maybe we didn't wait long enough.
- if e.errno != errno.EAGAIN:
- raise
- break
-
- def _send(self, cmds):
- """Send command to EC.
-
- This function always flushes serial device before sending, and is used as
- a wrapper function to make sure the channel is always flushed before
- sending commands.
-
- Args:
- cmds: The commands to send to the device, either a list or a string.
-
- Raises:
- ptyError: Raised when writing to the device fails.
- """
- self._flush()
- if not isinstance(cmds, list):
- cmds = [cmds]
- for cmd in cmds:
- if self._child.sendline(cmd) != len(cmd) + 1:
- raise ptyError('Failed to send command.')
-
- def _issue_cmd(self, cmds):
- """Send command to the device and do not wait for response.
-
- Args:
- cmds: The commands to send to the device, either a list or a string.
- """
- self._issue_cmd_get_results(cmds, [])
-
- def _issue_cmd_get_results(self, cmds,
- regex_list, timeout=DEFAULT_UART_TIMEOUT):
- """Send command to the device and wait for response.
-
- This function waits for response message matching a regular
- expressions.
-
- Args:
- cmds: The commands issued, either a list or a string.
- regex_list: List of Regular expressions used to match response message.
- Note1, list must be ordered.
- Note2, empty list sends and returns.
- timeout: time to wait for matching results before failing.
-
- Returns:
- List of tuples, each of which contains the entire matched string and
- all the subgroups of the match. None if not matched.
- For example:
- response of the given command:
- High temp: 37.2
- Low temp: 36.4
- regex_list:
- ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)']
- returns:
- [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')]
-
- Raises:
- ptyError: If timed out waiting for a response
- """
- result_list = []
- self._open()
- try:
- self._send(cmds)
- for regex in regex_list:
- self._child.expect(regex, timeout)
- match = self._child.match
- lastindex = match.lastindex if match and match.lastindex else 0
- # Create a tuple which contains the entire matched string and all
- # the subgroups of the match.
- result = match.group(*range(lastindex + 1)) if match else None
- if result:
- result = tuple(res.decode('utf-8') for res in result)
- result_list.append(result)
- except pexpect.TIMEOUT:
- raise ptyError('Timeout waiting for response.')
- finally:
- if not regex_list:
- # Must be longer than delaybeforesend
- time.sleep(0.1)
- self._close()
- return result_list
-
- def _issue_cmd_get_multi_results(self, cmd, regex):
- """Send command to the device and wait for multiple response.
-
- This function waits for arbitrary number of response message
- matching a regular expression.
-
- Args:
- cmd: The command issued.
- regex: Regular expression used to match response message.
-
- Returns:
- List of tuples, each of which contains the entire matched string and
- all the subgroups of the match. None if not matched.
- """
- result_list = []
- self._open()
- try:
- self._send(cmd)
- while True:
+ """Automate interactive commands on a pty interface."""
+
+ def __init__(self, interface, params, fast=False):
+ """Init class variables."""
+ self._child = None
+ self._fd = None
+ self._interface = interface
+ self._pty_path = self._interface.get_pty()
+ self._dict = UART_PARAMS.copy()
+ self._fast = fast
+
+ def __del__(self):
+ self.close()
+
+ def close(self):
+ """Close any open files and interfaces."""
+ if self._fd:
+ self._close()
+ self._interface.close()
+
+ def _open(self):
+ """Connect to serial device and create pexpect interface."""
+ assert self._fd is None
+ self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK)
+ # Don't allow forked processes to access.
+ fcntl.fcntl(
+ self._fd,
+ fcntl.F_SETFD,
+ fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC,
+ )
+ self._child = fdpexpect.fdspawn(self._fd)
+ # pexpect defaults to a 100ms delay before sending characters, to
+ # work around race conditions in ssh. We don't need this feature
+ # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up.
+ if self._fast:
+ self._child.delaybeforesend = 0.001
+
+ def _close(self):
+ """Close serial device connection."""
+ os.close(self._fd)
+ self._fd = None
+ self._child = None
+
+ def _flush(self):
+ """Flush device output to prevent previous messages interfering."""
+ if self._child.sendline("") != 1:
+ raise ptyError("Failed to send newline.")
+ # Have a maximum timeout for the flush operation. We should have cleared
+ # all data from the buffer, but if data is regularly being generated, we
+ # can't guarantee it will ever stop.
+ flush_end_time = time.time() + FLUSH_UART_TIMEOUT
+ while time.time() <= flush_end_time:
+ try:
+ self._child.expect(".", timeout=0.01)
+ except (pexpect.TIMEOUT, pexpect.EOF):
+ break
+ except OSError as e:
+ # EAGAIN indicates no data available, maybe we didn't wait long enough.
+ if e.errno != errno.EAGAIN:
+ raise
+ break
+
+ def _send(self, cmds):
+ """Send command to EC.
+
+ This function always flushes serial device before sending, and is used as
+ a wrapper function to make sure the channel is always flushed before
+ sending commands.
+
+ Args:
+ cmds: The commands to send to the device, either a list or a string.
+
+ Raises:
+ ptyError: Raised when writing to the device fails.
+ """
+ self._flush()
+ if not isinstance(cmds, list):
+ cmds = [cmds]
+ for cmd in cmds:
+ if self._child.sendline(cmd) != len(cmd) + 1:
+ raise ptyError("Failed to send command.")
+
+ def _issue_cmd(self, cmds):
+ """Send command to the device and do not wait for response.
+
+ Args:
+ cmds: The commands to send to the device, either a list or a string.
+ """
+ self._issue_cmd_get_results(cmds, [])
+
+ def _issue_cmd_get_results(self, cmds, regex_list, timeout=DEFAULT_UART_TIMEOUT):
+ """Send command to the device and wait for response.
+
+ This function waits for response message matching a regular
+ expressions.
+
+ Args:
+ cmds: The commands issued, either a list or a string.
+ regex_list: List of Regular expressions used to match response message.
+ Note1, list must be ordered.
+ Note2, empty list sends and returns.
+ timeout: time to wait for matching results before failing.
+
+ Returns:
+ List of tuples, each of which contains the entire matched string and
+ all the subgroups of the match. None if not matched.
+ For example:
+ response of the given command:
+ High temp: 37.2
+ Low temp: 36.4
+ regex_list:
+ ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)']
+ returns:
+ [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')]
+
+ Raises:
+ ptyError: If timed out waiting for a response
+ """
+ result_list = []
+ self._open()
try:
- self._child.expect(regex, timeout=0.1)
- match = self._child.match
- lastindex = match.lastindex if match and match.lastindex else 0
- # Create a tuple which contains the entire matched string and all
- # the subgroups of the match.
- result = match.group(*range(lastindex + 1)) if match else None
- if result:
- result = tuple(res.decode('utf-8') for res in result)
- result_list.append(result)
+ self._send(cmds)
+ for regex in regex_list:
+ self._child.expect(regex, timeout)
+ match = self._child.match
+ lastindex = match.lastindex if match and match.lastindex else 0
+ # Create a tuple which contains the entire matched string and all
+ # the subgroups of the match.
+ result = match.group(*range(lastindex + 1)) if match else None
+ if result:
+ result = tuple(res.decode("utf-8") for res in result)
+ result_list.append(result)
except pexpect.TIMEOUT:
- break
- finally:
- self._close()
- return result_list
-
- def _Set_uart_timeout(self, timeout):
- """Set timeout value for waiting for the device response.
-
- Args:
- timeout: Timeout value in second.
- """
- self._dict['uart_timeout'] = timeout
-
- def _Get_uart_timeout(self):
- """Get timeout value for waiting for the device response.
-
- Returns:
- Timeout value in second.
- """
- return self._dict['uart_timeout']
-
- def _Set_uart_regexp(self, regexp):
- """Set the list of regular expressions which matches the command response.
-
- Args:
- regexp: A string which contains a list of regular expressions.
- """
- if not isinstance(regexp, str):
- raise ptyError('The argument regexp should be a string.')
- self._dict['uart_regexp'] = ast.literal_eval(regexp)
-
- def _Get_uart_regexp(self):
- """Get the list of regular expressions which matches the command response.
-
- Returns:
- A string which contains a list of regular expressions.
- """
- return str(self._dict['uart_regexp'])
-
- def _Set_uart_cmd(self, cmd):
- """Set the UART command and send it to the device.
-
- If ec_uart_regexp is 'None', the command is just sent and it doesn't care
- about its response.
-
- If ec_uart_regexp is not 'None', the command is send and its response,
- which matches the regular expression of ec_uart_regexp, will be kept.
- Use its getter to obtain this result. If no match after ec_uart_timeout
- seconds, a timeout error will be raised.
-
- Args:
- cmd: A string of UART command.
- """
- if self._dict['uart_regexp']:
- self._dict['uart_cmd'] = self._issue_cmd_get_results(
- cmd, self._dict['uart_regexp'], self._dict['uart_timeout'])
- else:
- self._dict['uart_cmd'] = None
- self._issue_cmd(cmd)
-
- def _Set_uart_multicmd(self, cmds):
- """Set multiple UART commands and send them to the device.
-
- Note that ec_uart_regexp is not supported to match the results.
-
- Args:
- cmds: A semicolon-separated string of UART commands.
- """
- self._issue_cmd(cmds.split(';'))
-
- def _Get_uart_cmd(self):
- """Get the result of the latest UART command.
-
- Returns:
- A string which contains a list of tuples, each of which contains the
- entire matched string and all the subgroups of the match. 'None' if
- the ec_uart_regexp is 'None'.
- """
- return str(self._dict['uart_cmd'])
-
- def _Set_uart_capture(self, cmd):
- """Set UART capture mode (on or off).
-
- Once capture is enabled, UART output could be collected periodically by
- invoking _Get_uart_stream() below.
-
- Args:
- cmd: True for on, False for off
- """
- self._interface.set_capture_active(cmd)
-
- def _Get_uart_capture(self):
- """Get the UART capture mode (on or off)."""
- return self._interface.get_capture_active()
-
- def _Get_uart_stream(self):
- """Get uart stream generated since last time."""
- return self._interface.get_stream()
+ raise ptyError("Timeout waiting for response.")
+ finally:
+ if not regex_list:
+ # Must be longer than delaybeforesend
+ time.sleep(0.1)
+ self._close()
+ return result_list
+
+ def _issue_cmd_get_multi_results(self, cmd, regex):
+ """Send command to the device and wait for multiple response.
+
+ This function waits for arbitrary number of response message
+ matching a regular expression.
+
+ Args:
+ cmd: The command issued.
+ regex: Regular expression used to match response message.
+
+ Returns:
+ List of tuples, each of which contains the entire matched string and
+ all the subgroups of the match. None if not matched.
+ """
+ result_list = []
+ self._open()
+ try:
+ self._send(cmd)
+ while True:
+ try:
+ self._child.expect(regex, timeout=0.1)
+ match = self._child.match
+ lastindex = match.lastindex if match and match.lastindex else 0
+ # Create a tuple which contains the entire matched string and all
+ # the subgroups of the match.
+ result = match.group(*range(lastindex + 1)) if match else None
+ if result:
+ result = tuple(res.decode("utf-8") for res in result)
+ result_list.append(result)
+ except pexpect.TIMEOUT:
+ break
+ finally:
+ self._close()
+ return result_list
+
+ def _Set_uart_timeout(self, timeout):
+ """Set timeout value for waiting for the device response.
+
+ Args:
+ timeout: Timeout value in second.
+ """
+ self._dict["uart_timeout"] = timeout
+
+ def _Get_uart_timeout(self):
+ """Get timeout value for waiting for the device response.
+
+ Returns:
+ Timeout value in second.
+ """
+ return self._dict["uart_timeout"]
+
+ def _Set_uart_regexp(self, regexp):
+ """Set the list of regular expressions which matches the command response.
+
+ Args:
+ regexp: A string which contains a list of regular expressions.
+ """
+ if not isinstance(regexp, str):
+ raise ptyError("The argument regexp should be a string.")
+ self._dict["uart_regexp"] = ast.literal_eval(regexp)
+
+ def _Get_uart_regexp(self):
+ """Get the list of regular expressions which matches the command response.
+
+ Returns:
+ A string which contains a list of regular expressions.
+ """
+ return str(self._dict["uart_regexp"])
+
+ def _Set_uart_cmd(self, cmd):
+ """Set the UART command and send it to the device.
+
+ If ec_uart_regexp is 'None', the command is just sent and it doesn't care
+ about its response.
+
+ If ec_uart_regexp is not 'None', the command is send and its response,
+ which matches the regular expression of ec_uart_regexp, will be kept.
+ Use its getter to obtain this result. If no match after ec_uart_timeout
+ seconds, a timeout error will be raised.
+
+ Args:
+ cmd: A string of UART command.
+ """
+ if self._dict["uart_regexp"]:
+ self._dict["uart_cmd"] = self._issue_cmd_get_results(
+ cmd, self._dict["uart_regexp"], self._dict["uart_timeout"]
+ )
+ else:
+ self._dict["uart_cmd"] = None
+ self._issue_cmd(cmd)
+
+ def _Set_uart_multicmd(self, cmds):
+ """Set multiple UART commands and send them to the device.
+
+ Note that ec_uart_regexp is not supported to match the results.
+
+ Args:
+ cmds: A semicolon-separated string of UART commands.
+ """
+ self._issue_cmd(cmds.split(";"))
+
+ def _Get_uart_cmd(self):
+ """Get the result of the latest UART command.
+
+ Returns:
+ A string which contains a list of tuples, each of which contains the
+ entire matched string and all the subgroups of the match. 'None' if
+ the ec_uart_regexp is 'None'.
+ """
+ return str(self._dict["uart_cmd"])
+
+ def _Set_uart_capture(self, cmd):
+ """Set UART capture mode (on or off).
+
+ Once capture is enabled, UART output could be collected periodically by
+ invoking _Get_uart_stream() below.
+
+ Args:
+ cmd: True for on, False for off
+ """
+ self._interface.set_capture_active(cmd)
+
+ def _Get_uart_capture(self):
+ """Get the UART capture mode (on or off)."""
+ return self._interface.get_capture_active()
+
+ def _Get_uart_stream(self):
+ """Get uart stream generated since last time."""
+ return self._interface.get_stream()
diff --git a/extra/tigertool/ecusb/stm32uart.py b/extra/tigertool/ecusb/stm32uart.py
index 95219455a9..5794bd091b 100644
--- a/extra/tigertool/ecusb/stm32uart.py
+++ b/extra/tigertool/ecusb/stm32uart.py
@@ -17,232 +17,244 @@ import termios
import threading
import time
import tty
+
import usb
from . import stm32usb
class SuartError(Exception):
- """Class for exceptions of Suart."""
- def __init__(self, msg, value=0):
- """SuartError constructor.
+ """Class for exceptions of Suart."""
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SuartError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
+ def __init__(self, msg, value=0):
+ """SuartError constructor.
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SuartError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
-class Suart(object):
- """Provide interface to stm32 serial usb endpoint."""
- def __init__(self, vendor=0x18d1, product=0x501a, interface=0,
- serialname=None, debuglog=False):
- """Suart contstructor.
-
- Initializes stm32 USB stream interface.
-
- Args:
- vendor: usb vendor id of stm32 device
- product: usb product id of stm32 device
- interface: interface number of stm32 device to use
- serialname: serial name to target. Defaults to None.
- debuglog: chatty output. Defaults to False.
-
- Raises:
- SuartError: If init fails
- """
- self._ptym = None
- self._ptys = None
- self._ptyname = None
- self._rx_thread = None
- self._tx_thread = None
- self._debuglog = debuglog
- self._susb = stm32usb.Susb(vendor=vendor, product=product,
- interface=interface, serialname=serialname)
- self._running = False
-
- def __del__(self):
- """Suart destructor."""
- self.close()
-
- def close(self):
- """Stop all running threads."""
- self._running = False
- if self._rx_thread:
- self._rx_thread.join(2)
- self._rx_thread = None
- if self._tx_thread:
- self._tx_thread.join(2)
- self._tx_thread = None
- self._susb.close()
-
- def run_rx_thread(self):
- """Background loop to pass data from USB to pty."""
- ep = select.epoll()
- ep.register(self._ptym, select.EPOLLHUP)
- try:
- while self._running:
- events = ep.poll(0)
- # Check if the pty is connected to anything, or hungup.
- if not events:
- try:
- r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
- if r:
- if self._debuglog:
- print(''.join([chr(x) for x in r]), end='')
- os.write(self._ptym, r)
-
- # If we miss some characters on pty disconnect, that's fine.
- # ep.read() also throws USBError on timeout, which we discard.
- except OSError:
- pass
- except usb.core.USBError:
- pass
- else:
- time.sleep(.1)
- except Exception as e:
- raise e
-
- def run_tx_thread(self):
- """Background loop to pass data from pty to USB."""
- ep = select.epoll()
- ep.register(self._ptym, select.EPOLLHUP)
- try:
- while self._running:
- events = ep.poll(0)
- # Check if the pty is connected to anything, or hungup.
- if not events:
- try:
- r = os.read(self._ptym, 64)
- # TODO(crosbug.com/936182): Remove when the servo v4/micro console
- # issues are fixed.
- time.sleep(0.001)
- if r:
- self._susb._write_ep.write(r, self._susb.TIMEOUT_MS)
-
- except OSError:
- pass
- except usb.core.USBError:
- pass
- else:
- time.sleep(.1)
- except Exception as e:
- raise e
-
- def run(self):
- """Creates pthreads to poll stm32 & PTY for data."""
- m, s = os.openpty()
- self._ptyname = os.ttyname(s)
-
- self._ptym = m
- self._ptys = s
-
- os.fchmod(s, 0o660)
-
- # Change the owner and group of the PTY to the user who started servod.
- try:
- uid = int(os.environ.get('SUDO_UID', -1))
- except TypeError:
- uid = -1
- try:
- gid = int(os.environ.get('SUDO_GID', -1))
- except TypeError:
- gid = -1
- os.fchown(s, uid, gid)
-
- tty.setraw(self._ptym, termios.TCSADRAIN)
-
- # Generate a HUP flag on pty slave fd.
- os.fdopen(s).close()
-
- self._running = True
-
- self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[])
- self._rx_thread.daemon = True
- self._rx_thread.start()
-
- self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[])
- self._tx_thread.daemon = True
- self._tx_thread.start()
-
- def get_uart_props(self):
- """Get the uart's properties.
-
- Returns:
- dict where:
- baudrate: integer of uarts baudrate
- bits: integer, number of bits of data Can be 5|6|7|8 inclusive
- parity: integer, parity of 0-2 inclusive where:
- 0: no parity
- 1: odd parity
- 2: even parity
- sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
- 0: 1 stop bit
- 1: 1.5 stop bits
- 2: 2 stop bits
- """
- return {
- 'baudrate': 115200,
- 'bits': 8,
- 'parity': 0,
- 'sbits': 1,
- }
-
- def set_uart_props(self, line_props):
- """Set the uart's properties.
-
- Note that Suart cannot set properties
- and will fail if the properties are not the default 115200,8n1.
-
- Args:
- line_props: dict where:
- baudrate: integer of uarts baudrate
- bits: integer, number of bits of data ( prior to stop bit)
- parity: integer, parity of 0-2 inclusive where
- 0: no parity
- 1: odd parity
- 2: even parity
- sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
- 0: 1 stop bit
- 1: 1.5 stop bits
- 2: 2 stop bits
-
- Raises:
- SuartError: If requested line properties are not the default.
- """
- curr_props = self.get_uart_props()
- for prop in line_props:
- if line_props[prop] != curr_props[prop]:
- raise SuartError('Line property %s cannot be set from %s to %s' % (
- prop, curr_props[prop], line_props[prop]))
- return True
-
- def get_pty(self):
- """Gets path to pty for communication to/from uart.
-
- Returns:
- String path to the pty connected to the uart
- """
- return self._ptyname
+class Suart(object):
+ """Provide interface to stm32 serial usb endpoint."""
+
+ def __init__(
+ self,
+ vendor=0x18D1,
+ product=0x501A,
+ interface=0,
+ serialname=None,
+ debuglog=False,
+ ):
+ """Suart contstructor.
+
+ Initializes stm32 USB stream interface.
+
+ Args:
+ vendor: usb vendor id of stm32 device
+ product: usb product id of stm32 device
+ interface: interface number of stm32 device to use
+ serialname: serial name to target. Defaults to None.
+ debuglog: chatty output. Defaults to False.
+
+ Raises:
+ SuartError: If init fails
+ """
+ self._ptym = None
+ self._ptys = None
+ self._ptyname = None
+ self._rx_thread = None
+ self._tx_thread = None
+ self._debuglog = debuglog
+ self._susb = stm32usb.Susb(
+ vendor=vendor, product=product, interface=interface, serialname=serialname
+ )
+ self._running = False
+
+ def __del__(self):
+ """Suart destructor."""
+ self.close()
+
+ def close(self):
+ """Stop all running threads."""
+ self._running = False
+ if self._rx_thread:
+ self._rx_thread.join(2)
+ self._rx_thread = None
+ if self._tx_thread:
+ self._tx_thread.join(2)
+ self._tx_thread = None
+ self._susb.close()
+
+ def run_rx_thread(self):
+ """Background loop to pass data from USB to pty."""
+ ep = select.epoll()
+ ep.register(self._ptym, select.EPOLLHUP)
+ try:
+ while self._running:
+ events = ep.poll(0)
+ # Check if the pty is connected to anything, or hungup.
+ if not events:
+ try:
+ r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
+ if r:
+ if self._debuglog:
+ print("".join([chr(x) for x in r]), end="")
+ os.write(self._ptym, r)
+
+ # If we miss some characters on pty disconnect, that's fine.
+ # ep.read() also throws USBError on timeout, which we discard.
+ except OSError:
+ pass
+ except usb.core.USBError:
+ pass
+ else:
+ time.sleep(0.1)
+ except Exception as e:
+ raise e
+
+ def run_tx_thread(self):
+ """Background loop to pass data from pty to USB."""
+ ep = select.epoll()
+ ep.register(self._ptym, select.EPOLLHUP)
+ try:
+ while self._running:
+ events = ep.poll(0)
+ # Check if the pty is connected to anything, or hungup.
+ if not events:
+ try:
+ r = os.read(self._ptym, 64)
+ # TODO(crosbug.com/936182): Remove when the servo v4/micro console
+ # issues are fixed.
+ time.sleep(0.001)
+ if r:
+ self._susb._write_ep.write(r, self._susb.TIMEOUT_MS)
+
+ except OSError:
+ pass
+ except usb.core.USBError:
+ pass
+ else:
+ time.sleep(0.1)
+ except Exception as e:
+ raise e
+
+ def run(self):
+ """Creates pthreads to poll stm32 & PTY for data."""
+ m, s = os.openpty()
+ self._ptyname = os.ttyname(s)
+
+ self._ptym = m
+ self._ptys = s
+
+ os.fchmod(s, 0o660)
+
+ # Change the owner and group of the PTY to the user who started servod.
+ try:
+ uid = int(os.environ.get("SUDO_UID", -1))
+ except TypeError:
+ uid = -1
+
+ try:
+ gid = int(os.environ.get("SUDO_GID", -1))
+ except TypeError:
+ gid = -1
+ os.fchown(s, uid, gid)
+
+ tty.setraw(self._ptym, termios.TCSADRAIN)
+
+ # Generate a HUP flag on pty slave fd.
+ os.fdopen(s).close()
+
+ self._running = True
+
+ self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[])
+ self._rx_thread.daemon = True
+ self._rx_thread.start()
+
+ self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[])
+ self._tx_thread.daemon = True
+ self._tx_thread.start()
+
+ def get_uart_props(self):
+ """Get the uart's properties.
+
+ Returns:
+ dict where:
+ baudrate: integer of uarts baudrate
+ bits: integer, number of bits of data Can be 5|6|7|8 inclusive
+ parity: integer, parity of 0-2 inclusive where:
+ 0: no parity
+ 1: odd parity
+ 2: even parity
+ sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
+ 0: 1 stop bit
+ 1: 1.5 stop bits
+ 2: 2 stop bits
+ """
+ return {
+ "baudrate": 115200,
+ "bits": 8,
+ "parity": 0,
+ "sbits": 1,
+ }
+
+ def set_uart_props(self, line_props):
+ """Set the uart's properties.
+
+ Note that Suart cannot set properties
+ and will fail if the properties are not the default 115200,8n1.
+
+ Args:
+ line_props: dict where:
+ baudrate: integer of uarts baudrate
+ bits: integer, number of bits of data ( prior to stop bit)
+ parity: integer, parity of 0-2 inclusive where
+ 0: no parity
+ 1: odd parity
+ 2: even parity
+ sbits: integer, number of stop bits. Can be 0|1|2 inclusive where:
+ 0: 1 stop bit
+ 1: 1.5 stop bits
+ 2: 2 stop bits
+
+ Raises:
+ SuartError: If requested line properties are not the default.
+ """
+ curr_props = self.get_uart_props()
+ for prop in line_props:
+ if line_props[prop] != curr_props[prop]:
+ raise SuartError(
+ "Line property %s cannot be set from %s to %s"
+ % (prop, curr_props[prop], line_props[prop])
+ )
+ return True
+
+ def get_pty(self):
+ """Gets path to pty for communication to/from uart.
+
+ Returns:
+ String path to the pty connected to the uart
+ """
+ return self._ptyname
def main():
- """Run a suart test with the default parameters."""
- try:
- sobj = Suart()
- sobj.run()
+ """Run a suart test with the default parameters."""
+ try:
+ sobj = Suart()
+ sobj.run()
- # run() is a thread so just busy wait to mimic server.
- while True:
- # Ours sleeps to eleven!
- time.sleep(11)
- except KeyboardInterrupt:
- sys.exit(0)
+ # run() is a thread so just busy wait to mimic server.
+ while True:
+ # Ours sleeps to eleven!
+ time.sleep(11)
+ except KeyboardInterrupt:
+ sys.exit(0)
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/extra/tigertool/ecusb/stm32usb.py b/extra/tigertool/ecusb/stm32usb.py
index bfd5fbb1fb..4b2b23fbac 100644
--- a/extra/tigertool/ecusb/stm32usb.py
+++ b/extra/tigertool/ecusb/stm32usb.py
@@ -12,108 +12,115 @@ import usb
class SusbError(Exception):
- """Class for exceptions of Susb."""
- def __init__(self, msg, value=0):
- """SusbError constructor.
+ """Class for exceptions of Susb."""
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SusbError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
+ def __init__(self, msg, value=0):
+ """SusbError constructor.
+
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SusbError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
class Susb(object):
- """Provide stm32 USB functionality.
-
- Instance Variables:
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
- READ_ENDPOINT = 0x81
- WRITE_ENDPOINT = 0x1
- TIMEOUT_MS = 100
-
- def __init__(self, vendor=0x18d1,
- product=0x5027, interface=1, serialname=None, logger=None):
- """Susb constructor.
-
- Discovers and connects to stm32 USB endpoints.
-
- Args:
- vendor: usb vendor id of stm32 device.
- product: usb product id of stm32 device.
- interface: interface number ( 1 - 4 ) of stm32 device to use.
- serialname: string of device serialname.
- logger: none
-
- Raises:
- SusbError: An error accessing Susb object
+ """Provide stm32 USB functionality.
+
+ Instance Variables:
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- self._vendor = vendor
- self._product = product
- self._interface = interface
- self._serialname = serialname
- self._find_device()
-
- def _find_device(self):
- """Set up the usb endpoint"""
- # Find the stm32.
- dev_g = usb.core.find(idVendor=self._vendor, idProduct=self._product,
- find_all=True)
- dev_list = list(dev_g)
-
- if not dev_list:
- raise SusbError('USB device not found')
-
- # Check if we have multiple stm32s and we've specified the serial.
- dev = None
- if self._serialname:
- for d in dev_list:
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == self._serialname:
- dev = d
- break
- if dev is None:
- raise SusbError('USB device(%s) not found' % self._serialname)
- else:
- try:
- dev = dev_list[0]
- except StopIteration:
- raise SusbError('USB device %04x:%04x not found' % (
- self._vendor, self._product))
-
- # If we can't set configuration, it's already been set.
- try:
- dev.set_configuration()
- except usb.core.USBError:
- pass
-
- self._dev = dev
-
- # Get an endpoint instance.
- cfg = dev.get_active_configuration()
- intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface)
- self._intf = intf
- if not intf:
- raise SusbError('Interface %04x:%04x - 0x%x not found' % (
- self._vendor, self._product, self._interface))
-
- # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel
- # module driver that produces a ttyUSB, or direct endpoint access, but
- # can't do both at the same time.
- if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
- dev.detach_kernel_driver(intf.bInterfaceNumber)
-
- read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
- read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
- self._read_ep = read_ep
-
- write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
- write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
- self._write_ep = write_ep
-
- def close(self):
- usb.util.dispose_resources(self._dev)
+
+ READ_ENDPOINT = 0x81
+ WRITE_ENDPOINT = 0x1
+ TIMEOUT_MS = 100
+
+ def __init__(
+ self, vendor=0x18D1, product=0x5027, interface=1, serialname=None, logger=None
+ ):
+ """Susb constructor.
+
+ Discovers and connects to stm32 USB endpoints.
+
+ Args:
+ vendor: usb vendor id of stm32 device.
+ product: usb product id of stm32 device.
+ interface: interface number ( 1 - 4 ) of stm32 device to use.
+ serialname: string of device serialname.
+ logger: none
+
+ Raises:
+ SusbError: An error accessing Susb object
+ """
+ self._vendor = vendor
+ self._product = product
+ self._interface = interface
+ self._serialname = serialname
+ self._find_device()
+
+ def _find_device(self):
+ """Set up the usb endpoint"""
+ # Find the stm32.
+ dev_g = usb.core.find(
+ idVendor=self._vendor, idProduct=self._product, find_all=True
+ )
+ dev_list = list(dev_g)
+
+ if not dev_list:
+ raise SusbError("USB device not found")
+
+ # Check if we have multiple stm32s and we've specified the serial.
+ dev = None
+ if self._serialname:
+ for d in dev_list:
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == self._serialname:
+ dev = d
+ break
+ if dev is None:
+ raise SusbError("USB device(%s) not found" % self._serialname)
+ else:
+ try:
+ dev = dev_list[0]
+ except StopIteration:
+ raise SusbError(
+ "USB device %04x:%04x not found" % (self._vendor, self._product)
+ )
+
+ # If we can't set configuration, it's already been set.
+ try:
+ dev.set_configuration()
+ except usb.core.USBError:
+ pass
+
+ self._dev = dev
+
+ # Get an endpoint instance.
+ cfg = dev.get_active_configuration()
+ intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface)
+ self._intf = intf
+ if not intf:
+ raise SusbError(
+ "Interface %04x:%04x - 0x%x not found"
+ % (self._vendor, self._product, self._interface)
+ )
+
+ # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel
+ # module driver that produces a ttyUSB, or direct endpoint access, but
+ # can't do both at the same time.
+ if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
+ dev.detach_kernel_driver(intf.bInterfaceNumber)
+
+ read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
+ read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
+ self._read_ep = read_ep
+
+ write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
+ write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
+ self._write_ep = write_ep
+
+ def close(self):
+ usb.util.dispose_resources(self._dev)
diff --git a/extra/tigertool/ecusb/tiny_servo_common.py b/extra/tigertool/ecusb/tiny_servo_common.py
index 726e2e64b7..599bae80dc 100644
--- a/extra/tigertool/ecusb/tiny_servo_common.py
+++ b/extra/tigertool/ecusb/tiny_servo_common.py
@@ -17,215 +17,224 @@ import time
import six
import usb
-from . import pty_driver
-from . import stm32uart
+from . import pty_driver, stm32uart
def get_subprocess_args():
- if six.PY3:
- return {'encoding': 'utf-8'}
- return {}
+ if six.PY3:
+ return {"encoding": "utf-8"}
+ return {}
class TinyServoError(Exception):
- """Exceptions."""
+ """Exceptions."""
def log(output):
- """Print output to console, logfiles can be added here.
+ """Print output to console, logfiles can be added here.
+
+ Args:
+ output: string to output.
+ """
+ sys.stdout.write(output)
+ sys.stdout.write("\n")
+ sys.stdout.flush()
- Args:
- output: string to output.
- """
- sys.stdout.write(output)
- sys.stdout.write('\n')
- sys.stdout.flush()
def check_usb(vidpid, serialname=None):
- """Check if |vidpid| is present on the system's USB.
+ """Check if |vidpid| is present on the system's USB.
+
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specified.
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specified.
+ Returns:
+ True if found, False, otherwise.
+ """
+ if get_usb_dev(vidpid, serialname):
+ return True
- Returns:
- True if found, False, otherwise.
- """
- if get_usb_dev(vidpid, serialname):
- return True
+ return False
- return False
def check_usb_sn(vidpid):
- """Return the serial number
+ """Return the serial number
- Return the serial number of the first USB device with VID:PID vidpid,
- or None if no device is found. This will not work well with two of
- the same device attached.
+ Return the serial number of the first USB device with VID:PID vidpid,
+ or None if no device is found. This will not work well with two of
+ the same device attached.
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- Returns:
- string serial number if found, None otherwise.
- """
- dev = get_usb_dev(vidpid)
+ Returns:
+ string serial number if found, None otherwise.
+ """
+ dev = get_usb_dev(vidpid)
- if dev:
- dev_serial = usb.util.get_string(dev, dev.iSerialNumber)
+ if dev:
+ dev_serial = usb.util.get_string(dev, dev.iSerialNumber)
- return dev_serial
+ return dev_serial
+
+ return None
- return None
def get_usb_dev(vidpid, serialname=None):
- """Return the USB pyusb devie struct
+ """Return the USB pyusb devie struct
+
+ Return the dev struct of the first USB device with VID:PID vidpid,
+ or None if no device is found. If more than one device check serial
+ if supplied.
+
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specified.
+
+ Returns:
+ pyusb device if found, None otherwise.
+ """
+ vidpidst = vidpid.split(":")
+ vid = int(vidpidst[0], 16)
+ pid = int(vidpidst[1], 16)
+
+ dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True)
+ dev_list = list(dev_g)
+
+ if not dev_list:
+ return None
+
+ # Check if we have multiple devices and we've specified the serial.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == serialname:
+ dev = d
+ break
+ if dev is None:
+ return None
+ else:
+ try:
+ dev = dev_list[0]
+ except StopIteration:
+ return None
+
+ return dev
+
- Return the dev struct of the first USB device with VID:PID vidpid,
- or None if no device is found. If more than one device check serial
- if supplied.
+def check_usb_dev(vidpid, serialname=None):
+ """Return the USB dev number
+
+ Return the dev number of the first USB device with VID:PID vidpid,
+ or None if no device is found. If more than one device check serial
+ if supplied.
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specified.
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specified.
- Returns:
- pyusb device if found, None otherwise.
- """
- vidpidst = vidpid.split(':')
- vid = int(vidpidst[0], 16)
- pid = int(vidpidst[1], 16)
+ Returns:
+ usb device number if found, None otherwise.
+ """
+ dev = get_usb_dev(vidpid, serialname=serialname)
- dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True)
- dev_list = list(dev_g)
+ if dev:
+ return dev.address
- if not dev_list:
return None
- # Check if we have multiple devices and we've specified the serial.
- dev = None
- if serialname:
- for d in dev_list:
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == serialname:
- dev = d
- break
- if dev is None:
- return None
- else:
- try:
- dev = dev_list[0]
- except StopIteration:
- return None
-
- return dev
-def check_usb_dev(vidpid, serialname=None):
- """Return the USB dev number
+def wait_for_usb_remove(vidpid, serialname=None, timeout=None):
+ """Wait for USB device with vidpid to be removed.
- Return the dev number of the first USB device with VID:PID vidpid,
- or None if no device is found. If more than one device check serial
- if supplied.
+ Wrapper for wait_for_usb below
+ """
+ wait_for_usb(vidpid, serialname=serialname, timeout=timeout, desiredpresence=False)
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specified.
- Returns:
- usb device number if found, None otherwise.
- """
- dev = get_usb_dev(vidpid, serialname=serialname)
+def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True):
+ """Wait for usb device with vidpid to be present/absent.
- if dev:
- return dev.address
+ Args:
+ vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
+ serialname: serialname if specificed.
+ timeout: timeout in seconds, None for no timeout.
+ desiredpresence: True for present, False for not present.
- return None
+ Raises:
+ TinyServoError: on timeout.
+ """
+ if timeout:
+ finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout)
+ while check_usb(vidpid, serialname) != desiredpresence:
+ time.sleep(0.1)
+ if timeout:
+ if datetime.datetime.now() > finish:
+ raise TinyServoError("Timeout", "Timeout waiting for USB %s" % vidpid)
-def wait_for_usb_remove(vidpid, serialname=None, timeout=None):
- """Wait for USB device with vidpid to be removed.
+def do_serialno(serialno, pty):
+ """Set serialnumber 'serialno' via ec console 'pty'.
- Wrapper for wait_for_usb below
- """
- wait_for_usb(vidpid, serialname=serialname,
- timeout=timeout, desiredpresence=False)
+ Commands are:
+ # > serialno set 1234
+ # Saving serial number
+ # Serial number: 1234
-def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True):
- """Wait for usb device with vidpid to be present/absent.
-
- Args:
- vidpid: string representation of the usb vid:pid, eg. '18d1:2001'
- serialname: serialname if specificed.
- timeout: timeout in seconds, None for no timeout.
- desiredpresence: True for present, False for not present.
-
- Raises:
- TinyServoError: on timeout.
- """
- if timeout:
- finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout)
- while check_usb(vidpid, serialname) != desiredpresence:
- time.sleep(.1)
- if timeout:
- if datetime.datetime.now() > finish:
- raise TinyServoError('Timeout', 'Timeout waiting for USB %s' % vidpid)
+ Args:
+ serialno: string serial number to set.
+ pty: tinyservo console to send commands.
+
+ Raises:
+ TinyServoError: on failure to set.
+ ptyError: on command interface error.
+ """
+ cmd = r"serialno set %s" % serialno
+ regex = r"Serial number:\s+(\S+)"
+
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ sn = results[1].strip().strip("\n\r")
+
+ if sn == serialno:
+ log("Success !")
+ log("Serial set to %s" % sn)
+ else:
+ log("Serial number set to %s but saved as %s." % (serialno, sn))
+ raise TinyServoError(
+ "Serial Number", "Serial number set to %s but saved as %s." % (serialno, sn)
+ )
-def do_serialno(serialno, pty):
- """Set serialnumber 'serialno' via ec console 'pty'.
-
- Commands are:
- # > serialno set 1234
- # Saving serial number
- # Serial number: 1234
-
- Args:
- serialno: string serial number to set.
- pty: tinyservo console to send commands.
-
- Raises:
- TinyServoError: on failure to set.
- ptyError: on command interface error.
- """
- cmd = r'serialno set %s' % serialno
- regex = r'Serial number:\s+(\S+)'
-
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- sn = results[1].strip().strip('\n\r')
-
- if sn == serialno:
- log('Success !')
- log('Serial set to %s' % sn)
- else:
- log('Serial number set to %s but saved as %s.' % (serialno, sn))
- raise TinyServoError(
- 'Serial Number',
- 'Serial number set to %s but saved as %s.' % (serialno, sn))
def setup_tinyservod(vidpid, interface, serialname=None, debuglog=False):
- """Set up a pty
-
- Set up a pty to the ec console in order
- to send commands. Returns a pty_driver object.
-
- Args:
- vidpid: string vidpid of device to access.
- interface: not used.
- serialname: string serial name of device requested, optional.
- debuglog: chatty printout (boolean)
-
- Returns:
- pty object
-
- Raises:
- UsbError, SusbError: on device not found
- """
- vidstr, pidstr = vidpid.split(':')
- vid = int(vidstr, 16)
- pid = int(pidstr, 16)
- suart = stm32uart.Suart(vendor=vid, product=pid,
- interface=interface, serialname=serialname,
- debuglog=debuglog)
- suart.run()
- pty = pty_driver.ptyDriver(suart, [])
-
- return pty
+ """Set up a pty
+
+ Set up a pty to the ec console in order
+ to send commands. Returns a pty_driver object.
+
+ Args:
+ vidpid: string vidpid of device to access.
+ interface: not used.
+ serialname: string serial name of device requested, optional.
+ debuglog: chatty printout (boolean)
+
+ Returns:
+ pty object
+
+ Raises:
+ UsbError, SusbError: on device not found
+ """
+ vidstr, pidstr = vidpid.split(":")
+ vid = int(vidstr, 16)
+ pid = int(pidstr, 16)
+ suart = stm32uart.Suart(
+ vendor=vid,
+ product=pid,
+ interface=interface,
+ serialname=serialname,
+ debuglog=debuglog,
+ )
+ suart.run()
+ pty = pty_driver.ptyDriver(suart, [])
+
+ return pty
diff --git a/extra/tigertool/ecusb/tiny_servod.py b/extra/tigertool/ecusb/tiny_servod.py
index 632d9c3a20..ca5ca63f31 100644
--- a/extra/tigertool/ecusb/tiny_servod.py
+++ b/extra/tigertool/ecusb/tiny_servod.py
@@ -8,47 +8,48 @@
"""Helper class to facilitate communication to servo ec console."""
-from ecusb import pty_driver
-from ecusb import stm32uart
+from ecusb import pty_driver, stm32uart
class TinyServod(object):
- """Helper class to wrap a pty_driver with interface."""
-
- def __init__(self, vid, pid, interface, serialname=None, debug=False):
- """Build the driver and interface.
-
- Args:
- vid: servo device vid
- pid: servo device pid
- interface: which usb interface the servo console is on
- serialname: the servo device serial (if available)
- """
- self._vid = vid
- self._pid = pid
- self._interface = interface
- self._serial = serialname
- self._debug = debug
- self._init()
-
- def _init(self):
- self.suart = stm32uart.Suart(vendor=self._vid,
- product=self._pid,
- interface=self._interface,
- serialname=self._serial,
- debuglog=self._debug)
- self.suart.run()
- self.pty = pty_driver.ptyDriver(self.suart, [])
-
- def reinitialize(self):
- """Reinitialize the connect after a reset/disconnect/etc."""
- self.close()
- self._init()
-
- def close(self):
- """Close out the connection and release resources.
-
- Note: if another TinyServod process or servod itself needs the same device
- it's necessary to call this to ensure the usb device is available.
- """
- self.suart.close()
+ """Helper class to wrap a pty_driver with interface."""
+
+ def __init__(self, vid, pid, interface, serialname=None, debug=False):
+ """Build the driver and interface.
+
+ Args:
+ vid: servo device vid
+ pid: servo device pid
+ interface: which usb interface the servo console is on
+ serialname: the servo device serial (if available)
+ """
+ self._vid = vid
+ self._pid = pid
+ self._interface = interface
+ self._serial = serialname
+ self._debug = debug
+ self._init()
+
+ def _init(self):
+ self.suart = stm32uart.Suart(
+ vendor=self._vid,
+ product=self._pid,
+ interface=self._interface,
+ serialname=self._serial,
+ debuglog=self._debug,
+ )
+ self.suart.run()
+ self.pty = pty_driver.ptyDriver(self.suart, [])
+
+ def reinitialize(self):
+ """Reinitialize the connect after a reset/disconnect/etc."""
+ self.close()
+ self._init()
+
+ def close(self):
+ """Close out the connection and release resources.
+
+ Note: if another TinyServod process or servod itself needs the same device
+ it's necessary to call this to ensure the usb device is available.
+ """
+ self.suart.close()
diff --git a/extra/tigertool/tigertest.py b/extra/tigertool/tigertest.py
index 0cd31c8cce..8f8b2c7f03 100755
--- a/extra/tigertool/tigertest.py
+++ b/extra/tigertool/tigertest.py
@@ -13,7 +13,6 @@ import argparse
import subprocess
import sys
-
# Script to control tigertail USB-C Mux board.
#
# optional arguments:
@@ -35,58 +34,60 @@ import sys
def testCmd(cmd, expected_results):
- """Run command on console, check for success.
-
- Args:
- cmd: shell command to run.
- expected_results: a list object of strings expected in the result.
-
- Raises:
- Exception on fail.
- """
- print('run: ' + cmd)
- try:
- p = subprocess.run(cmd, shell=True, check=False, capture_output=True)
- output = p.stdout.decode('utf-8')
- error = p.stderr.decode('utf-8')
- assert p.returncode == 0
- for result in expected_results:
- output.index(result)
- except Exception as e:
- print('FAIL')
- print('cmd: ' + cmd)
- print('error: ' + str(e))
- print('stdout:\n' + output)
- print('stderr:\n' + error)
- print('expected: ' + str(expected_results))
- print('RC: ' + str(p.returncode))
- raise e
+ """Run command on console, check for success.
+
+ Args:
+ cmd: shell command to run.
+ expected_results: a list object of strings expected in the result.
+
+ Raises:
+ Exception on fail.
+ """
+ print("run: " + cmd)
+ try:
+ p = subprocess.run(cmd, shell=True, check=False, capture_output=True)
+ output = p.stdout.decode("utf-8")
+ error = p.stderr.decode("utf-8")
+ assert p.returncode == 0
+ for result in expected_results:
+ output.index(result)
+ except Exception as e:
+ print("FAIL")
+ print("cmd: " + cmd)
+ print("error: " + str(e))
+ print("stdout:\n" + output)
+ print("stderr:\n" + error)
+ print("expected: " + str(expected_results))
+ print("RC: " + str(p.returncode))
+ raise e
+
def test_sequence():
- testCmd('./tigertool.py --reboot', ['PASS'])
- testCmd('./tigertool.py --setserialno test', ['PASS'])
- testCmd('./tigertool.py --check_serial', ['test', 'PASS'])
- testCmd('./tigertool.py -s test --check_serial', ['test', 'PASS'])
- testCmd('./tigertool.py -m A', ['Mux set to A', 'PASS'])
- testCmd('./tigertool.py -m B', ['Mux set to B', 'PASS'])
- testCmd('./tigertool.py -m off', ['Mux set to off', 'PASS'])
- testCmd('./tigertool.py -p', ['PASS'])
- testCmd('./tigertool.py -r rw', ['PASS'])
- testCmd('./tigertool.py -r ro', ['PASS'])
- testCmd('./tigertool.py --check_version', ['RW', 'RO', 'PASS'])
-
- print('PASS')
+ testCmd("./tigertool.py --reboot", ["PASS"])
+ testCmd("./tigertool.py --setserialno test", ["PASS"])
+ testCmd("./tigertool.py --check_serial", ["test", "PASS"])
+ testCmd("./tigertool.py -s test --check_serial", ["test", "PASS"])
+ testCmd("./tigertool.py -m A", ["Mux set to A", "PASS"])
+ testCmd("./tigertool.py -m B", ["Mux set to B", "PASS"])
+ testCmd("./tigertool.py -m off", ["Mux set to off", "PASS"])
+ testCmd("./tigertool.py -p", ["PASS"])
+ testCmd("./tigertool.py -r rw", ["PASS"])
+ testCmd("./tigertool.py -r ro", ["PASS"])
+ testCmd("./tigertool.py --check_version", ["RW", "RO", "PASS"])
+
+ print("PASS")
+
def main(argv):
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument('-c', '--count', type=int, default=1,
- help='loops to run')
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("-c", "--count", type=int, default=1, help="loops to run")
+
+ opts = parser.parse_args(argv)
- opts = parser.parse_args(argv)
+ for i in range(1, opts.count + 1):
+ print("Iteration: %d" % i)
+ test_sequence()
- for i in range(1, opts.count + 1):
- print('Iteration: %d' % i)
- test_sequence()
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv[1:])
diff --git a/extra/tigertool/tigertool.py b/extra/tigertool/tigertool.py
index 6baae8abdf..37b7b01495 100755
--- a/extra/tigertool/tigertool.py
+++ b/extra/tigertool/tigertool.py
@@ -17,287 +17,308 @@ import time
import ecusb.tiny_servo_common as c
-STM_VIDPID = '18d1:5027'
-serialno = 'Uninitialized'
+STM_VIDPID = "18d1:5027"
+serialno = "Uninitialized"
+
def do_mux(mux, pty):
- """Set mux via ec console 'pty'.
+ """Set mux via ec console 'pty'.
+
+ Args:
+ mux: mux to connect to DUT, 'A', 'B', or 'off'
+ pty: a pty object connected to tigertail
- Args:
- mux: mux to connect to DUT, 'A', 'B', or 'off'
- pty: a pty object connected to tigertail
+ Commands are:
+ # > mux A
+ # TYPE-C mux is A
+ """
+ validmux = ["A", "B", "off"]
+ if mux not in validmux:
+ c.log("Mux setting %s invalid, try one of %s" % (mux, validmux))
+ return False
- Commands are:
- # > mux A
- # TYPE-C mux is A
- """
- validmux = ['A', 'B', 'off']
- if mux not in validmux:
- c.log('Mux setting %s invalid, try one of %s' % (mux, validmux))
- return False
+ cmd = "mux %s" % mux
+ regex = "TYPE\-C mux is ([^\s\r\n]*)\r"
- cmd = 'mux %s' % mux
- regex = 'TYPE\-C mux is ([^\s\r\n]*)\r'
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ result = results[1].strip().strip("\n\r")
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- result = results[1].strip().strip('\n\r')
+ if result != mux:
+ c.log("Mux set to %s but saved as %s." % (mux, result))
+ return False
+ c.log("Mux set to %s" % result)
+ return True
- if result != mux:
- c.log('Mux set to %s but saved as %s.' % (mux, result))
- return False
- c.log('Mux set to %s' % result)
- return True
def do_version(pty):
- """Check version via ec console 'pty'.
-
- Args:
- pty: a pty object connected to tigertail
-
- Commands are:
- # > version
- # Chip: stm stm32f07x
- # Board: 0
- # RO: tigertail_v1.1.6749-74d1a312e
- # RW: tigertail_v1.1.6749-74d1a312e
- # Build: tigertail_v1.1.6749-74d1a312e
- # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com
- """
- cmd = 'version'
- regex = r'RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+' \
- r'(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)'
-
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- c.log('Version is %s' % results[3])
- c.log('RO: %s' % results[1])
- c.log('RW: %s' % results[2])
- c.log('Date: %s' % results[4])
- c.log('Src: %s' % results[5])
-
- return True
+ """Check version via ec console 'pty'.
+
+ Args:
+ pty: a pty object connected to tigertail
+
+ Commands are:
+ # > version
+ # Chip: stm stm32f07x
+ # Board: 0
+ # RO: tigertail_v1.1.6749-74d1a312e
+ # RW: tigertail_v1.1.6749-74d1a312e
+ # Build: tigertail_v1.1.6749-74d1a312e
+ # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com
+ """
+ cmd = "version"
+ regex = (
+ r"RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+"
+ r"(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)"
+ )
+
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ c.log("Version is %s" % results[3])
+ c.log("RO: %s" % results[1])
+ c.log("RW: %s" % results[2])
+ c.log("Date: %s" % results[4])
+ c.log("Src: %s" % results[5])
+
+ return True
+
def do_check_serial(pty):
- """Check serial via ec console 'pty'.
+ """Check serial via ec console 'pty'.
- Args:
- pty: a pty object connected to tigertail
+ Args:
+ pty: a pty object connected to tigertail
- Commands are:
- # > serialno
- # Serial number: number
- """
- cmd = 'serialno'
- regex = r'Serial number: ([^\n\r]+)'
+ Commands are:
+ # > serialno
+ # Serial number: number
+ """
+ cmd = "serialno"
+ regex = r"Serial number: ([^\n\r]+)"
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- c.log('Serial is %s' % results[1])
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ c.log("Serial is %s" % results[1])
- return True
+ return True
def do_power(count, bus, pty):
- """Check power usage via ec console 'pty'.
-
- Args:
- count: number of samples to capture
- bus: rail to monitor, 'vbus', 'cc1', or 'cc2'
- pty: a pty object connected to tigertail
-
- Commands are:
- # > ina 0
- # Configuration: 4127
- # Shunt voltage: 02c4 => 1770 uV
- # Bus voltage : 1008 => 5130 mV
- # Power : 0019 => 625 mW
- # Current : 0082 => 130 mA
- # Calibration : 0155
- # Mask/Enable : 0008
- # Alert limit : 0000
- """
- if bus == 'vbus':
- ina = 0
- if bus == 'cc1':
- ina = 4
- if bus == 'cc2':
- ina = 1
-
- start = time.time()
-
- c.log('time,\tmV,\tmW,\tmA')
-
- cmd = 'ina %s' % ina
- regex = r'Bus voltage : \S+ \S+ (\d+) mV\s+' \
- r'Power : \S+ \S+ (\d+) mW\s+' \
- r'Current : \S+ \S+ (\d+) mA'
-
- for i in range(0, count):
- results = pty._issue_cmd_get_results(cmd, [regex])[0]
- c.log('%.2f,\t%s,\t%s\t%s' % (
- time.time() - start,
- results[1], results[2], results[3]))
+ """Check power usage via ec console 'pty'.
+
+ Args:
+ count: number of samples to capture
+ bus: rail to monitor, 'vbus', 'cc1', or 'cc2'
+ pty: a pty object connected to tigertail
+
+ Commands are:
+ # > ina 0
+ # Configuration: 4127
+ # Shunt voltage: 02c4 => 1770 uV
+ # Bus voltage : 1008 => 5130 mV
+ # Power : 0019 => 625 mW
+ # Current : 0082 => 130 mA
+ # Calibration : 0155
+ # Mask/Enable : 0008
+ # Alert limit : 0000
+ """
+ if bus == "vbus":
+ ina = 0
+ if bus == "cc1":
+ ina = 4
+ if bus == "cc2":
+ ina = 1
+
+ start = time.time()
+
+ c.log("time,\tmV,\tmW,\tmA")
+
+ cmd = "ina %s" % ina
+ regex = (
+ r"Bus voltage : \S+ \S+ (\d+) mV\s+"
+ r"Power : \S+ \S+ (\d+) mW\s+"
+ r"Current : \S+ \S+ (\d+) mA"
+ )
+
+ for i in range(0, count):
+ results = pty._issue_cmd_get_results(cmd, [regex])[0]
+ c.log(
+ "%.2f,\t%s,\t%s\t%s"
+ % (time.time() - start, results[1], results[2], results[3])
+ )
+
+ return True
- return True
def do_reboot(pty, serialname):
- """Reboot via ec console pty
-
- Args:
- pty: a pty object connected to tigertail
- serialname: serial name, can be None.
-
- Command is: reboot.
- """
- cmd = 'reboot'
-
- # Check usb dev number on current instance.
- devno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
- if not devno:
- c.log('Device not found')
- return False
-
- try:
- pty._issue_cmd(cmd)
- except Exception as e:
- c.log('Failed to send command: ' + str(e))
- return False
-
- try:
- c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
- # it's not going to. This step just goes faster if it's detected.
- pass
-
- try:
- c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- c.log('Failed to return from reboot: ' + str(e))
- return False
-
- # Check that the device had a new device number, i.e. it's
- # disconnected and reconnected.
- newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
- if newdevno == devno:
- c.log("Device didn't reboot")
- return False
-
- return True
+ """Reboot via ec console pty
+
+ Args:
+ pty: a pty object connected to tigertail
+ serialname: serial name, can be None.
+
+ Command is: reboot.
+ """
+ cmd = "reboot"
+
+ # Check usb dev number on current instance.
+ devno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
+ if not devno:
+ c.log("Device not found")
+ return False
+
+ try:
+ pty._issue_cmd(cmd)
+ except Exception as e:
+ c.log("Failed to send command: " + str(e))
+ return False
+
+ try:
+ c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
+ # it's not going to. This step just goes faster if it's detected.
+ pass
+
+ try:
+ c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ c.log("Failed to return from reboot: " + str(e))
+ return False
+
+ # Check that the device had a new device number, i.e. it's
+ # disconnected and reconnected.
+ newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname)
+ if newdevno == devno:
+ c.log("Device didn't reboot")
+ return False
+
+ return True
+
def do_sysjump(region, pty, serialname):
- """Set region via ec console 'pty'.
-
- Args:
- region: ec code region to execute, 'ro' or 'rw'
- pty: a pty object connected to tigertail
- serialname: serial name, can be None.
-
- Commands are:
- # > sysjump rw
- """
- validregion = ['ro', 'rw']
- if region not in validregion:
- c.log('Region setting %s invalid, try one of %s' % (
- region, validregion))
- return False
-
- cmd = 'sysjump %s' % region
- try:
- pty._issue_cmd(cmd)
- except Exception as e:
- c.log('Exception: ' + str(e))
- return False
-
- try:
- c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
- # it's not going to. This step just goes faster if it's detected.
- pass
-
- try:
- c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname)
- except Exception as e:
- c.log('Failed to return from restart: ' + str(e))
- return False
-
- c.log('Region requested %s' % region)
- return True
+ """Set region via ec console 'pty'.
+
+ Args:
+ region: ec code region to execute, 'ro' or 'rw'
+ pty: a pty object connected to tigertail
+ serialname: serial name, can be None.
+
+ Commands are:
+ # > sysjump rw
+ """
+ validregion = ["ro", "rw"]
+ if region not in validregion:
+ c.log("Region setting %s invalid, try one of %s" % (region, validregion))
+ return False
+
+ cmd = "sysjump %s" % region
+ try:
+ pty._issue_cmd(cmd)
+ except Exception as e:
+ c.log("Exception: " + str(e))
+ return False
+
+ try:
+ c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds
+ # it's not going to. This step just goes faster if it's detected.
+ pass
+
+ try:
+ c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname)
+ except Exception as e:
+ c.log("Failed to return from restart: " + str(e))
+ return False
+
+ c.log("Region requested %s" % region)
+ return True
+
def get_parser():
- parser = argparse.ArgumentParser(
- description=__doc__)
- parser.add_argument('-s', '--serialno', type=str, default=None,
- help='serial number of board to use')
- parser.add_argument('-b', '--bus', type=str, default='vbus',
- help='Which rail to log: [vbus|cc1|cc2]')
- group = parser.add_mutually_exclusive_group()
- group.add_argument('--setserialno', type=str, default=None,
- help='serial number to set on the board.')
- group.add_argument('--check_serial', action='store_true',
- help='check serial number set on the board.')
- group.add_argument('-m', '--mux', type=str, default=None,
- help='mux selection')
- group.add_argument('-p', '--power', action='store_true',
- help='check VBUS')
- group.add_argument('-l', '--powerlog', type=int, default=None,
- help='log VBUS')
- group.add_argument('-r', '--sysjump', type=str, default=None,
- help='region selection')
- group.add_argument('--reboot', action='store_true',
- help='reboot tigertail')
- group.add_argument('--check_version', action='store_true',
- help='check tigertail version')
- return parser
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "-s", "--serialno", type=str, default=None, help="serial number of board to use"
+ )
+ parser.add_argument(
+ "-b",
+ "--bus",
+ type=str,
+ default="vbus",
+ help="Which rail to log: [vbus|cc1|cc2]",
+ )
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument(
+ "--setserialno",
+ type=str,
+ default=None,
+ help="serial number to set on the board.",
+ )
+ group.add_argument(
+ "--check_serial",
+ action="store_true",
+ help="check serial number set on the board.",
+ )
+ group.add_argument("-m", "--mux", type=str, default=None, help="mux selection")
+ group.add_argument("-p", "--power", action="store_true", help="check VBUS")
+ group.add_argument("-l", "--powerlog", type=int, default=None, help="log VBUS")
+ group.add_argument(
+ "-r", "--sysjump", type=str, default=None, help="region selection"
+ )
+ group.add_argument("--reboot", action="store_true", help="reboot tigertail")
+ group.add_argument(
+ "--check_version", action="store_true", help="check tigertail version"
+ )
+ return parser
+
def main(argv):
- parser = get_parser()
- opts = parser.parse_args(argv)
+ parser = get_parser()
+ opts = parser.parse_args(argv)
- result = True
+ result = True
- # Let's make sure there's a tigertail
- # If nothing found in 5 seconds, fail.
- c.wait_for_usb(STM_VIDPID, timeout=5., serialname=opts.serialno)
+ # Let's make sure there's a tigertail
+ # If nothing found in 5 seconds, fail.
+ c.wait_for_usb(STM_VIDPID, timeout=5.0, serialname=opts.serialno)
- pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno)
+ pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno)
- if opts.bus not in ('vbus', 'cc1', 'cc2'):
- c.log('Try --bus [vbus|cc1|cc2]')
- result = False
+ if opts.bus not in ("vbus", "cc1", "cc2"):
+ c.log("Try --bus [vbus|cc1|cc2]")
+ result = False
- elif opts.setserialno:
- try:
- c.do_serialno(opts.setserialno, pty)
- except Exception:
- result = False
+ elif opts.setserialno:
+ try:
+ c.do_serialno(opts.setserialno, pty)
+ except Exception:
+ result = False
- elif opts.mux:
- result &= do_mux(opts.mux, pty)
+ elif opts.mux:
+ result &= do_mux(opts.mux, pty)
- elif opts.sysjump:
- result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno)
+ elif opts.sysjump:
+ result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno)
- elif opts.reboot:
- result &= do_reboot(pty, serialname=opts.serialno)
+ elif opts.reboot:
+ result &= do_reboot(pty, serialname=opts.serialno)
- elif opts.check_version:
- result &= do_version(pty)
+ elif opts.check_version:
+ result &= do_version(pty)
- elif opts.check_serial:
- result &= do_check_serial(pty)
+ elif opts.check_serial:
+ result &= do_check_serial(pty)
- elif opts.power:
- result &= do_power(1, opts.bus, pty)
+ elif opts.power:
+ result &= do_power(1, opts.bus, pty)
- elif opts.powerlog:
- result &= do_power(opts.powerlog, opts.bus, pty)
+ elif opts.powerlog:
+ result &= do_power(opts.powerlog, opts.bus, pty)
- if result:
- c.log('PASS')
- else:
- c.log('FAIL')
- sys.exit(-1)
+ if result:
+ c.log("PASS")
+ else:
+ c.log("FAIL")
+ sys.exit(-1)
-if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+if __name__ == "__main__":
+ sys.exit(main(sys.argv[1:]))
diff --git a/extra/usb_power/convert_power_log_board.py b/extra/usb_power/convert_power_log_board.py
index 8aab77ee4c..c1b25f57db 100644
--- a/extra/usb_power/convert_power_log_board.py
+++ b/extra/usb_power/convert_power_log_board.py
@@ -14,6 +14,7 @@ Program to convert sweetberry config to servod config template.
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import json
import os
import sys
@@ -48,21 +49,25 @@ def write_to_file(file, sweetberry, inas):
inas: list of inas read from board file.
"""
- with open(file, 'w') as pyfile:
+ with open(file, "w") as pyfile:
- pyfile.write('inas = [\n')
+ pyfile.write("inas = [\n")
for rec in inas:
- if rec['sweetberry'] != sweetberry:
+ if rec["sweetberry"] != sweetberry:
continue
# EX : ('sweetberry', 0x40, 'SB_FW_CAM_2P8', 5.0, 1.000, 3, False),
- channel, i2c_addr = Spower.CHMAP[rec['channel']]
- record = (" ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')"
- ",\n" % (i2c_addr, rec['name'], rec['rs'], channel))
+ channel, i2c_addr = Spower.CHMAP[rec["channel"]]
+ record = " ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" ",\n" % (
+ i2c_addr,
+ rec["name"],
+ rec["rs"],
+ channel,
+ )
pyfile.write(record)
- pyfile.write(']\n')
+ pyfile.write("]\n")
def main(argv):
@@ -76,16 +81,18 @@ def main(argv):
inas = fetch_records(inputf)
- sweetberry = set(rec['sweetberry'] for rec in inas)
+ sweetberry = set(rec["sweetberry"] for rec in inas)
if len(sweetberry) == 2:
- print("Converting %s to %s and %s" % (inputf, basename + '_a.py',
- basename + '_b.py'))
- write_to_file(basename + '_a.py', 'A', inas)
- write_to_file(basename + '_b.py', 'B', inas)
+ print(
+ "Converting %s to %s and %s"
+ % (inputf, basename + "_a.py", basename + "_b.py")
+ )
+ write_to_file(basename + "_a.py", "A", inas)
+ write_to_file(basename + "_b.py", "B", inas)
else:
- print("Converting %s to %s" % (inputf, basename + '.py'))
- write_to_file(basename + '.py', sweetberry.pop(), inas)
+ print("Converting %s to %s" % (inputf, basename + ".py"))
+ write_to_file(basename + ".py", sweetberry.pop(), inas)
if __name__ == "__main__":
diff --git a/extra/usb_power/convert_servo_ina.py b/extra/usb_power/convert_servo_ina.py
index 1c70f31aeb..6ccd474e6c 100755
--- a/extra/usb_power/convert_servo_ina.py
+++ b/extra/usb_power/convert_servo_ina.py
@@ -14,67 +14,71 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import os
import sys
def fetch_records(basename):
- """Import records from servo_ina file.
+ """Import records from servo_ina file.
- servo_ina files are python imports, and have a list of tuples with
- the INA data.
- (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True)
+ servo_ina files are python imports, and have a list of tuples with
+ the INA data.
+ (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True)
- Args:
- basename: python import name (filename -.py)
+ Args:
+ basename: python import name (filename -.py)
- Returns:
- list of tuples as described above.
- """
- ina_desc = __import__(basename)
- return ina_desc.inas
+ Returns:
+ list of tuples as described above.
+ """
+ ina_desc = __import__(basename)
+ return ina_desc.inas
def main(argv):
- if len(argv) != 2:
- print("usage:")
- print(" %s input.py" % argv[0])
- return
+ if len(argv) != 2:
+ print("usage:")
+ print(" %s input.py" % argv[0])
+ return
- inputf = argv[1]
- basename = os.path.splitext(inputf)[0]
- outputf = basename + '.board'
- outputs = basename + '.scenario'
+ inputf = argv[1]
+ basename = os.path.splitext(inputf)[0]
+ outputf = basename + ".board"
+ outputs = basename + ".scenario"
- print("Converting %s to %s, %s" % (inputf, outputf, outputs))
+ print("Converting %s to %s, %s" % (inputf, outputf, outputs))
- inas = fetch_records(basename)
+ inas = fetch_records(basename)
+ boardfile = open(outputf, "w")
+ scenario = open(outputs, "w")
- boardfile = open(outputf, 'w')
- scenario = open(outputs, 'w')
+ boardfile.write("[\n")
+ scenario.write("[\n")
+ start = True
- boardfile.write('[\n')
- scenario.write('[\n')
- start = True
+ for rec in inas:
+ if start:
+ start = False
+ else:
+ boardfile.write(",\n")
+ scenario.write(",\n")
- for rec in inas:
- if start:
- start = False
- else:
- boardfile.write(',\n')
- scenario.write(',\n')
+ record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % (
+ rec[2],
+ rec[4],
+ rec[1] - 64,
+ )
+ boardfile.write(record)
+ scenario.write('"%s"' % rec[2])
- record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % (
- rec[2], rec[4], rec[1] - 64)
- boardfile.write(record)
- scenario.write('"%s"' % rec[2])
+ boardfile.write("\n")
+ boardfile.write("]")
- boardfile.write('\n')
- boardfile.write(']')
+ scenario.write("\n")
+ scenario.write("]")
- scenario.write('\n')
- scenario.write(']')
if __name__ == "__main__":
- main(sys.argv)
+ main(sys.argv)
diff --git a/extra/usb_power/powerlog.py b/extra/usb_power/powerlog.py
index 82cce3daed..d893e5c6b9 100755
--- a/extra/usb_power/powerlog.py
+++ b/extra/usb_power/powerlog.py
@@ -14,9 +14,9 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import argparse
import array
-from distutils import sysconfig
import json
import logging
import os
@@ -25,884 +25,988 @@ import struct
import sys
import time
import traceback
+from distutils import sysconfig
import usb
-
from stats_manager import StatsManager
# Directory where hdctools installs configuration files into.
-LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), 'servo',
- 'data')
+LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), "servo", "data")
# Potential config file locations: current working directory, the same directory
# as powerlog.py file or LIB_DIR.
-CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)),
- LIB_DIR]
-
-def logoutput(msg):
- print(msg)
- sys.stdout.flush()
-
-def process_filename(filename):
- """Find the file path from the filename.
-
- If filename is already the complete path, return that directly. If filename is
- just the short name, look for the file in the current working directory, in
- the directory of the current .py file, and then in the directory installed by
- hdctools. If the file is found, return the complete path of the file.
-
- Args:
- filename: complete file path or short file name.
-
- Returns:
- a complete file path.
-
- Raises:
- IOError if filename does not exist.
- """
- # Check if filename is absolute path.
- if os.path.isabs(filename) and os.path.isfile(filename):
- return filename
- # Check if filename is relative to a known config location.
- for dirname in CONFIG_LOCATIONS:
- file_at_dir = os.path.join(dirname, filename)
- if os.path.isfile(file_at_dir):
- return file_at_dir
- raise IOError('No such file or directory: \'%s\'' % filename)
-
-
-class Spower(object):
- """Power class to access devices on the bus.
-
- Usage:
- bus = Spower()
-
- Instance Variables:
- _dev: pyUSB device object
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
-
- # INA interface type.
- INA_POWER = 1
- INA_BUSV = 2
- INA_CURRENT = 3
- INA_SHUNTV = 4
- # INA_SUFFIX is used to differentiate multiple ina types for the same power
- # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1
- # (power, no suffix for backward compatibility).
- INA_SUFFIX = ['', '', '_busv', '_cur', '_shuntv']
-
- # usb power commands
- CMD_RESET = 0x0000
- CMD_STOP = 0x0001
- CMD_ADDINA = 0x0002
- CMD_START = 0x0003
- CMD_NEXT = 0x0004
- CMD_SETTIME = 0x0005
-
- # Map between header channel number (0-47)
- # and INA I2C bus/addr on sweetberry.
- CHMAP = {
- 0: (3, 0x40),
- 1: (1, 0x40),
- 2: (2, 0x40),
- 3: (0, 0x40),
- 4: (3, 0x41),
- 5: (1, 0x41),
- 6: (2, 0x41),
- 7: (0, 0x41),
- 8: (3, 0x42),
- 9: (1, 0x42),
- 10: (2, 0x42),
- 11: (0, 0x42),
- 12: (3, 0x43),
- 13: (1, 0x43),
- 14: (2, 0x43),
- 15: (0, 0x43),
- 16: (3, 0x44),
- 17: (1, 0x44),
- 18: (2, 0x44),
- 19: (0, 0x44),
- 20: (3, 0x45),
- 21: (1, 0x45),
- 22: (2, 0x45),
- 23: (0, 0x45),
- 24: (3, 0x46),
- 25: (1, 0x46),
- 26: (2, 0x46),
- 27: (0, 0x46),
- 28: (3, 0x47),
- 29: (1, 0x47),
- 30: (2, 0x47),
- 31: (0, 0x47),
- 32: (3, 0x48),
- 33: (1, 0x48),
- 34: (2, 0x48),
- 35: (0, 0x48),
- 36: (3, 0x49),
- 37: (1, 0x49),
- 38: (2, 0x49),
- 39: (0, 0x49),
- 40: (3, 0x4a),
- 41: (1, 0x4a),
- 42: (2, 0x4a),
- 43: (0, 0x4a),
- 44: (3, 0x4b),
- 45: (1, 0x4b),
- 46: (2, 0x4b),
- 47: (0, 0x4b),
- }
-
- def __init__(self, board, vendor=0x18d1,
- product=0x5020, interface=1, serialname=None):
- self._logger = logging.getLogger(__name__)
- self._board = board
-
- # Find the stm32.
- dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
- dev_list = list(dev_g)
- if dev_list is None:
- raise Exception("Power", "USB device not found")
-
- # Check if we have multiple stm32s and we've specified the serial.
- dev = None
- if serialname:
- for d in dev_list:
- dev_serial = "PyUSB dioesn't have a stable interface"
- try:
- dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
- except ValueError:
- # Incompatible pyUsb version.
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == serialname:
- dev = d
- break
- if dev is None:
- raise Exception("Power", "USB device(%s) not found" % serialname)
- else:
- try:
- dev = dev_list[0]
- except TypeError:
- # Incompatible pyUsb version.
- dev = dev_list.next()
-
- self._logger.debug("Found USB device: %04x:%04x", vendor, product)
- self._dev = dev
-
- # Get an endpoint instance.
- try:
- dev.set_configuration()
- except usb.USBError:
- pass
- cfg = dev.get_active_configuration()
-
- intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \
- i.bInterfaceClass==255 and i.bInterfaceSubClass==0x54)
-
- self._intf = intf
- self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber)
-
- read_ep = usb.util.find_descriptor(
- intf,
- # match the first IN endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_IN
- )
-
- self._read_ep = read_ep
- self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress)
-
- write_ep = usb.util.find_descriptor(
- intf,
- # match the first OUT endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_OUT
- )
-
- self._write_ep = write_ep
- self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress)
-
- self.clear_ina_struct()
-
- self._logger.debug("Found power logging USB endpoint.")
-
- def clear_ina_struct(self):
- """ Clear INA description struct."""
- self._inas = []
+CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), LIB_DIR]
- def append_ina_struct(self, name, rs, port, addr,
- data=None, ina_type=INA_POWER):
- """Add an INA descriptor into the list of active INAs.
- Args:
- name: Readable name of this channel.
- rs: Sense resistor value in ohms, floating point.
- port: I2C channel this INA is connected to.
- addr: I2C addr of this INA.
- data: Misc data for special handling, board specific.
- ina_type: INA function to use, power, voltage, etc.
- """
- ina = {}
- ina['name'] = name
- ina['rs'] = rs
- ina['port'] = port
- ina['addr'] = addr
- ina['type'] = ina_type
- # Calculate INA231 Calibration register
- # (see INA231 spec p.15)
- # CurrentLSB = uA per div = 80mV / (Rsh * 2^15)
- # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000)
- ina['uAscale'] = 80000000. / (rs * 0x8000);
- ina['uWscale'] = 25. * ina['uAscale'];
- ina['mVscale'] = 1.25
- ina['uVscale'] = 2.5
- ina['data'] = data
- self._inas.append(ina)
-
- def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000):
- """Write command to logger logic.
-
- This function writes byte command values list to stm, then reads
- byte status.
-
- Args:
- write_list: list of command byte values [0~255].
- read_count: number of status byte values to read.
-
- Interface:
- write: [command, data ... ]
- read: [status ]
-
- Returns:
- bytes read, or None on failure.
- """
- self._logger.debug("Spower.wr_command(write_list=[%s] (%d), read_count=%s)",
- list(bytearray(write_list)), len(write_list), read_count)
-
- # Clean up args from python style to correct types.
- write_length = 0
- if write_list:
- write_length = len(write_list)
- if not read_count:
- read_count = 0
-
- # Send command to stm32.
- if write_list:
- cmd = write_list
- ret = self._write_ep.write(cmd, wtimeout)
-
- self._logger.debug("RET: %s ", ret)
-
- # Read back response if necessary.
- if read_count:
- bytesread = self._read_ep.read(512, rtimeout)
- self._logger.debug("BYTES: [%s]", bytesread)
-
- if len(bytesread) != read_count:
- pass
-
- self._logger.debug("STATUS: 0x%02x", int(bytesread[0]))
- if read_count == 1:
- return bytesread[0]
- else:
- return bytesread
-
- return None
-
- def clear(self):
- """Clear pending reads on the stm32"""
- try:
- while True:
- ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50)
- self._logger.debug("Try Clear: read %s",
- "success" if ret == 0 else "failure")
- except:
- pass
-
- def send_reset(self):
- """Reset the power interface on the stm32"""
- cmd = struct.pack("<H", self.CMD_RESET)
- ret = self.wr_command(cmd, rtimeout=50, wtimeout=50)
- self._logger.debug("Command RESET: %s",
- "success" if ret == 0 else "failure")
-
- def reset(self):
- """Try resetting the USB interface until success.
-
- Use linear back off strategy when encounter the error with 10ms increment.
-
- Raises:
- Exception on failure.
- """
- max_reset_retry = 100
- for count in range(1, max_reset_retry + 1):
- self.clear()
- try:
- self.send_reset()
- return
- except Exception as e:
- self.clear()
- self.clear()
- self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e)
- time.sleep(count * 0.01)
- raise Exception("Power", "Failed to reset")
-
- def stop(self):
- """Stop any active data acquisition."""
- cmd = struct.pack("<H", self.CMD_STOP)
- ret = self.wr_command(cmd)
- self._logger.debug("Command STOP: %s",
- "success" if ret == 0 else "failure")
-
- def start(self, integration_us):
- """Start data acquisition.
-
- Args:
- integration_us: int, how many us between samples, and
- how often the data block must be read.
+def logoutput(msg):
+ print(msg)
+ sys.stdout.flush()
- Returns:
- actual sampling interval in ms.
- """
- cmd = struct.pack("<HI", self.CMD_START, integration_us)
- read = self.wr_command(cmd, read_count=5)
- actual_us = 0
- if len(read) == 5:
- ret, actual_us = struct.unpack("<BI", read)
- self._logger.debug("Command START: %s %dus",
- "success" if ret == 0 else "failure", actual_us)
- else:
- self._logger.debug("Command START: FAIL")
- return actual_us
+def process_filename(filename):
+ """Find the file path from the filename.
- def add_ina_name(self, name_tuple):
- """Add INA from board config.
+ If filename is already the complete path, return that directly. If filename is
+ just the short name, look for the file in the current working directory, in
+ the directory of the current .py file, and then in the directory installed by
+ hdctools. If the file is found, return the complete path of the file.
Args:
- name_tuple: name and type of power rail in board config.
+ filename: complete file path or short file name.
Returns:
- True if INA added, False if the INA is not on this board.
+ a complete file path.
Raises:
- Exception on unexpected failure.
+ IOError if filename does not exist.
"""
- name, ina_type = name_tuple
-
- for datum in self._brdcfg:
- if datum["name"] == name:
- rs = int(float(datum["rs"]) * 1000.)
- board = datum["sweetberry"]
-
- if board == self._board:
- if 'port' in datum and 'addr' in datum:
- port = datum['port']
- addr = datum['addr']
- else:
- channel = int(datum["channel"])
- port, addr = self.CHMAP[channel]
- self.add_ina(port, ina_type, addr, 0, rs, data=datum)
- return True
- else:
- return False
- raise Exception("Power", "Failed to find INA %s" % name)
-
- def set_time(self, timestamp_us):
- """Set sweetberry time to match host time.
-
- Args:
- timestamp_us: host timestmap in us.
- """
- # 0x0005 , 8 byte timestamp
- cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us)
- ret = self.wr_command(cmd)
+ # Check if filename is absolute path.
+ if os.path.isabs(filename) and os.path.isfile(filename):
+ return filename
+ # Check if filename is relative to a known config location.
+ for dirname in CONFIG_LOCATIONS:
+ file_at_dir = os.path.join(dirname, filename)
+ if os.path.isfile(file_at_dir):
+ return file_at_dir
+ raise IOError("No such file or directory: '%s'" % filename)
- self._logger.debug("Command SETTIME: %s",
- "success" if ret == 0 else "failure")
- def add_ina(self, bus, ina_type, addr, extra, resistance, data=None):
- """Add an INA to the data acquisition list.
+class Spower(object):
+ """Power class to access devices on the bus.
- Args:
- bus: which i2c bus the INA is on. Same ordering as Si2c.
- ina_type: Ina interface: INA_POWER/BUSV/etc.
- addr: 7 bit i2c addr of this INA
- extra: extra data for nonstandard configs.
- resistance: int, shunt resistance in mOhm
- """
- # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs
- cmd = struct.pack("<HBBBBI", self.CMD_ADDINA,
- bus, ina_type, addr, extra, resistance)
- ret = self.wr_command(cmd)
- if ret == 0:
- if data:
- name = data['name']
- else:
- name = "ina%d_%02x" % (bus, addr)
- self.append_ina_struct(name, resistance, bus, addr,
- data=data, ina_type=ina_type)
- self._logger.debug("Command ADD_INA: %s",
- "success" if ret == 0 else "failure")
-
- def report_header_size(self):
- """Helper function to calculate power record header size."""
- result = 2
- timestamp = 8
- return result + timestamp
-
- def report_size(self, ina_count):
- """Helper function to calculate full power record size."""
- record = 2
-
- datasize = self.report_header_size() + ina_count * record
- # Round to multiple of 4 bytes.
- datasize = int(((datasize + 3) // 4) * 4)
-
- return datasize
-
- def read_line(self):
- """Read a line of data from the setup INAs
+ Usage:
+ bus = Spower()
- Returns:
- list of dicts of the values read by ina/type tuple, otherwise None.
- [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}]
+ Instance Variables:
+ _dev: pyUSB device object
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- try:
- expected_bytes = self.report_size(len(self._inas))
- cmd = struct.pack("<H", self.CMD_NEXT)
- bytesread = self.wr_command(cmd, read_count=expected_bytes)
- except usb.core.USBError as e:
- self._logger.error("READ LINE FAILED %s", e)
- return None
-
- if len(bytesread) == 1:
- if bytesread[0] != 0x6:
- self._logger.debug("READ LINE FAILED bytes: %d ret: %02x",
- len(bytesread), bytesread[0])
- return None
-
- if len(bytesread) % expected_bytes != 0:
- self._logger.debug("READ LINE WARNING: expected %d, got %d",
- expected_bytes, len(bytesread))
-
- packet_count = len(bytesread) // expected_bytes
-
- values = []
- for i in range(0, packet_count):
- start = i * expected_bytes
- end = (i + 1) * expected_bytes
- record = self.interpret_line(bytesread[start:end])
- values.append(record)
-
- return values
-
- def interpret_line(self, data):
- """Interpret a power record from INAs
- Args:
- data: one single record of bytes.
+ # INA interface type.
+ INA_POWER = 1
+ INA_BUSV = 2
+ INA_CURRENT = 3
+ INA_SHUNTV = 4
+ # INA_SUFFIX is used to differentiate multiple ina types for the same power
+ # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1
+ # (power, no suffix for backward compatibility).
+ INA_SUFFIX = ["", "", "_busv", "_cur", "_shuntv"]
+
+ # usb power commands
+ CMD_RESET = 0x0000
+ CMD_STOP = 0x0001
+ CMD_ADDINA = 0x0002
+ CMD_START = 0x0003
+ CMD_NEXT = 0x0004
+ CMD_SETTIME = 0x0005
+
+ # Map between header channel number (0-47)
+ # and INA I2C bus/addr on sweetberry.
+ CHMAP = {
+ 0: (3, 0x40),
+ 1: (1, 0x40),
+ 2: (2, 0x40),
+ 3: (0, 0x40),
+ 4: (3, 0x41),
+ 5: (1, 0x41),
+ 6: (2, 0x41),
+ 7: (0, 0x41),
+ 8: (3, 0x42),
+ 9: (1, 0x42),
+ 10: (2, 0x42),
+ 11: (0, 0x42),
+ 12: (3, 0x43),
+ 13: (1, 0x43),
+ 14: (2, 0x43),
+ 15: (0, 0x43),
+ 16: (3, 0x44),
+ 17: (1, 0x44),
+ 18: (2, 0x44),
+ 19: (0, 0x44),
+ 20: (3, 0x45),
+ 21: (1, 0x45),
+ 22: (2, 0x45),
+ 23: (0, 0x45),
+ 24: (3, 0x46),
+ 25: (1, 0x46),
+ 26: (2, 0x46),
+ 27: (0, 0x46),
+ 28: (3, 0x47),
+ 29: (1, 0x47),
+ 30: (2, 0x47),
+ 31: (0, 0x47),
+ 32: (3, 0x48),
+ 33: (1, 0x48),
+ 34: (2, 0x48),
+ 35: (0, 0x48),
+ 36: (3, 0x49),
+ 37: (1, 0x49),
+ 38: (2, 0x49),
+ 39: (0, 0x49),
+ 40: (3, 0x4A),
+ 41: (1, 0x4A),
+ 42: (2, 0x4A),
+ 43: (0, 0x4A),
+ 44: (3, 0x4B),
+ 45: (1, 0x4B),
+ 46: (2, 0x4B),
+ 47: (0, 0x4B),
+ }
+
+ def __init__(
+ self, board, vendor=0x18D1, product=0x5020, interface=1, serialname=None
+ ):
+ self._logger = logging.getLogger(__name__)
+ self._board = board
+
+ # Find the stm32.
+ dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
+ dev_list = list(dev_g)
+ if dev_list is None:
+ raise Exception("Power", "USB device not found")
+
+ # Check if we have multiple stm32s and we've specified the serial.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ dev_serial = "PyUSB dioesn't have a stable interface"
+ try:
+ dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
+ except ValueError:
+ # Incompatible pyUsb version.
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == serialname:
+ dev = d
+ break
+ if dev is None:
+ raise Exception("Power", "USB device(%s) not found" % serialname)
+ else:
+ try:
+ dev = dev_list[0]
+ except TypeError:
+ # Incompatible pyUsb version.
+ dev = dev_list.next()
- Output:
- stdout of the record in csv format.
+ self._logger.debug("Found USB device: %04x:%04x", vendor, product)
+ self._dev = dev
- Returns:
- dict containing name, value of recorded data.
- """
- status, size = struct.unpack("<BB", data[0:2])
- if len(data) != self.report_size(size):
- self._logger.error("READ LINE FAILED st:%d size:%d expected:%d len:%d",
- status, size, self.report_size(size), len(data))
- else:
- pass
+ # Get an endpoint instance.
+ try:
+ dev.set_configuration()
+ except usb.USBError:
+ pass
+ cfg = dev.get_active_configuration()
+
+ intf = usb.util.find_descriptor(
+ cfg,
+ custom_match=lambda i: i.bInterfaceClass == 255
+ and i.bInterfaceSubClass == 0x54,
+ )
+
+ self._intf = intf
+ self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber)
+
+ read_ep = usb.util.find_descriptor(
+ intf,
+ # match the first IN endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_IN,
+ )
+
+ self._read_ep = read_ep
+ self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress)
+
+ write_ep = usb.util.find_descriptor(
+ intf,
+ # match the first OUT endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_OUT,
+ )
+
+ self._write_ep = write_ep
+ self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress)
+
+ self.clear_ina_struct()
+
+ self._logger.debug("Found power logging USB endpoint.")
+
+ def clear_ina_struct(self):
+ """Clear INA description struct."""
+ self._inas = []
+
+ def append_ina_struct(self, name, rs, port, addr, data=None, ina_type=INA_POWER):
+ """Add an INA descriptor into the list of active INAs.
+
+ Args:
+ name: Readable name of this channel.
+ rs: Sense resistor value in ohms, floating point.
+ port: I2C channel this INA is connected to.
+ addr: I2C addr of this INA.
+ data: Misc data for special handling, board specific.
+ ina_type: INA function to use, power, voltage, etc.
+ """
+ ina = {}
+ ina["name"] = name
+ ina["rs"] = rs
+ ina["port"] = port
+ ina["addr"] = addr
+ ina["type"] = ina_type
+ # Calculate INA231 Calibration register
+ # (see INA231 spec p.15)
+ # CurrentLSB = uA per div = 80mV / (Rsh * 2^15)
+ # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000)
+ ina["uAscale"] = 80000000.0 / (rs * 0x8000)
+ ina["uWscale"] = 25.0 * ina["uAscale"]
+ ina["mVscale"] = 1.25
+ ina["uVscale"] = 2.5
+ ina["data"] = data
+ self._inas.append(ina)
+
+ def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000):
+ """Write command to logger logic.
+
+ This function writes byte command values list to stm, then reads
+ byte status.
+
+ Args:
+ write_list: list of command byte values [0~255].
+ read_count: number of status byte values to read.
+
+ Interface:
+ write: [command, data ... ]
+ read: [status ]
+
+ Returns:
+ bytes read, or None on failure.
+ """
+ self._logger.debug(
+ "Spower.wr_command(write_list=[%s] (%d), read_count=%s)",
+ list(bytearray(write_list)),
+ len(write_list),
+ read_count,
+ )
+
+ # Clean up args from python style to correct types.
+ write_length = 0
+ if write_list:
+ write_length = len(write_list)
+ if not read_count:
+ read_count = 0
+
+ # Send command to stm32.
+ if write_list:
+ cmd = write_list
+ ret = self._write_ep.write(cmd, wtimeout)
+
+ self._logger.debug("RET: %s ", ret)
+
+ # Read back response if necessary.
+ if read_count:
+ bytesread = self._read_ep.read(512, rtimeout)
+ self._logger.debug("BYTES: [%s]", bytesread)
+
+ if len(bytesread) != read_count:
+ pass
+
+ self._logger.debug("STATUS: 0x%02x", int(bytesread[0]))
+ if read_count == 1:
+ return bytesread[0]
+ else:
+ return bytesread
+
+ return None
+
+ def clear(self):
+ """Clear pending reads on the stm32"""
+ try:
+ while True:
+ ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50)
+ self._logger.debug(
+ "Try Clear: read %s", "success" if ret == 0 else "failure"
+ )
+ except:
+ pass
+
+ def send_reset(self):
+ """Reset the power interface on the stm32"""
+ cmd = struct.pack("<H", self.CMD_RESET)
+ ret = self.wr_command(cmd, rtimeout=50, wtimeout=50)
+ self._logger.debug("Command RESET: %s", "success" if ret == 0 else "failure")
+
+ def reset(self):
+ """Try resetting the USB interface until success.
+
+ Use linear back off strategy when encounter the error with 10ms increment.
+
+ Raises:
+ Exception on failure.
+ """
+ max_reset_retry = 100
+ for count in range(1, max_reset_retry + 1):
+ self.clear()
+ try:
+ self.send_reset()
+ return
+ except Exception as e:
+ self.clear()
+ self.clear()
+ self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e)
+ time.sleep(count * 0.01)
+ raise Exception("Power", "Failed to reset")
+
+ def stop(self):
+ """Stop any active data acquisition."""
+ cmd = struct.pack("<H", self.CMD_STOP)
+ ret = self.wr_command(cmd)
+ self._logger.debug("Command STOP: %s", "success" if ret == 0 else "failure")
+
+ def start(self, integration_us):
+ """Start data acquisition.
+
+ Args:
+ integration_us: int, how many us between samples, and
+ how often the data block must be read.
+
+ Returns:
+ actual sampling interval in ms.
+ """
+ cmd = struct.pack("<HI", self.CMD_START, integration_us)
+ read = self.wr_command(cmd, read_count=5)
+ actual_us = 0
+ if len(read) == 5:
+ ret, actual_us = struct.unpack("<BI", read)
+ self._logger.debug(
+ "Command START: %s %dus",
+ "success" if ret == 0 else "failure",
+ actual_us,
+ )
+ else:
+ self._logger.debug("Command START: FAIL")
+
+ return actual_us
+
+ def add_ina_name(self, name_tuple):
+ """Add INA from board config.
+
+ Args:
+ name_tuple: name and type of power rail in board config.
+
+ Returns:
+ True if INA added, False if the INA is not on this board.
+
+ Raises:
+ Exception on unexpected failure.
+ """
+ name, ina_type = name_tuple
+
+ for datum in self._brdcfg:
+ if datum["name"] == name:
+ rs = int(float(datum["rs"]) * 1000.0)
+ board = datum["sweetberry"]
+
+ if board == self._board:
+ if "port" in datum and "addr" in datum:
+ port = datum["port"]
+ addr = datum["addr"]
+ else:
+ channel = int(datum["channel"])
+ port, addr = self.CHMAP[channel]
+ self.add_ina(port, ina_type, addr, 0, rs, data=datum)
+ return True
+ else:
+ return False
+ raise Exception("Power", "Failed to find INA %s" % name)
+
+ def set_time(self, timestamp_us):
+ """Set sweetberry time to match host time.
+
+ Args:
+ timestamp_us: host timestmap in us.
+ """
+ # 0x0005 , 8 byte timestamp
+ cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us)
+ ret = self.wr_command(cmd)
+
+ self._logger.debug("Command SETTIME: %s", "success" if ret == 0 else "failure")
+
+ def add_ina(self, bus, ina_type, addr, extra, resistance, data=None):
+ """Add an INA to the data acquisition list.
+
+ Args:
+ bus: which i2c bus the INA is on. Same ordering as Si2c.
+ ina_type: Ina interface: INA_POWER/BUSV/etc.
+ addr: 7 bit i2c addr of this INA
+ extra: extra data for nonstandard configs.
+ resistance: int, shunt resistance in mOhm
+ """
+ # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs
+ cmd = struct.pack(
+ "<HBBBBI", self.CMD_ADDINA, bus, ina_type, addr, extra, resistance
+ )
+ ret = self.wr_command(cmd)
+ if ret == 0:
+ if data:
+ name = data["name"]
+ else:
+ name = "ina%d_%02x" % (bus, addr)
+ self.append_ina_struct(
+ name, resistance, bus, addr, data=data, ina_type=ina_type
+ )
+ self._logger.debug("Command ADD_INA: %s", "success" if ret == 0 else "failure")
+
+ def report_header_size(self):
+ """Helper function to calculate power record header size."""
+ result = 2
+ timestamp = 8
+ return result + timestamp
+
+ def report_size(self, ina_count):
+ """Helper function to calculate full power record size."""
+ record = 2
+
+ datasize = self.report_header_size() + ina_count * record
+ # Round to multiple of 4 bytes.
+ datasize = int(((datasize + 3) // 4) * 4)
+
+ return datasize
+
+ def read_line(self):
+ """Read a line of data from the setup INAs
+
+ Returns:
+ list of dicts of the values read by ina/type tuple, otherwise None.
+ [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}]
+ """
+ try:
+ expected_bytes = self.report_size(len(self._inas))
+ cmd = struct.pack("<H", self.CMD_NEXT)
+ bytesread = self.wr_command(cmd, read_count=expected_bytes)
+ except usb.core.USBError as e:
+ self._logger.error("READ LINE FAILED %s", e)
+ return None
+
+ if len(bytesread) == 1:
+ if bytesread[0] != 0x6:
+ self._logger.debug(
+ "READ LINE FAILED bytes: %d ret: %02x", len(bytesread), bytesread[0]
+ )
+ return None
+
+ if len(bytesread) % expected_bytes != 0:
+ self._logger.debug(
+ "READ LINE WARNING: expected %d, got %d", expected_bytes, len(bytesread)
+ )
+
+ packet_count = len(bytesread) // expected_bytes
+
+ values = []
+ for i in range(0, packet_count):
+ start = i * expected_bytes
+ end = (i + 1) * expected_bytes
+ record = self.interpret_line(bytesread[start:end])
+ values.append(record)
+
+ return values
+
+ def interpret_line(self, data):
+ """Interpret a power record from INAs
+
+ Args:
+ data: one single record of bytes.
+
+ Output:
+ stdout of the record in csv format.
+
+ Returns:
+ dict containing name, value of recorded data.
+ """
+ status, size = struct.unpack("<BB", data[0:2])
+ if len(data) != self.report_size(size):
+ self._logger.error(
+ "READ LINE FAILED st:%d size:%d expected:%d len:%d",
+ status,
+ size,
+ self.report_size(size),
+ len(data),
+ )
+ else:
+ pass
- timestamp = struct.unpack("<Q", data[2:10])[0]
- self._logger.debug("READ LINE: st:%d size:%d time:%dus", status, size,
- timestamp)
- ftimestamp = float(timestamp) / 1000000.
+ timestamp = struct.unpack("<Q", data[2:10])[0]
+ self._logger.debug(
+ "READ LINE: st:%d size:%d time:%dus", status, size, timestamp
+ )
+ ftimestamp = float(timestamp) / 1000000.0
- record = {"ts": ftimestamp, "status": status, "berry":self._board}
+ record = {"ts": ftimestamp, "status": status, "berry": self._board}
- for i in range(0, size):
- idx = self.report_header_size() + 2*i
- name = self._inas[i]['name']
- name_tuple = (self._inas[i]['name'], self._inas[i]['type'])
+ for i in range(0, size):
+ idx = self.report_header_size() + 2 * i
+ name = self._inas[i]["name"]
+ name_tuple = (self._inas[i]["name"], self._inas[i]["type"])
- raw_val = struct.unpack("<h", data[idx:idx+2])[0]
+ raw_val = struct.unpack("<h", data[idx : idx + 2])[0]
- if self._inas[i]['type'] == Spower.INA_POWER:
- val = raw_val * self._inas[i]['uWscale']
- elif self._inas[i]['type'] == Spower.INA_BUSV:
- val = raw_val * self._inas[i]['mVscale']
- elif self._inas[i]['type'] == Spower.INA_CURRENT:
- val = raw_val * self._inas[i]['uAscale']
- elif self._inas[i]['type'] == Spower.INA_SHUNTV:
- val = raw_val * self._inas[i]['uVscale']
+ if self._inas[i]["type"] == Spower.INA_POWER:
+ val = raw_val * self._inas[i]["uWscale"]
+ elif self._inas[i]["type"] == Spower.INA_BUSV:
+ val = raw_val * self._inas[i]["mVscale"]
+ elif self._inas[i]["type"] == Spower.INA_CURRENT:
+ val = raw_val * self._inas[i]["uAscale"]
+ elif self._inas[i]["type"] == Spower.INA_SHUNTV:
+ val = raw_val * self._inas[i]["uVscale"]
- self._logger.debug("READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp,
- raw_val, val)
- record[name_tuple] = val
+ self._logger.debug(
+ "READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, raw_val, val
+ )
+ record[name_tuple] = val
- return record
+ return record
- def load_board(self, brdfile):
- """Load a board config.
+ def load_board(self, brdfile):
+ """Load a board config.
- Args:
- brdfile: Filename of a json file decribing the INA wiring of this board.
- """
- with open(process_filename(brdfile)) as data_file:
- data = json.load(data_file)
+ Args:
+ brdfile: Filename of a json file decribing the INA wiring of this board.
+ """
+ with open(process_filename(brdfile)) as data_file:
+ data = json.load(data_file)
- #TODO: validate this.
- self._brdcfg = data;
- self._logger.debug(pprint.pformat(data))
+ # TODO: validate this.
+ self._brdcfg = data
+ self._logger.debug(pprint.pformat(data))
class powerlog(object):
- """Power class to log aggregated power.
-
- Usage:
- obj = powerlog()
+ """Power class to log aggregated power.
- Instance Variables:
- _data: a StatsManager object that records sweetberry readings and calculates
- statistics.
- _pwr[]: Spower objects for individual sweetberries.
- """
-
- def __init__(self, brdfile, cfgfile, serial_a=None, serial_b=None,
- sync_date=False, use_ms=False, use_mW=False, print_stats=False,
- stats_dir=None, stats_json_dir=None, print_raw_data=True,
- raw_data_dir=None):
- """Init the powerlog class and set the variables.
-
- Args:
- brdfile: string name of json file containing board layout.
- cfgfile: string name of json containing list of rails to read.
- serial_a: serial number of sweetberry A.
- serial_b: serial number of sweetberry B.
- sync_date: report timestamps synced with host datetime.
- use_ms: report timestamps in ms rather than us.
- use_mW: report power as milliwatts, otherwise default to microwatts.
- print_stats: print statistics for sweetberry readings at the end.
- stats_dir: directory to save sweetberry readings statistics; if None then
- do not save the statistics.
- stats_json_dir: directory to save means of sweetberry readings in json
- format; if None then do not save the statistics.
- print_raw_data: print sweetberry readings raw data in real time, default
- is to print.
- raw_data_dir: directory to save sweetberry readings raw data; if None then
- do not save the raw data.
- """
- self._logger = logging.getLogger(__name__)
- self._data = StatsManager()
- self._pwr = {}
- self._use_ms = use_ms
- self._use_mW = use_mW
- self._print_stats = print_stats
- self._stats_dir = stats_dir
- self._stats_json_dir = stats_json_dir
- self._print_raw_data = print_raw_data
- self._raw_data_dir = raw_data_dir
-
- if not serial_a and not serial_b:
- self._pwr['A'] = Spower('A')
- if serial_a:
- self._pwr['A'] = Spower('A', serialname=serial_a)
- if serial_b:
- self._pwr['B'] = Spower('B', serialname=serial_b)
-
- with open(process_filename(cfgfile)) as data_file:
- names = json.load(data_file)
- self._names = self.process_scenario(names)
-
- for key in self._pwr:
- self._pwr[key].load_board(brdfile)
- self._pwr[key].reset()
-
- # Allocate the rails to the appropriate boards.
- used_boards = []
- for name in self._names:
- success = False
- for key in self._pwr.keys():
- if self._pwr[key].add_ina_name(name):
- success = True
- if key not in used_boards:
- used_boards.append(key)
- if not success:
- raise Exception("Failed to add %s (maybe missing "
- "sweetberry, or bad board file?)" % name)
-
- # Evict unused boards.
- for key in list(self._pwr.keys()):
- if key not in used_boards:
- self._pwr.pop(key)
-
- for key in self._pwr.keys():
- if sync_date:
- self._pwr[key].set_time(time.time() * 1000000)
- else:
- self._pwr[key].set_time(0)
-
- def process_scenario(self, name_list):
- """Return list of tuples indicating name and type.
+ Usage:
+ obj = powerlog()
- Args:
- json originated list of names, or [name, type]
- Returns:
- list of tuples of (name, type) defaulting to type "POWER"
- Raises: exception, invalid INA type.
+ Instance Variables:
+ _data: a StatsManager object that records sweetberry readings and calculates
+ statistics.
+ _pwr[]: Spower objects for individual sweetberries.
"""
- names = []
- for entry in name_list:
- if isinstance(entry, list):
- name = entry[0]
- if entry[1] == "POWER":
- type = Spower.INA_POWER
- elif entry[1] == "BUSV":
- type = Spower.INA_BUSV
- elif entry[1] == "CURRENT":
- type = Spower.INA_CURRENT
- elif entry[1] == "SHUNTV":
- type = Spower.INA_SHUNTV
- else:
- raise Exception("Invalid INA type", "Type of %s [%s] not recognized,"
- " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1]))
- else:
- name = entry
- type = Spower.INA_POWER
-
- names.append((name, type))
- return names
- def start(self, integration_us_request, seconds, sync_speed=.8):
- """Starts sampling.
+ def __init__(
+ self,
+ brdfile,
+ cfgfile,
+ serial_a=None,
+ serial_b=None,
+ sync_date=False,
+ use_ms=False,
+ use_mW=False,
+ print_stats=False,
+ stats_dir=None,
+ stats_json_dir=None,
+ print_raw_data=True,
+ raw_data_dir=None,
+ ):
+ """Init the powerlog class and set the variables.
+
+ Args:
+ brdfile: string name of json file containing board layout.
+ cfgfile: string name of json containing list of rails to read.
+ serial_a: serial number of sweetberry A.
+ serial_b: serial number of sweetberry B.
+ sync_date: report timestamps synced with host datetime.
+ use_ms: report timestamps in ms rather than us.
+ use_mW: report power as milliwatts, otherwise default to microwatts.
+ print_stats: print statistics for sweetberry readings at the end.
+ stats_dir: directory to save sweetberry readings statistics; if None then
+ do not save the statistics.
+ stats_json_dir: directory to save means of sweetberry readings in json
+ format; if None then do not save the statistics.
+ print_raw_data: print sweetberry readings raw data in real time, default
+ is to print.
+ raw_data_dir: directory to save sweetberry readings raw data; if None then
+ do not save the raw data.
+ """
+ self._logger = logging.getLogger(__name__)
+ self._data = StatsManager()
+ self._pwr = {}
+ self._use_ms = use_ms
+ self._use_mW = use_mW
+ self._print_stats = print_stats
+ self._stats_dir = stats_dir
+ self._stats_json_dir = stats_json_dir
+ self._print_raw_data = print_raw_data
+ self._raw_data_dir = raw_data_dir
+
+ if not serial_a and not serial_b:
+ self._pwr["A"] = Spower("A")
+ if serial_a:
+ self._pwr["A"] = Spower("A", serialname=serial_a)
+ if serial_b:
+ self._pwr["B"] = Spower("B", serialname=serial_b)
+
+ with open(process_filename(cfgfile)) as data_file:
+ names = json.load(data_file)
+ self._names = self.process_scenario(names)
- Args:
- integration_us_request: requested interval between sample values.
- seconds: time until exit, or None to run until cancel.
- sync_speed: A usb request is sent every [.8] * integration_us.
- """
- # We will get back the actual integration us.
- # It should be the same for all devices.
- integration_us = None
- for key in self._pwr:
- integration_us_new = self._pwr[key].start(integration_us_request)
- if integration_us:
- if integration_us != integration_us_new:
- raise Exception("FAIL",
- "Integration on A: %dus != integration on B %dus" % (
- integration_us, integration_us_new))
- integration_us = integration_us_new
-
- # CSV header
- title = "ts:%dus" % integration_us
- for name_tuple in self._names:
- name, ina_type = name_tuple
-
- if ina_type == Spower.INA_POWER:
- unit = "mW" if self._use_mW else "uW"
- elif ina_type == Spower.INA_BUSV:
- unit = "mV"
- elif ina_type == Spower.INA_CURRENT:
- unit = "uA"
- elif ina_type == Spower.INA_SHUNTV:
- unit = "uV"
-
- title += ", %s %s" % (name, unit)
- name_type = name + Spower.INA_SUFFIX[ina_type]
- self._data.SetUnit(name_type, unit)
- title += ", status"
- if self._print_raw_data:
- logoutput(title)
-
- forever = False
- if not seconds:
- forever = True
- end_time = time.time() + seconds
- try:
- pending_records = []
- while forever or end_time > time.time():
- if (integration_us > 5000):
- time.sleep((integration_us / 1000000.) * sync_speed)
for key in self._pwr:
- records = self._pwr[key].read_line()
- if not records:
- continue
-
- for record in records:
- pending_records.append(record)
-
- pending_records.sort(key=lambda r: r['ts'])
-
- aggregate_record = {"boards": set()}
- for record in pending_records:
- if record["berry"] not in aggregate_record["boards"]:
- for rkey in record.keys():
- aggregate_record[rkey] = record[rkey]
- aggregate_record["boards"].add(record["berry"])
- else:
- self._logger.info("break %s, %s", record["berry"],
- aggregate_record["boards"])
- break
-
- if aggregate_record["boards"] == set(self._pwr.keys()):
- csv = "%f" % aggregate_record["ts"]
- for name in self._names:
- if name in aggregate_record:
- multiplier = 0.001 if (self._use_mW and
- name[1]==Spower.INA_POWER) else 1
- value = aggregate_record[name] * multiplier
- csv += ", %.2f" % value
- name_type = name[0] + Spower.INA_SUFFIX[name[1]]
- self._data.AddSample(name_type, value)
- else:
- csv += ", "
- csv += ", %d" % aggregate_record["status"]
- if self._print_raw_data:
- logoutput(csv)
-
- aggregate_record = {"boards": set()}
- for r in range(0, len(self._pwr)):
- pending_records.pop(0)
-
- except KeyboardInterrupt:
- self._logger.info('\nCTRL+C caught.')
-
- finally:
- for key in self._pwr:
- self._pwr[key].stop()
- self._data.CalculateStats()
- if self._print_stats:
- print(self._data.SummaryToString())
- save_dir = 'sweetberry%s' % time.time()
- if self._stats_dir:
- stats_dir = os.path.join(self._stats_dir, save_dir)
- self._data.SaveSummary(stats_dir)
- if self._stats_json_dir:
- stats_json_dir = os.path.join(self._stats_json_dir, save_dir)
- self._data.SaveSummaryJSON(stats_json_dir)
- if self._raw_data_dir:
- raw_data_dir = os.path.join(self._raw_data_dir, save_dir)
- self._data.SaveRawData(raw_data_dir)
+ self._pwr[key].load_board(brdfile)
+ self._pwr[key].reset()
+
+ # Allocate the rails to the appropriate boards.
+ used_boards = []
+ for name in self._names:
+ success = False
+ for key in self._pwr.keys():
+ if self._pwr[key].add_ina_name(name):
+ success = True
+ if key not in used_boards:
+ used_boards.append(key)
+ if not success:
+ raise Exception(
+ "Failed to add %s (maybe missing "
+ "sweetberry, or bad board file?)" % name
+ )
+
+ # Evict unused boards.
+ for key in list(self._pwr.keys()):
+ if key not in used_boards:
+ self._pwr.pop(key)
+
+ for key in self._pwr.keys():
+ if sync_date:
+ self._pwr[key].set_time(time.time() * 1000000)
+ else:
+ self._pwr[key].set_time(0)
+
+ def process_scenario(self, name_list):
+ """Return list of tuples indicating name and type.
+
+ Args:
+ json originated list of names, or [name, type]
+ Returns:
+ list of tuples of (name, type) defaulting to type "POWER"
+ Raises: exception, invalid INA type.
+ """
+ names = []
+ for entry in name_list:
+ if isinstance(entry, list):
+ name = entry[0]
+ if entry[1] == "POWER":
+ type = Spower.INA_POWER
+ elif entry[1] == "BUSV":
+ type = Spower.INA_BUSV
+ elif entry[1] == "CURRENT":
+ type = Spower.INA_CURRENT
+ elif entry[1] == "SHUNTV":
+ type = Spower.INA_SHUNTV
+ else:
+ raise Exception(
+ "Invalid INA type",
+ "Type of %s [%s] not recognized,"
+ " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1]),
+ )
+ else:
+ name = entry
+ type = Spower.INA_POWER
+
+ names.append((name, type))
+ return names
+
+ def start(self, integration_us_request, seconds, sync_speed=0.8):
+ """Starts sampling.
+
+ Args:
+ integration_us_request: requested interval between sample values.
+ seconds: time until exit, or None to run until cancel.
+ sync_speed: A usb request is sent every [.8] * integration_us.
+ """
+ # We will get back the actual integration us.
+ # It should be the same for all devices.
+ integration_us = None
+ for key in self._pwr:
+ integration_us_new = self._pwr[key].start(integration_us_request)
+ if integration_us:
+ if integration_us != integration_us_new:
+ raise Exception(
+ "FAIL",
+ "Integration on A: %dus != integration on B %dus"
+ % (integration_us, integration_us_new),
+ )
+ integration_us = integration_us_new
+
+ # CSV header
+ title = "ts:%dus" % integration_us
+ for name_tuple in self._names:
+ name, ina_type = name_tuple
+
+ if ina_type == Spower.INA_POWER:
+ unit = "mW" if self._use_mW else "uW"
+ elif ina_type == Spower.INA_BUSV:
+ unit = "mV"
+ elif ina_type == Spower.INA_CURRENT:
+ unit = "uA"
+ elif ina_type == Spower.INA_SHUNTV:
+ unit = "uV"
+
+ title += ", %s %s" % (name, unit)
+ name_type = name + Spower.INA_SUFFIX[ina_type]
+ self._data.SetUnit(name_type, unit)
+ title += ", status"
+ if self._print_raw_data:
+ logoutput(title)
+
+ forever = False
+ if not seconds:
+ forever = True
+ end_time = time.time() + seconds
+ try:
+ pending_records = []
+ while forever or end_time > time.time():
+ if integration_us > 5000:
+ time.sleep((integration_us / 1000000.0) * sync_speed)
+ for key in self._pwr:
+ records = self._pwr[key].read_line()
+ if not records:
+ continue
+
+ for record in records:
+ pending_records.append(record)
+
+ pending_records.sort(key=lambda r: r["ts"])
+
+ aggregate_record = {"boards": set()}
+ for record in pending_records:
+ if record["berry"] not in aggregate_record["boards"]:
+ for rkey in record.keys():
+ aggregate_record[rkey] = record[rkey]
+ aggregate_record["boards"].add(record["berry"])
+ else:
+ self._logger.info(
+ "break %s, %s", record["berry"], aggregate_record["boards"]
+ )
+ break
+
+ if aggregate_record["boards"] == set(self._pwr.keys()):
+ csv = "%f" % aggregate_record["ts"]
+ for name in self._names:
+ if name in aggregate_record:
+ multiplier = (
+ 0.001
+ if (self._use_mW and name[1] == Spower.INA_POWER)
+ else 1
+ )
+ value = aggregate_record[name] * multiplier
+ csv += ", %.2f" % value
+ name_type = name[0] + Spower.INA_SUFFIX[name[1]]
+ self._data.AddSample(name_type, value)
+ else:
+ csv += ", "
+ csv += ", %d" % aggregate_record["status"]
+ if self._print_raw_data:
+ logoutput(csv)
+
+ aggregate_record = {"boards": set()}
+ for r in range(0, len(self._pwr)):
+ pending_records.pop(0)
+
+ except KeyboardInterrupt:
+ self._logger.info("\nCTRL+C caught.")
+
+ finally:
+ for key in self._pwr:
+ self._pwr[key].stop()
+ self._data.CalculateStats()
+ if self._print_stats:
+ print(self._data.SummaryToString())
+ save_dir = "sweetberry%s" % time.time()
+ if self._stats_dir:
+ stats_dir = os.path.join(self._stats_dir, save_dir)
+ self._data.SaveSummary(stats_dir)
+ if self._stats_json_dir:
+ stats_json_dir = os.path.join(self._stats_json_dir, save_dir)
+ self._data.SaveSummaryJSON(stats_json_dir)
+ if self._raw_data_dir:
+ raw_data_dir = os.path.join(self._raw_data_dir, save_dir)
+ self._data.SaveRawData(raw_data_dir)
def main(argv=None):
- if argv is None:
- argv = sys.argv[1:]
- # Command line argument description.
- parser = argparse.ArgumentParser(
- description="Gather CSV data from sweetberry")
- parser.add_argument('-b', '--board', type=str,
- help="Board configuration file, eg. my.board", default="")
- parser.add_argument('-c', '--config', type=str,
- help="Rail config to monitor, eg my.scenario", default="")
- parser.add_argument('-A', '--serial', type=str,
- help="Serial number of sweetberry A", default="")
- parser.add_argument('-B', '--serial_b', type=str,
- help="Serial number of sweetberry B", default="")
- parser.add_argument('-t', '--integration_us', type=int,
- help="Target integration time for samples", default=100000)
- parser.add_argument('-s', '--seconds', type=float,
- help="Seconds to run capture", default=0.)
- parser.add_argument('--date', default=False,
- help="Sync logged timestamp to host date", action="store_true")
- parser.add_argument('--ms', default=False,
- help="Print timestamp as milliseconds", action="store_true")
- parser.add_argument('--mW', default=False,
- help="Print power as milliwatts, otherwise default to microwatts",
- action="store_true")
- parser.add_argument('--slow', default=False,
- help="Intentionally overflow", action="store_true")
- parser.add_argument('--print_stats', default=False, action="store_true",
- help="Print statistics for sweetberry readings at the end")
- parser.add_argument('--save_stats', type=str, nargs='?',
- dest='stats_dir', metavar='STATS_DIR',
- const=os.path.dirname(os.path.abspath(__file__)), default=None,
- help="Save statistics for sweetberry readings to %(metavar)s if "
- "%(metavar)s is specified, %(metavar)s will be created if it does "
- "not exist; if %(metavar)s is not specified but the flag is set, "
- "stats will be saved to where %(prog)s is located; if this flag is "
- "not set, then do not save stats")
- parser.add_argument('--save_stats_json', type=str, nargs='?',
- dest='stats_json_dir', metavar='STATS_JSON_DIR',
- const=os.path.dirname(os.path.abspath(__file__)), default=None,
- help="Save means for sweetberry readings in json to %(metavar)s if "
- "%(metavar)s is specified, %(metavar)s will be created if it does "
- "not exist; if %(metavar)s is not specified but the flag is set, "
- "stats will be saved to where %(prog)s is located; if this flag is "
- "not set, then do not save stats")
- parser.add_argument('--no_print_raw_data',
- dest='print_raw_data', default=True, action="store_false",
- help="Not print raw sweetberry readings at real time, default is to "
- "print")
- parser.add_argument('--save_raw_data', type=str, nargs='?',
- dest='raw_data_dir', metavar='RAW_DATA_DIR',
- const=os.path.dirname(os.path.abspath(__file__)), default=None,
- help="Save raw data for sweetberry readings to %(metavar)s if "
- "%(metavar)s is specified, %(metavar)s will be created if it does "
- "not exist; if %(metavar)s is not specified but the flag is set, "
- "raw data will be saved to where %(prog)s is located; if this flag "
- "is not set, then do not save raw data")
- parser.add_argument('-v', '--verbose', default=False,
- help="Very chatty printout", action="store_true")
-
- args = parser.parse_args(argv)
-
- root_logger = logging.getLogger(__name__)
- if args.verbose:
- root_logger.setLevel(logging.DEBUG)
- else:
- root_logger.setLevel(logging.INFO)
-
- # if powerlog is used through main, log to sys.stdout
- if __name__ == "__main__":
- stdout_handler = logging.StreamHandler(sys.stdout)
- stdout_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
- root_logger.addHandler(stdout_handler)
-
- integration_us_request = args.integration_us
- if not args.board:
- raise Exception("Power", "No board file selected, see board.README")
- if not args.config:
- raise Exception("Power", "No config file selected, see board.README")
-
- brdfile = args.board
- cfgfile = args.config
- seconds = args.seconds
- serial_a = args.serial
- serial_b = args.serial_b
- sync_date = args.date
- use_ms = args.ms
- use_mW = args.mW
- print_stats = args.print_stats
- stats_dir = args.stats_dir
- stats_json_dir = args.stats_json_dir
- print_raw_data = args.print_raw_data
- raw_data_dir = args.raw_data_dir
-
- boards = []
-
- sync_speed = .8
- if args.slow:
- sync_speed = 1.2
-
- # Set up logging interface.
- powerlogger = powerlog(brdfile, cfgfile, serial_a=serial_a, serial_b=serial_b,
- sync_date=sync_date, use_ms=use_ms, use_mW=use_mW,
- print_stats=print_stats, stats_dir=stats_dir,
- stats_json_dir=stats_json_dir,
- print_raw_data=print_raw_data,raw_data_dir=raw_data_dir)
-
- # Start logging.
- powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed)
+ if argv is None:
+ argv = sys.argv[1:]
+ # Command line argument description.
+ parser = argparse.ArgumentParser(description="Gather CSV data from sweetberry")
+ parser.add_argument(
+ "-b",
+ "--board",
+ type=str,
+ help="Board configuration file, eg. my.board",
+ default="",
+ )
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=str,
+ help="Rail config to monitor, eg my.scenario",
+ default="",
+ )
+ parser.add_argument(
+ "-A", "--serial", type=str, help="Serial number of sweetberry A", default=""
+ )
+ parser.add_argument(
+ "-B", "--serial_b", type=str, help="Serial number of sweetberry B", default=""
+ )
+ parser.add_argument(
+ "-t",
+ "--integration_us",
+ type=int,
+ help="Target integration time for samples",
+ default=100000,
+ )
+ parser.add_argument(
+ "-s", "--seconds", type=float, help="Seconds to run capture", default=0.0
+ )
+ parser.add_argument(
+ "--date",
+ default=False,
+ help="Sync logged timestamp to host date",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--ms",
+ default=False,
+ help="Print timestamp as milliseconds",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--mW",
+ default=False,
+ help="Print power as milliwatts, otherwise default to microwatts",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--slow", default=False, help="Intentionally overflow", action="store_true"
+ )
+ parser.add_argument(
+ "--print_stats",
+ default=False,
+ action="store_true",
+ help="Print statistics for sweetberry readings at the end",
+ )
+ parser.add_argument(
+ "--save_stats",
+ type=str,
+ nargs="?",
+ dest="stats_dir",
+ metavar="STATS_DIR",
+ const=os.path.dirname(os.path.abspath(__file__)),
+ default=None,
+ help="Save statistics for sweetberry readings to %(metavar)s if "
+ "%(metavar)s is specified, %(metavar)s will be created if it does "
+ "not exist; if %(metavar)s is not specified but the flag is set, "
+ "stats will be saved to where %(prog)s is located; if this flag is "
+ "not set, then do not save stats",
+ )
+ parser.add_argument(
+ "--save_stats_json",
+ type=str,
+ nargs="?",
+ dest="stats_json_dir",
+ metavar="STATS_JSON_DIR",
+ const=os.path.dirname(os.path.abspath(__file__)),
+ default=None,
+ help="Save means for sweetberry readings in json to %(metavar)s if "
+ "%(metavar)s is specified, %(metavar)s will be created if it does "
+ "not exist; if %(metavar)s is not specified but the flag is set, "
+ "stats will be saved to where %(prog)s is located; if this flag is "
+ "not set, then do not save stats",
+ )
+ parser.add_argument(
+ "--no_print_raw_data",
+ dest="print_raw_data",
+ default=True,
+ action="store_false",
+ help="Not print raw sweetberry readings at real time, default is to " "print",
+ )
+ parser.add_argument(
+ "--save_raw_data",
+ type=str,
+ nargs="?",
+ dest="raw_data_dir",
+ metavar="RAW_DATA_DIR",
+ const=os.path.dirname(os.path.abspath(__file__)),
+ default=None,
+ help="Save raw data for sweetberry readings to %(metavar)s if "
+ "%(metavar)s is specified, %(metavar)s will be created if it does "
+ "not exist; if %(metavar)s is not specified but the flag is set, "
+ "raw data will be saved to where %(prog)s is located; if this flag "
+ "is not set, then do not save raw data",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ default=False,
+ help="Very chatty printout",
+ action="store_true",
+ )
+
+ args = parser.parse_args(argv)
+
+ root_logger = logging.getLogger(__name__)
+ if args.verbose:
+ root_logger.setLevel(logging.DEBUG)
+ else:
+ root_logger.setLevel(logging.INFO)
+
+ # if powerlog is used through main, log to sys.stdout
+ if __name__ == "__main__":
+ stdout_handler = logging.StreamHandler(sys.stdout)
+ stdout_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
+ root_logger.addHandler(stdout_handler)
+
+ integration_us_request = args.integration_us
+ if not args.board:
+ raise Exception("Power", "No board file selected, see board.README")
+ if not args.config:
+ raise Exception("Power", "No config file selected, see board.README")
+
+ brdfile = args.board
+ cfgfile = args.config
+ seconds = args.seconds
+ serial_a = args.serial
+ serial_b = args.serial_b
+ sync_date = args.date
+ use_ms = args.ms
+ use_mW = args.mW
+ print_stats = args.print_stats
+ stats_dir = args.stats_dir
+ stats_json_dir = args.stats_json_dir
+ print_raw_data = args.print_raw_data
+ raw_data_dir = args.raw_data_dir
+
+ boards = []
+
+ sync_speed = 0.8
+ if args.slow:
+ sync_speed = 1.2
+
+ # Set up logging interface.
+ powerlogger = powerlog(
+ brdfile,
+ cfgfile,
+ serial_a=serial_a,
+ serial_b=serial_b,
+ sync_date=sync_date,
+ use_ms=use_ms,
+ use_mW=use_mW,
+ print_stats=print_stats,
+ stats_dir=stats_dir,
+ stats_json_dir=stats_json_dir,
+ print_raw_data=print_raw_data,
+ raw_data_dir=raw_data_dir,
+ )
+
+ # Start logging.
+ powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed)
if __name__ == "__main__":
- main()
+ main()
diff --git a/extra/usb_power/powerlog_unittest.py b/extra/usb_power/powerlog_unittest.py
index 1d0718530e..693826c16d 100644
--- a/extra/usb_power/powerlog_unittest.py
+++ b/extra/usb_power/powerlog_unittest.py
@@ -15,40 +15,42 @@ import unittest
import powerlog
+
class TestPowerlog(unittest.TestCase):
- """Test to verify powerlog util methods work as expected."""
-
- def setUp(self):
- """Set up data and create a temporary directory to save data and stats."""
- self.tempdir = tempfile.mkdtemp()
- self.filename = 'testfile'
- self.filepath = os.path.join(self.tempdir, self.filename)
- with open(self.filepath, 'w') as f:
- f.write('')
-
- def tearDown(self):
- """Delete the temporary directory and its content."""
- shutil.rmtree(self.tempdir)
-
- def test_ProcessFilenameAbsoluteFilePath(self):
- """Absolute file path is returned unchanged."""
- processed_fname = powerlog.process_filename(self.filepath)
- self.assertEqual(self.filepath, processed_fname)
-
- def test_ProcessFilenameRelativeFilePath(self):
- """Finds relative file path inside a known config location."""
- original = powerlog.CONFIG_LOCATIONS
- powerlog.CONFIG_LOCATIONS = [self.tempdir]
- processed_fname = powerlog.process_filename(self.filename)
- try:
- self.assertEqual(self.filepath, processed_fname)
- finally:
- powerlog.CONFIG_LOCATIONS = original
-
- def test_ProcessFilenameInvalid(self):
- """IOError is raised when file cannot be found by any of the four ways."""
- with self.assertRaises(IOError):
- powerlog.process_filename(self.filename)
-
-if __name__ == '__main__':
- unittest.main()
+ """Test to verify powerlog util methods work as expected."""
+
+ def setUp(self):
+ """Set up data and create a temporary directory to save data and stats."""
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = "testfile"
+ self.filepath = os.path.join(self.tempdir, self.filename)
+ with open(self.filepath, "w") as f:
+ f.write("")
+
+ def tearDown(self):
+ """Delete the temporary directory and its content."""
+ shutil.rmtree(self.tempdir)
+
+ def test_ProcessFilenameAbsoluteFilePath(self):
+ """Absolute file path is returned unchanged."""
+ processed_fname = powerlog.process_filename(self.filepath)
+ self.assertEqual(self.filepath, processed_fname)
+
+ def test_ProcessFilenameRelativeFilePath(self):
+ """Finds relative file path inside a known config location."""
+ original = powerlog.CONFIG_LOCATIONS
+ powerlog.CONFIG_LOCATIONS = [self.tempdir]
+ processed_fname = powerlog.process_filename(self.filename)
+ try:
+ self.assertEqual(self.filepath, processed_fname)
+ finally:
+ powerlog.CONFIG_LOCATIONS = original
+
+ def test_ProcessFilenameInvalid(self):
+ """IOError is raised when file cannot be found by any of the four ways."""
+ with self.assertRaises(IOError):
+ powerlog.process_filename(self.filename)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extra/usb_power/stats_manager.py b/extra/usb_power/stats_manager.py
index 0f8c3fcb15..633312311d 100644
--- a/extra/usb_power/stats_manager.py
+++ b/extra/usb_power/stats_manager.py
@@ -20,382 +20,391 @@ import os
import numpy
-STATS_PREFIX = '@@'
-NAN_TAG = '*'
-NAN_DESCRIPTION = '%s domains contain NaN samples' % NAN_TAG
+STATS_PREFIX = "@@"
+NAN_TAG = "*"
+NAN_DESCRIPTION = "%s domains contain NaN samples" % NAN_TAG
LONG_UNIT = {
- '': 'N/A',
- 'mW': 'milliwatt',
- 'uW': 'microwatt',
- 'mV': 'millivolt',
- 'uA': 'microamp',
- 'uV': 'microvolt'
+ "": "N/A",
+ "mW": "milliwatt",
+ "uW": "microwatt",
+ "mV": "millivolt",
+ "uA": "microamp",
+ "uV": "microvolt",
}
class StatsManagerError(Exception):
- """Errors in StatsManager class."""
- pass
+ """Errors in StatsManager class."""
+ pass
-class StatsManager(object):
- """Calculates statistics for several lists of data(float).
-
- Example usage:
-
- >>> stats = StatsManager(title='Title Banner')
- >>> stats.AddSample(TIME_KEY, 50.0)
- >>> stats.AddSample(TIME_KEY, 25.0)
- >>> stats.AddSample(TIME_KEY, 40.0)
- >>> stats.AddSample(TIME_KEY, 10.0)
- >>> stats.AddSample(TIME_KEY, 10.0)
- >>> stats.AddSample('frobnicate', 11.5)
- >>> stats.AddSample('frobnicate', 9.0)
- >>> stats.AddSample('foobar', 11111.0)
- >>> stats.AddSample('foobar', 22222.0)
- >>> stats.CalculateStats()
- >>> print(stats.SummaryToString())
- ` @@--------------------------------------------------------------
- ` @@ Title Banner
- @@--------------------------------------------------------------
- @@ NAME COUNT MEAN STDDEV MAX MIN
- @@ sample_msecs 4 31.25 15.16 50.00 10.00
- @@ foobar 2 16666.50 5555.50 22222.00 11111.00
- @@ frobnicate 2 10.25 1.25 11.50 9.00
- ` @@--------------------------------------------------------------
-
- Attributes:
- _data: dict of list of readings for each domain(key)
- _unit: dict of unit for each domain(key)
- _smid: id supplied to differentiate data output to other StatsManager
- instances that potentially save to the same directory
- if smid all output files will be named |smid|_|fname|
- _title: title to add as banner to formatted summary. If no title,
- no banner gets added
- _order: list of formatting order for domains. Domains not listed are
- displayed in sorted order
- _hide_domains: collection of domains to hide when formatting summary string
- _accept_nan: flag to indicate if NaN samples are acceptable
- _nan_domains: set to keep track of which domains contain NaN samples
- _summary: dict of stats per domain (key): min, max, count, mean, stddev
- _logger = StatsManager logger
-
- Note:
- _summary is empty until CalculateStats() is called, and is updated when
- CalculateStats() is called.
- """
-
- # pylint: disable=W0102
- def __init__(self, smid='', title='', order=[], hide_domains=[],
- accept_nan=True):
- """Initialize infrastructure for data and their statistics."""
- self._title = title
- self._data = collections.defaultdict(list)
- self._unit = collections.defaultdict(str)
- self._smid = smid
- self._order = order
- self._hide_domains = hide_domains
- self._accept_nan = accept_nan
- self._nan_domains = set()
- self._summary = {}
- self._logger = logging.getLogger(type(self).__name__)
-
- def AddSample(self, domain, sample):
- """Add one sample for a domain.
-
- Args:
- domain: the domain name for the sample.
- sample: one time sample for domain, expect type float.
-
- Raises:
- StatsManagerError: if trying to add NaN and |_accept_nan| is false
- """
- try:
- sample = float(sample)
- except ValueError:
- # if we don't accept nan this will be caught below
- self._logger.debug('sample %s for domain %s is not a number. Making NaN',
- sample, domain)
- sample = float('NaN')
- if not self._accept_nan and math.isnan(sample):
- raise StatsManagerError('accept_nan is false. Cannot add NaN sample.')
- self._data[domain].append(sample)
- if math.isnan(sample):
- self._nan_domains.add(domain)
-
- def SetUnit(self, domain, unit):
- """Set the unit for a domain.
-
- There can be only one unit for each domain. Setting unit twice will
- overwrite the original unit.
-
- Args:
- domain: the domain name.
- unit: unit of the domain.
- """
- if domain in self._unit:
- self._logger.warning('overwriting the unit of %s, old unit is %s, new '
- 'unit is %s.', domain, self._unit[domain], unit)
- self._unit[domain] = unit
- def CalculateStats(self):
- """Calculate stats for all domain-data pairs.
-
- First erases all previous stats, then calculate stats for all data.
- """
- self._summary = {}
- for domain, data in self._data.items():
- data_np = numpy.array(data)
- self._summary[domain] = {
- 'mean': numpy.nanmean(data_np),
- 'min': numpy.nanmin(data_np),
- 'max': numpy.nanmax(data_np),
- 'stddev': numpy.nanstd(data_np),
- 'count': data_np.size,
- }
-
- @property
- def DomainsToDisplay(self):
- """List of domains that the manager will output in summaries."""
- return set(self._summary.keys()) - set(self._hide_domains)
-
- @property
- def NanInOutput(self):
- """Return whether any of the domains to display have NaN values."""
- return bool(len(set(self._nan_domains) & self.DomainsToDisplay))
-
- def _SummaryTable(self):
- """Generate the matrix to output as a summary.
-
- Returns:
- A 2d matrix of headers and their data for each domain
- e.g.
- [[NAME, COUNT, MEAN, STDDEV, MAX, MIN],
- [pp5000_mw, 10, 50, 0, 50, 50]]
- """
- headers = ('NAME', 'COUNT', 'MEAN', 'STDDEV', 'MAX', 'MIN')
- table = [headers]
- # determine what domains to display & and the order
- domains_to_display = self.DomainsToDisplay
- display_order = [key for key in self._order if key in domains_to_display]
- domains_to_display -= set(display_order)
- display_order.extend(sorted(domains_to_display))
- for domain in display_order:
- stats = self._summary[domain]
- if not domain.endswith(self._unit[domain]):
- domain = '%s_%s' % (domain, self._unit[domain])
- if domain in self._nan_domains:
- domain = '%s%s' % (domain, NAN_TAG)
- row = [domain]
- row.append(str(stats['count']))
- for entry in headers[2:]:
- row.append('%.2f' % stats[entry.lower()])
- table.append(row)
- return table
-
- def SummaryToMarkdownString(self):
- """Format the summary into a b/ compatible markdown table string.
-
- This requires this sort of output format
-
- | header1 | header2 | header3 | ...
- | --------- | --------- | --------- | ...
- | sample1h1 | sample1h2 | sample1h3 | ...
- .
- .
- .
-
- Returns:
- formatted summary string.
- """
- # All we need to do before processing is insert a row of '-' between
- # the headers, and the data
- table = self._SummaryTable()
- columns = len(table[0])
- # Using '-:' to allow the numbers to be right aligned
- sep_row = ['-'] + ['-:'] * (columns - 1)
- table.insert(1, sep_row)
- text_rows = ['|'.join(r) for r in table]
- body = '\n'.join(['|%s|' % r for r in text_rows])
- if self._title:
- title_section = '**%s** \n\n' % self._title
- body = title_section + body
- # Make sure that the body is terminated with a newline.
- return body + '\n'
-
- def SummaryToString(self, prefix=STATS_PREFIX):
- """Format summary into a string, ready for pretty print.
-
- See class description for format example.
-
- Args:
- prefix: start every row in summary string with prefix, for easier reading.
-
- Returns:
- formatted summary string.
- """
- table = self._SummaryTable()
- max_col_width = []
- for col_idx in range(len(table[0])):
- col_item_widths = [len(row[col_idx]) for row in table]
- max_col_width.append(max(col_item_widths))
-
- formatted_lines = []
- for row in table:
- formatted_row = prefix + ' '
- for i in range(len(row)):
- formatted_row += row[i].rjust(max_col_width[i] + 2)
- formatted_lines.append(formatted_row)
- if self.NanInOutput:
- formatted_lines.append('%s %s' % (prefix, NAN_DESCRIPTION))
-
- if self._title:
- line_length = len(formatted_lines[0])
- dec_length = len(prefix)
- # trim title to be at most as long as the longest line without the prefix
- title = self._title[:(line_length - dec_length)]
- # line is a seperator line consisting of -----
- line = '%s%s' % (prefix, '-' * (line_length - dec_length))
- # prepend the prefix to the centered title
- padded_title = '%s%s' % (prefix, title.center(line_length)[dec_length:])
- formatted_lines = [line, padded_title, line] + formatted_lines + [line]
- formatted_output = '\n'.join(formatted_lines)
- return formatted_output
-
- def GetSummary(self):
- """Getter for summary."""
- return self._summary
-
- def _MakeUniqueFName(self, fname):
- """prepend |_smid| to fname & rotate fname to ensure uniqueness.
-
- Before saving a file through the StatsManager, make sure that the filename
- is unique, first by prepending the smid if any and otherwise by appending
- increasing integer suffixes until the filename is unique.
-
- If |smid| is defined /path/to/example/file.txt becomes
- /path/to/example/{smid}_file.txt.
-
- The rotation works by changing /path/to/example/somename.txt to
- /path/to/example/somename1.txt if the first one already exists on the
- system.
-
- Note: this is not thread-safe. While it makes sense to use StatsManager
- in a threaded data-collection, the data retrieval should happen in a
- single threaded environment to ensure files don't get potentially clobbered.
-
- Args:
- fname: filename to ensure uniqueness.
-
- Returns:
- {smid_}fname{tag}.[b].ext
- the smid portion gets prepended if |smid| is defined
- the tag portion gets appended if necessary to ensure unique fname
- """
- fdir = os.path.dirname(fname)
- base, ext = os.path.splitext(os.path.basename(fname))
- if self._smid:
- base = '%s_%s' % (self._smid, base)
- unique_fname = os.path.join(fdir, '%s%s' % (base, ext))
- tag = 0
- while os.path.exists(unique_fname):
- old_fname = unique_fname
- unique_fname = os.path.join(fdir, '%s%d%s' % (base, tag, ext))
- self._logger.warning('Attempted to store stats information at %s, but '
- 'file already exists. Attempting to store at %s '
- 'now.', old_fname, unique_fname)
- tag += 1
- return unique_fname
-
- def SaveSummary(self, directory, fname='summary.txt', prefix=STATS_PREFIX):
- """Save summary to file.
-
- Args:
- directory: directory to save the summary in.
- fname: filename to save summary under.
- prefix: start every row in summary string with prefix, for easier reading.
-
- Returns:
- full path of summary save location
- """
- summary_str = self.SummaryToString(prefix=prefix) + '\n'
- return self._SaveSummary(summary_str, directory, fname)
-
- def SaveSummaryJSON(self, directory, fname='summary.json'):
- """Save summary (only MEAN) into a JSON file.
-
- Args:
- directory: directory to save the JSON summary in.
- fname: filename to save summary under.
-
- Returns:
- full path of summary save location
- """
- data = {}
- for domain in self._summary:
- unit = LONG_UNIT.get(self._unit[domain], self._unit[domain])
- data_entry = {'mean': self._summary[domain]['mean'], 'unit': unit}
- data[domain] = data_entry
- summary_str = json.dumps(data, indent=2)
- return self._SaveSummary(summary_str, directory, fname)
-
- def SaveSummaryMD(self, directory, fname='summary.md'):
- """Save summary into a MD file to paste into b/.
-
- Args:
- directory: directory to save the MD summary in.
- fname: filename to save summary under.
-
- Returns:
- full path of summary save location
+class StatsManager(object):
+ """Calculates statistics for several lists of data(float).
+
+ Example usage:
+
+ >>> stats = StatsManager(title='Title Banner')
+ >>> stats.AddSample(TIME_KEY, 50.0)
+ >>> stats.AddSample(TIME_KEY, 25.0)
+ >>> stats.AddSample(TIME_KEY, 40.0)
+ >>> stats.AddSample(TIME_KEY, 10.0)
+ >>> stats.AddSample(TIME_KEY, 10.0)
+ >>> stats.AddSample('frobnicate', 11.5)
+ >>> stats.AddSample('frobnicate', 9.0)
+ >>> stats.AddSample('foobar', 11111.0)
+ >>> stats.AddSample('foobar', 22222.0)
+ >>> stats.CalculateStats()
+ >>> print(stats.SummaryToString())
+ ` @@--------------------------------------------------------------
+ ` @@ Title Banner
+ @@--------------------------------------------------------------
+ @@ NAME COUNT MEAN STDDEV MAX MIN
+ @@ sample_msecs 4 31.25 15.16 50.00 10.00
+ @@ foobar 2 16666.50 5555.50 22222.00 11111.00
+ @@ frobnicate 2 10.25 1.25 11.50 9.00
+ ` @@--------------------------------------------------------------
+
+ Attributes:
+ _data: dict of list of readings for each domain(key)
+ _unit: dict of unit for each domain(key)
+ _smid: id supplied to differentiate data output to other StatsManager
+ instances that potentially save to the same directory
+ if smid all output files will be named |smid|_|fname|
+ _title: title to add as banner to formatted summary. If no title,
+ no banner gets added
+ _order: list of formatting order for domains. Domains not listed are
+ displayed in sorted order
+ _hide_domains: collection of domains to hide when formatting summary string
+ _accept_nan: flag to indicate if NaN samples are acceptable
+ _nan_domains: set to keep track of which domains contain NaN samples
+ _summary: dict of stats per domain (key): min, max, count, mean, stddev
+ _logger = StatsManager logger
+
+ Note:
+ _summary is empty until CalculateStats() is called, and is updated when
+ CalculateStats() is called.
"""
- summary_str = self.SummaryToMarkdownString()
- return self._SaveSummary(summary_str, directory, fname)
- def _SaveSummary(self, output_str, directory, fname):
- """Wrote |output_str| to |fname|.
-
- Args:
- output_str: formatted output string
- directory: directory to save the summary in.
- fname: filename to save summary under.
-
- Returns:
- full path of summary save location
- """
- if not os.path.exists(directory):
- os.makedirs(directory)
- fname = self._MakeUniqueFName(os.path.join(directory, fname))
- with open(fname, 'w') as f:
- f.write(output_str)
- return fname
-
- def GetRawData(self):
- """Getter for all raw_data."""
- return self._data
-
- def SaveRawData(self, directory, dirname='raw_data'):
- """Save raw data to file.
-
- Args:
- directory: directory to create the raw data folder in.
- dirname: folder in which raw data live.
-
- Returns:
- list of full path of each domain's raw data save location
- """
- if not os.path.exists(directory):
- os.makedirs(directory)
- dirname = os.path.join(directory, dirname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- fnames = []
- for domain, data in self._data.items():
- if not domain.endswith(self._unit[domain]):
- domain = '%s_%s' % (domain, self._unit[domain])
- fname = self._MakeUniqueFName(os.path.join(dirname, '%s.txt' % domain))
- with open(fname, 'w') as f:
- f.write('\n'.join('%.2f' % sample for sample in data) + '\n')
- fnames.append(fname)
- return fnames
+ # pylint: disable=W0102
+ def __init__(self, smid="", title="", order=[], hide_domains=[], accept_nan=True):
+ """Initialize infrastructure for data and their statistics."""
+ self._title = title
+ self._data = collections.defaultdict(list)
+ self._unit = collections.defaultdict(str)
+ self._smid = smid
+ self._order = order
+ self._hide_domains = hide_domains
+ self._accept_nan = accept_nan
+ self._nan_domains = set()
+ self._summary = {}
+ self._logger = logging.getLogger(type(self).__name__)
+
+ def AddSample(self, domain, sample):
+ """Add one sample for a domain.
+
+ Args:
+ domain: the domain name for the sample.
+ sample: one time sample for domain, expect type float.
+
+ Raises:
+ StatsManagerError: if trying to add NaN and |_accept_nan| is false
+ """
+ try:
+ sample = float(sample)
+ except ValueError:
+ # if we don't accept nan this will be caught below
+ self._logger.debug(
+ "sample %s for domain %s is not a number. Making NaN", sample, domain
+ )
+ sample = float("NaN")
+ if not self._accept_nan and math.isnan(sample):
+ raise StatsManagerError("accept_nan is false. Cannot add NaN sample.")
+ self._data[domain].append(sample)
+ if math.isnan(sample):
+ self._nan_domains.add(domain)
+
+ def SetUnit(self, domain, unit):
+ """Set the unit for a domain.
+
+ There can be only one unit for each domain. Setting unit twice will
+ overwrite the original unit.
+
+ Args:
+ domain: the domain name.
+ unit: unit of the domain.
+ """
+ if domain in self._unit:
+ self._logger.warning(
+ "overwriting the unit of %s, old unit is %s, new " "unit is %s.",
+ domain,
+ self._unit[domain],
+ unit,
+ )
+ self._unit[domain] = unit
+
+ def CalculateStats(self):
+ """Calculate stats for all domain-data pairs.
+
+ First erases all previous stats, then calculate stats for all data.
+ """
+ self._summary = {}
+ for domain, data in self._data.items():
+ data_np = numpy.array(data)
+ self._summary[domain] = {
+ "mean": numpy.nanmean(data_np),
+ "min": numpy.nanmin(data_np),
+ "max": numpy.nanmax(data_np),
+ "stddev": numpy.nanstd(data_np),
+ "count": data_np.size,
+ }
+
+ @property
+ def DomainsToDisplay(self):
+ """List of domains that the manager will output in summaries."""
+ return set(self._summary.keys()) - set(self._hide_domains)
+
+ @property
+ def NanInOutput(self):
+ """Return whether any of the domains to display have NaN values."""
+ return bool(len(set(self._nan_domains) & self.DomainsToDisplay))
+
+ def _SummaryTable(self):
+ """Generate the matrix to output as a summary.
+
+ Returns:
+ A 2d matrix of headers and their data for each domain
+ e.g.
+ [[NAME, COUNT, MEAN, STDDEV, MAX, MIN],
+ [pp5000_mw, 10, 50, 0, 50, 50]]
+ """
+ headers = ("NAME", "COUNT", "MEAN", "STDDEV", "MAX", "MIN")
+ table = [headers]
+ # determine what domains to display & and the order
+ domains_to_display = self.DomainsToDisplay
+ display_order = [key for key in self._order if key in domains_to_display]
+ domains_to_display -= set(display_order)
+ display_order.extend(sorted(domains_to_display))
+ for domain in display_order:
+ stats = self._summary[domain]
+ if not domain.endswith(self._unit[domain]):
+ domain = "%s_%s" % (domain, self._unit[domain])
+ if domain in self._nan_domains:
+ domain = "%s%s" % (domain, NAN_TAG)
+ row = [domain]
+ row.append(str(stats["count"]))
+ for entry in headers[2:]:
+ row.append("%.2f" % stats[entry.lower()])
+ table.append(row)
+ return table
+
+ def SummaryToMarkdownString(self):
+ """Format the summary into a b/ compatible markdown table string.
+
+ This requires this sort of output format
+
+ | header1 | header2 | header3 | ...
+ | --------- | --------- | --------- | ...
+ | sample1h1 | sample1h2 | sample1h3 | ...
+ .
+ .
+ .
+
+ Returns:
+ formatted summary string.
+ """
+ # All we need to do before processing is insert a row of '-' between
+ # the headers, and the data
+ table = self._SummaryTable()
+ columns = len(table[0])
+ # Using '-:' to allow the numbers to be right aligned
+ sep_row = ["-"] + ["-:"] * (columns - 1)
+ table.insert(1, sep_row)
+ text_rows = ["|".join(r) for r in table]
+ body = "\n".join(["|%s|" % r for r in text_rows])
+ if self._title:
+ title_section = "**%s** \n\n" % self._title
+ body = title_section + body
+ # Make sure that the body is terminated with a newline.
+ return body + "\n"
+
+ def SummaryToString(self, prefix=STATS_PREFIX):
+ """Format summary into a string, ready for pretty print.
+
+ See class description for format example.
+
+ Args:
+ prefix: start every row in summary string with prefix, for easier reading.
+
+ Returns:
+ formatted summary string.
+ """
+ table = self._SummaryTable()
+ max_col_width = []
+ for col_idx in range(len(table[0])):
+ col_item_widths = [len(row[col_idx]) for row in table]
+ max_col_width.append(max(col_item_widths))
+
+ formatted_lines = []
+ for row in table:
+ formatted_row = prefix + " "
+ for i in range(len(row)):
+ formatted_row += row[i].rjust(max_col_width[i] + 2)
+ formatted_lines.append(formatted_row)
+ if self.NanInOutput:
+ formatted_lines.append("%s %s" % (prefix, NAN_DESCRIPTION))
+
+ if self._title:
+ line_length = len(formatted_lines[0])
+ dec_length = len(prefix)
+ # trim title to be at most as long as the longest line without the prefix
+ title = self._title[: (line_length - dec_length)]
+ # line is a seperator line consisting of -----
+ line = "%s%s" % (prefix, "-" * (line_length - dec_length))
+ # prepend the prefix to the centered title
+ padded_title = "%s%s" % (prefix, title.center(line_length)[dec_length:])
+ formatted_lines = [line, padded_title, line] + formatted_lines + [line]
+ formatted_output = "\n".join(formatted_lines)
+ return formatted_output
+
+ def GetSummary(self):
+ """Getter for summary."""
+ return self._summary
+
+ def _MakeUniqueFName(self, fname):
+ """prepend |_smid| to fname & rotate fname to ensure uniqueness.
+
+ Before saving a file through the StatsManager, make sure that the filename
+ is unique, first by prepending the smid if any and otherwise by appending
+ increasing integer suffixes until the filename is unique.
+
+ If |smid| is defined /path/to/example/file.txt becomes
+ /path/to/example/{smid}_file.txt.
+
+ The rotation works by changing /path/to/example/somename.txt to
+ /path/to/example/somename1.txt if the first one already exists on the
+ system.
+
+ Note: this is not thread-safe. While it makes sense to use StatsManager
+ in a threaded data-collection, the data retrieval should happen in a
+ single threaded environment to ensure files don't get potentially clobbered.
+
+ Args:
+ fname: filename to ensure uniqueness.
+
+ Returns:
+ {smid_}fname{tag}.[b].ext
+ the smid portion gets prepended if |smid| is defined
+ the tag portion gets appended if necessary to ensure unique fname
+ """
+ fdir = os.path.dirname(fname)
+ base, ext = os.path.splitext(os.path.basename(fname))
+ if self._smid:
+ base = "%s_%s" % (self._smid, base)
+ unique_fname = os.path.join(fdir, "%s%s" % (base, ext))
+ tag = 0
+ while os.path.exists(unique_fname):
+ old_fname = unique_fname
+ unique_fname = os.path.join(fdir, "%s%d%s" % (base, tag, ext))
+ self._logger.warning(
+ "Attempted to store stats information at %s, but "
+ "file already exists. Attempting to store at %s "
+ "now.",
+ old_fname,
+ unique_fname,
+ )
+ tag += 1
+ return unique_fname
+
+ def SaveSummary(self, directory, fname="summary.txt", prefix=STATS_PREFIX):
+ """Save summary to file.
+
+ Args:
+ directory: directory to save the summary in.
+ fname: filename to save summary under.
+ prefix: start every row in summary string with prefix, for easier reading.
+
+ Returns:
+ full path of summary save location
+ """
+ summary_str = self.SummaryToString(prefix=prefix) + "\n"
+ return self._SaveSummary(summary_str, directory, fname)
+
+ def SaveSummaryJSON(self, directory, fname="summary.json"):
+ """Save summary (only MEAN) into a JSON file.
+
+ Args:
+ directory: directory to save the JSON summary in.
+ fname: filename to save summary under.
+
+ Returns:
+ full path of summary save location
+ """
+ data = {}
+ for domain in self._summary:
+ unit = LONG_UNIT.get(self._unit[domain], self._unit[domain])
+ data_entry = {"mean": self._summary[domain]["mean"], "unit": unit}
+ data[domain] = data_entry
+ summary_str = json.dumps(data, indent=2)
+ return self._SaveSummary(summary_str, directory, fname)
+
+ def SaveSummaryMD(self, directory, fname="summary.md"):
+ """Save summary into a MD file to paste into b/.
+
+ Args:
+ directory: directory to save the MD summary in.
+ fname: filename to save summary under.
+
+ Returns:
+ full path of summary save location
+ """
+ summary_str = self.SummaryToMarkdownString()
+ return self._SaveSummary(summary_str, directory, fname)
+
+ def _SaveSummary(self, output_str, directory, fname):
+ """Wrote |output_str| to |fname|.
+
+ Args:
+ output_str: formatted output string
+ directory: directory to save the summary in.
+ fname: filename to save summary under.
+
+ Returns:
+ full path of summary save location
+ """
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ fname = self._MakeUniqueFName(os.path.join(directory, fname))
+ with open(fname, "w") as f:
+ f.write(output_str)
+ return fname
+
+ def GetRawData(self):
+ """Getter for all raw_data."""
+ return self._data
+
+ def SaveRawData(self, directory, dirname="raw_data"):
+ """Save raw data to file.
+
+ Args:
+ directory: directory to create the raw data folder in.
+ dirname: folder in which raw data live.
+
+ Returns:
+ list of full path of each domain's raw data save location
+ """
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ dirname = os.path.join(directory, dirname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ fnames = []
+ for domain, data in self._data.items():
+ if not domain.endswith(self._unit[domain]):
+ domain = "%s_%s" % (domain, self._unit[domain])
+ fname = self._MakeUniqueFName(os.path.join(dirname, "%s.txt" % domain))
+ with open(fname, "w") as f:
+ f.write("\n".join("%.2f" % sample for sample in data) + "\n")
+ fnames.append(fname)
+ return fnames
diff --git a/extra/usb_power/stats_manager_unittest.py b/extra/usb_power/stats_manager_unittest.py
index beb9984b93..37a472af98 100644
--- a/extra/usb_power/stats_manager_unittest.py
+++ b/extra/usb_power/stats_manager_unittest.py
@@ -9,6 +9,7 @@
"""Unit tests for StatsManager."""
from __future__ import print_function
+
import json
import os
import re
@@ -20,296 +21,304 @@ import stats_manager
class TestStatsManager(unittest.TestCase):
- """Test to verify StatsManager methods work as expected.
-
- StatsManager should collect raw data, calculate their statistics, and save
- them in expected format.
- """
-
- def _populate_mock_stats(self):
- """Create a populated & processed StatsManager to test data retrieval."""
- self.data.AddSample('A', 99999.5)
- self.data.AddSample('A', 100000.5)
- self.data.SetUnit('A', 'uW')
- self.data.SetUnit('A', 'mW')
- self.data.AddSample('B', 1.5)
- self.data.AddSample('B', 2.5)
- self.data.AddSample('B', 3.5)
- self.data.SetUnit('B', 'mV')
- self.data.CalculateStats()
-
- def _populate_mock_stats_no_unit(self):
- self.data.AddSample('B', 1000)
- self.data.AddSample('A', 200)
- self.data.SetUnit('A', 'blue')
-
- def setUp(self):
- """Set up StatsManager and create a temporary directory for test."""
- self.tempdir = tempfile.mkdtemp()
- self.data = stats_manager.StatsManager()
-
- def tearDown(self):
- """Delete the temporary directory and its content."""
- shutil.rmtree(self.tempdir)
-
- def test_AddSample(self):
- """Adding a sample successfully adds a sample."""
- self.data.AddSample('Test', 1000)
- self.data.SetUnit('Test', 'test')
- self.data.CalculateStats()
- summary = self.data.GetSummary()
- self.assertEqual(1, summary['Test']['count'])
-
- def test_AddSampleNoFloatAcceptNaN(self):
- """Adding a non-number adds 'NaN' and doesn't raise an exception."""
- self.data.AddSample('Test', 10)
- self.data.AddSample('Test', 20)
- # adding a fake NaN: one that gets converted into NaN internally
- self.data.AddSample('Test', 'fiesta')
- # adding a real NaN
- self.data.AddSample('Test', float('NaN'))
- self.data.SetUnit('Test', 'test')
- self.data.CalculateStats()
- summary = self.data.GetSummary()
- # assert that 'NaN' as added.
- self.assertEqual(4, summary['Test']['count'])
- # assert that mean, min, and max calculatings ignore the 'NaN'
- self.assertEqual(10, summary['Test']['min'])
- self.assertEqual(20, summary['Test']['max'])
- self.assertEqual(15, summary['Test']['mean'])
-
- def test_AddSampleNoFloatNotAcceptNaN(self):
- """Adding a non-number raises a StatsManagerError if accept_nan is False."""
- self.data = stats_manager.StatsManager(accept_nan=False)
- with self.assertRaisesRegexp(stats_manager.StatsManagerError,
- 'accept_nan is false. Cannot add NaN sample.'):
- # adding a fake NaN: one that gets converted into NaN internally
- self.data.AddSample('Test', 'fiesta')
- with self.assertRaisesRegexp(stats_manager.StatsManagerError,
- 'accept_nan is false. Cannot add NaN sample.'):
- # adding a real NaN
- self.data.AddSample('Test', float('NaN'))
-
- def test_AddSampleNoUnit(self):
- """Not adding a unit does not cause an exception on CalculateStats()."""
- self.data.AddSample('Test', 17)
- self.data.CalculateStats()
- summary = self.data.GetSummary()
- self.assertEqual(1, summary['Test']['count'])
-
- def test_UnitSuffix(self):
- """Unit gets appended as a suffix in the displayed summary."""
- self.data.AddSample('test', 250)
- self.data.SetUnit('test', 'mw')
- self.data.CalculateStats()
- summary_str = self.data.SummaryToString()
- self.assertIn('test_mw', summary_str)
-
- def test_DoubleUnitSuffix(self):
- """If domain already ends in unit, verify that unit doesn't get appended."""
- self.data.AddSample('test_mw', 250)
- self.data.SetUnit('test_mw', 'mw')
- self.data.CalculateStats()
- summary_str = self.data.SummaryToString()
- self.assertIn('test_mw', summary_str)
- self.assertNotIn('test_mw_mw', summary_str)
-
- def test_GetRawData(self):
- """GetRawData returns exact same data as fed in."""
- self._populate_mock_stats()
- raw_data = self.data.GetRawData()
- self.assertListEqual([99999.5, 100000.5], raw_data['A'])
- self.assertListEqual([1.5, 2.5, 3.5], raw_data['B'])
-
- def test_GetSummary(self):
- """GetSummary returns expected stats about the data fed in."""
- self._populate_mock_stats()
- summary = self.data.GetSummary()
- self.assertEqual(2, summary['A']['count'])
- self.assertAlmostEqual(100000.5, summary['A']['max'])
- self.assertAlmostEqual(99999.5, summary['A']['min'])
- self.assertAlmostEqual(0.5, summary['A']['stddev'])
- self.assertAlmostEqual(100000.0, summary['A']['mean'])
- self.assertEqual(3, summary['B']['count'])
- self.assertAlmostEqual(3.5, summary['B']['max'])
- self.assertAlmostEqual(1.5, summary['B']['min'])
- self.assertAlmostEqual(0.81649658092773, summary['B']['stddev'])
- self.assertAlmostEqual(2.5, summary['B']['mean'])
-
- def test_SaveRawData(self):
- """SaveRawData stores same data as fed in."""
- self._populate_mock_stats()
- dirname = 'unittest_raw_data'
- expected_files = set(['A_mW.txt', 'B_mV.txt'])
- fnames = self.data.SaveRawData(self.tempdir, dirname)
- files_returned = set([os.path.basename(f) for f in fnames])
- # Assert that only the expected files got returned.
- self.assertEqual(expected_files, files_returned)
- # Assert that only the returned files are in the outdir.
- self.assertEqual(set(os.listdir(os.path.join(self.tempdir, dirname))),
- files_returned)
- for fname in fnames:
- with open(fname, 'r') as f:
- if 'A_mW' in fname:
- self.assertEqual('99999.50', f.readline().strip())
- self.assertEqual('100000.50', f.readline().strip())
- if 'B_mV' in fname:
- self.assertEqual('1.50', f.readline().strip())
- self.assertEqual('2.50', f.readline().strip())
- self.assertEqual('3.50', f.readline().strip())
-
- def test_SaveRawDataNoUnit(self):
- """SaveRawData appends no unit suffix if the unit is not specified."""
- self._populate_mock_stats_no_unit()
- self.data.CalculateStats()
- outdir = 'unittest_raw_data'
- files = self.data.SaveRawData(self.tempdir, outdir)
- files = [os.path.basename(f) for f in files]
- # Verify nothing gets appended to domain for filename if no unit exists.
- self.assertIn('B.txt', files)
-
- def test_SaveRawDataSMID(self):
- """SaveRawData uses the smid when creating output filename."""
- identifier = 'ec'
- self.data = stats_manager.StatsManager(smid=identifier)
- self._populate_mock_stats()
- files = self.data.SaveRawData(self.tempdir)
- for fname in files:
- self.assertTrue(os.path.basename(fname).startswith(identifier))
-
- def test_SummaryToStringNaNHelp(self):
- """NaN containing row gets tagged with *, help banner gets added."""
- help_banner_exp = '%s %s' % (stats_manager.STATS_PREFIX,
- stats_manager.NAN_DESCRIPTION)
- nan_domain = 'A-domain'
- nan_domain_exp = '%s%s' % (nan_domain, stats_manager.NAN_TAG)
- # NaN helper banner is added when a NaN domain is found & domain gets tagged
- data = stats_manager.StatsManager()
- data.AddSample(nan_domain, float('NaN'))
- data.AddSample(nan_domain, 17)
- data.AddSample('B-domain', 17)
- data.CalculateStats()
- summarystr = data.SummaryToString()
- self.assertIn(help_banner_exp, summarystr)
- self.assertIn(nan_domain_exp, summarystr)
- # NaN helper banner is not added when no NaN domain output, no tagging
- data = stats_manager.StatsManager()
- # nan_domain in this scenario does not contain any NaN
- data.AddSample(nan_domain, 19)
- data.AddSample('B-domain', 17)
- data.CalculateStats()
- summarystr = data.SummaryToString()
- self.assertNotIn(help_banner_exp, summarystr)
- self.assertNotIn(nan_domain_exp, summarystr)
-
- def test_SummaryToStringTitle(self):
- """Title shows up in SummaryToString if title specified."""
- title = 'titulo'
- data = stats_manager.StatsManager(title=title)
- self._populate_mock_stats()
- summary_str = data.SummaryToString()
- self.assertIn(title, summary_str)
-
- def test_SummaryToStringHideDomains(self):
- """Keys indicated in hide_domains are not printed in the summary."""
- data = stats_manager.StatsManager(hide_domains=['A-domain'])
- data.AddSample('A-domain', 17)
- data.AddSample('B-domain', 17)
- data.CalculateStats()
- summary_str = data.SummaryToString()
- self.assertIn('B-domain', summary_str)
- self.assertNotIn('A-domain', summary_str)
-
- def test_SummaryToStringOrder(self):
- """Order passed into StatsManager is honoured when formatting summary."""
- # StatsManager that should print D & B first, and the subsequent elements
- # are sorted.
- d_b_a_c_regexp = re.compile('D-domain.*B-domain.*A-domain.*C-domain',
- re.DOTALL)
- data = stats_manager.StatsManager(order=['D-domain', 'B-domain'])
- data.AddSample('A-domain', 17)
- data.AddSample('B-domain', 17)
- data.AddSample('C-domain', 17)
- data.AddSample('D-domain', 17)
- data.CalculateStats()
- summary_str = data.SummaryToString()
- self.assertRegexpMatches(summary_str, d_b_a_c_regexp)
-
- def test_MakeUniqueFName(self):
- data = stats_manager.StatsManager()
- testfile = os.path.join(self.tempdir, 'testfile.txt')
- with open(testfile, 'w') as f:
- f.write('')
- expected_fname = os.path.join(self.tempdir, 'testfile0.txt')
- self.assertEqual(expected_fname, data._MakeUniqueFName(testfile))
-
- def test_SaveSummary(self):
- """SaveSummary properly dumps the summary into a file."""
- self._populate_mock_stats()
- fname = 'unittest_summary.txt'
- expected_fname = os.path.join(self.tempdir, fname)
- fname = self.data.SaveSummary(self.tempdir, fname)
- # Assert the reported fname is the same as the expected fname
- self.assertEqual(expected_fname, fname)
- # Assert only the reported fname is output (in the tempdir)
- self.assertEqual(set([os.path.basename(fname)]),
- set(os.listdir(self.tempdir)))
- with open(fname, 'r') as f:
- self.assertEqual(
- '@@ NAME COUNT MEAN STDDEV MAX MIN\n',
- f.readline())
- self.assertEqual(
- '@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n',
- f.readline())
- self.assertEqual(
- '@@ B_mV 3 2.50 0.82 3.50 1.50\n',
- f.readline())
-
- def test_SaveSummarySMID(self):
- """SaveSummary uses the smid when creating output filename."""
- identifier = 'ec'
- self.data = stats_manager.StatsManager(smid=identifier)
- self._populate_mock_stats()
- fname = os.path.basename(self.data.SaveSummary(self.tempdir))
- self.assertTrue(fname.startswith(identifier))
-
- def test_SaveSummaryJSON(self):
- """SaveSummaryJSON saves the added data properly in JSON format."""
- self._populate_mock_stats()
- fname = 'unittest_summary.json'
- expected_fname = os.path.join(self.tempdir, fname)
- fname = self.data.SaveSummaryJSON(self.tempdir, fname)
- # Assert the reported fname is the same as the expected fname
- self.assertEqual(expected_fname, fname)
- # Assert only the reported fname is output (in the tempdir)
- self.assertEqual(set([os.path.basename(fname)]),
- set(os.listdir(self.tempdir)))
- with open(fname, 'r') as f:
- summary = json.load(f)
- self.assertAlmostEqual(100000.0, summary['A']['mean'])
- self.assertEqual('milliwatt', summary['A']['unit'])
- self.assertAlmostEqual(2.5, summary['B']['mean'])
- self.assertEqual('millivolt', summary['B']['unit'])
-
- def test_SaveSummaryJSONSMID(self):
- """SaveSummaryJSON uses the smid when creating output filename."""
- identifier = 'ec'
- self.data = stats_manager.StatsManager(smid=identifier)
- self._populate_mock_stats()
- fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir))
- self.assertTrue(fname.startswith(identifier))
-
- def test_SaveSummaryJSONNoUnit(self):
- """SaveSummaryJSON marks unknown units properly as N/A."""
- self._populate_mock_stats_no_unit()
- self.data.CalculateStats()
- fname = 'unittest_summary.json'
- fname = self.data.SaveSummaryJSON(self.tempdir, fname)
- with open(fname, 'r') as f:
- summary = json.load(f)
- self.assertEqual('blue', summary['A']['unit'])
- # if no unit is specified, JSON should save 'N/A' as the unit.
- self.assertEqual('N/A', summary['B']['unit'])
-
-if __name__ == '__main__':
- unittest.main()
+ """Test to verify StatsManager methods work as expected.
+
+ StatsManager should collect raw data, calculate their statistics, and save
+ them in expected format.
+ """
+
+ def _populate_mock_stats(self):
+ """Create a populated & processed StatsManager to test data retrieval."""
+ self.data.AddSample("A", 99999.5)
+ self.data.AddSample("A", 100000.5)
+ self.data.SetUnit("A", "uW")
+ self.data.SetUnit("A", "mW")
+ self.data.AddSample("B", 1.5)
+ self.data.AddSample("B", 2.5)
+ self.data.AddSample("B", 3.5)
+ self.data.SetUnit("B", "mV")
+ self.data.CalculateStats()
+
+ def _populate_mock_stats_no_unit(self):
+ self.data.AddSample("B", 1000)
+ self.data.AddSample("A", 200)
+ self.data.SetUnit("A", "blue")
+
+ def setUp(self):
+ """Set up StatsManager and create a temporary directory for test."""
+ self.tempdir = tempfile.mkdtemp()
+ self.data = stats_manager.StatsManager()
+
+ def tearDown(self):
+ """Delete the temporary directory and its content."""
+ shutil.rmtree(self.tempdir)
+
+ def test_AddSample(self):
+ """Adding a sample successfully adds a sample."""
+ self.data.AddSample("Test", 1000)
+ self.data.SetUnit("Test", "test")
+ self.data.CalculateStats()
+ summary = self.data.GetSummary()
+ self.assertEqual(1, summary["Test"]["count"])
+
+ def test_AddSampleNoFloatAcceptNaN(self):
+ """Adding a non-number adds 'NaN' and doesn't raise an exception."""
+ self.data.AddSample("Test", 10)
+ self.data.AddSample("Test", 20)
+ # adding a fake NaN: one that gets converted into NaN internally
+ self.data.AddSample("Test", "fiesta")
+ # adding a real NaN
+ self.data.AddSample("Test", float("NaN"))
+ self.data.SetUnit("Test", "test")
+ self.data.CalculateStats()
+ summary = self.data.GetSummary()
+ # assert that 'NaN' as added.
+ self.assertEqual(4, summary["Test"]["count"])
+ # assert that mean, min, and max calculatings ignore the 'NaN'
+ self.assertEqual(10, summary["Test"]["min"])
+ self.assertEqual(20, summary["Test"]["max"])
+ self.assertEqual(15, summary["Test"]["mean"])
+
+ def test_AddSampleNoFloatNotAcceptNaN(self):
+ """Adding a non-number raises a StatsManagerError if accept_nan is False."""
+ self.data = stats_manager.StatsManager(accept_nan=False)
+ with self.assertRaisesRegexp(
+ stats_manager.StatsManagerError,
+ "accept_nan is false. Cannot add NaN sample.",
+ ):
+ # adding a fake NaN: one that gets converted into NaN internally
+ self.data.AddSample("Test", "fiesta")
+ with self.assertRaisesRegexp(
+ stats_manager.StatsManagerError,
+ "accept_nan is false. Cannot add NaN sample.",
+ ):
+ # adding a real NaN
+ self.data.AddSample("Test", float("NaN"))
+
+ def test_AddSampleNoUnit(self):
+ """Not adding a unit does not cause an exception on CalculateStats()."""
+ self.data.AddSample("Test", 17)
+ self.data.CalculateStats()
+ summary = self.data.GetSummary()
+ self.assertEqual(1, summary["Test"]["count"])
+
+ def test_UnitSuffix(self):
+ """Unit gets appended as a suffix in the displayed summary."""
+ self.data.AddSample("test", 250)
+ self.data.SetUnit("test", "mw")
+ self.data.CalculateStats()
+ summary_str = self.data.SummaryToString()
+ self.assertIn("test_mw", summary_str)
+
+ def test_DoubleUnitSuffix(self):
+ """If domain already ends in unit, verify that unit doesn't get appended."""
+ self.data.AddSample("test_mw", 250)
+ self.data.SetUnit("test_mw", "mw")
+ self.data.CalculateStats()
+ summary_str = self.data.SummaryToString()
+ self.assertIn("test_mw", summary_str)
+ self.assertNotIn("test_mw_mw", summary_str)
+
+ def test_GetRawData(self):
+ """GetRawData returns exact same data as fed in."""
+ self._populate_mock_stats()
+ raw_data = self.data.GetRawData()
+ self.assertListEqual([99999.5, 100000.5], raw_data["A"])
+ self.assertListEqual([1.5, 2.5, 3.5], raw_data["B"])
+
+ def test_GetSummary(self):
+ """GetSummary returns expected stats about the data fed in."""
+ self._populate_mock_stats()
+ summary = self.data.GetSummary()
+ self.assertEqual(2, summary["A"]["count"])
+ self.assertAlmostEqual(100000.5, summary["A"]["max"])
+ self.assertAlmostEqual(99999.5, summary["A"]["min"])
+ self.assertAlmostEqual(0.5, summary["A"]["stddev"])
+ self.assertAlmostEqual(100000.0, summary["A"]["mean"])
+ self.assertEqual(3, summary["B"]["count"])
+ self.assertAlmostEqual(3.5, summary["B"]["max"])
+ self.assertAlmostEqual(1.5, summary["B"]["min"])
+ self.assertAlmostEqual(0.81649658092773, summary["B"]["stddev"])
+ self.assertAlmostEqual(2.5, summary["B"]["mean"])
+
+ def test_SaveRawData(self):
+ """SaveRawData stores same data as fed in."""
+ self._populate_mock_stats()
+ dirname = "unittest_raw_data"
+ expected_files = set(["A_mW.txt", "B_mV.txt"])
+ fnames = self.data.SaveRawData(self.tempdir, dirname)
+ files_returned = set([os.path.basename(f) for f in fnames])
+ # Assert that only the expected files got returned.
+ self.assertEqual(expected_files, files_returned)
+ # Assert that only the returned files are in the outdir.
+ self.assertEqual(
+ set(os.listdir(os.path.join(self.tempdir, dirname))), files_returned
+ )
+ for fname in fnames:
+ with open(fname, "r") as f:
+ if "A_mW" in fname:
+ self.assertEqual("99999.50", f.readline().strip())
+ self.assertEqual("100000.50", f.readline().strip())
+ if "B_mV" in fname:
+ self.assertEqual("1.50", f.readline().strip())
+ self.assertEqual("2.50", f.readline().strip())
+ self.assertEqual("3.50", f.readline().strip())
+
+ def test_SaveRawDataNoUnit(self):
+ """SaveRawData appends no unit suffix if the unit is not specified."""
+ self._populate_mock_stats_no_unit()
+ self.data.CalculateStats()
+ outdir = "unittest_raw_data"
+ files = self.data.SaveRawData(self.tempdir, outdir)
+ files = [os.path.basename(f) for f in files]
+ # Verify nothing gets appended to domain for filename if no unit exists.
+ self.assertIn("B.txt", files)
+
+ def test_SaveRawDataSMID(self):
+ """SaveRawData uses the smid when creating output filename."""
+ identifier = "ec"
+ self.data = stats_manager.StatsManager(smid=identifier)
+ self._populate_mock_stats()
+ files = self.data.SaveRawData(self.tempdir)
+ for fname in files:
+ self.assertTrue(os.path.basename(fname).startswith(identifier))
+
+ def test_SummaryToStringNaNHelp(self):
+ """NaN containing row gets tagged with *, help banner gets added."""
+ help_banner_exp = "%s %s" % (
+ stats_manager.STATS_PREFIX,
+ stats_manager.NAN_DESCRIPTION,
+ )
+ nan_domain = "A-domain"
+ nan_domain_exp = "%s%s" % (nan_domain, stats_manager.NAN_TAG)
+ # NaN helper banner is added when a NaN domain is found & domain gets tagged
+ data = stats_manager.StatsManager()
+ data.AddSample(nan_domain, float("NaN"))
+ data.AddSample(nan_domain, 17)
+ data.AddSample("B-domain", 17)
+ data.CalculateStats()
+ summarystr = data.SummaryToString()
+ self.assertIn(help_banner_exp, summarystr)
+ self.assertIn(nan_domain_exp, summarystr)
+ # NaN helper banner is not added when no NaN domain output, no tagging
+ data = stats_manager.StatsManager()
+ # nan_domain in this scenario does not contain any NaN
+ data.AddSample(nan_domain, 19)
+ data.AddSample("B-domain", 17)
+ data.CalculateStats()
+ summarystr = data.SummaryToString()
+ self.assertNotIn(help_banner_exp, summarystr)
+ self.assertNotIn(nan_domain_exp, summarystr)
+
+ def test_SummaryToStringTitle(self):
+ """Title shows up in SummaryToString if title specified."""
+ title = "titulo"
+ data = stats_manager.StatsManager(title=title)
+ self._populate_mock_stats()
+ summary_str = data.SummaryToString()
+ self.assertIn(title, summary_str)
+
+ def test_SummaryToStringHideDomains(self):
+ """Keys indicated in hide_domains are not printed in the summary."""
+ data = stats_manager.StatsManager(hide_domains=["A-domain"])
+ data.AddSample("A-domain", 17)
+ data.AddSample("B-domain", 17)
+ data.CalculateStats()
+ summary_str = data.SummaryToString()
+ self.assertIn("B-domain", summary_str)
+ self.assertNotIn("A-domain", summary_str)
+
+ def test_SummaryToStringOrder(self):
+ """Order passed into StatsManager is honoured when formatting summary."""
+ # StatsManager that should print D & B first, and the subsequent elements
+ # are sorted.
+ d_b_a_c_regexp = re.compile("D-domain.*B-domain.*A-domain.*C-domain", re.DOTALL)
+ data = stats_manager.StatsManager(order=["D-domain", "B-domain"])
+ data.AddSample("A-domain", 17)
+ data.AddSample("B-domain", 17)
+ data.AddSample("C-domain", 17)
+ data.AddSample("D-domain", 17)
+ data.CalculateStats()
+ summary_str = data.SummaryToString()
+ self.assertRegexpMatches(summary_str, d_b_a_c_regexp)
+
+ def test_MakeUniqueFName(self):
+ data = stats_manager.StatsManager()
+ testfile = os.path.join(self.tempdir, "testfile.txt")
+ with open(testfile, "w") as f:
+ f.write("")
+ expected_fname = os.path.join(self.tempdir, "testfile0.txt")
+ self.assertEqual(expected_fname, data._MakeUniqueFName(testfile))
+
+ def test_SaveSummary(self):
+ """SaveSummary properly dumps the summary into a file."""
+ self._populate_mock_stats()
+ fname = "unittest_summary.txt"
+ expected_fname = os.path.join(self.tempdir, fname)
+ fname = self.data.SaveSummary(self.tempdir, fname)
+ # Assert the reported fname is the same as the expected fname
+ self.assertEqual(expected_fname, fname)
+ # Assert only the reported fname is output (in the tempdir)
+ self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir)))
+ with open(fname, "r") as f:
+ self.assertEqual(
+ "@@ NAME COUNT MEAN STDDEV MAX MIN\n",
+ f.readline(),
+ )
+ self.assertEqual(
+ "@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n",
+ f.readline(),
+ )
+ self.assertEqual(
+ "@@ B_mV 3 2.50 0.82 3.50 1.50\n",
+ f.readline(),
+ )
+
+ def test_SaveSummarySMID(self):
+ """SaveSummary uses the smid when creating output filename."""
+ identifier = "ec"
+ self.data = stats_manager.StatsManager(smid=identifier)
+ self._populate_mock_stats()
+ fname = os.path.basename(self.data.SaveSummary(self.tempdir))
+ self.assertTrue(fname.startswith(identifier))
+
+ def test_SaveSummaryJSON(self):
+ """SaveSummaryJSON saves the added data properly in JSON format."""
+ self._populate_mock_stats()
+ fname = "unittest_summary.json"
+ expected_fname = os.path.join(self.tempdir, fname)
+ fname = self.data.SaveSummaryJSON(self.tempdir, fname)
+ # Assert the reported fname is the same as the expected fname
+ self.assertEqual(expected_fname, fname)
+ # Assert only the reported fname is output (in the tempdir)
+ self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir)))
+ with open(fname, "r") as f:
+ summary = json.load(f)
+ self.assertAlmostEqual(100000.0, summary["A"]["mean"])
+ self.assertEqual("milliwatt", summary["A"]["unit"])
+ self.assertAlmostEqual(2.5, summary["B"]["mean"])
+ self.assertEqual("millivolt", summary["B"]["unit"])
+
+ def test_SaveSummaryJSONSMID(self):
+ """SaveSummaryJSON uses the smid when creating output filename."""
+ identifier = "ec"
+ self.data = stats_manager.StatsManager(smid=identifier)
+ self._populate_mock_stats()
+ fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir))
+ self.assertTrue(fname.startswith(identifier))
+
+ def test_SaveSummaryJSONNoUnit(self):
+ """SaveSummaryJSON marks unknown units properly as N/A."""
+ self._populate_mock_stats_no_unit()
+ self.data.CalculateStats()
+ fname = "unittest_summary.json"
+ fname = self.data.SaveSummaryJSON(self.tempdir, fname)
+ with open(fname, "r") as f:
+ summary = json.load(f)
+ self.assertEqual("blue", summary["A"]["unit"])
+ # if no unit is specified, JSON should save 'N/A' as the unit.
+ self.assertEqual("N/A", summary["B"]["unit"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/extra/usb_serial/console.py b/extra/usb_serial/console.py
index 7211dceff6..c08cb72092 100755
--- a/extra/usb_serial/console.py
+++ b/extra/usb_serial/console.py
@@ -12,6 +12,7 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import argparse
import array
import os
@@ -20,23 +21,24 @@ import termios
import threading
import time
import tty
+
try:
- import usb
+ import usb
except ModuleNotFoundError:
- print('import usb failed')
- print('try running these commands:')
- print(' sudo apt-get install python-pip')
- print(' sudo pip install --pre pyusb')
- print()
- sys.exit(-1)
+ print("import usb failed")
+ print("try running these commands:")
+ print(" sudo apt-get install python-pip")
+ print(" sudo pip install --pre pyusb")
+ print()
+ sys.exit(-1)
import six
def GetBuffer(stream):
- if six.PY3:
- return stream.buffer
- return stream
+ if six.PY3:
+ return stream.buffer
+ return stream
"""Class Susb covers USB device discovery and initialization.
@@ -45,99 +47,104 @@ def GetBuffer(stream):
and interface number.
"""
+
class SusbError(Exception):
- """Class for exceptions of Susb."""
- def __init__(self, msg, value=0):
- """SusbError constructor.
+ """Class for exceptions of Susb."""
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SusbError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
-
-class Susb():
- """Provide USB functionality.
-
- Instance Variables:
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
- READ_ENDPOINT = 0x81
- WRITE_ENDPOINT = 0x1
- TIMEOUT_MS = 100
-
- def __init__(self, vendor=0x18d1,
- product=0x500f, interface=1, serialname=None):
- """Susb constructor.
-
- Discovers and connects to USB endpoints.
-
- Args:
- vendor: usb vendor id of device
- product: usb product id of device
- interface: interface number ( 1 - 8 ) of device to use
- serialname: string of device serialnumber.
-
- Raises:
- SusbError: An error accessing Susb object
+ def __init__(self, msg, value=0):
+ """SusbError constructor.
+
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SusbError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
+
+
+class Susb:
+ """Provide USB functionality.
+
+ Instance Variables:
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- # Find the device.
- dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
- dev_list = list(dev_g)
- if dev_list is None:
- raise SusbError('USB device not found')
-
- # Check if we have multiple devices.
- dev = None
- if serialname:
- for d in dev_list:
- dev_serial = "PyUSB doesn't have a stable interface"
- try:
- dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
- except:
- dev_serial = usb.util.get_string(d, d.iSerialNumber)
- if dev_serial == serialname:
- dev = d
- break
- if dev is None:
- raise SusbError('USB device(%s) not found' % (serialname,))
- else:
- try:
- dev = dev_list[0]
- except:
- try:
- dev = dev_list.next()
- except:
- raise SusbError('USB device %04x:%04x not found' % (vendor, product))
- # If we can't set configuration, it's already been set.
- try:
- dev.set_configuration()
- except usb.core.USBError:
- pass
+ READ_ENDPOINT = 0x81
+ WRITE_ENDPOINT = 0x1
+ TIMEOUT_MS = 100
+
+ def __init__(self, vendor=0x18D1, product=0x500F, interface=1, serialname=None):
+ """Susb constructor.
+
+ Discovers and connects to USB endpoints.
+
+ Args:
+ vendor: usb vendor id of device
+ product: usb product id of device
+ interface: interface number ( 1 - 8 ) of device to use
+ serialname: string of device serialnumber.
+
+ Raises:
+ SusbError: An error accessing Susb object
+ """
+ # Find the device.
+ dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
+ dev_list = list(dev_g)
+ if dev_list is None:
+ raise SusbError("USB device not found")
+
+ # Check if we have multiple devices.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ dev_serial = "PyUSB doesn't have a stable interface"
+ try:
+ dev_serial = usb.util.get_string(d, 256, d.iSerialNumber)
+ except:
+ dev_serial = usb.util.get_string(d, d.iSerialNumber)
+ if dev_serial == serialname:
+ dev = d
+ break
+ if dev is None:
+ raise SusbError("USB device(%s) not found" % (serialname,))
+ else:
+ try:
+ dev = dev_list[0]
+ except:
+ try:
+ dev = dev_list.next()
+ except:
+ raise SusbError(
+ "USB device %04x:%04x not found" % (vendor, product)
+ )
+
+ # If we can't set configuration, it's already been set.
+ try:
+ dev.set_configuration()
+ except usb.core.USBError:
+ pass
- # Get an endpoint instance.
- cfg = dev.get_active_configuration()
- intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface)
- self._intf = intf
+ # Get an endpoint instance.
+ cfg = dev.get_active_configuration()
+ intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface)
+ self._intf = intf
- if not intf:
- raise SusbError('Interface not found')
+ if not intf:
+ raise SusbError("Interface not found")
- # Detach raiden.ko if it is loaded.
- if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
+ # Detach raiden.ko if it is loaded.
+ if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True:
dev.detach_kernel_driver(intf.bInterfaceNumber)
- read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
- read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
- self._read_ep = read_ep
+ read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT
+ read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number)
+ self._read_ep = read_ep
- write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
- write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
- self._write_ep = write_ep
+ write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT
+ write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number)
+ self._write_ep = write_ep
"""Suart class implements a stream interface, to access Google's USB class.
@@ -146,89 +153,91 @@ class Susb():
and forwards them across. This particular class is hardcoded to stdin/out.
"""
-class SuartError(Exception):
- """Class for exceptions of Suart."""
- def __init__(self, msg, value=0):
- """SuartError constructor.
-
- Args:
- msg: string, message describing error in detail
- value: integer, value of error when non-zero status returned. Default=0
- """
- super(SuartError, self).__init__(msg, value)
- self.msg = msg
- self.value = value
-
-
-class Suart():
- """Provide interface to serial usb endpoint."""
- def __init__(self, vendor=0x18d1, product=0x501c, interface=0,
- serialname=None):
- """Suart contstructor.
+class SuartError(Exception):
+ """Class for exceptions of Suart."""
- Initializes USB stream interface.
+ def __init__(self, msg, value=0):
+ """SuartError constructor.
- Args:
- vendor: usb vendor id of device
- product: usb product id of device
- interface: interface number of device to use
- serialname: Defaults to None.
+ Args:
+ msg: string, message describing error in detail
+ value: integer, value of error when non-zero status returned. Default=0
+ """
+ super(SuartError, self).__init__(msg, value)
+ self.msg = msg
+ self.value = value
- Raises:
- SuartError: If init fails
- """
- self._done = threading.Event()
- self._susb = Susb(vendor=vendor, product=product,
- interface=interface, serialname=serialname)
- def wait_until_done(self, timeout=None):
- return self._done.wait(timeout=timeout)
+class Suart:
+ """Provide interface to serial usb endpoint."""
- def run_rx_thread(self):
- try:
- while True:
- try:
- r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
- if r:
- GetBuffer(sys.stdout).write(r.tobytes())
- GetBuffer(sys.stdout).flush()
-
- except Exception as e:
- # If we miss some characters on pty disconnect, that's fine.
- # ep.read() also throws USBError on timeout, which we discard.
- if not isinstance(e, (OSError, usb.core.USBError)):
- print('rx %s' % e)
- finally:
- self._done.set()
+ def __init__(self, vendor=0x18D1, product=0x501C, interface=0, serialname=None):
+ """Suart contstructor.
- def run_tx_thread(self):
- try:
- while True:
- try:
- r = GetBuffer(sys.stdin).read(1)
- if not r or r == b'\x03':
- break
- if r:
- self._susb._write_ep.write(array.array('B', r),
- self._susb.TIMEOUT_MS)
- except Exception as e:
- print('tx %s' % e)
- finally:
- self._done.set()
+ Initializes USB stream interface.
- def run(self):
- """Creates pthreads to poll USB & PTY for data."""
- self._exit = False
+ Args:
+ vendor: usb vendor id of device
+ product: usb product id of device
+ interface: interface number of device to use
+ serialname: Defaults to None.
- self._rx_thread = threading.Thread(target=self.run_rx_thread)
- self._rx_thread.daemon = True
- self._rx_thread.start()
+ Raises:
+ SuartError: If init fails
+ """
+ self._done = threading.Event()
+ self._susb = Susb(
+ vendor=vendor, product=product, interface=interface, serialname=serialname
+ )
- self._tx_thread = threading.Thread(target=self.run_tx_thread)
- self._tx_thread.daemon = True
- self._tx_thread.start()
+ def wait_until_done(self, timeout=None):
+ return self._done.wait(timeout=timeout)
+ def run_rx_thread(self):
+ try:
+ while True:
+ try:
+ r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS)
+ if r:
+ GetBuffer(sys.stdout).write(r.tobytes())
+ GetBuffer(sys.stdout).flush()
+
+ except Exception as e:
+ # If we miss some characters on pty disconnect, that's fine.
+ # ep.read() also throws USBError on timeout, which we discard.
+ if not isinstance(e, (OSError, usb.core.USBError)):
+ print("rx %s" % e)
+ finally:
+ self._done.set()
+
+ def run_tx_thread(self):
+ try:
+ while True:
+ try:
+ r = GetBuffer(sys.stdin).read(1)
+ if not r or r == b"\x03":
+ break
+ if r:
+ self._susb._write_ep.write(
+ array.array("B", r), self._susb.TIMEOUT_MS
+ )
+ except Exception as e:
+ print("tx %s" % e)
+ finally:
+ self._done.set()
+
+ def run(self):
+ """Creates pthreads to poll USB & PTY for data."""
+ self._exit = False
+
+ self._rx_thread = threading.Thread(target=self.run_rx_thread)
+ self._rx_thread.daemon = True
+ self._rx_thread.start()
+
+ self._tx_thread = threading.Thread(target=self.run_tx_thread)
+ self._tx_thread.daemon = True
+ self._tx_thread.start()
"""Command line functionality
@@ -237,59 +246,67 @@ class Suart():
Ctrl-C exits.
"""
-parser = argparse.ArgumentParser(description='Open a console to a USB device')
-parser.add_argument('-d', '--device', type=str,
- help='vid:pid of target device', default='18d1:501c')
-parser.add_argument('-i', '--interface', type=int,
- help='interface number of console', default=0)
-parser.add_argument('-s', '--serialno', type=str,
- help='serial number of device', default='')
-parser.add_argument('-S', '--notty-exit-sleep', type=float, default=0.2,
- help='When stdin is *not* a TTY, wait this many seconds '
- 'after EOF from stdin before exiting, to give time for '
- 'receiving a reply from the USB device.')
+parser = argparse.ArgumentParser(description="Open a console to a USB device")
+parser.add_argument(
+ "-d", "--device", type=str, help="vid:pid of target device", default="18d1:501c"
+)
+parser.add_argument(
+ "-i", "--interface", type=int, help="interface number of console", default=0
+)
+parser.add_argument(
+ "-s", "--serialno", type=str, help="serial number of device", default=""
+)
+parser.add_argument(
+ "-S",
+ "--notty-exit-sleep",
+ type=float,
+ default=0.2,
+ help="When stdin is *not* a TTY, wait this many seconds "
+ "after EOF from stdin before exiting, to give time for "
+ "receiving a reply from the USB device.",
+)
+
def runconsole():
- """Run the usb console code
+ """Run the usb console code
- Starts the pty thread, and idles until a ^C is caught.
- """
- args = parser.parse_args()
+ Starts the pty thread, and idles until a ^C is caught.
+ """
+ args = parser.parse_args()
- vidstr, pidstr = args.device.split(':')
- vid = int(vidstr, 16)
- pid = int(pidstr, 16)
+ vidstr, pidstr = args.device.split(":")
+ vid = int(vidstr, 16)
+ pid = int(pidstr, 16)
- serialno = args.serialno
- interface = args.interface
+ serialno = args.serialno
+ interface = args.interface
- sobj = Suart(vendor=vid, product=pid, interface=interface,
- serialname=serialno)
- if sys.stdin.isatty():
- tty.setraw(sys.stdin.fileno())
- sobj.run()
- sobj.wait_until_done()
- if not sys.stdin.isatty() and args.notty_exit_sleep > 0:
- time.sleep(args.notty_exit_sleep)
+ sobj = Suart(vendor=vid, product=pid, interface=interface, serialname=serialno)
+ if sys.stdin.isatty():
+ tty.setraw(sys.stdin.fileno())
+ sobj.run()
+ sobj.wait_until_done()
+ if not sys.stdin.isatty() and args.notty_exit_sleep > 0:
+ time.sleep(args.notty_exit_sleep)
def main():
- stdin_isatty = sys.stdin.isatty()
- if stdin_isatty:
- fd = sys.stdin.fileno()
- os.system('stty -echo')
- old_settings = termios.tcgetattr(fd)
-
- try:
- runconsole()
- finally:
+ stdin_isatty = sys.stdin.isatty()
if stdin_isatty:
- termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
- os.system('stty echo')
- # Avoid having the user's shell prompt start mid-line after the final output
- # from this program.
- print()
+ fd = sys.stdin.fileno()
+ os.system("stty -echo")
+ old_settings = termios.tcgetattr(fd)
+
+ try:
+ runconsole()
+ finally:
+ if stdin_isatty:
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
+ os.system("stty echo")
+ # Avoid having the user's shell prompt start mid-line after the final output
+ # from this program.
+ print()
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/extra/usb_updater/fw_update.py b/extra/usb_updater/fw_update.py
index 0d7a570fc3..f05797bfb6 100755
--- a/extra/usb_updater/fw_update.py
+++ b/extra/usb_updater/fw_update.py
@@ -20,407 +20,425 @@ import struct
import sys
import time
from pprint import pprint
-import usb
+import usb
debug = False
-def debuglog(msg):
- if debug:
- print(msg)
-
-def log(msg):
- print(msg)
- sys.stdout.flush()
-
-
-"""Sends firmware update to CROS EC usb endpoint."""
-
-class Supdate(object):
- """Class to access firmware update endpoints.
-
- Usage:
- d = Supdate()
- Instance Variables:
- _dev: pyUSB device object
- _read_ep: pyUSB read endpoint for this interface
- _write_ep: pyUSB write endpoint for this interface
- """
- USB_SUBCLASS_GOOGLE_UPDATE = 0x53
- USB_CLASS_VENDOR = 0xFF
- def __init__(self):
- pass
-
-
- def connect_usb(self, serialname=None):
- """Initial discovery and connection to USB endpoint.
-
- This searches for a USB device matching the VID:PID specified
- in the config file, optionally matching a specified serialname.
-
- Args:
- serialname: Find the device with this serial, in case multiple
- devices are attached.
-
- Returns:
- True on success.
- Raises:
- Exception on error.
- """
- # Find the stm32.
- vendor = self._brdcfg['vid']
- product = self._brdcfg['pid']
-
- dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
- dev_list = list(dev_g)
- if dev_list is None:
- raise Exception("Update", "USB device not found")
-
- # Check if we have multiple stm32s and we've specified the serial.
- dev = None
- if serialname:
- for d in dev_list:
- if usb.util.get_string(d, d.iSerialNumber) == serialname:
- dev = d
- break
- if dev is None:
- raise SusbError("USB device(%s) not found" % serialname)
- else:
- try:
- dev = dev_list[0]
- except:
- dev = dev_list.next()
-
- debuglog("Found stm32: %04x:%04x" % (vendor, product))
- self._dev = dev
-
- # Get an endpoint instance.
- try:
- dev.set_configuration()
- except:
- pass
- cfg = dev.get_active_configuration()
-
- intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \
- i.bInterfaceClass==self.USB_CLASS_VENDOR and \
- i.bInterfaceSubClass==self.USB_SUBCLASS_GOOGLE_UPDATE)
-
- self._intf = intf
- debuglog("Interface: %s" % intf)
- debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber)
-
- read_ep = usb.util.find_descriptor(
- intf,
- # match the first IN endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_IN
- )
-
- self._read_ep = read_ep
- debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress)
-
- write_ep = usb.util.find_descriptor(
- intf,
- # match the first OUT endpoint
- custom_match = \
- lambda e: \
- usb.util.endpoint_direction(e.bEndpointAddress) == \
- usb.util.ENDPOINT_OUT
- )
-
- self._write_ep = write_ep
- debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress)
-
- return True
-
-
- def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000):
- """Write command to logger logic..
-
- This function writes byte command values list to stm, then reads
- byte status.
-
- Args:
- write_list: list of command byte values [0~255].
- read_count: number of status byte values to read.
- wtimeout: mS to wait for write success
- rtimeout: mS to wait for read success
-
- Returns:
- status byte, if one byte is read,
- byte list, if multiple bytes are read,
- None, if no bytes are read.
-
- Interface:
- write: [command, data ... ]
- read: [status ]
- """
- debuglog("wr_command(write_list=[%s] (%d), read_count=%s)" % (
- list(bytearray(write_list)), len(write_list), read_count))
-
- # Clean up args from python style to correct types.
- write_length = 0
- if write_list:
- write_length = len(write_list)
- if not read_count:
- read_count = 0
-
- # Send command to stm32.
- if write_list:
- cmd = write_list
- ret = self._write_ep.write(cmd, wtimeout)
- debuglog("RET: %s " % ret)
-
- # Read back response if necessary.
- if read_count:
- bytesread = self._read_ep.read(512, rtimeout)
- debuglog("BYTES: [%s]" % bytesread)
-
- if len(bytesread) != read_count:
- debuglog("Unexpected bytes read: %d, expected: %d" % (len(bytesread), read_count))
- pass
-
- debuglog("STATUS: 0x%02x" % int(bytesread[0]))
- if read_count == 1:
- return bytesread[0]
- else:
- return bytesread
-
- return None
+def debuglog(msg):
+ if debug:
+ print(msg)
- def stop(self):
- """Finalize system flash and exit."""
- cmd = struct.pack(">I", 0xB007AB1E)
- read = self.wr_command(cmd, read_count=4)
- if len(read) == 4:
- log("Finished flashing")
- return
+def log(msg):
+ print(msg)
+ sys.stdout.flush()
- raise Exception("Update", "Stop failed [%s]" % read)
+"""Sends firmware update to CROS EC usb endpoint."""
- def write_file(self):
- """Write the update region packet by packet to USB
- This sends write packets of size 128B out, in 32B chunks.
- Overall, this will write all data in the inactive code region.
+class Supdate(object):
+ """Class to access firmware update endpoints.
- Raises:
- Exception if write failed or address out of bounds.
- """
- region = self._region
- flash_base = self._brdcfg["flash"]
- offset = self._base - flash_base
- if offset != self._brdcfg['regions'][region][0]:
- raise Exception("Update", "Region %s offset 0x%x != available offset 0x%x" % (
- region, self._brdcfg['regions'][region][0], offset))
-
- length = self._brdcfg['regions'][region][1]
- log("Sending")
-
- # Go to the correct region in the ec.bin file.
- self._binfile.seek(offset)
-
- # Send 32 bytes at a time. Must be less than the endpoint's max packet size.
- maxpacket = 32
-
- # While data is left, create update packets.
- while length > 0:
- # Update packets are 128B. We can use any number
- # but the micro must malloc this memory.
- pagesize = min(length, 128)
-
- # Packet is:
- # packet size: page bytes transferred plus 3 x 32b values header.
- # cmd: n/a
- # base: flash address to write this packet.
- # data: 128B of data to write into flash_base
- cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base)
- read = self.wr_command(cmd, read_count=0)
-
- # Push 'todo' bytes out the pipe.
- todo = pagesize
- while todo > 0:
- packetsize = min(maxpacket, todo)
- data = self._binfile.read(packetsize)
- if len(data) != packetsize:
- raise Exception("Update", "No more data from file")
- for i in range(0, 10):
- try:
- self.wr_command(data, read_count=0)
- break
- except:
- log("Timeout fail")
- todo -= packetsize
- # Done with this packet, move to the next one.
- length -= pagesize
- offset += pagesize
-
- # Validate that the micro thinks it successfully wrote the data.
- read = self.wr_command(''.encode(), read_count=4)
- result = struct.unpack("<I", read)
- result = result[0]
- if result != 0:
- raise Exception("Update", "Upload failed with rc: 0x%x" % result)
-
-
- def start(self):
- """Start a transaction and erase currently inactive region.
-
- This function sends a start command, and receives the base of the
- preferred inactive region. This could be RW, RW_B,
- or RO (if there's no RW_B)
-
- Note that the region is erased here, so you'd better program the RO if
- you just erased it. TODO(nsanders): Modify the protocol to allow active
- region select or query before erase.
- """
+ Usage:
+ d = Supdate()
- # Size is 3 uint32 fields
- # packet: [packetsize, cmd, base]
- size = 4 + 4 + 4
- # Return value is [status, base_addr]
- expected = 4 + 4
-
- cmd = struct.pack("<III", size, 0, 0)
- read = self.wr_command(cmd, read_count=expected)
-
- if len(read) == 4:
- raise Exception("Update", "Protocol version 0 not supported")
- elif len(read) == expected:
- base, version = struct.unpack(">II", read)
- log("Update protocol v. %d" % version)
- log("Available flash region base: %x" % base)
- else:
- raise Exception("Update", "Start command returned %d bytes" % len(read))
-
- if base < 256:
- raise Exception("Update", "Start returned error code 0x%x" % base)
-
- self._base = base
- flash_base = self._brdcfg["flash"]
- self._offset = self._base - flash_base
-
- # Find our active region.
- for region in self._brdcfg['regions']:
- if (self._offset >= self._brdcfg['regions'][region][0]) and \
- (self._offset < (self._brdcfg['regions'][region][0] + \
- self._brdcfg['regions'][region][1])):
- log("Active region: %s" % region)
- self._region = region
-
-
- def load_board(self, brdfile):
- """Load firmware layout file.
-
- example as follows:
- {
- "board": "servo micro",
- "vid": 6353,
- "pid": 20506,
- "flash": 134217728,
- "regions": {
- "RW": [65536, 65536],
- "PSTATE": [61440, 4096],
- "RO": [0, 61440]
- }
- }
-
- Args:
- brdfile: path to board description file.
+ Instance Variables:
+ _dev: pyUSB device object
+ _read_ep: pyUSB read endpoint for this interface
+ _write_ep: pyUSB write endpoint for this interface
"""
- with open(brdfile) as data_file:
- data = json.load(data_file)
-
- # TODO(nsanders): validate this data before moving on.
- self._brdcfg = data;
- if debug:
- pprint(data)
-
- log("Board is %s" % self._brdcfg['board'])
- # Cast hex strings to int.
- self._brdcfg['flash'] = int(self._brdcfg['flash'], 0)
- self._brdcfg['vid'] = int(self._brdcfg['vid'], 0)
- self._brdcfg['pid'] = int(self._brdcfg['pid'], 0)
- log("Flash Base is %x" % self._brdcfg['flash'])
- self._flashsize = 0
- for region in self._brdcfg['regions']:
- base = int(self._brdcfg['regions'][region][0], 0)
- length = int(self._brdcfg['regions'][region][1], 0)
- log("region %s\tbase:0x%08x size:0x%08x" % (
- region, base, length))
- self._flashsize += length
+ USB_SUBCLASS_GOOGLE_UPDATE = 0x53
+ USB_CLASS_VENDOR = 0xFF
- # Convert these to int because json doesn't support hex.
- self._brdcfg['regions'][region][0] = base
- self._brdcfg['regions'][region][1] = length
-
- log("Flash Size: 0x%x" % self._flashsize)
-
- def load_file(self, binfile):
- """Open and verify size of the target ec.bin file.
-
- Args:
- binfile: path to ec.bin
-
- Raises:
- Exception on file not found or filesize not matching.
- """
- self._filesize = os.path.getsize(binfile)
- self._binfile = open(binfile, 'rb')
-
- if self._filesize != self._flashsize:
- raise Exception("Update", "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize))
+ def __init__(self):
+ pass
+ def connect_usb(self, serialname=None):
+ """Initial discovery and connection to USB endpoint.
+
+ This searches for a USB device matching the VID:PID specified
+ in the config file, optionally matching a specified serialname.
+
+ Args:
+ serialname: Find the device with this serial, in case multiple
+ devices are attached.
+
+ Returns:
+ True on success.
+ Raises:
+ Exception on error.
+ """
+ # Find the stm32.
+ vendor = self._brdcfg["vid"]
+ product = self._brdcfg["pid"]
+
+ dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True)
+ dev_list = list(dev_g)
+ if dev_list is None:
+ raise Exception("Update", "USB device not found")
+
+ # Check if we have multiple stm32s and we've specified the serial.
+ dev = None
+ if serialname:
+ for d in dev_list:
+ if usb.util.get_string(d, d.iSerialNumber) == serialname:
+ dev = d
+ break
+ if dev is None:
+ raise SusbError("USB device(%s) not found" % serialname)
+ else:
+ try:
+ dev = dev_list[0]
+ except:
+ dev = dev_list.next()
+
+ debuglog("Found stm32: %04x:%04x" % (vendor, product))
+ self._dev = dev
+
+ # Get an endpoint instance.
+ try:
+ dev.set_configuration()
+ except:
+ pass
+ cfg = dev.get_active_configuration()
+
+ intf = usb.util.find_descriptor(
+ cfg,
+ custom_match=lambda i: i.bInterfaceClass == self.USB_CLASS_VENDOR
+ and i.bInterfaceSubClass == self.USB_SUBCLASS_GOOGLE_UPDATE,
+ )
+
+ self._intf = intf
+ debuglog("Interface: %s" % intf)
+ debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber)
+
+ read_ep = usb.util.find_descriptor(
+ intf,
+ # match the first IN endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_IN,
+ )
+
+ self._read_ep = read_ep
+ debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress)
+
+ write_ep = usb.util.find_descriptor(
+ intf,
+ # match the first OUT endpoint
+ custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress)
+ == usb.util.ENDPOINT_OUT,
+ )
+
+ self._write_ep = write_ep
+ debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress)
+
+ return True
+
+ def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000):
+ """Write command to logger logic..
+
+ This function writes byte command values list to stm, then reads
+ byte status.
+
+ Args:
+ write_list: list of command byte values [0~255].
+ read_count: number of status byte values to read.
+ wtimeout: mS to wait for write success
+ rtimeout: mS to wait for read success
+
+ Returns:
+ status byte, if one byte is read,
+ byte list, if multiple bytes are read,
+ None, if no bytes are read.
+
+ Interface:
+ write: [command, data ... ]
+ read: [status ]
+ """
+ debuglog(
+ "wr_command(write_list=[%s] (%d), read_count=%s)"
+ % (list(bytearray(write_list)), len(write_list), read_count)
+ )
+
+ # Clean up args from python style to correct types.
+ write_length = 0
+ if write_list:
+ write_length = len(write_list)
+ if not read_count:
+ read_count = 0
+
+ # Send command to stm32.
+ if write_list:
+ cmd = write_list
+ ret = self._write_ep.write(cmd, wtimeout)
+ debuglog("RET: %s " % ret)
+
+ # Read back response if necessary.
+ if read_count:
+ bytesread = self._read_ep.read(512, rtimeout)
+ debuglog("BYTES: [%s]" % bytesread)
+
+ if len(bytesread) != read_count:
+ debuglog(
+ "Unexpected bytes read: %d, expected: %d"
+ % (len(bytesread), read_count)
+ )
+ pass
+
+ debuglog("STATUS: 0x%02x" % int(bytesread[0]))
+ if read_count == 1:
+ return bytesread[0]
+ else:
+ return bytesread
+
+ return None
+
+ def stop(self):
+ """Finalize system flash and exit."""
+ cmd = struct.pack(">I", 0xB007AB1E)
+ read = self.wr_command(cmd, read_count=4)
+
+ if len(read) == 4:
+ log("Finished flashing")
+ return
+
+ raise Exception("Update", "Stop failed [%s]" % read)
+
+ def write_file(self):
+ """Write the update region packet by packet to USB
+
+ This sends write packets of size 128B out, in 32B chunks.
+ Overall, this will write all data in the inactive code region.
+
+ Raises:
+ Exception if write failed or address out of bounds.
+ """
+ region = self._region
+ flash_base = self._brdcfg["flash"]
+ offset = self._base - flash_base
+ if offset != self._brdcfg["regions"][region][0]:
+ raise Exception(
+ "Update",
+ "Region %s offset 0x%x != available offset 0x%x"
+ % (region, self._brdcfg["regions"][region][0], offset),
+ )
+
+ length = self._brdcfg["regions"][region][1]
+ log("Sending")
+
+ # Go to the correct region in the ec.bin file.
+ self._binfile.seek(offset)
+
+ # Send 32 bytes at a time. Must be less than the endpoint's max packet size.
+ maxpacket = 32
+
+ # While data is left, create update packets.
+ while length > 0:
+ # Update packets are 128B. We can use any number
+ # but the micro must malloc this memory.
+ pagesize = min(length, 128)
+
+ # Packet is:
+ # packet size: page bytes transferred plus 3 x 32b values header.
+ # cmd: n/a
+ # base: flash address to write this packet.
+ # data: 128B of data to write into flash_base
+ cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base)
+ read = self.wr_command(cmd, read_count=0)
+
+ # Push 'todo' bytes out the pipe.
+ todo = pagesize
+ while todo > 0:
+ packetsize = min(maxpacket, todo)
+ data = self._binfile.read(packetsize)
+ if len(data) != packetsize:
+ raise Exception("Update", "No more data from file")
+ for i in range(0, 10):
+ try:
+ self.wr_command(data, read_count=0)
+ break
+ except:
+ log("Timeout fail")
+ todo -= packetsize
+ # Done with this packet, move to the next one.
+ length -= pagesize
+ offset += pagesize
+
+ # Validate that the micro thinks it successfully wrote the data.
+ read = self.wr_command("".encode(), read_count=4)
+ result = struct.unpack("<I", read)
+ result = result[0]
+ if result != 0:
+ raise Exception("Update", "Upload failed with rc: 0x%x" % result)
+
+ def start(self):
+ """Start a transaction and erase currently inactive region.
+
+ This function sends a start command, and receives the base of the
+ preferred inactive region. This could be RW, RW_B,
+ or RO (if there's no RW_B)
+
+ Note that the region is erased here, so you'd better program the RO if
+ you just erased it. TODO(nsanders): Modify the protocol to allow active
+ region select or query before erase.
+ """
+
+ # Size is 3 uint32 fields
+ # packet: [packetsize, cmd, base]
+ size = 4 + 4 + 4
+ # Return value is [status, base_addr]
+ expected = 4 + 4
+
+ cmd = struct.pack("<III", size, 0, 0)
+ read = self.wr_command(cmd, read_count=expected)
+
+ if len(read) == 4:
+ raise Exception("Update", "Protocol version 0 not supported")
+ elif len(read) == expected:
+ base, version = struct.unpack(">II", read)
+ log("Update protocol v. %d" % version)
+ log("Available flash region base: %x" % base)
+ else:
+ raise Exception("Update", "Start command returned %d bytes" % len(read))
+
+ if base < 256:
+ raise Exception("Update", "Start returned error code 0x%x" % base)
+
+ self._base = base
+ flash_base = self._brdcfg["flash"]
+ self._offset = self._base - flash_base
+
+ # Find our active region.
+ for region in self._brdcfg["regions"]:
+ if (self._offset >= self._brdcfg["regions"][region][0]) and (
+ self._offset
+ < (
+ self._brdcfg["regions"][region][0]
+ + self._brdcfg["regions"][region][1]
+ )
+ ):
+ log("Active region: %s" % region)
+ self._region = region
+
+ def load_board(self, brdfile):
+ """Load firmware layout file.
+
+ example as follows:
+ {
+ "board": "servo micro",
+ "vid": 6353,
+ "pid": 20506,
+ "flash": 134217728,
+ "regions": {
+ "RW": [65536, 65536],
+ "PSTATE": [61440, 4096],
+ "RO": [0, 61440]
+ }
+ }
+
+ Args:
+ brdfile: path to board description file.
+ """
+ with open(brdfile) as data_file:
+ data = json.load(data_file)
+
+ # TODO(nsanders): validate this data before moving on.
+ self._brdcfg = data
+ if debug:
+ pprint(data)
+
+ log("Board is %s" % self._brdcfg["board"])
+ # Cast hex strings to int.
+ self._brdcfg["flash"] = int(self._brdcfg["flash"], 0)
+ self._brdcfg["vid"] = int(self._brdcfg["vid"], 0)
+ self._brdcfg["pid"] = int(self._brdcfg["pid"], 0)
+
+ log("Flash Base is %x" % self._brdcfg["flash"])
+ self._flashsize = 0
+ for region in self._brdcfg["regions"]:
+ base = int(self._brdcfg["regions"][region][0], 0)
+ length = int(self._brdcfg["regions"][region][1], 0)
+ log("region %s\tbase:0x%08x size:0x%08x" % (region, base, length))
+ self._flashsize += length
+
+ # Convert these to int because json doesn't support hex.
+ self._brdcfg["regions"][region][0] = base
+ self._brdcfg["regions"][region][1] = length
+
+ log("Flash Size: 0x%x" % self._flashsize)
+
+ def load_file(self, binfile):
+ """Open and verify size of the target ec.bin file.
+
+ Args:
+ binfile: path to ec.bin
+
+ Raises:
+ Exception on file not found or filesize not matching.
+ """
+ self._filesize = os.path.getsize(binfile)
+ self._binfile = open(binfile, "rb")
+
+ if self._filesize != self._flashsize:
+ raise Exception(
+ "Update",
+ "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize),
+ )
# Generate command line arguments
parser = argparse.ArgumentParser(description="Update firmware over usb")
-parser.add_argument('-b', '--board', type=str, help="Board configuration json file", default="board.json")
-parser.add_argument('-f', '--file', type=str, help="Complete ec.bin file", default="ec.bin")
-parser.add_argument('-s', '--serial', type=str, help="Serial number", default="")
-parser.add_argument('-l', '--list', action="store_true", help="List regions")
-parser.add_argument('-v', '--verbose', action="store_true", help="Chatty output")
+parser.add_argument(
+ "-b",
+ "--board",
+ type=str,
+ help="Board configuration json file",
+ default="board.json",
+)
+parser.add_argument(
+ "-f", "--file", type=str, help="Complete ec.bin file", default="ec.bin"
+)
+parser.add_argument("-s", "--serial", type=str, help="Serial number", default="")
+parser.add_argument("-l", "--list", action="store_true", help="List regions")
+parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output")
+
def main():
- global debug
- args = parser.parse_args()
+ global debug
+ args = parser.parse_args()
+ brdfile = args.board
+ serial = args.serial
+ binfile = args.file
+ if args.verbose:
+ debug = True
- brdfile = args.board
- serial = args.serial
- binfile = args.file
- if args.verbose:
- debug = True
+ with open(brdfile) as data_file:
+ names = json.load(data_file)
- with open(brdfile) as data_file:
- names = json.load(data_file)
+ p = Supdate()
+ p.load_board(brdfile)
+ p.connect_usb(serialname=serial)
+ p.load_file(binfile)
- p = Supdate()
- p.load_board(brdfile)
- p.connect_usb(serialname=serial)
- p.load_file(binfile)
+ # List solely prints the config.
+ if args.list:
+ return
- # List solely prints the config.
- if (args.list):
- return
+ # Start transfer and erase.
+ p.start()
+ # Upload the bin file
+ log("Uploading %s" % binfile)
+ p.write_file()
- # Start transfer and erase.
- p.start()
- # Upload the bin file
- log("Uploading %s" % binfile)
- p.write_file()
+ # Finalize
+ log("Done. Finalizing.")
+ p.stop()
- # Finalize
- log("Done. Finalizing.")
- p.stop()
if __name__ == "__main__":
- main()
-
-
+ main()
diff --git a/extra/usb_updater/servo_updater.py b/extra/usb_updater/servo_updater.py
index 432ee120de..5402af70aa 100755
--- a/extra/usb_updater/servo_updater.py
+++ b/extra/usb_updater/servo_updater.py
@@ -14,40 +14,46 @@
from __future__ import print_function
import argparse
+import json
import os
import re
import subprocess
import time
-import json
-
-import fw_update
import ecusb.tiny_servo_common as c
+import fw_update
from ecusb import tiny_servod
+
class ServoUpdaterException(Exception):
- """Raised on exceptions generated by servo_updater."""
+ """Raised on exceptions generated by servo_updater."""
-BOARD_C2D2 = 'c2d2'
-BOARD_SERVO_MICRO = 'servo_micro'
-BOARD_SERVO_V4 = 'servo_v4'
-BOARD_SERVO_V4P1 = 'servo_v4p1'
-BOARD_SWEETBERRY = 'sweetberry'
+
+BOARD_C2D2 = "c2d2"
+BOARD_SERVO_MICRO = "servo_micro"
+BOARD_SERVO_V4 = "servo_v4"
+BOARD_SERVO_V4P1 = "servo_v4p1"
+BOARD_SWEETBERRY = "sweetberry"
DEFAULT_BOARD = BOARD_SERVO_V4
# These lists are to facilitate exposing choices in the command-line tool
# below.
-BOARDS = [BOARD_C2D2, BOARD_SERVO_MICRO, BOARD_SERVO_V4, BOARD_SERVO_V4P1,
- BOARD_SWEETBERRY]
+BOARDS = [
+ BOARD_C2D2,
+ BOARD_SERVO_MICRO,
+ BOARD_SERVO_V4,
+ BOARD_SERVO_V4P1,
+ BOARD_SWEETBERRY,
+]
# Servo firmware bundles four channels of firmware. We need to make sure the
# user does not request a non-existing channel, so keep the lists around to
# guard on command-line usage.
-DEFAULT_CHANNEL = STABLE_CHANNEL = 'stable'
+DEFAULT_CHANNEL = STABLE_CHANNEL = "stable"
-PREV_CHANNEL = 'prev'
+PREV_CHANNEL = "prev"
# The ordering here matters. From left to right it's the channel that the user
# is most likely to be running. This is used to inform and warn the user if
@@ -55,403 +61,436 @@ PREV_CHANNEL = 'prev'
# user know they are running the 'stable' version before letting them know they
# are running 'dev' or even 'alpah' which (while true) might cause confusion.
-CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, 'dev', 'alpha']
+CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, "dev", "alpha"]
-DEFAULT_BASE_PATH = '/usr/'
-TEST_IMAGE_BASE_PATH = '/usr/local/'
+DEFAULT_BASE_PATH = "/usr/"
+TEST_IMAGE_BASE_PATH = "/usr/local/"
-COMMON_PATH = 'share/servo_updater'
+COMMON_PATH = "share/servo_updater"
-FIRMWARE_DIR = 'firmware/'
-CONFIGS_DIR = 'configs/'
+FIRMWARE_DIR = "firmware/"
+CONFIGS_DIR = "configs/"
RETRIES_COUNT = 10
RETRIES_DELAY = 1
+
def do_with_retries(func, *args):
- """Try a function several times
-
- Call function passed as argument and check if no error happened.
- If exception was raised by function,
- it will be retried up to RETRIES_COUNT times.
-
- Args:
- func: function that will be called
- args: arguments passed to 'func'
-
- Returns:
- If call to function was successful, its result will be returned.
- If retries count was exceeded, exception will be raised.
- """
-
- retry = 0
- while retry < RETRIES_COUNT:
- try:
- return func(*args)
- except Exception as e:
- print('Retrying function %s: %s' % (func.__name__, e))
- retry = retry + 1
- time.sleep(RETRIES_DELAY)
- continue
-
- raise Exception("'{}' failed after {} retries".format(
- func.__name__, RETRIES_COUNT))
+ """Try a function several times
+
+ Call function passed as argument and check if no error happened.
+ If exception was raised by function,
+ it will be retried up to RETRIES_COUNT times.
+
+ Args:
+ func: function that will be called
+ args: arguments passed to 'func'
+
+ Returns:
+ If call to function was successful, its result will be returned.
+ If retries count was exceeded, exception will be raised.
+ """
+
+ retry = 0
+ while retry < RETRIES_COUNT:
+ try:
+ return func(*args)
+ except Exception as e:
+ print("Retrying function %s: %s" % (func.__name__, e))
+ retry = retry + 1
+ time.sleep(RETRIES_DELAY)
+ continue
+
+ raise Exception("'{}' failed after {} retries".format(func.__name__, RETRIES_COUNT))
+
def flash(brdfile, serialno, binfile):
- """Call fw_update to upload to updater USB endpoint.
+ """Call fw_update to upload to updater USB endpoint.
+
+ Args:
+ brdfile: path to board configuration file
+ serialno: device serial number
+ binfile: firmware file
+ """
- Args:
- brdfile: path to board configuration file
- serialno: device serial number
- binfile: firmware file
- """
+ p = fw_update.Supdate()
+ p.load_board(brdfile)
+ p.connect_usb(serialname=serialno)
+ p.load_file(binfile)
- p = fw_update.Supdate()
- p.load_board(brdfile)
- p.connect_usb(serialname=serialno)
- p.load_file(binfile)
+ # Start transfer and erase.
+ p.start()
+ # Upload the bin file
+ print("Uploading %s" % binfile)
+ p.write_file()
- # Start transfer and erase.
- p.start()
- # Upload the bin file
- print('Uploading %s' % binfile)
- p.write_file()
+ # Finalize
+ print("Done. Finalizing.")
+ p.stop()
- # Finalize
- print('Done. Finalizing.')
- p.stop()
def flash2(vidpid, serialno, binfile):
- """Call fw update via usb_updater2 commandline.
-
- Args:
- vidpid: vendor id and product id of device
- serialno: device serial number (optional)
- binfile: firmware file
- """
-
- tool = 'usb_updater2'
- cmd = '%s -d %s' % (tool, vidpid)
- if serialno:
- cmd += ' -S %s' % serialno
- cmd += ' -n'
- cmd += ' %s' % binfile
-
- print(cmd)
- help_cmd = '%s --help' % tool
- with open('/dev/null') as devnull:
- valid_check = subprocess.call(help_cmd.split(), stdout=devnull,
- stderr=devnull)
- if valid_check:
- raise ServoUpdaterException('%s exit with res = %d. Make sure the tool '
- 'is available on the device.' % (help_cmd,
- valid_check))
- res = subprocess.call(cmd.split())
-
- if res in (0, 1, 2):
- return res
- else:
- raise ServoUpdaterException('%s exit with res = %d' % (cmd, res))
+ """Call fw update via usb_updater2 commandline.
+
+ Args:
+ vidpid: vendor id and product id of device
+ serialno: device serial number (optional)
+ binfile: firmware file
+ """
+
+ tool = "usb_updater2"
+ cmd = "%s -d %s" % (tool, vidpid)
+ if serialno:
+ cmd += " -S %s" % serialno
+ cmd += " -n"
+ cmd += " %s" % binfile
+
+ print(cmd)
+ help_cmd = "%s --help" % tool
+ with open("/dev/null") as devnull:
+ valid_check = subprocess.call(help_cmd.split(), stdout=devnull, stderr=devnull)
+ if valid_check:
+ raise ServoUpdaterException(
+ "%s exit with res = %d. Make sure the tool "
+ "is available on the device." % (help_cmd, valid_check)
+ )
+ res = subprocess.call(cmd.split())
+
+ if res in (0, 1, 2):
+ return res
+ else:
+ raise ServoUpdaterException("%s exit with res = %d" % (cmd, res))
+
def select(tinys, region):
- """Jump to specified boot region
+ """Jump to specified boot region
- Ensure the servo is in the expected ro/rw region.
- This function jumps to the required region and verify if jump was
- successful by executing 'sysinfo' command and reading current region.
- If response was not received or region is invalid, exception is raised.
+ Ensure the servo is in the expected ro/rw region.
+ This function jumps to the required region and verify if jump was
+ successful by executing 'sysinfo' command and reading current region.
+ If response was not received or region is invalid, exception is raised.
- Args:
- tinys: TinyServod object
- region: region to jump to, only "rw" and "ro" is allowed
- """
+ Args:
+ tinys: TinyServod object
+ region: region to jump to, only "rw" and "ro" is allowed
+ """
- if region not in ['rw', 'ro']:
- raise Exception('Region must be ro or rw')
+ if region not in ["rw", "ro"]:
+ raise Exception("Region must be ro or rw")
- if region == 'ro':
- cmd = 'reboot'
- else:
- cmd = 'sysjump %s' % region
+ if region == "ro":
+ cmd = "reboot"
+ else:
+ cmd = "sysjump %s" % region
- tinys.pty._issue_cmd(cmd)
+ tinys.pty._issue_cmd(cmd)
- tinys.close()
- time.sleep(2)
- tinys.reinitialize()
+ tinys.close()
+ time.sleep(2)
+ tinys.reinitialize()
+
+ res = tinys.pty._issue_cmd_get_results("sysinfo", [r"Copy:[\s]+(RO|RW)"])
+ current_region = res[0][1].lower()
+ if current_region != region:
+ raise Exception("Invalid region: %s/%s" % (current_region, region))
- res = tinys.pty._issue_cmd_get_results('sysinfo', [r'Copy:[\s]+(RO|RW)'])
- current_region = res[0][1].lower()
- if current_region != region:
- raise Exception('Invalid region: %s/%s' % (current_region, region))
def do_version(tinys):
- """Check version via ec console 'pty'.
+ """Check version via ec console 'pty'.
+
+ Args:
+ tinys: TinyServod object
- Args:
- tinys: TinyServod object
+ Returns:
+ detected version number
- Returns:
- detected version number
+ Commands are:
+ # > version
+ # ...
+ # Build: tigertail_v1.1.6749-74d1a312e
+ """
+ cmd = "version"
+ regex = r"Build:\s+(\S+)[\r\n]+"
- Commands are:
- # > version
- # ...
- # Build: tigertail_v1.1.6749-74d1a312e
- """
- cmd = 'version'
- regex = r'Build:\s+(\S+)[\r\n]+'
+ results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0]
- results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0]
+ return results[1].strip(" \t\r\n\0")
- return results[1].strip(' \t\r\n\0')
def do_updater_version(tinys):
- """Check whether this uses python updater or c++ updater
-
- Args:
- tinys: TinyServod object
-
- Returns:
- updater version number. 2 or 6.
- """
- vers = do_version(tinys)
-
- # Servo versions below 58 are from servo-9040.B. Versions starting with _v2
- # are newer than anything _v1, no need to check the exact number. Updater
- # version is not directly queryable.
- if re.search(r'_v[2-9]\.\d', vers):
- return 6
- m = re.search(r'_v1\.1\.(\d\d\d\d)', vers)
- if m:
- version_number = int(m.group(1))
- if version_number < 5800:
- return 2
- else:
- return 6
- raise ServoUpdaterException(
- "Can't determine updater target from vers: [%s]" % vers)
+ """Check whether this uses python updater or c++ updater
-def _extract_version(boardname, binfile):
- """Find the version string from |binfile|.
+ Args:
+ tinys: TinyServod object
- Args:
- boardname: the name of the board, eg. "servo_micro"
- binfile: path to the binary to search
+ Returns:
+ updater version number. 2 or 6.
+ """
+ vers = do_version(tinys)
- Returns:
- the version string.
- """
- if boardname is None:
- # cannot extract the version if the name is None
- return None
- rawstrings = subprocess.check_output(
- ['cbfstool', binfile, 'read', '-r', 'RO_FRID', '-f', '/dev/stdout'],
- **c.get_subprocess_args())
- m = re.match(r'%s_v\S+' % boardname, rawstrings)
- if m:
- newvers = m.group(0).strip(' \t\r\n\0')
- else:
- raise ServoUpdaterException("Can't find version from file: %s." % binfile)
+ # Servo versions below 58 are from servo-9040.B. Versions starting with _v2
+ # are newer than anything _v1, no need to check the exact number. Updater
+ # version is not directly queryable.
+ if re.search(r"_v[2-9]\.\d", vers):
+ return 6
+ m = re.search(r"_v1\.1\.(\d\d\d\d)", vers)
+ if m:
+ version_number = int(m.group(1))
+ if version_number < 5800:
+ return 2
+ else:
+ return 6
+ raise ServoUpdaterException("Can't determine updater target from vers: [%s]" % vers)
+
+
+def _extract_version(boardname, binfile):
+ """Find the version string from |binfile|.
+
+ Args:
+ boardname: the name of the board, eg. "servo_micro"
+ binfile: path to the binary to search
+
+ Returns:
+ the version string.
+ """
+ if boardname is None:
+ # cannot extract the version if the name is None
+ return None
+ rawstrings = subprocess.check_output(
+ ["cbfstool", binfile, "read", "-r", "RO_FRID", "-f", "/dev/stdout"],
+ **c.get_subprocess_args()
+ )
+ m = re.match(r"%s_v\S+" % boardname, rawstrings)
+ if m:
+ newvers = m.group(0).strip(" \t\r\n\0")
+ else:
+ raise ServoUpdaterException("Can't find version from file: %s." % binfile)
+
+ return newvers
- return newvers
def get_firmware_channel(bname, version):
- """Find out which channel |version| for |bname| came from.
-
- Args:
- bname: board name
- version: current version string
-
- Returns:
- one of the channel names if |version| came from one of those, or None
- """
- for channel in CHANNELS:
- # Pass |bname| as cname to find the board specific file, and pass None as
- # fname to ensure the default directory is searched
- _, _, vers = get_files_and_version(bname, None, channel=channel)
- if version == vers:
- return channel
- # None of the channels matched. This firmware is currently unknown.
- return None
+ """Find out which channel |version| for |bname| came from.
+
+ Args:
+ bname: board name
+ version: current version string
+
+ Returns:
+ one of the channel names if |version| came from one of those, or None
+ """
+ for channel in CHANNELS:
+ # Pass |bname| as cname to find the board specific file, and pass None as
+ # fname to ensure the default directory is searched
+ _, _, vers = get_files_and_version(bname, None, channel=channel)
+ if version == vers:
+ return channel
+ # None of the channels matched. This firmware is currently unknown.
+ return None
+
def get_files_and_version(cname, fname=None, channel=DEFAULT_CHANNEL):
- """Select config and firmware binary files.
-
- This checks default file names and paths.
- In: /usr/share/servo_updater/[firmware|configs]
- check for board.json, board.bin
-
- Args:
- cname: board name, or config name. eg. "servo_v4" or "servo_v4.json"
- fname: firmware binary name. Can be None to try default.
- channel: the channel requested for servo firmware. See |CHANNELS| above.
-
- Returns:
- cname, fname, version: validated filenames selected from the path.
- """
- for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH):
- updater_path = os.path.join(p, COMMON_PATH)
- if os.path.exists(updater_path):
- break
- else:
- raise ServoUpdaterException('servo_updater/ dir not found in known spots.')
-
- firmware_path = os.path.join(updater_path, FIRMWARE_DIR)
- configs_path = os.path.join(updater_path, CONFIGS_DIR)
-
- for p in (firmware_path, configs_path):
- if not os.path.exists(p):
- raise ServoUpdaterException('Could not find required path %r' % p)
-
- if not os.path.isfile(cname):
- # If not an existing file, try checking on the default path.
- newname = os.path.join(configs_path, cname)
- if os.path.isfile(newname):
- cname = newname
+ """Select config and firmware binary files.
+
+ This checks default file names and paths.
+ In: /usr/share/servo_updater/[firmware|configs]
+ check for board.json, board.bin
+
+ Args:
+ cname: board name, or config name. eg. "servo_v4" or "servo_v4.json"
+ fname: firmware binary name. Can be None to try default.
+ channel: the channel requested for servo firmware. See |CHANNELS| above.
+
+ Returns:
+ cname, fname, version: validated filenames selected from the path.
+ """
+ for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH):
+ updater_path = os.path.join(p, COMMON_PATH)
+ if os.path.exists(updater_path):
+ break
else:
- # Try appending ".json" to convert board name to config file.
- cname = newname + '.json'
+ raise ServoUpdaterException("servo_updater/ dir not found in known spots.")
+
+ firmware_path = os.path.join(updater_path, FIRMWARE_DIR)
+ configs_path = os.path.join(updater_path, CONFIGS_DIR)
+
+ for p in (firmware_path, configs_path):
+ if not os.path.exists(p):
+ raise ServoUpdaterException("Could not find required path %r" % p)
+
if not os.path.isfile(cname):
- raise ServoUpdaterException("Can't find config file: %s." % cname)
-
- # Always retrieve the boardname
- with open(cname) as data_file:
- data = json.load(data_file)
- boardname = data['board']
-
- if not fname:
- # If no |fname| supplied, look for the default locations with the board
- # and channel requested.
- binary_file = '%s.%s.bin' % (boardname, channel)
- newname = os.path.join(firmware_path, binary_file)
- if os.path.isfile(newname):
- fname = newname
- else:
- raise ServoUpdaterException("Can't find firmware binary: %s." %
- binary_file)
- elif not os.path.isfile(fname):
- # If a name is specified but not found, try the default path.
- newname = os.path.join(firmware_path, fname)
- if os.path.isfile(newname):
- fname = newname
+ # If not an existing file, try checking on the default path.
+ newname = os.path.join(configs_path, cname)
+ if os.path.isfile(newname):
+ cname = newname
+ else:
+ # Try appending ".json" to convert board name to config file.
+ cname = newname + ".json"
+ if not os.path.isfile(cname):
+ raise ServoUpdaterException("Can't find config file: %s." % cname)
+
+ # Always retrieve the boardname
+ with open(cname) as data_file:
+ data = json.load(data_file)
+ boardname = data["board"]
+
+ if not fname:
+ # If no |fname| supplied, look for the default locations with the board
+ # and channel requested.
+ binary_file = "%s.%s.bin" % (boardname, channel)
+ newname = os.path.join(firmware_path, binary_file)
+ if os.path.isfile(newname):
+ fname = newname
+ else:
+ raise ServoUpdaterException("Can't find firmware binary: %s." % binary_file)
+ elif not os.path.isfile(fname):
+ # If a name is specified but not found, try the default path.
+ newname = os.path.join(firmware_path, fname)
+ if os.path.isfile(newname):
+ fname = newname
+ else:
+ raise ServoUpdaterException("Can't find file: %s." % fname)
+
+ # Lastly, retrieve the version as well for decision making, debug, and
+ # informational purposes.
+ binvers = _extract_version(boardname, fname)
+
+ return cname, fname, binvers
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Image a servo device")
+ parser.add_argument(
+ "-p",
+ "--print",
+ dest="print_only",
+ action="store_true",
+ default=False,
+ help="only print available firmware for board/channel",
+ )
+ parser.add_argument(
+ "-s", "--serialno", type=str, help="serial number to program", default=None
+ )
+ parser.add_argument(
+ "-b",
+ "--board",
+ type=str,
+ help="Board configuration json file",
+ default=DEFAULT_BOARD,
+ choices=BOARDS,
+ )
+ parser.add_argument(
+ "-c",
+ "--channel",
+ type=str,
+ help="Firmware channel to use",
+ default=DEFAULT_CHANNEL,
+ choices=CHANNELS,
+ )
+ parser.add_argument(
+ "-f", "--file", type=str, help="Complete ec.bin file", default=None
+ )
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Update even if version match",
+ default=False,
+ )
+ parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output")
+ parser.add_argument(
+ "-r", "--reboot", action="store_true", help="Always reboot, even after probe."
+ )
+
+ args = parser.parse_args()
+
+ brdfile, binfile, newvers = get_files_and_version(
+ args.board, args.file, args.channel
+ )
+
+ # If the user only cares about the information then just print it here,
+ # and exit.
+ if args.print_only:
+ output = ("board: %s\n" "channel: %s\n" "firmware: %s") % (
+ args.board,
+ args.channel,
+ newvers,
+ )
+ print(output)
+ return
+
+ serialno = args.serialno
+
+ with open(brdfile) as data_file:
+ data = json.load(data_file)
+ vid, pid = int(data["vid"], 0), int(data["pid"], 0)
+ vidpid = "%04x:%04x" % (vid, pid)
+ iface = int(data["console"], 0)
+ boardname = data["board"]
+
+ # Make sure device is up.
+ print("===== Waiting for USB device =====")
+ c.wait_for_usb(vidpid, serialname=serialno)
+ # We need a tiny_servod to query some information. Set it up first.
+ tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose)
+
+ if not args.force:
+ vers = do_version(tinys)
+ print("Current %s version is %s" % (boardname, vers))
+ print("Available %s version is %s" % (boardname, newvers))
+
+ if newvers == vers:
+ print("No version update needed")
+ if args.reboot:
+ select(tinys, "ro")
+ return
+ else:
+ print("Updating to recommended version.")
+
+ # Make sure the servo MCU is in RO
+ print("===== Jumping to RO =====")
+ do_with_retries(select, tinys, "ro")
+
+ print("===== Flashing RW =====")
+ vers = do_with_retries(do_updater_version, tinys)
+ # To make sure that the tiny_servod here does not interfere with other
+ # processes, close it out.
+ tinys.close()
+
+ if vers == 2:
+ flash(brdfile, serialno, binfile)
+ elif vers == 6:
+ do_with_retries(flash2, vidpid, serialno, binfile)
else:
- raise ServoUpdaterException("Can't find file: %s." % fname)
+ raise ServoUpdaterException("Can't detect updater version")
- # Lastly, retrieve the version as well for decision making, debug, and
- # informational purposes.
- binvers = _extract_version(boardname, fname)
+ # Make sure device is up.
+ c.wait_for_usb(vidpid, serialname=serialno)
+ # After we have made sure that it's back/available, reconnect the tiny servod.
+ tinys.reinitialize()
- return cname, fname, binvers
+ # Make sure the servo MCU is in RW
+ print("===== Jumping to RW =====")
+ do_with_retries(select, tinys, "rw")
-def main():
- parser = argparse.ArgumentParser(description='Image a servo device')
- parser.add_argument('-p', '--print', dest='print_only', action='store_true',
- default=False,
- help='only print available firmware for board/channel')
- parser.add_argument('-s', '--serialno', type=str,
- help='serial number to program', default=None)
- parser.add_argument('-b', '--board', type=str,
- help='Board configuration json file',
- default=DEFAULT_BOARD, choices=BOARDS)
- parser.add_argument('-c', '--channel', type=str,
- help='Firmware channel to use',
- default=DEFAULT_CHANNEL, choices=CHANNELS)
- parser.add_argument('-f', '--file', type=str,
- help='Complete ec.bin file', default=None)
- parser.add_argument('--force', action='store_true',
- help='Update even if version match', default=False)
- parser.add_argument('-v', '--verbose', action='store_true',
- help='Chatty output')
- parser.add_argument('-r', '--reboot', action='store_true',
- help='Always reboot, even after probe.')
-
- args = parser.parse_args()
-
- brdfile, binfile, newvers = get_files_and_version(args.board, args.file,
- args.channel)
-
- # If the user only cares about the information then just print it here,
- # and exit.
- if args.print_only:
- output = ('board: %s\n'
- 'channel: %s\n'
- 'firmware: %s') % (args.board, args.channel, newvers)
- print(output)
- return
-
- serialno = args.serialno
-
- with open(brdfile) as data_file:
- data = json.load(data_file)
- vid, pid = int(data['vid'], 0), int(data['pid'], 0)
- vidpid = '%04x:%04x' % (vid, pid)
- iface = int(data['console'], 0)
- boardname = data['board']
-
- # Make sure device is up.
- print('===== Waiting for USB device =====')
- c.wait_for_usb(vidpid, serialname=serialno)
- # We need a tiny_servod to query some information. Set it up first.
- tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose)
-
- if not args.force:
- vers = do_version(tinys)
- print('Current %s version is %s' % (boardname, vers))
- print('Available %s version is %s' % (boardname, newvers))
-
- if newvers == vers:
- print('No version update needed')
- if args.reboot:
- select(tinys, 'ro')
- return
+ print("===== Flashing RO =====")
+ vers = do_with_retries(do_updater_version, tinys)
+
+ if vers == 2:
+ flash(brdfile, serialno, binfile)
+ elif vers == 6:
+ do_with_retries(flash2, vidpid, serialno, binfile)
else:
- print('Updating to recommended version.')
-
- # Make sure the servo MCU is in RO
- print('===== Jumping to RO =====')
- do_with_retries(select, tinys, 'ro')
-
- print('===== Flashing RW =====')
- vers = do_with_retries(do_updater_version, tinys)
- # To make sure that the tiny_servod here does not interfere with other
- # processes, close it out.
- tinys.close()
-
- if vers == 2:
- flash(brdfile, serialno, binfile)
- elif vers == 6:
- do_with_retries(flash2, vidpid, serialno, binfile)
- else:
- raise ServoUpdaterException("Can't detect updater version")
-
- # Make sure device is up.
- c.wait_for_usb(vidpid, serialname=serialno)
- # After we have made sure that it's back/available, reconnect the tiny servod.
- tinys.reinitialize()
-
- # Make sure the servo MCU is in RW
- print('===== Jumping to RW =====')
- do_with_retries(select, tinys, 'rw')
-
- print('===== Flashing RO =====')
- vers = do_with_retries(do_updater_version, tinys)
-
- if vers == 2:
- flash(brdfile, serialno, binfile)
- elif vers == 6:
- do_with_retries(flash2, vidpid, serialno, binfile)
- else:
- raise ServoUpdaterException("Can't detect updater version")
-
- # Make sure the servo MCU is in RO
- print('===== Rebooting =====')
- do_with_retries(select, tinys, 'ro')
- # Perform additional reboot to free USB/UART resources, taken by tiny servod.
- # See https://issuetracker.google.com/196021317 for background.
- tinys.pty._issue_cmd('reboot')
-
- print('===== Finished =====')
-
-if __name__ == '__main__':
- main()
+ raise ServoUpdaterException("Can't detect updater version")
+
+ # Make sure the servo MCU is in RO
+ print("===== Rebooting =====")
+ do_with_retries(select, tinys, "ro")
+ # Perform additional reboot to free USB/UART resources, taken by tiny servod.
+ # See https://issuetracker.google.com/196021317 for background.
+ tinys.pty._issue_cmd("reboot")
+
+ print("===== Finished =====")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/firmware_builder.py b/firmware_builder.py
index 78b7614190..d2548c3f5f 100755
--- a/firmware_builder.py
+++ b/firmware_builder.py
@@ -16,21 +16,20 @@ import pathlib
import subprocess
import sys
-# pylint: disable=import-error
-from google.protobuf import json_format
-
from chromite.api.gen_sdk.chromite.api import firmware_pb2
+# pylint: disable=import-error
+from google.protobuf import json_format
-DEFAULT_BUNDLE_DIRECTORY = '/tmp/artifact_bundles'
-DEFAULT_BUNDLE_METADATA_FILE = '/tmp/artifact_bundle_metadata'
+DEFAULT_BUNDLE_DIRECTORY = "/tmp/artifact_bundles"
+DEFAULT_BUNDLE_METADATA_FILE = "/tmp/artifact_bundle_metadata"
# The the list of boards whose on-device unit tests we will verify compilation.
# TODO(b/172501728) On-device unit tests should build for all boards, but
# they've bit rotted, so we only build the ones that compile.
BOARDS_UNIT_TEST = [
- 'bloonchipper',
- 'dartmonkey',
+ "bloonchipper",
+ "dartmonkey",
]
@@ -51,59 +50,57 @@ def build(opts):
"When --code-coverage is selected, 'build' is a no-op. "
"Run 'test' with --code-coverage instead."
)
- with open(opts.metrics, 'w') as f:
+ with open(opts.metrics, "w") as f:
f.write(json_format.MessageToJson(metric_list))
return
ec_dir = pathlib.Path(__file__).parent
subprocess.run([ec_dir / "util" / "check_clang_format.py"], check=True)
- cmd = ['make', 'buildall_only', f'-j{opts.cpus}']
+ cmd = ["make", "buildall_only", f"-j{opts.cpus}"]
print(f"# Running {' '.join(cmd)}.")
subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True)
ec_dir = os.path.dirname(__file__)
- build_dir = os.path.join(ec_dir, 'build')
+ build_dir = os.path.join(ec_dir, "build")
for build_target in sorted(os.listdir(build_dir)):
metric = metric_list.value.add()
metric.target_name = build_target
- metric.platform_name = 'ec'
- for variant in ['RO', 'RW']:
+ metric.platform_name = "ec"
+ for variant in ["RO", "RW"]:
memsize_file = (
pathlib.Path(build_dir)
/ build_target
/ variant
- / f'ec.{variant}.elf.memsize.txt'
+ / f"ec.{variant}.elf.memsize.txt"
)
if memsize_file.exists():
parse_memsize(memsize_file, metric, variant)
- with open(opts.metrics, 'w') as f:
+ with open(opts.metrics, "w") as f:
f.write(json_format.MessageToJson(metric_list))
# Ensure that there are no regressions for boards that build successfully
# with clang: b/172020503.
- cmd = ['./util/build_with_clang.py']
+ cmd = ["./util/build_with_clang.py"]
print(f'# Running {" ".join(cmd)}.')
- subprocess.run(cmd,
- cwd=os.path.dirname(__file__),
- check=True)
+ subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True)
UNITS = {
- 'B': 1,
- 'KB': 1024,
- 'MB': 1024 * 1024,
- 'GB': 1024 * 1024 * 1024,
+ "B": 1,
+ "KB": 1024,
+ "MB": 1024 * 1024,
+ "GB": 1024 * 1024 * 1024,
}
def parse_memsize(filename, metric, variant):
- with open(filename, 'r') as infile:
+ with open(filename, "r") as infile:
# Skip header line
infile.readline()
for line in infile.readlines():
parts = line.split()
fw_section = metric.fw_section.add()
- fw_section.region = variant + '_' + parts[0][:-1]
+ fw_section.region = variant + "_" + parts[0][:-1]
fw_section.used = int(parts[1]) * UNITS[parts[2]]
fw_section.total = int(parts[3]) * UNITS[parts[4]]
fw_section.track_on_gerrit = False
@@ -135,7 +132,7 @@ def write_metadata(opts, info):
bundle_metadata_file = (
opts.metadata if opts.metadata else DEFAULT_BUNDLE_METADATA_FILE
)
- with open(bundle_metadata_file, 'w') as f:
+ with open(bundle_metadata_file, "w") as f:
f.write(json_format.MessageToJson(info))
@@ -145,10 +142,10 @@ def bundle_coverage(opts):
info.bcs_version_info.version_string = opts.bcs_version
bundle_dir = get_bundle_dir(opts)
ec_dir = os.path.dirname(__file__)
- tarball_name = 'coverage.tbz2'
+ tarball_name = "coverage.tbz2"
tarball_path = os.path.join(bundle_dir, tarball_name)
- cmd = ['tar', 'cvfj', tarball_path, 'lcov.info']
- subprocess.run(cmd, cwd=os.path.join(ec_dir, 'build/coverage'), check=True)
+ cmd = ["tar", "cvfj", tarball_path, "lcov.info"]
+ subprocess.run(cmd, cwd=os.path.join(ec_dir, "build/coverage"), check=True)
meta = info.objects.add()
meta.file_name = tarball_name
meta.lcov_info.type = (
@@ -164,16 +161,20 @@ def bundle_firmware(opts):
info.bcs_version_info.version_string = opts.bcs_version
bundle_dir = get_bundle_dir(opts)
ec_dir = os.path.dirname(__file__)
- for build_target in sorted(os.listdir(os.path.join(ec_dir, 'build'))):
- tarball_name = ''.join([build_target, '.firmware.tbz2'])
+ for build_target in sorted(os.listdir(os.path.join(ec_dir, "build"))):
+ tarball_name = "".join([build_target, ".firmware.tbz2"])
tarball_path = os.path.join(bundle_dir, tarball_name)
cmd = [
- 'tar', 'cvfj', tarball_path,
- '--exclude=*.o.d', '--exclude=*.o', '.',
+ "tar",
+ "cvfj",
+ tarball_path,
+ "--exclude=*.o.d",
+ "--exclude=*.o",
+ ".",
]
subprocess.run(
cmd,
- cwd=os.path.join(ec_dir, 'build', build_target),
+ cwd=os.path.join(ec_dir, "build", build_target),
check=True,
)
meta = info.objects.add()
@@ -191,7 +192,7 @@ def test(opts):
"""Runs all of the unit tests for EC firmware"""
# TODO(b/169178847): Add appropriate metric information
metrics = firmware_pb2.FwTestMetricList()
- with open(opts.metrics, 'w') as f:
+ with open(opts.metrics, "w") as f:
f.write(json_format.MessageToJson(metrics))
# If building for code coverage, build the 'coverage' target, which
@@ -200,8 +201,8 @@ def test(opts):
#
# Otherwise, build the 'runtests' target, which verifies all
# posix-based unit tests build and pass.
- target = 'coverage' if opts.code_coverage else 'runtests'
- cmd = ['make', target, f'-j{opts.cpus}']
+ target = "coverage" if opts.code_coverage else "runtests"
+ cmd = ["make", target, f"-j{opts.cpus}"]
print(f"# Running {' '.join(cmd)}.")
subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True)
@@ -209,13 +210,13 @@ def test(opts):
# Verify compilation of the on-device unit test binaries.
# TODO(b/172501728) These should build for all boards, but they've bit
# rotted, so we only build the ones that compile.
- cmd = ['make', f'-j{opts.cpus}']
- cmd.extend(['tests-' + b for b in BOARDS_UNIT_TEST])
+ cmd = ["make", f"-j{opts.cpus}"]
+ cmd.extend(["tests-" + b for b in BOARDS_UNIT_TEST])
print(f"# Running {' '.join(cmd)}.")
subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True)
# Verify the tests pass with ASan also
- cmd = ['make', 'TEST_ASAN=y', target, f'-j{opts.cpus}']
+ cmd = ["make", "TEST_ASAN=y", target, f"-j{opts.cpus}"]
print(f"# Running {' '.join(cmd)}.")
subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True)
@@ -227,8 +228,8 @@ def main(args):
"""
opts = parse_args(args)
- if not hasattr(opts, 'func'):
- print('Must select a valid sub command!')
+ if not hasattr(opts, "func"):
+ print("Must select a valid sub command!")
return -1
# Run selected sub command function
@@ -244,66 +245,64 @@ def parse_args(args):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
- '--cpus',
+ "--cpus",
default=multiprocessing.cpu_count(),
- help='The number of cores to use.',
+ help="The number of cores to use.",
)
parser.add_argument(
- '--metrics',
- dest='metrics',
+ "--metrics",
+ dest="metrics",
required=True,
- help='File to write the json-encoded MetricsList proto message.',
+ help="File to write the json-encoded MetricsList proto message.",
)
parser.add_argument(
- '--metadata',
+ "--metadata",
required=False,
- help='Full pathname for the file in which to write build artifact '
- 'metadata.',
+ help="Full pathname for the file in which to write build artifact " "metadata.",
)
parser.add_argument(
- '--output-dir',
+ "--output-dir",
required=False,
- help='Full pathanme for the directory in which to bundle build '
- 'artifacts.',
+ help="Full pathanme for the directory in which to bundle build " "artifacts.",
)
parser.add_argument(
- '--code-coverage',
+ "--code-coverage",
required=False,
- action='store_true',
- help='Build host-based unit tests for code coverage.',
+ action="store_true",
+ help="Build host-based unit tests for code coverage.",
)
parser.add_argument(
- '--bcs-version',
- dest='bcs_version',
- default='',
+ "--bcs-version",
+ dest="bcs_version",
+ default="",
required=False,
# TODO(b/180008931): make this required=True.
- help='BCS version to include in metadata.',
+ help="BCS version to include in metadata.",
)
# Would make this required=True, but not available until 3.7
sub_cmds = parser.add_subparsers()
- build_cmd = sub_cmds.add_parser('build', help='Builds all firmware targets')
+ build_cmd = sub_cmds.add_parser("build", help="Builds all firmware targets")
build_cmd.set_defaults(func=build)
build_cmd = sub_cmds.add_parser(
- 'bundle',
- help='Creates a tarball containing build '
- 'artifacts from all firmware targets',
+ "bundle",
+ help="Creates a tarball containing build "
+ "artifacts from all firmware targets",
)
build_cmd.set_defaults(func=bundle)
- test_cmd = sub_cmds.add_parser('test', help='Runs all firmware unit tests')
+ test_cmd = sub_cmds.add_parser("test", help="Runs all firmware unit tests")
test_cmd.set_defaults(func=test)
return parser.parse_args(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/setup.py b/setup.py
index fc6c5d396b..2274d95c2f 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@ setup(
author="Aseda Aboagye",
author_email="aaboagye@chromium.org",
url="https://www.chromium.org/chromium-os/ec-development",
- package_dir={"" : "util"},
+ package_dir={"": "util"},
packages=["ec3po"],
py_modules=["ec3po.console", "ec3po.interpreter"],
description="EC console interpreter.",
@@ -22,7 +22,7 @@ setup(
author="Nick Sanders",
author_email="nsanders@chromium.org",
url="https://www.chromium.org/chromium-os/ec-development",
- package_dir={"" : "extra/tigertool"},
+ package_dir={"": "extra/tigertool"},
packages=["ecusb"],
description="Tiny implementation of servod.",
)
@@ -33,17 +33,23 @@ setup(
author="Nick Sanders",
author_email="nsanders@chromium.org",
url="https://www.chromium.org/chromium-os/ec-development",
- package_dir={"" : "extra/usb_updater"},
+ package_dir={"": "extra/usb_updater"},
py_modules=["servo_updater", "fw_update"],
- entry_points = {
+ entry_points={
"console_scripts": ["servo_updater=servo_updater:main"],
},
- data_files=[("share/servo_updater/configs",
- ["extra/usb_updater/c2d2.json",
- "extra/usb_updater/servo_v4.json",
- "extra/usb_updater/servo_v4p1.json",
- "extra/usb_updater/servo_micro.json",
- "extra/usb_updater/sweetberry.json"])],
+ data_files=[
+ (
+ "share/servo_updater/configs",
+ [
+ "extra/usb_updater/c2d2.json",
+ "extra/usb_updater/servo_v4.json",
+ "extra/usb_updater/servo_v4p1.json",
+ "extra/usb_updater/servo_micro.json",
+ "extra/usb_updater/sweetberry.json",
+ ],
+ )
+ ],
description="Servo usb updater.",
)
@@ -53,9 +59,9 @@ setup(
author="Nick Sanders",
author_email="nsanders@chromium.org",
url="https://www.chromium.org/chromium-os/ec-development",
- package_dir={"" : "extra/usb_power"},
+ package_dir={"": "extra/usb_power"},
py_modules=["powerlog", "stats_manager"],
- entry_points = {
+ entry_points={
"console_scripts": ["powerlog=powerlog:main"],
},
description="Sweetberry power logger.",
@@ -67,9 +73,9 @@ setup(
author="Nick Sanders",
author_email="nsanders@chromium.org",
url="https://www.chromium.org/chromium-os/ec-development",
- package_dir={"" : "extra/usb_serial"},
+ package_dir={"": "extra/usb_serial"},
py_modules=["console"],
- entry_points = {
+ entry_points={
"console_scripts": ["usb_console=console:main"],
},
description="Tool to open the usb console on servo, cr50.",
@@ -81,11 +87,10 @@ setup(
author="Wei-Han Chen",
author_email="stimim@chromium.org",
url="https://www.chromium.org/chromium-os/ec-development",
- package_dir={"" : "util"},
+ package_dir={"": "util"},
py_modules=["unpack_ftb"],
- entry_points = {
+ entry_points={
"console_scripts": ["unpack_ftb=unpack_ftb:main"],
},
description="Tool to convert ST touchpad .ftb file to .bin",
)
-
diff --git a/test/run_device_tests.py b/test/run_device_tests.py
index 8de2fa417c..09f255f43e 100755
--- a/test/run_device_tests.py
+++ b/test/run_device_tests.py
@@ -51,63 +51,74 @@ import time
from concurrent.futures.thread import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
-from typing import Optional, BinaryIO, List
+from typing import BinaryIO, List, Optional
# pylint: disable=import-error
import colorama # type: ignore[import]
-from contextlib2 import ExitStack
import fmap
+from contextlib2 import ExitStack
+
# pylint: enable=import-error
EC_DIR = Path(os.path.dirname(os.path.realpath(__file__))).parent
-JTRACE_FLASH_SCRIPT = os.path.join(EC_DIR, 'util/flash_jlink.py')
-SERVO_MICRO_FLASH_SCRIPT = os.path.join(EC_DIR, 'util/flash_ec')
+JTRACE_FLASH_SCRIPT = os.path.join(EC_DIR, "util/flash_jlink.py")
+SERVO_MICRO_FLASH_SCRIPT = os.path.join(EC_DIR, "util/flash_ec")
-ALL_TESTS_PASSED_REGEX = re.compile(r'Pass!\r\n')
-ALL_TESTS_FAILED_REGEX = re.compile(r'Fail! \(\d+ tests\)\r\n')
+ALL_TESTS_PASSED_REGEX = re.compile(r"Pass!\r\n")
+ALL_TESTS_FAILED_REGEX = re.compile(r"Fail! \(\d+ tests\)\r\n")
-SINGLE_CHECK_PASSED_REGEX = re.compile(r'Pass: .*')
-SINGLE_CHECK_FAILED_REGEX = re.compile(r'.*failed:.*')
+SINGLE_CHECK_PASSED_REGEX = re.compile(r"Pass: .*")
+SINGLE_CHECK_FAILED_REGEX = re.compile(r".*failed:.*")
-ASSERTION_FAILURE_REGEX = re.compile(r'ASSERTION FAILURE.*')
+ASSERTION_FAILURE_REGEX = re.compile(r"ASSERTION FAILURE.*")
DATA_ACCESS_VIOLATION_8020000_REGEX = re.compile(
- r'Data access violation, mfar = 8020000\r\n')
+ r"Data access violation, mfar = 8020000\r\n"
+)
DATA_ACCESS_VIOLATION_8040000_REGEX = re.compile(
- r'Data access violation, mfar = 8040000\r\n')
+ r"Data access violation, mfar = 8040000\r\n"
+)
DATA_ACCESS_VIOLATION_80C0000_REGEX = re.compile(
- r'Data access violation, mfar = 80c0000\r\n')
+ r"Data access violation, mfar = 80c0000\r\n"
+)
DATA_ACCESS_VIOLATION_80E0000_REGEX = re.compile(
- r'Data access violation, mfar = 80e0000\r\n')
+ r"Data access violation, mfar = 80e0000\r\n"
+)
DATA_ACCESS_VIOLATION_20000000_REGEX = re.compile(
- r'Data access violation, mfar = 20000000\r\n')
+ r"Data access violation, mfar = 20000000\r\n"
+)
DATA_ACCESS_VIOLATION_24000000_REGEX = re.compile(
- r'Data access violation, mfar = 24000000\r\n')
+ r"Data access violation, mfar = 24000000\r\n"
+)
-BLOONCHIPPER = 'bloonchipper'
-DARTMONKEY = 'dartmonkey'
+BLOONCHIPPER = "bloonchipper"
+DARTMONKEY = "dartmonkey"
-JTRACE = 'jtrace'
-SERVO_MICRO = 'servo_micro'
+JTRACE = "jtrace"
+SERVO_MICRO = "servo_micro"
-GCC = 'gcc'
-CLANG = 'clang'
+GCC = "gcc"
+CLANG = "clang"
-TEST_ASSETS_BUCKET = 'gs://chromiumos-test-assets-public/fpmcu/RO'
+TEST_ASSETS_BUCKET = "gs://chromiumos-test-assets-public/fpmcu/RO"
DARTMONKEY_IMAGE_PATH = os.path.join(
- TEST_ASSETS_BUCKET, 'dartmonkey_v2.0.2887-311310808.bin')
+ TEST_ASSETS_BUCKET, "dartmonkey_v2.0.2887-311310808.bin"
+)
NOCTURNE_FP_IMAGE_PATH = os.path.join(
- TEST_ASSETS_BUCKET, 'nocturne_fp_v2.2.64-58cf5974e.bin')
-NAMI_FP_IMAGE_PATH = os.path.join(
- TEST_ASSETS_BUCKET, 'nami_fp_v2.2.144-7a08e07eb.bin')
+ TEST_ASSETS_BUCKET, "nocturne_fp_v2.2.64-58cf5974e.bin"
+)
+NAMI_FP_IMAGE_PATH = os.path.join(TEST_ASSETS_BUCKET, "nami_fp_v2.2.144-7a08e07eb.bin")
BLOONCHIPPER_V4277_IMAGE_PATH = os.path.join(
- TEST_ASSETS_BUCKET, 'bloonchipper_v2.0.4277-9f652bb3.bin')
+ TEST_ASSETS_BUCKET, "bloonchipper_v2.0.4277-9f652bb3.bin"
+)
BLOONCHIPPER_V5938_IMAGE_PATH = os.path.join(
- TEST_ASSETS_BUCKET, 'bloonchipper_v2.0.5938-197506c1.bin')
+ TEST_ASSETS_BUCKET, "bloonchipper_v2.0.5938-197506c1.bin"
+)
class ImageType(Enum):
"""EC Image type to use for the test."""
+
RO = 1
RW = 2
@@ -115,9 +126,16 @@ class ImageType(Enum):
class BoardConfig:
"""Board-specific configuration."""
- def __init__(self, name, servo_uart_name, servo_power_enable,
- rollback_region0_regex, rollback_region1_regex, mpu_regex,
- variants):
+ def __init__(
+ self,
+ name,
+ servo_uart_name,
+ servo_power_enable,
+ rollback_region0_regex,
+ rollback_region1_regex,
+ mpu_regex,
+ variants,
+ ):
self.name = name
self.servo_uart_name = servo_uart_name
self.servo_power_enable = servo_power_enable
@@ -130,18 +148,31 @@ class BoardConfig:
class TestConfig:
"""Configuration for a given test."""
- def __init__(self, test_name, image_to_use=ImageType.RW,
- finish_regexes=None, fail_regexes=None, toggle_power=False,
- test_args=None, num_flash_attempts=2, timeout_secs=10,
- enable_hw_write_protect=False, ro_image=None, build_board=None,
- config_name=None):
+ def __init__(
+ self,
+ test_name,
+ image_to_use=ImageType.RW,
+ finish_regexes=None,
+ fail_regexes=None,
+ toggle_power=False,
+ test_args=None,
+ num_flash_attempts=2,
+ timeout_secs=10,
+ enable_hw_write_protect=False,
+ ro_image=None,
+ build_board=None,
+ config_name=None,
+ ):
if test_args is None:
test_args = []
if finish_regexes is None:
finish_regexes = [ALL_TESTS_PASSED_REGEX, ALL_TESTS_FAILED_REGEX]
if fail_regexes is None:
- fail_regexes = [SINGLE_CHECK_FAILED_REGEX, ALL_TESTS_FAILED_REGEX,
- ASSERTION_FAILURE_REGEX]
+ fail_regexes = [
+ SINGLE_CHECK_FAILED_REGEX,
+ ALL_TESTS_FAILED_REGEX,
+ ASSERTION_FAILURE_REGEX,
+ ]
if config_name is None:
config_name = test_name
@@ -177,68 +208,104 @@ class AllTests:
@staticmethod
def get_public_tests(board_config: BoardConfig) -> List[TestConfig]:
tests = [
- TestConfig(test_name='aes'),
- TestConfig(test_name='cec'),
- TestConfig(test_name='cortexm_fpu'),
- TestConfig(test_name='crc'),
- TestConfig(test_name='flash_physical', image_to_use=ImageType.RO,
- toggle_power=True),
- TestConfig(test_name='flash_write_protect',
- image_to_use=ImageType.RO,
- toggle_power=True, enable_hw_write_protect=True),
- TestConfig(test_name='fpsensor_hw'),
- TestConfig(config_name='fpsensor_spi_ro', test_name='fpsensor',
- image_to_use=ImageType.RO, test_args=['spi']),
- TestConfig(config_name='fpsensor_spi_rw', test_name='fpsensor',
- test_args=['spi']),
- TestConfig(config_name='fpsensor_uart_ro', test_name='fpsensor',
- image_to_use=ImageType.RO, test_args=['uart']),
- TestConfig(config_name='fpsensor_uart_rw', test_name='fpsensor',
- test_args=['uart']),
- TestConfig(config_name='mpu_ro', test_name='mpu',
- image_to_use=ImageType.RO,
- finish_regexes=[board_config.mpu_regex]),
- TestConfig(config_name='mpu_rw', test_name='mpu',
- finish_regexes=[board_config.mpu_regex]),
- TestConfig(test_name='mutex'),
- TestConfig(test_name='pingpong'),
- TestConfig(test_name='printf'),
- TestConfig(test_name='queue'),
- TestConfig(config_name='rollback_region0', test_name='rollback',
- finish_regexes=[board_config.rollback_region0_regex],
- test_args=['region0']),
- TestConfig(config_name='rollback_region1', test_name='rollback',
- finish_regexes=[board_config.rollback_region1_regex],
- test_args=['region1']),
- TestConfig(test_name='rollback_entropy', image_to_use=ImageType.RO),
- TestConfig(test_name='rtc'),
- TestConfig(test_name='sha256'),
- TestConfig(test_name='sha256_unrolled'),
- TestConfig(test_name='static_if'),
- TestConfig(test_name='stdlib'),
- TestConfig(config_name='system_is_locked_wp_on',
- test_name='system_is_locked', test_args=['wp_on'],
- toggle_power=True, enable_hw_write_protect=True),
- TestConfig(config_name='system_is_locked_wp_off',
- test_name='system_is_locked', test_args=['wp_off'],
- toggle_power=True, enable_hw_write_protect=False),
- TestConfig(test_name='timer_dos'),
- TestConfig(test_name='utils', timeout_secs=20),
- TestConfig(test_name='utils_str'),
+ TestConfig(test_name="aes"),
+ TestConfig(test_name="cec"),
+ TestConfig(test_name="cortexm_fpu"),
+ TestConfig(test_name="crc"),
+ TestConfig(
+ test_name="flash_physical", image_to_use=ImageType.RO, toggle_power=True
+ ),
+ TestConfig(
+ test_name="flash_write_protect",
+ image_to_use=ImageType.RO,
+ toggle_power=True,
+ enable_hw_write_protect=True,
+ ),
+ TestConfig(test_name="fpsensor_hw"),
+ TestConfig(
+ config_name="fpsensor_spi_ro",
+ test_name="fpsensor",
+ image_to_use=ImageType.RO,
+ test_args=["spi"],
+ ),
+ TestConfig(
+ config_name="fpsensor_spi_rw", test_name="fpsensor", test_args=["spi"]
+ ),
+ TestConfig(
+ config_name="fpsensor_uart_ro",
+ test_name="fpsensor",
+ image_to_use=ImageType.RO,
+ test_args=["uart"],
+ ),
+ TestConfig(
+ config_name="fpsensor_uart_rw", test_name="fpsensor", test_args=["uart"]
+ ),
+ TestConfig(
+ config_name="mpu_ro",
+ test_name="mpu",
+ image_to_use=ImageType.RO,
+ finish_regexes=[board_config.mpu_regex],
+ ),
+ TestConfig(
+ config_name="mpu_rw",
+ test_name="mpu",
+ finish_regexes=[board_config.mpu_regex],
+ ),
+ TestConfig(test_name="mutex"),
+ TestConfig(test_name="pingpong"),
+ TestConfig(test_name="printf"),
+ TestConfig(test_name="queue"),
+ TestConfig(
+ config_name="rollback_region0",
+ test_name="rollback",
+ finish_regexes=[board_config.rollback_region0_regex],
+ test_args=["region0"],
+ ),
+ TestConfig(
+ config_name="rollback_region1",
+ test_name="rollback",
+ finish_regexes=[board_config.rollback_region1_regex],
+ test_args=["region1"],
+ ),
+ TestConfig(test_name="rollback_entropy", image_to_use=ImageType.RO),
+ TestConfig(test_name="rtc"),
+ TestConfig(test_name="sha256"),
+ TestConfig(test_name="sha256_unrolled"),
+ TestConfig(test_name="static_if"),
+ TestConfig(test_name="stdlib"),
+ TestConfig(
+ config_name="system_is_locked_wp_on",
+ test_name="system_is_locked",
+ test_args=["wp_on"],
+ toggle_power=True,
+ enable_hw_write_protect=True,
+ ),
+ TestConfig(
+ config_name="system_is_locked_wp_off",
+ test_name="system_is_locked",
+ test_args=["wp_off"],
+ toggle_power=True,
+ enable_hw_write_protect=False,
+ ),
+ TestConfig(test_name="timer_dos"),
+ TestConfig(test_name="utils", timeout_secs=20),
+ TestConfig(test_name="utils_str"),
]
if board_config.name == BLOONCHIPPER:
- tests.append(TestConfig(test_name='stm32f_rtc'))
+ tests.append(TestConfig(test_name="stm32f_rtc"))
# Run panic data tests for all boards and RO versions.
for variant_name, variant_info in board_config.variants.items():
tests.append(
- TestConfig(config_name='panic_data_' + variant_name,
- test_name='panic_data',
- fail_regexes=[SINGLE_CHECK_FAILED_REGEX,
- ALL_TESTS_FAILED_REGEX],
- ro_image=variant_info.get('ro_image_path'),
- build_board=variant_info.get('build_board')))
+ TestConfig(
+ config_name="panic_data_" + variant_name,
+ test_name="panic_data",
+ fail_regexes=[SINGLE_CHECK_FAILED_REGEX, ALL_TESTS_FAILED_REGEX],
+ ro_image=variant_info.get("ro_image_path"),
+ build_board=variant_info.get("build_board"),
+ )
+ )
return tests
@@ -248,75 +315,72 @@ class AllTests:
tests = []
try:
current_dir = os.path.dirname(__file__)
- private_dir = os.path.join(current_dir, os.pardir, 'private/test')
+ private_dir = os.path.join(current_dir, os.pardir, "private/test")
have_private = os.path.isdir(private_dir)
if not have_private:
return []
sys.path.append(private_dir)
import private_tests # pylint: disable=import-error
+
for test_args in private_tests.tests:
tests.append(TestConfig(**test_args))
# Catch all exceptions to avoid disruptions in public repo
except BaseException as e:
- logging.debug('Failed to get list of private tests: %s', str(e))
- logging.debug('Ignore error and continue.')
+ logging.debug("Failed to get list of private tests: %s", str(e))
+ logging.debug("Ignore error and continue.")
return []
return tests
BLOONCHIPPER_CONFIG = BoardConfig(
name=BLOONCHIPPER,
- servo_uart_name='raw_fpmcu_console_uart_pty',
- servo_power_enable='fpmcu_pp3300',
+ servo_uart_name="raw_fpmcu_console_uart_pty",
+ servo_power_enable="fpmcu_pp3300",
rollback_region0_regex=DATA_ACCESS_VIOLATION_8020000_REGEX,
rollback_region1_regex=DATA_ACCESS_VIOLATION_8040000_REGEX,
mpu_regex=DATA_ACCESS_VIOLATION_20000000_REGEX,
variants={
- 'bloonchipper_v2.0.4277': {
- 'ro_image_path': BLOONCHIPPER_V4277_IMAGE_PATH
- },
- 'bloonchipper_v2.0.5938': {
- 'ro_image_path': BLOONCHIPPER_V5938_IMAGE_PATH
- }
- }
+ "bloonchipper_v2.0.4277": {"ro_image_path": BLOONCHIPPER_V4277_IMAGE_PATH},
+ "bloonchipper_v2.0.5938": {"ro_image_path": BLOONCHIPPER_V5938_IMAGE_PATH},
+ },
)
DARTMONKEY_CONFIG = BoardConfig(
name=DARTMONKEY,
- servo_uart_name='raw_fpmcu_console_uart_pty',
- servo_power_enable='fpmcu_pp3300',
+ servo_uart_name="raw_fpmcu_console_uart_pty",
+ servo_power_enable="fpmcu_pp3300",
rollback_region0_regex=DATA_ACCESS_VIOLATION_80C0000_REGEX,
rollback_region1_regex=DATA_ACCESS_VIOLATION_80E0000_REGEX,
mpu_regex=DATA_ACCESS_VIOLATION_24000000_REGEX,
# For dartmonkey board, run panic data test also on nocturne_fp and
# nami_fp boards with appropriate RO image.
variants={
- 'dartmonkey_v2.0.2887': {
- 'ro_image_path': DARTMONKEY_IMAGE_PATH
+ "dartmonkey_v2.0.2887": {"ro_image_path": DARTMONKEY_IMAGE_PATH},
+ "nocturne_fp_v2.2.64": {
+ "ro_image_path": NOCTURNE_FP_IMAGE_PATH,
+ "build_board": "nocturne_fp",
},
- 'nocturne_fp_v2.2.64': {
- 'ro_image_path': NOCTURNE_FP_IMAGE_PATH,
- 'build_board': 'nocturne_fp'
+ "nami_fp_v2.2.144": {
+ "ro_image_path": NAMI_FP_IMAGE_PATH,
+ "build_board": "nami_fp",
},
- 'nami_fp_v2.2.144': {
- 'ro_image_path': NAMI_FP_IMAGE_PATH,
- 'build_board': 'nami_fp'
- }
- }
+ },
)
BOARD_CONFIGS = {
- 'bloonchipper': BLOONCHIPPER_CONFIG,
- 'dartmonkey': DARTMONKEY_CONFIG,
+ "bloonchipper": BLOONCHIPPER_CONFIG,
+ "dartmonkey": DARTMONKEY_CONFIG,
}
def read_file_gsutil(path: str) -> bytes:
"""Get data from bucket, using gsutil tool"""
- cmd = ['gsutil', 'cat', path]
+ cmd = ["gsutil", "cat", path]
- logging.debug('Running command: "%s"', ' '.join(cmd))
- gsutil = subprocess.run(cmd, stdout=subprocess.PIPE) # pylint: disable=subprocess-run-check
+ logging.debug('Running command: "%s"', " ".join(cmd))
+ gsutil = subprocess.run(
+ cmd, stdout=subprocess.PIPE
+ ) # pylint: disable=subprocess-run-check
gsutil.check_returncode()
return gsutil.stdout
@@ -324,9 +388,9 @@ def read_file_gsutil(path: str) -> bytes:
def find_section_offset_size(section: str, image: bytes) -> (int, int):
"""Get offset and size of the section in image"""
- areas = fmap.fmap_decode(image)['areas']
- area = next(area for area in areas if area['name'] == section)
- return area['offset'], area['size']
+ areas = fmap.fmap_decode(image)["areas"]
+ area = next(area for area in areas if area["name"] == section)
+ return area["offset"], area["size"]
def read_section(src: bytes, section: str) -> bytes:
@@ -341,10 +405,10 @@ def write_section(data: bytes, image: bytearray, section: str):
(section_start, section_size) = find_section_offset_size(section, image)
if section_size < len(data):
- raise ValueError(section + ' section size is not enough to store data')
+ raise ValueError(section + " section size is not enough to store data")
section_end = section_start + section_size
- filling = bytes([0xff for _ in range(section_size - len(data))])
+ filling = bytes([0xFF for _ in range(section_size - len(data))])
image[section_start:section_end] = data + filling
@@ -355,12 +419,14 @@ def copy_section(src: bytes, dst: bytearray, section: str):
(dst_start, dst_size) = find_section_offset_size(section, dst)
if dst_size < src_size:
- raise ValueError('Section ' + section + ' from source image has '
- 'greater size than the section in destination image')
+ raise ValueError(
+ "Section " + section + " from source image has "
+ "greater size than the section in destination image"
+ )
src_end = src_start + src_size
dst_end = dst_start + dst_size
- filling = bytes([0xff for _ in range(dst_size - src_size)])
+ filling = bytes([0xFF for _ in range(dst_size - src_size)])
dst[dst_start:dst_end] = src[src_start:src_end] + filling
@@ -368,28 +434,28 @@ def copy_section(src: bytes, dst: bytearray, section: str):
def replace_ro(image: bytearray, ro: bytes):
"""Replace RO in image with provided one"""
# Backup RO public key since its private part was used to sign RW.
- ro_pubkey = read_section(image, 'KEY_RO')
+ ro_pubkey = read_section(image, "KEY_RO")
# Copy RO part of the firmware to the image. Please note that RO public key
# is copied too since EC_RO area includes KEY_RO area.
- copy_section(ro, image, 'EC_RO')
+ copy_section(ro, image, "EC_RO")
# Restore RO public key.
- write_section(ro_pubkey, image, 'KEY_RO')
+ write_section(ro_pubkey, image, "KEY_RO")
def get_console(board_config: BoardConfig) -> Optional[str]:
"""Get the name of the console for a given board."""
cmd = [
- 'dut-control',
+ "dut-control",
board_config.servo_uart_name,
]
- logging.debug('Running command: "%s"', ' '.join(cmd))
+ logging.debug('Running command: "%s"', " ".join(cmd))
with subprocess.Popen(cmd, stdout=subprocess.PIPE) as proc:
for line in io.TextIOWrapper(proc.stdout): # type: ignore[arg-type]
logging.debug(line)
- pty = line.split(':')
+ pty = line.split(":")
if len(pty) == 2 and pty[0] == board_config.servo_uart_name:
return pty[1].strip()
@@ -399,77 +465,82 @@ def get_console(board_config: BoardConfig) -> Optional[str]:
def power(board_config: BoardConfig, on: bool) -> None:
"""Turn power to board on/off."""
if on:
- state = 'pp3300'
+ state = "pp3300"
else:
- state = 'off'
+ state = "off"
cmd = [
- 'dut-control',
- board_config.servo_power_enable + ':' + state,
+ "dut-control",
+ board_config.servo_power_enable + ":" + state,
]
- logging.debug('Running command: "%s"', ' '.join(cmd))
+ logging.debug('Running command: "%s"', " ".join(cmd))
subprocess.run(cmd).check_returncode() # pylint: disable=subprocess-run-check
def hw_write_protect(enable: bool) -> None:
"""Enable/disable hardware write protect."""
if enable:
- state = 'force_on'
+ state = "force_on"
else:
- state = 'force_off'
+ state = "force_off"
cmd = [
- 'dut-control',
- 'fw_wp_state:' + state,
- ]
- logging.debug('Running command: "%s"', ' '.join(cmd))
+ "dut-control",
+ "fw_wp_state:" + state,
+ ]
+ logging.debug('Running command: "%s"', " ".join(cmd))
subprocess.run(cmd).check_returncode() # pylint: disable=subprocess-run-check
def build(test_name: str, board_name: str, compiler: str) -> None:
"""Build specified test for specified board."""
- cmd = ['make']
+ cmd = ["make"]
if compiler == CLANG:
- cmd = cmd + ['CC=arm-none-eabi-clang']
+ cmd = cmd + ["CC=arm-none-eabi-clang"]
cmd = cmd + [
- 'BOARD=' + board_name,
- 'test-' + test_name,
- '-j',
+ "BOARD=" + board_name,
+ "test-" + test_name,
+ "-j",
]
- logging.debug('Running command: "%s"', ' '.join(cmd))
+ logging.debug('Running command: "%s"', " ".join(cmd))
subprocess.run(cmd).check_returncode() # pylint: disable=subprocess-run-check
-def flash(image_path: str, board: str, flasher: str, remote_ip: str,
- remote_port: int) -> bool:
+def flash(
+ image_path: str, board: str, flasher: str, remote_ip: str, remote_port: int
+) -> bool:
"""Flash specified test to specified board."""
- logging.info('Flashing test')
+ logging.info("Flashing test")
cmd = []
if flasher == JTRACE:
cmd.append(JTRACE_FLASH_SCRIPT)
if remote_ip:
- cmd.extend(['--remote', remote_ip + ':' + str(remote_port)])
+ cmd.extend(["--remote", remote_ip + ":" + str(remote_port)])
elif flasher == SERVO_MICRO:
cmd.append(SERVO_MICRO_FLASH_SCRIPT)
else:
logging.error('Unknown flasher: "%s"', flasher)
return False
- cmd.extend([
- '--board', board,
- '--image', image_path,
- ])
- logging.debug('Running command: "%s"', ' '.join(cmd))
+ cmd.extend(
+ [
+ "--board",
+ board,
+ "--image",
+ image_path,
+ ]
+ )
+ logging.debug('Running command: "%s"', " ".join(cmd))
completed_process = subprocess.run(cmd) # pylint: disable=subprocess-run-check
return completed_process.returncode == 0
def patch_image(test: TestConfig, image_path: str):
"""Replace RO part of the firmware with provided one."""
- with open(image_path, 'rb+') as f:
+ with open(image_path, "rb+") as f:
image = bytearray(f.read())
ro = read_file_gsutil(test.ro_image)
replace_ro(image, ro)
@@ -478,8 +549,9 @@ def patch_image(test: TestConfig, image_path: str):
f.truncate()
-def readline(executor: ThreadPoolExecutor, f: BinaryIO, timeout_secs: int) -> \
- Optional[bytes]:
+def readline(
+ executor: ThreadPoolExecutor, f: BinaryIO, timeout_secs: int
+) -> Optional[bytes]:
"""Read a line with timeout."""
a = executor.submit(f.readline)
try:
@@ -488,8 +560,7 @@ def readline(executor: ThreadPoolExecutor, f: BinaryIO, timeout_secs: int) -> \
return None
-def readlines_until_timeout(executor, f: BinaryIO, timeout_secs: int) -> \
- List[bytes]:
+def readlines_until_timeout(executor, f: BinaryIO, timeout_secs: int) -> List[bytes]:
"""Continuously read lines for timeout_secs."""
lines: List[bytes] = []
while True:
@@ -519,19 +590,20 @@ def process_console_output_line(line: bytes, test: TestConfig):
return None
-def run_test(test: TestConfig, console: io.FileIO,
- executor: ThreadPoolExecutor) -> bool:
+def run_test(
+ test: TestConfig, console: io.FileIO, executor: ThreadPoolExecutor
+) -> bool:
"""Run specified test."""
start = time.time()
# Wait for boot to finish
time.sleep(1)
- console.write('\n'.encode())
+ console.write("\n".encode())
if test.image_to_use == ImageType.RO:
- console.write('reboot ro\n'.encode())
+ console.write("reboot ro\n".encode())
time.sleep(1)
- test_cmd = 'runtest ' + ' '.join(test.test_args) + '\n'
+ test_cmd = "runtest " + " ".join(test.test_args) + "\n"
console.write(test_cmd.encode())
while True:
@@ -540,7 +612,7 @@ def run_test(test: TestConfig, console: io.FileIO,
if not line:
now = time.time()
if now - start > test.timeout_secs:
- logging.debug('Test timed out')
+ logging.debug("Test timed out")
return False
continue
@@ -569,15 +641,18 @@ def run_test(test: TestConfig, console: io.FileIO,
def get_test_list(config: BoardConfig, test_args) -> List[TestConfig]:
"""Get a list of tests to run."""
- if test_args == 'all':
+ if test_args == "all":
return AllTests.get(config)
test_list = []
for t in test_args:
- logging.debug('test: %s', t)
+ logging.debug("test: %s", t)
test_regex = re.compile(t)
- tests = [test for test in AllTests.get(config)
- if test_regex.fullmatch(test.config_name)]
+ tests = [
+ test
+ for test in AllTests.get(config)
+ if test_regex.fullmatch(test.config_name)
+ ]
if not tests:
logging.error('Unable to find test config for "%s"', t)
sys.exit(1)
@@ -588,7 +663,7 @@ def get_test_list(config: BoardConfig, test_args) -> List[TestConfig]:
def parse_remote_arg(remote: str) -> str:
if not remote:
- return ''
+ return ""
try:
ip = socket.gethostbyname(remote)
@@ -601,67 +676,69 @@ def parse_remote_arg(remote: str) -> str:
def main():
parser = argparse.ArgumentParser()
- default_board = 'bloonchipper'
- parser.add_argument(
- '--board', '-b',
- help='Board (default: ' + default_board + ')',
- default=default_board)
-
- default_tests = 'all'
+ default_board = "bloonchipper"
parser.add_argument(
- '--tests', '-t',
- nargs='+',
- help='Tests (default: ' + default_tests + ')',
- default=default_tests)
+ "--board",
+ "-b",
+ help="Board (default: " + default_board + ")",
+ default=default_board,
+ )
- log_level_choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
+ default_tests = "all"
parser.add_argument(
- '--log_level', '-l',
- choices=log_level_choices,
- default='DEBUG'
+ "--tests",
+ "-t",
+ nargs="+",
+ help="Tests (default: " + default_tests + ")",
+ default=default_tests,
)
+ log_level_choices = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
+ parser.add_argument("--log_level", "-l", choices=log_level_choices, default="DEBUG")
+
flasher_choices = [SERVO_MICRO, JTRACE]
- parser.add_argument(
- '--flasher', '-f',
- choices=flasher_choices,
- default=JTRACE
- )
+ parser.add_argument("--flasher", "-f", choices=flasher_choices, default=JTRACE)
compiler_options = [GCC, CLANG]
- parser.add_argument('--compiler', '-c',
- choices=compiler_options,
- default=GCC)
+ parser.add_argument("--compiler", "-c", choices=compiler_options, default=GCC)
# This might be expanded to serve as a "remote" for flash_ec also, so
# we will leave it generic.
parser.add_argument(
- '--remote', '-n',
- help='The remote host connected to one or both of: J-Link and Servo.',
+ "--remote",
+ "-n",
+ help="The remote host connected to one or both of: J-Link and Servo.",
)
- parser.add_argument('--jlink_port', '-j',
- type=int,
- help='The port to use when connecting to JLink.')
- parser.add_argument('--console_port', '-p',
- type=int,
- help='The port connected to the FPMCU console.')
+ parser.add_argument(
+ "--jlink_port", "-j", type=int, help="The port to use when connecting to JLink."
+ )
+ parser.add_argument(
+ "--console_port",
+ "-p",
+ type=int,
+ help="The port connected to the FPMCU console.",
+ )
args = parser.parse_args()
logging.basicConfig(level=args.log_level)
if args.jlink_port and not args.flasher == JTRACE:
- logging.error('jlink_port specified, but flasher is not set to J-Link.')
+ logging.error("jlink_port specified, but flasher is not set to J-Link.")
sys.exit(1)
if args.remote and not (args.jlink_port or args.console_port):
- logging.error('jlink_port or console_port must be specified when using '
- 'the remote option.')
+ logging.error(
+ "jlink_port or console_port must be specified when using "
+ "the remote option."
+ )
sys.exit(1)
if (args.jlink_port or args.console_port) and not args.remote:
- logging.error('The remote option must be specified when using the '
- 'jlink_port or console_port options.')
+ logging.error(
+ "The remote option must be specified when using the "
+ "jlink_port or console_port options."
+ )
sys.exit(1)
if args.board not in BOARD_CONFIGS:
@@ -675,9 +752,7 @@ def main():
e = ThreadPoolExecutor(max_workers=1)
test_list = get_test_list(board_config, args.tests)
- logging.debug(
- 'Running tests: %s', [
- test.config_name for test in test_list])
+ logging.debug("Running tests: %s", [test.config_name for test in test_list])
for test in test_list:
build_board = args.board
@@ -689,15 +764,17 @@ def main():
# build test binary
build(test.test_name, build_board, args.compiler)
- image_path = os.path.join(EC_DIR, 'build', build_board, test.test_name,
- test.test_name + '.bin')
+ image_path = os.path.join(
+ EC_DIR, "build", build_board, test.test_name, test.test_name + ".bin"
+ )
if test.ro_image is not None:
try:
patch_image(test, image_path)
except Exception as exception:
- logging.warning('An exception occurred while patching '
- 'image: %s', exception)
+ logging.warning(
+ "An exception occurred while patching " "image: %s", exception
+ )
test.passed = False
continue
@@ -706,16 +783,16 @@ def main():
# flash_write_protect test is run; works after second attempt.
flash_succeeded = False
for i in range(0, test.num_flash_attempts):
- logging.debug('Flash attempt %d', i + 1)
- if flash(image_path, args.board, args.flasher, remote_ip,
- args.jlink_port):
+ logging.debug("Flash attempt %d", i + 1)
+ if flash(image_path, args.board, args.flasher, remote_ip, args.jlink_port):
flash_succeeded = True
break
time.sleep(1)
if not flash_succeeded:
- logging.debug('Flashing failed after max attempts: %d',
- test.num_flash_attempts)
+ logging.debug(
+ "Flashing failed after max attempts: %d", test.num_flash_attempts
+ )
test.passed = False
continue
@@ -733,11 +810,11 @@ def main():
if remote_ip and args.console_port:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((remote_ip, args.console_port))
- console = stack.enter_context(
- s.makefile(mode='rwb', buffering=0))
+ console = stack.enter_context(s.makefile(mode="rwb", buffering=0))
else:
console = stack.enter_context(
- open(get_console(board_config), 'wb+', buffering=0))
+ open(get_console(board_config), "wb+", buffering=0)
+ )
test.passed = run_test(test, console, executor=e)
@@ -745,11 +822,11 @@ def main():
exit_code = 0
for test in test_list:
# print results
- print('Test "' + test.config_name + '": ', end='')
+ print('Test "' + test.config_name + '": ', end="")
if test.passed:
- print(colorama.Fore.GREEN + 'PASSED')
+ print(colorama.Fore.GREEN + "PASSED")
else:
- print(colorama.Fore.RED + 'FAILED')
+ print(colorama.Fore.RED + "FAILED")
exit_code = 1
print(colorama.Style.RESET_ALL)
@@ -758,5 +835,5 @@ def main():
sys.exit(exit_code)
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main())
diff --git a/test/timer_calib.py b/test/timer_calib.py
index 2a625d80c7..c7ac9fc23b 100644
--- a/test/timer_calib.py
+++ b/test/timer_calib.py
@@ -7,48 +7,52 @@
import time
-def one_pass(helper):
- helper.wait_output("=== Timer calibration ===")
- res = helper.wait_output("back-to-back get_time : (?P<lat>[0-9]+) us",
- use_re=True)["lat"]
- minlat = int(res)
- helper.trace("get_time latency %d us\n" % minlat)
-
- helper.wait_output("sleep 1s")
- t0 = time.time()
- second = helper.wait_output("done. delay = (?P<second>[0-9]+) us",
- use_re=True)["second"]
- t1 = time.time()
- secondreal = t1 - t0
- secondlat = int(second) - 1000000
- helper.trace("1s timer latency %d us / real time %f s\n" % (secondlat,
- secondreal))
-
- us = {}
- for pow2 in range(7):
- delay = 1 << (7-pow2)
- us[delay] = helper.wait_output("%d us => (?P<us>[0-9]+) us" % delay,
- use_re=True)["us"]
- helper.wait_output("Done.")
-
- return minlat, secondlat, secondreal
+def one_pass(helper):
+ helper.wait_output("=== Timer calibration ===")
+ res = helper.wait_output("back-to-back get_time : (?P<lat>[0-9]+) us", use_re=True)[
+ "lat"
+ ]
+ minlat = int(res)
+ helper.trace("get_time latency %d us\n" % minlat)
+
+ helper.wait_output("sleep 1s")
+ t0 = time.time()
+ second = helper.wait_output("done. delay = (?P<second>[0-9]+) us", use_re=True)[
+ "second"
+ ]
+ t1 = time.time()
+ secondreal = t1 - t0
+ secondlat = int(second) - 1000000
+ helper.trace("1s timer latency %d us / real time %f s\n" % (secondlat, secondreal))
+
+ us = {}
+ for pow2 in range(7):
+ delay = 1 << (7 - pow2)
+ us[delay] = helper.wait_output(
+ "%d us => (?P<us>[0-9]+) us" % delay, use_re=True
+ )["us"]
+ helper.wait_output("Done.")
+
+ return minlat, secondlat, secondreal
def test(helper):
- one_pass(helper)
+ one_pass(helper)
- helper.ec_command("reboot")
- helper.wait_output("--- UART initialized")
+ helper.ec_command("reboot")
+ helper.wait_output("--- UART initialized")
- # get the timing results on the second pass
- # to avoid binary translation overhead
- minlat, secondlat, secondreal = one_pass(helper)
+ # get the timing results on the second pass
+ # to avoid binary translation overhead
+ minlat, secondlat, secondreal = one_pass(helper)
- # check that the timings somewhat make sense
- if minlat > 220 or secondlat > 500 or abs(secondreal-1.0) > 0.200:
- helper.fail("imprecise timings " +
- "(get_time %d us sleep %d us / real time %.3f s)" %
- (minlat, secondlat, secondreal))
+ # check that the timings somewhat make sense
+ if minlat > 220 or secondlat > 500 or abs(secondreal - 1.0) > 0.200:
+ helper.fail(
+ "imprecise timings "
+ + "(get_time %d us sleep %d us / real time %.3f s)"
+ % (minlat, secondlat, secondreal)
+ )
- return True # PASS !
+ return True # PASS !
diff --git a/test/timer_jump.py b/test/timer_jump.py
index f506a69fcf..2801c3b3fa 100644
--- a/test/timer_jump.py
+++ b/test/timer_jump.py
@@ -10,22 +10,25 @@ import time
DELAY = 5
ERROR_MARGIN = 0.5
+
def test(helper):
- helper.wait_output("idle task started")
- helper.ec_command("sysinfo")
- copy = helper.wait_output("Copy:\s+(?P<c>\S+)", use_re=True)["c"]
- if copy != "RO":
- helper.ec_command("sysjump ro")
- helper.wait_output("idle task started")
- helper.ec_command("gettime")
- ec_start_time = helper.wait_output("Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s",
- use_re=True)["t"]
- time.sleep(DELAY)
- helper.ec_command("sysjump a")
- helper.wait_output("idle task started")
- helper.ec_command("gettime")
- ec_end_time = helper.wait_output("Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s",
- use_re=True)["t"]
+ helper.wait_output("idle task started")
+ helper.ec_command("sysinfo")
+ copy = helper.wait_output("Copy:\s+(?P<c>\S+)", use_re=True)["c"]
+ if copy != "RO":
+ helper.ec_command("sysjump ro")
+ helper.wait_output("idle task started")
+ helper.ec_command("gettime")
+ ec_start_time = helper.wait_output(
+ "Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s", use_re=True
+ )["t"]
+ time.sleep(DELAY)
+ helper.ec_command("sysjump a")
+ helper.wait_output("idle task started")
+ helper.ec_command("gettime")
+ ec_end_time = helper.wait_output(
+ "Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s", use_re=True
+ )["t"]
- time_diff = float(ec_end_time) - float(ec_start_time)
- return time_diff >= DELAY and time_diff <= DELAY + ERROR_MARGIN
+ time_diff = float(ec_end_time) - float(ec_start_time)
+ return time_diff >= DELAY and time_diff <= DELAY + ERROR_MARGIN
diff --git a/util/build_with_clang.py b/util/build_with_clang.py
index a38ade2cb8..98da942152 100755
--- a/util/build_with_clang.py
+++ b/util/build_with_clang.py
@@ -12,15 +12,14 @@ import multiprocessing
import os
import subprocess
import sys
-
from concurrent.futures import ThreadPoolExecutor
# Add to this list as compilation errors are fixed for boards.
BOARDS_THAT_COMPILE_SUCCESSFULLY_WITH_CLANG = [
- 'dartmonkey',
- 'bloonchipper',
- 'nucleo-f412zg',
- 'nucleo-h743zi',
+ "dartmonkey",
+ "bloonchipper",
+ "nucleo-f412zg",
+ "nucleo-h743zi",
]
@@ -29,35 +28,29 @@ def build(board_name: str) -> None:
logging.debug('Building board: "%s"', board_name)
cmd = [
- 'make',
- 'BOARD=' + board_name,
- '-j',
+ "make",
+ "BOARD=" + board_name,
+ "-j",
]
- logging.debug('Running command: "%s"', ' '.join(cmd))
- subprocess.run(cmd, env=dict(os.environ, CC='clang'), check=True)
+ logging.debug('Running command: "%s"', " ".join(cmd))
+ subprocess.run(cmd, env=dict(os.environ, CC="clang"), check=True)
def main() -> int:
parser = argparse.ArgumentParser()
- log_level_choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
- parser.add_argument(
- '--log_level', '-l',
- choices=log_level_choices,
- default='DEBUG'
- )
+ log_level_choices = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
+ parser.add_argument("--log_level", "-l", choices=log_level_choices, default="DEBUG")
parser.add_argument(
- '--num_threads', '-j',
- type=int,
- default=multiprocessing.cpu_count()
+ "--num_threads", "-j", type=int, default=multiprocessing.cpu_count()
)
args = parser.parse_args()
logging.basicConfig(level=args.log_level)
- logging.debug('Building with %d threads', args.num_threads)
+ logging.debug("Building with %d threads", args.num_threads)
failed_boards = []
with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
@@ -73,13 +66,14 @@ def main() -> int:
failed_boards.append(board)
if len(failed_boards) > 0:
- logging.error('The following boards failed to compile:\n%s',
- '\n'.join(failed_boards))
+ logging.error(
+ "The following boards failed to compile:\n%s", "\n".join(failed_boards)
+ )
return 1
- logging.info('All boards compiled successfully!')
+ logging.info("All boards compiled successfully!")
return 0
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main())
diff --git a/util/chargen b/util/chargen
index 9ba14d3d6a..a1f7947a14 100644
--- a/util/chargen
+++ b/util/chargen
@@ -5,6 +5,7 @@
import sys
+
def chargen(modulo, max_chars):
"""Generate a stream of characters on the console.
@@ -18,7 +19,7 @@ def chargen(modulo, max_chars):
zero, if zero - print indefinitely
"""
- base = '0'
+ base = "0"
c = base
counter = 0
while True:
@@ -26,25 +27,25 @@ def chargen(modulo, max_chars):
counter = counter + 1
if (max_chars != 0) and (counter == max_chars):
- sys.stdout.write('\n')
+ sys.stdout.write("\n")
return
if modulo and ((counter % modulo) == 0):
c = base
continue
- if c == 'z':
+ if c == "z":
c = base
- elif c == 'Z':
- c = 'a'
- elif c == '9':
- c = 'A'
+ elif c == "Z":
+ c = "a"
+ elif c == "9":
+ c = "A"
else:
- c = '%c' % (ord(c) + 1)
+ c = "%c" % (ord(c) + 1)
def main(args):
- '''Process command line arguments and invoke chargen if args are valid'''
+ """Process command line arguments and invoke chargen if args are valid"""
modulo = 0
max_chars = 0
@@ -55,8 +56,7 @@ def main(args):
if len(args) > 1:
max_chars = int(args[1])
except ValueError:
- sys.stderr.write('usage %s:'
- "['seq_length' ['max_chars']]\n")
+ sys.stderr.write("usage %s:" "['seq_length' ['max_chars']]\n")
sys.exit(1)
try:
@@ -64,6 +64,7 @@ def main(args):
except KeyboardInterrupt:
print()
-if __name__ == '__main__':
+
+if __name__ == "__main__":
main(sys.argv[1:])
sys.exit(0)
diff --git a/util/config_option_check.py b/util/config_option_check.py
index 8bd8ecb1f0..0b05b64091 100755
--- a/util/config_option_check.py
+++ b/util/config_option_check.py
@@ -13,6 +13,7 @@ Script to ensure that all configuration options for the Chrome EC are defined
in config.h.
"""
from __future__ import print_function
+
import enum
import os
import re
@@ -21,368 +22,392 @@ import sys
class Line(object):
- """Class for each changed line in diff output.
+ """Class for each changed line in diff output.
- Attributes:
- line_num: The integer line number that this line appears in the file.
- string: The literal string of this line.
- line_type: '+' or '-' indicating if this line was an addition or
- deletion.
- """
+ Attributes:
+ line_num: The integer line number that this line appears in the file.
+ string: The literal string of this line.
+ line_type: '+' or '-' indicating if this line was an addition or
+ deletion.
+ """
- def __init__(self, line_num, string, line_type):
- """Inits Line with the line number and the actual string."""
- self.line_num = line_num
- self.string = string
- self.line_type = line_type
+ def __init__(self, line_num, string, line_type):
+ """Inits Line with the line number and the actual string."""
+ self.line_num = line_num
+ self.string = string
+ self.line_type = line_type
class Hunk(object):
- """Class for a git diff hunk.
+ """Class for a git diff hunk.
- Attributes:
- filename: The name of the file that this hunk belongs to.
- lines: A list of Line objects that are a part of this hunk.
- """
+ Attributes:
+ filename: The name of the file that this hunk belongs to.
+ lines: A list of Line objects that are a part of this hunk.
+ """
- def __init__(self, filename, lines):
- """Inits Hunk with the filename and the list of lines of the hunk."""
- self.filename = filename
- self.lines = lines
+ def __init__(self, filename, lines):
+ """Inits Hunk with the filename and the list of lines of the hunk."""
+ self.filename = filename
+ self.lines = lines
# Master file which is supposed to include all CONFIG_xxxx descriptions.
-CONFIG_FILE = 'include/config.h'
+CONFIG_FILE = "include/config.h"
# Specific files which the checker should ignore.
-ALLOWLIST = [CONFIG_FILE, 'util/config_option_check.py']
+ALLOWLIST = [CONFIG_FILE, "util/config_option_check.py"]
# Specific directories which the checker should ignore.
-ALLOW_PATTERN = re.compile('zephyr/.*')
+ALLOW_PATTERN = re.compile("zephyr/.*")
# Specific CONFIG_* flags which the checker should ignore.
-ALLOWLIST_CONFIGS = ['CONFIG_ZTEST']
+ALLOWLIST_CONFIGS = ["CONFIG_ZTEST"]
+
def obtain_current_config_options():
- """Obtains current config options from include/config.h.
-
- Scans through the main config file defined in CONFIG_FILE for all CONFIG_*
- options.
-
- Returns:
- config_options: A list of all the config options in the main CONFIG_FILE.
- """
-
- config_options = []
- config_option_re = re.compile(r'^#(define|undef)\s+(CONFIG_[A-Z0-9_]+)')
- with open(CONFIG_FILE, 'r') as config_file:
- for line in config_file:
- result = config_option_re.search(line)
- if not result:
- continue
- word = result.groups()[1]
- if word not in config_options:
- config_options.append(word)
- return config_options
+ """Obtains current config options from include/config.h.
+
+ Scans through the main config file defined in CONFIG_FILE for all CONFIG_*
+ options.
+
+ Returns:
+ config_options: A list of all the config options in the main CONFIG_FILE.
+ """
+
+ config_options = []
+ config_option_re = re.compile(r"^#(define|undef)\s+(CONFIG_[A-Z0-9_]+)")
+ with open(CONFIG_FILE, "r") as config_file:
+ for line in config_file:
+ result = config_option_re.search(line)
+ if not result:
+ continue
+ word = result.groups()[1]
+ if word not in config_options:
+ config_options.append(word)
+ return config_options
+
def obtain_config_options_in_use():
- """Obtains all the config options in use in the repo.
-
- Scans through the entire repo looking for all CONFIG_* options actively used.
-
- Returns:
- options_in_use: A set of all the config options in use in the repo.
- """
- file_list = []
- cwd = os.getcwd()
- config_option_re = re.compile(r'\b(CONFIG_[a-zA-Z0-9_]+)')
- config_debug_option_re = re.compile(r'\b(CONFIG_DEBUG_[a-zA-Z0-9_]+)')
- options_in_use = set()
- for (dirpath, dirnames, filenames) in os.walk(cwd, topdown=True):
- # Ignore the build and private directories (taken from .gitignore)
- if 'build' in dirnames:
- dirnames.remove('build')
- if 'private' in dirnames:
- dirnames.remove('private')
- for f in filenames:
- # Ignore hidden files.
- if f.startswith('.'):
- continue
- # Only consider C source, assembler, and Make-style files.
- if (os.path.splitext(f)[1] in ('.c', '.h', '.inc', '.S', '.mk') or
- 'Makefile' in f):
- file_list.append(os.path.join(dirpath, f))
-
- # Search through each file and build a set of the CONFIG_* options being
- # used.
-
- for f in file_list:
- if CONFIG_FILE in f:
- continue
- with open(f, 'r') as cur_file:
- for line in cur_file:
- match = config_option_re.findall(line)
- if match:
- for option in match:
- if not in_comment(f, line, option):
- if option not in options_in_use:
- options_in_use.add(option)
-
- # Since debug options can be turned on at any time, assume that they are
- # always in use in case any aren't being used.
-
- with open(CONFIG_FILE, 'r') as config_file:
- for line in config_file:
- match = config_debug_option_re.findall(line)
- if match:
- for option in match:
- if not in_comment(CONFIG_FILE, line, option):
- if option not in options_in_use:
- options_in_use.add(option)
-
- return options_in_use
+ """Obtains all the config options in use in the repo.
+
+ Scans through the entire repo looking for all CONFIG_* options actively used.
+
+ Returns:
+ options_in_use: A set of all the config options in use in the repo.
+ """
+ file_list = []
+ cwd = os.getcwd()
+ config_option_re = re.compile(r"\b(CONFIG_[a-zA-Z0-9_]+)")
+ config_debug_option_re = re.compile(r"\b(CONFIG_DEBUG_[a-zA-Z0-9_]+)")
+ options_in_use = set()
+ for (dirpath, dirnames, filenames) in os.walk(cwd, topdown=True):
+ # Ignore the build and private directories (taken from .gitignore)
+ if "build" in dirnames:
+ dirnames.remove("build")
+ if "private" in dirnames:
+ dirnames.remove("private")
+ for f in filenames:
+ # Ignore hidden files.
+ if f.startswith("."):
+ continue
+ # Only consider C source, assembler, and Make-style files.
+ if (
+ os.path.splitext(f)[1] in (".c", ".h", ".inc", ".S", ".mk")
+ or "Makefile" in f
+ ):
+ file_list.append(os.path.join(dirpath, f))
+
+ # Search through each file and build a set of the CONFIG_* options being
+ # used.
+
+ for f in file_list:
+ if CONFIG_FILE in f:
+ continue
+ with open(f, "r") as cur_file:
+ for line in cur_file:
+ match = config_option_re.findall(line)
+ if match:
+ for option in match:
+ if not in_comment(f, line, option):
+ if option not in options_in_use:
+ options_in_use.add(option)
+
+ # Since debug options can be turned on at any time, assume that they are
+ # always in use in case any aren't being used.
+
+ with open(CONFIG_FILE, "r") as config_file:
+ for line in config_file:
+ match = config_debug_option_re.findall(line)
+ if match:
+ for option in match:
+ if not in_comment(CONFIG_FILE, line, option):
+ if option not in options_in_use:
+ options_in_use.add(option)
+
+ return options_in_use
+
def print_missing_config_options(hunks, config_options):
- """Searches thru all the changes in hunks for missing options and prints them.
-
- Args:
- hunks: A list of Hunk objects which represent the hunks from the git
- diff output.
- config_options: A list of all the config options in the main CONFIG_FILE.
-
- Returns:
- missing_config_option: A boolean indicating if any CONFIG_* options
- are missing from the main CONFIG_FILE in this commit or if any CONFIG_*
- options removed are no longer being used in the repo.
- """
- missing_config_option = False
- print_banner = True
- deprecated_options = set()
- # Determine longest CONFIG_* length to be used for formatting.
- max_option_length = max(len(option) for option in config_options)
- config_option_re = re.compile(r'\b(CONFIG_[a-zA-Z0-9_]+)')
-
- # Search for all CONFIG_* options in use in the repo.
- options_in_use = obtain_config_options_in_use()
-
- # Check each hunk's line for a missing config option.
- for h in hunks:
- for l in h.lines:
- # Check for the existence of a CONFIG_* in the line.
- match = filter(lambda opt: opt in ALLOWLIST_CONFIGS,
- config_option_re.findall(l.string))
- if not match:
- continue
-
- # At this point, an option was found in the line. However, we need to
- # verify that it is not within a comment.
- violations = set()
-
- for option in match:
- if not in_comment(h.filename, l.string, option):
- # Since the CONFIG_* option is not within a comment, we've found a
- # violation. We now need to determine if this line is a deletion or
- # not. For deletions, we will need to verify if this CONFIG_* option
- # is no longer being used in the entire repo.
-
- if l.line_type == '-':
- if option not in options_in_use and option in config_options:
- deprecated_options.add(option)
- else:
- violations.add(option)
-
- # Check to see if the CONFIG_* option is in the config file and print the
- # violations.
- for option in match:
- if option not in config_options and option in violations:
- # Print the banner once.
- if print_banner:
- print('The following config options were found to be missing '
- 'from %s.\n'
- 'Please add new config options there along with '
- 'descriptions.\n\n' % CONFIG_FILE)
- print_banner = False
+ """Searches thru all the changes in hunks for missing options and prints them.
+
+ Args:
+ hunks: A list of Hunk objects which represent the hunks from the git
+ diff output.
+ config_options: A list of all the config options in the main CONFIG_FILE.
+
+ Returns:
+ missing_config_option: A boolean indicating if any CONFIG_* options
+ are missing from the main CONFIG_FILE in this commit or if any CONFIG_*
+ options removed are no longer being used in the repo.
+ """
+ missing_config_option = False
+ print_banner = True
+ deprecated_options = set()
+ # Determine longest CONFIG_* length to be used for formatting.
+ max_option_length = max(len(option) for option in config_options)
+ config_option_re = re.compile(r"\b(CONFIG_[a-zA-Z0-9_]+)")
+
+ # Search for all CONFIG_* options in use in the repo.
+ options_in_use = obtain_config_options_in_use()
+
+ # Check each hunk's line for a missing config option.
+ for h in hunks:
+ for l in h.lines:
+ # Check for the existence of a CONFIG_* in the line.
+ match = filter(
+ lambda opt: opt in ALLOWLIST_CONFIGS, config_option_re.findall(l.string)
+ )
+ if not match:
+ continue
+
+ # At this point, an option was found in the line. However, we need to
+ # verify that it is not within a comment.
+ violations = set()
+
+ for option in match:
+ if not in_comment(h.filename, l.string, option):
+ # Since the CONFIG_* option is not within a comment, we've found a
+ # violation. We now need to determine if this line is a deletion or
+ # not. For deletions, we will need to verify if this CONFIG_* option
+ # is no longer being used in the entire repo.
+
+ if l.line_type == "-":
+ if option not in options_in_use and option in config_options:
+ deprecated_options.add(option)
+ else:
+ violations.add(option)
+
+ # Check to see if the CONFIG_* option is in the config file and print the
+ # violations.
+ for option in match:
+ if option not in config_options and option in violations:
+ # Print the banner once.
+ if print_banner:
+ print(
+ "The following config options were found to be missing "
+ "from %s.\n"
+ "Please add new config options there along with "
+ "descriptions.\n\n" % CONFIG_FILE
+ )
+ print_banner = False
+ missing_config_option = True
+ # Print the misssing config option.
+ print(
+ "> %-*s %s:%s"
+ % (max_option_length, option, h.filename, l.line_num)
+ )
+
+ if deprecated_options:
+ print(
+ "\n\nThe following config options are being removed and also appear"
+ " to be the last uses\nof that option. Please remove these "
+ "options from %s.\n\n" % CONFIG_FILE
+ )
+ for option in deprecated_options:
+ print("> %s" % option)
missing_config_option = True
- # Print the misssing config option.
- print('> %-*s %s:%s' % (max_option_length, option,
- h.filename,
- l.line_num))
- if deprecated_options:
- print('\n\nThe following config options are being removed and also appear'
- ' to be the last uses\nof that option. Please remove these '
- 'options from %s.\n\n' % CONFIG_FILE)
- for option in deprecated_options:
- print('> %s' % option)
- missing_config_option = True
+ return missing_config_option
- return missing_config_option
def in_comment(filename, line, substr):
- """Checks if given substring appears in a comment.
-
- Args:
- filename: The filename where this line is from. This is used to determine
- what kind of comments to look for.
- line: String of line to search in.
- substr: Substring to search for in the line.
-
- Returns:
- is_in_comment: Boolean indicating if substr was in a comment.
- """
-
- c_style_ext = ('.c', '.h', '.inc', '.S')
- make_style_ext = ('.mk')
- is_in_comment = False
-
- extension = os.path.splitext(filename)[1]
- substr_idx = line.find(substr)
-
- # Different files have different comment syntax; Handle appropriately.
- if extension in c_style_ext:
- beg_comment_idx = line.find('/*')
- end_comment_idx = line.find('*/')
- if end_comment_idx == -1:
- end_comment_idx = len(line)
-
- if beg_comment_idx == -1:
- # Check to see if this line is from a multi-line comment.
- if line.lstrip().startswith('* '):
- # It _seems_ like it is.
- is_in_comment = True
- else:
- # Check to see if its actually inside the comment.
- if beg_comment_idx < substr_idx < end_comment_idx:
- is_in_comment = True
- elif extension in make_style_ext or 'Makefile' in filename:
- beg_comment_idx = line.find('#')
- # Ignore everything to the right of the hash.
- if beg_comment_idx < substr_idx and beg_comment_idx != -1:
- is_in_comment = True
- return is_in_comment
+ """Checks if given substring appears in a comment.
+
+ Args:
+ filename: The filename where this line is from. This is used to determine
+ what kind of comments to look for.
+ line: String of line to search in.
+ substr: Substring to search for in the line.
+
+ Returns:
+ is_in_comment: Boolean indicating if substr was in a comment.
+ """
+
+ c_style_ext = (".c", ".h", ".inc", ".S")
+ make_style_ext = ".mk"
+ is_in_comment = False
+
+ extension = os.path.splitext(filename)[1]
+ substr_idx = line.find(substr)
+
+ # Different files have different comment syntax; Handle appropriately.
+ if extension in c_style_ext:
+ beg_comment_idx = line.find("/*")
+ end_comment_idx = line.find("*/")
+ if end_comment_idx == -1:
+ end_comment_idx = len(line)
+
+ if beg_comment_idx == -1:
+ # Check to see if this line is from a multi-line comment.
+ if line.lstrip().startswith("* "):
+ # It _seems_ like it is.
+ is_in_comment = True
+ else:
+ # Check to see if its actually inside the comment.
+ if beg_comment_idx < substr_idx < end_comment_idx:
+ is_in_comment = True
+ elif extension in make_style_ext or "Makefile" in filename:
+ beg_comment_idx = line.find("#")
+ # Ignore everything to the right of the hash.
+ if beg_comment_idx < substr_idx and beg_comment_idx != -1:
+ is_in_comment = True
+ return is_in_comment
+
def get_hunks():
- """Gets the hunks of the most recent commit.
-
- States:
- new_file: Searching for a new file in the git diff.
- filename_search: Searching for the filename of this hunk.
- hunk: Searching for the beginning of a new hunk.
- lines: Counting line numbers and searching for changes.
-
- Returns:
- hunks: A list of Hunk objects which represent the hunks in the git diff
- output.
- """
-
- diff = []
- hunks = []
- hunk_lines = []
- line = ''
- filename = ''
- i = 0
- line_num = 0
-
- # Regex patterns
- new_file_re = re.compile(r'^diff --git')
- filename_re = re.compile(r'^[+]{3} (.*)')
- hunk_line_num_re = re.compile(r'^@@ -[0-9]+,[0-9]+ \+([0-9]+),[0-9]+ @@.*')
- line_re = re.compile(r'^([+| |-])(.*)')
-
- # Get the diff output.
- proc = subprocess.run(['git', 'diff', '--cached', '-GCONFIG_*', '--no-prefix',
- '--no-ext-diff', 'HEAD~1'],
- stdout=subprocess.PIPE,
- encoding='utf-8',
- check=True)
- diff = proc.stdout.splitlines()
- if not diff:
- return []
- line = diff[0]
-
- state = enum.Enum('state', 'NEW_FILE FILENAME_SEARCH HUNK LINES')
- current_state = state.NEW_FILE
-
- while True:
- # Search for the beginning of a new file.
- if current_state is state.NEW_FILE:
- match = new_file_re.search(line)
- if match:
- current_state = state.FILENAME_SEARCH
-
- # Search the diff output for a file name.
- elif current_state is state.FILENAME_SEARCH:
- # Search for a file name.
- match = filename_re.search(line)
- if match:
- filename = match.groups(1)[0]
- if filename in ALLOWLIST or ALLOW_PATTERN.match(filename):
- # Skip the file if it's allowlisted.
- current_state = state.NEW_FILE
- else:
- current_state = state.HUNK
-
- # Search for a hunk. Each hunk starts with a line describing the line
- # numbers in the file.
- elif current_state is state.HUNK:
- hunk_lines = []
- match = hunk_line_num_re.search(line)
- if match:
- # Extract the line number offset.
- line_num = int(match.groups(1)[0])
- current_state = state.LINES
-
- # Start looking for changes.
- elif current_state is state.LINES:
- # Check if state needs updating.
- new_hunk = hunk_line_num_re.search(line)
- new_file = new_file_re.search(line)
- if new_hunk:
- current_state = state.HUNK
- hunks.append(Hunk(filename, hunk_lines))
- continue
- elif new_file:
- current_state = state.NEW_FILE
- hunks.append(Hunk(filename, hunk_lines))
- continue
-
- match = line_re.search(line)
- if match:
- line_type = match.groups(1)[0]
- # We only care about modifications.
- if line_type != ' ':
- hunk_lines.append(Line(line_num, match.groups(2)[1], line_type))
- # Deletions don't count towards the line numbers.
- if line_type != '-':
- line_num += 1
-
- # Advance to the next line
- try:
- i += 1
- line = diff[i]
- except IndexError:
- # We've reached the end of the diff. Return what we have.
- if hunk_lines:
- hunks.append(Hunk(filename, hunk_lines))
- return hunks
+ """Gets the hunks of the most recent commit.
+
+ States:
+ new_file: Searching for a new file in the git diff.
+ filename_search: Searching for the filename of this hunk.
+ hunk: Searching for the beginning of a new hunk.
+ lines: Counting line numbers and searching for changes.
+
+ Returns:
+ hunks: A list of Hunk objects which represent the hunks in the git diff
+ output.
+ """
+
+ diff = []
+ hunks = []
+ hunk_lines = []
+ line = ""
+ filename = ""
+ i = 0
+ line_num = 0
+
+ # Regex patterns
+ new_file_re = re.compile(r"^diff --git")
+ filename_re = re.compile(r"^[+]{3} (.*)")
+ hunk_line_num_re = re.compile(r"^@@ -[0-9]+,[0-9]+ \+([0-9]+),[0-9]+ @@.*")
+ line_re = re.compile(r"^([+| |-])(.*)")
+
+ # Get the diff output.
+ proc = subprocess.run(
+ [
+ "git",
+ "diff",
+ "--cached",
+ "-GCONFIG_*",
+ "--no-prefix",
+ "--no-ext-diff",
+ "HEAD~1",
+ ],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ check=True,
+ )
+ diff = proc.stdout.splitlines()
+ if not diff:
+ return []
+ line = diff[0]
+
+ state = enum.Enum("state", "NEW_FILE FILENAME_SEARCH HUNK LINES")
+ current_state = state.NEW_FILE
+
+ while True:
+ # Search for the beginning of a new file.
+ if current_state is state.NEW_FILE:
+ match = new_file_re.search(line)
+ if match:
+ current_state = state.FILENAME_SEARCH
+
+ # Search the diff output for a file name.
+ elif current_state is state.FILENAME_SEARCH:
+ # Search for a file name.
+ match = filename_re.search(line)
+ if match:
+ filename = match.groups(1)[0]
+ if filename in ALLOWLIST or ALLOW_PATTERN.match(filename):
+ # Skip the file if it's allowlisted.
+ current_state = state.NEW_FILE
+ else:
+ current_state = state.HUNK
+
+ # Search for a hunk. Each hunk starts with a line describing the line
+ # numbers in the file.
+ elif current_state is state.HUNK:
+ hunk_lines = []
+ match = hunk_line_num_re.search(line)
+ if match:
+ # Extract the line number offset.
+ line_num = int(match.groups(1)[0])
+ current_state = state.LINES
+
+ # Start looking for changes.
+ elif current_state is state.LINES:
+ # Check if state needs updating.
+ new_hunk = hunk_line_num_re.search(line)
+ new_file = new_file_re.search(line)
+ if new_hunk:
+ current_state = state.HUNK
+ hunks.append(Hunk(filename, hunk_lines))
+ continue
+ elif new_file:
+ current_state = state.NEW_FILE
+ hunks.append(Hunk(filename, hunk_lines))
+ continue
+
+ match = line_re.search(line)
+ if match:
+ line_type = match.groups(1)[0]
+ # We only care about modifications.
+ if line_type != " ":
+ hunk_lines.append(Line(line_num, match.groups(2)[1], line_type))
+ # Deletions don't count towards the line numbers.
+ if line_type != "-":
+ line_num += 1
+
+ # Advance to the next line
+ try:
+ i += 1
+ line = diff[i]
+ except IndexError:
+ # We've reached the end of the diff. Return what we have.
+ if hunk_lines:
+ hunks.append(Hunk(filename, hunk_lines))
+ return hunks
+
def main():
- """Searches through committed changes for missing config options.
-
- Checks through committed changes for CONFIG_* options. Then checks to make
- sure that all CONFIG_* options used are defined in include/config.h. Finally,
- reports any missing config options.
- """
- # Obtain the hunks of the commit to search through.
- hunks = get_hunks()
- # Obtain config options from include/config.h.
- config_options = obtain_current_config_options()
- # Find any missing config options from the hunks and print them.
- missing_opts = print_missing_config_options(hunks, config_options)
-
- if missing_opts:
- print('\nIt may also be possible that you have a typo.')
- sys.exit(1)
-
-if __name__ == '__main__':
- main()
+ """Searches through committed changes for missing config options.
+
+ Checks through committed changes for CONFIG_* options. Then checks to make
+ sure that all CONFIG_* options used are defined in include/config.h. Finally,
+ reports any missing config options.
+ """
+ # Obtain the hunks of the commit to search through.
+ hunks = get_hunks()
+ # Obtain config options from include/config.h.
+ config_options = obtain_current_config_options()
+ # Find any missing config options from the hunks and print them.
+ missing_opts = print_missing_config_options(hunks, config_options)
+
+ if missing_opts:
+ print("\nIt may also be possible that you have a typo.")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/util/ec3po/console.py b/util/ec3po/console.py
index e71216e3f2..33fa5a6775 100755
--- a/util/ec3po/console.py
+++ b/util/ec3po/console.py
@@ -17,7 +17,6 @@ from __future__ import print_function
import argparse
import binascii
import ctypes
-from datetime import datetime
import logging
import os
import pty
@@ -26,26 +25,25 @@ import select
import stat
import sys
import traceback
+from datetime import datetime
import six
+from ec3po import interpreter, threadproc_shim
-from ec3po import interpreter
-from ec3po import threadproc_shim
-
-
-PROMPT = b'> '
+PROMPT = b"> "
CONSOLE_INPUT_LINE_SIZE = 80 # Taken from the CONFIG_* with the same name.
CONSOLE_MAX_READ = 100 # Max bytes to read at a time from the user.
LOOK_BUFFER_SIZE = 256 # Size of search window when looking for the enhanced EC
- # image string.
+# image string.
# In console_init(), the EC will print a string saying that the EC console is
# enabled. Enhanced images will print a slightly different string. These
# regular expressions are used to determine at reboot whether the EC image is
# enhanced or not.
-ENHANCED_IMAGE_RE = re.compile(br'Enhanced Console is enabled '
- br'\(v([0-9]+\.[0-9]+\.[0-9]+)\)')
-NON_ENHANCED_IMAGE_RE = re.compile(br'Console is enabled; ')
+ENHANCED_IMAGE_RE = re.compile(
+ rb"Enhanced Console is enabled " rb"\(v([0-9]+\.[0-9]+\.[0-9]+)\)"
+)
+NON_ENHANCED_IMAGE_RE = re.compile(rb"Console is enabled; ")
# The timeouts are really only useful for enhanced EC images, but otherwise just
# serve as a delay for non-enhanced EC images. Therefore, we can keep this
@@ -54,1118 +52,1173 @@ NON_ENHANCED_IMAGE_RE = re.compile(br'Console is enabled; ')
# EC image, we can increase the timeout for stability just in case it takes a
# bit longer to receive an ACK for some reason.
NON_ENHANCED_EC_INTERROGATION_TIMEOUT = 0.3 # Maximum number of seconds to wait
- # for a response to an
- # interrogation of a non-enhanced
- # EC image.
+# for a response to an
+# interrogation of a non-enhanced
+# EC image.
ENHANCED_EC_INTERROGATION_TIMEOUT = 1.0 # Maximum number of seconds to wait for
- # a response to an interrogation of an
- # enhanced EC image.
+# a response to an interrogation of an
+# enhanced EC image.
# List of modes which control when interrogations are performed with the EC.
-INTERROGATION_MODES = [b'never', b'always', b'auto']
+INTERROGATION_MODES = [b"never", b"always", b"auto"]
# Format for printing host timestamp
-HOST_STRFTIME="%y-%m-%d %H:%M:%S.%f"
+HOST_STRFTIME = "%y-%m-%d %H:%M:%S.%f"
class EscState(object):
- """Class which contains an enumeration for states of ESC sequences."""
- ESC_START = 1
- ESC_BRACKET = 2
- ESC_BRACKET_1 = 3
- ESC_BRACKET_3 = 4
- ESC_BRACKET_8 = 5
-
+ """Class which contains an enumeration for states of ESC sequences."""
-class ControlKey(object):
- """Class which contains codes for various control keys."""
- BACKSPACE = 0x08
- CTRL_A = 0x01
- CTRL_B = 0x02
- CTRL_D = 0x04
- CTRL_E = 0x05
- CTRL_F = 0x06
- CTRL_K = 0x0b
- CTRL_N = 0xe
- CTRL_P = 0x10
- CARRIAGE_RETURN = 0x0d
- ESC = 0x1b
+ ESC_START = 1
+ ESC_BRACKET = 2
+ ESC_BRACKET_1 = 3
+ ESC_BRACKET_3 = 4
+ ESC_BRACKET_8 = 5
-class Console(object):
- """Class which provides the console interface between the EC and the user.
-
- This class essentially represents the console interface between the user and
- the EC. It handles all of the console editing behaviour
-
- Attributes:
- logger: A logger for this module.
- controller_pty: File descriptor to the controller side of the PTY. Used for
- driving output to the user and receiving user input.
- user_pty: A string representing the PTY name of the served console.
- cmd_pipe: A socket.socket or multiprocessing.Connection object which
- represents the console side of the command pipe. This must be a
- bidirectional pipe. Console commands and responses utilize this pipe.
- dbg_pipe: A socket.socket or multiprocessing.Connection object which
- represents the console's read-only side of the debug pipe. This must be a
- unidirectional pipe attached to the intepreter. EC debug messages use
- this pipe.
- oobm_queue: A queue.Queue or multiprocessing.Queue which is used for out of
- band management for the interactive console.
- input_buffer: A string representing the current input command.
- input_buffer_pos: An integer representing the current position in the buffer
- to insert a char.
- partial_cmd: A string representing the command entered on a line before
- pressing the up arrow keys.
- esc_state: An integer represeting the current state within an escape
- sequence.
- line_limit: An integer representing the maximum number of characters on a
- line.
- history: A list of strings containing the past entered console commands.
- history_pos: An integer representing the current history buffer position.
- This index is used to show previous commands.
- prompt: A string representing the console prompt displayed to the user.
- enhanced_ec: A boolean indicating if the EC image that we are currently
- communicating with is enhanced or not. Enhanced EC images will support
- packed commands and host commands over the UART. This defaults to False
- until we perform some handshaking.
- interrogation_timeout: A float representing the current maximum seconds to
- wait for a response to an interrogation.
- receiving_oobm_cmd: A boolean indicating whether or not the console is in
- the middle of receiving an out of band command.
- pending_oobm_cmd: A string containing the pending OOBM command.
- interrogation_mode: A string containing the current mode of whether
- interrogations are performed with the EC or not and how often.
- raw_debug: Flag to indicate whether per interrupt data should be logged to
- debug
- output_line_log_buffer: buffer for lines coming from the EC to log to debug
- """
-
- def __init__(self, controller_pty, user_pty, interface_pty, cmd_pipe, dbg_pipe,
- name=None):
- """Initalises a Console object with the provided arguments.
+class ControlKey(object):
+ """Class which contains codes for various control keys."""
- Args:
- controller_pty: File descriptor to the controller side of the PTY. Used for
- driving output to the user and receiving user input.
- user_pty: A string representing the PTY name of the served console.
- interface_pty: A string representing the PTY name of the served command
- interface.
- cmd_pipe: A socket.socket or multiprocessing.Connection object which
- represents the console side of the command pipe. This must be a
- bidirectional pipe. Console commands and responses utilize this pipe.
- dbg_pipe: A socket.socket or multiprocessing.Connection object which
- represents the console's read-only side of the debug pipe. This must be a
- unidirectional pipe attached to the intepreter. EC debug messages use
- this pipe.
- name: the console source name
- """
- # Create a unique logger based on the console name
- console_prefix = ('%s - ' % name) if name else ''
- logger = logging.getLogger('%sEC3PO.Console' % console_prefix)
- self.logger = interpreter.LoggerAdapter(logger, {'pty': user_pty})
- self.controller_pty = controller_pty
- self.user_pty = user_pty
- self.interface_pty = interface_pty
- self.cmd_pipe = cmd_pipe
- self.dbg_pipe = dbg_pipe
- self.oobm_queue = threadproc_shim.Queue()
- self.input_buffer = b''
- self.input_buffer_pos = 0
- self.partial_cmd = b''
- self.esc_state = 0
- self.line_limit = CONSOLE_INPUT_LINE_SIZE
- self.history = []
- self.history_pos = 0
- self.prompt = PROMPT
- self.enhanced_ec = False
- self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT
- self.receiving_oobm_cmd = False
- self.pending_oobm_cmd = b''
- self.interrogation_mode = b'auto'
- self.timestamp_enabled = True
- self.look_buffer = b''
- self.raw_debug = False
- self.output_line_log_buffer = []
-
- def __str__(self):
- """Show internal state of Console object as a string."""
- string = []
- string.append('controller_pty: %s' % self.controller_pty)
- string.append('user_pty: %s' % self.user_pty)
- string.append('interface_pty: %s' % self.interface_pty)
- string.append('cmd_pipe: %s' % self.cmd_pipe)
- string.append('dbg_pipe: %s' % self.dbg_pipe)
- string.append('oobm_queue: %s' % self.oobm_queue)
- string.append('input_buffer: %s' % self.input_buffer)
- string.append('input_buffer_pos: %d' % self.input_buffer_pos)
- string.append('esc_state: %d' % self.esc_state)
- string.append('line_limit: %d' % self.line_limit)
- string.append('history: %r' % self.history)
- string.append('history_pos: %d' % self.history_pos)
- string.append('prompt: %r' % self.prompt)
- string.append('partial_cmd: %r'% self.partial_cmd)
- string.append('interrogation_mode: %r' % self.interrogation_mode)
- string.append('look_buffer: %r' % self.look_buffer)
- return '\n'.join(string)
-
- def LogConsoleOutput(self, data):
- """Log to debug user MCU output to controller_pty when line is filled.
-
- The logging also suppresses the Cr50 spinner lines by removing characters
- when it sees backspaces.
+ BACKSPACE = 0x08
+ CTRL_A = 0x01
+ CTRL_B = 0x02
+ CTRL_D = 0x04
+ CTRL_E = 0x05
+ CTRL_F = 0x06
+ CTRL_K = 0x0B
+ CTRL_N = 0xE
+ CTRL_P = 0x10
+ CARRIAGE_RETURN = 0x0D
+ ESC = 0x1B
- Args:
- data: bytes - string received from MCU
- """
- data = list(data)
- # For compatibility with python2 and python3, standardize on the data
- # being a list of integers. This requires one more transformation in py2
- if not isinstance(data[0], int):
- data = [ord(c) for c in data]
-
- # This is a list of already filtered characters (or placeholders).
- line = self.output_line_log_buffer
-
- # TODO(b/177480273): use raw strings here
- symbols = {
- ord(b'\n'): u'\\n',
- ord(b'\r'): u'\\r',
- ord(b'\t'): u'\\t'
- }
- # self.logger.debug(u'%s + %r', u''.join(line), ''.join(data))
- while data:
- # Recall, data is a list of integers, namely the byte values sent by
- # the MCU.
- byte = data.pop(0)
- # This means that |byte| is an int.
- if byte == ord('\n'):
- line.append(symbols[byte])
- if line:
- self.logger.debug(u'%s', ''.join(line))
- line = []
- elif byte == ord('\b'):
- # Backspace: trim the last character off the buffer
- if line:
- line.pop(-1)
- elif byte in symbols:
- line.append(symbols[byte])
- elif byte < ord(' ') or byte > ord('~'):
- # Turn any character that isn't printable ASCII into escaped hex.
- # ' ' is chr(20), and 0-19 are unprintable control characters.
- # '~' is chr(126), and 127 is DELETE. 128-255 are control and Latin-1.
- line.append(u'\\x%02x' % byte)
- else:
- # byte is printable. Thus it is safe to use chr() to get the printable
- # character out of it again.
- line.append(u'%s' % chr(byte))
- self.output_line_log_buffer = line
-
- def PrintHistory(self):
- """Print the history of entered commands."""
- fd = self.controller_pty
- # Make it pretty by figuring out how wide to pad the numbers.
- wide = (len(self.history) // 10) + 1
- for i in range(len(self.history)):
- line = b' %*d %s\r\n' % (wide, i, self.history[i])
- os.write(fd, line)
-
- def ShowPreviousCommand(self):
- """Shows the previous command from the history list."""
- # There's nothing to do if there's no history at all.
- if not self.history:
- self.logger.debug('No history to print.')
- return
-
- # Don't do anything if there's no more history to show.
- if self.history_pos == 0:
- self.logger.debug('No more history to show.')
- return
-
- self.logger.debug('current history position: %d.', self.history_pos)
-
- # Decrement the history buffer position.
- self.history_pos -= 1
- self.logger.debug('new history position.: %d', self.history_pos)
-
- # Save the text entered on the console if any.
- if self.history_pos == len(self.history)-1:
- self.logger.debug('saving partial_cmd: %r', self.input_buffer)
- self.partial_cmd = self.input_buffer
-
- # Backspace the line.
- for _ in range(self.input_buffer_pos):
- self.SendBackspace()
-
- # Print the last entry in the history buffer.
- self.logger.debug('printing previous entry %d - %s', self.history_pos,
- self.history[self.history_pos])
- fd = self.controller_pty
- prev_cmd = self.history[self.history_pos]
- os.write(fd, prev_cmd)
- # Update the input buffer.
- self.input_buffer = prev_cmd
- self.input_buffer_pos = len(prev_cmd)
-
- def ShowNextCommand(self):
- """Shows the next command from the history list."""
- # Don't do anything if there's no history at all.
- if not self.history:
- self.logger.debug('History buffer is empty.')
- return
-
- fd = self.controller_pty
-
- self.logger.debug('current history position: %d', self.history_pos)
- # Increment the history position.
- self.history_pos += 1
-
- # Restore the partial cmd.
- if self.history_pos == len(self.history):
- self.logger.debug('Restoring partial command of %r', self.partial_cmd)
- # Backspace the line.
- for _ in range(self.input_buffer_pos):
- self.SendBackspace()
- # Print the partially entered command if any.
- os.write(fd, self.partial_cmd)
- self.input_buffer = self.partial_cmd
- self.input_buffer_pos = len(self.input_buffer)
- # Now that we've printed it, clear the partial cmd storage.
- self.partial_cmd = b''
- # Reset history position.
- self.history_pos = len(self.history)
- return
-
- self.logger.debug('new history position: %d', self.history_pos)
- if self.history_pos > len(self.history)-1:
- self.logger.debug('No more history to show.')
- self.history_pos -= 1
- self.logger.debug('Reset history position to %d', self.history_pos)
- return
-
- # Backspace the line.
- for _ in range(self.input_buffer_pos):
- self.SendBackspace()
-
- # Print the newer entry from the history buffer.
- self.logger.debug('printing next entry %d - %s', self.history_pos,
- self.history[self.history_pos])
- next_cmd = self.history[self.history_pos]
- os.write(fd, next_cmd)
- # Update the input buffer.
- self.input_buffer = next_cmd
- self.input_buffer_pos = len(next_cmd)
- self.logger.debug('new history position: %d.', self.history_pos)
-
- def SliceOutChar(self):
- """Remove a char from the line and shift everything over 1 column."""
- fd = self.controller_pty
- # Remove the character at the input_buffer_pos by slicing it out.
- self.input_buffer = self.input_buffer[0:self.input_buffer_pos] + \
- self.input_buffer[self.input_buffer_pos+1:]
- # Write the rest of the line
- moved_col = os.write(fd, self.input_buffer[self.input_buffer_pos:])
- # Write a space to clear out the last char
- moved_col += os.write(fd, b' ')
- # Update the input buffer position.
- self.input_buffer_pos += moved_col
- # Reset the cursor
- self.MoveCursor('left', moved_col)
-
- def HandleEsc(self, byte):
- """HandleEsc processes escape sequences.
- Args:
- byte: An integer representing the current byte in the sequence.
+class Console(object):
+ """Class which provides the console interface between the EC and the user.
+
+ This class essentially represents the console interface between the user and
+ the EC. It handles all of the console editing behaviour
+
+ Attributes:
+ logger: A logger for this module.
+ controller_pty: File descriptor to the controller side of the PTY. Used for
+ driving output to the user and receiving user input.
+ user_pty: A string representing the PTY name of the served console.
+ cmd_pipe: A socket.socket or multiprocessing.Connection object which
+ represents the console side of the command pipe. This must be a
+ bidirectional pipe. Console commands and responses utilize this pipe.
+ dbg_pipe: A socket.socket or multiprocessing.Connection object which
+ represents the console's read-only side of the debug pipe. This must be a
+ unidirectional pipe attached to the intepreter. EC debug messages use
+ this pipe.
+ oobm_queue: A queue.Queue or multiprocessing.Queue which is used for out of
+ band management for the interactive console.
+ input_buffer: A string representing the current input command.
+ input_buffer_pos: An integer representing the current position in the buffer
+ to insert a char.
+ partial_cmd: A string representing the command entered on a line before
+ pressing the up arrow keys.
+ esc_state: An integer represeting the current state within an escape
+ sequence.
+ line_limit: An integer representing the maximum number of characters on a
+ line.
+ history: A list of strings containing the past entered console commands.
+ history_pos: An integer representing the current history buffer position.
+ This index is used to show previous commands.
+ prompt: A string representing the console prompt displayed to the user.
+ enhanced_ec: A boolean indicating if the EC image that we are currently
+ communicating with is enhanced or not. Enhanced EC images will support
+ packed commands and host commands over the UART. This defaults to False
+ until we perform some handshaking.
+ interrogation_timeout: A float representing the current maximum seconds to
+ wait for a response to an interrogation.
+ receiving_oobm_cmd: A boolean indicating whether or not the console is in
+ the middle of receiving an out of band command.
+ pending_oobm_cmd: A string containing the pending OOBM command.
+ interrogation_mode: A string containing the current mode of whether
+ interrogations are performed with the EC or not and how often.
+ raw_debug: Flag to indicate whether per interrupt data should be logged to
+ debug
+ output_line_log_buffer: buffer for lines coming from the EC to log to debug
"""
- # We shouldn't be handling an escape sequence if we haven't seen one.
- assert self.esc_state != 0
-
- if self.esc_state is EscState.ESC_START:
- self.logger.debug('ESC_START')
- if byte == ord('['):
- self.esc_state = EscState.ESC_BRACKET
- return
- else:
- self.logger.error('Unexpected sequence. %c', byte)
- self.esc_state = 0
-
- elif self.esc_state is EscState.ESC_BRACKET:
- self.logger.debug('ESC_BRACKET')
- # Left Arrow key was pressed.
- if byte == ord('D'):
- self.logger.debug('Left arrow key pressed.')
- self.MoveCursor('left', 1)
- self.esc_state = 0 # Reset the state.
- return
-
- # Right Arrow key.
- elif byte == ord('C'):
- self.logger.debug('Right arrow key pressed.')
- self.MoveCursor('right', 1)
- self.esc_state = 0 # Reset the state.
- return
-
- # Up Arrow key.
- elif byte == ord('A'):
- self.logger.debug('Up arrow key pressed.')
- self.ShowPreviousCommand()
- # Reset the state.
- self.esc_state = 0 # Reset the state.
- return
-
- # Down Arrow key.
- elif byte == ord('B'):
- self.logger.debug('Down arrow key pressed.')
- self.ShowNextCommand()
- # Reset the state.
- self.esc_state = 0 # Reset the state.
- return
-
- # For some reason, minicom sends a 1 instead of 7. /shrug
- # TODO(aaboagye): Figure out why this happens.
- elif byte == ord('1') or byte == ord('7'):
- self.esc_state = EscState.ESC_BRACKET_1
-
- elif byte == ord('3'):
- self.esc_state = EscState.ESC_BRACKET_3
-
- elif byte == ord('8'):
- self.esc_state = EscState.ESC_BRACKET_8
-
- else:
- self.logger.error(r'Bad or unhandled escape sequence. got ^[%c\(%d)',
- chr(byte), byte)
- self.esc_state = 0
- return
-
- elif self.esc_state is EscState.ESC_BRACKET_1:
- self.logger.debug('ESC_BRACKET_1')
- # HOME key.
- if byte == ord('~'):
- self.logger.debug('Home key pressed.')
- self.MoveCursor('left', self.input_buffer_pos)
- self.esc_state = 0 # Reset the state.
- self.logger.debug('ESC sequence complete.')
- return
-
- elif self.esc_state is EscState.ESC_BRACKET_3:
- self.logger.debug('ESC_BRACKET_3')
- # DEL key.
- if byte == ord('~'):
- self.logger.debug('Delete key pressed.')
- if self.input_buffer_pos != len(self.input_buffer):
- self.SliceOutChar()
- self.esc_state = 0 # Reset the state.
-
- elif self.esc_state is EscState.ESC_BRACKET_8:
- self.logger.debug('ESC_BRACKET_8')
- # END key.
- if byte == ord('~'):
- self.logger.debug('End key pressed.')
- self.MoveCursor('right',
- len(self.input_buffer) - self.input_buffer_pos)
- self.esc_state = 0 # Reset the state.
- self.logger.debug('ESC sequence complete.')
- return
-
- else:
- self.logger.error('Unexpected sequence. %c', byte)
+ def __init__(
+ self, controller_pty, user_pty, interface_pty, cmd_pipe, dbg_pipe, name=None
+ ):
+ """Initalises a Console object with the provided arguments.
+
+ Args:
+ controller_pty: File descriptor to the controller side of the PTY. Used for
+ driving output to the user and receiving user input.
+ user_pty: A string representing the PTY name of the served console.
+ interface_pty: A string representing the PTY name of the served command
+ interface.
+ cmd_pipe: A socket.socket or multiprocessing.Connection object which
+ represents the console side of the command pipe. This must be a
+ bidirectional pipe. Console commands and responses utilize this pipe.
+ dbg_pipe: A socket.socket or multiprocessing.Connection object which
+ represents the console's read-only side of the debug pipe. This must be a
+ unidirectional pipe attached to the intepreter. EC debug messages use
+ this pipe.
+ name: the console source name
+ """
+ # Create a unique logger based on the console name
+ console_prefix = ("%s - " % name) if name else ""
+ logger = logging.getLogger("%sEC3PO.Console" % console_prefix)
+ self.logger = interpreter.LoggerAdapter(logger, {"pty": user_pty})
+ self.controller_pty = controller_pty
+ self.user_pty = user_pty
+ self.interface_pty = interface_pty
+ self.cmd_pipe = cmd_pipe
+ self.dbg_pipe = dbg_pipe
+ self.oobm_queue = threadproc_shim.Queue()
+ self.input_buffer = b""
+ self.input_buffer_pos = 0
+ self.partial_cmd = b""
self.esc_state = 0
+ self.line_limit = CONSOLE_INPUT_LINE_SIZE
+ self.history = []
+ self.history_pos = 0
+ self.prompt = PROMPT
+ self.enhanced_ec = False
+ self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT
+ self.receiving_oobm_cmd = False
+ self.pending_oobm_cmd = b""
+ self.interrogation_mode = b"auto"
+ self.timestamp_enabled = True
+ self.look_buffer = b""
+ self.raw_debug = False
+ self.output_line_log_buffer = []
+
+ def __str__(self):
+ """Show internal state of Console object as a string."""
+ string = []
+ string.append("controller_pty: %s" % self.controller_pty)
+ string.append("user_pty: %s" % self.user_pty)
+ string.append("interface_pty: %s" % self.interface_pty)
+ string.append("cmd_pipe: %s" % self.cmd_pipe)
+ string.append("dbg_pipe: %s" % self.dbg_pipe)
+ string.append("oobm_queue: %s" % self.oobm_queue)
+ string.append("input_buffer: %s" % self.input_buffer)
+ string.append("input_buffer_pos: %d" % self.input_buffer_pos)
+ string.append("esc_state: %d" % self.esc_state)
+ string.append("line_limit: %d" % self.line_limit)
+ string.append("history: %r" % self.history)
+ string.append("history_pos: %d" % self.history_pos)
+ string.append("prompt: %r" % self.prompt)
+ string.append("partial_cmd: %r" % self.partial_cmd)
+ string.append("interrogation_mode: %r" % self.interrogation_mode)
+ string.append("look_buffer: %r" % self.look_buffer)
+ return "\n".join(string)
+
+ def LogConsoleOutput(self, data):
+ """Log to debug user MCU output to controller_pty when line is filled.
+
+ The logging also suppresses the Cr50 spinner lines by removing characters
+ when it sees backspaces.
+
+ Args:
+ data: bytes - string received from MCU
+ """
+ data = list(data)
+ # For compatibility with python2 and python3, standardize on the data
+ # being a list of integers. This requires one more transformation in py2
+ if not isinstance(data[0], int):
+ data = [ord(c) for c in data]
+
+ # This is a list of already filtered characters (or placeholders).
+ line = self.output_line_log_buffer
+
+ # TODO(b/177480273): use raw strings here
+ symbols = {ord(b"\n"): "\\n", ord(b"\r"): "\\r", ord(b"\t"): "\\t"}
+ # self.logger.debug(u'%s + %r', u''.join(line), ''.join(data))
+ while data:
+ # Recall, data is a list of integers, namely the byte values sent by
+ # the MCU.
+ byte = data.pop(0)
+ # This means that |byte| is an int.
+ if byte == ord("\n"):
+ line.append(symbols[byte])
+ if line:
+ self.logger.debug("%s", "".join(line))
+ line = []
+ elif byte == ord("\b"):
+ # Backspace: trim the last character off the buffer
+ if line:
+ line.pop(-1)
+ elif byte in symbols:
+ line.append(symbols[byte])
+ elif byte < ord(" ") or byte > ord("~"):
+ # Turn any character that isn't printable ASCII into escaped hex.
+ # ' ' is chr(20), and 0-19 are unprintable control characters.
+ # '~' is chr(126), and 127 is DELETE. 128-255 are control and Latin-1.
+ line.append("\\x%02x" % byte)
+ else:
+ # byte is printable. Thus it is safe to use chr() to get the printable
+ # character out of it again.
+ line.append("%s" % chr(byte))
+ self.output_line_log_buffer = line
+
+ def PrintHistory(self):
+ """Print the history of entered commands."""
+ fd = self.controller_pty
+ # Make it pretty by figuring out how wide to pad the numbers.
+ wide = (len(self.history) // 10) + 1
+ for i in range(len(self.history)):
+ line = b" %*d %s\r\n" % (wide, i, self.history[i])
+ os.write(fd, line)
+
+ def ShowPreviousCommand(self):
+ """Shows the previous command from the history list."""
+ # There's nothing to do if there's no history at all.
+ if not self.history:
+ self.logger.debug("No history to print.")
+ return
+
+ # Don't do anything if there's no more history to show.
+ if self.history_pos == 0:
+ self.logger.debug("No more history to show.")
+ return
+
+ self.logger.debug("current history position: %d.", self.history_pos)
+
+ # Decrement the history buffer position.
+ self.history_pos -= 1
+ self.logger.debug("new history position.: %d", self.history_pos)
+
+ # Save the text entered on the console if any.
+ if self.history_pos == len(self.history) - 1:
+ self.logger.debug("saving partial_cmd: %r", self.input_buffer)
+ self.partial_cmd = self.input_buffer
+
+ # Backspace the line.
+ for _ in range(self.input_buffer_pos):
+ self.SendBackspace()
+
+ # Print the last entry in the history buffer.
+ self.logger.debug(
+ "printing previous entry %d - %s",
+ self.history_pos,
+ self.history[self.history_pos],
+ )
+ fd = self.controller_pty
+ prev_cmd = self.history[self.history_pos]
+ os.write(fd, prev_cmd)
+ # Update the input buffer.
+ self.input_buffer = prev_cmd
+ self.input_buffer_pos = len(prev_cmd)
+
+ def ShowNextCommand(self):
+ """Shows the next command from the history list."""
+ # Don't do anything if there's no history at all.
+ if not self.history:
+ self.logger.debug("History buffer is empty.")
+ return
+
+ fd = self.controller_pty
+
+ self.logger.debug("current history position: %d", self.history_pos)
+ # Increment the history position.
+ self.history_pos += 1
+
+ # Restore the partial cmd.
+ if self.history_pos == len(self.history):
+ self.logger.debug("Restoring partial command of %r", self.partial_cmd)
+ # Backspace the line.
+ for _ in range(self.input_buffer_pos):
+ self.SendBackspace()
+ # Print the partially entered command if any.
+ os.write(fd, self.partial_cmd)
+ self.input_buffer = self.partial_cmd
+ self.input_buffer_pos = len(self.input_buffer)
+ # Now that we've printed it, clear the partial cmd storage.
+ self.partial_cmd = b""
+ # Reset history position.
+ self.history_pos = len(self.history)
+ return
+
+ self.logger.debug("new history position: %d", self.history_pos)
+ if self.history_pos > len(self.history) - 1:
+ self.logger.debug("No more history to show.")
+ self.history_pos -= 1
+ self.logger.debug("Reset history position to %d", self.history_pos)
+ return
+
+ # Backspace the line.
+ for _ in range(self.input_buffer_pos):
+ self.SendBackspace()
+
+ # Print the newer entry from the history buffer.
+ self.logger.debug(
+ "printing next entry %d - %s",
+ self.history_pos,
+ self.history[self.history_pos],
+ )
+ next_cmd = self.history[self.history_pos]
+ os.write(fd, next_cmd)
+ # Update the input buffer.
+ self.input_buffer = next_cmd
+ self.input_buffer_pos = len(next_cmd)
+ self.logger.debug("new history position: %d.", self.history_pos)
+
+ def SliceOutChar(self):
+ """Remove a char from the line and shift everything over 1 column."""
+ fd = self.controller_pty
+ # Remove the character at the input_buffer_pos by slicing it out.
+ self.input_buffer = (
+ self.input_buffer[0 : self.input_buffer_pos]
+ + self.input_buffer[self.input_buffer_pos + 1 :]
+ )
+ # Write the rest of the line
+ moved_col = os.write(fd, self.input_buffer[self.input_buffer_pos :])
+ # Write a space to clear out the last char
+ moved_col += os.write(fd, b" ")
+ # Update the input buffer position.
+ self.input_buffer_pos += moved_col
+ # Reset the cursor
+ self.MoveCursor("left", moved_col)
+
+ def HandleEsc(self, byte):
+ """HandleEsc processes escape sequences.
+
+ Args:
+ byte: An integer representing the current byte in the sequence.
+ """
+ # We shouldn't be handling an escape sequence if we haven't seen one.
+ assert self.esc_state != 0
+
+ if self.esc_state is EscState.ESC_START:
+ self.logger.debug("ESC_START")
+ if byte == ord("["):
+ self.esc_state = EscState.ESC_BRACKET
+ return
+
+ else:
+ self.logger.error("Unexpected sequence. %c", byte)
+ self.esc_state = 0
+
+ elif self.esc_state is EscState.ESC_BRACKET:
+ self.logger.debug("ESC_BRACKET")
+ # Left Arrow key was pressed.
+ if byte == ord("D"):
+ self.logger.debug("Left arrow key pressed.")
+ self.MoveCursor("left", 1)
+ self.esc_state = 0 # Reset the state.
+ return
+
+ # Right Arrow key.
+ elif byte == ord("C"):
+ self.logger.debug("Right arrow key pressed.")
+ self.MoveCursor("right", 1)
+ self.esc_state = 0 # Reset the state.
+ return
+
+ # Up Arrow key.
+ elif byte == ord("A"):
+ self.logger.debug("Up arrow key pressed.")
+ self.ShowPreviousCommand()
+ # Reset the state.
+ self.esc_state = 0 # Reset the state.
+ return
+
+ # Down Arrow key.
+ elif byte == ord("B"):
+ self.logger.debug("Down arrow key pressed.")
+ self.ShowNextCommand()
+ # Reset the state.
+ self.esc_state = 0 # Reset the state.
+ return
+
+ # For some reason, minicom sends a 1 instead of 7. /shrug
+ # TODO(aaboagye): Figure out why this happens.
+ elif byte == ord("1") or byte == ord("7"):
+ self.esc_state = EscState.ESC_BRACKET_1
+
+ elif byte == ord("3"):
+ self.esc_state = EscState.ESC_BRACKET_3
+
+ elif byte == ord("8"):
+ self.esc_state = EscState.ESC_BRACKET_8
+
+ else:
+ self.logger.error(
+ r"Bad or unhandled escape sequence. got ^[%c\(%d)", chr(byte), byte
+ )
+ self.esc_state = 0
+ return
+
+ elif self.esc_state is EscState.ESC_BRACKET_1:
+ self.logger.debug("ESC_BRACKET_1")
+ # HOME key.
+ if byte == ord("~"):
+ self.logger.debug("Home key pressed.")
+ self.MoveCursor("left", self.input_buffer_pos)
+ self.esc_state = 0 # Reset the state.
+ self.logger.debug("ESC sequence complete.")
+ return
+
+ elif self.esc_state is EscState.ESC_BRACKET_3:
+ self.logger.debug("ESC_BRACKET_3")
+ # DEL key.
+ if byte == ord("~"):
+ self.logger.debug("Delete key pressed.")
+ if self.input_buffer_pos != len(self.input_buffer):
+ self.SliceOutChar()
+ self.esc_state = 0 # Reset the state.
+
+ elif self.esc_state is EscState.ESC_BRACKET_8:
+ self.logger.debug("ESC_BRACKET_8")
+ # END key.
+ if byte == ord("~"):
+ self.logger.debug("End key pressed.")
+ self.MoveCursor("right", len(self.input_buffer) - self.input_buffer_pos)
+ self.esc_state = 0 # Reset the state.
+ self.logger.debug("ESC sequence complete.")
+ return
+
+ else:
+ self.logger.error("Unexpected sequence. %c", byte)
+ self.esc_state = 0
+
+ else:
+ self.logger.error("Unexpected sequence. %c", byte)
+ self.esc_state = 0
+
+ def ProcessInput(self):
+ """Captures the input determines what actions to take."""
+ # There's nothing to do if the input buffer is empty.
+ if len(self.input_buffer) == 0:
+ return
+
+ # Don't store 2 consecutive identical commands in the history.
+ if self.history and self.history[-1] != self.input_buffer or not self.history:
+ self.history.append(self.input_buffer)
+
+ # Split the command up by spaces.
+ line = self.input_buffer.split(b" ")
+ self.logger.debug("cmd: %s", self.input_buffer)
+ cmd = line[0].lower()
+
+ # The 'history' command is a special case that we handle locally.
+ if cmd == "history":
+ self.PrintHistory()
+ return
+
+ # Send the command to the interpreter.
+ self.logger.debug("Sending command to interpreter.")
+ self.cmd_pipe.send(self.input_buffer)
+
+ def CheckForEnhancedECImage(self):
+ """Performs an interrogation of the EC image.
+
+ Send a SYN and expect an ACK. If no ACK or the response is incorrect, then
+ assume that the current EC image that we are talking to is not enhanced.
+
+ Returns:
+ is_enhanced: A boolean indicating whether the EC responded to the
+ interrogation correctly.
+
+ Raises:
+ EOFError: Allowed to propagate through from self.dbg_pipe.recv().
+ """
+ # Send interrogation byte and wait for the response.
+ self.logger.debug("Performing interrogation.")
+ self.cmd_pipe.send(interpreter.EC_SYN)
+
+ response = ""
+ if self.dbg_pipe.poll(self.interrogation_timeout):
+ response = self.dbg_pipe.recv()
+ self.logger.debug("response: %r", binascii.hexlify(response))
+ else:
+ self.logger.debug("Timed out waiting for EC_ACK")
+
+ # Verify the acknowledgment.
+ is_enhanced = response == interpreter.EC_ACK
+
+ if is_enhanced:
+ # Increase the interrogation timeout for stability purposes.
+ self.interrogation_timeout = ENHANCED_EC_INTERROGATION_TIMEOUT
+ self.logger.debug(
+ "Increasing interrogation timeout to %rs.", self.interrogation_timeout
+ )
+ else:
+ # Reduce the timeout in order to reduce the perceivable delay.
+ self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT
+ self.logger.debug(
+ "Reducing interrogation timeout to %rs.", self.interrogation_timeout
+ )
+
+ return is_enhanced
+
+ def HandleChar(self, byte):
+ """HandleChar does a certain action when it receives a character.
+
+ Args:
+ byte: An integer representing the character received from the user.
+
+ Raises:
+ EOFError: Allowed to propagate through from self.CheckForEnhancedECImage()
+ i.e. from self.dbg_pipe.recv().
+ """
+ fd = self.controller_pty
+
+ # Enter the OOBM prompt mode if the user presses '%'.
+ if byte == ord("%"):
+ self.logger.debug("Begin OOBM command.")
+ self.receiving_oobm_cmd = True
+ # Print a "prompt".
+ os.write(self.controller_pty, b"\r\n% ")
+ return
+
+ # Add chars to the pending OOBM command if we're currently receiving one.
+ if self.receiving_oobm_cmd and byte != ControlKey.CARRIAGE_RETURN:
+ tmp_bytes = six.int2byte(byte)
+ self.pending_oobm_cmd += tmp_bytes
+ self.logger.debug("%s", tmp_bytes)
+ os.write(self.controller_pty, tmp_bytes)
+ return
+
+ if byte == ControlKey.CARRIAGE_RETURN:
+ if self.receiving_oobm_cmd:
+ # Terminate the command and place it in the OOBM queue.
+ self.logger.debug("End OOBM command.")
+ if self.pending_oobm_cmd:
+ self.oobm_queue.put(self.pending_oobm_cmd)
+ self.logger.debug(
+ "Placed %r into OOBM command queue.", self.pending_oobm_cmd
+ )
+
+ # Reset the state.
+ os.write(self.controller_pty, b"\r\n" + self.prompt)
+ self.input_buffer = b""
+ self.input_buffer_pos = 0
+ self.receiving_oobm_cmd = False
+ self.pending_oobm_cmd = b""
+ return
+
+ if self.interrogation_mode == b"never":
+ self.logger.debug(
+ "Skipping interrogation because interrogation mode"
+ " is set to never."
+ )
+ elif self.interrogation_mode == b"always":
+ # Only interrogate the EC if the interrogation mode is set to 'always'.
+ self.enhanced_ec = self.CheckForEnhancedECImage()
+ self.logger.debug("Enhanced EC image? %r", self.enhanced_ec)
+
+ if not self.enhanced_ec:
+ # Send everything straight to the EC to handle.
+ self.cmd_pipe.send(six.int2byte(byte))
+ # Reset the input buffer.
+ self.input_buffer = b""
+ self.input_buffer_pos = 0
+ self.logger.log(1, "Reset input buffer.")
+ return
+
+ # Keep handling the ESC sequence if we're in the middle of it.
+ if self.esc_state != 0:
+ self.HandleEsc(byte)
+ return
+
+ # When we're at the end of the line, we should only allow going backwards,
+ # backspace, carriage return, up, or down. The arrow keys are escape
+ # sequences, so we let the escape...escape.
+ if self.input_buffer_pos >= self.line_limit and byte not in [
+ ControlKey.CTRL_B,
+ ControlKey.ESC,
+ ControlKey.BACKSPACE,
+ ControlKey.CTRL_A,
+ ControlKey.CARRIAGE_RETURN,
+ ControlKey.CTRL_P,
+ ControlKey.CTRL_N,
+ ]:
+ return
+
+ # If the input buffer is full we can't accept new chars.
+ buffer_full = len(self.input_buffer) >= self.line_limit
+
+ # Carriage_Return/Enter
+ if byte == ControlKey.CARRIAGE_RETURN:
+ self.logger.debug("Enter key pressed.")
+ # Put a carriage return/newline and the print the prompt.
+ os.write(fd, b"\r\n")
+
+ # TODO(aaboagye): When we control the printing of all output, print the
+ # prompt AFTER printing all the output. We can't do it yet because we
+ # don't know how much is coming from the EC.
+
+ # Print the prompt.
+ os.write(fd, self.prompt)
+ # Process the input.
+ self.ProcessInput()
+ # Now, clear the buffer.
+ self.input_buffer = b""
+ self.input_buffer_pos = 0
+ # Reset history buffer pos.
+ self.history_pos = len(self.history)
+ # Clear partial command.
+ self.partial_cmd = b""
+
+ # Backspace
+ elif byte == ControlKey.BACKSPACE:
+ self.logger.debug("Backspace pressed.")
+ if self.input_buffer_pos > 0:
+ # Move left 1 column.
+ self.MoveCursor("left", 1)
+ # Remove the character at the input_buffer_pos by slicing it out.
+ self.SliceOutChar()
+
+ self.logger.debug("input_buffer_pos: %d", self.input_buffer_pos)
+
+ # Ctrl+A. Move cursor to beginning of the line
+ elif byte == ControlKey.CTRL_A:
+ self.logger.debug("Control+A pressed.")
+ self.MoveCursor("left", self.input_buffer_pos)
+
+ # Ctrl+B. Move cursor left 1 column.
+ elif byte == ControlKey.CTRL_B:
+ self.logger.debug("Control+B pressed.")
+ self.MoveCursor("left", 1)
+
+ # Ctrl+D. Delete a character.
+ elif byte == ControlKey.CTRL_D:
+ self.logger.debug("Control+D pressed.")
+ if self.input_buffer_pos != len(self.input_buffer):
+ # Remove the character by slicing it out.
+ self.SliceOutChar()
+
+ # Ctrl+E. Move cursor to end of the line.
+ elif byte == ControlKey.CTRL_E:
+ self.logger.debug("Control+E pressed.")
+ self.MoveCursor("right", len(self.input_buffer) - self.input_buffer_pos)
+
+ # Ctrl+F. Move cursor right 1 column.
+ elif byte == ControlKey.CTRL_F:
+ self.logger.debug("Control+F pressed.")
+ self.MoveCursor("right", 1)
+
+ # Ctrl+K. Kill line.
+ elif byte == ControlKey.CTRL_K:
+ self.logger.debug("Control+K pressed.")
+ self.KillLine()
+
+ # Ctrl+N. Next line.
+ elif byte == ControlKey.CTRL_N:
+ self.logger.debug("Control+N pressed.")
+ self.ShowNextCommand()
+
+ # Ctrl+P. Previous line.
+ elif byte == ControlKey.CTRL_P:
+ self.logger.debug("Control+P pressed.")
+ self.ShowPreviousCommand()
+
+ # ESC sequence
+ elif byte == ControlKey.ESC:
+ # Starting an ESC sequence
+ self.esc_state = EscState.ESC_START
+
+ # Only print printable chars.
+ elif IsPrintable(byte):
+ # Drop the character if we're full.
+ if buffer_full:
+ self.logger.debug("Dropped char: %c(%d)", byte, byte)
+ return
+ # Print the character.
+ os.write(fd, six.int2byte(byte))
+ # Print the rest of the line (if any).
+ extra_bytes_written = os.write(
+ fd, self.input_buffer[self.input_buffer_pos :]
+ )
+
+ # Recreate the input buffer.
+ self.input_buffer = (
+ self.input_buffer[0 : self.input_buffer_pos]
+ + six.int2byte(byte)
+ + self.input_buffer[self.input_buffer_pos :]
+ )
+ # Update the input buffer position.
+ self.input_buffer_pos += 1 + extra_bytes_written
+
+ # Reset the cursor if we wrote any extra bytes.
+ if extra_bytes_written:
+ self.MoveCursor("left", extra_bytes_written)
+
+ self.logger.debug("input_buffer_pos: %d", self.input_buffer_pos)
+
+ def MoveCursor(self, direction, count):
+ """MoveCursor moves the cursor left or right by count columns.
+
+ Args:
+ direction: A string that should be either 'left' or 'right' representing
+ the direction to move the cursor on the console.
+ count: An integer representing how many columns the cursor should be
+ moved.
+
+ Raises:
+ AssertionError: If the direction is not equal to 'left' or 'right'.
+ """
+ # If there's nothing to move, we're done.
+ if not count:
+ return
+ fd = self.controller_pty
+ seq = b"\033[" + str(count).encode("ascii")
+ if direction == "left":
+ # Bind the movement.
+ if count > self.input_buffer_pos:
+ count = self.input_buffer_pos
+ seq += b"D"
+ self.logger.debug("move cursor left %d", count)
+ self.input_buffer_pos -= count
+
+ elif direction == "right":
+ # Bind the movement.
+ if (count + self.input_buffer_pos) > len(self.input_buffer):
+ count = 0
+ seq += b"C"
+ self.logger.debug("move cursor right %d", count)
+ self.input_buffer_pos += count
+
+ else:
+ raise AssertionError(
+ ("The only valid directions are 'left' and " "'right'")
+ )
+
+ self.logger.debug("input_buffer_pos: %d", self.input_buffer_pos)
+ # Move the cursor.
+ if count != 0:
+ os.write(fd, seq)
+
+ def KillLine(self):
+ """Kill the rest of the line based on the input buffer position."""
+ # Killing the line is killing all the text to the right.
+ diff = len(self.input_buffer) - self.input_buffer_pos
+ self.logger.debug("diff: %d", diff)
+ # Diff shouldn't be negative, but if it is for some reason, let's try to
+ # correct the cursor.
+ if diff < 0:
+ self.logger.warning(
+ "Resetting input buffer position to %d...", len(self.input_buffer)
+ )
+ self.MoveCursor("left", -diff)
+ return
+ if diff:
+ self.MoveCursor("right", diff)
+ for _ in range(diff):
+ self.SendBackspace()
+ self.input_buffer_pos -= diff
+ self.input_buffer = self.input_buffer[0 : self.input_buffer_pos]
+
+ def SendBackspace(self):
+ """Backspace a character on the console."""
+ os.write(self.controller_pty, b"\033[1D \033[1D")
+
+ def ProcessOOBMQueue(self):
+ """Retrieve an item from the OOBM queue and process it."""
+ item = self.oobm_queue.get()
+ self.logger.debug("OOBM cmd: %r", item)
+ cmd = item.split(b" ")
+
+ if cmd[0] == b"loglevel":
+ # An integer is required in order to set the log level.
+ if len(cmd) < 2:
+ self.logger.debug("Insufficient args")
+ self.PrintOOBMHelp()
+ return
+ try:
+ self.logger.debug("Log level change request.")
+ new_log_level = int(cmd[1])
+ self.logger.logger.setLevel(new_log_level)
+ self.logger.info("Log level changed to %d.", new_log_level)
+
+ # Forward the request to the interpreter as well.
+ self.cmd_pipe.send(item)
+ except ValueError:
+ # Ignoring the request if an integer was not provided.
+ self.PrintOOBMHelp()
+
+ elif cmd[0] == b"timestamp":
+ mode = cmd[1].lower()
+ self.timestamp_enabled = mode == b"on"
+ self.logger.info(
+ "%sabling uart timestamps.", "En" if self.timestamp_enabled else "Dis"
+ )
+
+ elif cmd[0] == b"rawdebug":
+ mode = cmd[1].lower()
+ self.raw_debug = mode == b"on"
+ self.logger.info(
+ "%sabling per interrupt debug logs.", "En" if self.raw_debug else "Dis"
+ )
+
+ elif cmd[0] == b"interrogate" and len(cmd) >= 2:
+ enhanced = False
+ mode = cmd[1]
+ if len(cmd) >= 3 and cmd[2] == b"enhanced":
+ enhanced = True
+
+ # Set the mode if correct.
+ if mode in INTERROGATION_MODES:
+ self.interrogation_mode = mode
+ self.logger.debug("Updated interrogation mode to %s.", mode)
+
+ # Update the assumptions of the EC image.
+ self.enhanced_ec = enhanced
+ self.logger.debug("Enhanced EC image is now %r", self.enhanced_ec)
+
+ # Send command to interpreter as well.
+ self.cmd_pipe.send(b"enhanced " + str(self.enhanced_ec).encode("ascii"))
+ else:
+ self.PrintOOBMHelp()
+
+ else:
+ self.PrintOOBMHelp()
+
+ def PrintOOBMHelp(self):
+ """Prints out the OOBM help."""
+ # Print help syntax.
+ os.write(self.controller_pty, b"\r\n" + b"Known OOBM commands:\r\n")
+ os.write(
+ self.controller_pty,
+ b" interrogate <never | always | auto> " b"[enhanced]\r\n",
+ )
+ os.write(self.controller_pty, b" loglevel <int>\r\n")
+
+ def CheckBufferForEnhancedImage(self, data):
+ """Adds data to a look buffer and checks to see for enhanced EC image.
+
+ The EC's console task prints a string upon initialization which says that
+ "Console is enabled; type HELP for help.". The enhanced EC images print a
+ different string as a part of their init. This function searches through a
+ "look" buffer, scanning for the presence of either of those strings and
+ updating the enhanced_ec state accordingly.
+
+ Args:
+ data: A string containing the data sent from the interpreter.
+ """
+ self.look_buffer += data
+
+ # Search the buffer for any of the EC image strings.
+ enhanced_match = re.search(ENHANCED_IMAGE_RE, self.look_buffer)
+ non_enhanced_match = re.search(NON_ENHANCED_IMAGE_RE, self.look_buffer)
+
+ # Update the state if any matches were found.
+ if enhanced_match or non_enhanced_match:
+ if enhanced_match:
+ self.enhanced_ec = True
+ elif non_enhanced_match:
+ self.enhanced_ec = False
+
+ # Inform the interpreter of the result.
+ self.cmd_pipe.send(b"enhanced " + str(self.enhanced_ec).encode("ascii"))
+ self.logger.debug("Enhanced EC image? %r", self.enhanced_ec)
+
+ # Clear look buffer since a match was found.
+ self.look_buffer = b""
+
+ # Move the sliding window.
+ self.look_buffer = self.look_buffer[-LOOK_BUFFER_SIZE:]
- else:
- self.logger.error('Unexpected sequence. %c', byte)
- self.esc_state = 0
-
- def ProcessInput(self):
- """Captures the input determines what actions to take."""
- # There's nothing to do if the input buffer is empty.
- if len(self.input_buffer) == 0:
- return
-
- # Don't store 2 consecutive identical commands in the history.
- if (self.history and self.history[-1] != self.input_buffer
- or not self.history):
- self.history.append(self.input_buffer)
-
- # Split the command up by spaces.
- line = self.input_buffer.split(b' ')
- self.logger.debug('cmd: %s', self.input_buffer)
- cmd = line[0].lower()
-
- # The 'history' command is a special case that we handle locally.
- if cmd == 'history':
- self.PrintHistory()
- return
-
- # Send the command to the interpreter.
- self.logger.debug('Sending command to interpreter.')
- self.cmd_pipe.send(self.input_buffer)
- def CheckForEnhancedECImage(self):
- """Performs an interrogation of the EC image.
+def CanonicalizeTimeString(timestr):
+ """Canonicalize the timestamp string.
- Send a SYN and expect an ACK. If no ACK or the response is incorrect, then
- assume that the current EC image that we are talking to is not enhanced.
+ Args:
+ timestr: A timestamp string ended with 6 digits msec.
Returns:
- is_enhanced: A boolean indicating whether the EC responded to the
- interrogation correctly.
-
- Raises:
- EOFError: Allowed to propagate through from self.dbg_pipe.recv().
+ A string with 3 digits msec and an extra space.
"""
- # Send interrogation byte and wait for the response.
- self.logger.debug('Performing interrogation.')
- self.cmd_pipe.send(interpreter.EC_SYN)
-
- response = ''
- if self.dbg_pipe.poll(self.interrogation_timeout):
- response = self.dbg_pipe.recv()
- self.logger.debug('response: %r', binascii.hexlify(response))
- else:
- self.logger.debug('Timed out waiting for EC_ACK')
+ return timestr[:-3].encode("ascii") + b" "
- # Verify the acknowledgment.
- is_enhanced = response == interpreter.EC_ACK
-
- if is_enhanced:
- # Increase the interrogation timeout for stability purposes.
- self.interrogation_timeout = ENHANCED_EC_INTERROGATION_TIMEOUT
- self.logger.debug('Increasing interrogation timeout to %rs.',
- self.interrogation_timeout)
- else:
- # Reduce the timeout in order to reduce the perceivable delay.
- self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT
- self.logger.debug('Reducing interrogation timeout to %rs.',
- self.interrogation_timeout)
-
- return is_enhanced
-
- def HandleChar(self, byte):
- """HandleChar does a certain action when it receives a character.
-
- Args:
- byte: An integer representing the character received from the user.
- Raises:
- EOFError: Allowed to propagate through from self.CheckForEnhancedECImage()
- i.e. from self.dbg_pipe.recv().
- """
- fd = self.controller_pty
-
- # Enter the OOBM prompt mode if the user presses '%'.
- if byte == ord('%'):
- self.logger.debug('Begin OOBM command.')
- self.receiving_oobm_cmd = True
- # Print a "prompt".
- os.write(self.controller_pty, b'\r\n% ')
- return
-
- # Add chars to the pending OOBM command if we're currently receiving one.
- if self.receiving_oobm_cmd and byte != ControlKey.CARRIAGE_RETURN:
- tmp_bytes = six.int2byte(byte)
- self.pending_oobm_cmd += tmp_bytes
- self.logger.debug('%s', tmp_bytes)
- os.write(self.controller_pty, tmp_bytes)
- return
-
- if byte == ControlKey.CARRIAGE_RETURN:
- if self.receiving_oobm_cmd:
- # Terminate the command and place it in the OOBM queue.
- self.logger.debug('End OOBM command.')
- if self.pending_oobm_cmd:
- self.oobm_queue.put(self.pending_oobm_cmd)
- self.logger.debug('Placed %r into OOBM command queue.',
- self.pending_oobm_cmd)
-
- # Reset the state.
- os.write(self.controller_pty, b'\r\n' + self.prompt)
- self.input_buffer = b''
- self.input_buffer_pos = 0
- self.receiving_oobm_cmd = False
- self.pending_oobm_cmd = b''
- return
-
- if self.interrogation_mode == b'never':
- self.logger.debug('Skipping interrogation because interrogation mode'
- ' is set to never.')
- elif self.interrogation_mode == b'always':
- # Only interrogate the EC if the interrogation mode is set to 'always'.
- self.enhanced_ec = self.CheckForEnhancedECImage()
- self.logger.debug('Enhanced EC image? %r', self.enhanced_ec)
-
- if not self.enhanced_ec:
- # Send everything straight to the EC to handle.
- self.cmd_pipe.send(six.int2byte(byte))
- # Reset the input buffer.
- self.input_buffer = b''
- self.input_buffer_pos = 0
- self.logger.log(1, 'Reset input buffer.')
- return
-
- # Keep handling the ESC sequence if we're in the middle of it.
- if self.esc_state != 0:
- self.HandleEsc(byte)
- return
-
- # When we're at the end of the line, we should only allow going backwards,
- # backspace, carriage return, up, or down. The arrow keys are escape
- # sequences, so we let the escape...escape.
- if (self.input_buffer_pos >= self.line_limit and
- byte not in [ControlKey.CTRL_B, ControlKey.ESC, ControlKey.BACKSPACE,
- ControlKey.CTRL_A, ControlKey.CARRIAGE_RETURN,
- ControlKey.CTRL_P, ControlKey.CTRL_N]):
- return
-
- # If the input buffer is full we can't accept new chars.
- buffer_full = len(self.input_buffer) >= self.line_limit
-
-
- # Carriage_Return/Enter
- if byte == ControlKey.CARRIAGE_RETURN:
- self.logger.debug('Enter key pressed.')
- # Put a carriage return/newline and the print the prompt.
- os.write(fd, b'\r\n')
-
- # TODO(aaboagye): When we control the printing of all output, print the
- # prompt AFTER printing all the output. We can't do it yet because we
- # don't know how much is coming from the EC.
-
- # Print the prompt.
- os.write(fd, self.prompt)
- # Process the input.
- self.ProcessInput()
- # Now, clear the buffer.
- self.input_buffer = b''
- self.input_buffer_pos = 0
- # Reset history buffer pos.
- self.history_pos = len(self.history)
- # Clear partial command.
- self.partial_cmd = b''
-
- # Backspace
- elif byte == ControlKey.BACKSPACE:
- self.logger.debug('Backspace pressed.')
- if self.input_buffer_pos > 0:
- # Move left 1 column.
- self.MoveCursor('left', 1)
- # Remove the character at the input_buffer_pos by slicing it out.
- self.SliceOutChar()
-
- self.logger.debug('input_buffer_pos: %d', self.input_buffer_pos)
-
- # Ctrl+A. Move cursor to beginning of the line
- elif byte == ControlKey.CTRL_A:
- self.logger.debug('Control+A pressed.')
- self.MoveCursor('left', self.input_buffer_pos)
-
- # Ctrl+B. Move cursor left 1 column.
- elif byte == ControlKey.CTRL_B:
- self.logger.debug('Control+B pressed.')
- self.MoveCursor('left', 1)
-
- # Ctrl+D. Delete a character.
- elif byte == ControlKey.CTRL_D:
- self.logger.debug('Control+D pressed.')
- if self.input_buffer_pos != len(self.input_buffer):
- # Remove the character by slicing it out.
- self.SliceOutChar()
-
- # Ctrl+E. Move cursor to end of the line.
- elif byte == ControlKey.CTRL_E:
- self.logger.debug('Control+E pressed.')
- self.MoveCursor('right',
- len(self.input_buffer) - self.input_buffer_pos)
-
- # Ctrl+F. Move cursor right 1 column.
- elif byte == ControlKey.CTRL_F:
- self.logger.debug('Control+F pressed.')
- self.MoveCursor('right', 1)
-
- # Ctrl+K. Kill line.
- elif byte == ControlKey.CTRL_K:
- self.logger.debug('Control+K pressed.')
- self.KillLine()
-
- # Ctrl+N. Next line.
- elif byte == ControlKey.CTRL_N:
- self.logger.debug('Control+N pressed.')
- self.ShowNextCommand()
-
- # Ctrl+P. Previous line.
- elif byte == ControlKey.CTRL_P:
- self.logger.debug('Control+P pressed.')
- self.ShowPreviousCommand()
-
- # ESC sequence
- elif byte == ControlKey.ESC:
- # Starting an ESC sequence
- self.esc_state = EscState.ESC_START
-
- # Only print printable chars.
- elif IsPrintable(byte):
- # Drop the character if we're full.
- if buffer_full:
- self.logger.debug('Dropped char: %c(%d)', byte, byte)
- return
- # Print the character.
- os.write(fd, six.int2byte(byte))
- # Print the rest of the line (if any).
- extra_bytes_written = os.write(fd,
- self.input_buffer[self.input_buffer_pos:])
-
- # Recreate the input buffer.
- self.input_buffer = (self.input_buffer[0:self.input_buffer_pos] +
- six.int2byte(byte) +
- self.input_buffer[self.input_buffer_pos:])
- # Update the input buffer position.
- self.input_buffer_pos += 1 + extra_bytes_written
-
- # Reset the cursor if we wrote any extra bytes.
- if extra_bytes_written:
- self.MoveCursor('left', extra_bytes_written)
-
- self.logger.debug('input_buffer_pos: %d', self.input_buffer_pos)
-
- def MoveCursor(self, direction, count):
- """MoveCursor moves the cursor left or right by count columns.
+def IsPrintable(byte):
+ """Determines if a byte is printable.
Args:
- direction: A string that should be either 'left' or 'right' representing
- the direction to move the cursor on the console.
- count: An integer representing how many columns the cursor should be
- moved.
+ byte: An integer potentially representing a printable character.
- Raises:
- AssertionError: If the direction is not equal to 'left' or 'right'.
+ Returns:
+ A boolean indicating whether the byte is a printable character.
"""
- # If there's nothing to move, we're done.
- if not count:
- return
- fd = self.controller_pty
- seq = b'\033[' + str(count).encode('ascii')
- if direction == 'left':
- # Bind the movement.
- if count > self.input_buffer_pos:
- count = self.input_buffer_pos
- seq += b'D'
- self.logger.debug('move cursor left %d', count)
- self.input_buffer_pos -= count
-
- elif direction == 'right':
- # Bind the movement.
- if (count + self.input_buffer_pos) > len(self.input_buffer):
- count = 0
- seq += b'C'
- self.logger.debug('move cursor right %d', count)
- self.input_buffer_pos += count
-
- else:
- raise AssertionError(('The only valid directions are \'left\' and '
- '\'right\''))
-
- self.logger.debug('input_buffer_pos: %d', self.input_buffer_pos)
- # Move the cursor.
- if count != 0:
- os.write(fd, seq)
-
- def KillLine(self):
- """Kill the rest of the line based on the input buffer position."""
- # Killing the line is killing all the text to the right.
- diff = len(self.input_buffer) - self.input_buffer_pos
- self.logger.debug('diff: %d', diff)
- # Diff shouldn't be negative, but if it is for some reason, let's try to
- # correct the cursor.
- if diff < 0:
- self.logger.warning('Resetting input buffer position to %d...',
- len(self.input_buffer))
- self.MoveCursor('left', -diff)
- return
- if diff:
- self.MoveCursor('right', diff)
- for _ in range(diff):
- self.SendBackspace()
- self.input_buffer_pos -= diff
- self.input_buffer = self.input_buffer[0:self.input_buffer_pos]
-
- def SendBackspace(self):
- """Backspace a character on the console."""
- os.write(self.controller_pty, b'\033[1D \033[1D')
-
- def ProcessOOBMQueue(self):
- """Retrieve an item from the OOBM queue and process it."""
- item = self.oobm_queue.get()
- self.logger.debug('OOBM cmd: %r', item)
- cmd = item.split(b' ')
-
- if cmd[0] == b'loglevel':
- # An integer is required in order to set the log level.
- if len(cmd) < 2:
- self.logger.debug('Insufficient args')
- self.PrintOOBMHelp()
- return
- try:
- self.logger.debug('Log level change request.')
- new_log_level = int(cmd[1])
- self.logger.logger.setLevel(new_log_level)
- self.logger.info('Log level changed to %d.', new_log_level)
-
- # Forward the request to the interpreter as well.
- self.cmd_pipe.send(item)
- except ValueError:
- # Ignoring the request if an integer was not provided.
- self.PrintOOBMHelp()
-
- elif cmd[0] == b'timestamp':
- mode = cmd[1].lower()
- self.timestamp_enabled = (mode == b'on')
- self.logger.info('%sabling uart timestamps.',
- 'En' if self.timestamp_enabled else 'Dis')
-
- elif cmd[0] == b'rawdebug':
- mode = cmd[1].lower()
- self.raw_debug = (mode == b'on')
- self.logger.info('%sabling per interrupt debug logs.',
- 'En' if self.raw_debug else 'Dis')
-
- elif cmd[0] == b'interrogate' and len(cmd) >= 2:
- enhanced = False
- mode = cmd[1]
- if len(cmd) >= 3 and cmd[2] == b'enhanced':
- enhanced = True
-
- # Set the mode if correct.
- if mode in INTERROGATION_MODES:
- self.interrogation_mode = mode
- self.logger.debug('Updated interrogation mode to %s.', mode)
-
- # Update the assumptions of the EC image.
- self.enhanced_ec = enhanced
- self.logger.debug('Enhanced EC image is now %r', self.enhanced_ec)
-
- # Send command to interpreter as well.
- self.cmd_pipe.send(b'enhanced ' + str(self.enhanced_ec).encode('ascii'))
- else:
- self.PrintOOBMHelp()
-
- else:
- self.PrintOOBMHelp()
+ return byte >= ord(" ") and byte <= ord("~")
- def PrintOOBMHelp(self):
- """Prints out the OOBM help."""
- # Print help syntax.
- os.write(self.controller_pty, b'\r\n' + b'Known OOBM commands:\r\n')
- os.write(self.controller_pty, b' interrogate <never | always | auto> '
- b'[enhanced]\r\n')
- os.write(self.controller_pty, b' loglevel <int>\r\n')
- def CheckBufferForEnhancedImage(self, data):
- """Adds data to a look buffer and checks to see for enhanced EC image.
-
- The EC's console task prints a string upon initialization which says that
- "Console is enabled; type HELP for help.". The enhanced EC images print a
- different string as a part of their init. This function searches through a
- "look" buffer, scanning for the presence of either of those strings and
- updating the enhanced_ec state accordingly.
+def StartLoop(console, command_active, shutdown_pipe=None):
+ """Starts the infinite loop of console processing.
Args:
- data: A string containing the data sent from the interpreter.
+ console: A Console object that has been properly initialzed.
+ command_active: ctypes data object or multiprocessing.Value indicating if
+ servod owns the console, or user owns the console. This prevents input
+ collisions.
+ shutdown_pipe: A file object for a pipe or equivalent that becomes readable
+ (not blocked) to indicate that the loop should exit. Can be None to never
+ exit the loop.
"""
- self.look_buffer += data
-
- # Search the buffer for any of the EC image strings.
- enhanced_match = re.search(ENHANCED_IMAGE_RE, self.look_buffer)
- non_enhanced_match = re.search(NON_ENHANCED_IMAGE_RE, self.look_buffer)
-
- # Update the state if any matches were found.
- if enhanced_match or non_enhanced_match:
- if enhanced_match:
- self.enhanced_ec = True
- elif non_enhanced_match:
- self.enhanced_ec = False
-
- # Inform the interpreter of the result.
- self.cmd_pipe.send(b'enhanced ' + str(self.enhanced_ec).encode('ascii'))
- self.logger.debug('Enhanced EC image? %r', self.enhanced_ec)
-
- # Clear look buffer since a match was found.
- self.look_buffer = b''
-
- # Move the sliding window.
- self.look_buffer = self.look_buffer[-LOOK_BUFFER_SIZE:]
-
-
-def CanonicalizeTimeString(timestr):
- """Canonicalize the timestamp string.
-
- Args:
- timestr: A timestamp string ended with 6 digits msec.
-
- Returns:
- A string with 3 digits msec and an extra space.
- """
- return timestr[:-3].encode('ascii') + b' '
-
-
-def IsPrintable(byte):
- """Determines if a byte is printable.
-
- Args:
- byte: An integer potentially representing a printable character.
-
- Returns:
- A boolean indicating whether the byte is a printable character.
- """
- return byte >= ord(' ') and byte <= ord('~')
-
-
-def StartLoop(console, command_active, shutdown_pipe=None):
- """Starts the infinite loop of console processing.
-
- Args:
- console: A Console object that has been properly initialzed.
- command_active: ctypes data object or multiprocessing.Value indicating if
- servod owns the console, or user owns the console. This prevents input
- collisions.
- shutdown_pipe: A file object for a pipe or equivalent that becomes readable
- (not blocked) to indicate that the loop should exit. Can be None to never
- exit the loop.
- """
- try:
- console.logger.debug('Console is being served on %s.', console.user_pty)
- console.logger.debug('Console controller is on %s.', console.controller_pty)
- console.logger.debug('Command interface is being served on %s.',
- console.interface_pty)
- console.logger.debug(console)
-
- # This checks for HUP to indicate if the user has connected to the pty.
- ep = select.epoll()
- ep.register(console.controller_pty, select.EPOLLHUP)
-
- # This is used instead of "break" to avoid exiting the loop in the middle of
- # an iteration.
- continue_looping = True
-
- # Used for determining when to print host timestamps
- tm_req = True
-
- while continue_looping:
- # Check to see if pts is connected to anything
- events = ep.poll(0)
- controller_connected = not events
-
- # Check to see if pipes or the console are ready for reading.
- read_list = [console.interface_pty,
- console.cmd_pipe, console.dbg_pipe]
- if controller_connected:
- read_list.append(console.controller_pty)
- if shutdown_pipe is not None:
- read_list.append(shutdown_pipe)
-
- # Check if any input is ready, or wait for .1 sec and re-poll if
- # a user has connected to the pts.
- select_output = select.select(read_list, [], [], .1)
- if not select_output:
- continue
- ready_for_reading = select_output[0]
-
- for obj in ready_for_reading:
- if obj is console.controller_pty:
- if not command_active.value:
- # Convert to bytes so we can look for non-printable chars such as
- # Ctrl+A, Ctrl+E, etc.
- try:
- line = bytearray(os.read(console.controller_pty, CONSOLE_MAX_READ))
- console.logger.debug('Input from user: %s, locked:%s',
- str(line).strip(), command_active.value)
- for i in line:
- try:
- # Handle each character as it arrives.
- console.HandleChar(i)
- except EOFError:
- console.logger.debug(
- 'ec3po console received EOF from dbg_pipe in HandleChar()'
- ' while reading console.controller_pty')
- continue_looping = False
- break
- except OSError:
- console.logger.debug('Ptm read failed, probably user disconnect.')
-
- elif obj is console.interface_pty:
- if command_active.value:
- # Convert to bytes so we can look for non-printable chars such as
- # Ctrl+A, Ctrl+E, etc.
- line = bytearray(os.read(console.interface_pty, CONSOLE_MAX_READ))
- console.logger.debug('Input from interface: %s, locked:%s',
- str(line).strip(), command_active.value)
- for i in line:
- try:
- # Handle each character as it arrives.
- console.HandleChar(i)
- except EOFError:
- console.logger.debug(
- 'ec3po console received EOF from dbg_pipe in HandleChar()'
- ' while reading console.interface_pty')
- continue_looping = False
- break
-
- elif obj is console.cmd_pipe:
- try:
- data = console.cmd_pipe.recv()
- except EOFError:
- console.logger.debug('ec3po console received EOF from cmd_pipe')
- continue_looping = False
- else:
- # Write it to the user console.
- if console.raw_debug:
- console.logger.debug('|CMD|-%s->%r',
- ('u' if controller_connected else '') +
- ('i' if command_active.value else ''),
- data.strip())
+ try:
+ console.logger.debug("Console is being served on %s.", console.user_pty)
+ console.logger.debug("Console controller is on %s.", console.controller_pty)
+ console.logger.debug(
+ "Command interface is being served on %s.", console.interface_pty
+ )
+ console.logger.debug(console)
+
+ # This checks for HUP to indicate if the user has connected to the pty.
+ ep = select.epoll()
+ ep.register(console.controller_pty, select.EPOLLHUP)
+
+ # This is used instead of "break" to avoid exiting the loop in the middle of
+ # an iteration.
+ continue_looping = True
+
+ # Used for determining when to print host timestamps
+ tm_req = True
+
+ while continue_looping:
+ # Check to see if pts is connected to anything
+ events = ep.poll(0)
+ controller_connected = not events
+
+ # Check to see if pipes or the console are ready for reading.
+ read_list = [console.interface_pty, console.cmd_pipe, console.dbg_pipe]
if controller_connected:
- os.write(console.controller_pty, data)
- if command_active.value:
- os.write(console.interface_pty, data)
-
- elif obj is console.dbg_pipe:
- try:
- data = console.dbg_pipe.recv()
- except EOFError:
- console.logger.debug('ec3po console received EOF from dbg_pipe')
- continue_looping = False
- else:
- if console.interrogation_mode == b'auto':
- # Search look buffer for enhanced EC image string.
- console.CheckBufferForEnhancedImage(data)
- # Write it to the user console.
- if len(data) > 1 and console.raw_debug:
- console.logger.debug('|DBG|-%s->%r',
- ('u' if controller_connected else '') +
- ('i' if command_active.value else ''),
- data.strip())
- console.LogConsoleOutput(data)
- if controller_connected:
- end = len(data) - 1
- if console.timestamp_enabled:
- # A timestamp is required at the beginning of this line
- if tm_req is True:
- now = datetime.now()
- tm = CanonicalizeTimeString(now.strftime(HOST_STRFTIME))
- os.write(console.controller_pty, tm)
- tm_req = False
-
- # Insert timestamps into the middle where appropriate
- # except if the last character is a newline
- nls_found = data.count(b'\n', 0, end)
- now = datetime.now()
- tm = CanonicalizeTimeString(now.strftime('\n' + HOST_STRFTIME))
- data_tm = data.replace(b'\n', tm, nls_found)
- else:
- data_tm = data
-
- # timestamp required on next input
- if data[end] == b'\n'[0]:
- tm_req = True
- os.write(console.controller_pty, data_tm)
- if command_active.value:
- os.write(console.interface_pty, data)
-
- elif obj is shutdown_pipe:
- console.logger.debug(
- 'ec3po console received shutdown pipe unblocked notification')
- continue_looping = False
-
- while not console.oobm_queue.empty():
- console.logger.debug('OOBM queue ready for reading.')
- console.ProcessOOBMQueue()
-
- except KeyboardInterrupt:
- pass
-
- finally:
- ep.unregister(console.controller_pty)
- console.dbg_pipe.close()
- console.cmd_pipe.close()
- os.close(console.controller_pty)
- os.close(console.interface_pty)
- if shutdown_pipe is not None:
- shutdown_pipe.close()
- console.logger.debug('Exit ec3po console loop for %s', console.user_pty)
+ read_list.append(console.controller_pty)
+ if shutdown_pipe is not None:
+ read_list.append(shutdown_pipe)
+
+ # Check if any input is ready, or wait for .1 sec and re-poll if
+ # a user has connected to the pts.
+ select_output = select.select(read_list, [], [], 0.1)
+ if not select_output:
+ continue
+ ready_for_reading = select_output[0]
+
+ for obj in ready_for_reading:
+ if obj is console.controller_pty:
+ if not command_active.value:
+ # Convert to bytes so we can look for non-printable chars such as
+ # Ctrl+A, Ctrl+E, etc.
+ try:
+ line = bytearray(
+ os.read(console.controller_pty, CONSOLE_MAX_READ)
+ )
+ console.logger.debug(
+ "Input from user: %s, locked:%s",
+ str(line).strip(),
+ command_active.value,
+ )
+ for i in line:
+ try:
+ # Handle each character as it arrives.
+ console.HandleChar(i)
+ except EOFError:
+ console.logger.debug(
+ "ec3po console received EOF from dbg_pipe in HandleChar()"
+ " while reading console.controller_pty"
+ )
+ continue_looping = False
+ break
+ except OSError:
+ console.logger.debug(
+ "Ptm read failed, probably user disconnect."
+ )
+
+ elif obj is console.interface_pty:
+ if command_active.value:
+ # Convert to bytes so we can look for non-printable chars such as
+ # Ctrl+A, Ctrl+E, etc.
+ line = bytearray(
+ os.read(console.interface_pty, CONSOLE_MAX_READ)
+ )
+ console.logger.debug(
+ "Input from interface: %s, locked:%s",
+ str(line).strip(),
+ command_active.value,
+ )
+ for i in line:
+ try:
+ # Handle each character as it arrives.
+ console.HandleChar(i)
+ except EOFError:
+ console.logger.debug(
+ "ec3po console received EOF from dbg_pipe in HandleChar()"
+ " while reading console.interface_pty"
+ )
+ continue_looping = False
+ break
+
+ elif obj is console.cmd_pipe:
+ try:
+ data = console.cmd_pipe.recv()
+ except EOFError:
+ console.logger.debug("ec3po console received EOF from cmd_pipe")
+ continue_looping = False
+ else:
+ # Write it to the user console.
+ if console.raw_debug:
+ console.logger.debug(
+ "|CMD|-%s->%r",
+ ("u" if controller_connected else "")
+ + ("i" if command_active.value else ""),
+ data.strip(),
+ )
+ if controller_connected:
+ os.write(console.controller_pty, data)
+ if command_active.value:
+ os.write(console.interface_pty, data)
+
+ elif obj is console.dbg_pipe:
+ try:
+ data = console.dbg_pipe.recv()
+ except EOFError:
+ console.logger.debug("ec3po console received EOF from dbg_pipe")
+ continue_looping = False
+ else:
+ if console.interrogation_mode == b"auto":
+ # Search look buffer for enhanced EC image string.
+ console.CheckBufferForEnhancedImage(data)
+ # Write it to the user console.
+ if len(data) > 1 and console.raw_debug:
+ console.logger.debug(
+ "|DBG|-%s->%r",
+ ("u" if controller_connected else "")
+ + ("i" if command_active.value else ""),
+ data.strip(),
+ )
+ console.LogConsoleOutput(data)
+ if controller_connected:
+ end = len(data) - 1
+ if console.timestamp_enabled:
+ # A timestamp is required at the beginning of this line
+ if tm_req is True:
+ now = datetime.now()
+ tm = CanonicalizeTimeString(
+ now.strftime(HOST_STRFTIME)
+ )
+ os.write(console.controller_pty, tm)
+ tm_req = False
+
+ # Insert timestamps into the middle where appropriate
+ # except if the last character is a newline
+ nls_found = data.count(b"\n", 0, end)
+ now = datetime.now()
+ tm = CanonicalizeTimeString(
+ now.strftime("\n" + HOST_STRFTIME)
+ )
+ data_tm = data.replace(b"\n", tm, nls_found)
+ else:
+ data_tm = data
+
+ # timestamp required on next input
+ if data[end] == b"\n"[0]:
+ tm_req = True
+ os.write(console.controller_pty, data_tm)
+ if command_active.value:
+ os.write(console.interface_pty, data)
+
+ elif obj is shutdown_pipe:
+ console.logger.debug(
+ "ec3po console received shutdown pipe unblocked notification"
+ )
+ continue_looping = False
+
+ while not console.oobm_queue.empty():
+ console.logger.debug("OOBM queue ready for reading.")
+ console.ProcessOOBMQueue()
+
+ except KeyboardInterrupt:
+ pass
+
+ finally:
+ ep.unregister(console.controller_pty)
+ console.dbg_pipe.close()
+ console.cmd_pipe.close()
+ os.close(console.controller_pty)
+ os.close(console.interface_pty)
+ if shutdown_pipe is not None:
+ shutdown_pipe.close()
+ console.logger.debug("Exit ec3po console loop for %s", console.user_pty)
def main(argv):
- """Kicks off the EC-3PO interactive console interface and interpreter.
-
- We create some pipes to communicate with an interpreter, instantiate an
- interpreter, create a PTY pair, and begin serving the console interface.
-
- Args:
- argv: A list of strings containing the arguments this module was called
- with.
- """
- # Set up argument parser.
- parser = argparse.ArgumentParser(description=('Start interactive EC console '
- 'and interpreter.'))
- parser.add_argument('ec_uart_pty',
- help=('The full PTY name that the EC UART'
- ' is present on. eg: /dev/pts/12'))
- parser.add_argument('--log-level',
- default='info',
- help='info, debug, warning, error, or critical')
-
- # Parse arguments.
- opts = parser.parse_args(argv)
-
- # Set logging level.
- opts.log_level = opts.log_level.lower()
- if opts.log_level == 'info':
- log_level = logging.INFO
- elif opts.log_level == 'debug':
- log_level = logging.DEBUG
- elif opts.log_level == 'warning':
- log_level = logging.WARNING
- elif opts.log_level == 'error':
- log_level = logging.ERROR
- elif opts.log_level == 'critical':
- log_level = logging.CRITICAL
- else:
- parser.error('Invalid log level. (info, debug, warning, error, critical)')
-
- # Start logging with a timestamp, module, and log level shown in each log
- # entry.
- logging.basicConfig(level=log_level, format=('%(asctime)s - %(module)s -'
- ' %(levelname)s - %(message)s'))
-
- # Create some pipes to communicate between the interpreter and the console.
- # The command pipe is bidirectional.
- cmd_pipe_interactive, cmd_pipe_interp = threadproc_shim.Pipe()
- # The debug pipe is unidirectional from interpreter to console only.
- dbg_pipe_interactive, dbg_pipe_interp = threadproc_shim.Pipe(duplex=False)
-
- # Create an interpreter instance.
- itpr = interpreter.Interpreter(opts.ec_uart_pty, cmd_pipe_interp,
- dbg_pipe_interp, log_level)
-
- # Spawn an interpreter process.
- itpr_process = threadproc_shim.ThreadOrProcess(
- target=interpreter.StartLoop, args=(itpr,))
- # Make sure to kill the interpreter when we terminate.
- itpr_process.daemon = True
- # Start the interpreter.
- itpr_process.start()
-
- # Open a new pseudo-terminal pair
- (controller_pty, user_pty) = pty.openpty()
- # Set the permissions to 660.
- os.chmod(os.ttyname(user_pty), (stat.S_IRGRP | stat.S_IWGRP |
- stat.S_IRUSR | stat.S_IWUSR))
- # Create a console.
- console = Console(controller_pty, os.ttyname(user_pty), cmd_pipe_interactive,
- dbg_pipe_interactive)
- # Start serving the console.
- v = threadproc_shim.Value(ctypes.c_bool, False)
- StartLoop(console, v)
-
-
-if __name__ == '__main__':
- main(sys.argv[1:])
+ """Kicks off the EC-3PO interactive console interface and interpreter.
+
+ We create some pipes to communicate with an interpreter, instantiate an
+ interpreter, create a PTY pair, and begin serving the console interface.
+
+ Args:
+ argv: A list of strings containing the arguments this module was called
+ with.
+ """
+ # Set up argument parser.
+ parser = argparse.ArgumentParser(
+ description=("Start interactive EC console " "and interpreter.")
+ )
+ parser.add_argument(
+ "ec_uart_pty",
+ help=("The full PTY name that the EC UART" " is present on. eg: /dev/pts/12"),
+ )
+ parser.add_argument(
+ "--log-level", default="info", help="info, debug, warning, error, or critical"
+ )
+
+ # Parse arguments.
+ opts = parser.parse_args(argv)
+
+ # Set logging level.
+ opts.log_level = opts.log_level.lower()
+ if opts.log_level == "info":
+ log_level = logging.INFO
+ elif opts.log_level == "debug":
+ log_level = logging.DEBUG
+ elif opts.log_level == "warning":
+ log_level = logging.WARNING
+ elif opts.log_level == "error":
+ log_level = logging.ERROR
+ elif opts.log_level == "critical":
+ log_level = logging.CRITICAL
+ else:
+ parser.error("Invalid log level. (info, debug, warning, error, critical)")
+
+ # Start logging with a timestamp, module, and log level shown in each log
+ # entry.
+ logging.basicConfig(
+ level=log_level,
+ format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"),
+ )
+
+ # Create some pipes to communicate between the interpreter and the console.
+ # The command pipe is bidirectional.
+ cmd_pipe_interactive, cmd_pipe_interp = threadproc_shim.Pipe()
+ # The debug pipe is unidirectional from interpreter to console only.
+ dbg_pipe_interactive, dbg_pipe_interp = threadproc_shim.Pipe(duplex=False)
+
+ # Create an interpreter instance.
+ itpr = interpreter.Interpreter(
+ opts.ec_uart_pty, cmd_pipe_interp, dbg_pipe_interp, log_level
+ )
+
+ # Spawn an interpreter process.
+ itpr_process = threadproc_shim.ThreadOrProcess(
+ target=interpreter.StartLoop, args=(itpr,)
+ )
+ # Make sure to kill the interpreter when we terminate.
+ itpr_process.daemon = True
+ # Start the interpreter.
+ itpr_process.start()
+
+ # Open a new pseudo-terminal pair
+ (controller_pty, user_pty) = pty.openpty()
+ # Set the permissions to 660.
+ os.chmod(
+ os.ttyname(user_pty),
+ (stat.S_IRGRP | stat.S_IWGRP | stat.S_IRUSR | stat.S_IWUSR),
+ )
+ # Create a console.
+ console = Console(
+ controller_pty, os.ttyname(user_pty), cmd_pipe_interactive, dbg_pipe_interactive
+ )
+ # Start serving the console.
+ v = threadproc_shim.Value(ctypes.c_bool, False)
+ StartLoop(console, v)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/util/ec3po/console_unittest.py b/util/ec3po/console_unittest.py
index 7e341e7e8d..41ae324ef4 100755
--- a/util/ec3po/console_unittest.py
+++ b/util/ec3po/console_unittest.py
@@ -11,1262 +11,1310 @@ from __future__ import print_function
import binascii
import logging
-import mock
import tempfile
import unittest
+import mock
import six
-
-from ec3po import console
-from ec3po import interpreter
-from ec3po import threadproc_shim
+from ec3po import console, interpreter, threadproc_shim
ESC_STRING = six.int2byte(console.ControlKey.ESC)
+
class Keys(object):
- """A class that contains the escape sequences for special keys."""
- LEFT_ARROW = [console.ControlKey.ESC, ord('['), ord('D')]
- RIGHT_ARROW = [console.ControlKey.ESC, ord('['), ord('C')]
- UP_ARROW = [console.ControlKey.ESC, ord('['), ord('A')]
- DOWN_ARROW = [console.ControlKey.ESC, ord('['), ord('B')]
- HOME = [console.ControlKey.ESC, ord('['), ord('1'), ord('~')]
- END = [console.ControlKey.ESC, ord('['), ord('8'), ord('~')]
- DEL = [console.ControlKey.ESC, ord('['), ord('3'), ord('~')]
+ """A class that contains the escape sequences for special keys."""
+
+ LEFT_ARROW = [console.ControlKey.ESC, ord("["), ord("D")]
+ RIGHT_ARROW = [console.ControlKey.ESC, ord("["), ord("C")]
+ UP_ARROW = [console.ControlKey.ESC, ord("["), ord("A")]
+ DOWN_ARROW = [console.ControlKey.ESC, ord("["), ord("B")]
+ HOME = [console.ControlKey.ESC, ord("["), ord("1"), ord("~")]
+ END = [console.ControlKey.ESC, ord("["), ord("8"), ord("~")]
+ DEL = [console.ControlKey.ESC, ord("["), ord("3"), ord("~")]
+
class OutputStream(object):
- """A class that has methods which return common console output."""
+ """A class that has methods which return common console output."""
- @staticmethod
- def MoveCursorLeft(count):
- """Produces what would be printed to the console if the cursor moved left.
+ @staticmethod
+ def MoveCursorLeft(count):
+ """Produces what would be printed to the console if the cursor moved left.
- Args:
- count: An integer representing how many columns to move left.
+ Args:
+ count: An integer representing how many columns to move left.
- Returns:
- string: A string which contains what would be printed to the console if
- the cursor moved left.
- """
- string = ESC_STRING
- string += b'[' + str(count).encode('ascii') + b'D'
- return string
+ Returns:
+ string: A string which contains what would be printed to the console if
+ the cursor moved left.
+ """
+ string = ESC_STRING
+ string += b"[" + str(count).encode("ascii") + b"D"
+ return string
- @staticmethod
- def MoveCursorRight(count):
- """Produces what would be printed to the console if the cursor moved right.
+ @staticmethod
+ def MoveCursorRight(count):
+ """Produces what would be printed to the console if the cursor moved right.
- Args:
- count: An integer representing how many columns to move right.
+ Args:
+ count: An integer representing how many columns to move right.
- Returns:
- string: A string which contains what would be printed to the console if
- the cursor moved right.
- """
- string = ESC_STRING
- string += b'[' + str(count).encode('ascii') + b'C'
- return string
+ Returns:
+ string: A string which contains what would be printed to the console if
+ the cursor moved right.
+ """
+ string = ESC_STRING
+ string += b"[" + str(count).encode("ascii") + b"C"
+ return string
-BACKSPACE_STRING = b''
+
+BACKSPACE_STRING = b""
# Move cursor left 1 column.
BACKSPACE_STRING += OutputStream.MoveCursorLeft(1)
# Write a space.
-BACKSPACE_STRING += b' '
+BACKSPACE_STRING += b" "
# Move cursor left 1 column.
BACKSPACE_STRING += OutputStream.MoveCursorLeft(1)
+
def BytesToByteList(string):
- """Converts a bytes string to list of bytes.
+ """Converts a bytes string to list of bytes.
+
+ Args:
+ string: A literal bytes to turn into a list of bytes.
- Args:
- string: A literal bytes to turn into a list of bytes.
+ Returns:
+ A list of integers representing the byte value of each character in the
+ string.
+ """
+ if six.PY3:
+ return [c for c in string]
+ return [ord(c) for c in string]
- Returns:
- A list of integers representing the byte value of each character in the
- string.
- """
- if six.PY3:
- return [c for c in string]
- return [ord(c) for c in string]
def CheckConsoleOutput(test_case, exp_console_out):
- """Verify what was sent out the console matches what we expect.
+ """Verify what was sent out the console matches what we expect.
- Args:
- test_case: A unittest.TestCase object representing the current unit test.
- exp_console_out: A string representing the console output stream.
- """
- # Read what was sent out the console.
- test_case.tempfile.seek(0)
- console_out = test_case.tempfile.read()
+ Args:
+ test_case: A unittest.TestCase object representing the current unit test.
+ exp_console_out: A string representing the console output stream.
+ """
+ # Read what was sent out the console.
+ test_case.tempfile.seek(0)
+ console_out = test_case.tempfile.read()
- test_case.assertEqual(exp_console_out, console_out)
+ test_case.assertEqual(exp_console_out, console_out)
-def CheckInputBuffer(test_case, exp_input_buffer):
- """Verify that the input buffer contains what we expect.
-
- Args:
- test_case: A unittest.TestCase object representing the current unit test.
- exp_input_buffer: A string containing the contents of the current input
- buffer.
- """
- test_case.assertEqual(exp_input_buffer, test_case.console.input_buffer,
- (b'input buffer does not match expected.\n'
- b'expected: |' + exp_input_buffer + b'|\n'
- b'got: |' + test_case.console.input_buffer +
- b'|\n' + str(test_case.console).encode('ascii')))
-def CheckInputBufferPosition(test_case, exp_pos):
- """Verify the input buffer position.
+def CheckInputBuffer(test_case, exp_input_buffer):
+ """Verify that the input buffer contains what we expect.
- Args:
- test_case: A unittest.TestCase object representing the current unit test.
- exp_pos: An integer representing the expected input buffer position.
- """
- test_case.assertEqual(exp_pos, test_case.console.input_buffer_pos,
- 'input buffer position is incorrect.\ngot: ' +
- str(test_case.console.input_buffer_pos) + '\nexp: ' +
- str(exp_pos) + '\n' + str(test_case.console))
+ Args:
+ test_case: A unittest.TestCase object representing the current unit test.
+ exp_input_buffer: A string containing the contents of the current input
+ buffer.
+ """
+ test_case.assertEqual(
+ exp_input_buffer,
+ test_case.console.input_buffer,
+ (
+ b"input buffer does not match expected.\n"
+ b"expected: |" + exp_input_buffer + b"|\n"
+ b"got: |"
+ + test_case.console.input_buffer
+ + b"|\n"
+ + str(test_case.console).encode("ascii")
+ ),
+ )
-def CheckHistoryBuffer(test_case, exp_history):
- """Verify that the items in the history buffer are what we expect.
-
- Args:
- test_case: A unittest.TestCase object representing the current unit test.
- exp_history: A list of strings representing the expected contents of the
- history buffer.
- """
- # First, check to see if the length is what we expect.
- test_case.assertEqual(len(exp_history), len(test_case.console.history),
- ('The number of items in the history is unexpected.\n'
- 'exp: ' + str(len(exp_history)) + '\n'
- 'got: ' + str(len(test_case.console.history)) + '\n'
- 'internal state:\n' + str(test_case.console)))
-
- # Next, check the actual contents of the history buffer.
- for i in range(len(exp_history)):
- test_case.assertEqual(exp_history[i], test_case.console.history[i],
- (b'history buffer contents are incorrect.\n'
- b'exp: ' + exp_history[i] + b'\n'
- b'got: ' + test_case.console.history[i] + b'\n'
- b'internal state:\n' +
- str(test_case.console).encode('ascii')))
+def CheckInputBufferPosition(test_case, exp_pos):
+ """Verify the input buffer position.
-class TestConsoleEditingMethods(unittest.TestCase):
- """Test case to verify all console editing methods."""
-
- def setUp(self):
- """Setup the test harness."""
- # Setup logging with a timestamp, the module, and the log level.
- logging.basicConfig(level=logging.DEBUG,
- format=('%(asctime)s - %(module)s -'
- ' %(levelname)s - %(message)s'))
-
- # Create a temp file and set both the controller and peripheral PTYs to the
- # file to create a loopback.
- self.tempfile = tempfile.TemporaryFile()
-
- # Create some mock pipes. These won't be used since we'll mock out sends
- # to the interpreter.
- mock_pipe_end_0, mock_pipe_end_1 = threadproc_shim.Pipe()
- self.console = console.Console(self.tempfile.fileno(), self.tempfile,
- tempfile.TemporaryFile(),
- mock_pipe_end_0, mock_pipe_end_1, "EC")
-
- # Console editing methods are only valid for enhanced EC images, therefore
- # we have to assume that the "EC" we're talking to is enhanced. By default,
- # the console believes that the EC it's communicating with is NOT enhanced
- # which is why we have to override it here.
- self.console.enhanced_ec = True
- self.console.CheckForEnhancedECImage = mock.MagicMock(return_value=True)
-
- def test_EnteringChars(self):
- """Verify that characters are echoed onto the console."""
- test_str = b'abc'
- input_stream = BytesToByteList(test_str)
-
- # Send the characters in.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Check the input position.
- exp_pos = len(test_str)
- CheckInputBufferPosition(self, exp_pos)
-
- # Verify that the input buffer is correct.
- expected_buffer = test_str
- CheckInputBuffer(self, expected_buffer)
-
- # Check console output
- exp_console_out = test_str
- CheckConsoleOutput(self, exp_console_out)
-
- def test_EnteringDeletingMoreCharsThanEntered(self):
- """Verify that we can press backspace more than we have entered chars."""
- test_str = b'spamspam'
- input_stream = BytesToByteList(test_str)
-
- # Send the characters in.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Now backspace 1 more than what we sent.
- input_stream = []
- for _ in range(len(test_str) + 1):
- input_stream.append(console.ControlKey.BACKSPACE)
-
- # Send that sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # First, verify that input buffer position is 0.
- CheckInputBufferPosition(self, 0)
-
- # Next, examine the output stream for the correct sequence.
- exp_console_out = test_str
- for _ in range(len(test_str)):
- exp_console_out += BACKSPACE_STRING
-
- # Now, verify that we got what we expected.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_EnteringMoreThanCharLimit(self):
- """Verify that we drop characters when the line is too long."""
- test_str = self.console.line_limit * b'o' # All allowed.
- test_str += 5 * b'x' # All should be dropped.
- input_stream = BytesToByteList(test_str)
-
- # Send the characters in.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # First, we expect that input buffer position should be equal to the line
- # limit.
- exp_pos = self.console.line_limit
- CheckInputBufferPosition(self, exp_pos)
-
- # The input buffer should only hold until the line limit.
- exp_buffer = test_str[0:self.console.line_limit]
- CheckInputBuffer(self, exp_buffer)
-
- # Lastly, check that the extra characters are not printed.
- exp_console_out = exp_buffer
- CheckConsoleOutput(self, exp_console_out)
-
- def test_ValidKeysOnLongLine(self):
- """Verify that we can still press valid keys if the line is too long."""
- # Fill the line.
- test_str = self.console.line_limit * b'o'
- exp_console_out = test_str
- # Try to fill it even more; these should all be dropped.
- test_str += 5 * b'x'
- input_stream = BytesToByteList(test_str)
-
- # We should be able to press the following keys:
- # - Backspace
- # - Arrow Keys/CTRL+B/CTRL+F/CTRL+P/CTRL+N
- # - Delete
- # - Home/CTRL+A
- # - End/CTRL+E
- # - Carriage Return
-
- # Backspace 1 character
- input_stream.append(console.ControlKey.BACKSPACE)
- exp_console_out += BACKSPACE_STRING
- # Refill the line.
- input_stream.extend(BytesToByteList(b'o'))
- exp_console_out += b'o'
-
- # Left arrow key.
- input_stream.extend(Keys.LEFT_ARROW)
- exp_console_out += OutputStream.MoveCursorLeft(1)
-
- # Right arrow key.
- input_stream.extend(Keys.RIGHT_ARROW)
- exp_console_out += OutputStream.MoveCursorRight(1)
-
- # CTRL+B
- input_stream.append(console.ControlKey.CTRL_B)
- exp_console_out += OutputStream.MoveCursorLeft(1)
-
- # CTRL+F
- input_stream.append(console.ControlKey.CTRL_F)
- exp_console_out += OutputStream.MoveCursorRight(1)
-
- # Let's press enter now so we can test up and down.
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- exp_console_out += b'\r\n' + self.console.prompt
-
- # Up arrow key.
- input_stream.extend(Keys.UP_ARROW)
- exp_console_out += test_str[:self.console.line_limit]
-
- # Down arrow key.
- input_stream.extend(Keys.DOWN_ARROW)
- # Since the line was blank, we have to backspace the entire line.
- exp_console_out += self.console.line_limit * BACKSPACE_STRING
-
- # CTRL+P
- input_stream.append(console.ControlKey.CTRL_P)
- exp_console_out += test_str[:self.console.line_limit]
-
- # CTRL+N
- input_stream.append(console.ControlKey.CTRL_N)
- # Since the line was blank, we have to backspace the entire line.
- exp_console_out += self.console.line_limit * BACKSPACE_STRING
-
- # Press the Up arrow key to reprint the long line.
- input_stream.extend(Keys.UP_ARROW)
- exp_console_out += test_str[:self.console.line_limit]
-
- # Press the Home key to jump to the beginning of the line.
- input_stream.extend(Keys.HOME)
- exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit)
-
- # Press the End key to jump to the end of the line.
- input_stream.extend(Keys.END)
- exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit)
-
- # Press CTRL+A to jump to the beginning of the line.
- input_stream.append(console.ControlKey.CTRL_A)
- exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit)
-
- # Press CTRL+E to jump to the end of the line.
- input_stream.extend(Keys.END)
- exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit)
-
- # Move left one column so we can delete a character.
- input_stream.extend(Keys.LEFT_ARROW)
- exp_console_out += OutputStream.MoveCursorLeft(1)
-
- # Press the delete key.
- input_stream.extend(Keys.DEL)
- # This should look like a space, and then move cursor left 1 column since
- # we're at the end of line.
- exp_console_out += b' ' + OutputStream.MoveCursorLeft(1)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify everything happened correctly.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_BackspaceOnEmptyLine(self):
- """Verify that we can backspace on an empty line with no bad effects."""
- # Send a single backspace.
- test_str = [console.ControlKey.BACKSPACE]
-
- # Send the characters in.
- for byte in test_str:
- self.console.HandleChar(byte)
-
- # Check the input position.
- exp_pos = 0
- CheckInputBufferPosition(self, exp_pos)
-
- # Check that buffer is empty.
- exp_input_buffer = b''
- CheckInputBuffer(self, exp_input_buffer)
-
- # Check that the console output is empty.
- exp_console_out = b''
- CheckConsoleOutput(self, exp_console_out)
-
- def test_BackspaceWithinLine(self):
- """Verify that we shift the chars over when backspacing within a line."""
- # Misspell 'help'
- test_str = b'heelp'
- input_stream = BytesToByteList(test_str)
- # Use the arrow key to go back to fix it.
- # Move cursor left 1 column.
- input_stream.extend(2*Keys.LEFT_ARROW)
- # Backspace once to remove the extra 'e'.
- input_stream.append(console.ControlKey.BACKSPACE)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify the input buffer
- exp_input_buffer = b'help'
- CheckInputBuffer(self, exp_input_buffer)
-
- # Verify the input buffer position. It should be at 2 (cursor over the 'l')
- CheckInputBufferPosition(self, 2)
-
- # We expect the console output to be the test string, with two moves to the
- # left, another move left, and then the rest of the line followed by a
- # space.
- exp_console_out = test_str
- exp_console_out += 2 * OutputStream.MoveCursorLeft(1)
-
- # Move cursor left 1 column.
- exp_console_out += OutputStream.MoveCursorLeft(1)
- # Rest of the line and a space. (test_str in this case)
- exp_console_out += b'lp '
- # Reset the cursor 2 + 1 to the left.
- exp_console_out += OutputStream.MoveCursorLeft(3)
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_JumpToBeginningOfLineViaCtrlA(self):
- """Verify that we can jump to the beginning of a line with Ctrl+A."""
- # Enter some chars and press CTRL+A
- test_str = b'abc'
- input_stream = BytesToByteList(test_str) + [console.ControlKey.CTRL_A]
-
- # Send the characters in.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # We expect to see our test string followed by a move cursor left.
- exp_console_out = test_str
- exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
-
- # Check to see what whas printed on the console.
- CheckConsoleOutput(self, exp_console_out)
-
- # Check that the input buffer position is now 0.
- CheckInputBufferPosition(self, 0)
-
- # Check input buffer still contains our test string.
- CheckInputBuffer(self, test_str)
-
- def test_JumpToBeginningOfLineViaHomeKey(self):
- """Jump to beginning of line via HOME key."""
- test_str = b'version'
- input_stream = BytesToByteList(test_str)
- input_stream.extend(Keys.HOME)
-
- # Send out the stream.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # First, verify that input buffer position is now 0.
- CheckInputBufferPosition(self, 0)
-
- # Next, verify that the input buffer did not change.
- CheckInputBuffer(self, test_str)
-
- # Lastly, check that the cursor moved correctly.
- exp_console_out = test_str
- exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
- CheckConsoleOutput(self, exp_console_out)
-
- def test_JumpToEndOfLineViaEndKey(self):
- """Jump to the end of the line using the END key."""
- test_str = b'version'
- input_stream = BytesToByteList(test_str)
- input_stream += [console.ControlKey.CTRL_A]
- # Now, jump to the end of the line.
- input_stream.extend(Keys.END)
-
- # Send out the stream.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the input buffer position is correct. This should be at the
- # end of the test string.
- CheckInputBufferPosition(self, len(test_str))
-
- # The expected output should be the test string, followed by a jump to the
- # beginning of the line, and lastly a jump to the end of the line.
- exp_console_out = test_str
- exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
- # Now the jump back to the end of the line.
- exp_console_out += OutputStream.MoveCursorRight(len(test_str))
-
- # Verify console output stream.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_JumpToEndOfLineViaCtrlE(self):
- """Enter some chars and then try to jump to the end. (Should be a no-op)"""
- test_str = b'sysinfo'
- input_stream = BytesToByteList(test_str)
- input_stream.append(console.ControlKey.CTRL_E)
-
- # Send out the stream
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the input buffer position isn't any further than we expect.
- # At this point, the position should be at the end of the test string.
- CheckInputBufferPosition(self, len(test_str))
-
- # Now, let's try to jump to the beginning and then jump back to the end.
- input_stream = [console.ControlKey.CTRL_A, console.ControlKey.CTRL_E]
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Perform the same verification.
- CheckInputBufferPosition(self, len(test_str))
-
- # Lastly try to jump again, beyond the end.
- input_stream = [console.ControlKey.CTRL_E]
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Perform the same verification.
- CheckInputBufferPosition(self, len(test_str))
-
- # We expect to see the test string, a jump to the beginning of the line, and
- # one jump to the end of the line.
- exp_console_out = test_str
- # Jump to beginning.
- exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
- # Jump back to end.
- exp_console_out += OutputStream.MoveCursorRight(len(test_str))
-
- # Verify the console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_MoveLeftWithArrowKey(self):
- """Move cursor left one column with arrow key."""
- test_str = b'tastyspam'
- input_stream = BytesToByteList(test_str)
- input_stream.extend(Keys.LEFT_ARROW)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the input buffer position is 1 less than the length.
- CheckInputBufferPosition(self, len(test_str) - 1)
-
- # Also, verify that the input buffer is not modified.
- CheckInputBuffer(self, test_str)
-
- # We expect the test string, followed by a one column move left.
- exp_console_out = test_str + OutputStream.MoveCursorLeft(1)
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_MoveLeftWithCtrlB(self):
- """Move cursor back one column with Ctrl+B."""
- test_str = b'tastyspam'
- input_stream = BytesToByteList(test_str)
- input_stream.append(console.ControlKey.CTRL_B)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the input buffer position is 1 less than the length.
- CheckInputBufferPosition(self, len(test_str) - 1)
+ Args:
+ test_case: A unittest.TestCase object representing the current unit test.
+ exp_pos: An integer representing the expected input buffer position.
+ """
+ test_case.assertEqual(
+ exp_pos,
+ test_case.console.input_buffer_pos,
+ "input buffer position is incorrect.\ngot: "
+ + str(test_case.console.input_buffer_pos)
+ + "\nexp: "
+ + str(exp_pos)
+ + "\n"
+ + str(test_case.console),
+ )
- # Also, verify that the input buffer is not modified.
- CheckInputBuffer(self, test_str)
- # We expect the test string, followed by a one column move left.
- exp_console_out = test_str + OutputStream.MoveCursorLeft(1)
+def CheckHistoryBuffer(test_case, exp_history):
+ """Verify that the items in the history buffer are what we expect.
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
+ Args:
+ test_case: A unittest.TestCase object representing the current unit test.
+ exp_history: A list of strings representing the expected contents of the
+ history buffer.
+ """
+ # First, check to see if the length is what we expect.
+ test_case.assertEqual(
+ len(exp_history),
+ len(test_case.console.history),
+ (
+ "The number of items in the history is unexpected.\n"
+ "exp: " + str(len(exp_history)) + "\n"
+ "got: " + str(len(test_case.console.history)) + "\n"
+ "internal state:\n" + str(test_case.console)
+ ),
+ )
+
+ # Next, check the actual contents of the history buffer.
+ for i in range(len(exp_history)):
+ test_case.assertEqual(
+ exp_history[i],
+ test_case.console.history[i],
+ (
+ b"history buffer contents are incorrect.\n"
+ b"exp: " + exp_history[i] + b"\n"
+ b"got: " + test_case.console.history[i] + b"\n"
+ b"internal state:\n" + str(test_case.console).encode("ascii")
+ ),
+ )
- def test_MoveRightWithArrowKey(self):
- """Move cursor one column to the right with the arrow key."""
- test_str = b'version'
- input_stream = BytesToByteList(test_str)
- # Jump to beginning of line.
- input_stream.append(console.ControlKey.CTRL_A)
- # Press right arrow key.
- input_stream.extend(Keys.RIGHT_ARROW)
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
+class TestConsoleEditingMethods(unittest.TestCase):
+ """Test case to verify all console editing methods."""
+
+ def setUp(self):
+ """Setup the test harness."""
+ # Setup logging with a timestamp, the module, and the log level.
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"),
+ )
+
+ # Create a temp file and set both the controller and peripheral PTYs to the
+ # file to create a loopback.
+ self.tempfile = tempfile.TemporaryFile()
+
+ # Create some mock pipes. These won't be used since we'll mock out sends
+ # to the interpreter.
+ mock_pipe_end_0, mock_pipe_end_1 = threadproc_shim.Pipe()
+ self.console = console.Console(
+ self.tempfile.fileno(),
+ self.tempfile,
+ tempfile.TemporaryFile(),
+ mock_pipe_end_0,
+ mock_pipe_end_1,
+ "EC",
+ )
+
+ # Console editing methods are only valid for enhanced EC images, therefore
+ # we have to assume that the "EC" we're talking to is enhanced. By default,
+ # the console believes that the EC it's communicating with is NOT enhanced
+ # which is why we have to override it here.
+ self.console.enhanced_ec = True
+ self.console.CheckForEnhancedECImage = mock.MagicMock(return_value=True)
+
+ def test_EnteringChars(self):
+ """Verify that characters are echoed onto the console."""
+ test_str = b"abc"
+ input_stream = BytesToByteList(test_str)
+
+ # Send the characters in.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Check the input position.
+ exp_pos = len(test_str)
+ CheckInputBufferPosition(self, exp_pos)
+
+ # Verify that the input buffer is correct.
+ expected_buffer = test_str
+ CheckInputBuffer(self, expected_buffer)
+
+ # Check console output
+ exp_console_out = test_str
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_EnteringDeletingMoreCharsThanEntered(self):
+ """Verify that we can press backspace more than we have entered chars."""
+ test_str = b"spamspam"
+ input_stream = BytesToByteList(test_str)
+
+ # Send the characters in.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Now backspace 1 more than what we sent.
+ input_stream = []
+ for _ in range(len(test_str) + 1):
+ input_stream.append(console.ControlKey.BACKSPACE)
+
+ # Send that sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # First, verify that input buffer position is 0.
+ CheckInputBufferPosition(self, 0)
+
+ # Next, examine the output stream for the correct sequence.
+ exp_console_out = test_str
+ for _ in range(len(test_str)):
+ exp_console_out += BACKSPACE_STRING
+
+ # Now, verify that we got what we expected.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_EnteringMoreThanCharLimit(self):
+ """Verify that we drop characters when the line is too long."""
+ test_str = self.console.line_limit * b"o" # All allowed.
+ test_str += 5 * b"x" # All should be dropped.
+ input_stream = BytesToByteList(test_str)
+
+ # Send the characters in.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # First, we expect that input buffer position should be equal to the line
+ # limit.
+ exp_pos = self.console.line_limit
+ CheckInputBufferPosition(self, exp_pos)
+
+ # The input buffer should only hold until the line limit.
+ exp_buffer = test_str[0 : self.console.line_limit]
+ CheckInputBuffer(self, exp_buffer)
+
+ # Lastly, check that the extra characters are not printed.
+ exp_console_out = exp_buffer
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_ValidKeysOnLongLine(self):
+ """Verify that we can still press valid keys if the line is too long."""
+ # Fill the line.
+ test_str = self.console.line_limit * b"o"
+ exp_console_out = test_str
+ # Try to fill it even more; these should all be dropped.
+ test_str += 5 * b"x"
+ input_stream = BytesToByteList(test_str)
+
+ # We should be able to press the following keys:
+ # - Backspace
+ # - Arrow Keys/CTRL+B/CTRL+F/CTRL+P/CTRL+N
+ # - Delete
+ # - Home/CTRL+A
+ # - End/CTRL+E
+ # - Carriage Return
+
+ # Backspace 1 character
+ input_stream.append(console.ControlKey.BACKSPACE)
+ exp_console_out += BACKSPACE_STRING
+ # Refill the line.
+ input_stream.extend(BytesToByteList(b"o"))
+ exp_console_out += b"o"
+
+ # Left arrow key.
+ input_stream.extend(Keys.LEFT_ARROW)
+ exp_console_out += OutputStream.MoveCursorLeft(1)
+
+ # Right arrow key.
+ input_stream.extend(Keys.RIGHT_ARROW)
+ exp_console_out += OutputStream.MoveCursorRight(1)
+
+ # CTRL+B
+ input_stream.append(console.ControlKey.CTRL_B)
+ exp_console_out += OutputStream.MoveCursorLeft(1)
+
+ # CTRL+F
+ input_stream.append(console.ControlKey.CTRL_F)
+ exp_console_out += OutputStream.MoveCursorRight(1)
+
+ # Let's press enter now so we can test up and down.
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ exp_console_out += b"\r\n" + self.console.prompt
+
+ # Up arrow key.
+ input_stream.extend(Keys.UP_ARROW)
+ exp_console_out += test_str[: self.console.line_limit]
+
+ # Down arrow key.
+ input_stream.extend(Keys.DOWN_ARROW)
+ # Since the line was blank, we have to backspace the entire line.
+ exp_console_out += self.console.line_limit * BACKSPACE_STRING
+
+ # CTRL+P
+ input_stream.append(console.ControlKey.CTRL_P)
+ exp_console_out += test_str[: self.console.line_limit]
+
+ # CTRL+N
+ input_stream.append(console.ControlKey.CTRL_N)
+ # Since the line was blank, we have to backspace the entire line.
+ exp_console_out += self.console.line_limit * BACKSPACE_STRING
+
+ # Press the Up arrow key to reprint the long line.
+ input_stream.extend(Keys.UP_ARROW)
+ exp_console_out += test_str[: self.console.line_limit]
+
+ # Press the Home key to jump to the beginning of the line.
+ input_stream.extend(Keys.HOME)
+ exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit)
+
+ # Press the End key to jump to the end of the line.
+ input_stream.extend(Keys.END)
+ exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit)
+
+ # Press CTRL+A to jump to the beginning of the line.
+ input_stream.append(console.ControlKey.CTRL_A)
+ exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit)
+
+ # Press CTRL+E to jump to the end of the line.
+ input_stream.extend(Keys.END)
+ exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit)
+
+ # Move left one column so we can delete a character.
+ input_stream.extend(Keys.LEFT_ARROW)
+ exp_console_out += OutputStream.MoveCursorLeft(1)
+
+ # Press the delete key.
+ input_stream.extend(Keys.DEL)
+ # This should look like a space, and then move cursor left 1 column since
+ # we're at the end of line.
+ exp_console_out += b" " + OutputStream.MoveCursorLeft(1)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify everything happened correctly.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_BackspaceOnEmptyLine(self):
+ """Verify that we can backspace on an empty line with no bad effects."""
+ # Send a single backspace.
+ test_str = [console.ControlKey.BACKSPACE]
+
+ # Send the characters in.
+ for byte in test_str:
+ self.console.HandleChar(byte)
+
+ # Check the input position.
+ exp_pos = 0
+ CheckInputBufferPosition(self, exp_pos)
+
+ # Check that buffer is empty.
+ exp_input_buffer = b""
+ CheckInputBuffer(self, exp_input_buffer)
+
+ # Check that the console output is empty.
+ exp_console_out = b""
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_BackspaceWithinLine(self):
+ """Verify that we shift the chars over when backspacing within a line."""
+ # Misspell 'help'
+ test_str = b"heelp"
+ input_stream = BytesToByteList(test_str)
+ # Use the arrow key to go back to fix it.
+ # Move cursor left 1 column.
+ input_stream.extend(2 * Keys.LEFT_ARROW)
+ # Backspace once to remove the extra 'e'.
+ input_stream.append(console.ControlKey.BACKSPACE)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify the input buffer
+ exp_input_buffer = b"help"
+ CheckInputBuffer(self, exp_input_buffer)
+
+ # Verify the input buffer position. It should be at 2 (cursor over the 'l')
+ CheckInputBufferPosition(self, 2)
+
+ # We expect the console output to be the test string, with two moves to the
+ # left, another move left, and then the rest of the line followed by a
+ # space.
+ exp_console_out = test_str
+ exp_console_out += 2 * OutputStream.MoveCursorLeft(1)
+
+ # Move cursor left 1 column.
+ exp_console_out += OutputStream.MoveCursorLeft(1)
+ # Rest of the line and a space. (test_str in this case)
+ exp_console_out += b"lp "
+ # Reset the cursor 2 + 1 to the left.
+ exp_console_out += OutputStream.MoveCursorLeft(3)
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_JumpToBeginningOfLineViaCtrlA(self):
+ """Verify that we can jump to the beginning of a line with Ctrl+A."""
+ # Enter some chars and press CTRL+A
+ test_str = b"abc"
+ input_stream = BytesToByteList(test_str) + [console.ControlKey.CTRL_A]
+
+ # Send the characters in.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # We expect to see our test string followed by a move cursor left.
+ exp_console_out = test_str
+ exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
+
+ # Check to see what whas printed on the console.
+ CheckConsoleOutput(self, exp_console_out)
+
+ # Check that the input buffer position is now 0.
+ CheckInputBufferPosition(self, 0)
+
+ # Check input buffer still contains our test string.
+ CheckInputBuffer(self, test_str)
+
+ def test_JumpToBeginningOfLineViaHomeKey(self):
+ """Jump to beginning of line via HOME key."""
+ test_str = b"version"
+ input_stream = BytesToByteList(test_str)
+ input_stream.extend(Keys.HOME)
+
+ # Send out the stream.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # First, verify that input buffer position is now 0.
+ CheckInputBufferPosition(self, 0)
+
+ # Next, verify that the input buffer did not change.
+ CheckInputBuffer(self, test_str)
+
+ # Lastly, check that the cursor moved correctly.
+ exp_console_out = test_str
+ exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_JumpToEndOfLineViaEndKey(self):
+ """Jump to the end of the line using the END key."""
+ test_str = b"version"
+ input_stream = BytesToByteList(test_str)
+ input_stream += [console.ControlKey.CTRL_A]
+ # Now, jump to the end of the line.
+ input_stream.extend(Keys.END)
+
+ # Send out the stream.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the input buffer position is correct. This should be at the
+ # end of the test string.
+ CheckInputBufferPosition(self, len(test_str))
+
+ # The expected output should be the test string, followed by a jump to the
+ # beginning of the line, and lastly a jump to the end of the line.
+ exp_console_out = test_str
+ exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
+ # Now the jump back to the end of the line.
+ exp_console_out += OutputStream.MoveCursorRight(len(test_str))
+
+ # Verify console output stream.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_JumpToEndOfLineViaCtrlE(self):
+ """Enter some chars and then try to jump to the end. (Should be a no-op)"""
+ test_str = b"sysinfo"
+ input_stream = BytesToByteList(test_str)
+ input_stream.append(console.ControlKey.CTRL_E)
+
+ # Send out the stream
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the input buffer position isn't any further than we expect.
+ # At this point, the position should be at the end of the test string.
+ CheckInputBufferPosition(self, len(test_str))
+
+ # Now, let's try to jump to the beginning and then jump back to the end.
+ input_stream = [console.ControlKey.CTRL_A, console.ControlKey.CTRL_E]
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Perform the same verification.
+ CheckInputBufferPosition(self, len(test_str))
+
+ # Lastly try to jump again, beyond the end.
+ input_stream = [console.ControlKey.CTRL_E]
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Perform the same verification.
+ CheckInputBufferPosition(self, len(test_str))
+
+ # We expect to see the test string, a jump to the beginning of the line, and
+ # one jump to the end of the line.
+ exp_console_out = test_str
+ # Jump to beginning.
+ exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
+ # Jump back to end.
+ exp_console_out += OutputStream.MoveCursorRight(len(test_str))
+
+ # Verify the console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_MoveLeftWithArrowKey(self):
+ """Move cursor left one column with arrow key."""
+ test_str = b"tastyspam"
+ input_stream = BytesToByteList(test_str)
+ input_stream.extend(Keys.LEFT_ARROW)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the input buffer position is 1 less than the length.
+ CheckInputBufferPosition(self, len(test_str) - 1)
+
+ # Also, verify that the input buffer is not modified.
+ CheckInputBuffer(self, test_str)
+
+ # We expect the test string, followed by a one column move left.
+ exp_console_out = test_str + OutputStream.MoveCursorLeft(1)
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_MoveLeftWithCtrlB(self):
+ """Move cursor back one column with Ctrl+B."""
+ test_str = b"tastyspam"
+ input_stream = BytesToByteList(test_str)
+ input_stream.append(console.ControlKey.CTRL_B)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the input buffer position is 1 less than the length.
+ CheckInputBufferPosition(self, len(test_str) - 1)
- # Verify that the input buffer position is 1.
- CheckInputBufferPosition(self, 1)
+ # Also, verify that the input buffer is not modified.
+ CheckInputBuffer(self, test_str)
- # Also, verify that the input buffer is not modified.
- CheckInputBuffer(self, test_str)
+ # We expect the test string, followed by a one column move left.
+ exp_console_out = test_str + OutputStream.MoveCursorLeft(1)
- # We expect the test string, followed by a jump to the beginning of the
- # line, and finally a move right 1.
- exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str)))
-
- # A move right 1 column.
- exp_console_out += OutputStream.MoveCursorRight(1)
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_MoveRightWithCtrlF(self):
- """Move cursor forward one column with Ctrl+F."""
- test_str = b'panicinfo'
- input_stream = BytesToByteList(test_str)
- input_stream.append(console.ControlKey.CTRL_A)
- # Now, move right one column.
- input_stream.append(console.ControlKey.CTRL_F)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the input buffer position is 1.
- CheckInputBufferPosition(self, 1)
-
- # Also, verify that the input buffer is not modified.
- CheckInputBuffer(self, test_str)
-
- # We expect the test string, followed by a jump to the beginning of the
- # line, and finally a move right 1.
- exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str)))
-
- # A move right 1 column.
- exp_console_out += OutputStream.MoveCursorRight(1)
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_ImpossibleMoveLeftWithArrowKey(self):
- """Verify that we can't move left at the beginning of the line."""
- # We shouldn't be able to move left if we're at the beginning of the line.
- input_stream = Keys.LEFT_ARROW
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Nothing should have been output.
- exp_console_output = b''
- CheckConsoleOutput(self, exp_console_output)
-
- # The input buffer position should still be 0.
- CheckInputBufferPosition(self, 0)
-
- # The input buffer itself should be empty.
- CheckInputBuffer(self, b'')
-
- def test_ImpossibleMoveRightWithArrowKey(self):
- """Verify that we can't move right at the end of the line."""
- # We shouldn't be able to move right if we're at the end of the line.
- input_stream = Keys.RIGHT_ARROW
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Nothing should have been output.
- exp_console_output = b''
- CheckConsoleOutput(self, exp_console_output)
-
- # The input buffer position should still be 0.
- CheckInputBufferPosition(self, 0)
-
- # The input buffer itself should be empty.
- CheckInputBuffer(self, b'')
-
- def test_KillEntireLine(self):
- """Verify that we can kill an entire line with Ctrl+K."""
- test_str = b'accelinfo on'
- input_stream = BytesToByteList(test_str)
- # Jump to beginning of line and then kill it with Ctrl+K.
- input_stream.extend([console.ControlKey.CTRL_A, console.ControlKey.CTRL_K])
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # First, we expect that the input buffer is empty.
- CheckInputBuffer(self, b'')
-
- # The buffer position should be 0.
- CheckInputBufferPosition(self, 0)
-
- # What we expect to see on the console stream should be the following. The
- # test string, a jump to the beginning of the line, then jump back to the
- # end of the line and replace the line with spaces.
- exp_console_out = test_str
- # Jump to beginning of line.
- exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
- # Jump to end of line.
- exp_console_out += OutputStream.MoveCursorRight(len(test_str))
- # Replace line with spaces, which looks like backspaces.
- for _ in range(len(test_str)):
- exp_console_out += BACKSPACE_STRING
-
- # Verify the console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_KillPartialLine(self):
- """Verify that we can kill a portion of a line."""
- test_str = b'accelread 0 1'
- input_stream = BytesToByteList(test_str)
- len_to_kill = 5
- for _ in range(len_to_kill):
- # Move cursor left
- input_stream.extend(Keys.LEFT_ARROW)
- # Now kill
- input_stream.append(console.ControlKey.CTRL_K)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # First, check that the input buffer was truncated.
- exp_input_buffer = test_str[:-len_to_kill]
- CheckInputBuffer(self, exp_input_buffer)
-
- # Verify the input buffer position.
- CheckInputBufferPosition(self, len(test_str) - len_to_kill)
-
- # The console output stream that we expect is the test string followed by a
- # move left of len_to_kill, then a jump to the end of the line and backspace
- # of len_to_kill.
- exp_console_out = test_str
- for _ in range(len_to_kill):
- # Move left 1 column.
- exp_console_out += OutputStream.MoveCursorLeft(1)
- # Then jump to the end of the line
- exp_console_out += OutputStream.MoveCursorRight(len_to_kill)
- # Backspace of len_to_kill
- for _ in range(len_to_kill):
- exp_console_out += BACKSPACE_STRING
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_InsertingCharacters(self):
- """Verify that we can insert characters within the line."""
- test_str = b'accel 0 1' # Here we forgot the 'read' part in 'accelread'
- input_stream = BytesToByteList(test_str)
- # We need to move over to the 'l' and add read.
- insertion_point = test_str.find(b'l') + 1
- for i in range(len(test_str) - insertion_point):
- # Move cursor left.
- input_stream.extend(Keys.LEFT_ARROW)
- # Now, add in 'read'
- added_str = b'read'
- input_stream.extend(BytesToByteList(added_str))
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # First, verify that the input buffer is correct.
- exp_input_buffer = test_str[:insertion_point] + added_str
- exp_input_buffer += test_str[insertion_point:]
- CheckInputBuffer(self, exp_input_buffer)
-
- # Verify that the input buffer position is correct.
- exp_input_buffer_pos = insertion_point + len(added_str)
- CheckInputBufferPosition(self, exp_input_buffer_pos)
-
- # The console output stream that we expect is the test string, followed by
- # move cursor left until the 'l' was found, the added test string while
- # shifting characters around.
- exp_console_out = test_str
- for i in range(len(test_str) - insertion_point):
- # Move cursor left.
- exp_console_out += OutputStream.MoveCursorLeft(1)
-
- # Now for each character, write the rest of the line will be shifted to the
- # right one column.
- for i in range(len(added_str)):
- # Printed character.
- exp_console_out += added_str[i:i+1]
- # The rest of the line
- exp_console_out += test_str[insertion_point:]
- # Reset the cursor back left
- reset_dist = len(test_str[insertion_point:])
- exp_console_out += OutputStream.MoveCursorLeft(reset_dist)
-
- # Verify the console output.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_StoreCommandHistory(self):
- """Verify that entered commands are stored in the history."""
- test_commands = []
- test_commands.append(b'help')
- test_commands.append(b'version')
- test_commands.append(b'accelread 0 1')
- input_stream = []
- for c in test_commands:
- input_stream.extend(BytesToByteList(c))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # We expect to have the test commands in the history buffer.
- exp_history_buf = test_commands
- CheckHistoryBuffer(self, exp_history_buf)
-
- def test_CycleUpThruCommandHistory(self):
- """Verify that the UP arrow key will print itmes in the history buffer."""
- # Enter some commands.
- test_commands = [b'version', b'accelrange 0', b'battery', b'gettime']
- input_stream = []
- for command in test_commands:
- input_stream.extend(BytesToByteList(command))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Now, hit the UP arrow key to print the previous entries.
- for i in range(len(test_commands)):
- input_stream.extend(Keys.UP_ARROW)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The expected output should be test commands with prompts printed in
- # between, followed by line kills with the previous test commands printed.
- exp_console_out = b''
- for i in range(len(test_commands)):
- exp_console_out += test_commands[i] + b'\r\n' + self.console.prompt
-
- # When we press up, the line should be cleared and print the previous buffer
- # entry.
- for i in range(len(test_commands)-1, 0, -1):
- exp_console_out += test_commands[i]
- # Backspace to the beginning.
- for i in range(len(test_commands[i])):
- exp_console_out += BACKSPACE_STRING
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
- # The last command should just be printed out with no backspacing.
- exp_console_out += test_commands[0]
-
- # Now, verify.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_UpArrowOnEmptyHistory(self):
- """Ensure nothing happens if the history is empty."""
- # Press the up arrow key twice.
- input_stream = 2 * Keys.UP_ARROW
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # We expect nothing to have happened.
- exp_console_out = b''
- exp_input_buffer = b''
- exp_input_buffer_pos = 0
- exp_history_buf = []
-
- # Verify.
- CheckConsoleOutput(self, exp_console_out)
- CheckInputBufferPosition(self, exp_input_buffer_pos)
- CheckInputBuffer(self, exp_input_buffer)
- CheckHistoryBuffer(self, exp_history_buf)
-
- def test_UpArrowDoesNotGoOutOfBounds(self):
- """Verify that pressing the up arrow many times won't go out of bounds."""
- # Enter one command.
- test_str = b'help version'
- input_stream = BytesToByteList(test_str)
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- # Then press the up arrow key twice.
- input_stream.extend(2 * Keys.UP_ARROW)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the history buffer is correct.
- exp_history_buf = [test_str]
- CheckHistoryBuffer(self, exp_history_buf)
-
- # We expect that the console output should only contain our entered command,
- # a new prompt, and then our command aggain.
- exp_console_out = test_str + b'\r\n' + self.console.prompt
- # Pressing up should reprint the command we entered.
- exp_console_out += test_str
-
- # Verify.
- CheckConsoleOutput(self, exp_console_out)
-
- def test_CycleDownThruCommandHistory(self):
- """Verify that we can select entries by hitting the down arrow."""
- # Enter at least 4 commands.
- test_commands = [b'version', b'accelrange 0', b'battery', b'gettime']
- input_stream = []
- for command in test_commands:
- input_stream.extend(BytesToByteList(command))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Now, hit the UP arrow key twice to print the previous two entries.
- for i in range(2):
- input_stream.extend(Keys.UP_ARROW)
-
- # Now, hit the DOWN arrow key twice to print the newer entries.
- input_stream.extend(2*Keys.DOWN_ARROW)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The expected output should be commands that we entered, followed by
- # prompts, then followed by our last two commands in reverse. Then, we
- # should see the last entry in the list, followed by the saved partial cmd
- # of a blank line.
- exp_console_out = b''
- for i in range(len(test_commands)):
- exp_console_out += test_commands[i] + b'\r\n' + self.console.prompt
-
- # When we press up, the line should be cleared and print the previous buffer
- # entry.
- for i in range(len(test_commands)-1, 1, -1):
- exp_console_out += test_commands[i]
- # Backspace to the beginning.
- for i in range(len(test_commands[i])):
- exp_console_out += BACKSPACE_STRING
+ def test_MoveRightWithArrowKey(self):
+ """Move cursor one column to the right with the arrow key."""
+ test_str = b"version"
+ input_stream = BytesToByteList(test_str)
+ # Jump to beginning of line.
+ input_stream.append(console.ControlKey.CTRL_A)
+ # Press right arrow key.
+ input_stream.extend(Keys.RIGHT_ARROW)
- # When we press down, it should have cleared the last command (which we
- # covered with the previous for loop), and then prints the next command.
- exp_console_out += test_commands[3]
- for i in range(len(test_commands[3])):
- exp_console_out += BACKSPACE_STRING
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- # Verify input buffer.
- exp_input_buffer = b'' # Empty because our partial command was empty.
- exp_input_buffer_pos = len(exp_input_buffer)
- CheckInputBuffer(self, exp_input_buffer)
- CheckInputBufferPosition(self, exp_input_buffer_pos)
-
- def test_SavingPartialCommandWhenNavigatingHistory(self):
- """Verify that partial commands are saved when navigating history."""
- # Enter a command.
- test_str = b'accelinfo'
- input_stream = BytesToByteList(test_str)
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Enter a partial command.
- partial_cmd = b'ver'
- input_stream.extend(BytesToByteList(partial_cmd))
-
- # Hit the UP arrow key.
- input_stream.extend(Keys.UP_ARROW)
- # Then, the DOWN arrow key.
- input_stream.extend(Keys.DOWN_ARROW)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The expected output should be the command we entered, a prompt, the
- # partial command, clearing of the partial command, the command entered,
- # clearing of the command entered, and then the partial command.
- exp_console_out = test_str + b'\r\n' + self.console.prompt
- exp_console_out += partial_cmd
- for _ in range(len(partial_cmd)):
- exp_console_out += BACKSPACE_STRING
- exp_console_out += test_str
- for _ in range(len(test_str)):
- exp_console_out += BACKSPACE_STRING
- exp_console_out += partial_cmd
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- # Verify input buffer.
- exp_input_buffer = partial_cmd
- exp_input_buffer_pos = len(exp_input_buffer)
- CheckInputBuffer(self, exp_input_buffer)
- CheckInputBufferPosition(self, exp_input_buffer_pos)
-
- def test_DownArrowOnEmptyHistory(self):
- """Ensure nothing happens if the history is empty."""
- # Then press the up down arrow twice.
- input_stream = 2 * Keys.DOWN_ARROW
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # We expect nothing to have happened.
- exp_console_out = b''
- exp_input_buffer = b''
- exp_input_buffer_pos = 0
- exp_history_buf = []
-
- # Verify.
- CheckConsoleOutput(self, exp_console_out)
- CheckInputBufferPosition(self, exp_input_buffer_pos)
- CheckInputBuffer(self, exp_input_buffer)
- CheckHistoryBuffer(self, exp_history_buf)
-
- def test_DeleteCharsUsingDELKey(self):
- """Verify that we can delete characters using the DEL key."""
- test_str = b'version'
- input_stream = BytesToByteList(test_str)
-
- # Hit the left arrow key 2 times.
- input_stream.extend(2 * Keys.LEFT_ARROW)
-
- # Press the DEL key.
- input_stream.extend(Keys.DEL)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The expected output should be the command we entered, 2 individual cursor
- # moves to the left, and then removing a char and shifting everything to the
- # left one column.
- exp_console_out = test_str
- exp_console_out += 2 * OutputStream.MoveCursorLeft(1)
-
- # Remove the char by shifting everything to the left one, slicing out the
- # remove char.
- exp_console_out += test_str[-1:] + b' '
-
- # Reset the cursor by moving back 2 columns because of the 'n' and space.
- exp_console_out += OutputStream.MoveCursorLeft(2)
-
- # Verify console output.
- CheckConsoleOutput(self, exp_console_out)
-
- # Verify input buffer. The input buffer should have the char sliced out and
- # be positioned where the char was removed.
- exp_input_buffer = test_str[:-2] + test_str[-1:]
- exp_input_buffer_pos = len(exp_input_buffer) - 1
- CheckInputBuffer(self, exp_input_buffer)
- CheckInputBufferPosition(self, exp_input_buffer_pos)
-
- def test_RepeatedCommandInHistory(self):
- """Verify that we don't store 2 consecutive identical commands in history"""
- # Enter a few commands.
- test_commands = [b'version', b'accelrange 0', b'battery', b'gettime']
- # Repeat the last command.
- test_commands.append(test_commands[len(test_commands)-1])
-
- input_stream = []
- for command in test_commands:
- input_stream.extend(BytesToByteList(command))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Verify that the history buffer is correct. The last command, since
- # it was repeated, should not have been added to the history.
- exp_history_buf = test_commands[0:len(test_commands)-1]
- CheckHistoryBuffer(self, exp_history_buf)
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+ # Verify that the input buffer position is 1.
+ CheckInputBufferPosition(self, 1)
-class TestConsoleCompatibility(unittest.TestCase):
- """Verify that console can speak to enhanced and non-enhanced EC images."""
- def setUp(self):
- """Setup the test harness."""
- # Setup logging with a timestamp, the module, and the log level.
- logging.basicConfig(level=logging.DEBUG,
- format=('%(asctime)s - %(module)s -'
- ' %(levelname)s - %(message)s'))
- # Create a temp file and set both the controller and peripheral PTYs to the
- # file to create a loopback.
- self.tempfile = tempfile.TemporaryFile()
-
- # Mock out the pipes.
- mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock()
- self.console = console.Console(self.tempfile.fileno(), self.tempfile,
- tempfile.TemporaryFile(),
- mock_pipe_end_0, mock_pipe_end_1, "EC")
-
- @mock.patch('ec3po.console.Console.CheckForEnhancedECImage')
- def test_ActAsPassThruInNonEnhancedMode(self, mock_check):
- """Verify we simply pass everything thru to non-enhanced ECs.
+ # Also, verify that the input buffer is not modified.
+ CheckInputBuffer(self, test_str)
- Args:
- mock_check: A MagicMock object replacing the CheckForEnhancedECImage()
- method.
- """
- # Set the interrogation mode to always so that we actually interrogate.
- self.console.interrogation_mode = b'always'
-
- # Assume EC interrogations indicate that the image is non-enhanced.
- mock_check.return_value = False
-
- # Press enter, followed by the command, and another enter.
- input_stream = []
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- test_command = b'version'
- input_stream.extend(BytesToByteList(test_command))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Expected calls to send down the pipe would be each character of the test
- # command.
- expected_calls = []
- expected_calls.append(mock.call(
- six.int2byte(console.ControlKey.CARRIAGE_RETURN)))
- for char in test_command:
- if six.PY3:
- expected_calls.append(mock.call(bytes([char])))
- else:
- expected_calls.append(mock.call(char))
- expected_calls.append(mock.call(
- six.int2byte(console.ControlKey.CARRIAGE_RETURN)))
-
- # Verify that the calls happened.
- self.console.cmd_pipe.send.assert_has_calls(expected_calls)
-
- # Since we're acting as a pass-thru, the input buffer should be empty and
- # input_buffer_pos is 0.
- CheckInputBuffer(self, b'')
- CheckInputBufferPosition(self, 0)
-
- @mock.patch('ec3po.console.Console.CheckForEnhancedECImage')
- def test_TransitionFromNonEnhancedToEnhanced(self, mock_check):
- """Verify that we transition correctly to enhanced mode.
+ # We expect the test string, followed by a jump to the beginning of the
+ # line, and finally a move right 1.
+ exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str)))
+
+ # A move right 1 column.
+ exp_console_out += OutputStream.MoveCursorRight(1)
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_MoveRightWithCtrlF(self):
+ """Move cursor forward one column with Ctrl+F."""
+ test_str = b"panicinfo"
+ input_stream = BytesToByteList(test_str)
+ input_stream.append(console.ControlKey.CTRL_A)
+ # Now, move right one column.
+ input_stream.append(console.ControlKey.CTRL_F)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the input buffer position is 1.
+ CheckInputBufferPosition(self, 1)
+
+ # Also, verify that the input buffer is not modified.
+ CheckInputBuffer(self, test_str)
+
+ # We expect the test string, followed by a jump to the beginning of the
+ # line, and finally a move right 1.
+ exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str)))
+
+ # A move right 1 column.
+ exp_console_out += OutputStream.MoveCursorRight(1)
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_ImpossibleMoveLeftWithArrowKey(self):
+ """Verify that we can't move left at the beginning of the line."""
+ # We shouldn't be able to move left if we're at the beginning of the line.
+ input_stream = Keys.LEFT_ARROW
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Nothing should have been output.
+ exp_console_output = b""
+ CheckConsoleOutput(self, exp_console_output)
+
+ # The input buffer position should still be 0.
+ CheckInputBufferPosition(self, 0)
+
+ # The input buffer itself should be empty.
+ CheckInputBuffer(self, b"")
+
+ def test_ImpossibleMoveRightWithArrowKey(self):
+ """Verify that we can't move right at the end of the line."""
+ # We shouldn't be able to move right if we're at the end of the line.
+ input_stream = Keys.RIGHT_ARROW
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Nothing should have been output.
+ exp_console_output = b""
+ CheckConsoleOutput(self, exp_console_output)
+
+ # The input buffer position should still be 0.
+ CheckInputBufferPosition(self, 0)
+
+ # The input buffer itself should be empty.
+ CheckInputBuffer(self, b"")
+
+ def test_KillEntireLine(self):
+ """Verify that we can kill an entire line with Ctrl+K."""
+ test_str = b"accelinfo on"
+ input_stream = BytesToByteList(test_str)
+ # Jump to beginning of line and then kill it with Ctrl+K.
+ input_stream.extend([console.ControlKey.CTRL_A, console.ControlKey.CTRL_K])
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # First, we expect that the input buffer is empty.
+ CheckInputBuffer(self, b"")
+
+ # The buffer position should be 0.
+ CheckInputBufferPosition(self, 0)
+
+ # What we expect to see on the console stream should be the following. The
+ # test string, a jump to the beginning of the line, then jump back to the
+ # end of the line and replace the line with spaces.
+ exp_console_out = test_str
+ # Jump to beginning of line.
+ exp_console_out += OutputStream.MoveCursorLeft(len(test_str))
+ # Jump to end of line.
+ exp_console_out += OutputStream.MoveCursorRight(len(test_str))
+ # Replace line with spaces, which looks like backspaces.
+ for _ in range(len(test_str)):
+ exp_console_out += BACKSPACE_STRING
+
+ # Verify the console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_KillPartialLine(self):
+ """Verify that we can kill a portion of a line."""
+ test_str = b"accelread 0 1"
+ input_stream = BytesToByteList(test_str)
+ len_to_kill = 5
+ for _ in range(len_to_kill):
+ # Move cursor left
+ input_stream.extend(Keys.LEFT_ARROW)
+ # Now kill
+ input_stream.append(console.ControlKey.CTRL_K)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # First, check that the input buffer was truncated.
+ exp_input_buffer = test_str[:-len_to_kill]
+ CheckInputBuffer(self, exp_input_buffer)
+
+ # Verify the input buffer position.
+ CheckInputBufferPosition(self, len(test_str) - len_to_kill)
+
+ # The console output stream that we expect is the test string followed by a
+ # move left of len_to_kill, then a jump to the end of the line and backspace
+ # of len_to_kill.
+ exp_console_out = test_str
+ for _ in range(len_to_kill):
+ # Move left 1 column.
+ exp_console_out += OutputStream.MoveCursorLeft(1)
+ # Then jump to the end of the line
+ exp_console_out += OutputStream.MoveCursorRight(len_to_kill)
+ # Backspace of len_to_kill
+ for _ in range(len_to_kill):
+ exp_console_out += BACKSPACE_STRING
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_InsertingCharacters(self):
+ """Verify that we can insert characters within the line."""
+ test_str = b"accel 0 1" # Here we forgot the 'read' part in 'accelread'
+ input_stream = BytesToByteList(test_str)
+ # We need to move over to the 'l' and add read.
+ insertion_point = test_str.find(b"l") + 1
+ for i in range(len(test_str) - insertion_point):
+ # Move cursor left.
+ input_stream.extend(Keys.LEFT_ARROW)
+ # Now, add in 'read'
+ added_str = b"read"
+ input_stream.extend(BytesToByteList(added_str))
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # First, verify that the input buffer is correct.
+ exp_input_buffer = test_str[:insertion_point] + added_str
+ exp_input_buffer += test_str[insertion_point:]
+ CheckInputBuffer(self, exp_input_buffer)
+
+ # Verify that the input buffer position is correct.
+ exp_input_buffer_pos = insertion_point + len(added_str)
+ CheckInputBufferPosition(self, exp_input_buffer_pos)
+
+ # The console output stream that we expect is the test string, followed by
+ # move cursor left until the 'l' was found, the added test string while
+ # shifting characters around.
+ exp_console_out = test_str
+ for i in range(len(test_str) - insertion_point):
+ # Move cursor left.
+ exp_console_out += OutputStream.MoveCursorLeft(1)
+
+ # Now for each character, write the rest of the line will be shifted to the
+ # right one column.
+ for i in range(len(added_str)):
+ # Printed character.
+ exp_console_out += added_str[i : i + 1]
+ # The rest of the line
+ exp_console_out += test_str[insertion_point:]
+ # Reset the cursor back left
+ reset_dist = len(test_str[insertion_point:])
+ exp_console_out += OutputStream.MoveCursorLeft(reset_dist)
+
+ # Verify the console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_StoreCommandHistory(self):
+ """Verify that entered commands are stored in the history."""
+ test_commands = []
+ test_commands.append(b"help")
+ test_commands.append(b"version")
+ test_commands.append(b"accelread 0 1")
+ input_stream = []
+ for c in test_commands:
+ input_stream.extend(BytesToByteList(c))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # We expect to have the test commands in the history buffer.
+ exp_history_buf = test_commands
+ CheckHistoryBuffer(self, exp_history_buf)
+
+ def test_CycleUpThruCommandHistory(self):
+ """Verify that the UP arrow key will print itmes in the history buffer."""
+ # Enter some commands.
+ test_commands = [b"version", b"accelrange 0", b"battery", b"gettime"]
+ input_stream = []
+ for command in test_commands:
+ input_stream.extend(BytesToByteList(command))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Now, hit the UP arrow key to print the previous entries.
+ for i in range(len(test_commands)):
+ input_stream.extend(Keys.UP_ARROW)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The expected output should be test commands with prompts printed in
+ # between, followed by line kills with the previous test commands printed.
+ exp_console_out = b""
+ for i in range(len(test_commands)):
+ exp_console_out += test_commands[i] + b"\r\n" + self.console.prompt
+
+ # When we press up, the line should be cleared and print the previous buffer
+ # entry.
+ for i in range(len(test_commands) - 1, 0, -1):
+ exp_console_out += test_commands[i]
+ # Backspace to the beginning.
+ for i in range(len(test_commands[i])):
+ exp_console_out += BACKSPACE_STRING
+
+ # The last command should just be printed out with no backspacing.
+ exp_console_out += test_commands[0]
+
+ # Now, verify.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_UpArrowOnEmptyHistory(self):
+ """Ensure nothing happens if the history is empty."""
+ # Press the up arrow key twice.
+ input_stream = 2 * Keys.UP_ARROW
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # We expect nothing to have happened.
+ exp_console_out = b""
+ exp_input_buffer = b""
+ exp_input_buffer_pos = 0
+ exp_history_buf = []
+
+ # Verify.
+ CheckConsoleOutput(self, exp_console_out)
+ CheckInputBufferPosition(self, exp_input_buffer_pos)
+ CheckInputBuffer(self, exp_input_buffer)
+ CheckHistoryBuffer(self, exp_history_buf)
+
+ def test_UpArrowDoesNotGoOutOfBounds(self):
+ """Verify that pressing the up arrow many times won't go out of bounds."""
+ # Enter one command.
+ test_str = b"help version"
+ input_stream = BytesToByteList(test_str)
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ # Then press the up arrow key twice.
+ input_stream.extend(2 * Keys.UP_ARROW)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the history buffer is correct.
+ exp_history_buf = [test_str]
+ CheckHistoryBuffer(self, exp_history_buf)
+
+ # We expect that the console output should only contain our entered command,
+ # a new prompt, and then our command aggain.
+ exp_console_out = test_str + b"\r\n" + self.console.prompt
+ # Pressing up should reprint the command we entered.
+ exp_console_out += test_str
+
+ # Verify.
+ CheckConsoleOutput(self, exp_console_out)
+
+ def test_CycleDownThruCommandHistory(self):
+ """Verify that we can select entries by hitting the down arrow."""
+ # Enter at least 4 commands.
+ test_commands = [b"version", b"accelrange 0", b"battery", b"gettime"]
+ input_stream = []
+ for command in test_commands:
+ input_stream.extend(BytesToByteList(command))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Now, hit the UP arrow key twice to print the previous two entries.
+ for i in range(2):
+ input_stream.extend(Keys.UP_ARROW)
+
+ # Now, hit the DOWN arrow key twice to print the newer entries.
+ input_stream.extend(2 * Keys.DOWN_ARROW)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The expected output should be commands that we entered, followed by
+ # prompts, then followed by our last two commands in reverse. Then, we
+ # should see the last entry in the list, followed by the saved partial cmd
+ # of a blank line.
+ exp_console_out = b""
+ for i in range(len(test_commands)):
+ exp_console_out += test_commands[i] + b"\r\n" + self.console.prompt
+
+ # When we press up, the line should be cleared and print the previous buffer
+ # entry.
+ for i in range(len(test_commands) - 1, 1, -1):
+ exp_console_out += test_commands[i]
+ # Backspace to the beginning.
+ for i in range(len(test_commands[i])):
+ exp_console_out += BACKSPACE_STRING
+
+ # When we press down, it should have cleared the last command (which we
+ # covered with the previous for loop), and then prints the next command.
+ exp_console_out += test_commands[3]
+ for i in range(len(test_commands[3])):
+ exp_console_out += BACKSPACE_STRING
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ # Verify input buffer.
+ exp_input_buffer = b"" # Empty because our partial command was empty.
+ exp_input_buffer_pos = len(exp_input_buffer)
+ CheckInputBuffer(self, exp_input_buffer)
+ CheckInputBufferPosition(self, exp_input_buffer_pos)
+
+ def test_SavingPartialCommandWhenNavigatingHistory(self):
+ """Verify that partial commands are saved when navigating history."""
+ # Enter a command.
+ test_str = b"accelinfo"
+ input_stream = BytesToByteList(test_str)
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Enter a partial command.
+ partial_cmd = b"ver"
+ input_stream.extend(BytesToByteList(partial_cmd))
+
+ # Hit the UP arrow key.
+ input_stream.extend(Keys.UP_ARROW)
+ # Then, the DOWN arrow key.
+ input_stream.extend(Keys.DOWN_ARROW)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The expected output should be the command we entered, a prompt, the
+ # partial command, clearing of the partial command, the command entered,
+ # clearing of the command entered, and then the partial command.
+ exp_console_out = test_str + b"\r\n" + self.console.prompt
+ exp_console_out += partial_cmd
+ for _ in range(len(partial_cmd)):
+ exp_console_out += BACKSPACE_STRING
+ exp_console_out += test_str
+ for _ in range(len(test_str)):
+ exp_console_out += BACKSPACE_STRING
+ exp_console_out += partial_cmd
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ # Verify input buffer.
+ exp_input_buffer = partial_cmd
+ exp_input_buffer_pos = len(exp_input_buffer)
+ CheckInputBuffer(self, exp_input_buffer)
+ CheckInputBufferPosition(self, exp_input_buffer_pos)
+
+ def test_DownArrowOnEmptyHistory(self):
+ """Ensure nothing happens if the history is empty."""
+ # Then press the up down arrow twice.
+ input_stream = 2 * Keys.DOWN_ARROW
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # We expect nothing to have happened.
+ exp_console_out = b""
+ exp_input_buffer = b""
+ exp_input_buffer_pos = 0
+ exp_history_buf = []
+
+ # Verify.
+ CheckConsoleOutput(self, exp_console_out)
+ CheckInputBufferPosition(self, exp_input_buffer_pos)
+ CheckInputBuffer(self, exp_input_buffer)
+ CheckHistoryBuffer(self, exp_history_buf)
+
+ def test_DeleteCharsUsingDELKey(self):
+ """Verify that we can delete characters using the DEL key."""
+ test_str = b"version"
+ input_stream = BytesToByteList(test_str)
+
+ # Hit the left arrow key 2 times.
+ input_stream.extend(2 * Keys.LEFT_ARROW)
+
+ # Press the DEL key.
+ input_stream.extend(Keys.DEL)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The expected output should be the command we entered, 2 individual cursor
+ # moves to the left, and then removing a char and shifting everything to the
+ # left one column.
+ exp_console_out = test_str
+ exp_console_out += 2 * OutputStream.MoveCursorLeft(1)
+
+ # Remove the char by shifting everything to the left one, slicing out the
+ # remove char.
+ exp_console_out += test_str[-1:] + b" "
+
+ # Reset the cursor by moving back 2 columns because of the 'n' and space.
+ exp_console_out += OutputStream.MoveCursorLeft(2)
+
+ # Verify console output.
+ CheckConsoleOutput(self, exp_console_out)
+
+ # Verify input buffer. The input buffer should have the char sliced out and
+ # be positioned where the char was removed.
+ exp_input_buffer = test_str[:-2] + test_str[-1:]
+ exp_input_buffer_pos = len(exp_input_buffer) - 1
+ CheckInputBuffer(self, exp_input_buffer)
+ CheckInputBufferPosition(self, exp_input_buffer_pos)
+
+ def test_RepeatedCommandInHistory(self):
+ """Verify that we don't store 2 consecutive identical commands in history"""
+ # Enter a few commands.
+ test_commands = [b"version", b"accelrange 0", b"battery", b"gettime"]
+ # Repeat the last command.
+ test_commands.append(test_commands[len(test_commands) - 1])
+
+ input_stream = []
+ for command in test_commands:
+ input_stream.extend(BytesToByteList(command))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Verify that the history buffer is correct. The last command, since
+ # it was repeated, should not have been added to the history.
+ exp_history_buf = test_commands[0 : len(test_commands) - 1]
+ CheckHistoryBuffer(self, exp_history_buf)
- Args:
- mock_check: A MagicMock object replacing the CheckForEnhancedECImage()
- method.
- """
- # Set the interrogation mode to always so that we actually interrogate.
- self.console.interrogation_mode = b'always'
-
- # First, assume that the EC interrogations indicate an enhanced EC image.
- mock_check.return_value = True
- # But our current knowledge of the EC image (which was actually the
- # 'previous' EC) was a non-enhanced image.
- self.console.enhanced_ec = False
-
- test_command = b'sysinfo'
- input_stream = []
- input_stream.extend(BytesToByteList(test_command))
-
- expected_calls = []
- # All keystrokes to the console should be directed straight through to the
- # EC until we press the enter key.
- for char in test_command:
- if six.PY3:
- expected_calls.append(mock.call(bytes([char])))
- else:
- expected_calls.append(mock.call(char))
-
- # Press the enter key.
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- # The enter key should not be sent to the pipe since we should negotiate
- # to an enhanced EC image.
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # At this point, we should have negotiated to enhanced.
- self.assertTrue(self.console.enhanced_ec, msg=('Did not negotiate to '
- 'enhanced EC image.'))
-
- # The command would have been dropped however, so verify this...
- CheckInputBuffer(self, b'')
- CheckInputBufferPosition(self, 0)
- # ...and repeat the command.
- input_stream = BytesToByteList(test_command)
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Since we're enhanced now, we should have sent the entire command as one
- # string with no trailing carriage return
- expected_calls.append(mock.call(test_command))
-
- # Verify all of the calls.
- self.console.cmd_pipe.send.assert_has_calls(expected_calls)
-
- @mock.patch('ec3po.console.Console.CheckForEnhancedECImage')
- def test_TransitionFromEnhancedToNonEnhanced(self, mock_check):
- """Verify that we transition correctly to non-enhanced mode.
- Args:
- mock_check: A MagicMock object replacing the CheckForEnhancedECImage()
- method.
- """
- # Set the interrogation mode to always so that we actually interrogate.
- self.console.interrogation_mode = b'always'
-
- # First, assume that the EC interrogations indicate an non-enhanced EC
- # image.
- mock_check.return_value = False
- # But our current knowledge of the EC image (which was actually the
- # 'previous' EC) was an enhanced image.
- self.console.enhanced_ec = True
-
- test_command = b'sysinfo'
- input_stream = []
- input_stream.extend(BytesToByteList(test_command))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # But, we will negotiate to non-enhanced however, dropping this command.
- # Verify this.
- self.assertFalse(self.console.enhanced_ec, msg=('Did not negotiate to'
- 'non-enhanced EC image.'))
- CheckInputBuffer(self, b'')
- CheckInputBufferPosition(self, 0)
-
- # The carriage return should have passed through though.
- expected_calls = []
- expected_calls.append(mock.call(
- six.int2byte(console.ControlKey.CARRIAGE_RETURN)))
-
- # Since the command was dropped, repeat the command.
- input_stream = BytesToByteList(test_command)
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # Since we're not enhanced now, we should have sent each character in the
- # entire command separately and a carriage return.
- for char in test_command:
- if six.PY3:
- expected_calls.append(mock.call(bytes([char])))
- else:
- expected_calls.append(mock.call(char))
- expected_calls.append(mock.call(
- six.int2byte(console.ControlKey.CARRIAGE_RETURN)))
-
- # Verify all of the calls.
- self.console.cmd_pipe.send.assert_has_calls(expected_calls)
-
- def test_EnhancedCheckIfTimedOut(self):
- """Verify that the check returns false if it times out."""
- # Make the debug pipe "time out".
- self.console.dbg_pipe.poll.return_value = False
- self.assertFalse(self.console.CheckForEnhancedECImage())
-
- def test_EnhancedCheckIfACKReceived(self):
- """Verify that the check returns true if the ACK is received."""
- # Make the debug pipe return EC_ACK.
- self.console.dbg_pipe.poll.return_value = True
- self.console.dbg_pipe.recv.return_value = interpreter.EC_ACK
- self.assertTrue(self.console.CheckForEnhancedECImage())
-
- def test_EnhancedCheckIfWrong(self):
- """Verify that the check returns false if byte received is wrong."""
- # Make the debug pipe return the wrong byte.
- self.console.dbg_pipe.poll.return_value = True
- self.console.dbg_pipe.recv.return_value = b'\xff'
- self.assertFalse(self.console.CheckForEnhancedECImage())
-
- def test_EnhancedCheckUsingBuffer(self):
- """Verify that given reboot output, enhanced EC images are detected."""
- enhanced_output_stream = b"""
+class TestConsoleCompatibility(unittest.TestCase):
+ """Verify that console can speak to enhanced and non-enhanced EC images."""
+
+ def setUp(self):
+ """Setup the test harness."""
+ # Setup logging with a timestamp, the module, and the log level.
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"),
+ )
+ # Create a temp file and set both the controller and peripheral PTYs to the
+ # file to create a loopback.
+ self.tempfile = tempfile.TemporaryFile()
+
+ # Mock out the pipes.
+ mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock()
+ self.console = console.Console(
+ self.tempfile.fileno(),
+ self.tempfile,
+ tempfile.TemporaryFile(),
+ mock_pipe_end_0,
+ mock_pipe_end_1,
+ "EC",
+ )
+
+ @mock.patch("ec3po.console.Console.CheckForEnhancedECImage")
+ def test_ActAsPassThruInNonEnhancedMode(self, mock_check):
+ """Verify we simply pass everything thru to non-enhanced ECs.
+
+ Args:
+ mock_check: A MagicMock object replacing the CheckForEnhancedECImage()
+ method.
+ """
+ # Set the interrogation mode to always so that we actually interrogate.
+ self.console.interrogation_mode = b"always"
+
+ # Assume EC interrogations indicate that the image is non-enhanced.
+ mock_check.return_value = False
+
+ # Press enter, followed by the command, and another enter.
+ input_stream = []
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ test_command = b"version"
+ input_stream.extend(BytesToByteList(test_command))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Expected calls to send down the pipe would be each character of the test
+ # command.
+ expected_calls = []
+ expected_calls.append(
+ mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN))
+ )
+ for char in test_command:
+ if six.PY3:
+ expected_calls.append(mock.call(bytes([char])))
+ else:
+ expected_calls.append(mock.call(char))
+ expected_calls.append(
+ mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN))
+ )
+
+ # Verify that the calls happened.
+ self.console.cmd_pipe.send.assert_has_calls(expected_calls)
+
+ # Since we're acting as a pass-thru, the input buffer should be empty and
+ # input_buffer_pos is 0.
+ CheckInputBuffer(self, b"")
+ CheckInputBufferPosition(self, 0)
+
+ @mock.patch("ec3po.console.Console.CheckForEnhancedECImage")
+ def test_TransitionFromNonEnhancedToEnhanced(self, mock_check):
+ """Verify that we transition correctly to enhanced mode.
+
+ Args:
+ mock_check: A MagicMock object replacing the CheckForEnhancedECImage()
+ method.
+ """
+ # Set the interrogation mode to always so that we actually interrogate.
+ self.console.interrogation_mode = b"always"
+
+ # First, assume that the EC interrogations indicate an enhanced EC image.
+ mock_check.return_value = True
+ # But our current knowledge of the EC image (which was actually the
+ # 'previous' EC) was a non-enhanced image.
+ self.console.enhanced_ec = False
+
+ test_command = b"sysinfo"
+ input_stream = []
+ input_stream.extend(BytesToByteList(test_command))
+
+ expected_calls = []
+ # All keystrokes to the console should be directed straight through to the
+ # EC until we press the enter key.
+ for char in test_command:
+ if six.PY3:
+ expected_calls.append(mock.call(bytes([char])))
+ else:
+ expected_calls.append(mock.call(char))
+
+ # Press the enter key.
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ # The enter key should not be sent to the pipe since we should negotiate
+ # to an enhanced EC image.
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # At this point, we should have negotiated to enhanced.
+ self.assertTrue(
+ self.console.enhanced_ec, msg=("Did not negotiate to " "enhanced EC image.")
+ )
+
+ # The command would have been dropped however, so verify this...
+ CheckInputBuffer(self, b"")
+ CheckInputBufferPosition(self, 0)
+ # ...and repeat the command.
+ input_stream = BytesToByteList(test_command)
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Since we're enhanced now, we should have sent the entire command as one
+ # string with no trailing carriage return
+ expected_calls.append(mock.call(test_command))
+
+ # Verify all of the calls.
+ self.console.cmd_pipe.send.assert_has_calls(expected_calls)
+
+ @mock.patch("ec3po.console.Console.CheckForEnhancedECImage")
+ def test_TransitionFromEnhancedToNonEnhanced(self, mock_check):
+ """Verify that we transition correctly to non-enhanced mode.
+
+ Args:
+ mock_check: A MagicMock object replacing the CheckForEnhancedECImage()
+ method.
+ """
+ # Set the interrogation mode to always so that we actually interrogate.
+ self.console.interrogation_mode = b"always"
+
+ # First, assume that the EC interrogations indicate an non-enhanced EC
+ # image.
+ mock_check.return_value = False
+ # But our current knowledge of the EC image (which was actually the
+ # 'previous' EC) was an enhanced image.
+ self.console.enhanced_ec = True
+
+ test_command = b"sysinfo"
+ input_stream = []
+ input_stream.extend(BytesToByteList(test_command))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # But, we will negotiate to non-enhanced however, dropping this command.
+ # Verify this.
+ self.assertFalse(
+ self.console.enhanced_ec,
+ msg=("Did not negotiate to" "non-enhanced EC image."),
+ )
+ CheckInputBuffer(self, b"")
+ CheckInputBufferPosition(self, 0)
+
+ # The carriage return should have passed through though.
+ expected_calls = []
+ expected_calls.append(
+ mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN))
+ )
+
+ # Since the command was dropped, repeat the command.
+ input_stream = BytesToByteList(test_command)
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # Since we're not enhanced now, we should have sent each character in the
+ # entire command separately and a carriage return.
+ for char in test_command:
+ if six.PY3:
+ expected_calls.append(mock.call(bytes([char])))
+ else:
+ expected_calls.append(mock.call(char))
+ expected_calls.append(
+ mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN))
+ )
+
+ # Verify all of the calls.
+ self.console.cmd_pipe.send.assert_has_calls(expected_calls)
+
+ def test_EnhancedCheckIfTimedOut(self):
+ """Verify that the check returns false if it times out."""
+ # Make the debug pipe "time out".
+ self.console.dbg_pipe.poll.return_value = False
+ self.assertFalse(self.console.CheckForEnhancedECImage())
+
+ def test_EnhancedCheckIfACKReceived(self):
+ """Verify that the check returns true if the ACK is received."""
+ # Make the debug pipe return EC_ACK.
+ self.console.dbg_pipe.poll.return_value = True
+ self.console.dbg_pipe.recv.return_value = interpreter.EC_ACK
+ self.assertTrue(self.console.CheckForEnhancedECImage())
+
+ def test_EnhancedCheckIfWrong(self):
+ """Verify that the check returns false if byte received is wrong."""
+ # Make the debug pipe return the wrong byte.
+ self.console.dbg_pipe.poll.return_value = True
+ self.console.dbg_pipe.recv.return_value = b"\xff"
+ self.assertFalse(self.console.CheckForEnhancedECImage())
+
+ def test_EnhancedCheckUsingBuffer(self):
+ """Verify that given reboot output, enhanced EC images are detected."""
+ enhanced_output_stream = b"""
--- UART initialized after reboot ---
[Reset cause: reset-pin soft]
[Image: RO, jerry_v1.1.4363-2af8572-dirty 2016-02-23 13:26:20 aaboagye@lithium.mtv.corp.google.com]
@@ -1295,19 +1343,19 @@ Enhanced Console is enabled (v1.0.0); type HELP for help.
[0.224060 hash done 41dac382e3a6e3d2ea5b4d789c1bc46525cae7cc5ff6758f0de8d8369b506f57]
[0.375150 POWER_GOOD seen]
"""
- for line in enhanced_output_stream.split(b'\n'):
- self.console.CheckBufferForEnhancedImage(line)
+ for line in enhanced_output_stream.split(b"\n"):
+ self.console.CheckBufferForEnhancedImage(line)
- # Since the enhanced console string was present in the output, the console
- # should have caught it.
- self.assertTrue(self.console.enhanced_ec)
+ # Since the enhanced console string was present in the output, the console
+ # should have caught it.
+ self.assertTrue(self.console.enhanced_ec)
- # Also should check that the command was sent to the interpreter.
- self.console.cmd_pipe.send.assert_called_once_with(b'enhanced True')
+ # Also should check that the command was sent to the interpreter.
+ self.console.cmd_pipe.send.assert_called_once_with(b"enhanced True")
- # Now test the non-enhanced EC image.
- self.console.cmd_pipe.reset_mock()
- non_enhanced_output_stream = b"""
+ # Now test the non-enhanced EC image.
+ self.console.cmd_pipe.reset_mock()
+ non_enhanced_output_stream = b"""
--- UART initialized after reboot ---
[Reset cause: reset-pin soft]
[Image: RO, jerry_v1.1.4363-2af8572-dirty 2016-02-23 13:03:15 aaboagye@lithium.mtv.corp.google.com]
@@ -1331,239 +1379,253 @@ Console is enabled; type HELP for help.
[0.010285 power on 2]
[0.010385 power state 5 = S5->S3, in 0x0000]
"""
- for line in non_enhanced_output_stream.split(b'\n'):
- self.console.CheckBufferForEnhancedImage(line)
+ for line in non_enhanced_output_stream.split(b"\n"):
+ self.console.CheckBufferForEnhancedImage(line)
- # Since the default console string is present in the output, it should be
- # determined to be non enhanced now.
- self.assertFalse(self.console.enhanced_ec)
+ # Since the default console string is present in the output, it should be
+ # determined to be non enhanced now.
+ self.assertFalse(self.console.enhanced_ec)
- # Check that command was also sent to the interpreter.
- self.console.cmd_pipe.send.assert_called_once_with(b'enhanced False')
+ # Check that command was also sent to the interpreter.
+ self.console.cmd_pipe.send.assert_called_once_with(b"enhanced False")
class TestOOBMConsoleCommands(unittest.TestCase):
- """Verify that OOBM console commands work correctly."""
- def setUp(self):
- """Setup the test harness."""
- # Setup logging with a timestamp, the module, and the log level.
- logging.basicConfig(level=logging.DEBUG,
- format=('%(asctime)s - %(module)s -'
- ' %(levelname)s - %(message)s'))
- # Create a temp file and set both the controller and peripheral PTYs to the
- # file to create a loopback.
- self.tempfile = tempfile.TemporaryFile()
-
- # Mock out the pipes.
- mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock()
- self.console = console.Console(self.tempfile.fileno(), self.tempfile,
- tempfile.TemporaryFile(),
- mock_pipe_end_0, mock_pipe_end_1, "EC")
- self.console.oobm_queue = mock.MagicMock()
-
- @mock.patch('ec3po.console.Console.CheckForEnhancedECImage')
- def test_InterrogateCommand(self, mock_check):
- """Verify that 'interrogate' command works as expected.
-
- Args:
- mock_check: A MagicMock object replacing the CheckForEnhancedECIMage()
- method.
- """
- input_stream = []
- expected_calls = []
- mock_check.side_effect = [False]
-
- # 'interrogate never' should disable the interrogation from happening at
- # all.
- cmd = b'interrogate never'
- # Enter the OOBM prompt.
- input_stream.extend(BytesToByteList(b'%'))
- # Type the command
- input_stream.extend(BytesToByteList(cmd))
- # Press enter.
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- input_stream = []
-
- # The OOBM queue should have been called with the command being put.
- expected_calls.append(mock.call.put(cmd))
- self.console.oobm_queue.assert_has_calls(expected_calls)
-
- # Process the OOBM queue.
- self.console.oobm_queue.get.side_effect = [cmd]
- self.console.ProcessOOBMQueue()
-
- # Type out a few commands.
- input_stream.extend(BytesToByteList(b'version'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'flashinfo'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'sysinfo'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The Check function should NOT have been called at all.
- mock_check.assert_not_called()
-
- # The EC image should be assumed to be not enhanced.
- self.assertFalse(self.console.enhanced_ec, 'The image should be assumed to'
- ' be NOT enhanced.')
-
- # Reset the mocks.
- mock_check.reset_mock()
- self.console.oobm_queue.reset_mock()
-
- # 'interrogate auto' should not interrogate at all. It should only be
- # scanning the output stream for the 'console is enabled' strings.
- cmd = b'interrogate auto'
- # Enter the OOBM prompt.
- input_stream.extend(BytesToByteList(b'%'))
- # Type the command
- input_stream.extend(BytesToByteList(cmd))
- # Press enter.
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- input_stream = []
- expected_calls = []
-
- # The OOBM queue should have been called with the command being put.
- expected_calls.append(mock.call.put(cmd))
- self.console.oobm_queue.assert_has_calls(expected_calls)
-
- # Process the OOBM queue.
- self.console.oobm_queue.get.side_effect = [cmd]
- self.console.ProcessOOBMQueue()
-
- # Type out a few commands.
- input_stream.extend(BytesToByteList(b'version'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'flashinfo'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'sysinfo'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The Check function should NOT have been called at all.
- mock_check.assert_not_called()
-
- # The EC image should be assumed to be not enhanced.
- self.assertFalse(self.console.enhanced_ec, 'The image should be assumed to'
- ' be NOT enhanced.')
-
- # Reset the mocks.
- mock_check.reset_mock()
- self.console.oobm_queue.reset_mock()
-
- # 'interrogate always' should, like its name implies, interrogate always
- # after each press of the enter key. This was the former way of doing
- # interrogation.
- cmd = b'interrogate always'
- # Enter the OOBM prompt.
- input_stream.extend(BytesToByteList(b'%'))
- # Type the command
- input_stream.extend(BytesToByteList(cmd))
- # Press enter.
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- input_stream = []
- expected_calls = []
-
- # The OOBM queue should have been called with the command being put.
- expected_calls.append(mock.call.put(cmd))
- self.console.oobm_queue.assert_has_calls(expected_calls)
-
- # Process the OOBM queue.
- self.console.oobm_queue.get.side_effect = [cmd]
- self.console.ProcessOOBMQueue()
-
- # The Check method should be called 3 times here.
- mock_check.side_effect = [False, False, False]
-
- # Type out a few commands.
- input_stream.extend(BytesToByteList(b'help list'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'taskinfo'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'hibdelay'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The Check method should have been called 3 times here.
- expected_calls = [mock.call(), mock.call(), mock.call()]
- mock_check.assert_has_calls(expected_calls)
-
- # The EC image should be assumed to be not enhanced.
- self.assertFalse(self.console.enhanced_ec, 'The image should be assumed to'
- ' be NOT enhanced.')
-
- # Now, let's try to assume that the image is enhanced while still disabling
- # interrogation.
- mock_check.reset_mock()
- self.console.oobm_queue.reset_mock()
- input_stream = []
- cmd = b'interrogate never enhanced'
- # Enter the OOBM prompt.
- input_stream.extend(BytesToByteList(b'%'))
- # Type the command
- input_stream.extend(BytesToByteList(cmd))
- # Press enter.
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- input_stream = []
- expected_calls = []
-
- # The OOBM queue should have been called with the command being put.
- expected_calls.append(mock.call.put(cmd))
- self.console.oobm_queue.assert_has_calls(expected_calls)
-
- # Process the OOBM queue.
- self.console.oobm_queue.get.side_effect = [cmd]
- self.console.ProcessOOBMQueue()
-
- # Type out a few commands.
- input_stream.extend(BytesToByteList(b'chgstate'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'hash'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
- input_stream.extend(BytesToByteList(b'sysjump rw'))
- input_stream.append(console.ControlKey.CARRIAGE_RETURN)
-
- # Send the sequence out.
- for byte in input_stream:
- self.console.HandleChar(byte)
-
- # The check method should have never been called.
- mock_check.assert_not_called()
-
- # The EC image should be assumed to be enhanced.
- self.assertTrue(self.console.enhanced_ec, 'The image should be'
- ' assumed to be enhanced.')
-
-
-if __name__ == '__main__':
- unittest.main()
+ """Verify that OOBM console commands work correctly."""
+
+ def setUp(self):
+ """Setup the test harness."""
+ # Setup logging with a timestamp, the module, and the log level.
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"),
+ )
+ # Create a temp file and set both the controller and peripheral PTYs to the
+ # file to create a loopback.
+ self.tempfile = tempfile.TemporaryFile()
+
+ # Mock out the pipes.
+ mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock()
+ self.console = console.Console(
+ self.tempfile.fileno(),
+ self.tempfile,
+ tempfile.TemporaryFile(),
+ mock_pipe_end_0,
+ mock_pipe_end_1,
+ "EC",
+ )
+ self.console.oobm_queue = mock.MagicMock()
+
+ @mock.patch("ec3po.console.Console.CheckForEnhancedECImage")
+ def test_InterrogateCommand(self, mock_check):
+ """Verify that 'interrogate' command works as expected.
+
+ Args:
+ mock_check: A MagicMock object replacing the CheckForEnhancedECIMage()
+ method.
+ """
+ input_stream = []
+ expected_calls = []
+ mock_check.side_effect = [False]
+
+ # 'interrogate never' should disable the interrogation from happening at
+ # all.
+ cmd = b"interrogate never"
+ # Enter the OOBM prompt.
+ input_stream.extend(BytesToByteList(b"%"))
+ # Type the command
+ input_stream.extend(BytesToByteList(cmd))
+ # Press enter.
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ input_stream = []
+
+ # The OOBM queue should have been called with the command being put.
+ expected_calls.append(mock.call.put(cmd))
+ self.console.oobm_queue.assert_has_calls(expected_calls)
+
+ # Process the OOBM queue.
+ self.console.oobm_queue.get.side_effect = [cmd]
+ self.console.ProcessOOBMQueue()
+
+ # Type out a few commands.
+ input_stream.extend(BytesToByteList(b"version"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"flashinfo"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"sysinfo"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The Check function should NOT have been called at all.
+ mock_check.assert_not_called()
+
+ # The EC image should be assumed to be not enhanced.
+ self.assertFalse(
+ self.console.enhanced_ec,
+ "The image should be assumed to" " be NOT enhanced.",
+ )
+
+ # Reset the mocks.
+ mock_check.reset_mock()
+ self.console.oobm_queue.reset_mock()
+
+ # 'interrogate auto' should not interrogate at all. It should only be
+ # scanning the output stream for the 'console is enabled' strings.
+ cmd = b"interrogate auto"
+ # Enter the OOBM prompt.
+ input_stream.extend(BytesToByteList(b"%"))
+ # Type the command
+ input_stream.extend(BytesToByteList(cmd))
+ # Press enter.
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ input_stream = []
+ expected_calls = []
+
+ # The OOBM queue should have been called with the command being put.
+ expected_calls.append(mock.call.put(cmd))
+ self.console.oobm_queue.assert_has_calls(expected_calls)
+
+ # Process the OOBM queue.
+ self.console.oobm_queue.get.side_effect = [cmd]
+ self.console.ProcessOOBMQueue()
+
+ # Type out a few commands.
+ input_stream.extend(BytesToByteList(b"version"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"flashinfo"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"sysinfo"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The Check function should NOT have been called at all.
+ mock_check.assert_not_called()
+
+ # The EC image should be assumed to be not enhanced.
+ self.assertFalse(
+ self.console.enhanced_ec,
+ "The image should be assumed to" " be NOT enhanced.",
+ )
+
+ # Reset the mocks.
+ mock_check.reset_mock()
+ self.console.oobm_queue.reset_mock()
+
+ # 'interrogate always' should, like its name implies, interrogate always
+ # after each press of the enter key. This was the former way of doing
+ # interrogation.
+ cmd = b"interrogate always"
+ # Enter the OOBM prompt.
+ input_stream.extend(BytesToByteList(b"%"))
+ # Type the command
+ input_stream.extend(BytesToByteList(cmd))
+ # Press enter.
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ input_stream = []
+ expected_calls = []
+
+ # The OOBM queue should have been called with the command being put.
+ expected_calls.append(mock.call.put(cmd))
+ self.console.oobm_queue.assert_has_calls(expected_calls)
+
+ # Process the OOBM queue.
+ self.console.oobm_queue.get.side_effect = [cmd]
+ self.console.ProcessOOBMQueue()
+
+ # The Check method should be called 3 times here.
+ mock_check.side_effect = [False, False, False]
+
+ # Type out a few commands.
+ input_stream.extend(BytesToByteList(b"help list"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"taskinfo"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"hibdelay"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The Check method should have been called 3 times here.
+ expected_calls = [mock.call(), mock.call(), mock.call()]
+ mock_check.assert_has_calls(expected_calls)
+
+ # The EC image should be assumed to be not enhanced.
+ self.assertFalse(
+ self.console.enhanced_ec,
+ "The image should be assumed to" " be NOT enhanced.",
+ )
+
+ # Now, let's try to assume that the image is enhanced while still disabling
+ # interrogation.
+ mock_check.reset_mock()
+ self.console.oobm_queue.reset_mock()
+ input_stream = []
+ cmd = b"interrogate never enhanced"
+ # Enter the OOBM prompt.
+ input_stream.extend(BytesToByteList(b"%"))
+ # Type the command
+ input_stream.extend(BytesToByteList(cmd))
+ # Press enter.
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ input_stream = []
+ expected_calls = []
+
+ # The OOBM queue should have been called with the command being put.
+ expected_calls.append(mock.call.put(cmd))
+ self.console.oobm_queue.assert_has_calls(expected_calls)
+
+ # Process the OOBM queue.
+ self.console.oobm_queue.get.side_effect = [cmd]
+ self.console.ProcessOOBMQueue()
+
+ # Type out a few commands.
+ input_stream.extend(BytesToByteList(b"chgstate"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"hash"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+ input_stream.extend(BytesToByteList(b"sysjump rw"))
+ input_stream.append(console.ControlKey.CARRIAGE_RETURN)
+
+ # Send the sequence out.
+ for byte in input_stream:
+ self.console.HandleChar(byte)
+
+ # The check method should have never been called.
+ mock_check.assert_not_called()
+
+ # The EC image should be assumed to be enhanced.
+ self.assertTrue(
+ self.console.enhanced_ec, "The image should be" " assumed to be enhanced."
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/util/ec3po/interpreter.py b/util/ec3po/interpreter.py
index 4e151083bd..591603e038 100644
--- a/util/ec3po/interpreter.py
+++ b/util/ec3po/interpreter.py
@@ -25,443 +25,447 @@ import traceback
import six
-
COMMAND_RETRIES = 3 # Number of attempts to retry a command.
EC_MAX_READ = 1024 # Max bytes to read at a time from the EC.
-EC_SYN = b'\xec' # Byte indicating EC interrogation.
-EC_ACK = b'\xc0' # Byte representing correct EC response to interrogation.
+EC_SYN = b"\xec" # Byte indicating EC interrogation.
+EC_ACK = b"\xc0" # Byte representing correct EC response to interrogation.
class LoggerAdapter(logging.LoggerAdapter):
- """Class which provides a small adapter for the logger."""
+ """Class which provides a small adapter for the logger."""
- def process(self, msg, kwargs):
- """Prepends the served PTY to the beginning of the log message."""
- return '%s - %s' % (self.extra['pty'], msg), kwargs
+ def process(self, msg, kwargs):
+ """Prepends the served PTY to the beginning of the log message."""
+ return "%s - %s" % (self.extra["pty"], msg), kwargs
class Interpreter(object):
- """Class which provides the interpretation layer between the EC and user.
-
- This class essentially performs all of the intepretation for the EC and the
- user. It handles all of the automatic command retrying as well as the
- formation of commands for EC images which support that.
-
- Attributes:
- logger: A logger for this module.
- ec_uart_pty: An opened file object to the raw EC UART PTY.
- ec_uart_pty_name: A string containing the name of the raw EC UART PTY.
- cmd_pipe: A socket.socket or multiprocessing.Connection object which
- represents the Interpreter side of the command pipe. This must be a
- bidirectional pipe. Commands and responses will utilize this pipe.
- dbg_pipe: A socket.socket or multiprocessing.Connection object which
- represents the Interpreter side of the debug pipe. This must be a
- unidirectional pipe with write capabilities. EC debug output will utilize
- this pipe.
- cmd_retries: An integer representing the number of attempts the console
- should retry commands if it receives an error.
- log_level: An integer representing the numeric value of the log level.
- inputs: A list of objects that the intpreter selects for reading.
- Initially, these are the EC UART and the command pipe.
- outputs: A list of objects that the interpreter selects for writing.
- ec_cmd_queue: A FIFO queue used for sending commands down to the EC UART.
- last_cmd: A string that represents the last command sent to the EC. If an
- error is encountered, the interpreter will attempt to retry this command
- up to COMMAND_RETRIES.
- enhanced_ec: A boolean indicating if the EC image that we are currently
- communicating with is enhanced or not. Enhanced EC images will support
- packed commands and host commands over the UART. This defaults to False
- and is changed depending on the result of an interrogation.
- interrogating: A boolean indicating if we are in the middle of interrogating
- the EC.
- connected: A boolean indicating if the interpreter is actually connected to
- the UART and listening.
- """
- def __init__(self, ec_uart_pty, cmd_pipe, dbg_pipe, log_level=logging.INFO,
- name=None):
- """Intializes an Interpreter object with the provided args.
+ """Class which provides the interpretation layer between the EC and user.
- Args:
- ec_uart_pty: A string representing the EC UART to connect to.
+ This class essentially performs all of the intepretation for the EC and the
+ user. It handles all of the automatic command retrying as well as the
+ formation of commands for EC images which support that.
+
+ Attributes:
+ logger: A logger for this module.
+ ec_uart_pty: An opened file object to the raw EC UART PTY.
+ ec_uart_pty_name: A string containing the name of the raw EC UART PTY.
cmd_pipe: A socket.socket or multiprocessing.Connection object which
represents the Interpreter side of the command pipe. This must be a
bidirectional pipe. Commands and responses will utilize this pipe.
dbg_pipe: A socket.socket or multiprocessing.Connection object which
represents the Interpreter side of the debug pipe. This must be a
- unidirectional pipe with write capabilities. EC debug output will
- utilize this pipe.
+ unidirectional pipe with write capabilities. EC debug output will utilize
+ this pipe.
cmd_retries: An integer representing the number of attempts the console
should retry commands if it receives an error.
- log_level: An optional integer representing the numeric value of the log
- level. By default, the log level will be logging.INFO (20).
- name: the console source name
- """
- # Create a unique logger based on the interpreter name
- interpreter_prefix = ('%s - ' % name) if name else ''
- logger = logging.getLogger('%sEC3PO.Interpreter' % interpreter_prefix)
- self.logger = LoggerAdapter(logger, {'pty': ec_uart_pty})
- # TODO(https://crbug.com/1162189): revist the 2 TODOs below
- # TODO(https://bugs.python.org/issue27805, python3.7+): revert to ab+
- # TODO(https://bugs.python.org/issue20074): removing buffering=0 if/when
- # that gets fixed, or keep two pty: one for reading and one for writing
- self.ec_uart_pty = open(ec_uart_pty, 'r+b', buffering=0)
- self.ec_uart_pty_name = ec_uart_pty
- self.cmd_pipe = cmd_pipe
- self.dbg_pipe = dbg_pipe
- self.cmd_retries = COMMAND_RETRIES
- self.log_level = log_level
- self.inputs = [self.ec_uart_pty, self.cmd_pipe]
- self.outputs = []
- self.ec_cmd_queue = six.moves.queue.Queue()
- self.last_cmd = b''
- self.enhanced_ec = False
- self.interrogating = False
- self.connected = True
-
- def __str__(self):
- """Show internal state of the Interpreter object.
-
- Returns:
- A string that shows the values of the attributes.
- """
- string = []
- string.append('%r' % self)
- string.append('ec_uart_pty: %s' % self.ec_uart_pty)
- string.append('cmd_pipe: %r' % self.cmd_pipe)
- string.append('dbg_pipe: %r' % self.dbg_pipe)
- string.append('cmd_retries: %d' % self.cmd_retries)
- string.append('log_level: %d' % self.log_level)
- string.append('inputs: %r' % self.inputs)
- string.append('outputs: %r' % self.outputs)
- string.append('ec_cmd_queue: %r' % self.ec_cmd_queue)
- string.append('last_cmd: \'%s\'' % self.last_cmd)
- string.append('enhanced_ec: %r' % self.enhanced_ec)
- string.append('interrogating: %r' % self.interrogating)
- return '\n'.join(string)
-
- def EnqueueCmd(self, command):
- """Enqueue a command to be sent to the EC UART.
-
- Args:
- command: A string which contains the command to be sent.
+ log_level: An integer representing the numeric value of the log level.
+ inputs: A list of objects that the intpreter selects for reading.
+ Initially, these are the EC UART and the command pipe.
+ outputs: A list of objects that the interpreter selects for writing.
+ ec_cmd_queue: A FIFO queue used for sending commands down to the EC UART.
+ last_cmd: A string that represents the last command sent to the EC. If an
+ error is encountered, the interpreter will attempt to retry this command
+ up to COMMAND_RETRIES.
+ enhanced_ec: A boolean indicating if the EC image that we are currently
+ communicating with is enhanced or not. Enhanced EC images will support
+ packed commands and host commands over the UART. This defaults to False
+ and is changed depending on the result of an interrogation.
+ interrogating: A boolean indicating if we are in the middle of interrogating
+ the EC.
+ connected: A boolean indicating if the interpreter is actually connected to
+ the UART and listening.
"""
- self.ec_cmd_queue.put(command)
- self.logger.log(1, 'Commands now in queue: %d', self.ec_cmd_queue.qsize())
- # Add the EC UART as an output to be serviced.
- if self.connected and self.ec_uart_pty not in self.outputs:
- self.outputs.append(self.ec_uart_pty)
-
- def PackCommand(self, raw_cmd):
- r"""Packs a command for use with error checking.
-
- For error checking, we pack console commands in a particular format. The
- format is as follows:
-
- &&[x][x][x][x]&{cmd}\n\n
- ^ ^ ^^ ^^ ^ ^-- 2 newlines.
- | | || || |-- the raw console command.
- | | || ||-- 1 ampersand.
- | | ||____|--- 2 hex digits representing the CRC8 of cmd.
- | |____|-- 2 hex digits reprsenting the length of cmd.
- |-- 2 ampersands
-
- Args:
- raw_cmd: A pre-packed string which contains the raw command.
-
- Returns:
- A string which contains the packed command.
- """
- # Don't pack a single carriage return.
- if raw_cmd != b'\r':
- # The command format is as follows.
- # &&[x][x][x][x]&{cmd}\n\n
- packed_cmd = []
- packed_cmd.append(b'&&')
- # The first pair of hex digits are the length of the command.
- packed_cmd.append(b'%02x' % len(raw_cmd))
- # Then the CRC8 of cmd.
- packed_cmd.append(b'%02x' % Crc8(raw_cmd))
- packed_cmd.append(b'&')
- # Now, the raw command followed by 2 newlines.
- packed_cmd.append(raw_cmd)
- packed_cmd.append(b'\n\n')
- return b''.join(packed_cmd)
- else:
- return raw_cmd
-
- def ProcessCommand(self, command):
- """Captures the input determines what actions to take.
-
- Args:
- command: A string representing the command sent by the user.
- """
- if command == b'disconnect':
- if self.connected:
- self.logger.debug('UART disconnect request.')
- # Drop all pending commands if any.
- while not self.ec_cmd_queue.empty():
- c = self.ec_cmd_queue.get()
- self.logger.debug('dropped: \'%s\'', c)
- if self.enhanced_ec:
- # Reset retry state.
- self.cmd_retries = COMMAND_RETRIES
- self.last_cmd = b''
- # Get the UART that the interpreter is attached to.
- fileobj = self.ec_uart_pty
- self.logger.debug('fileobj: %r', fileobj)
- # Remove the descriptor from the inputs and outputs.
- self.inputs.remove(fileobj)
- if fileobj in self.outputs:
- self.outputs.remove(fileobj)
- self.logger.debug('Removed fileobj. Remaining inputs: %r', self.inputs)
- # Close the file.
- fileobj.close()
- # Mark the interpreter as disconnected now.
- self.connected = False
- self.logger.debug('Disconnected from %s.', self.ec_uart_pty_name)
- return
-
- elif command == b'reconnect':
- if not self.connected:
- self.logger.debug('UART reconnect request.')
- # Reopen the PTY.
+ def __init__(
+ self, ec_uart_pty, cmd_pipe, dbg_pipe, log_level=logging.INFO, name=None
+ ):
+ """Intializes an Interpreter object with the provided args.
+
+ Args:
+ ec_uart_pty: A string representing the EC UART to connect to.
+ cmd_pipe: A socket.socket or multiprocessing.Connection object which
+ represents the Interpreter side of the command pipe. This must be a
+ bidirectional pipe. Commands and responses will utilize this pipe.
+ dbg_pipe: A socket.socket or multiprocessing.Connection object which
+ represents the Interpreter side of the debug pipe. This must be a
+ unidirectional pipe with write capabilities. EC debug output will
+ utilize this pipe.
+ cmd_retries: An integer representing the number of attempts the console
+ should retry commands if it receives an error.
+ log_level: An optional integer representing the numeric value of the log
+ level. By default, the log level will be logging.INFO (20).
+ name: the console source name
+ """
+ # Create a unique logger based on the interpreter name
+ interpreter_prefix = ("%s - " % name) if name else ""
+ logger = logging.getLogger("%sEC3PO.Interpreter" % interpreter_prefix)
+ self.logger = LoggerAdapter(logger, {"pty": ec_uart_pty})
+ # TODO(https://crbug.com/1162189): revist the 2 TODOs below
# TODO(https://bugs.python.org/issue27805, python3.7+): revert to ab+
# TODO(https://bugs.python.org/issue20074): removing buffering=0 if/when
# that gets fixed, or keep two pty: one for reading and one for writing
- fileobj = open(self.ec_uart_pty_name, 'r+b', buffering=0)
- self.logger.debug('fileobj: %r', fileobj)
- self.ec_uart_pty = fileobj
- # Add the descriptor to the inputs.
- self.inputs.append(fileobj)
- self.logger.debug('fileobj added. curr inputs: %r', self.inputs)
- # Mark the interpreter as connected now.
- self.connected = True
- self.logger.debug('Connected to %s.', self.ec_uart_pty_name)
- return
-
- elif command.startswith(b'enhanced'):
- self.enhanced_ec = command.split(b' ')[1] == b'True'
- return
-
- # Ignore any other commands while in the disconnected state.
- self.logger.log(1, 'command: \'%s\'', command)
- if not self.connected:
- self.logger.debug('Ignoring command because currently disconnected.')
- return
-
- # Remove leading and trailing spaces only if this is an enhanced EC image.
- # For non-enhanced EC images, commands will be single characters at a time
- # and can be spaces.
- if self.enhanced_ec:
- command = command.strip(b' ')
-
- # There's nothing to do if the command is empty.
- if len(command) == 0:
- return
-
- # Handle log level change requests.
- if command.startswith(b'loglevel'):
- self.logger.debug('Log level change request.')
- new_log_level = int(command.split(b' ')[1])
- self.logger.logger.setLevel(new_log_level)
- self.logger.info('Log level changed to %d.', new_log_level)
- return
-
- # Check for interrogation command.
- if command == EC_SYN:
- # User is requesting interrogation. Send SYN as is.
- self.logger.debug('User requesting interrogation.')
- self.interrogating = True
- # Assume the EC isn't enhanced until we get a response.
- self.enhanced_ec = False
- elif self.enhanced_ec:
- # Enhanced EC images require the plaintext commands to be packed.
- command = self.PackCommand(command)
- # TODO(aaboagye): Make a dict of commands and keys and eventually,
- # handle partial matching based on unique prefixes.
-
- self.EnqueueCmd(command)
-
- def HandleCmdRetries(self):
- """Attempts to retry commands if possible."""
- if self.cmd_retries > 0:
- # The EC encountered an error. We'll have to retry again.
- self.logger.warning('Retrying command...')
- self.cmd_retries -= 1
- self.logger.warning('Retries remaining: %d', self.cmd_retries)
- # Retry the command and add the EC UART to the writers again.
- self.EnqueueCmd(self.last_cmd)
- self.outputs.append(self.ec_uart_pty)
- else:
- # We're out of retries, so just give up.
- self.logger.error('Command failed. No retries left.')
- # Clear the command in progress.
- self.last_cmd = b''
- # Reset the retry count.
- self.cmd_retries = COMMAND_RETRIES
-
- def SendCmdToEC(self):
- """Sends a command to the EC."""
- # If we're retrying a command, just try to send it again.
- if self.cmd_retries < COMMAND_RETRIES:
- cmd = self.last_cmd
- else:
- # If we're not retrying, we should not be writing to the EC if we have no
- # items in our command queue.
- assert not self.ec_cmd_queue.empty()
- # Get the command to send.
- cmd = self.ec_cmd_queue.get()
-
- # Send the command.
- self.ec_uart_pty.write(cmd)
- self.ec_uart_pty.flush()
- self.logger.log(1, 'Sent command to EC.')
-
- if self.enhanced_ec and cmd != EC_SYN:
- # Now, that we've sent the command, store the current command as the last
- # command sent. If we encounter an error string, we will attempt to retry
- # this command.
- if cmd != self.last_cmd:
- self.last_cmd = cmd
- # Reset the retry count.
+ self.ec_uart_pty = open(ec_uart_pty, "r+b", buffering=0)
+ self.ec_uart_pty_name = ec_uart_pty
+ self.cmd_pipe = cmd_pipe
+ self.dbg_pipe = dbg_pipe
self.cmd_retries = COMMAND_RETRIES
+ self.log_level = log_level
+ self.inputs = [self.ec_uart_pty, self.cmd_pipe]
+ self.outputs = []
+ self.ec_cmd_queue = six.moves.queue.Queue()
+ self.last_cmd = b""
+ self.enhanced_ec = False
+ self.interrogating = False
+ self.connected = True
- # If no command is pending to be sent, then we can remove the EC UART from
- # writers. Might need better checking for command retry logic in here.
- if self.ec_cmd_queue.empty():
- # Remove the EC UART from the writers while we wait for a response.
- self.logger.debug('Removing EC UART from writers.')
- self.outputs.remove(self.ec_uart_pty)
-
- def HandleECData(self):
- """Handle any debug prints from the EC."""
- self.logger.log(1, 'EC has data')
- # Read what the EC sent us.
- data = os.read(self.ec_uart_pty.fileno(), EC_MAX_READ)
- self.logger.log(1, 'got: \'%s\'', binascii.hexlify(data))
- if b'&E' in data and self.enhanced_ec:
- # We received an error, so we should retry it if possible.
- self.logger.warning('Error string found in data.')
- self.HandleCmdRetries()
- return
-
- # If we were interrogating, check the response and update our knowledge
- # of the current EC image.
- if self.interrogating:
- self.enhanced_ec = data == EC_ACK
- if self.enhanced_ec:
- self.logger.debug('The current EC image seems enhanced.')
- else:
- self.logger.debug('The current EC image does NOT seem enhanced.')
- # Done interrogating.
- self.interrogating = False
- # For now, just forward everything the EC sends us.
- self.logger.log(1, 'Forwarding to user...')
- self.dbg_pipe.send(data)
-
- def HandleUserData(self):
- """Handle any incoming commands from the user.
-
- Raises:
- EOFError: Allowed to propagate through from self.cmd_pipe.recv().
- """
- self.logger.log(1, 'Command data available. Begin processing.')
- data = self.cmd_pipe.recv()
- # Process the command.
- self.ProcessCommand(data)
+ def __str__(self):
+ """Show internal state of the Interpreter object.
+
+ Returns:
+ A string that shows the values of the attributes.
+ """
+ string = []
+ string.append("%r" % self)
+ string.append("ec_uart_pty: %s" % self.ec_uart_pty)
+ string.append("cmd_pipe: %r" % self.cmd_pipe)
+ string.append("dbg_pipe: %r" % self.dbg_pipe)
+ string.append("cmd_retries: %d" % self.cmd_retries)
+ string.append("log_level: %d" % self.log_level)
+ string.append("inputs: %r" % self.inputs)
+ string.append("outputs: %r" % self.outputs)
+ string.append("ec_cmd_queue: %r" % self.ec_cmd_queue)
+ string.append("last_cmd: '%s'" % self.last_cmd)
+ string.append("enhanced_ec: %r" % self.enhanced_ec)
+ string.append("interrogating: %r" % self.interrogating)
+ return "\n".join(string)
+
+ def EnqueueCmd(self, command):
+ """Enqueue a command to be sent to the EC UART.
+
+ Args:
+ command: A string which contains the command to be sent.
+ """
+ self.ec_cmd_queue.put(command)
+ self.logger.log(1, "Commands now in queue: %d", self.ec_cmd_queue.qsize())
+
+ # Add the EC UART as an output to be serviced.
+ if self.connected and self.ec_uart_pty not in self.outputs:
+ self.outputs.append(self.ec_uart_pty)
+
+ def PackCommand(self, raw_cmd):
+ r"""Packs a command for use with error checking.
+
+ For error checking, we pack console commands in a particular format. The
+ format is as follows:
+
+ &&[x][x][x][x]&{cmd}\n\n
+ ^ ^ ^^ ^^ ^ ^-- 2 newlines.
+ | | || || |-- the raw console command.
+ | | || ||-- 1 ampersand.
+ | | ||____|--- 2 hex digits representing the CRC8 of cmd.
+ | |____|-- 2 hex digits reprsenting the length of cmd.
+ |-- 2 ampersands
+
+ Args:
+ raw_cmd: A pre-packed string which contains the raw command.
+
+ Returns:
+ A string which contains the packed command.
+ """
+ # Don't pack a single carriage return.
+ if raw_cmd != b"\r":
+ # The command format is as follows.
+ # &&[x][x][x][x]&{cmd}\n\n
+ packed_cmd = []
+ packed_cmd.append(b"&&")
+ # The first pair of hex digits are the length of the command.
+ packed_cmd.append(b"%02x" % len(raw_cmd))
+ # Then the CRC8 of cmd.
+ packed_cmd.append(b"%02x" % Crc8(raw_cmd))
+ packed_cmd.append(b"&")
+ # Now, the raw command followed by 2 newlines.
+ packed_cmd.append(raw_cmd)
+ packed_cmd.append(b"\n\n")
+ return b"".join(packed_cmd)
+ else:
+ return raw_cmd
+
+ def ProcessCommand(self, command):
+ """Captures the input determines what actions to take.
+
+ Args:
+ command: A string representing the command sent by the user.
+ """
+ if command == b"disconnect":
+ if self.connected:
+ self.logger.debug("UART disconnect request.")
+ # Drop all pending commands if any.
+ while not self.ec_cmd_queue.empty():
+ c = self.ec_cmd_queue.get()
+ self.logger.debug("dropped: '%s'", c)
+ if self.enhanced_ec:
+ # Reset retry state.
+ self.cmd_retries = COMMAND_RETRIES
+ self.last_cmd = b""
+ # Get the UART that the interpreter is attached to.
+ fileobj = self.ec_uart_pty
+ self.logger.debug("fileobj: %r", fileobj)
+ # Remove the descriptor from the inputs and outputs.
+ self.inputs.remove(fileobj)
+ if fileobj in self.outputs:
+ self.outputs.remove(fileobj)
+ self.logger.debug("Removed fileobj. Remaining inputs: %r", self.inputs)
+ # Close the file.
+ fileobj.close()
+ # Mark the interpreter as disconnected now.
+ self.connected = False
+ self.logger.debug("Disconnected from %s.", self.ec_uart_pty_name)
+ return
+
+ elif command == b"reconnect":
+ if not self.connected:
+ self.logger.debug("UART reconnect request.")
+ # Reopen the PTY.
+ # TODO(https://bugs.python.org/issue27805, python3.7+): revert to ab+
+ # TODO(https://bugs.python.org/issue20074): removing buffering=0 if/when
+ # that gets fixed, or keep two pty: one for reading and one for writing
+ fileobj = open(self.ec_uart_pty_name, "r+b", buffering=0)
+ self.logger.debug("fileobj: %r", fileobj)
+ self.ec_uart_pty = fileobj
+ # Add the descriptor to the inputs.
+ self.inputs.append(fileobj)
+ self.logger.debug("fileobj added. curr inputs: %r", self.inputs)
+ # Mark the interpreter as connected now.
+ self.connected = True
+ self.logger.debug("Connected to %s.", self.ec_uart_pty_name)
+ return
+
+ elif command.startswith(b"enhanced"):
+ self.enhanced_ec = command.split(b" ")[1] == b"True"
+ return
+
+ # Ignore any other commands while in the disconnected state.
+ self.logger.log(1, "command: '%s'", command)
+ if not self.connected:
+ self.logger.debug("Ignoring command because currently disconnected.")
+ return
+
+ # Remove leading and trailing spaces only if this is an enhanced EC image.
+ # For non-enhanced EC images, commands will be single characters at a time
+ # and can be spaces.
+ if self.enhanced_ec:
+ command = command.strip(b" ")
+
+ # There's nothing to do if the command is empty.
+ if len(command) == 0:
+ return
+
+ # Handle log level change requests.
+ if command.startswith(b"loglevel"):
+ self.logger.debug("Log level change request.")
+ new_log_level = int(command.split(b" ")[1])
+ self.logger.logger.setLevel(new_log_level)
+ self.logger.info("Log level changed to %d.", new_log_level)
+ return
+
+ # Check for interrogation command.
+ if command == EC_SYN:
+ # User is requesting interrogation. Send SYN as is.
+ self.logger.debug("User requesting interrogation.")
+ self.interrogating = True
+ # Assume the EC isn't enhanced until we get a response.
+ self.enhanced_ec = False
+ elif self.enhanced_ec:
+ # Enhanced EC images require the plaintext commands to be packed.
+ command = self.PackCommand(command)
+ # TODO(aaboagye): Make a dict of commands and keys and eventually,
+ # handle partial matching based on unique prefixes.
+
+ self.EnqueueCmd(command)
+
+ def HandleCmdRetries(self):
+ """Attempts to retry commands if possible."""
+ if self.cmd_retries > 0:
+ # The EC encountered an error. We'll have to retry again.
+ self.logger.warning("Retrying command...")
+ self.cmd_retries -= 1
+ self.logger.warning("Retries remaining: %d", self.cmd_retries)
+ # Retry the command and add the EC UART to the writers again.
+ self.EnqueueCmd(self.last_cmd)
+ self.outputs.append(self.ec_uart_pty)
+ else:
+ # We're out of retries, so just give up.
+ self.logger.error("Command failed. No retries left.")
+ # Clear the command in progress.
+ self.last_cmd = b""
+ # Reset the retry count.
+ self.cmd_retries = COMMAND_RETRIES
+
+ def SendCmdToEC(self):
+ """Sends a command to the EC."""
+ # If we're retrying a command, just try to send it again.
+ if self.cmd_retries < COMMAND_RETRIES:
+ cmd = self.last_cmd
+ else:
+ # If we're not retrying, we should not be writing to the EC if we have no
+ # items in our command queue.
+ assert not self.ec_cmd_queue.empty()
+ # Get the command to send.
+ cmd = self.ec_cmd_queue.get()
+
+ # Send the command.
+ self.ec_uart_pty.write(cmd)
+ self.ec_uart_pty.flush()
+ self.logger.log(1, "Sent command to EC.")
+
+ if self.enhanced_ec and cmd != EC_SYN:
+ # Now, that we've sent the command, store the current command as the last
+ # command sent. If we encounter an error string, we will attempt to retry
+ # this command.
+ if cmd != self.last_cmd:
+ self.last_cmd = cmd
+ # Reset the retry count.
+ self.cmd_retries = COMMAND_RETRIES
+
+ # If no command is pending to be sent, then we can remove the EC UART from
+ # writers. Might need better checking for command retry logic in here.
+ if self.ec_cmd_queue.empty():
+ # Remove the EC UART from the writers while we wait for a response.
+ self.logger.debug("Removing EC UART from writers.")
+ self.outputs.remove(self.ec_uart_pty)
+
+ def HandleECData(self):
+ """Handle any debug prints from the EC."""
+ self.logger.log(1, "EC has data")
+ # Read what the EC sent us.
+ data = os.read(self.ec_uart_pty.fileno(), EC_MAX_READ)
+ self.logger.log(1, "got: '%s'", binascii.hexlify(data))
+ if b"&E" in data and self.enhanced_ec:
+ # We received an error, so we should retry it if possible.
+ self.logger.warning("Error string found in data.")
+ self.HandleCmdRetries()
+ return
+
+ # If we were interrogating, check the response and update our knowledge
+ # of the current EC image.
+ if self.interrogating:
+ self.enhanced_ec = data == EC_ACK
+ if self.enhanced_ec:
+ self.logger.debug("The current EC image seems enhanced.")
+ else:
+ self.logger.debug("The current EC image does NOT seem enhanced.")
+ # Done interrogating.
+ self.interrogating = False
+ # For now, just forward everything the EC sends us.
+ self.logger.log(1, "Forwarding to user...")
+ self.dbg_pipe.send(data)
+
+ def HandleUserData(self):
+ """Handle any incoming commands from the user.
+
+ Raises:
+ EOFError: Allowed to propagate through from self.cmd_pipe.recv().
+ """
+ self.logger.log(1, "Command data available. Begin processing.")
+ data = self.cmd_pipe.recv()
+ # Process the command.
+ self.ProcessCommand(data)
def Crc8(data):
- """Calculates the CRC8 of data.
+ """Calculates the CRC8 of data.
- The generator polynomial used is: x^8 + x^2 + x + 1.
- This is the same implementation that is used in the EC.
+ The generator polynomial used is: x^8 + x^2 + x + 1.
+ This is the same implementation that is used in the EC.
- Args:
- data: A string of data that we wish to calculate the CRC8 on.
+ Args:
+ data: A string of data that we wish to calculate the CRC8 on.
- Returns:
- crc >> 8: An integer representing the CRC8 value.
- """
- crc = 0
- for byte in six.iterbytes(data):
- crc ^= (byte << 8)
- for _ in range(8):
- if crc & 0x8000:
- crc ^= (0x1070 << 3)
- crc <<= 1
- return crc >> 8
+ Returns:
+ crc >> 8: An integer representing the CRC8 value.
+ """
+ crc = 0
+ for byte in six.iterbytes(data):
+ crc ^= byte << 8
+ for _ in range(8):
+ if crc & 0x8000:
+ crc ^= 0x1070 << 3
+ crc <<= 1
+ return crc >> 8
def StartLoop(interp, shutdown_pipe=None):
- """Starts an infinite loop of servicing the user and the EC.
-
- StartLoop checks to see if there are any commands to process, processing them
- if any, and forwards EC output to the user.
-
- When sending a command to the EC, we send the command once and check the
- response to see if the EC encountered an error when receiving the command. An
- error condition is reported to the interpreter by a string with at least one
- '&' and 'E'. The full string is actually '&&EE', however it's possible that
- the leading ampersand or trailing 'E' could be dropped. If an error is
- encountered, the interpreter will retry up to the amount configured.
-
- Args:
- interp: An Interpreter object that has been properly initialised.
- shutdown_pipe: A file object for a pipe or equivalent that becomes readable
- (not blocked) to indicate that the loop should exit. Can be None to never
- exit the loop.
- """
- try:
- # This is used instead of "break" to avoid exiting the loop in the middle of
- # an iteration.
- continue_looping = True
-
- while continue_looping:
- # The inputs list is created anew in each loop iteration because the
- # Interpreter class sometimes modifies the interp.inputs list.
- if shutdown_pipe is None:
- inputs = interp.inputs
- else:
- inputs = list(interp.inputs)
- inputs.append(shutdown_pipe)
-
- readable, writeable, _ = select.select(inputs, interp.outputs, [])
-
- for obj in readable:
- # Handle any debug prints from the EC.
- if obj is interp.ec_uart_pty:
- interp.HandleECData()
-
- # Handle any commands from the user.
- elif obj is interp.cmd_pipe:
- try:
- interp.HandleUserData()
- except EOFError:
- interp.logger.debug(
- 'ec3po interpreter received EOF from cmd_pipe in '
- 'HandleUserData()')
- continue_looping = False
-
- elif obj is shutdown_pipe:
- interp.logger.debug(
- 'ec3po interpreter received shutdown pipe unblocked notification')
- continue_looping = False
-
- for obj in writeable:
- # Send a command to the EC.
- if obj is interp.ec_uart_pty:
- interp.SendCmdToEC()
-
- except KeyboardInterrupt:
- pass
-
- finally:
- interp.cmd_pipe.close()
- interp.dbg_pipe.close()
- interp.ec_uart_pty.close()
- if shutdown_pipe is not None:
- shutdown_pipe.close()
- interp.logger.debug('Exit ec3po interpreter loop for %s',
- interp.ec_uart_pty_name)
+ """Starts an infinite loop of servicing the user and the EC.
+
+ StartLoop checks to see if there are any commands to process, processing them
+ if any, and forwards EC output to the user.
+
+ When sending a command to the EC, we send the command once and check the
+ response to see if the EC encountered an error when receiving the command. An
+ error condition is reported to the interpreter by a string with at least one
+ '&' and 'E'. The full string is actually '&&EE', however it's possible that
+ the leading ampersand or trailing 'E' could be dropped. If an error is
+ encountered, the interpreter will retry up to the amount configured.
+
+ Args:
+ interp: An Interpreter object that has been properly initialised.
+ shutdown_pipe: A file object for a pipe or equivalent that becomes readable
+ (not blocked) to indicate that the loop should exit. Can be None to never
+ exit the loop.
+ """
+ try:
+ # This is used instead of "break" to avoid exiting the loop in the middle of
+ # an iteration.
+ continue_looping = True
+
+ while continue_looping:
+ # The inputs list is created anew in each loop iteration because the
+ # Interpreter class sometimes modifies the interp.inputs list.
+ if shutdown_pipe is None:
+ inputs = interp.inputs
+ else:
+ inputs = list(interp.inputs)
+ inputs.append(shutdown_pipe)
+
+ readable, writeable, _ = select.select(inputs, interp.outputs, [])
+
+ for obj in readable:
+ # Handle any debug prints from the EC.
+ if obj is interp.ec_uart_pty:
+ interp.HandleECData()
+
+ # Handle any commands from the user.
+ elif obj is interp.cmd_pipe:
+ try:
+ interp.HandleUserData()
+ except EOFError:
+ interp.logger.debug(
+ "ec3po interpreter received EOF from cmd_pipe in "
+ "HandleUserData()"
+ )
+ continue_looping = False
+
+ elif obj is shutdown_pipe:
+ interp.logger.debug(
+ "ec3po interpreter received shutdown pipe unblocked notification"
+ )
+ continue_looping = False
+
+ for obj in writeable:
+ # Send a command to the EC.
+ if obj is interp.ec_uart_pty:
+ interp.SendCmdToEC()
+
+ except KeyboardInterrupt:
+ pass
+
+ finally:
+ interp.cmd_pipe.close()
+ interp.dbg_pipe.close()
+ interp.ec_uart_pty.close()
+ if shutdown_pipe is not None:
+ shutdown_pipe.close()
+ interp.logger.debug(
+ "Exit ec3po interpreter loop for %s", interp.ec_uart_pty_name
+ )
diff --git a/util/ec3po/interpreter_unittest.py b/util/ec3po/interpreter_unittest.py
index fe4d43c351..509b90f667 100755
--- a/util/ec3po/interpreter_unittest.py
+++ b/util/ec3po/interpreter_unittest.py
@@ -10,371 +10,389 @@
from __future__ import print_function
import logging
-import mock
import tempfile
import unittest
+import mock
import six
-
-from ec3po import interpreter
-from ec3po import threadproc_shim
+from ec3po import interpreter, threadproc_shim
def GetBuiltins(func):
- if six.PY2:
- return '__builtin__.' + func
- return 'builtins.' + func
+ if six.PY2:
+ return "__builtin__." + func
+ return "builtins." + func
class TestEnhancedECBehaviour(unittest.TestCase):
- """Test case to verify all enhanced EC interpretation tasks."""
- def setUp(self):
- """Setup the test harness."""
- # Setup logging with a timestamp, the module, and the log level.
- logging.basicConfig(level=logging.DEBUG,
- format=('%(asctime)s - %(module)s -'
- ' %(levelname)s - %(message)s'))
-
- # Create a tempfile that would represent the EC UART PTY.
- self.tempfile = tempfile.NamedTemporaryFile()
-
- # Create the pipes that the interpreter will use.
- self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe()
- self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False)
-
- # Mock the open() function so we can inspect reads/writes to the EC.
- self.ec_uart_pty = mock.mock_open()
-
- with mock.patch(GetBuiltins('open'), self.ec_uart_pty):
- # Create an interpreter.
- self.itpr = interpreter.Interpreter(self.tempfile.name,
- self.cmd_pipe_itpr,
- self.dbg_pipe_itpr,
- log_level=logging.DEBUG,
- name="EC")
-
- @mock.patch('ec3po.interpreter.os')
- def test_HandlingCommandsThatProduceNoOutput(self, mock_os):
- """Verify that the Interpreter correctly handles non-output commands.
-
- Args:
- mock_os: MagicMock object replacing the 'os' module for this test
- case.
- """
- # The interpreter init should open the EC UART PTY.
- expected_ec_calls = [mock.call(self.tempfile.name, 'r+b', buffering=0)]
- # Have a command come in the command pipe. The first command will be an
- # interrogation to determine if the EC is enhanced or not.
- self.cmd_pipe_user.send(interpreter.EC_SYN)
- self.itpr.HandleUserData()
- # At this point, the command should be queued up waiting to be sent, so
- # let's actually send it to the EC.
- self.itpr.SendCmdToEC()
- expected_ec_calls.extend([mock.call().write(interpreter.EC_SYN),
- mock.call().flush()])
- # Now, assume that the EC sends only 1 response back of EC_ACK.
- mock_os.read.side_effect = [interpreter.EC_ACK]
- # When reading the EC, the interpreter will call file.fileno() to pass to
- # os.read().
- expected_ec_calls.append(mock.call().fileno())
- # Simulate the response.
- self.itpr.HandleECData()
-
- # Now that the interrogation was complete, it's time to send down the real
- # command.
- test_cmd = b'chan save'
- # Send the test command down the pipe.
- self.cmd_pipe_user.send(test_cmd)
- self.itpr.HandleUserData()
- self.itpr.SendCmdToEC()
- # Since the EC image is enhanced, we should have sent a packed command.
- expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd)))
- expected_ec_calls.append(mock.call().flush())
-
- # Now that the first command was sent, we should send another command which
- # produces no output. The console would send another interrogation.
- self.cmd_pipe_user.send(interpreter.EC_SYN)
- self.itpr.HandleUserData()
- self.itpr.SendCmdToEC()
- expected_ec_calls.extend([mock.call().write(interpreter.EC_SYN),
- mock.call().flush()])
- # Again, assume that the EC sends only 1 response back of EC_ACK.
- mock_os.read.side_effect = [interpreter.EC_ACK]
- # When reading the EC, the interpreter will call file.fileno() to pass to
- # os.read().
- expected_ec_calls.append(mock.call().fileno())
- # Simulate the response.
- self.itpr.HandleECData()
-
- # Now send the second test command.
- test_cmd = b'chan 0'
- self.cmd_pipe_user.send(test_cmd)
- self.itpr.HandleUserData()
- self.itpr.SendCmdToEC()
- # Since the EC image is enhanced, we should have sent a packed command.
- expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd)))
- expected_ec_calls.append(mock.call().flush())
-
- # Finally, verify that the appropriate writes were actually sent to the EC.
- self.ec_uart_pty.assert_has_calls(expected_ec_calls)
-
- @mock.patch('ec3po.interpreter.os')
- def test_CommandRetryingOnError(self, mock_os):
- """Verify that commands are retried if an error is encountered.
-
- Args:
- mock_os: MagicMock object replacing the 'os' module for this test
- case.
- """
- # The interpreter init should open the EC UART PTY.
- expected_ec_calls = [mock.call(self.tempfile.name, 'r+b', buffering=0)]
- # Have a command come in the command pipe. The first command will be an
- # interrogation to determine if the EC is enhanced or not.
- self.cmd_pipe_user.send(interpreter.EC_SYN)
- self.itpr.HandleUserData()
- # At this point, the command should be queued up waiting to be sent, so
- # let's actually send it to the EC.
- self.itpr.SendCmdToEC()
- expected_ec_calls.extend([mock.call().write(interpreter.EC_SYN),
- mock.call().flush()])
- # Now, assume that the EC sends only 1 response back of EC_ACK.
- mock_os.read.side_effect = [interpreter.EC_ACK]
- # When reading the EC, the interpreter will call file.fileno() to pass to
- # os.read().
- expected_ec_calls.append(mock.call().fileno())
- # Simulate the response.
- self.itpr.HandleECData()
-
- # Let's send a command that is received on the EC-side with an error.
- test_cmd = b'accelinfo'
- self.cmd_pipe_user.send(test_cmd)
- self.itpr.HandleUserData()
- self.itpr.SendCmdToEC()
- packed_cmd = self.itpr.PackCommand(test_cmd)
- expected_ec_calls.extend([mock.call().write(packed_cmd),
- mock.call().flush()])
- # Have the EC return the error string twice.
- mock_os.read.side_effect = [b'&&EE', b'&&EE']
- for i in range(2):
- # When reading the EC, the interpreter will call file.fileno() to pass to
- # os.read().
- expected_ec_calls.append(mock.call().fileno())
- # Simulate the response.
- self.itpr.HandleECData()
-
- # Since an error was received, the EC should attempt to retry the command.
- expected_ec_calls.extend([mock.call().write(packed_cmd),
- mock.call().flush()])
- # Verify that the retry count was decremented.
- self.assertEqual(interpreter.COMMAND_RETRIES-i-1, self.itpr.cmd_retries,
- 'Unexpected cmd_remaining count.')
- # Actually retry the command.
- self.itpr.SendCmdToEC()
-
- # Now assume that the last one goes through with no trouble.
- expected_ec_calls.extend([mock.call().write(packed_cmd),
- mock.call().flush()])
- self.itpr.SendCmdToEC()
-
- # Verify all the calls.
- self.ec_uart_pty.assert_has_calls(expected_ec_calls)
-
- def test_PackCommandsForEnhancedEC(self):
- """Verify that the interpreter packs commands for enhanced EC images."""
- # Assume current EC image is enhanced.
- self.itpr.enhanced_ec = True
- # Receive a command from the user.
- test_cmd = b'gettime'
- self.cmd_pipe_user.send(test_cmd)
- # Mock out PackCommand to see if it was called.
- self.itpr.PackCommand = mock.MagicMock()
- # Have the interpreter handle the command.
- self.itpr.HandleUserData()
- # Verify that PackCommand() was called.
- self.itpr.PackCommand.assert_called_once_with(test_cmd)
-
- def test_DontPackCommandsForNonEnhancedEC(self):
- """Verify the interpreter doesn't pack commands for non-enhanced images."""
- # Assume current EC image is not enhanced.
- self.itpr.enhanced_ec = False
- # Receive a command from the user.
- test_cmd = b'gettime'
- self.cmd_pipe_user.send(test_cmd)
- # Mock out PackCommand to see if it was called.
- self.itpr.PackCommand = mock.MagicMock()
- # Have the interpreter handle the command.
- self.itpr.HandleUserData()
- # Verify that PackCommand() was called.
- self.itpr.PackCommand.assert_not_called()
-
- @mock.patch('ec3po.interpreter.os')
- def test_KeepingTrackOfInterrogation(self, mock_os):
- """Verify that the interpreter can track the state of the interrogation.
-
- Args:
- mock_os: MagicMock object replacing the 'os' module. for this test
- case.
- """
- # Upon init, the interpreter should assume that the current EC image is not
- # enhanced.
- self.assertFalse(self.itpr.enhanced_ec, msg=('State of enhanced_ec upon'
- ' init is not False.'))
-
- # Assume an interrogation request comes in from the user.
- self.cmd_pipe_user.send(interpreter.EC_SYN)
- self.itpr.HandleUserData()
-
- # Verify the state is now within an interrogation.
- self.assertTrue(self.itpr.interrogating, 'interrogating should be True')
- # The state of enhanced_ec should not be changed yet because we haven't
- # received a valid response yet.
- self.assertFalse(self.itpr.enhanced_ec, msg=('State of enhanced_ec is '
- 'not False.'))
-
- # Assume that the EC responds with an EC_ACK.
- mock_os.read.side_effect = [interpreter.EC_ACK]
- self.itpr.HandleECData()
-
- # Now, the interrogation should be complete and we should know that the
- # current EC image is enhanced.
- self.assertFalse(self.itpr.interrogating, msg=('interrogating should be '
- 'False'))
- self.assertTrue(self.itpr.enhanced_ec, msg='enhanced_ec sholud be True')
-
- # Now let's perform another interrogation, but pretend that the EC ignores
- # it.
- self.cmd_pipe_user.send(interpreter.EC_SYN)
- self.itpr.HandleUserData()
-
- # Verify interrogating state.
- self.assertTrue(self.itpr.interrogating, 'interrogating sholud be True')
- # We should assume that the image is not enhanced until we get the valid
- # response.
- self.assertFalse(self.itpr.enhanced_ec, 'enhanced_ec should be False now.')
-
- # Let's pretend that we get a random debug print. This should clear the
- # interrogating flag.
- mock_os.read.side_effect = [b'[1660.593076 HC 0x103]']
- self.itpr.HandleECData()
-
- # Verify that interrogating flag is cleared and enhanced_ec is still False.
- self.assertFalse(self.itpr.interrogating, 'interrogating should be False.')
- self.assertFalse(self.itpr.enhanced_ec,
- 'enhanced_ec should still be False.')
+ """Test case to verify all enhanced EC interpretation tasks."""
+
+ def setUp(self):
+ """Setup the test harness."""
+ # Setup logging with a timestamp, the module, and the log level.
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"),
+ )
+
+ # Create a tempfile that would represent the EC UART PTY.
+ self.tempfile = tempfile.NamedTemporaryFile()
+
+ # Create the pipes that the interpreter will use.
+ self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe()
+ self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False)
+
+ # Mock the open() function so we can inspect reads/writes to the EC.
+ self.ec_uart_pty = mock.mock_open()
+
+ with mock.patch(GetBuiltins("open"), self.ec_uart_pty):
+ # Create an interpreter.
+ self.itpr = interpreter.Interpreter(
+ self.tempfile.name,
+ self.cmd_pipe_itpr,
+ self.dbg_pipe_itpr,
+ log_level=logging.DEBUG,
+ name="EC",
+ )
+
+ @mock.patch("ec3po.interpreter.os")
+ def test_HandlingCommandsThatProduceNoOutput(self, mock_os):
+ """Verify that the Interpreter correctly handles non-output commands.
+
+ Args:
+ mock_os: MagicMock object replacing the 'os' module for this test
+ case.
+ """
+ # The interpreter init should open the EC UART PTY.
+ expected_ec_calls = [mock.call(self.tempfile.name, "r+b", buffering=0)]
+ # Have a command come in the command pipe. The first command will be an
+ # interrogation to determine if the EC is enhanced or not.
+ self.cmd_pipe_user.send(interpreter.EC_SYN)
+ self.itpr.HandleUserData()
+ # At this point, the command should be queued up waiting to be sent, so
+ # let's actually send it to the EC.
+ self.itpr.SendCmdToEC()
+ expected_ec_calls.extend(
+ [mock.call().write(interpreter.EC_SYN), mock.call().flush()]
+ )
+ # Now, assume that the EC sends only 1 response back of EC_ACK.
+ mock_os.read.side_effect = [interpreter.EC_ACK]
+ # When reading the EC, the interpreter will call file.fileno() to pass to
+ # os.read().
+ expected_ec_calls.append(mock.call().fileno())
+ # Simulate the response.
+ self.itpr.HandleECData()
+
+ # Now that the interrogation was complete, it's time to send down the real
+ # command.
+ test_cmd = b"chan save"
+ # Send the test command down the pipe.
+ self.cmd_pipe_user.send(test_cmd)
+ self.itpr.HandleUserData()
+ self.itpr.SendCmdToEC()
+ # Since the EC image is enhanced, we should have sent a packed command.
+ expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd)))
+ expected_ec_calls.append(mock.call().flush())
+
+ # Now that the first command was sent, we should send another command which
+ # produces no output. The console would send another interrogation.
+ self.cmd_pipe_user.send(interpreter.EC_SYN)
+ self.itpr.HandleUserData()
+ self.itpr.SendCmdToEC()
+ expected_ec_calls.extend(
+ [mock.call().write(interpreter.EC_SYN), mock.call().flush()]
+ )
+ # Again, assume that the EC sends only 1 response back of EC_ACK.
+ mock_os.read.side_effect = [interpreter.EC_ACK]
+ # When reading the EC, the interpreter will call file.fileno() to pass to
+ # os.read().
+ expected_ec_calls.append(mock.call().fileno())
+ # Simulate the response.
+ self.itpr.HandleECData()
+
+ # Now send the second test command.
+ test_cmd = b"chan 0"
+ self.cmd_pipe_user.send(test_cmd)
+ self.itpr.HandleUserData()
+ self.itpr.SendCmdToEC()
+ # Since the EC image is enhanced, we should have sent a packed command.
+ expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd)))
+ expected_ec_calls.append(mock.call().flush())
+
+ # Finally, verify that the appropriate writes were actually sent to the EC.
+ self.ec_uart_pty.assert_has_calls(expected_ec_calls)
+
+ @mock.patch("ec3po.interpreter.os")
+ def test_CommandRetryingOnError(self, mock_os):
+ """Verify that commands are retried if an error is encountered.
+
+ Args:
+ mock_os: MagicMock object replacing the 'os' module for this test
+ case.
+ """
+ # The interpreter init should open the EC UART PTY.
+ expected_ec_calls = [mock.call(self.tempfile.name, "r+b", buffering=0)]
+ # Have a command come in the command pipe. The first command will be an
+ # interrogation to determine if the EC is enhanced or not.
+ self.cmd_pipe_user.send(interpreter.EC_SYN)
+ self.itpr.HandleUserData()
+ # At this point, the command should be queued up waiting to be sent, so
+ # let's actually send it to the EC.
+ self.itpr.SendCmdToEC()
+ expected_ec_calls.extend(
+ [mock.call().write(interpreter.EC_SYN), mock.call().flush()]
+ )
+ # Now, assume that the EC sends only 1 response back of EC_ACK.
+ mock_os.read.side_effect = [interpreter.EC_ACK]
+ # When reading the EC, the interpreter will call file.fileno() to pass to
+ # os.read().
+ expected_ec_calls.append(mock.call().fileno())
+ # Simulate the response.
+ self.itpr.HandleECData()
+
+ # Let's send a command that is received on the EC-side with an error.
+ test_cmd = b"accelinfo"
+ self.cmd_pipe_user.send(test_cmd)
+ self.itpr.HandleUserData()
+ self.itpr.SendCmdToEC()
+ packed_cmd = self.itpr.PackCommand(test_cmd)
+ expected_ec_calls.extend([mock.call().write(packed_cmd), mock.call().flush()])
+ # Have the EC return the error string twice.
+ mock_os.read.side_effect = [b"&&EE", b"&&EE"]
+ for i in range(2):
+ # When reading the EC, the interpreter will call file.fileno() to pass to
+ # os.read().
+ expected_ec_calls.append(mock.call().fileno())
+ # Simulate the response.
+ self.itpr.HandleECData()
+
+ # Since an error was received, the EC should attempt to retry the command.
+ expected_ec_calls.extend(
+ [mock.call().write(packed_cmd), mock.call().flush()]
+ )
+ # Verify that the retry count was decremented.
+ self.assertEqual(
+ interpreter.COMMAND_RETRIES - i - 1,
+ self.itpr.cmd_retries,
+ "Unexpected cmd_remaining count.",
+ )
+ # Actually retry the command.
+ self.itpr.SendCmdToEC()
+
+ # Now assume that the last one goes through with no trouble.
+ expected_ec_calls.extend([mock.call().write(packed_cmd), mock.call().flush()])
+ self.itpr.SendCmdToEC()
+
+ # Verify all the calls.
+ self.ec_uart_pty.assert_has_calls(expected_ec_calls)
+
+ def test_PackCommandsForEnhancedEC(self):
+ """Verify that the interpreter packs commands for enhanced EC images."""
+ # Assume current EC image is enhanced.
+ self.itpr.enhanced_ec = True
+ # Receive a command from the user.
+ test_cmd = b"gettime"
+ self.cmd_pipe_user.send(test_cmd)
+ # Mock out PackCommand to see if it was called.
+ self.itpr.PackCommand = mock.MagicMock()
+ # Have the interpreter handle the command.
+ self.itpr.HandleUserData()
+ # Verify that PackCommand() was called.
+ self.itpr.PackCommand.assert_called_once_with(test_cmd)
+
+ def test_DontPackCommandsForNonEnhancedEC(self):
+ """Verify the interpreter doesn't pack commands for non-enhanced images."""
+ # Assume current EC image is not enhanced.
+ self.itpr.enhanced_ec = False
+ # Receive a command from the user.
+ test_cmd = b"gettime"
+ self.cmd_pipe_user.send(test_cmd)
+ # Mock out PackCommand to see if it was called.
+ self.itpr.PackCommand = mock.MagicMock()
+ # Have the interpreter handle the command.
+ self.itpr.HandleUserData()
+ # Verify that PackCommand() was called.
+ self.itpr.PackCommand.assert_not_called()
+
+ @mock.patch("ec3po.interpreter.os")
+ def test_KeepingTrackOfInterrogation(self, mock_os):
+ """Verify that the interpreter can track the state of the interrogation.
+
+ Args:
+ mock_os: MagicMock object replacing the 'os' module. for this test
+ case.
+ """
+ # Upon init, the interpreter should assume that the current EC image is not
+ # enhanced.
+ self.assertFalse(
+ self.itpr.enhanced_ec,
+ msg=("State of enhanced_ec upon" " init is not False."),
+ )
+
+ # Assume an interrogation request comes in from the user.
+ self.cmd_pipe_user.send(interpreter.EC_SYN)
+ self.itpr.HandleUserData()
+
+ # Verify the state is now within an interrogation.
+ self.assertTrue(self.itpr.interrogating, "interrogating should be True")
+ # The state of enhanced_ec should not be changed yet because we haven't
+ # received a valid response yet.
+ self.assertFalse(
+ self.itpr.enhanced_ec, msg=("State of enhanced_ec is " "not False.")
+ )
+
+ # Assume that the EC responds with an EC_ACK.
+ mock_os.read.side_effect = [interpreter.EC_ACK]
+ self.itpr.HandleECData()
+
+ # Now, the interrogation should be complete and we should know that the
+ # current EC image is enhanced.
+ self.assertFalse(
+ self.itpr.interrogating, msg=("interrogating should be " "False")
+ )
+ self.assertTrue(self.itpr.enhanced_ec, msg="enhanced_ec sholud be True")
+
+ # Now let's perform another interrogation, but pretend that the EC ignores
+ # it.
+ self.cmd_pipe_user.send(interpreter.EC_SYN)
+ self.itpr.HandleUserData()
+
+ # Verify interrogating state.
+ self.assertTrue(self.itpr.interrogating, "interrogating sholud be True")
+ # We should assume that the image is not enhanced until we get the valid
+ # response.
+ self.assertFalse(self.itpr.enhanced_ec, "enhanced_ec should be False now.")
+
+ # Let's pretend that we get a random debug print. This should clear the
+ # interrogating flag.
+ mock_os.read.side_effect = [b"[1660.593076 HC 0x103]"]
+ self.itpr.HandleECData()
+
+ # Verify that interrogating flag is cleared and enhanced_ec is still False.
+ self.assertFalse(self.itpr.interrogating, "interrogating should be False.")
+ self.assertFalse(self.itpr.enhanced_ec, "enhanced_ec should still be False.")
class TestUARTDisconnection(unittest.TestCase):
- """Test case to verify interpreter disconnection/reconnection."""
- def setUp(self):
- """Setup the test harness."""
- # Setup logging with a timestamp, the module, and the log level.
- logging.basicConfig(level=logging.DEBUG,
- format=('%(asctime)s - %(module)s -'
- ' %(levelname)s - %(message)s'))
-
- # Create a tempfile that would represent the EC UART PTY.
- self.tempfile = tempfile.NamedTemporaryFile()
-
- # Create the pipes that the interpreter will use.
- self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe()
- self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False)
-
- # Mock the open() function so we can inspect reads/writes to the EC.
- self.ec_uart_pty = mock.mock_open()
-
- with mock.patch(GetBuiltins('open'), self.ec_uart_pty):
- # Create an interpreter.
- self.itpr = interpreter.Interpreter(self.tempfile.name,
- self.cmd_pipe_itpr,
- self.dbg_pipe_itpr,
- log_level=logging.DEBUG,
- name="EC")
-
- # First, check that interpreter is initialized to connected.
- self.assertTrue(self.itpr.connected, ('The interpreter should be'
- ' initialized in a connected state'))
-
- def test_DisconnectStopsECTraffic(self):
- """Verify that when in disconnected state, no debug prints are sent."""
- # Let's send a disconnect command through the command pipe.
- self.cmd_pipe_user.send(b'disconnect')
- self.itpr.HandleUserData()
-
- # Verify interpreter is disconnected from EC.
- self.assertFalse(self.itpr.connected, ('The interpreter should be'
- 'disconnected.'))
- # Verify that the EC UART is no longer a member of the inputs. The
- # interpreter will never pull data from the EC if it's not a member of the
- # inputs list.
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs)
-
- def test_CommandsDroppedWhenDisconnected(self):
- """Verify that when in disconnected state, commands are dropped."""
- # Send a command, followed by 'disconnect'.
- self.cmd_pipe_user.send(b'taskinfo')
- self.itpr.HandleUserData()
- self.cmd_pipe_user.send(b'disconnect')
- self.itpr.HandleUserData()
-
- # Verify interpreter is disconnected from EC.
- self.assertFalse(self.itpr.connected, ('The interpreter should be'
- 'disconnected.'))
- # Verify that the EC UART is no longer a member of the inputs nor outputs.
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs)
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
-
- # Have the user send a few more commands in the disconnected state.
- command = 'help\n'
- for char in command:
- self.cmd_pipe_user.send(char.encode('utf-8'))
- self.itpr.HandleUserData()
-
- # The command queue should be empty.
- self.assertEqual(0, self.itpr.ec_cmd_queue.qsize())
-
- # Now send the reconnect command.
- self.cmd_pipe_user.send(b'reconnect')
-
- with mock.patch(GetBuiltins('open'), mock.mock_open()):
- self.itpr.HandleUserData()
-
- # Verify interpreter is connected.
- self.assertTrue(self.itpr.connected)
- # Verify that EC UART is a member of the inputs.
- self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs)
- # Since no command was sent after reconnection, verify that the EC UART is
- # not a member of the outputs.
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
-
- def test_ReconnectAllowsECTraffic(self):
- """Verify that when connected, EC UART traffic is allowed."""
- # Let's send a disconnect command through the command pipe.
- self.cmd_pipe_user.send(b'disconnect')
- self.itpr.HandleUserData()
-
- # Verify interpreter is disconnected.
- self.assertFalse(self.itpr.connected, ('The interpreter should be'
- 'disconnected.'))
- # Verify that the EC UART is no longer a member of the inputs nor outputs.
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs)
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
-
- # Issue reconnect command through the command pipe.
- self.cmd_pipe_user.send(b'reconnect')
-
- with mock.patch(GetBuiltins('open'), mock.mock_open()):
- self.itpr.HandleUserData()
-
- # Verify interpreter is connected.
- self.assertTrue(self.itpr.connected, ('The interpreter should be'
- 'connected.'))
- # Verify that the EC UART is now a member of the inputs.
- self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs)
- # Since we have issued no commands during the disconnected state, no
- # commands are pending and therefore the PTY should not be added to the
- # outputs.
- self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
-
-
-if __name__ == '__main__':
- unittest.main()
+ """Test case to verify interpreter disconnection/reconnection."""
+
+ def setUp(self):
+ """Setup the test harness."""
+ # Setup logging with a timestamp, the module, and the log level.
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"),
+ )
+
+ # Create a tempfile that would represent the EC UART PTY.
+ self.tempfile = tempfile.NamedTemporaryFile()
+
+ # Create the pipes that the interpreter will use.
+ self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe()
+ self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False)
+
+ # Mock the open() function so we can inspect reads/writes to the EC.
+ self.ec_uart_pty = mock.mock_open()
+
+ with mock.patch(GetBuiltins("open"), self.ec_uart_pty):
+ # Create an interpreter.
+ self.itpr = interpreter.Interpreter(
+ self.tempfile.name,
+ self.cmd_pipe_itpr,
+ self.dbg_pipe_itpr,
+ log_level=logging.DEBUG,
+ name="EC",
+ )
+
+ # First, check that interpreter is initialized to connected.
+ self.assertTrue(
+ self.itpr.connected,
+ ("The interpreter should be" " initialized in a connected state"),
+ )
+
+ def test_DisconnectStopsECTraffic(self):
+ """Verify that when in disconnected state, no debug prints are sent."""
+ # Let's send a disconnect command through the command pipe.
+ self.cmd_pipe_user.send(b"disconnect")
+ self.itpr.HandleUserData()
+
+ # Verify interpreter is disconnected from EC.
+ self.assertFalse(
+ self.itpr.connected, ("The interpreter should be" "disconnected.")
+ )
+ # Verify that the EC UART is no longer a member of the inputs. The
+ # interpreter will never pull data from the EC if it's not a member of the
+ # inputs list.
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs)
+
+ def test_CommandsDroppedWhenDisconnected(self):
+ """Verify that when in disconnected state, commands are dropped."""
+ # Send a command, followed by 'disconnect'.
+ self.cmd_pipe_user.send(b"taskinfo")
+ self.itpr.HandleUserData()
+ self.cmd_pipe_user.send(b"disconnect")
+ self.itpr.HandleUserData()
+
+ # Verify interpreter is disconnected from EC.
+ self.assertFalse(
+ self.itpr.connected, ("The interpreter should be" "disconnected.")
+ )
+ # Verify that the EC UART is no longer a member of the inputs nor outputs.
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs)
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
+
+ # Have the user send a few more commands in the disconnected state.
+ command = "help\n"
+ for char in command:
+ self.cmd_pipe_user.send(char.encode("utf-8"))
+ self.itpr.HandleUserData()
+
+ # The command queue should be empty.
+ self.assertEqual(0, self.itpr.ec_cmd_queue.qsize())
+
+ # Now send the reconnect command.
+ self.cmd_pipe_user.send(b"reconnect")
+
+ with mock.patch(GetBuiltins("open"), mock.mock_open()):
+ self.itpr.HandleUserData()
+
+ # Verify interpreter is connected.
+ self.assertTrue(self.itpr.connected)
+ # Verify that EC UART is a member of the inputs.
+ self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs)
+ # Since no command was sent after reconnection, verify that the EC UART is
+ # not a member of the outputs.
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
+
+ def test_ReconnectAllowsECTraffic(self):
+ """Verify that when connected, EC UART traffic is allowed."""
+ # Let's send a disconnect command through the command pipe.
+ self.cmd_pipe_user.send(b"disconnect")
+ self.itpr.HandleUserData()
+
+ # Verify interpreter is disconnected.
+ self.assertFalse(
+ self.itpr.connected, ("The interpreter should be" "disconnected.")
+ )
+ # Verify that the EC UART is no longer a member of the inputs nor outputs.
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs)
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
+
+ # Issue reconnect command through the command pipe.
+ self.cmd_pipe_user.send(b"reconnect")
+
+ with mock.patch(GetBuiltins("open"), mock.mock_open()):
+ self.itpr.HandleUserData()
+
+ # Verify interpreter is connected.
+ self.assertTrue(self.itpr.connected, ("The interpreter should be" "connected."))
+ # Verify that the EC UART is now a member of the inputs.
+ self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs)
+ # Since we have issued no commands during the disconnected state, no
+ # commands are pending and therefore the PTY should not be added to the
+ # outputs.
+ self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/util/ec3po/threadproc_shim.py b/util/ec3po/threadproc_shim.py
index da5440b1f3..c0b3ce0bf4 100644
--- a/util/ec3po/threadproc_shim.py
+++ b/util/ec3po/threadproc_shim.py
@@ -34,33 +34,34 @@ wait until after completing the TODO above to stop using multiprocessing.Pipe!
# Imports to bring objects into this namespace for users of this module.
from multiprocessing import Pipe
-from six.moves.queue import Queue
from threading import Thread as ThreadOrProcess
+from six.moves.queue import Queue
+
# True if this module has ec3po using subprocesses, False if using threads.
USING_SUBPROCS = False
def _DoNothing():
- """Do-nothing function for use as a callback with DoIf()."""
+ """Do-nothing function for use as a callback with DoIf()."""
def DoIf(subprocs=_DoNothing, threads=_DoNothing):
- """Return a callback or not based on ec3po use of subprocesses or threads.
+ """Return a callback or not based on ec3po use of subprocesses or threads.
- Args:
- subprocs: callback that does not require any args - This will be returned
- (not called!) if and only if ec3po is using subprocesses. This is
- OPTIONAL, the default value is a do-nothing callback that returns None.
- threads: callback that does not require any args - This will be returned
- (not called!) if and only if ec3po is using threads. This is OPTIONAL,
- the default value is a do-nothing callback that returns None.
+ Args:
+ subprocs: callback that does not require any args - This will be returned
+ (not called!) if and only if ec3po is using subprocesses. This is
+ OPTIONAL, the default value is a do-nothing callback that returns None.
+ threads: callback that does not require any args - This will be returned
+ (not called!) if and only if ec3po is using threads. This is OPTIONAL,
+ the default value is a do-nothing callback that returns None.
- Returns:
- Either the subprocs or threads argument will be returned.
- """
- return subprocs if USING_SUBPROCS else threads
+ Returns:
+ Either the subprocs or threads argument will be returned.
+ """
+ return subprocs if USING_SUBPROCS else threads
def Value(ctype, *args):
- return ctype(*args)
+ return ctype(*args)
diff --git a/util/ec_openocd.py b/util/ec_openocd.py
index a84c00643c..11956ffa1c 100755
--- a/util/ec_openocd.py
+++ b/util/ec_openocd.py
@@ -16,6 +16,7 @@ import time
Flashes and debugs the EC through openocd
"""
+
@dataclasses.dataclass
class BoardInfo:
gdb_variant: str
@@ -24,9 +25,7 @@ class BoardInfo:
# Debuggers for each board, OpenOCD currently only supports GDB
-boards = {
- "skyrim": BoardInfo("arm-none-eabi-gdb", 6, 4)
-}
+boards = {"skyrim": BoardInfo("arm-none-eabi-gdb", 6, 4)}
def create_openocd_args(interface, board):
@@ -36,9 +35,12 @@ def create_openocd_args(interface, board):
board_info = boards[board]
args = [
"openocd",
- "-f", f"interface/{interface}.cfg",
- "-c", "add_script_search_dir openocd",
- "-f", f"board/{board}.cfg",
+ "-f",
+ f"interface/{interface}.cfg",
+ "-c",
+ "add_script_search_dir openocd",
+ "-f",
+ f"board/{board}.cfg",
]
return args
@@ -53,11 +55,13 @@ def create_gdb_args(board, port, executable):
board_info.gdb_variant,
executable,
# GDB can't autodetect these according to OpenOCD
- "-ex", f"set remote hardware-breakpoint-limit {board_info.num_breakpoints}",
- "-ex", f"set remote hardware-watchpoint-limit {board_info.num_watchpoints}",
-
+ "-ex",
+ f"set remote hardware-breakpoint-limit {board_info.num_breakpoints}",
+ "-ex",
+ f"set remote hardware-watchpoint-limit {board_info.num_watchpoints}",
# Connect to OpenOCD
- "-ex", f"target extended-remote localhost:{port}",
+ "-ex",
+ f"target extended-remote localhost:{port}",
]
return args
diff --git a/util/flash_jlink.py b/util/flash_jlink.py
index 26c3c2e709..50a0bfca20 100755
--- a/util/flash_jlink.py
+++ b/util/flash_jlink.py
@@ -25,7 +25,6 @@ import sys
import tempfile
import time
-
DEFAULT_SEGGER_REMOTE_PORT = 19020
# Commands are documented here: https://wiki.segger.com/J-Link_Commander
@@ -41,27 +40,34 @@ exit
class BoardConfig:
"""Board configuration."""
+
def __init__(self, interface, device, flash_address):
self.interface = interface
self.device = device
self.flash_address = flash_address
-SWD_INTERFACE = 'SWD'
-STM32_DEFAULT_FLASH_ADDRESS = '0x8000000'
-DRAGONCLAW_CONFIG = BoardConfig(interface=SWD_INTERFACE, device='STM32F412CG',
- flash_address=STM32_DEFAULT_FLASH_ADDRESS)
-ICETOWER_CONFIG = BoardConfig(interface=SWD_INTERFACE, device='STM32H743ZI',
- flash_address=STM32_DEFAULT_FLASH_ADDRESS)
+SWD_INTERFACE = "SWD"
+STM32_DEFAULT_FLASH_ADDRESS = "0x8000000"
+DRAGONCLAW_CONFIG = BoardConfig(
+ interface=SWD_INTERFACE,
+ device="STM32F412CG",
+ flash_address=STM32_DEFAULT_FLASH_ADDRESS,
+)
+ICETOWER_CONFIG = BoardConfig(
+ interface=SWD_INTERFACE,
+ device="STM32H743ZI",
+ flash_address=STM32_DEFAULT_FLASH_ADDRESS,
+)
BOARD_CONFIGS = {
- 'dragonclaw': DRAGONCLAW_CONFIG,
- 'bloonchipper': DRAGONCLAW_CONFIG,
- 'nucleo-f412zg': DRAGONCLAW_CONFIG,
- 'dartmonkey': ICETOWER_CONFIG,
- 'icetower': ICETOWER_CONFIG,
- 'nucleo-dartmonkey': ICETOWER_CONFIG,
- 'nucleo-h743zi': ICETOWER_CONFIG,
+ "dragonclaw": DRAGONCLAW_CONFIG,
+ "bloonchipper": DRAGONCLAW_CONFIG,
+ "nucleo-f412zg": DRAGONCLAW_CONFIG,
+ "dartmonkey": ICETOWER_CONFIG,
+ "icetower": ICETOWER_CONFIG,
+ "nucleo-dartmonkey": ICETOWER_CONFIG,
+ "nucleo-h743zi": ICETOWER_CONFIG,
}
@@ -93,9 +99,11 @@ def is_tcp_port_open(host: str, tcp_port: int) -> bool:
def create_jlink_command_file(firmware_file, config):
tmp = tempfile.NamedTemporaryFile()
- tmp.write(JLINK_COMMANDS.format(FIRMWARE=firmware_file,
- FLASH_ADDRESS=config.flash_address).encode(
- 'utf-8'))
+ tmp.write(
+ JLINK_COMMANDS.format(
+ FIRMWARE=firmware_file, FLASH_ADDRESS=config.flash_address
+ ).encode("utf-8")
+ )
tmp.flush()
return tmp
@@ -106,8 +114,8 @@ def flash(jlink_exe, remote, device, interface, cmd_file):
]
if remote:
- logging.debug(f'Connecting to J-Link over TCP/IP {remote}.')
- remote_components = remote.split(':')
+ logging.debug(f"Connecting to J-Link over TCP/IP {remote}.")
+ remote_components = remote.split(":")
if len(remote_components) not in [1, 2]:
logging.debug(f'Given remote "{remote}" is malformed.')
return 1
@@ -118,7 +126,7 @@ def flash(jlink_exe, remote, device, interface, cmd_file):
except socket.gaierror as e:
logging.error(f'Failed to resolve host "{host}": {e}.')
return 1
- logging.debug(f'Resolved {host} as {ip}.')
+ logging.debug(f"Resolved {host} as {ip}.")
port = DEFAULT_SEGGER_REMOTE_PORT
if len(remote_components) == 2:
@@ -126,29 +134,36 @@ def flash(jlink_exe, remote, device, interface, cmd_file):
port = int(remote_components[1])
except ValueError:
logging.error(
- f'Given remote port "{remote_components[1]}" is malformed.')
+ f'Given remote port "{remote_components[1]}" is malformed.'
+ )
return 1
- remote = f'{ip}:{port}'
+ remote = f"{ip}:{port}"
- logging.debug(f'Checking connection to {remote}.')
+ logging.debug(f"Checking connection to {remote}.")
if not is_tcp_port_open(ip, port):
- logging.error(
- f"JLink server doesn't seem to be listening on {remote}.")
- logging.error('Ensure that JLinkRemoteServerCLExe is running.')
+ logging.error(f"JLink server doesn't seem to be listening on {remote}.")
+ logging.error("Ensure that JLinkRemoteServerCLExe is running.")
return 1
- cmd.extend(['-ip', remote])
-
- cmd.extend([
- '-device', device,
- '-if', interface,
- '-speed', 'auto',
- '-autoconnect', '1',
- '-CommandFile', cmd_file,
- ])
- logging.debug('Running command: "%s"', ' '.join(cmd))
+ cmd.extend(["-ip", remote])
+
+ cmd.extend(
+ [
+ "-device",
+ device,
+ "-if",
+ interface,
+ "-speed",
+ "auto",
+ "-autoconnect",
+ "1",
+ "-CommandFile",
+ cmd_file,
+ ]
+ )
+ logging.debug('Running command: "%s"', " ".join(cmd))
completed_process = subprocess.run(cmd) # pylint: disable=subprocess-run-check
- logging.debug('JLink return code: %d', completed_process.returncode)
+ logging.debug("JLink return code: %d", completed_process.returncode)
return completed_process.returncode
@@ -156,38 +171,42 @@ def main(argv: list):
parser = argparse.ArgumentParser()
- default_jlink = './JLink_Linux_V684a_x86_64/JLinkExe'
+ default_jlink = "./JLink_Linux_V684a_x86_64/JLinkExe"
if shutil.which(default_jlink) is None:
- default_jlink = 'JLinkExe'
- parser.add_argument(
- '--jlink', '-j',
- help='JLinkExe path (default: ' + default_jlink + ')',
- default=default_jlink)
-
+ default_jlink = "JLinkExe"
parser.add_argument(
- '--remote', '-n',
- help='Use TCP/IP host[:port] to connect to a J-Link or '
- 'JLinkRemoteServerCLExe. If unspecified, connect over USB.')
+ "--jlink",
+ "-j",
+ help="JLinkExe path (default: " + default_jlink + ")",
+ default=default_jlink,
+ )
- default_board = 'bloonchipper'
parser.add_argument(
- '--board', '-b',
- help='Board (default: ' + default_board + ')',
- default=default_board)
+ "--remote",
+ "-n",
+ help="Use TCP/IP host[:port] to connect to a J-Link or "
+ "JLinkRemoteServerCLExe. If unspecified, connect over USB.",
+ )
- default_firmware = os.path.join('./build', default_board, 'ec.bin')
+ default_board = "bloonchipper"
parser.add_argument(
- '--image', '-i',
- help='Firmware binary (default: ' + default_firmware + ')',
- default=default_firmware)
+ "--board",
+ "-b",
+ help="Board (default: " + default_board + ")",
+ default=default_board,
+ )
- log_level_choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
+ default_firmware = os.path.join("./build", default_board, "ec.bin")
parser.add_argument(
- '--log_level', '-l',
- choices=log_level_choices,
- default='DEBUG'
+ "--image",
+ "-i",
+ help="Firmware binary (default: " + default_firmware + ")",
+ default=default_firmware,
)
+ log_level_choices = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
+ parser.add_argument("--log_level", "-l", choices=log_level_choices, default="DEBUG")
+
args = parser.parse_args(argv)
logging.basicConfig(level=args.log_level)
@@ -201,11 +220,12 @@ def main(argv: list):
args.jlink = args.jlink
cmd_file = create_jlink_command_file(args.image, config)
- ret_code = flash(args.jlink, args.remote, config.device, config.interface,
- cmd_file.name)
+ ret_code = flash(
+ args.jlink, args.remote, config.device, config.interface, cmd_file.name
+ )
cmd_file.close()
return ret_code
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/util/fptool.py b/util/fptool.py
index 5d73302bbc..b7f2150289 100755
--- a/util/fptool.py
+++ b/util/fptool.py
@@ -19,14 +19,14 @@ def cmd_flash(args: argparse.Namespace) -> int:
disabled.
"""
- if not shutil.which('flash_fp_mcu'):
- print('Error - The flash_fp_mcu utility does not exist.')
+ if not shutil.which("flash_fp_mcu"):
+ print("Error - The flash_fp_mcu utility does not exist.")
return 1
- cmd = ['flash_fp_mcu']
+ cmd = ["flash_fp_mcu"]
if args.image:
if not os.path.isfile(args.image):
- print(f'Error - image {args.image} is not a file.')
+ print(f"Error - image {args.image} is not a file.")
return 1
cmd.append(args.image)
@@ -38,18 +38,17 @@ def cmd_flash(args: argparse.Namespace) -> int:
def main(argv: list) -> int:
parser = argparse.ArgumentParser(description=__doc__)
- subparsers = parser.add_subparsers(dest='subcommand', title='subcommands')
+ subparsers = parser.add_subparsers(dest="subcommand", title="subcommands")
# This method of setting required is more compatible with older python.
subparsers.required = True
# Parser for "flash" subcommand.
- parser_decrypt = subparsers.add_parser('flash', help=cmd_flash.__doc__)
- parser_decrypt.add_argument(
- 'image', nargs='?', help='Path to the firmware image')
+ parser_decrypt = subparsers.add_parser("flash", help=cmd_flash.__doc__)
+ parser_decrypt.add_argument("image", nargs="?", help="Path to the firmware image")
parser_decrypt.set_defaults(func=cmd_flash)
opts = parser.parse_args(argv)
return opts.func(opts)
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/util/inject-keys.py b/util/inject-keys.py
index bd10b693ad..d05d4fbed7 100755
--- a/util/inject-keys.py
+++ b/util/inject-keys.py
@@ -8,50 +8,124 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import string
import subprocess
import sys
-
-KEYMATRIX = {'`': (3, 1), '1': (6, 1), '2': (6, 4), '3': (6, 2), '4': (6, 3),
- '5': (3, 3), '6': (3, 6), '7': (6, 6), '8': (6, 5), '9': (6, 9),
- '0': (6, 8), '-': (3, 8), '=': (0, 8), 'q': (7, 1), 'w': (7, 4),
- 'e': (7, 2), 'r': (7, 3), 't': (2, 3), 'y': (2, 6), 'u': (7, 6),
- 'i': (7, 5), 'o': (7, 9), 'p': (7, 8), '[': (2, 8), ']': (2, 5),
- '\\': (3, 11), 'a': (4, 1), 's': (4, 4), 'd': (4, 2), 'f': (4, 3),
- 'g': (1, 3), 'h': (1, 6), 'j': (4, 6), 'k': (4, 5), 'l': (4, 9),
- ';': (4, 8), '\'': (1, 8), 'z': (5, 1), 'x': (5, 4), 'c': (5, 2),
- 'v': (5, 3), 'b': (0, 3), 'n': (0, 6), 'm': (5, 6), ',': (5, 5),
- '.': (5, 9), '/': (5, 8), ' ': (5, 11), '<right>': (6, 12),
- '<alt_r>': (0, 10), '<down>': (6, 11), '<tab>': (2, 1),
- '<f10>': (0, 4), '<shift_r>': (7, 7), '<ctrl_r>': (4, 0),
- '<esc>': (1, 1), '<backspace>': (1, 11), '<f2>': (3, 2),
- '<alt_l>': (6, 10), '<ctrl_l>': (2, 0), '<f1>': (0, 2),
- '<search>': (0, 1), '<f3>': (2, 2), '<f4>': (1, 2), '<f5>': (3, 4),
- '<f6>': (2, 4), '<f7>': (1, 4), '<f8>': (2, 9), '<f9>': (1, 9),
- '<up>': (7, 11), '<shift_l>': (5, 7), '<enter>': (4, 11),
- '<left>': (7, 12)}
-
-
-UNSHIFT_TABLE = { '~': '`', '!': '1', '@': '2', '#': '3', '$': '4',
- '%': '5', '^': '6', '&': '7', '*': '8', '(': '9',
- ')': '0', '_': '-', '+': '=', '{': '[', '}': ']',
- '|': '\\',
- ':': ';', '"': "'", '<': ',', '>': '.', '?': '/'}
+KEYMATRIX = {
+ "`": (3, 1),
+ "1": (6, 1),
+ "2": (6, 4),
+ "3": (6, 2),
+ "4": (6, 3),
+ "5": (3, 3),
+ "6": (3, 6),
+ "7": (6, 6),
+ "8": (6, 5),
+ "9": (6, 9),
+ "0": (6, 8),
+ "-": (3, 8),
+ "=": (0, 8),
+ "q": (7, 1),
+ "w": (7, 4),
+ "e": (7, 2),
+ "r": (7, 3),
+ "t": (2, 3),
+ "y": (2, 6),
+ "u": (7, 6),
+ "i": (7, 5),
+ "o": (7, 9),
+ "p": (7, 8),
+ "[": (2, 8),
+ "]": (2, 5),
+ "\\": (3, 11),
+ "a": (4, 1),
+ "s": (4, 4),
+ "d": (4, 2),
+ "f": (4, 3),
+ "g": (1, 3),
+ "h": (1, 6),
+ "j": (4, 6),
+ "k": (4, 5),
+ "l": (4, 9),
+ ";": (4, 8),
+ "'": (1, 8),
+ "z": (5, 1),
+ "x": (5, 4),
+ "c": (5, 2),
+ "v": (5, 3),
+ "b": (0, 3),
+ "n": (0, 6),
+ "m": (5, 6),
+ ",": (5, 5),
+ ".": (5, 9),
+ "/": (5, 8),
+ " ": (5, 11),
+ "<right>": (6, 12),
+ "<alt_r>": (0, 10),
+ "<down>": (6, 11),
+ "<tab>": (2, 1),
+ "<f10>": (0, 4),
+ "<shift_r>": (7, 7),
+ "<ctrl_r>": (4, 0),
+ "<esc>": (1, 1),
+ "<backspace>": (1, 11),
+ "<f2>": (3, 2),
+ "<alt_l>": (6, 10),
+ "<ctrl_l>": (2, 0),
+ "<f1>": (0, 2),
+ "<search>": (0, 1),
+ "<f3>": (2, 2),
+ "<f4>": (1, 2),
+ "<f5>": (3, 4),
+ "<f6>": (2, 4),
+ "<f7>": (1, 4),
+ "<f8>": (2, 9),
+ "<f9>": (1, 9),
+ "<up>": (7, 11),
+ "<shift_l>": (5, 7),
+ "<enter>": (4, 11),
+ "<left>": (7, 12),
+}
+
+
+UNSHIFT_TABLE = {
+ "~": "`",
+ "!": "1",
+ "@": "2",
+ "#": "3",
+ "$": "4",
+ "%": "5",
+ "^": "6",
+ "&": "7",
+ "*": "8",
+ "(": "9",
+ ")": "0",
+ "_": "-",
+ "+": "=",
+ "{": "[",
+ "}": "]",
+ "|": "\\",
+ ":": ";",
+ '"': "'",
+ "<": ",",
+ ">": ".",
+ "?": "/",
+}
for c in string.ascii_lowercase:
UNSHIFT_TABLE[c.upper()] = c
def inject_event(key, press):
- if len(key) >= 2 and key[0] != '<':
- key = '<' + key + '>'
+ if len(key) >= 2 and key[0] != "<":
+ key = "<" + key + ">"
if key not in KEYMATRIX:
print("%s: invalid key: %s" % (this_script, key))
sys.exit(1)
(row, col) = KEYMATRIX[key]
- subprocess.call(["ectool", "kbpress", str(row), str(col),
- "1" if press else "0"])
+ subprocess.call(["ectool", "kbpress", str(row), str(col), "1" if press else "0"])
def inject_key(key):
@@ -73,8 +147,10 @@ def inject_string(string):
def usage():
- print("Usage: %s [-s <string>] [-k <key>]" % this_script,
- "[-p <pressed-key>] [-r <released-key>] ...")
+ print(
+ "Usage: %s [-s <string>] [-k <key>]" % this_script,
+ "[-p <pressed-key>] [-r <released-key>] ...",
+ )
print("Examples:")
print("%s -s MyPassw0rd -k enter" % this_script)
print("%s -p ctrl_l -p alt_l -k f3 -r alt_l -r ctrl_l" % this_script)
@@ -85,7 +161,7 @@ def help():
print("Valid keys are:")
i = 0
for key in KEYMATRIX:
- print("%12s" % key, end='')
+ print("%12s" % key, end="")
i += 1
if i % 4 == 0:
print()
@@ -114,12 +190,13 @@ usage_check(arg_len > 1, "not enough arguments")
usage_check(arg_len % 2 == 1, "mismatched arguments")
for i in range(1, arg_len, 2):
- usage_check(sys.argv[i] in ("-s", "-k", "-p", "-r"),
- "unknown flag: %s" % sys.argv[i])
+ usage_check(
+ sys.argv[i] in ("-s", "-k", "-p", "-r"), "unknown flag: %s" % sys.argv[i]
+ )
for i in range(1, arg_len, 2):
flag = sys.argv[i]
- arg = sys.argv[i+1]
+ arg = sys.argv[i + 1]
if flag == "-s":
inject_string(arg)
elif flag == "-k":
diff --git a/util/kconfig_check.py b/util/kconfig_check.py
index d1eba8e62b..04cdf9a990 100755
--- a/util/kconfig_check.py
+++ b/util/kconfig_check.py
@@ -32,12 +32,13 @@ import sys
USE_KCONFIGLIB = False
try:
import kconfiglib
+
USE_KCONFIGLIB = True
except ImportError:
pass
# Where we put the new config_allowed file
-NEW_ALLOWED_FNAME = pathlib.Path('/tmp/new_config_allowed.txt')
+NEW_ALLOWED_FNAME = pathlib.Path("/tmp/new_config_allowed.txt")
def parse_args(argv):
@@ -49,38 +50,72 @@ def parse_args(argv):
Returns:
argparse.Namespace object containing the results
"""
- epilog = '''Checks that new ad-hoc CONFIG options are not introduced without
-a corresponding Kconfig option for Zephyr'''
+ epilog = """Checks that new ad-hoc CONFIG options are not introduced without
+a corresponding Kconfig option for Zephyr"""
parser = argparse.ArgumentParser(epilog=epilog)
- parser.add_argument('-a', '--allowed', type=str,
- default='util/config_allowed.txt',
- help='File containing list of allowed ad-hoc CONFIGs')
- parser.add_argument('-c', '--configs', type=str, default='.config',
- help='File containing CONFIG options to check')
- parser.add_argument('-d', '--use-defines', action='store_true',
- help='Lines in the configs file use #define')
parser.add_argument(
- '-D', '--debug', action='store_true',
- help='Enabling debugging (provides a full traceback on error)')
+ "-a",
+ "--allowed",
+ type=str,
+ default="util/config_allowed.txt",
+ help="File containing list of allowed ad-hoc CONFIGs",
+ )
+ parser.add_argument(
+ "-c",
+ "--configs",
+ type=str,
+ default=".config",
+ help="File containing CONFIG options to check",
+ )
+ parser.add_argument(
+ "-d",
+ "--use-defines",
+ action="store_true",
+ help="Lines in the configs file use #define",
+ )
+ parser.add_argument(
+ "-D",
+ "--debug",
+ action="store_true",
+ help="Enabling debugging (provides a full traceback on error)",
+ )
+ parser.add_argument(
+ "-i",
+ "--ignore",
+ action="append",
+ help="Kconfig options to ignore (without CONFIG_ prefix)",
+ )
parser.add_argument(
- '-i', '--ignore', action='append',
- help='Kconfig options to ignore (without CONFIG_ prefix)')
- parser.add_argument('-I', '--search-path', type=str, action='append',
- help='Search paths to look for Kconfigs')
- parser.add_argument('-p', '--prefix', type=str, default='PLATFORM_EC_',
- help='Prefix to string from Kconfig options')
- parser.add_argument('-s', '--srctree', type=str, default='zephyr/',
- help='Path to source tree to look for Kconfigs')
+ "-I",
+ "--search-path",
+ type=str,
+ action="append",
+ help="Search paths to look for Kconfigs",
+ )
+ parser.add_argument(
+ "-p",
+ "--prefix",
+ type=str,
+ default="PLATFORM_EC_",
+ help="Prefix to string from Kconfig options",
+ )
+ parser.add_argument(
+ "-s",
+ "--srctree",
+ type=str,
+ default="zephyr/",
+ help="Path to source tree to look for Kconfigs",
+ )
# TODO(sjg@chromium.org): The chroot uses a very old Python. Once it moves
# to 3.7 or later we can use this instead:
# subparsers = parser.add_subparsers(dest='cmd', required=True)
- subparsers = parser.add_subparsers(dest='cmd')
+ subparsers = parser.add_subparsers(dest="cmd")
subparsers.required = True
- subparsers.add_parser('build', help='Build new list of ad-hoc CONFIGs')
- subparsers.add_parser('check', help='Check for new ad-hoc CONFIGs')
+ subparsers.add_parser("build", help="Build new list of ad-hoc CONFIGs")
+ subparsers.add_parser("check", help="Check for new ad-hoc CONFIGs")
return parser.parse_args(argv)
@@ -107,6 +142,7 @@ class KconfigCheck:
the user is exhorted to add a new Kconfig. This helps avoid adding new ad-hoc
CONFIG options, eventually returning the number to zero.
"""
+
@classmethod
def find_new_adhoc(cls, configs, kconfigs, allowed):
"""Get a list of new ad-hoc CONFIG options
@@ -172,11 +208,12 @@ class KconfigCheck:
List of CONFIG_xxx options found in the file, with the 'CONFIG_'
prefix removed
"""
- with open(configs_file, 'r') as inf:
- configs = re.findall('%sCONFIG_([A-Za-z0-9_]*)%s' %
- ((use_defines and '#define ' or ''),
- (use_defines and ' ' or '')),
- inf.read())
+ with open(configs_file, "r") as inf:
+ configs = re.findall(
+ "%sCONFIG_([A-Za-z0-9_]*)%s"
+ % ((use_defines and "#define " or ""), (use_defines and " " or "")),
+ inf.read(),
+ )
return configs
@classmethod
@@ -190,8 +227,8 @@ class KconfigCheck:
List of CONFIG_xxx options found in the file, with the 'CONFIG_'
prefix removed
"""
- with open(allowed_file, 'r') as inf:
- configs = re.findall('CONFIG_([A-Za-z0-9_]*)', inf.read())
+ with open(allowed_file, "r") as inf:
+ configs = re.findall("CONFIG_([A-Za-z0-9_]*)", inf.read())
return configs
@classmethod
@@ -209,15 +246,17 @@ class KconfigCheck:
"""
kconfig_files = []
for root, dirs, files in os.walk(srcdir):
- kconfig_files += [os.path.join(root, fname)
- for fname in files if fname.startswith('Kconfig')]
- if 'Kconfig' in dirs:
- dirs.remove('Kconfig')
+ kconfig_files += [
+ os.path.join(root, fname)
+ for fname in files
+ if fname.startswith("Kconfig")
+ ]
+ if "Kconfig" in dirs:
+ dirs.remove("Kconfig")
return kconfig_files
@classmethod
- def scan_kconfigs(cls, srcdir, prefix='', search_paths=None,
- try_kconfiglib=True):
+ def scan_kconfigs(cls, srcdir, prefix="", search_paths=None, try_kconfiglib=True):
"""Scan a source tree for Kconfig options
Args:
@@ -231,31 +270,40 @@ class KconfigCheck:
List of config and menuconfig options found
"""
if USE_KCONFIGLIB and try_kconfiglib:
- os.environ['srctree'] = srcdir
- kconf = kconfiglib.Kconfig('Kconfig', warn=False,
- search_paths=search_paths,
- allow_empty_macros=True)
+ os.environ["srctree"] = srcdir
+ kconf = kconfiglib.Kconfig(
+ "Kconfig",
+ warn=False,
+ search_paths=search_paths,
+ allow_empty_macros=True,
+ )
# There is always a MODULES config, since kconfiglib is designed for
# linux, but we don't want it
- kconfigs = [name for name in kconf.syms if name != 'MODULES']
+ kconfigs = [name for name in kconf.syms if name != "MODULES"]
if prefix:
- re_drop_prefix = re.compile(r'^%s' % prefix)
- kconfigs = [re_drop_prefix.sub('', name) for name in kconfigs]
+ re_drop_prefix = re.compile(r"^%s" % prefix)
+ kconfigs = [re_drop_prefix.sub("", name) for name in kconfigs]
else:
kconfigs = []
# Remove the prefix if present
- expr = re.compile(r'\n(config|menuconfig) (%s)?([A-Za-z0-9_]*)\n' %
- prefix)
+ expr = re.compile(r"\n(config|menuconfig) (%s)?([A-Za-z0-9_]*)\n" % prefix)
for fname in cls.find_kconfigs(srcdir):
with open(fname) as inf:
found = re.findall(expr, inf.read())
kconfigs += [name for kctype, _, name in found]
return sorted(kconfigs)
- def check_adhoc_configs(self, configs_file, srcdir, allowed_file,
- prefix='', use_defines=False, search_paths=None):
+ def check_adhoc_configs(
+ self,
+ configs_file,
+ srcdir,
+ allowed_file,
+ prefix="",
+ use_defines=False,
+ search_paths=None,
+ ):
"""Find new and unneeded ad-hoc configs in the configs_file
Args:
@@ -283,8 +331,9 @@ class KconfigCheck:
except kconfiglib.KconfigError:
# If we don't actually have access to the full Kconfig then we may
# get an error. Fall back to using manual methods.
- kconfigs = self.scan_kconfigs(srcdir, prefix, search_paths,
- try_kconfiglib=False)
+ kconfigs = self.scan_kconfigs(
+ srcdir, prefix, search_paths, try_kconfiglib=False
+ )
allowed = self.read_allowed(allowed_file)
new_adhoc = self.find_new_adhoc(configs, kconfigs, allowed)
@@ -292,8 +341,16 @@ class KconfigCheck:
updated_adhoc = self.get_updated_adhoc(unneeded_adhoc, allowed)
return new_adhoc, unneeded_adhoc, updated_adhoc
- def do_check(self, configs_file, srcdir, allowed_file, prefix, use_defines,
- search_paths, ignore=None):
+ def do_check(
+ self,
+ configs_file,
+ srcdir,
+ allowed_file,
+ prefix,
+ use_defines,
+ search_paths,
+ ignore=None,
+ ):
"""Find new ad-hoc configs in the configs_file
Args:
@@ -313,11 +370,12 @@ class KconfigCheck:
Exit code: 0 if OK, 1 if a problem was found
"""
new_adhoc, unneeded_adhoc, updated_adhoc = self.check_adhoc_configs(
- configs_file, srcdir, allowed_file, prefix, use_defines,
- search_paths)
+ configs_file, srcdir, allowed_file, prefix, use_defines, search_paths
+ )
if new_adhoc:
- file_list = '\n'.join(['CONFIG_%s' % name for name in new_adhoc])
- print(f'''Error:\tThe EC is in the process of migrating to Zephyr.
+ file_list = "\n".join(["CONFIG_%s" % name for name in new_adhoc])
+ print(
+ f"""Error:\tThe EC is in the process of migrating to Zephyr.
\tZephyr uses Kconfig for configuration rather than ad-hoc #defines.
\tAny new EC CONFIG options must ALSO be added to Zephyr so that new
\tfunctionality is available in Zephyr also. The following new ad-hoc
@@ -330,19 +388,21 @@ file in zephyr/ and add a 'config' or 'menuconfig' option.
Also see details in http://issuetracker.google.com/181253613
To temporarily disable this, use: ALLOW_CONFIG=1 make ...
-''', file=sys.stderr)
+""",
+ file=sys.stderr,
+ )
return 1
if not ignore:
ignore = []
unneeded_adhoc = [name for name in unneeded_adhoc if name not in ignore]
if unneeded_adhoc:
- with open(NEW_ALLOWED_FNAME, 'w') as out:
+ with open(NEW_ALLOWED_FNAME, "w") as out:
for config in updated_adhoc:
- print('CONFIG_%s' % config, file=out)
- now_in_kconfig = '\n'.join(
- ['CONFIG_%s' % name for name in unneeded_adhoc])
- print(f'''The following options are now in Kconfig:
+ print("CONFIG_%s" % config, file=out)
+ now_in_kconfig = "\n".join(["CONFIG_%s" % name for name in unneeded_adhoc])
+ print(
+ f"""The following options are now in Kconfig:
{now_in_kconfig}
@@ -350,12 +410,14 @@ Please run this to update the list of allowed ad-hoc CONFIGs and include this
update in your CL:
cp {NEW_ALLOWED_FNAME} util/config_allowed.txt
-''')
+"""
+ )
return 1
return 0
- def do_build(self, configs_file, srcdir, allowed_file, prefix, use_defines,
- search_paths):
+ def do_build(
+ self, configs_file, srcdir, allowed_file, prefix, use_defines, search_paths
+ ):
"""Find new ad-hoc configs in the configs_file
Args:
@@ -372,13 +434,14 @@ update in your CL:
Exit code: 0 if OK, 1 if a problem was found
"""
new_adhoc, _, updated_adhoc = self.check_adhoc_configs(
- configs_file, srcdir, allowed_file, prefix, use_defines,
- search_paths)
- with open(NEW_ALLOWED_FNAME, 'w') as out:
+ configs_file, srcdir, allowed_file, prefix, use_defines, search_paths
+ )
+ with open(NEW_ALLOWED_FNAME, "w") as out:
combined = sorted(new_adhoc + updated_adhoc)
for config in combined:
- print(f'CONFIG_{config}', file=out)
- print(f'New list is in {NEW_ALLOWED_FNAME}')
+ print(f"CONFIG_{config}", file=out)
+ print(f"New list is in {NEW_ALLOWED_FNAME}")
+
def main(argv):
"""Main function"""
@@ -386,18 +449,27 @@ def main(argv):
if not args.debug:
sys.tracebacklimit = 0
checker = KconfigCheck()
- if args.cmd == 'check':
+ if args.cmd == "check":
return checker.do_check(
- configs_file=args.configs, srcdir=args.srctree,
- allowed_file=args.allowed, prefix=args.prefix,
- use_defines=args.use_defines, search_paths=args.search_path,
- ignore=args.ignore)
- elif args.cmd == 'build':
- return checker.do_build(configs_file=args.configs, srcdir=args.srctree,
- allowed_file=args.allowed, prefix=args.prefix,
- use_defines=args.use_defines, search_paths=args.search_path)
+ configs_file=args.configs,
+ srcdir=args.srctree,
+ allowed_file=args.allowed,
+ prefix=args.prefix,
+ use_defines=args.use_defines,
+ search_paths=args.search_path,
+ ignore=args.ignore,
+ )
+ elif args.cmd == "build":
+ return checker.do_build(
+ configs_file=args.configs,
+ srcdir=args.srctree,
+ allowed_file=args.allowed,
+ prefix=args.prefix,
+ use_defines=args.use_defines,
+ search_paths=args.search_path,
+ )
return 2
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/util/kconfiglib.py b/util/kconfiglib.py
index 0e05aaaeac..a0033bba2d 100644
--- a/util/kconfiglib.py
+++ b/util/kconfiglib.py
@@ -553,7 +553,6 @@ import sys
from glob import iglob
from os.path import dirname, exists, expandvars, islink, join, realpath
-
VERSION = (14, 1, 0)
@@ -810,6 +809,7 @@ class Kconfig(object):
The current parsing location, for use in Python preprocessor functions.
See the module docstring.
"""
+
__slots__ = (
"_encoding",
"_functions",
@@ -848,7 +848,6 @@ class Kconfig(object):
"warn_to_stderr",
"warnings",
"y",
-
# Parsing-related
"_parsing_kconfigs",
"_readline",
@@ -866,9 +865,16 @@ class Kconfig(object):
# Public interface
#
- def __init__(self, filename="Kconfig", warn=True, warn_to_stderr=True,
- encoding="utf-8", suppress_traceback=False, search_paths=None,
- allow_empty_macros=False):
+ def __init__(
+ self,
+ filename="Kconfig",
+ warn=True,
+ warn_to_stderr=True,
+ encoding="utf-8",
+ suppress_traceback=False,
+ search_paths=None,
+ allow_empty_macros=False,
+ ):
"""
Creates a new Kconfig object by parsing Kconfig files.
Note that Kconfig files are not the same as .config files (which store
@@ -972,8 +978,14 @@ class Kconfig(object):
Pass True here to allow empty / undefined macros.
"""
try:
- self._init(filename, warn, warn_to_stderr, encoding, search_paths,
- allow_empty_macros)
+ self._init(
+ filename,
+ warn,
+ warn_to_stderr,
+ encoding,
+ search_paths,
+ allow_empty_macros,
+ )
except (EnvironmentError, KconfigError) as e:
if suppress_traceback:
cmd = sys.argv[0] # Empty string if missing
@@ -985,8 +997,9 @@ class Kconfig(object):
sys.exit(cmd + str(e).strip())
raise
- def _init(self, filename, warn, warn_to_stderr, encoding, search_paths,
- allow_empty_macros):
+ def _init(
+ self, filename, warn, warn_to_stderr, encoding, search_paths, allow_empty_macros
+ ):
# See __init__()
self._encoding = encoding
@@ -1011,8 +1024,9 @@ class Kconfig(object):
self.config_prefix = os.getenv("CONFIG_", "CONFIG_")
# Regular expressions for parsing .config files
self._set_match = _re_match(self.config_prefix + r"([^=]+)=(.*)")
- self._unset_match = _re_match(r"# {}([^ ]+) is not set".format(
- self.config_prefix))
+ self._unset_match = _re_match(
+ r"# {}([^ ]+) is not set".format(self.config_prefix)
+ )
self.config_header = os.getenv("KCONFIG_CONFIG_HEADER", "")
self.header_header = os.getenv("KCONFIG_AUTOHEADER_HEADER", "")
@@ -1050,11 +1064,11 @@ class Kconfig(object):
# Predefined preprocessor functions, with min/max number of arguments
self._functions = {
- "info": (_info_fn, 1, 1),
- "error-if": (_error_if_fn, 2, 2),
- "filename": (_filename_fn, 0, 0),
- "lineno": (_lineno_fn, 0, 0),
- "shell": (_shell_fn, 1, 1),
+ "info": (_info_fn, 1, 1),
+ "error-if": (_error_if_fn, 2, 2),
+ "filename": (_filename_fn, 0, 0),
+ "lineno": (_lineno_fn, 0, 0),
+ "shell": (_shell_fn, 1, 1),
"warning-if": (_warning_if_fn, 2, 2),
}
@@ -1063,7 +1077,8 @@ class Kconfig(object):
self._functions.update(
importlib.import_module(
os.getenv("KCONFIG_FUNCTIONS", "kconfigfunctions")
- ).functions)
+ ).functions
+ )
except ImportError:
pass
@@ -1138,8 +1153,7 @@ class Kconfig(object):
# KCONFIG_STRICT is an older alias for KCONFIG_WARN_UNDEF, supported
# for backwards compatibility
- if os.getenv("KCONFIG_WARN_UNDEF") == "y" or \
- os.getenv("KCONFIG_STRICT") == "y":
+ if os.getenv("KCONFIG_WARN_UNDEF") == "y" or os.getenv("KCONFIG_STRICT") == "y":
self._check_undef_syms()
@@ -1247,15 +1261,14 @@ class Kconfig(object):
msg = None
if filename is None:
filename = standard_config_filename()
- if not exists(filename) and \
- not exists(join(self.srctree, filename)):
+ if not exists(filename) and not exists(join(self.srctree, filename)):
defconfig = self.defconfig_filename
if defconfig is None:
- return "Using default symbol values (no '{}')" \
- .format(filename)
+ return "Using default symbol values (no '{}')".format(filename)
- msg = " default configuration '{}' (no '{}')" \
- .format(defconfig, filename)
+ msg = " default configuration '{}' (no '{}')".format(
+ defconfig, filename
+ )
filename = defconfig
if not msg:
@@ -1313,15 +1326,20 @@ class Kconfig(object):
if sym.orig_type in _BOOL_TRISTATE:
# The C implementation only checks the first character
# to the right of '=', for whatever reason
- if not (sym.orig_type is BOOL
- and val.startswith(("y", "n")) or
- sym.orig_type is TRISTATE
- and val.startswith(("y", "m", "n"))):
- self._warn("'{}' is not a valid value for the {} "
- "symbol {}. Assignment ignored."
- .format(val, TYPE_TO_STR[sym.orig_type],
- sym.name_and_loc),
- filename, linenr)
+ if not (
+ sym.orig_type is BOOL
+ and val.startswith(("y", "n"))
+ or sym.orig_type is TRISTATE
+ and val.startswith(("y", "m", "n"))
+ ):
+ self._warn(
+ "'{}' is not a valid value for the {} "
+ "symbol {}. Assignment ignored.".format(
+ val, TYPE_TO_STR[sym.orig_type], sym.name_and_loc
+ ),
+ filename,
+ linenr,
+ )
continue
val = val[0]
@@ -1332,12 +1350,14 @@ class Kconfig(object):
# to the choice symbols
prev_mode = sym.choice.user_value
- if prev_mode is not None and \
- TRI_TO_STR[prev_mode] != val:
+ if prev_mode is not None and TRI_TO_STR[prev_mode] != val:
- self._warn("both m and y assigned to symbols "
- "within the same choice",
- filename, linenr)
+ self._warn(
+ "both m and y assigned to symbols "
+ "within the same choice",
+ filename,
+ linenr,
+ )
# Set the choice's mode
sym.choice.set_value(val)
@@ -1345,10 +1365,14 @@ class Kconfig(object):
elif sym.orig_type is STRING:
match = _conf_string_match(val)
if not match:
- self._warn("malformed string literal in "
- "assignment to {}. Assignment ignored."
- .format(sym.name_and_loc),
- filename, linenr)
+ self._warn(
+ "malformed string literal in "
+ "assignment to {}. Assignment ignored.".format(
+ sym.name_and_loc
+ ),
+ filename,
+ linenr,
+ )
continue
val = unescape(match.group(1))
@@ -1361,9 +1385,11 @@ class Kconfig(object):
# lines or comments. 'line' has already been
# rstrip()'d, so blank lines show up as "" here.
if line and not line.lstrip().startswith("#"):
- self._warn("ignoring malformed line '{}'"
- .format(line),
- filename, linenr)
+ self._warn(
+ "ignoring malformed line '{}'".format(line),
+ filename,
+ linenr,
+ )
continue
@@ -1403,8 +1429,12 @@ class Kconfig(object):
self.missing_syms.append((name, val))
if self.warn_assign_undef:
self._warn(
- "attempt to assign the value '{}' to the undefined symbol {}"
- .format(val, name), filename, linenr)
+ "attempt to assign the value '{}' to the undefined symbol {}".format(
+ val, name
+ ),
+ filename,
+ linenr,
+ )
def _assigned_twice(self, sym, new_val, filename, linenr):
# Called when a symbol is assigned more than once in a .config file
@@ -1416,7 +1446,8 @@ class Kconfig(object):
user_val = sym.user_value
msg = '{} set more than once. Old value "{}", new value "{}".'.format(
- sym.name_and_loc, user_val, new_val)
+ sym.name_and_loc, user_val, new_val
+ )
if user_val == new_val:
if self.warn_assign_redun:
@@ -1482,8 +1513,7 @@ class Kconfig(object):
in tools, which can do e.g. print(kconf.write_autoconf()).
"""
if filename is None:
- filename = os.getenv("KCONFIG_AUTOHEADER",
- "include/generated/autoconf.h")
+ filename = os.getenv("KCONFIG_AUTOHEADER", "include/generated/autoconf.h")
if self._write_if_changed(filename, self._autoconf_contents(header)):
return "Kconfig header saved to '{}'".format(filename)
@@ -1512,28 +1542,26 @@ class Kconfig(object):
if sym.orig_type in _BOOL_TRISTATE:
if val == "y":
- add("#define {}{} 1\n"
- .format(self.config_prefix, sym.name))
+ add("#define {}{} 1\n".format(self.config_prefix, sym.name))
elif val == "m":
- add("#define {}{}_MODULE 1\n"
- .format(self.config_prefix, sym.name))
+ add("#define {}{}_MODULE 1\n".format(self.config_prefix, sym.name))
elif sym.orig_type is STRING:
- add('#define {}{} "{}"\n'
- .format(self.config_prefix, sym.name, escape(val)))
+ add(
+ '#define {}{} "{}"\n'.format(
+ self.config_prefix, sym.name, escape(val)
+ )
+ )
else: # sym.orig_type in _INT_HEX:
- if sym.orig_type is HEX and \
- not val.startswith(("0x", "0X")):
+ if sym.orig_type is HEX and not val.startswith(("0x", "0X")):
val = "0x" + val
- add("#define {}{} {}\n"
- .format(self.config_prefix, sym.name, val))
+ add("#define {}{} {}\n".format(self.config_prefix, sym.name, val))
return "".join(chunks)
- def write_config(self, filename=None, header=None, save_old=True,
- verbose=None):
+ def write_config(self, filename=None, header=None, save_old=True, verbose=None):
r"""
Writes out symbol values in the .config format. The format matches the
C implementation, including ordering.
@@ -1647,9 +1675,12 @@ class Kconfig(object):
node = node.parent
# Add a comment when leaving visible menus
- if node.item is MENU and expr_value(node.dep) and \
- expr_value(node.visibility) and \
- node is not self.top_node:
+ if (
+ node.item is MENU
+ and expr_value(node.dep)
+ and expr_value(node.visibility)
+ and node is not self.top_node
+ ):
add("# end of {}\n".format(node.prompt[0]))
after_end_comment = True
@@ -1680,9 +1711,9 @@ class Kconfig(object):
add("\n")
add(conf_string)
- elif expr_value(node.dep) and \
- ((item is MENU and expr_value(node.visibility)) or
- item is COMMENT):
+ elif expr_value(node.dep) and (
+ (item is MENU and expr_value(node.visibility)) or item is COMMENT
+ ):
add("\n#\n# {}\n#\n".format(node.prompt[0]))
after_end_comment = False
@@ -1738,8 +1769,7 @@ class Kconfig(object):
# Skip symbols that cannot be changed. Only check
# non-choice symbols, as selects don't affect choice
# symbols.
- if not sym.choice and \
- sym.visibility <= expr_value(sym.rev_dep):
+ if not sym.choice and sym.visibility <= expr_value(sym.rev_dep):
continue
# Skip symbols whose value matches their default
@@ -1750,11 +1780,13 @@ class Kconfig(object):
# choice, unless the choice is optional or the symbol type
# isn't bool (it might be possible to set the choice mode
# to n or the symbol to m in those cases).
- if sym.choice and \
- not sym.choice.is_optional and \
- sym.choice._selection_from_defaults() is sym and \
- sym.orig_type is BOOL and \
- sym.tri_value == 2:
+ if (
+ sym.choice
+ and not sym.choice.is_optional
+ and sym.choice._selection_from_defaults() is sym
+ and sym.orig_type is BOOL
+ and sym.tri_value == 2
+ ):
continue
add(sym.config_string)
@@ -1842,9 +1874,11 @@ class Kconfig(object):
# making a missing symbol logically equivalent to n
if sym._write_to_conf:
- if sym._old_val is None and \
- sym.orig_type in _BOOL_TRISTATE and \
- val == "n":
+ if (
+ sym._old_val is None
+ and sym.orig_type in _BOOL_TRISTATE
+ and val == "n"
+ ):
# No old value (the symbol was missing or n), new value n.
# No change.
continue
@@ -1924,17 +1958,20 @@ class Kconfig(object):
# by passing a flag to it, plus we only need to look at symbols here.
self._write_if_changed(
- os.path.join(path, "auto.conf"),
- self._old_vals_contents())
+ os.path.join(path, "auto.conf"), self._old_vals_contents()
+ )
def _old_vals_contents(self):
# _write_old_vals() helper. Returns the contents to write as a string.
# Temporary list instead of generator makes this a bit faster
- return "".join([
- sym.config_string for sym in self.unique_defined_syms
+ return "".join(
+ [
+ sym.config_string
+ for sym in self.unique_defined_syms
if not (sym.orig_type in _BOOL_TRISTATE and not sym.tri_value)
- ])
+ ]
+ )
def node_iter(self, unique_syms=False):
"""
@@ -2112,30 +2149,35 @@ class Kconfig(object):
Returns a string with information about the Kconfig object when it is
evaluated on e.g. the interactive Python prompt.
"""
+
def status(flag):
return "enabled" if flag else "disabled"
- return "<{}>".format(", ".join((
- "configuration with {} symbols".format(len(self.syms)),
- 'main menu prompt "{}"'.format(self.mainmenu_text),
- "srctree is current directory" if not self.srctree else
- 'srctree "{}"'.format(self.srctree),
- 'config symbol prefix "{}"'.format(self.config_prefix),
- "warnings " + status(self.warn),
- "printing of warnings to stderr " + status(self.warn_to_stderr),
- "undef. symbol assignment warnings " +
- status(self.warn_assign_undef),
- "overriding symbol assignment warnings " +
- status(self.warn_assign_override),
- "redundant symbol assignment warnings " +
- status(self.warn_assign_redun)
- )))
+ return "<{}>".format(
+ ", ".join(
+ (
+ "configuration with {} symbols".format(len(self.syms)),
+ 'main menu prompt "{}"'.format(self.mainmenu_text),
+ "srctree is current directory"
+ if not self.srctree
+ else 'srctree "{}"'.format(self.srctree),
+ 'config symbol prefix "{}"'.format(self.config_prefix),
+ "warnings " + status(self.warn),
+ "printing of warnings to stderr " + status(self.warn_to_stderr),
+ "undef. symbol assignment warnings "
+ + status(self.warn_assign_undef),
+ "overriding symbol assignment warnings "
+ + status(self.warn_assign_override),
+ "redundant symbol assignment warnings "
+ + status(self.warn_assign_redun),
+ )
+ )
+ )
#
# Private methods
#
-
#
# File reading
#
@@ -2160,11 +2202,17 @@ class Kconfig(object):
e = e2
raise _KconfigIOError(
- e, "Could not open '{}' ({}: {}). Check that the $srctree "
- "environment variable ({}) is set correctly."
- .format(filename, errno.errorcode[e.errno], e.strerror,
- "set to '{}'".format(self.srctree) if self.srctree
- else "unset or blank"))
+ e,
+ "Could not open '{}' ({}: {}). Check that the $srctree "
+ "environment variable ({}) is set correctly.".format(
+ filename,
+ errno.errorcode[e.errno],
+ e.strerror,
+ "set to '{}'".format(self.srctree)
+ if self.srctree
+ else "unset or blank",
+ ),
+ )
def _enter_file(self, filename):
# Jumps to the beginning of a sourced Kconfig file, saving the previous
@@ -2179,7 +2227,7 @@ class Kconfig(object):
if filename.startswith(self._srctree_prefix):
# Relative path (or a redundant absolute path to within $srctree,
# but it's probably fine to reduce those too)
- rel_filename = filename[len(self._srctree_prefix):]
+ rel_filename = filename[len(self._srctree_prefix) :]
else:
# Absolute path
rel_filename = filename
@@ -2212,20 +2260,32 @@ class Kconfig(object):
raise KconfigError(
"\n{}:{}: recursive 'source' of '{}' detected. Check that "
"environment variables are set correctly.\n"
- "Include path:\n{}"
- .format(self.filename, self.linenr, rel_filename,
- "\n".join("{}:{}".format(name, linenr)
- for name, linenr in self._include_path)))
+ "Include path:\n{}".format(
+ self.filename,
+ self.linenr,
+ rel_filename,
+ "\n".join(
+ "{}:{}".format(name, linenr)
+ for name, linenr in self._include_path
+ ),
+ )
+ )
try:
self._readline = self._open(filename, "r").readline
except EnvironmentError as e:
# We already know that the file exists
raise _KconfigIOError(
- e, "{}:{}: Could not open '{}' (in '{}') ({}: {})"
- .format(self.filename, self.linenr, filename,
- self._line.strip(),
- errno.errorcode[e.errno], e.strerror))
+ e,
+ "{}:{}: Could not open '{}' (in '{}') ({}: {})".format(
+ self.filename,
+ self.linenr,
+ filename,
+ self._line.strip(),
+ errno.errorcode[e.errno],
+ e.strerror,
+ ),
+ )
self.filename = rel_filename
self.linenr = 0
@@ -2438,8 +2498,11 @@ class Kconfig(object):
else:
i = match.end()
- token = self.const_syms[name] if name in STR_TO_TRI else \
- self._lookup_sym(name)
+ token = (
+ self.const_syms[name]
+ if name in STR_TO_TRI
+ else self._lookup_sym(name)
+ )
else:
# It's a case of missing quotes. For example, the
@@ -2455,9 +2518,13 @@ class Kconfig(object):
# Named choices ('choice FOO') also end up here.
if token is not _T_CHOICE:
- self._warn("style: quotes recommended around '{}' in '{}'"
- .format(name, self._line.strip()),
- self.filename, self.linenr)
+ self._warn(
+ "style: quotes recommended around '{}' in '{}'".format(
+ name, self._line.strip()
+ ),
+ self.filename,
+ self.linenr,
+ )
token = name
i = match.end()
@@ -2476,7 +2543,7 @@ class Kconfig(object):
end_i = s.find(c, i + 1) + 1
if not end_i:
self._parse_error("unterminated string")
- val = s[i + 1:end_i - 1]
+ val = s[i + 1 : end_i - 1]
i = end_i
else:
# Slow path
@@ -2489,18 +2556,22 @@ class Kconfig(object):
#
# The preprocessor functionality changed how
# environment variables are referenced, to $(FOO).
- val = expandvars(s[i + 1:end_i - 1]
- .replace("$UNAME_RELEASE",
- _UNAME_RELEASE))
+ val = expandvars(
+ s[i + 1 : end_i - 1].replace(
+ "$UNAME_RELEASE", _UNAME_RELEASE
+ )
+ )
i = end_i
# This is the only place where we don't survive with a
# single token of lookback: 'option env="FOO"' does not
# refer to a constant symbol named "FOO".
- token = \
- val if token in _STRING_LEX or tokens[0] is _T_OPTION \
+ token = (
+ val
+ if token in _STRING_LEX or tokens[0] is _T_OPTION
else self._lookup_const_sym(val)
+ )
elif s.startswith("&&", i):
token = _T_AND
@@ -2533,7 +2604,6 @@ class Kconfig(object):
elif c == "#":
break
-
# Very rare
elif s.startswith("<=", i):
@@ -2552,16 +2622,13 @@ class Kconfig(object):
token = _T_GREATER
i += 1
-
else:
self._parse_error("unknown tokens in line")
-
# Skip trailing whitespace
while i < len(s) and s[i].isspace():
i += 1
-
# Add the token
tokens.append(token)
@@ -2652,7 +2719,6 @@ class Kconfig(object):
# Assigned variable
name = s[:i]
-
# Extract assignment operator (=, :=, or +=) and value
rhs_match = _assignment_rhs_match(s, i)
if not rhs_match:
@@ -2660,7 +2726,6 @@ class Kconfig(object):
op, val = rhs_match.groups()
-
if name in self.variables:
# Already seen variable
var = self.variables[name]
@@ -2686,8 +2751,9 @@ class Kconfig(object):
else: # op == "+="
# += does immediate expansion if the variable was last set
# with :=
- var.value += " " + (val if var.is_recursive else
- self._expand_whole(val, ()))
+ var.value += " " + (
+ val if var.is_recursive else self._expand_whole(val, ())
+ )
def _expand_whole(self, s, args):
# Expands preprocessor macros in all of 's'. Used whenever we don't
@@ -2753,7 +2819,6 @@ class Kconfig(object):
if not match:
self._parse_error("unterminated string")
-
if match.group() == quote:
# Found the end of the string
return (s, match.end())
@@ -2762,7 +2827,7 @@ class Kconfig(object):
# Replace '\x' with 'x'. 'i' ends up pointing to the character
# after 'x', which allows macros to be canceled with '\$(foo)'.
i = match.end()
- s = s[:match.start()] + s[i:]
+ s = s[: match.start()] + s[i:]
elif match.group() == "$(":
# A macro call within the string
@@ -2792,7 +2857,6 @@ class Kconfig(object):
if not match:
self._parse_error("missing end parenthesis in macro expansion")
-
if match.group() == "(":
nesting += 1
i = match.end()
@@ -2805,7 +2869,7 @@ class Kconfig(object):
# Found the end of the macro
- new_args.append(s[arg_start:match.start()])
+ new_args.append(s[arg_start : match.start()])
# $(1) is replaced by the first argument to the function, etc.,
# provided at least that many arguments were passed
@@ -2819,7 +2883,7 @@ class Kconfig(object):
# and also go through the function value path
res += self._fn_val(new_args)
- return (res + s[match.end():], len(res))
+ return (res + s[match.end() :], len(res))
elif match.group() == ",":
i = match.end()
@@ -2827,7 +2891,7 @@ class Kconfig(object):
continue
# Found the end of a macro argument
- new_args.append(s[arg_start:match.start()])
+ new_args.append(s[arg_start : match.start()])
arg_start = i
else: # match.group() == "$("
@@ -2847,13 +2911,17 @@ class Kconfig(object):
if len(args) == 1:
# Plain variable
if var._n_expansions:
- self._parse_error("Preprocessor variable {} recursively "
- "references itself".format(var.name))
+ self._parse_error(
+ "Preprocessor variable {} recursively "
+ "references itself".format(var.name)
+ )
elif var._n_expansions > 100:
# Allow functions to call themselves, but guess that functions
# that are overly recursive are stuck
- self._parse_error("Preprocessor function {} seems stuck "
- "in infinite recursion".format(var.name))
+ self._parse_error(
+ "Preprocessor function {} seems stuck "
+ "in infinite recursion".format(var.name)
+ )
var._n_expansions += 1
res = self._expand_whole(self.variables[fn].value, args)
@@ -2865,8 +2933,9 @@ class Kconfig(object):
py_fn, min_arg, max_arg = self._functions[fn]
- if len(args) - 1 < min_arg or \
- (max_arg is not None and len(args) - 1 > max_arg):
+ if len(args) - 1 < min_arg or (
+ max_arg is not None and len(args) - 1 > max_arg
+ ):
if min_arg == max_arg:
expected_args = min_arg
@@ -2875,10 +2944,12 @@ class Kconfig(object):
else:
expected_args = "{}-{}".format(min_arg, max_arg)
- raise KconfigError("{}:{}: bad number of arguments in call "
- "to {}, expected {}, got {}"
- .format(self.filename, self.linenr, fn,
- expected_args, len(args) - 1))
+ raise KconfigError(
+ "{}:{}: bad number of arguments in call "
+ "to {}, expected {}, got {}".format(
+ self.filename, self.linenr, fn, expected_args, len(args) - 1
+ )
+ )
return py_fn(self, *args)
@@ -2962,7 +3033,7 @@ class Kconfig(object):
node = MenuNode()
node.kconfig = self
node.item = sym
- node.is_menuconfig = (t0 is _T_MENUCONFIG)
+ node.is_menuconfig = t0 is _T_MENUCONFIG
node.prompt = node.help = node.list = None
node.parent = parent
node.filename = self.filename
@@ -2974,8 +3045,11 @@ class Kconfig(object):
self._parse_props(node)
if node.is_menuconfig and not node.prompt:
- self._warn("the menuconfig symbol {} has no prompt"
- .format(sym.name_and_loc))
+ self._warn(
+ "the menuconfig symbol {} has no prompt".format(
+ sym.name_and_loc
+ )
+ )
# Equivalent to
#
@@ -3014,11 +3088,16 @@ class Kconfig(object):
"{}:{}: '{}' not found (in '{}'). Check that "
"environment variables are set correctly (e.g. "
"$srctree, which is {}). Also note that unset "
- "environment variables expand to the empty string."
- .format(self.filename, self.linenr, pattern,
- self._line.strip(),
- "set to '{}'".format(self.srctree)
- if self.srctree else "unset or blank"))
+ "environment variables expand to the empty string.".format(
+ self.filename,
+ self.linenr,
+ pattern,
+ self._line.strip(),
+ "set to '{}'".format(self.srctree)
+ if self.srctree
+ else "unset or blank",
+ )
+ )
for filename in filenames:
self._enter_file(filename)
@@ -3125,20 +3204,28 @@ class Kconfig(object):
# A valid endchoice/endif/endmenu is caught by the 'end_token'
# check above
self._parse_error(
- "no corresponding 'choice'" if t0 is _T_ENDCHOICE else
- "no corresponding 'if'" if t0 is _T_ENDIF else
- "no corresponding 'menu'" if t0 is _T_ENDMENU else
- "unrecognized construct")
+ "no corresponding 'choice'"
+ if t0 is _T_ENDCHOICE
+ else "no corresponding 'if'"
+ if t0 is _T_ENDIF
+ else "no corresponding 'menu'"
+ if t0 is _T_ENDMENU
+ else "unrecognized construct"
+ )
# End of file reached. Return the last node.
if end_token:
raise KconfigError(
- "error: expected '{}' at end of '{}'"
- .format("endchoice" if end_token is _T_ENDCHOICE else
- "endif" if end_token is _T_ENDIF else
- "endmenu",
- self.filename))
+ "error: expected '{}' at end of '{}'".format(
+ "endchoice"
+ if end_token is _T_ENDCHOICE
+ else "endif"
+ if end_token is _T_ENDIF
+ else "endmenu",
+ self.filename,
+ )
+ )
return prev
@@ -3187,8 +3274,7 @@ class Kconfig(object):
if not self._check_token(_T_ON):
self._parse_error("expected 'on' after 'depends'")
- node.dep = self._make_and(node.dep,
- self._expect_expr_and_eol())
+ node.dep = self._make_and(node.dep, self._expect_expr_and_eol())
elif t0 is _T_HELP:
self._parse_help(node)
@@ -3197,42 +3283,40 @@ class Kconfig(object):
if node.item.__class__ is not Symbol:
self._parse_error("only symbols can select")
- node.selects.append((self._expect_nonconst_sym(),
- self._parse_cond()))
+ node.selects.append((self._expect_nonconst_sym(), self._parse_cond()))
elif t0 is None:
# Blank line
continue
elif t0 is _T_DEFAULT:
- node.defaults.append((self._parse_expr(False),
- self._parse_cond()))
+ node.defaults.append((self._parse_expr(False), self._parse_cond()))
elif t0 in _DEF_TOKEN_TO_TYPE:
self._set_type(node.item, _DEF_TOKEN_TO_TYPE[t0])
- node.defaults.append((self._parse_expr(False),
- self._parse_cond()))
+ node.defaults.append((self._parse_expr(False), self._parse_cond()))
elif t0 is _T_PROMPT:
self._parse_prompt(node)
elif t0 is _T_RANGE:
- node.ranges.append((self._expect_sym(), self._expect_sym(),
- self._parse_cond()))
+ node.ranges.append(
+ (self._expect_sym(), self._expect_sym(), self._parse_cond())
+ )
elif t0 is _T_IMPLY:
if node.item.__class__ is not Symbol:
self._parse_error("only symbols can imply")
- node.implies.append((self._expect_nonconst_sym(),
- self._parse_cond()))
+ node.implies.append((self._expect_nonconst_sym(), self._parse_cond()))
elif t0 is _T_VISIBLE:
if not self._check_token(_T_IF):
self._parse_error("expected 'if' after 'visible'")
- node.visibility = self._make_and(node.visibility,
- self._expect_expr_and_eol())
+ node.visibility = self._make_and(
+ node.visibility, self._expect_expr_and_eol()
+ )
elif t0 is _T_OPTION:
if self._check_token(_T_ENV):
@@ -3244,33 +3328,42 @@ class Kconfig(object):
if env_var in os.environ:
node.defaults.append(
- (self._lookup_const_sym(os.environ[env_var]),
- self.y))
+ (self._lookup_const_sym(os.environ[env_var]), self.y)
+ )
else:
- self._warn("{1} has 'option env=\"{0}\"', "
- "but the environment variable {0} is not "
- "set".format(node.item.name, env_var),
- self.filename, self.linenr)
+ self._warn(
+ "{1} has 'option env=\"{0}\"', "
+ "but the environment variable {0} is not "
+ "set".format(node.item.name, env_var),
+ self.filename,
+ self.linenr,
+ )
if env_var != node.item.name:
- self._warn("Kconfiglib expands environment variables "
- "in strings directly, meaning you do not "
- "need 'option env=...' \"bounce\" symbols. "
- "For compatibility with the C tools, "
- "rename {} to {} (so that the symbol name "
- "matches the environment variable name)."
- .format(node.item.name, env_var),
- self.filename, self.linenr)
+ self._warn(
+ "Kconfiglib expands environment variables "
+ "in strings directly, meaning you do not "
+ "need 'option env=...' \"bounce\" symbols. "
+ "For compatibility with the C tools, "
+ "rename {} to {} (so that the symbol name "
+ "matches the environment variable name).".format(
+ node.item.name, env_var
+ ),
+ self.filename,
+ self.linenr,
+ )
elif self._check_token(_T_DEFCONFIG_LIST):
if not self.defconfig_list:
self.defconfig_list = node.item
else:
- self._warn("'option defconfig_list' set on multiple "
- "symbols ({0} and {1}). Only {0} will be "
- "used.".format(self.defconfig_list.name,
- node.item.name),
- self.filename, self.linenr)
+ self._warn(
+ "'option defconfig_list' set on multiple "
+ "symbols ({0} and {1}). Only {0} will be "
+ "used.".format(self.defconfig_list.name, node.item.name),
+ self.filename,
+ self.linenr,
+ )
elif self._check_token(_T_MODULES):
# To reduce warning spam, only warn if 'option modules' is
@@ -3279,20 +3372,24 @@ class Kconfig(object):
# modules besides the kernel yet, and there it's likely to
# keep being called "MODULES".
if node.item is not self.modules:
- self._warn("the 'modules' option is not supported. "
- "Let me know if this is a problem for you, "
- "as it wouldn't be that hard to implement. "
- "Note that modules are supported -- "
- "Kconfiglib just assumes the symbol name "
- "MODULES, like older versions of the C "
- "implementation did when 'option modules' "
- "wasn't used.",
- self.filename, self.linenr)
+ self._warn(
+ "the 'modules' option is not supported. "
+ "Let me know if this is a problem for you, "
+ "as it wouldn't be that hard to implement. "
+ "Note that modules are supported -- "
+ "Kconfiglib just assumes the symbol name "
+ "MODULES, like older versions of the C "
+ "implementation did when 'option modules' "
+ "wasn't used.",
+ self.filename,
+ self.linenr,
+ )
elif self._check_token(_T_ALLNOCONFIG_Y):
if node.item.__class__ is not Symbol:
- self._parse_error("the 'allnoconfig_y' option is only "
- "valid for symbols")
+ self._parse_error(
+ "the 'allnoconfig_y' option is only " "valid for symbols"
+ )
node.item.is_allnoconfig_y = True
@@ -3315,8 +3412,11 @@ class Kconfig(object):
# UNKNOWN is falsy
if sc.orig_type and sc.orig_type is not new_type:
- self._warn("{} defined with multiple types, {} will be used"
- .format(sc.name_and_loc, TYPE_TO_STR[new_type]))
+ self._warn(
+ "{} defined with multiple types, {} will be used".format(
+ sc.name_and_loc, TYPE_TO_STR[new_type]
+ )
+ )
sc.orig_type = new_type
@@ -3326,8 +3426,10 @@ class Kconfig(object):
# multiple times
if node.prompt:
- self._warn(node.item.name_and_loc +
- " defined with multiple prompts in single location")
+ self._warn(
+ node.item.name_and_loc
+ + " defined with multiple prompts in single location"
+ )
prompt = self._tokens[1]
self._tokens_i = 2
@@ -3336,8 +3438,10 @@ class Kconfig(object):
self._parse_error("expected prompt string")
if prompt != prompt.strip():
- self._warn(node.item.name_and_loc +
- " has leading or trailing whitespace in its prompt")
+ self._warn(
+ node.item.name_and_loc
+ + " has leading or trailing whitespace in its prompt"
+ )
# This avoid issues for e.g. reStructuredText documentation, where
# '*prompt *' is invalid
@@ -3347,8 +3451,10 @@ class Kconfig(object):
def _parse_help(self, node):
if node.help is not None:
- self._warn(node.item.name_and_loc + " defined with more than "
- "one help text -- only the last one will be used")
+ self._warn(
+ node.item.name_and_loc + " defined with more than "
+ "one help text -- only the last one will be used"
+ )
# Micro-optimization. This code is pretty hot.
readline = self._readline
@@ -3403,8 +3509,7 @@ class Kconfig(object):
self._line_after_help(line)
def _empty_help(self, node, line):
- self._warn(node.item.name_and_loc +
- " has 'help' but empty help text")
+ self._warn(node.item.name_and_loc + " has 'help' but empty help text")
node.help = ""
if line:
self._line_after_help(line)
@@ -3447,8 +3552,11 @@ class Kconfig(object):
# Return 'and_expr' directly if we have a "single-operand" OR.
# Otherwise, parse the expression on the right and make an OR node.
# This turns A || B || C || D into (OR, A, (OR, B, (OR, C, D))).
- return and_expr if not self._check_token(_T_OR) else \
- (OR, and_expr, self._parse_expr(transform_m))
+ return (
+ and_expr
+ if not self._check_token(_T_OR)
+ else (OR, and_expr, self._parse_expr(transform_m))
+ )
def _parse_and_expr(self, transform_m):
factor = self._parse_factor(transform_m)
@@ -3456,8 +3564,11 @@ class Kconfig(object):
# Return 'factor' directly if we have a "single-operand" AND.
# Otherwise, parse the right operand and make an AND node. This turns
# A && B && C && D into (AND, A, (AND, B, (AND, C, D))).
- return factor if not self._check_token(_T_AND) else \
- (AND, factor, self._parse_and_expr(transform_m))
+ return (
+ factor
+ if not self._check_token(_T_AND)
+ else (AND, factor, self._parse_and_expr(transform_m))
+ )
def _parse_factor(self, transform_m):
token = self._tokens[self._tokens_i]
@@ -3481,8 +3592,7 @@ class Kconfig(object):
# _T_EQUAL, _T_UNEQUAL, etc., deliberately have the same values as
# EQUAL, UNEQUAL, etc., so we can just use the token directly
self._tokens_i += 1
- return (self._tokens[self._tokens_i - 1], token,
- self._expect_sym())
+ return (self._tokens[self._tokens_i - 1], token, self._expect_sym())
if token is _T_NOT:
# token == _T_NOT == NOT
@@ -3689,36 +3799,43 @@ class Kconfig(object):
if cur.item.__class__ in _SYMBOL_CHOICE:
# Propagate 'visible if' and dependencies to the prompt
if cur.prompt:
- cur.prompt = (cur.prompt[0],
- self._make_and(
- cur.prompt[1],
- self._make_and(visible_if, dep)))
+ cur.prompt = (
+ cur.prompt[0],
+ self._make_and(cur.prompt[1], self._make_and(visible_if, dep)),
+ )
# Propagate dependencies to defaults
if cur.defaults:
- cur.defaults = [(default, self._make_and(cond, dep))
- for default, cond in cur.defaults]
+ cur.defaults = [
+ (default, self._make_and(cond, dep))
+ for default, cond in cur.defaults
+ ]
# Propagate dependencies to ranges
if cur.ranges:
- cur.ranges = [(low, high, self._make_and(cond, dep))
- for low, high, cond in cur.ranges]
+ cur.ranges = [
+ (low, high, self._make_and(cond, dep))
+ for low, high, cond in cur.ranges
+ ]
# Propagate dependencies to selects
if cur.selects:
- cur.selects = [(target, self._make_and(cond, dep))
- for target, cond in cur.selects]
+ cur.selects = [
+ (target, self._make_and(cond, dep))
+ for target, cond in cur.selects
+ ]
# Propagate dependencies to implies
if cur.implies:
- cur.implies = [(target, self._make_and(cond, dep))
- for target, cond in cur.implies]
+ cur.implies = [
+ (target, self._make_and(cond, dep))
+ for target, cond in cur.implies
+ ]
elif cur.prompt: # Not a symbol/choice
# Propagate dependencies to the prompt. 'visible if' is only
# propagated to symbols/choices.
- cur.prompt = (cur.prompt[0],
- self._make_and(cur.prompt[1], dep))
+ cur.prompt = (cur.prompt[0], self._make_and(cur.prompt[1], dep))
cur = cur.next
@@ -3744,16 +3861,14 @@ class Kconfig(object):
# Modify the reverse dependencies of the selected symbol
for target, cond in node.selects:
- target.rev_dep = self._make_or(
- target.rev_dep,
- self._make_and(sym, cond))
+ target.rev_dep = self._make_or(target.rev_dep, self._make_and(sym, cond))
# Modify the weak reverse dependencies of the implied
# symbol
for target, cond in node.implies:
target.weak_rev_dep = self._make_or(
- target.weak_rev_dep,
- self._make_and(sym, cond))
+ target.weak_rev_dep, self._make_and(sym, cond)
+ )
#
# Misc.
@@ -3781,82 +3896,106 @@ class Kconfig(object):
for target_sym, _ in sym.selects:
if target_sym.orig_type not in _BOOL_TRISTATE_UNKNOWN:
- self._warn("{} selects the {} symbol {}, which is not "
- "bool or tristate"
- .format(sym.name_and_loc,
- TYPE_TO_STR[target_sym.orig_type],
- target_sym.name_and_loc))
+ self._warn(
+ "{} selects the {} symbol {}, which is not "
+ "bool or tristate".format(
+ sym.name_and_loc,
+ TYPE_TO_STR[target_sym.orig_type],
+ target_sym.name_and_loc,
+ )
+ )
for target_sym, _ in sym.implies:
if target_sym.orig_type not in _BOOL_TRISTATE_UNKNOWN:
- self._warn("{} implies the {} symbol {}, which is not "
- "bool or tristate"
- .format(sym.name_and_loc,
- TYPE_TO_STR[target_sym.orig_type],
- target_sym.name_and_loc))
+ self._warn(
+ "{} implies the {} symbol {}, which is not "
+ "bool or tristate".format(
+ sym.name_and_loc,
+ TYPE_TO_STR[target_sym.orig_type],
+ target_sym.name_and_loc,
+ )
+ )
elif sym.orig_type: # STRING/INT/HEX
for default, _ in sym.defaults:
if default.__class__ is not Symbol:
raise KconfigError(
"the {} symbol {} has a malformed default {} -- "
- "expected a single symbol"
- .format(TYPE_TO_STR[sym.orig_type],
- sym.name_and_loc, expr_str(default)))
+ "expected a single symbol".format(
+ TYPE_TO_STR[sym.orig_type],
+ sym.name_and_loc,
+ expr_str(default),
+ )
+ )
if sym.orig_type is STRING:
- if not default.is_constant and not default.nodes and \
- not default.name.isupper():
+ if (
+ not default.is_constant
+ and not default.nodes
+ and not default.name.isupper()
+ ):
# 'default foo' on a string symbol could be either a symbol
# reference or someone leaving out the quotes. Guess that
# the quotes were left out if 'foo' isn't all-uppercase
# (and no symbol named 'foo' exists).
- self._warn("style: quotes recommended around "
- "default value for string symbol "
- + sym.name_and_loc)
+ self._warn(
+ "style: quotes recommended around "
+ "default value for string symbol " + sym.name_and_loc
+ )
elif not num_ok(default, sym.orig_type): # INT/HEX
- self._warn("the {0} symbol {1} has a non-{0} default {2}"
- .format(TYPE_TO_STR[sym.orig_type],
- sym.name_and_loc,
- default.name_and_loc))
+ self._warn(
+ "the {0} symbol {1} has a non-{0} default {2}".format(
+ TYPE_TO_STR[sym.orig_type],
+ sym.name_and_loc,
+ default.name_and_loc,
+ )
+ )
if sym.selects or sym.implies:
- self._warn("the {} symbol {} has selects or implies"
- .format(TYPE_TO_STR[sym.orig_type],
- sym.name_and_loc))
+ self._warn(
+ "the {} symbol {} has selects or implies".format(
+ TYPE_TO_STR[sym.orig_type], sym.name_and_loc
+ )
+ )
else: # UNKNOWN
- self._warn("{} defined without a type"
- .format(sym.name_and_loc))
-
+ self._warn("{} defined without a type".format(sym.name_and_loc))
if sym.ranges:
if sym.orig_type not in _INT_HEX:
self._warn(
- "the {} symbol {} has ranges, but is not int or hex"
- .format(TYPE_TO_STR[sym.orig_type],
- sym.name_and_loc))
+ "the {} symbol {} has ranges, but is not int or hex".format(
+ TYPE_TO_STR[sym.orig_type], sym.name_and_loc
+ )
+ )
else:
for low, high, _ in sym.ranges:
- if not num_ok(low, sym.orig_type) or \
- not num_ok(high, sym.orig_type):
-
- self._warn("the {0} symbol {1} has a non-{0} "
- "range [{2}, {3}]"
- .format(TYPE_TO_STR[sym.orig_type],
- sym.name_and_loc,
- low.name_and_loc,
- high.name_and_loc))
+ if not num_ok(low, sym.orig_type) or not num_ok(
+ high, sym.orig_type
+ ):
+
+ self._warn(
+ "the {0} symbol {1} has a non-{0} "
+ "range [{2}, {3}]".format(
+ TYPE_TO_STR[sym.orig_type],
+ sym.name_and_loc,
+ low.name_and_loc,
+ high.name_and_loc,
+ )
+ )
def _check_choice_sanity(self):
# Checks various choice properties that are handiest to check after
# parsing. Only generates errors and warnings.
def warn_select_imply(sym, expr, expr_type):
- msg = "the choice symbol {} is {} by the following symbols, but " \
- "select/imply has no effect on choice symbols" \
- .format(sym.name_and_loc, expr_type)
+ msg = (
+ "the choice symbol {} is {} by the following symbols, but "
+ "select/imply has no effect on choice symbols".format(
+ sym.name_and_loc, expr_type
+ )
+ )
# si = select/imply
for si in split_expr(expr, OR):
@@ -3866,9 +4005,11 @@ class Kconfig(object):
for choice in self.unique_choices:
if choice.orig_type not in _BOOL_TRISTATE:
- self._warn("{} defined with type {}"
- .format(choice.name_and_loc,
- TYPE_TO_STR[choice.orig_type]))
+ self._warn(
+ "{} defined with type {}".format(
+ choice.name_and_loc, TYPE_TO_STR[choice.orig_type]
+ )
+ )
for node in choice.nodes:
if node.prompt:
@@ -3879,20 +4020,26 @@ class Kconfig(object):
for default, _ in choice.defaults:
if default.__class__ is not Symbol:
raise KconfigError(
- "{} has a malformed default {}"
- .format(choice.name_and_loc, expr_str(default)))
+ "{} has a malformed default {}".format(
+ choice.name_and_loc, expr_str(default)
+ )
+ )
if default.choice is not choice:
- self._warn("the default selection {} of {} is not "
- "contained in the choice"
- .format(default.name_and_loc,
- choice.name_and_loc))
+ self._warn(
+ "the default selection {} of {} is not "
+ "contained in the choice".format(
+ default.name_and_loc, choice.name_and_loc
+ )
+ )
for sym in choice.syms:
if sym.defaults:
- self._warn("default on the choice symbol {} will have "
- "no effect, as defaults do not affect choice "
- "symbols".format(sym.name_and_loc))
+ self._warn(
+ "default on the choice symbol {} will have "
+ "no effect, as defaults do not affect choice "
+ "symbols".format(sym.name_and_loc)
+ )
if sym.rev_dep is not sym.kconfig.n:
warn_select_imply(sym, sym.rev_dep, "selected")
@@ -3903,19 +4050,28 @@ class Kconfig(object):
for node in sym.nodes:
if node.parent.item is choice:
if not node.prompt:
- self._warn("the choice symbol {} has no prompt"
- .format(sym.name_and_loc))
+ self._warn(
+ "the choice symbol {} has no prompt".format(
+ sym.name_and_loc
+ )
+ )
elif node.prompt:
- self._warn("the choice symbol {} is defined with a "
- "prompt outside the choice"
- .format(sym.name_and_loc))
+ self._warn(
+ "the choice symbol {} is defined with a "
+ "prompt outside the choice".format(sym.name_and_loc)
+ )
def _parse_error(self, msg):
- raise KconfigError("{}error: couldn't parse '{}': {}".format(
- "" if self.filename is None else
- "{}:{}: ".format(self.filename, self.linenr),
- self._line.strip(), msg))
+ raise KconfigError(
+ "{}error: couldn't parse '{}': {}".format(
+ ""
+ if self.filename is None
+ else "{}:{}: ".format(self.filename, self.linenr),
+ self._line.strip(),
+ msg,
+ )
+ )
def _trailing_tokens_error(self):
self._parse_error("extra tokens at end of line")
@@ -3954,8 +4110,11 @@ class Kconfig(object):
# - For Python 3, force the encoding. Forcing the encoding on Python 2
# turns strings into Unicode strings, which gets messy. Python 2
# doesn't decode regular strings anyway.
- return open(filename, "rU" if mode == "r" else mode) if _IS_PY2 else \
- open(filename, mode, encoding=self._encoding)
+ return (
+ open(filename, "rU" if mode == "r" else mode)
+ if _IS_PY2
+ else open(filename, mode, encoding=self._encoding)
+ )
def _check_undef_syms(self):
# Prints warnings for all references to undefined symbols within the
@@ -3992,14 +4151,14 @@ class Kconfig(object):
# symbols, but shouldn't be flagged
#
# - The MODULES symbol always exists
- if not sym.nodes and not is_num(sym.name) and \
- sym.name != "MODULES":
+ if not sym.nodes and not is_num(sym.name) and sym.name != "MODULES":
msg = "undefined symbol {}:".format(sym.name)
for node in self.node_iter():
if sym in node.referenced:
- msg += "\n\n- Referenced at {}:{}:\n\n{}" \
- .format(node.filename, node.linenr, node)
+ msg += "\n\n- Referenced at {}:{}:\n\n{}".format(
+ node.filename, node.linenr, node
+ )
self._warn(msg)
def _warn(self, msg, filename=None, linenr=None):
@@ -4274,6 +4433,7 @@ class Symbol(object):
kconfig:
The Kconfig instance this symbol is from.
"""
+
__slots__ = (
"_cached_assignable",
"_cached_str_val",
@@ -4311,9 +4471,11 @@ class Symbol(object):
"""
See the class documentation.
"""
- if self.orig_type is TRISTATE and \
- (self.choice and self.choice.tri_value == 2 or
- not self.kconfig.modules.tri_value):
+ if self.orig_type is TRISTATE and (
+ self.choice
+ and self.choice.tri_value == 2
+ or not self.kconfig.modules.tri_value
+ ):
return BOOL
@@ -4344,7 +4506,7 @@ class Symbol(object):
# function call (property magic)
vis = self.visibility
- self._write_to_conf = (vis != 0)
+ self._write_to_conf = vis != 0
if self.orig_type in _INT_HEX:
# The C implementation checks the user value against the range in a
@@ -4361,10 +4523,16 @@ class Symbol(object):
# The zeros are from the C implementation running strtoll()
# on empty strings
- low = int(low_expr.str_value, base) if \
- _is_base_n(low_expr.str_value, base) else 0
- high = int(high_expr.str_value, base) if \
- _is_base_n(high_expr.str_value, base) else 0
+ low = (
+ int(low_expr.str_value, base)
+ if _is_base_n(low_expr.str_value, base)
+ else 0
+ )
+ high = (
+ int(high_expr.str_value, base)
+ if _is_base_n(high_expr.str_value, base)
+ else 0
+ )
break
else:
@@ -4381,10 +4549,14 @@ class Symbol(object):
self.kconfig._warn(
"user value {} on the {} symbol {} ignored due to "
"being outside the active range ([{}, {}]) -- falling "
- "back on defaults"
- .format(num2str(user_val), TYPE_TO_STR[self.orig_type],
- self.name_and_loc,
- num2str(low), num2str(high)))
+ "back on defaults".format(
+ num2str(user_val),
+ TYPE_TO_STR[self.orig_type],
+ self.name_and_loc,
+ num2str(low),
+ num2str(high),
+ )
+ )
else:
# If the user value is well-formed and satisfies range
# contraints, it is stored in exactly the same form as
@@ -4424,18 +4596,20 @@ class Symbol(object):
if clamp is not None:
# The value is rewritten to a standard form if it is
# clamped
- val = str(clamp) \
- if self.orig_type is INT else \
- hex(clamp)
+ val = str(clamp) if self.orig_type is INT else hex(clamp)
if has_default:
num2str = str if base == 10 else hex
self.kconfig._warn(
"default value {} on {} clamped to {} due to "
- "being outside the active range ([{}, {}])"
- .format(val_num, self.name_and_loc,
- num2str(clamp), num2str(low),
- num2str(high)))
+ "being outside the active range ([{}, {}])".format(
+ val_num,
+ self.name_and_loc,
+ num2str(clamp),
+ num2str(low),
+ num2str(high),
+ )
+ )
elif self.orig_type is STRING:
if vis and self.user_value is not None:
@@ -4473,8 +4647,10 @@ class Symbol(object):
# Would take some work to give the location here
self.kconfig._warn(
"The {} symbol {} is being evaluated in a logical context "
- "somewhere. It will always evaluate to n."
- .format(TYPE_TO_STR[self.orig_type], self.name_and_loc))
+ "somewhere. It will always evaluate to n.".format(
+ TYPE_TO_STR[self.orig_type], self.name_and_loc
+ )
+ )
self._cached_tri_val = 0
return 0
@@ -4482,7 +4658,7 @@ class Symbol(object):
# Warning: See Symbol._rec_invalidate(), and note that this is a hidden
# function call (property magic)
vis = self.visibility
- self._write_to_conf = (vis != 0)
+ self._write_to_conf = vis != 0
val = 0
@@ -4523,8 +4699,7 @@ class Symbol(object):
# m is promoted to y for (1) bool symbols and (2) symbols with a
# weak_rev_dep (from imply) of y
- if val == 1 and \
- (self.type is BOOL or expr_value(self.weak_rev_dep) == 2):
+ if val == 1 and (self.type is BOOL or expr_value(self.weak_rev_dep) == 2):
val = 2
elif vis == 2:
@@ -4570,19 +4745,17 @@ class Symbol(object):
return ""
if self.orig_type in _BOOL_TRISTATE:
- return "{}{}={}\n" \
- .format(self.kconfig.config_prefix, self.name, val) \
- if val != "n" else \
- "# {}{} is not set\n" \
- .format(self.kconfig.config_prefix, self.name)
+ return (
+ "{}{}={}\n".format(self.kconfig.config_prefix, self.name, val)
+ if val != "n"
+ else "# {}{} is not set\n".format(self.kconfig.config_prefix, self.name)
+ )
if self.orig_type in _INT_HEX:
- return "{}{}={}\n" \
- .format(self.kconfig.config_prefix, self.name, val)
+ return "{}{}={}\n".format(self.kconfig.config_prefix, self.name, val)
# sym.orig_type is STRING
- return '{}{}="{}"\n' \
- .format(self.kconfig.config_prefix, self.name, escape(val))
+ return '{}{}="{}"\n'.format(self.kconfig.config_prefix, self.name, escape(val))
@property
def name_and_loc(self):
@@ -4646,21 +4819,31 @@ class Symbol(object):
return True
# Check if the value is valid for our type
- if not (self.orig_type is BOOL and value in (2, 0) or
- self.orig_type is TRISTATE and value in TRI_TO_STR or
- value.__class__ is str and
- (self.orig_type is STRING or
- self.orig_type is INT and _is_base_n(value, 10) or
- self.orig_type is HEX and _is_base_n(value, 16)
- and int(value, 16) >= 0)):
+ if not (
+ self.orig_type is BOOL
+ and value in (2, 0)
+ or self.orig_type is TRISTATE
+ and value in TRI_TO_STR
+ or value.__class__ is str
+ and (
+ self.orig_type is STRING
+ or self.orig_type is INT
+ and _is_base_n(value, 10)
+ or self.orig_type is HEX
+ and _is_base_n(value, 16)
+ and int(value, 16) >= 0
+ )
+ ):
# Display tristate values as n, m, y in the warning
self.kconfig._warn(
"the value {} is invalid for {}, which has type {} -- "
- "assignment ignored"
- .format(TRI_TO_STR[value] if value in TRI_TO_STR else
- "'{}'".format(value),
- self.name_and_loc, TYPE_TO_STR[self.orig_type]))
+ "assignment ignored".format(
+ TRI_TO_STR[value] if value in TRI_TO_STR else "'{}'".format(value),
+ self.name_and_loc,
+ TYPE_TO_STR[self.orig_type],
+ )
+ )
return False
@@ -4738,17 +4921,28 @@ class Symbol(object):
add('"{}"'.format(node.prompt[0]))
# Only add quotes for non-bool/tristate symbols
- add("value " + (self.str_value if self.orig_type in _BOOL_TRISTATE
- else '"{}"'.format(self.str_value)))
+ add(
+ "value "
+ + (
+ self.str_value
+ if self.orig_type in _BOOL_TRISTATE
+ else '"{}"'.format(self.str_value)
+ )
+ )
if not self.is_constant:
# These aren't helpful to show for constant symbols
if self.user_value is not None:
# Only add quotes for non-bool/tristate symbols
- add("user value " + (TRI_TO_STR[self.user_value]
- if self.orig_type in _BOOL_TRISTATE
- else '"{}"'.format(self.user_value)))
+ add(
+ "user value "
+ + (
+ TRI_TO_STR[self.user_value]
+ if self.orig_type in _BOOL_TRISTATE
+ else '"{}"'.format(self.user_value)
+ )
+ )
add("visibility " + TRI_TO_STR[self.visibility])
@@ -4798,8 +4992,7 @@ class Symbol(object):
Works like Symbol.__str__(), but allows a custom format to be used for
all symbol/choice references. See expr_str().
"""
- return "\n\n".join(node.custom_str(sc_expr_str_fn)
- for node in self.nodes)
+ return "\n\n".join(node.custom_str(sc_expr_str_fn) for node in self.nodes)
#
# Private methods
@@ -4830,18 +5023,18 @@ class Symbol(object):
self.implies = []
self.ranges = []
- self.user_value = \
- self.choice = \
- self.env_var = \
- self._cached_str_val = self._cached_tri_val = self._cached_vis = \
- self._cached_assignable = None
+ self.user_value = (
+ self.choice
+ ) = (
+ self.env_var
+ ) = (
+ self._cached_str_val
+ ) = self._cached_tri_val = self._cached_vis = self._cached_assignable = None
# _write_to_conf is calculated along with the value. If True, the
# Symbol gets a .config entry.
- self.is_allnoconfig_y = \
- self._was_set = \
- self._write_to_conf = False
+ self.is_allnoconfig_y = self._was_set = self._write_to_conf = False
# See Kconfig._build_dep()
self._dependents = set()
@@ -4895,8 +5088,9 @@ class Symbol(object):
def _invalidate(self):
# Marks the symbol as needing to be recalculated
- self._cached_str_val = self._cached_tri_val = self._cached_vis = \
- self._cached_assignable = None
+ self._cached_str_val = (
+ self._cached_tri_val
+ ) = self._cached_vis = self._cached_assignable = None
def _rec_invalidate(self):
# Invalidates the symbol and all items that (possibly) depend on it
@@ -4948,8 +5142,10 @@ class Symbol(object):
return
if self.kconfig._warn_assign_no_prompt:
- self.kconfig._warn(self.name_and_loc + " has no prompt, meaning "
- "user values have no effect on it")
+ self.kconfig._warn(
+ self.name_and_loc + " has no prompt, meaning "
+ "user values have no effect on it"
+ )
def _str_default(self):
# write_min_config() helper function. Returns the value the symbol
@@ -4968,9 +5164,7 @@ class Symbol(object):
val = min(expr_value(default), cond_val)
break
- val = max(expr_value(self.rev_dep),
- expr_value(self.weak_rev_dep),
- val)
+ val = max(expr_value(self.rev_dep), expr_value(self.weak_rev_dep), val)
# Transpose mod to yes if type is bool (possibly due to modules
# being disabled)
@@ -4992,11 +5186,15 @@ class Symbol(object):
# and menus) is selected by some other symbol. Also warn if a symbol
# whose direct dependencies evaluate to m is selected to y.
- msg = "{} has direct dependencies {} with value {}, but is " \
- "currently being {}-selected by the following symbols:" \
- .format(self.name_and_loc, expr_str(self.direct_dep),
- TRI_TO_STR[expr_value(self.direct_dep)],
- TRI_TO_STR[expr_value(self.rev_dep)])
+ msg = (
+ "{} has direct dependencies {} with value {}, but is "
+ "currently being {}-selected by the following symbols:".format(
+ self.name_and_loc,
+ expr_str(self.direct_dep),
+ TRI_TO_STR[expr_value(self.direct_dep)],
+ TRI_TO_STR[expr_value(self.rev_dep)],
+ )
+ )
# The reverse dependencies from each select are ORed together
for select in split_expr(self.rev_dep, OR):
@@ -5010,17 +5208,20 @@ class Symbol(object):
# In both cases, we can split on AND and pick the first operand
selecting_sym = split_expr(select, AND)[0]
- msg += "\n - {}, with value {}, direct dependencies {} " \
- "(value: {})" \
- .format(selecting_sym.name_and_loc,
- selecting_sym.str_value,
- expr_str(selecting_sym.direct_dep),
- TRI_TO_STR[expr_value(selecting_sym.direct_dep)])
+ msg += (
+ "\n - {}, with value {}, direct dependencies {} "
+ "(value: {})".format(
+ selecting_sym.name_and_loc,
+ selecting_sym.str_value,
+ expr_str(selecting_sym.direct_dep),
+ TRI_TO_STR[expr_value(selecting_sym.direct_dep)],
+ )
+ )
if select.__class__ is tuple:
- msg += ", and select condition {} (value: {})" \
- .format(expr_str(select[2]),
- TRI_TO_STR[expr_value(select[2])])
+ msg += ", and select condition {} (value: {})".format(
+ expr_str(select[2]), TRI_TO_STR[expr_value(select[2])]
+ )
self.kconfig._warn(msg)
@@ -5182,6 +5383,7 @@ class Choice(object):
kconfig:
The Kconfig instance this choice is from.
"""
+
__slots__ = (
"_cached_assignable",
"_cached_selection",
@@ -5299,16 +5501,22 @@ class Choice(object):
self._was_set = True
return True
- if not (self.orig_type is BOOL and value in (2, 0) or
- self.orig_type is TRISTATE and value in TRI_TO_STR):
+ if not (
+ self.orig_type is BOOL
+ and value in (2, 0)
+ or self.orig_type is TRISTATE
+ and value in TRI_TO_STR
+ ):
# Display tristate values as n, m, y in the warning
self.kconfig._warn(
"the value {} is invalid for {}, which has type {} -- "
- "assignment ignored"
- .format(TRI_TO_STR[value] if value in TRI_TO_STR else
- "'{}'".format(value),
- self.name_and_loc, TYPE_TO_STR[self.orig_type]))
+ "assignment ignored".format(
+ TRI_TO_STR[value] if value in TRI_TO_STR else "'{}'".format(value),
+ self.name_and_loc,
+ TYPE_TO_STR[self.orig_type],
+ )
+ )
return False
@@ -5346,8 +5554,10 @@ class Choice(object):
Returns a string with information about the choice when it is evaluated
on e.g. the interactive Python prompt.
"""
- fields = ["choice " + self.name if self.name else "choice",
- TYPE_TO_STR[self.type]]
+ fields = [
+ "choice " + self.name if self.name else "choice",
+ TYPE_TO_STR[self.type],
+ ]
add = fields.append
for node in self.nodes:
@@ -5357,14 +5567,13 @@ class Choice(object):
add("mode " + self.str_value)
if self.user_value is not None:
- add('user mode {}'.format(TRI_TO_STR[self.user_value]))
+ add("user mode {}".format(TRI_TO_STR[self.user_value]))
if self.selection:
add("{} selected".format(self.selection.name))
if self.user_selection:
- user_sel_str = "{} selected by user" \
- .format(self.user_selection.name)
+ user_sel_str = "{} selected by user".format(self.user_selection.name)
if self.selection is not self.user_selection:
user_sel_str += " (overridden)"
@@ -5399,8 +5608,7 @@ class Choice(object):
Works like Choice.__str__(), but allows a custom format to be used for
all symbol/choice references. See expr_str().
"""
- return "\n\n".join(node.custom_str(sc_expr_str_fn)
- for node in self.nodes)
+ return "\n\n".join(node.custom_str(sc_expr_str_fn) for node in self.nodes)
#
# Private methods
@@ -5425,9 +5633,9 @@ class Choice(object):
self.syms = []
self.defaults = []
- self.name = \
- self.user_value = self.user_selection = \
- self._cached_vis = self._cached_assignable = None
+ self.name = (
+ self.user_value
+ ) = self.user_selection = self._cached_vis = self._cached_assignable = None
self._cached_selection = _NO_CACHED_SELECTION
@@ -5644,6 +5852,7 @@ class MenuNode(object):
kconfig:
The Kconfig instance the menu node is from.
"""
+
__slots__ = (
"dep",
"filename",
@@ -5658,7 +5867,6 @@ class MenuNode(object):
"parent",
"prompt",
"visibility",
-
# Properties
"defaults",
"selects",
@@ -5689,32 +5897,28 @@ class MenuNode(object):
"""
See the class documentation.
"""
- return [(default, self._strip_dep(cond))
- for default, cond in self.defaults]
+ return [(default, self._strip_dep(cond)) for default, cond in self.defaults]
@property
def orig_selects(self):
"""
See the class documentation.
"""
- return [(select, self._strip_dep(cond))
- for select, cond in self.selects]
+ return [(select, self._strip_dep(cond)) for select, cond in self.selects]
@property
def orig_implies(self):
"""
See the class documentation.
"""
- return [(imply, self._strip_dep(cond))
- for imply, cond in self.implies]
+ return [(imply, self._strip_dep(cond)) for imply, cond in self.implies]
@property
def orig_ranges(self):
"""
See the class documentation.
"""
- return [(low, high, self._strip_dep(cond))
- for low, high, cond in self.ranges]
+ return [(low, high, self._strip_dep(cond)) for low, high, cond in self.ranges]
@property
def referenced(self):
@@ -5774,8 +5978,11 @@ class MenuNode(object):
add("menu node for comment")
if self.prompt:
- add('prompt "{}" (visibility {})'.format(
- self.prompt[0], TRI_TO_STR[expr_value(self.prompt[1])]))
+ add(
+ 'prompt "{}" (visibility {})'.format(
+ self.prompt[0], TRI_TO_STR[expr_value(self.prompt[1])]
+ )
+ )
if self.item.__class__ is Symbol and self.is_menuconfig:
add("is menuconfig")
@@ -5822,20 +6029,20 @@ class MenuNode(object):
Works like MenuNode.__str__(), but allows a custom format to be used
for all symbol/choice references. See expr_str().
"""
- return self._menu_comment_node_str(sc_expr_str_fn) \
- if self.item in _MENU_COMMENT else \
- self._sym_choice_node_str(sc_expr_str_fn)
+ return (
+ self._menu_comment_node_str(sc_expr_str_fn)
+ if self.item in _MENU_COMMENT
+ else self._sym_choice_node_str(sc_expr_str_fn)
+ )
def _menu_comment_node_str(self, sc_expr_str_fn):
- s = '{} "{}"'.format("menu" if self.item is MENU else "comment",
- self.prompt[0])
+ s = '{} "{}"'.format("menu" if self.item is MENU else "comment", self.prompt[0])
if self.dep is not self.kconfig.y:
s += "\n\tdepends on {}".format(expr_str(self.dep, sc_expr_str_fn))
if self.item is MENU and self.visibility is not self.kconfig.y:
- s += "\n\tvisible if {}".format(expr_str(self.visibility,
- sc_expr_str_fn))
+ s += "\n\tvisible if {}".format(expr_str(self.visibility, sc_expr_str_fn))
return s
@@ -5851,8 +6058,7 @@ class MenuNode(object):
sc = self.item
if sc.__class__ is Symbol:
- lines = [("menuconfig " if self.is_menuconfig else "config ")
- + sc.name]
+ lines = [("menuconfig " if self.is_menuconfig else "config ") + sc.name]
else:
lines = ["choice " + sc.name if sc.name else "choice"]
@@ -5868,8 +6074,9 @@ class MenuNode(object):
# Symbol defined without a type (which generates a warning)
prefix = "prompt"
- indent_add_cond(prefix + ' "{}"'.format(escape(self.prompt[0])),
- self.orig_prompt[1])
+ indent_add_cond(
+ prefix + ' "{}"'.format(escape(self.prompt[0])), self.orig_prompt[1]
+ )
if sc.__class__ is Symbol:
if sc.is_allnoconfig_y:
@@ -5886,13 +6093,12 @@ class MenuNode(object):
for low, high, cond in self.orig_ranges:
indent_add_cond(
- "range {} {}".format(sc_expr_str_fn(low),
- sc_expr_str_fn(high)),
- cond)
+ "range {} {}".format(sc_expr_str_fn(low), sc_expr_str_fn(high)),
+ cond,
+ )
for default, cond in self.orig_defaults:
- indent_add_cond("default " + expr_str(default, sc_expr_str_fn),
- cond)
+ indent_add_cond("default " + expr_str(default, sc_expr_str_fn), cond)
if sc.__class__ is Choice and sc.is_optional:
indent_add("optional")
@@ -5954,6 +6160,7 @@ class Variable(object):
is_recursive:
True if the variable is recursive (defined with =).
"""
+
__slots__ = (
"_n_expansions",
"is_recursive",
@@ -5979,10 +6186,9 @@ class Variable(object):
return self.kconfig._fn_val((self.name,) + args)
def __repr__(self):
- return "<variable {}, {}, value '{}'>" \
- .format(self.name,
- "recursive" if self.is_recursive else "immediate",
- self.value)
+ return "<variable {}, {}, value '{}'>".format(
+ self.name, "recursive" if self.is_recursive else "immediate", self.value
+ )
class KconfigError(Exception):
@@ -5993,6 +6199,7 @@ class KconfigError(Exception):
KconfigSyntaxError alias is only maintained for backwards compatibility.
"""
+
KconfigSyntaxError = KconfigError # Backwards compatibility
@@ -6010,7 +6217,8 @@ class _KconfigIOError(IOError):
def __init__(self, ioerror, msg):
self.msg = msg
super(_KconfigIOError, self).__init__(
- ioerror.errno, ioerror.strerror, ioerror.filename)
+ ioerror.errno, ioerror.strerror, ioerror.filename
+ )
def __str__(self):
return self.msg
@@ -6070,12 +6278,19 @@ def expr_value(expr):
# parse as numbers
comp = _strcmp(v1.str_value, v2.str_value)
- return 2*(comp == 0 if rel is EQUAL else
- comp != 0 if rel is UNEQUAL else
- comp < 0 if rel is LESS else
- comp <= 0 if rel is LESS_EQUAL else
- comp > 0 if rel is GREATER else
- comp >= 0)
+ return 2 * (
+ comp == 0
+ if rel is EQUAL
+ else comp != 0
+ if rel is UNEQUAL
+ else comp < 0
+ if rel is LESS
+ else comp <= 0
+ if rel is LESS_EQUAL
+ else comp > 0
+ if rel is GREATER
+ else comp >= 0
+ )
def standard_sc_expr_str(sc):
@@ -6115,14 +6330,18 @@ def expr_str(expr, sc_expr_str_fn=standard_sc_expr_str):
return sc_expr_str_fn(expr)
if expr[0] is AND:
- return "{} && {}".format(_parenthesize(expr[1], OR, sc_expr_str_fn),
- _parenthesize(expr[2], OR, sc_expr_str_fn))
+ return "{} && {}".format(
+ _parenthesize(expr[1], OR, sc_expr_str_fn),
+ _parenthesize(expr[2], OR, sc_expr_str_fn),
+ )
if expr[0] is OR:
# This turns A && B || C && D into "(A && B) || (C && D)", which is
# redundant, but more readable
- return "{} || {}".format(_parenthesize(expr[1], AND, sc_expr_str_fn),
- _parenthesize(expr[2], AND, sc_expr_str_fn))
+ return "{} || {}".format(
+ _parenthesize(expr[1], AND, sc_expr_str_fn),
+ _parenthesize(expr[2], AND, sc_expr_str_fn),
+ )
if expr[0] is NOT:
if expr[1].__class__ is tuple:
@@ -6133,8 +6352,9 @@ def expr_str(expr, sc_expr_str_fn=standard_sc_expr_str):
#
# Relation operands are always symbols (quoted strings are constant
# symbols)
- return "{} {} {}".format(sc_expr_str_fn(expr[1]), REL_TO_STR[expr[0]],
- sc_expr_str_fn(expr[2]))
+ return "{} {} {}".format(
+ sc_expr_str_fn(expr[1]), REL_TO_STR[expr[0]], sc_expr_str_fn(expr[2])
+ )
def expr_items(expr):
@@ -6216,7 +6436,7 @@ def escape(s):
replaced by \" and \\, respectively.
"""
# \ must be escaped before " to avoid double escaping
- return s.replace("\\", r"\\").replace('"', r'\"')
+ return s.replace("\\", r"\\").replace('"', r"\"")
def unescape(s):
@@ -6226,6 +6446,7 @@ def unescape(s):
"""
return _unescape_sub(r"\1", s)
+
# unescape() helper
_unescape_sub = re.compile(r"\\(.)").sub
@@ -6245,15 +6466,16 @@ def standard_kconfig(description=None):
import argparse
parser = argparse.ArgumentParser(
- formatter_class=argparse.RawDescriptionHelpFormatter,
- description=description)
+ formatter_class=argparse.RawDescriptionHelpFormatter, description=description
+ )
parser.add_argument(
"kconfig",
metavar="KCONFIG",
default="Kconfig",
nargs="?",
- help="Top-level Kconfig file (default: Kconfig)")
+ help="Top-level Kconfig file (default: Kconfig)",
+ )
return Kconfig(parser.parse_args().kconfig, suppress_traceback=True)
@@ -6299,16 +6521,20 @@ def load_allconfig(kconf, filename):
try:
print(kconf.load_config("all.config", False))
except EnvironmentError as e2:
- sys.exit("error: KCONFIG_ALLCONFIG is set, but neither {} "
- "nor all.config could be opened: {}, {}"
- .format(filename, std_msg(e1), std_msg(e2)))
+ sys.exit(
+ "error: KCONFIG_ALLCONFIG is set, but neither {} "
+ "nor all.config could be opened: {}, {}".format(
+ filename, std_msg(e1), std_msg(e2)
+ )
+ )
else:
try:
print(kconf.load_config(allconfig, False))
except EnvironmentError as e:
- sys.exit("error: KCONFIG_ALLCONFIG is set to '{}', which "
- "could not be opened: {}"
- .format(allconfig, std_msg(e)))
+ sys.exit(
+ "error: KCONFIG_ALLCONFIG is set to '{}', which "
+ "could not be opened: {}".format(allconfig, std_msg(e))
+ )
kconf.warn_assign_override = old_warn_assign_override
kconf.warn_assign_redun = old_warn_assign_redun
@@ -6332,8 +6558,11 @@ def _visibility(sc):
vis = max(vis, expr_value(node.prompt[1]))
if sc.__class__ is Symbol and sc.choice:
- if sc.choice.orig_type is TRISTATE and \
- sc.orig_type is not TRISTATE and sc.choice.tri_value != 2:
+ if (
+ sc.choice.orig_type is TRISTATE
+ and sc.orig_type is not TRISTATE
+ and sc.choice.tri_value != 2
+ ):
# Non-tristate choice symbols are only visible in y mode
return 0
@@ -6407,8 +6636,11 @@ def _sym_to_num(sym):
# For BOOL and TRISTATE, n/m/y count as 0/1/2. This mirrors 9059a3493ef
# ("kconfig: fix relational operators for bool and tristate symbols") in
# the C implementation.
- return sym.tri_value if sym.orig_type in _BOOL_TRISTATE else \
- int(sym.str_value, _TYPE_TO_BASE[sym.orig_type])
+ return (
+ sym.tri_value
+ if sym.orig_type in _BOOL_TRISTATE
+ else int(sym.str_value, _TYPE_TO_BASE[sym.orig_type])
+ )
def _touch_dep_file(path, sym_name):
@@ -6421,8 +6653,7 @@ def _touch_dep_file(path, sym_name):
os.makedirs(sym_path_dir, 0o755)
# A kind of truncating touch, mirroring the C tools
- os.close(os.open(
- sym_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644))
+ os.close(os.open(sym_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644))
def _save_old(path):
@@ -6431,6 +6662,7 @@ def _save_old(path):
def copy(src, dst):
# Import as needed, to save some startup time
import shutil
+
shutil.copyfile(src, dst)
if islink(path):
@@ -6463,8 +6695,8 @@ def _locs(sc):
if sc.nodes:
return "(defined at {})".format(
- ", ".join("{0.filename}:{0.linenr}".format(node)
- for node in sc.nodes))
+ ", ".join("{0.filename}:{0.linenr}".format(node) for node in sc.nodes)
+ )
return "(undefined)"
@@ -6491,13 +6723,13 @@ def _expr_depends_on(expr, sym):
elif left is not sym:
return False
- return (expr[0] is EQUAL and right is sym.kconfig.m or
- right is sym.kconfig.y) or \
- (expr[0] is UNEQUAL and right is sym.kconfig.n)
+ return (
+ expr[0] is EQUAL and right is sym.kconfig.m or right is sym.kconfig.y
+ ) or (expr[0] is UNEQUAL and right is sym.kconfig.n)
- return expr[0] is AND and \
- (_expr_depends_on(expr[1], sym) or
- _expr_depends_on(expr[2], sym))
+ return expr[0] is AND and (
+ _expr_depends_on(expr[1], sym) or _expr_depends_on(expr[2], sym)
+ )
def _auto_menu_dep(node1, node2):
@@ -6505,8 +6737,7 @@ def _auto_menu_dep(node1, node2):
# node2 has a prompt, we check its condition. Otherwise, we look directly
# at node2.dep.
- return _expr_depends_on(node2.prompt[1] if node2.prompt else node2.dep,
- node1.item)
+ return _expr_depends_on(node2.prompt[1] if node2.prompt else node2.dep, node1.item)
def _flatten(node):
@@ -6521,8 +6752,7 @@ def _flatten(node):
# you enter the choice at some location with a prompt.
while node:
- if node.list and not node.prompt and \
- node.item.__class__ is not Choice:
+ if node.list and not node.prompt and node.item.__class__ is not Choice:
last_node = node.list
while 1:
@@ -6637,9 +6867,11 @@ def _check_dep_loop_sym(sym, ignore_choice):
#
# Since we aren't entering the choice via a choice symbol, all
# choice symbols need to be checked, hence the None.
- loop = _check_dep_loop_choice(dep, None) \
- if dep.__class__ is Choice \
- else _check_dep_loop_sym(dep, False)
+ loop = (
+ _check_dep_loop_choice(dep, None)
+ if dep.__class__ is Choice
+ else _check_dep_loop_sym(dep, False)
+ )
if loop:
# Dependency loop found
@@ -6711,8 +6943,7 @@ def _found_dep_loop(loop, cur):
# Yep, we have the entire loop. Throw an exception that shows it.
- msg = "\nDependency loop\n" \
- "===============\n\n"
+ msg = "\nDependency loop\n" "===============\n\n"
for item in loop:
if item is not loop[0]:
@@ -6720,8 +6951,7 @@ def _found_dep_loop(loop, cur):
if item.__class__ is Symbol and item.choice:
msg += "the choice symbol "
- msg += "{}, with definition...\n\n{}\n\n" \
- .format(item.name_and_loc, item)
+ msg += "{}, with definition...\n\n{}\n\n".format(item.name_and_loc, item)
# Small wart: Since we reuse the already calculated
# Symbol/Choice._dependents sets for recursive dependency detection, we
@@ -6738,12 +6968,14 @@ def _found_dep_loop(loop, cur):
if item.__class__ is Symbol:
if item.rev_dep is not item.kconfig.n:
- msg += "(select-related dependencies: {})\n\n" \
- .format(expr_str(item.rev_dep))
+ msg += "(select-related dependencies: {})\n\n".format(
+ expr_str(item.rev_dep)
+ )
if item.weak_rev_dep is not item.kconfig.n:
- msg += "(imply-related dependencies: {})\n\n" \
- .format(expr_str(item.rev_dep))
+ msg += "(imply-related dependencies: {})\n\n".format(
+ expr_str(item.rev_dep)
+ )
msg += "...depends again on " + loop[0].name_and_loc
@@ -6765,11 +6997,14 @@ def _decoding_error(e, filename, macro_linenr=None):
"Problematic data: {}\n"
"Reason: {}".format(
e.encoding,
- "'{}'".format(filename) if macro_linenr is None else
- "output from macro at {}:{}".format(filename, macro_linenr),
- e.object[max(e.start - 40, 0):e.end + 40],
- e.object[e.start:e.end],
- e.reason))
+ "'{}'".format(filename)
+ if macro_linenr is None
+ else "output from macro at {}:{}".format(filename, macro_linenr),
+ e.object[max(e.start - 40, 0) : e.end + 40],
+ e.object[e.start : e.end],
+ e.reason,
+ )
+ )
def _warn_verbose_deprecated(fn_name):
@@ -6779,7 +7014,8 @@ def _warn_verbose_deprecated(fn_name):
"and is always generated. Do e.g. print(kconf.{0}()) if you want to "
"want to show a message like \"Loaded configuration '.config'\" on "
"stdout. The old API required ugly hacks to reuse messages in "
- "configuration interfaces.\n".format(fn_name))
+ "configuration interfaces.\n".format(fn_name)
+ )
# Predefined preprocessor functions
@@ -6808,8 +7044,7 @@ def _warning_if_fn(kconf, _, cond, msg):
def _error_if_fn(kconf, _, cond, msg):
if cond == "y":
- raise KconfigError("{}:{}: {}".format(
- kconf.filename, kconf.linenr, msg))
+ raise KconfigError("{}:{}: {}".format(kconf.filename, kconf.linenr, msg))
return ""
@@ -6829,9 +7064,11 @@ def _shell_fn(kconf, _, command):
_decoding_error(e, kconf.filename, kconf.linenr)
if stderr:
- kconf._warn("'{}' wrote to stderr: {}".format(
- command, "\n".join(stderr.splitlines())),
- kconf.filename, kconf.linenr)
+ kconf._warn(
+ "'{}' wrote to stderr: {}".format(command, "\n".join(stderr.splitlines())),
+ kconf.filename,
+ kconf.linenr,
+ )
# Universal newlines with splitlines() (to prevent e.g. stray \r's in
# command output on Windows), trailing newline removal, and
@@ -6842,6 +7079,7 @@ def _shell_fn(kconf, _, command):
# parameter was added in 3.6), so we do this manual version instead.
return "\n".join(stdout.splitlines()).rstrip("\n").replace("\n", " ")
+
#
# Global constants
#
@@ -6871,6 +7109,7 @@ try:
except AttributeError:
# Only import as needed, to save some startup time
import platform
+
_UNAME_RELEASE = platform.uname()[2]
# The token and type constants below are safe to test with 'is', which is a bit
@@ -6940,112 +7179,112 @@ except AttributeError:
# Keyword to token map, with the get() method assigned directly as a small
# optimization
_get_keyword = {
- "---help---": _T_HELP,
- "allnoconfig_y": _T_ALLNOCONFIG_Y,
- "bool": _T_BOOL,
- "boolean": _T_BOOL,
- "choice": _T_CHOICE,
- "comment": _T_COMMENT,
- "config": _T_CONFIG,
- "def_bool": _T_DEF_BOOL,
- "def_hex": _T_DEF_HEX,
- "def_int": _T_DEF_INT,
- "def_string": _T_DEF_STRING,
- "def_tristate": _T_DEF_TRISTATE,
- "default": _T_DEFAULT,
+ "---help---": _T_HELP,
+ "allnoconfig_y": _T_ALLNOCONFIG_Y,
+ "bool": _T_BOOL,
+ "boolean": _T_BOOL,
+ "choice": _T_CHOICE,
+ "comment": _T_COMMENT,
+ "config": _T_CONFIG,
+ "def_bool": _T_DEF_BOOL,
+ "def_hex": _T_DEF_HEX,
+ "def_int": _T_DEF_INT,
+ "def_string": _T_DEF_STRING,
+ "def_tristate": _T_DEF_TRISTATE,
+ "default": _T_DEFAULT,
"defconfig_list": _T_DEFCONFIG_LIST,
- "depends": _T_DEPENDS,
- "endchoice": _T_ENDCHOICE,
- "endif": _T_ENDIF,
- "endmenu": _T_ENDMENU,
- "env": _T_ENV,
- "grsource": _T_ORSOURCE, # Backwards compatibility
- "gsource": _T_OSOURCE, # Backwards compatibility
- "help": _T_HELP,
- "hex": _T_HEX,
- "if": _T_IF,
- "imply": _T_IMPLY,
- "int": _T_INT,
- "mainmenu": _T_MAINMENU,
- "menu": _T_MENU,
- "menuconfig": _T_MENUCONFIG,
- "modules": _T_MODULES,
- "on": _T_ON,
- "option": _T_OPTION,
- "optional": _T_OPTIONAL,
- "orsource": _T_ORSOURCE,
- "osource": _T_OSOURCE,
- "prompt": _T_PROMPT,
- "range": _T_RANGE,
- "rsource": _T_RSOURCE,
- "select": _T_SELECT,
- "source": _T_SOURCE,
- "string": _T_STRING,
- "tristate": _T_TRISTATE,
- "visible": _T_VISIBLE,
+ "depends": _T_DEPENDS,
+ "endchoice": _T_ENDCHOICE,
+ "endif": _T_ENDIF,
+ "endmenu": _T_ENDMENU,
+ "env": _T_ENV,
+ "grsource": _T_ORSOURCE, # Backwards compatibility
+ "gsource": _T_OSOURCE, # Backwards compatibility
+ "help": _T_HELP,
+ "hex": _T_HEX,
+ "if": _T_IF,
+ "imply": _T_IMPLY,
+ "int": _T_INT,
+ "mainmenu": _T_MAINMENU,
+ "menu": _T_MENU,
+ "menuconfig": _T_MENUCONFIG,
+ "modules": _T_MODULES,
+ "on": _T_ON,
+ "option": _T_OPTION,
+ "optional": _T_OPTIONAL,
+ "orsource": _T_ORSOURCE,
+ "osource": _T_OSOURCE,
+ "prompt": _T_PROMPT,
+ "range": _T_RANGE,
+ "rsource": _T_RSOURCE,
+ "select": _T_SELECT,
+ "source": _T_SOURCE,
+ "string": _T_STRING,
+ "tristate": _T_TRISTATE,
+ "visible": _T_VISIBLE,
}.get
# The constants below match the value of the corresponding tokens to remove the
# need for conversion
# Node types
-MENU = _T_MENU
+MENU = _T_MENU
COMMENT = _T_COMMENT
# Expression types
-AND = _T_AND
-OR = _T_OR
-NOT = _T_NOT
-EQUAL = _T_EQUAL
-UNEQUAL = _T_UNEQUAL
-LESS = _T_LESS
-LESS_EQUAL = _T_LESS_EQUAL
-GREATER = _T_GREATER
+AND = _T_AND
+OR = _T_OR
+NOT = _T_NOT
+EQUAL = _T_EQUAL
+UNEQUAL = _T_UNEQUAL
+LESS = _T_LESS
+LESS_EQUAL = _T_LESS_EQUAL
+GREATER = _T_GREATER
GREATER_EQUAL = _T_GREATER_EQUAL
REL_TO_STR = {
- EQUAL: "=",
- UNEQUAL: "!=",
- LESS: "<",
- LESS_EQUAL: "<=",
- GREATER: ">",
+ EQUAL: "=",
+ UNEQUAL: "!=",
+ LESS: "<",
+ LESS_EQUAL: "<=",
+ GREATER: ">",
GREATER_EQUAL: ">=",
}
# Symbol/choice types. UNKNOWN is 0 (falsy) to simplify some checks.
# Client code shouldn't rely on it though, as it was non-zero in
# older versions.
-UNKNOWN = 0
-BOOL = _T_BOOL
+UNKNOWN = 0
+BOOL = _T_BOOL
TRISTATE = _T_TRISTATE
-STRING = _T_STRING
-INT = _T_INT
-HEX = _T_HEX
+STRING = _T_STRING
+INT = _T_INT
+HEX = _T_HEX
TYPE_TO_STR = {
- UNKNOWN: "unknown",
- BOOL: "bool",
+ UNKNOWN: "unknown",
+ BOOL: "bool",
TRISTATE: "tristate",
- STRING: "string",
- INT: "int",
- HEX: "hex",
+ STRING: "string",
+ INT: "int",
+ HEX: "hex",
}
# Used in comparisons. 0 means the base is inferred from the format of the
# string.
_TYPE_TO_BASE = {
- HEX: 16,
- INT: 10,
- STRING: 0,
- UNKNOWN: 0,
+ HEX: 16,
+ INT: 10,
+ STRING: 0,
+ UNKNOWN: 0,
}
# def_bool -> BOOL, etc.
_DEF_TOKEN_TO_TYPE = {
- _T_DEF_BOOL: BOOL,
- _T_DEF_HEX: HEX,
- _T_DEF_INT: INT,
- _T_DEF_STRING: STRING,
+ _T_DEF_BOOL: BOOL,
+ _T_DEF_HEX: HEX,
+ _T_DEF_INT: INT,
+ _T_DEF_STRING: STRING,
_T_DEF_TRISTATE: TRISTATE,
}
@@ -7056,91 +7295,115 @@ _DEF_TOKEN_TO_TYPE = {
# Identifier-like lexemes ("missing quotes") are also treated as strings after
# these tokens. _T_CHOICE is included to avoid symbols being registered for
# named choices.
-_STRING_LEX = frozenset({
- _T_BOOL,
- _T_CHOICE,
- _T_COMMENT,
- _T_HEX,
- _T_INT,
- _T_MAINMENU,
- _T_MENU,
- _T_ORSOURCE,
- _T_OSOURCE,
- _T_PROMPT,
- _T_RSOURCE,
- _T_SOURCE,
- _T_STRING,
- _T_TRISTATE,
-})
+_STRING_LEX = frozenset(
+ {
+ _T_BOOL,
+ _T_CHOICE,
+ _T_COMMENT,
+ _T_HEX,
+ _T_INT,
+ _T_MAINMENU,
+ _T_MENU,
+ _T_ORSOURCE,
+ _T_OSOURCE,
+ _T_PROMPT,
+ _T_RSOURCE,
+ _T_SOURCE,
+ _T_STRING,
+ _T_TRISTATE,
+ }
+)
# Various sets for quick membership tests. Gives a single global lookup and
# avoids creating temporary dicts/tuples.
-_TYPE_TOKENS = frozenset({
- _T_BOOL,
- _T_TRISTATE,
- _T_INT,
- _T_HEX,
- _T_STRING,
-})
-
-_SOURCE_TOKENS = frozenset({
- _T_SOURCE,
- _T_RSOURCE,
- _T_OSOURCE,
- _T_ORSOURCE,
-})
-
-_REL_SOURCE_TOKENS = frozenset({
- _T_RSOURCE,
- _T_ORSOURCE,
-})
+_TYPE_TOKENS = frozenset(
+ {
+ _T_BOOL,
+ _T_TRISTATE,
+ _T_INT,
+ _T_HEX,
+ _T_STRING,
+ }
+)
+
+_SOURCE_TOKENS = frozenset(
+ {
+ _T_SOURCE,
+ _T_RSOURCE,
+ _T_OSOURCE,
+ _T_ORSOURCE,
+ }
+)
+
+_REL_SOURCE_TOKENS = frozenset(
+ {
+ _T_RSOURCE,
+ _T_ORSOURCE,
+ }
+)
# Obligatory (non-optional) sources
-_OBL_SOURCE_TOKENS = frozenset({
- _T_SOURCE,
- _T_RSOURCE,
-})
-
-_BOOL_TRISTATE = frozenset({
- BOOL,
- TRISTATE,
-})
-
-_BOOL_TRISTATE_UNKNOWN = frozenset({
- BOOL,
- TRISTATE,
- UNKNOWN,
-})
-
-_INT_HEX = frozenset({
- INT,
- HEX,
-})
-
-_SYMBOL_CHOICE = frozenset({
- Symbol,
- Choice,
-})
-
-_MENU_COMMENT = frozenset({
- MENU,
- COMMENT,
-})
-
-_EQUAL_UNEQUAL = frozenset({
- EQUAL,
- UNEQUAL,
-})
-
-_RELATIONS = frozenset({
- EQUAL,
- UNEQUAL,
- LESS,
- LESS_EQUAL,
- GREATER,
- GREATER_EQUAL,
-})
+_OBL_SOURCE_TOKENS = frozenset(
+ {
+ _T_SOURCE,
+ _T_RSOURCE,
+ }
+)
+
+_BOOL_TRISTATE = frozenset(
+ {
+ BOOL,
+ TRISTATE,
+ }
+)
+
+_BOOL_TRISTATE_UNKNOWN = frozenset(
+ {
+ BOOL,
+ TRISTATE,
+ UNKNOWN,
+ }
+)
+
+_INT_HEX = frozenset(
+ {
+ INT,
+ HEX,
+ }
+)
+
+_SYMBOL_CHOICE = frozenset(
+ {
+ Symbol,
+ Choice,
+ }
+)
+
+_MENU_COMMENT = frozenset(
+ {
+ MENU,
+ COMMENT,
+ }
+)
+
+_EQUAL_UNEQUAL = frozenset(
+ {
+ EQUAL,
+ UNEQUAL,
+ }
+)
+
+_RELATIONS = frozenset(
+ {
+ EQUAL,
+ UNEQUAL,
+ LESS,
+ LESS_EQUAL,
+ GREATER,
+ GREATER_EQUAL,
+ }
+)
# Helper functions for getting compiled regular expressions, with the needed
# matching function returned directly as a small optimization.
@@ -7189,7 +7452,7 @@ _string_special_search = _re_search(r'"|\'|\\|\$\(')
# Special characters/strings while expanding a symbol name. Also includes
# end-of-line, in case the macro is the last thing on the line.
-_name_special_search = _re_search(r'[^A-Za-z0-9_$/.-]|\$\(|$')
+_name_special_search = _re_search(r"[^A-Za-z0-9_$/.-]|\$\(|$")
# A valid right-hand side for an assignment to a string symbol in a .config
# file, including escaped characters. Extracts the contents.
diff --git a/util/run_ects.py b/util/run_ects.py
index 9178328e5f..9293f60779 100644
--- a/util/run_ects.py
+++ b/util/run_ects.py
@@ -16,81 +16,81 @@ import subprocess
import sys
# List of tests to run.
-TESTS = ['meta', 'gpio', 'hook', 'i2c', 'interrupt', 'mutex', 'task', 'timer']
+TESTS = ["meta", "gpio", "hook", "i2c", "interrupt", "mutex", "task", "timer"]
class CtsRunner(object):
- """Class running eCTS tests."""
-
- def __init__(self, ec_dir, dryrun):
- self.ec_dir = ec_dir
- self.cts_py = []
- if dryrun:
- self.cts_py += ['echo']
- self.cts_py += [os.path.join(ec_dir, 'cts/cts.py')]
-
- def run_cmd(self, cmd):
- try:
- rc = subprocess.call(cmd)
- if rc != 0:
- return False
- except OSError:
- return False
- return True
-
- def run_test(self, test):
- cmd = self.cts_py + ['-m', test]
- self.run_cmd(cmd)
-
- def run(self, tests):
- for test in tests:
- logging.info('Running', test, 'test.')
- self.run_test(test)
-
- def sync(self):
- logging.info('Syncing tree...')
- os.chdir(self.ec_dir)
- cmd = ['repo', 'sync', '.']
- return self.run_cmd(cmd)
-
- def upload(self):
- logging.info('Uploading results...')
+ """Class running eCTS tests."""
+
+ def __init__(self, ec_dir, dryrun):
+ self.ec_dir = ec_dir
+ self.cts_py = []
+ if dryrun:
+ self.cts_py += ["echo"]
+ self.cts_py += [os.path.join(ec_dir, "cts/cts.py")]
+
+ def run_cmd(self, cmd):
+ try:
+ rc = subprocess.call(cmd)
+ if rc != 0:
+ return False
+ except OSError:
+ return False
+ return True
+
+ def run_test(self, test):
+ cmd = self.cts_py + ["-m", test]
+ self.run_cmd(cmd)
+
+ def run(self, tests):
+ for test in tests:
+ logging.info("Running", test, "test.")
+ self.run_test(test)
+
+ def sync(self):
+ logging.info("Syncing tree...")
+ os.chdir(self.ec_dir)
+ cmd = ["repo", "sync", "."]
+ return self.run_cmd(cmd)
+
+ def upload(self):
+ logging.info("Uploading results...")
def main():
- if not os.path.exists('/etc/cros_chroot_version'):
- logging.error('This script has to run inside chroot.')
- sys.exit(-1)
-
- ec_dir = os.path.realpath(os.path.dirname(__file__) + '/..')
-
- parser = argparse.ArgumentParser(description='Run eCTS and report results.')
- parser.add_argument('-d',
- '--dryrun',
- action='store_true',
- help='Echo commands to be executed without running them.')
- parser.add_argument('-s',
- '--sync',
- action='store_true',
- help='Sync tree before running tests.')
- parser.add_argument('-u',
- '--upload',
- action='store_true',
- help='Upload test results.')
- args = parser.parse_args()
-
- runner = CtsRunner(ec_dir, args.dryrun)
-
- if args.sync:
- if not runner.sync():
- logging.error('Failed to sync.')
- sys.exit(-1)
-
- runner.run(TESTS)
-
- if args.upload:
- runner.upload()
-
-
-if __name__ == '__main__':
- main()
+ if not os.path.exists("/etc/cros_chroot_version"):
+ logging.error("This script has to run inside chroot.")
+ sys.exit(-1)
+
+ ec_dir = os.path.realpath(os.path.dirname(__file__) + "/..")
+
+ parser = argparse.ArgumentParser(description="Run eCTS and report results.")
+ parser.add_argument(
+ "-d",
+ "--dryrun",
+ action="store_true",
+ help="Echo commands to be executed without running them.",
+ )
+ parser.add_argument(
+ "-s", "--sync", action="store_true", help="Sync tree before running tests."
+ )
+ parser.add_argument(
+ "-u", "--upload", action="store_true", help="Upload test results."
+ )
+ args = parser.parse_args()
+
+ runner = CtsRunner(ec_dir, args.dryrun)
+
+ if args.sync:
+ if not runner.sync():
+ logging.error("Failed to sync.")
+ sys.exit(-1)
+
+ runner.run(TESTS)
+
+ if args.upload:
+ runner.upload()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/util/test_kconfig_check.py b/util/test_kconfig_check.py
index cd1b9bf098..db73e7ee71 100644
--- a/util/test_kconfig_check.py
+++ b/util/test_kconfig_check.py
@@ -16,7 +16,8 @@ import kconfig_check
# Prefix that we strip from each Kconfig option, when considering whether it is
# equivalent to a CONFIG option with the same name
-PREFIX = 'PLATFORM_EC_'
+PREFIX = "PLATFORM_EC_"
+
@contextlib.contextmanager
def capture_sys_output():
@@ -39,38 +40,49 @@ def capture_sys_output():
# directly from Python. You can still run this test with 'pytest' if you like.
class KconfigCheck(unittest.TestCase):
"""Tests for the KconfigCheck class"""
+
def test_simple_check(self):
"""Check it detected a new ad-hoc CONFIG"""
checker = kconfig_check.KconfigCheck()
- self.assertEqual(['NEW_ONE'], checker.find_new_adhoc(
- configs=['NEW_ONE', 'OLD_ONE', 'IN_KCONFIG'],
- kconfigs=['IN_KCONFIG'],
- allowed=['OLD_ONE']))
+ self.assertEqual(
+ ["NEW_ONE"],
+ checker.find_new_adhoc(
+ configs=["NEW_ONE", "OLD_ONE", "IN_KCONFIG"],
+ kconfigs=["IN_KCONFIG"],
+ allowed=["OLD_ONE"],
+ ),
+ )
def test_sorted_check(self):
"""Check it sorts the results in order"""
checker = kconfig_check.KconfigCheck()
self.assertSequenceEqual(
- ['ANOTHER_NEW_ONE', 'NEW_ONE'],
+ ["ANOTHER_NEW_ONE", "NEW_ONE"],
checker.find_new_adhoc(
- configs=['NEW_ONE', 'ANOTHER_NEW_ONE', 'OLD_ONE', 'IN_KCONFIG'],
- kconfigs=['IN_KCONFIG'],
- allowed=['OLD_ONE']))
+ configs=["NEW_ONE", "ANOTHER_NEW_ONE", "OLD_ONE", "IN_KCONFIG"],
+ kconfigs=["IN_KCONFIG"],
+ allowed=["OLD_ONE"],
+ ),
+ )
def check_read_configs(self, use_defines):
checker = kconfig_check.KconfigCheck()
with tempfile.NamedTemporaryFile() as configs:
- with open(configs.name, 'w') as out:
- prefix = '#define ' if use_defines else ''
- suffix = ' ' if use_defines else '='
- out.write(f'''{prefix}CONFIG_OLD_ONE{suffix}y
+ with open(configs.name, "w") as out:
+ prefix = "#define " if use_defines else ""
+ suffix = " " if use_defines else "="
+ out.write(
+ f"""{prefix}CONFIG_OLD_ONE{suffix}y
{prefix}NOT_A_CONFIG{suffix}
{prefix}CONFIG_STRING{suffix}"something"
{prefix}CONFIG_INT{suffix}123
{prefix}CONFIG_HEX{suffix}45ab
-''')
- self.assertEqual(['OLD_ONE', 'STRING', 'INT', 'HEX'],
- checker.read_configs(configs.name, use_defines))
+"""
+ )
+ self.assertEqual(
+ ["OLD_ONE", "STRING", "INT", "HEX"],
+ checker.read_configs(configs.name, use_defines),
+ )
def test_read_configs(self):
"""Test KconfigCheck.read_configs()"""
@@ -87,22 +99,24 @@ class KconfigCheck(unittest.TestCase):
Args:
srctree: Directory to write to
"""
- with open(os.path.join(srctree, 'Kconfig'), 'w') as out:
- out.write(f'''config {PREFIX}MY_KCONFIG
+ with open(os.path.join(srctree, "Kconfig"), "w") as out:
+ out.write(
+ f"""config {PREFIX}MY_KCONFIG
\tbool "my kconfig"
rsource "subdir/Kconfig.wibble"
-''')
- subdir = os.path.join(srctree, 'subdir')
+"""
+ )
+ subdir = os.path.join(srctree, "subdir")
os.mkdir(subdir)
- with open(os.path.join(subdir, 'Kconfig.wibble'), 'w') as out:
- out.write('menuconfig %sMENU_KCONFIG\n' % PREFIX)
+ with open(os.path.join(subdir, "Kconfig.wibble"), "w") as out:
+ out.write("menuconfig %sMENU_KCONFIG\n" % PREFIX)
# Add a directory which should be ignored
- bad_subdir = os.path.join(subdir, 'Kconfig')
+ bad_subdir = os.path.join(subdir, "Kconfig")
os.mkdir(bad_subdir)
- with open(os.path.join(bad_subdir, 'Kconfig.bad'), 'w') as out:
- out.write('menuconfig %sBAD_KCONFIG' % PREFIX)
+ with open(os.path.join(bad_subdir, "Kconfig.bad"), "w") as out:
+ out.write("menuconfig %sBAD_KCONFIG" % PREFIX)
def test_find_kconfigs(self):
"""Test KconfigCheck.find_kconfigs()"""
@@ -110,20 +124,20 @@ rsource "subdir/Kconfig.wibble"
with tempfile.TemporaryDirectory() as srctree:
self.setup_srctree(srctree)
files = checker.find_kconfigs(srctree)
- fnames = [fname[len(srctree):] for fname in files]
- self.assertEqual(['/Kconfig', '/subdir/Kconfig.wibble'], fnames)
+ fnames = [fname[len(srctree) :] for fname in files]
+ self.assertEqual(["/Kconfig", "/subdir/Kconfig.wibble"], fnames)
def test_scan_kconfigs(self):
"""Test KconfigCheck.scan_configs()"""
checker = kconfig_check.KconfigCheck()
with tempfile.TemporaryDirectory() as srctree:
self.setup_srctree(srctree)
- self.assertEqual(['MENU_KCONFIG', 'MY_KCONFIG'],
- checker.scan_kconfigs(srctree, PREFIX))
+ self.assertEqual(
+ ["MENU_KCONFIG", "MY_KCONFIG"], checker.scan_kconfigs(srctree, PREFIX)
+ )
@classmethod
- def setup_allowed_and_configs(cls, allowed_fname, configs_fname,
- add_new_one=True):
+ def setup_allowed_and_configs(cls, allowed_fname, configs_fname, add_new_one=True):
"""Set up the 'allowed' and 'configs' files for tests
Args:
@@ -131,14 +145,14 @@ rsource "subdir/Kconfig.wibble"
configs_fname: Filename to which CONFIGs to check should be written
add_new_one: True to add CONFIG_NEW_ONE to the configs_fname file
"""
- with open(allowed_fname, 'w') as out:
- out.write('CONFIG_OLD_ONE\n')
- out.write('CONFIG_MENU_KCONFIG\n')
- with open(configs_fname, 'w') as out:
- to_add = ['CONFIG_OLD_ONE', 'CONFIG_MY_KCONFIG']
+ with open(allowed_fname, "w") as out:
+ out.write("CONFIG_OLD_ONE\n")
+ out.write("CONFIG_MENU_KCONFIG\n")
+ with open(configs_fname, "w") as out:
+ to_add = ["CONFIG_OLD_ONE", "CONFIG_MY_KCONFIG"]
if add_new_one:
- to_add.append('CONFIG_NEW_ONE')
- out.write('\n'.join(to_add))
+ to_add.append("CONFIG_NEW_ONE")
+ out.write("\n".join(to_add))
def test_check_adhoc_configs(self):
"""Test KconfigCheck.check_adhoc_configs()"""
@@ -148,12 +162,16 @@ rsource "subdir/Kconfig.wibble"
with tempfile.NamedTemporaryFile() as allowed:
with tempfile.NamedTemporaryFile() as configs:
self.setup_allowed_and_configs(allowed.name, configs.name)
- new_adhoc, unneeded_adhoc, updated_adhoc = (
- checker.check_adhoc_configs(
- configs.name, srctree, allowed.name, PREFIX))
- self.assertEqual(['NEW_ONE'], new_adhoc)
- self.assertEqual(['MENU_KCONFIG'], unneeded_adhoc)
- self.assertEqual(['OLD_ONE'], updated_adhoc)
+ (
+ new_adhoc,
+ unneeded_adhoc,
+ updated_adhoc,
+ ) = checker.check_adhoc_configs(
+ configs.name, srctree, allowed.name, PREFIX
+ )
+ self.assertEqual(["NEW_ONE"], new_adhoc)
+ self.assertEqual(["MENU_KCONFIG"], unneeded_adhoc)
+ self.assertEqual(["OLD_ONE"], updated_adhoc)
def test_check(self):
"""Test running the 'check' subcommand"""
@@ -162,29 +180,39 @@ rsource "subdir/Kconfig.wibble"
self.setup_srctree(srctree)
with tempfile.NamedTemporaryFile() as allowed:
with tempfile.NamedTemporaryFile() as configs:
- self.setup_allowed_and_configs(allowed.name,
- configs.name)
+ self.setup_allowed_and_configs(allowed.name, configs.name)
ret_code = kconfig_check.main(
- ['-c', configs.name, '-s', srctree,
- '-a', allowed.name, '-p', PREFIX, 'check'])
+ [
+ "-c",
+ configs.name,
+ "-s",
+ srctree,
+ "-a",
+ allowed.name,
+ "-p",
+ PREFIX,
+ "check",
+ ]
+ )
self.assertEqual(1, ret_code)
- self.assertEqual('', stdout.getvalue())
- found = re.findall('(CONFIG_.*)', stderr.getvalue())
- self.assertEqual(['CONFIG_NEW_ONE'], found)
+ self.assertEqual("", stdout.getvalue())
+ found = re.findall("(CONFIG_.*)", stderr.getvalue())
+ self.assertEqual(["CONFIG_NEW_ONE"], found)
def test_real_kconfig(self):
"""Same Kconfig should be returned for kconfiglib / adhoc"""
if not kconfig_check.USE_KCONFIGLIB:
- self.skipTest('No kconfiglib available')
- zephyr_path = pathlib.Path('../../third_party/zephyr/main').resolve()
+ self.skipTest("No kconfiglib available")
+ zephyr_path = pathlib.Path("../../third_party/zephyr/main").resolve()
if not zephyr_path.exists():
- self.skipTest('No zephyr tree available')
+ self.skipTest("No zephyr tree available")
checker = kconfig_check.KconfigCheck()
- srcdir = 'zephyr'
+ srcdir = "zephyr"
search_paths = [zephyr_path]
kc_version = checker.scan_kconfigs(
- srcdir, search_paths=search_paths, try_kconfiglib=True)
+ srcdir, search_paths=search_paths, try_kconfiglib=True
+ )
adhoc_version = checker.scan_kconfigs(srcdir, try_kconfiglib=False)
# List of things missing from the Kconfig
@@ -192,15 +220,17 @@ rsource "subdir/Kconfig.wibble"
# The Kconfig is disjoint in some places, e.g. the boards have their
# own Kconfig files which are not included from the main Kconfig
- missing = [item for item in missing
- if not item.startswith('BOARD') and
- not item.startswith('VARIANT')]
+ missing = [
+ item
+ for item in missing
+ if not item.startswith("BOARD") and not item.startswith("VARIANT")
+ ]
# Similarly, some other items are defined in files that are not included
# in all cases, only for particular values of $(ARCH)
self.assertEqual(
- ['FLASH_LOAD_OFFSET', 'NPCX_HEADER', 'SYS_CLOCK_HW_CYCLES_PER_SEC'],
- missing)
+ ["FLASH_LOAD_OFFSET", "NPCX_HEADER", "SYS_CLOCK_HW_CYCLES_PER_SEC"], missing
+ )
def test_check_unneeded(self):
"""Test running the 'check' subcommand with unneeded ad-hoc configs"""
@@ -209,18 +239,29 @@ rsource "subdir/Kconfig.wibble"
self.setup_srctree(srctree)
with tempfile.NamedTemporaryFile() as allowed:
with tempfile.NamedTemporaryFile() as configs:
- self.setup_allowed_and_configs(allowed.name,
- configs.name, False)
+ self.setup_allowed_and_configs(
+ allowed.name, configs.name, False
+ )
ret_code = kconfig_check.main(
- ['-c', configs.name, '-s', srctree,
- '-a', allowed.name, '-p', PREFIX, 'check'])
+ [
+ "-c",
+ configs.name,
+ "-s",
+ srctree,
+ "-a",
+ allowed.name,
+ "-p",
+ PREFIX,
+ "check",
+ ]
+ )
self.assertEqual(1, ret_code)
- self.assertEqual('', stderr.getvalue())
- found = re.findall('(CONFIG_.*)', stdout.getvalue())
- self.assertEqual(['CONFIG_MENU_KCONFIG'], found)
+ self.assertEqual("", stderr.getvalue())
+ found = re.findall("(CONFIG_.*)", stdout.getvalue())
+ self.assertEqual(["CONFIG_MENU_KCONFIG"], found)
allowed = kconfig_check.NEW_ALLOWED_FNAME.read_text().splitlines()
- self.assertEqual(['CONFIG_OLD_ONE'], allowed)
+ self.assertEqual(["CONFIG_OLD_ONE"], allowed)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/util/uart_stress_tester.py b/util/uart_stress_tester.py
index b3db60060e..a89fe730c9 100755
--- a/util/uart_stress_tester.py
+++ b/util/uart_stress_tester.py
@@ -21,9 +21,7 @@ Prerequisite:
e.g. dut-control cr50_uart_timestamp:off
"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
import argparse
import atexit
@@ -36,472 +34,501 @@ import time
import serial
-BAUDRATE = 115200 # Default baudrate setting for UART port
-CROS_USERNAME = 'root' # Account name to login to ChromeOS
-CROS_PASSWORD = 'test0000' # Password to login to ChromeOS
-CHARGEN_TXT = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
- # The result of 'chargen 62 62'
+BAUDRATE = 115200 # Default baudrate setting for UART port
+CROS_USERNAME = "root" # Account name to login to ChromeOS
+CROS_PASSWORD = "test0000" # Password to login to ChromeOS
+CHARGEN_TXT = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+# The result of 'chargen 62 62'
CHARGEN_TXT_LEN = len(CHARGEN_TXT)
-CR = '\r' # Carriage Return
-LF = '\n' # Line Feed
+CR = "\r" # Carriage Return
+LF = "\n" # Line Feed
CRLF = CR + LF
-FLAG_FILENAME = '/tmp/chargen_testing'
-TPM_CMD = ('trunks_client --key_create --rsa=2048 --usage=sign'
- ' --key_blob=/tmp/blob &> /dev/null')
- # A ChromeOS TPM command for the cr50 stress
- # purpose.
-CR50_LOAD_GEN_CMD = ('while [[ -f %s ]]; do %s; done &'
- % (FLAG_FILENAME, TPM_CMD))
- # A command line to run TPM_CMD in background
- # infinitely.
+FLAG_FILENAME = "/tmp/chargen_testing"
+TPM_CMD = (
+ "trunks_client --key_create --rsa=2048 --usage=sign"
+ " --key_blob=/tmp/blob &> /dev/null"
+)
+# A ChromeOS TPM command for the cr50 stress
+# purpose.
+CR50_LOAD_GEN_CMD = "while [[ -f %s ]]; do %s; done &" % (FLAG_FILENAME, TPM_CMD)
+# A command line to run TPM_CMD in background
+# infinitely.
class ChargenTestError(Exception):
- """Exception for Uart Stress Test Error"""
- pass
+ """Exception for Uart Stress Test Error"""
+ pass
-class UartSerial(object):
- """Test Object for a single UART serial device
-
- Attributes:
- UART_DEV_PROFILES
- char_loss_occurrences: Number that character loss happens
- cleanup_cli: Command list to perform before the test exits
- cr50_workload: True if cr50 should be stressed, or False otherwise
- usb_output: True if output should be generated to USB channel
- dev_prof: Dictionary of device profile
- duration: Time to keep chargen running
- eol: Characters to add at the end of input
- logger: object that store the log
- num_ch_exp: Expected number of characters in output
- num_ch_cap: Number of captured characters in output
- test_cli: Command list to run for chargen test
- test_thread: Thread object that captures the UART output
- serial: serial.Serial object
- """
- UART_DEV_PROFILES = (
- # Kernel
- {
- 'prompt':'localhost login:',
- 'device_type':'AP',
- 'prepare_cmd':[
- CROS_USERNAME, # Login
- CROS_PASSWORD, # Password
- 'dmesg -D', # Disable console message
- 'touch ' + FLAG_FILENAME, # Create a temp file
- ],
- 'cleanup_cmd':[
- 'rm -f ' + FLAG_FILENAME, # Remove the temp file
- 'dmesg -E', # Enable console message
- 'logout', # Logout
- ],
- 'end_of_input':LF,
- },
- # EC
- {
- 'prompt':'> ',
- 'device_type':'EC',
- 'prepare_cmd':[
- 'chan save',
- 'chan 0' # Disable console message
- ],
- 'cleanup_cmd':['', 'chan restore'],
- 'end_of_input':CRLF,
- },
- )
-
- def __init__(self, port, duration, timeout=1,
- baudrate=BAUDRATE, cr50_workload=False,
- usb_output=False):
- """Initialize UartSerial
-
- Args:
- port: UART device path. e.g. /dev/ttyUSB0
- duration: Time to test, in seconds
- timeout: Read timeout value.
- baudrate: Baud rate such as 9600 or 115200.
- cr50_workload: True if a workload should be generated on cr50
- usb_output: True if a workload should be generated to USB channel
- """
-
- # Initialize serial object
- self.serial = serial.Serial()
- self.serial.port = port
- self.serial.timeout = timeout
- self.serial.baudrate = baudrate
-
- self.duration = duration
- self.cr50_workload = cr50_workload
- self.usb_output = usb_output
-
- self.logger = logging.getLogger(type(self).__name__ + '| ' + port)
- self.test_thread = threading.Thread(target=self.stress_test_thread)
- self.dev_prof = {}
- self.cleanup_cli = []
- self.test_cli = []
- self.eol = CRLF
- self.num_ch_exp = 0
- self.num_ch_cap = 0
- self.char_loss_occurrences = 0
- atexit.register(self.cleanup)
-
- def run_command(self, command_lines, delay=0):
- """Run command(s) at UART prompt
-
- Args:
- command_lines: list of commands to run.
- delay: delay after a command in second
+class UartSerial(object):
+ """Test Object for a single UART serial device
+
+ Attributes:
+ UART_DEV_PROFILES
+ char_loss_occurrences: Number that character loss happens
+ cleanup_cli: Command list to perform before the test exits
+ cr50_workload: True if cr50 should be stressed, or False otherwise
+ usb_output: True if output should be generated to USB channel
+ dev_prof: Dictionary of device profile
+ duration: Time to keep chargen running
+ eol: Characters to add at the end of input
+ logger: object that store the log
+ num_ch_exp: Expected number of characters in output
+ num_ch_cap: Number of captured characters in output
+ test_cli: Command list to run for chargen test
+ test_thread: Thread object that captures the UART output
+ serial: serial.Serial object
"""
- for cli in command_lines:
- self.logger.debug('run %r', cli)
-
- self.serial.write((cli + self.eol).encode())
- self.serial.flush()
- if delay:
- time.sleep(delay)
-
- def cleanup(self):
- """Before termination, clean up the UART device."""
- self.logger.debug('Closing...')
-
- self.serial.open()
- self.run_command(self.cleanup_cli) # Run cleanup commands
- self.serial.close()
- self.logger.debug('Cleanup done')
+ UART_DEV_PROFILES = (
+ # Kernel
+ {
+ "prompt": "localhost login:",
+ "device_type": "AP",
+ "prepare_cmd": [
+ CROS_USERNAME, # Login
+ CROS_PASSWORD, # Password
+ "dmesg -D", # Disable console message
+ "touch " + FLAG_FILENAME, # Create a temp file
+ ],
+ "cleanup_cmd": [
+ "rm -f " + FLAG_FILENAME, # Remove the temp file
+ "dmesg -E", # Enable console message
+ "logout", # Logout
+ ],
+ "end_of_input": LF,
+ },
+ # EC
+ {
+ "prompt": "> ",
+ "device_type": "EC",
+ "prepare_cmd": ["chan save", "chan 0"], # Disable console message
+ "cleanup_cmd": ["", "chan restore"],
+ "end_of_input": CRLF,
+ },
+ )
+
+ def __init__(
+ self,
+ port,
+ duration,
+ timeout=1,
+ baudrate=BAUDRATE,
+ cr50_workload=False,
+ usb_output=False,
+ ):
+ """Initialize UartSerial
+
+ Args:
+ port: UART device path. e.g. /dev/ttyUSB0
+ duration: Time to test, in seconds
+ timeout: Read timeout value.
+ baudrate: Baud rate such as 9600 or 115200.
+ cr50_workload: True if a workload should be generated on cr50
+ usb_output: True if a workload should be generated to USB channel
+ """
+
+ # Initialize serial object
+ self.serial = serial.Serial()
+ self.serial.port = port
+ self.serial.timeout = timeout
+ self.serial.baudrate = baudrate
+
+ self.duration = duration
+ self.cr50_workload = cr50_workload
+ self.usb_output = usb_output
+
+ self.logger = logging.getLogger(type(self).__name__ + "| " + port)
+ self.test_thread = threading.Thread(target=self.stress_test_thread)
+
+ self.dev_prof = {}
+ self.cleanup_cli = []
+ self.test_cli = []
+ self.eol = CRLF
+ self.num_ch_exp = 0
+ self.num_ch_cap = 0
+ self.char_loss_occurrences = 0
+ atexit.register(self.cleanup)
+
+ def run_command(self, command_lines, delay=0):
+ """Run command(s) at UART prompt
+
+ Args:
+ command_lines: list of commands to run.
+ delay: delay after a command in second
+ """
+ for cli in command_lines:
+ self.logger.debug("run %r", cli)
+
+ self.serial.write((cli + self.eol).encode())
+ self.serial.flush()
+ if delay:
+ time.sleep(delay)
+
+ def cleanup(self):
+ """Before termination, clean up the UART device."""
+ self.logger.debug("Closing...")
+
+ self.serial.open()
+ self.run_command(self.cleanup_cli) # Run cleanup commands
+ self.serial.close()
+
+ self.logger.debug("Cleanup done")
+
+ def get_output(self):
+ """Capture the UART output
+
+ Args:
+ stop_char: Read output buffer until it reads stop_char.
+
+ Returns:
+ text from UART output.
+ """
+ if self.serial.inWaiting() == 0:
+ time.sleep(1)
+
+ return self.serial.read(self.serial.inWaiting()).decode()
+
+ def prepare(self):
+ """Prepare the test:
+
+ Identify the type of UART device (EC or Kernel?), then
+ decide what kind of commands to use to generate stress loads.
+
+ Raises:
+ ChargenTestError if UART source can't be identified.
+ """
+ try:
+ self.logger.info("Preparing...")
+
+ self.serial.open()
+
+ # Prepare the device for test
+ self.serial.flushInput()
+ self.serial.flushOutput()
+
+ self.get_output() # drain data
+
+ # Give a couple of line feeds, and capture the prompt text
+ self.run_command(["", ""])
+ prompt_txt = self.get_output()
+
+ # Detect the device source: EC or AP?
+ # Detect if the device is AP or EC console based on the captured.
+ for dev_prof in self.UART_DEV_PROFILES:
+ if dev_prof["prompt"] in prompt_txt:
+ self.dev_prof = dev_prof
+ break
+ else:
+ # No prompt patterns were found. UART seems not responding or in
+ # an undesirable status.
+ if prompt_txt:
+ raise ChargenTestError(
+ "%s: Got an unknown prompt text: %s\n"
+ "Check manually whether %s is available."
+ % (self.serial.port, prompt_txt, self.serial.port)
+ )
+ else:
+ raise ChargenTestError(
+ "%s: Got no input. Close any other connections"
+ " to this port, and try it again." % self.serial.port
+ )
+
+ self.logger.info("Detected as %s UART", self.dev_prof["device_type"])
+ # Log displays the UART type (AP|EC) instead of device filename.
+ self.logger = logging.getLogger(
+ type(self).__name__ + "| " + self.dev_prof["device_type"]
+ )
+
+ # Either login to AP or run some commands to prepare the device
+ # for test
+ self.eol = self.dev_prof["end_of_input"]
+ self.run_command(self.dev_prof["prepare_cmd"], delay=2)
+ self.cleanup_cli += self.dev_prof["cleanup_cmd"]
+
+ # 'chargen' of AP does not have option for USB output.
+ # Force it work on UART.
+ if self.dev_prof["device_type"] == "AP":
+ self.usb_output = False
+
+ # Check whether the command 'chargen' is available in the device.
+ # 'chargen 1 4' is supposed to print '0000'
+ self.get_output() # drain data
+
+ chargen_cmd = "chargen 1 4"
+ if self.usb_output:
+ chargen_cmd += " usb"
+ self.run_command([chargen_cmd])
+ tmp_txt = self.get_output()
+
+ # Check whether chargen command is available.
+ if "0000" not in tmp_txt:
+ raise ChargenTestError(
+ "%s: Chargen got an unexpected result: %s"
+ % (self.dev_prof["device_type"], tmp_txt)
+ )
+
+ self.num_ch_exp = int(self.serial.baudrate * self.duration / 10)
+ chargen_cmd = "chargen " + str(CHARGEN_TXT_LEN) + " " + str(self.num_ch_exp)
+ if self.usb_output:
+ chargen_cmd += " usb"
+ self.test_cli = [chargen_cmd]
+
+ self.logger.info("Ready to test")
+ finally:
+ self.serial.close()
+
+ def stress_test_thread(self):
+ """Test thread
+
+ Raises:
+ ChargenTestError: if broken character is found.
+ """
+ try:
+ self.serial.open()
+ self.serial.flushInput()
+ self.serial.flushOutput()
+
+ # Run TPM command in background to burden cr50.
+ if self.dev_prof["device_type"] == "AP" and self.cr50_workload:
+ self.run_command([CR50_LOAD_GEN_CMD])
+ self.logger.debug("run TPM job while %s exists", FLAG_FILENAME)
+
+ # Run the command 'chargen', one time
+ self.run_command([""]) # Give a line feed
+ self.get_output() # Drain the output
+ self.run_command(self.test_cli)
+ self.serial.readline() # Drain the echoed command line.
+
+ err_msg = "%s: Expected %r but got %s after %d char received"
+
+ # Keep capturing the output until the test timer is expired.
+ self.num_ch_cap = 0
+ self.char_loss_occurrences = 0
+ data_starve_count = 0
+
+ total_num_ch = self.num_ch_exp # Expected number of characters in total
+ ch_exp = CHARGEN_TXT[0]
+ ch_cap = "z" # any character value is ok for loop initial condition.
+ while self.num_ch_cap < total_num_ch:
+ captured = self.get_output()
+
+ if captured:
+ # There is some output data. Reset the data starvation count.
+ data_starve_count = 0
+ else:
+ data_starve_count += 1
+ if data_starve_count > 1:
+ # If nothing was captured more than once, then terminate the test.
+ self.logger.debug("No more output")
+ break
+
+ for ch_cap in captured:
+ if ch_cap not in CHARGEN_TXT:
+ # If it is not alpha-numeric, terminate the test.
+ if ch_cap not in CRLF:
+ # If it is neither a CR nor LF, then it is an error case.
+ self.logger.error("Whole captured characters: %r", captured)
+ raise ChargenTestError(
+ err_msg
+ % (
+ "Broken char captured",
+ ch_exp,
+ hex(ord(ch_cap)),
+ self.num_ch_cap,
+ )
+ )
+
+ # Set the loop termination condition true.
+ total_num_ch = self.num_ch_cap
+
+ if self.num_ch_cap >= total_num_ch:
+ break
+
+ if ch_exp != ch_cap:
+ # If it is alpha-numeric but not continuous, then some characters
+ # are lost.
+ self.logger.error(
+ err_msg,
+ "Char loss detected",
+ ch_exp,
+ repr(ch_cap),
+ self.num_ch_cap,
+ )
+ self.char_loss_occurrences += 1
+
+ # Recalculate the expected number of characters to adjust
+ # termination condition. The loss might be bigger than this
+ # adjustment, but it is okay since it will terminates by either
+ # CR/LF detection or by data starvation.
+ idx_ch_exp = CHARGEN_TXT.find(ch_exp)
+ idx_ch_cap = CHARGEN_TXT.find(ch_cap)
+ if idx_ch_cap < idx_ch_exp:
+ idx_ch_cap += len(CHARGEN_TXT)
+ total_num_ch -= idx_ch_cap - idx_ch_exp
+
+ self.num_ch_cap += 1
+
+ # Determine What character is expected next?
+ ch_exp = CHARGEN_TXT[
+ (CHARGEN_TXT.find(ch_cap) + 1) % CHARGEN_TXT_LEN
+ ]
+
+ finally:
+ self.serial.close()
+
+ def start_test(self):
+ """Start the test thread"""
+ self.logger.info("Test thread starts")
+ self.test_thread.start()
+
+ def wait_test_done(self):
+ """Wait until the test thread get done and join"""
+ self.test_thread.join()
+ self.logger.info("Test thread is done")
+
+ def get_result(self):
+ """Display the result
+
+ Returns:
+ Integer = the number of lost character
+
+ Raises:
+ ChargenTestError: if the capture is corrupted.
+ """
+ # If more characters than expected are captured, it means some messages
+ # from other than chargen are mixed. Stop processing further.
+ if self.num_ch_exp < self.num_ch_cap:
+ raise ChargenTestError(
+ "%s: UART output is corrupted." % self.dev_prof["device_type"]
+ )
+
+ # Get the count difference between the expected to the captured
+ # as the number of lost character.
+ char_lost = self.num_ch_exp - self.num_ch_cap
+ self.logger.info(
+ "%8d char lost / %10d (%.1f %%)",
+ char_lost,
+ self.num_ch_exp,
+ char_lost * 100.0 / self.num_ch_exp,
+ )
+
+ return char_lost, self.num_ch_exp, self.char_loss_occurrences
- def get_output(self):
- """Capture the UART output
- Args:
- stop_char: Read output buffer until it reads stop_char.
+class ChargenTest(object):
+ """UART stress tester
- Returns:
- text from UART output.
+ Attributes:
+ logger: logging object
+ serials: Dictionary where key is filename of UART device, and the value is
+ UartSerial object
"""
- if self.serial.inWaiting() == 0:
- time.sleep(1)
-
- return self.serial.read(self.serial.inWaiting()).decode()
- def prepare(self):
- """Prepare the test:
-
- Identify the type of UART device (EC or Kernel?), then
- decide what kind of commands to use to generate stress loads.
-
- Raises:
- ChargenTestError if UART source can't be identified.
- """
- try:
- self.logger.info('Preparing...')
-
- self.serial.open()
-
- # Prepare the device for test
- self.serial.flushInput()
- self.serial.flushOutput()
-
- self.get_output() # drain data
-
- # Give a couple of line feeds, and capture the prompt text
- self.run_command(['', ''])
- prompt_txt = self.get_output()
-
- # Detect the device source: EC or AP?
- # Detect if the device is AP or EC console based on the captured.
- for dev_prof in self.UART_DEV_PROFILES:
- if dev_prof['prompt'] in prompt_txt:
- self.dev_prof = dev_prof
- break
- else:
- # No prompt patterns were found. UART seems not responding or in
- # an undesirable status.
- if prompt_txt:
- raise ChargenTestError('%s: Got an unknown prompt text: %s\n'
- 'Check manually whether %s is available.' %
- (self.serial.port, prompt_txt,
- self.serial.port))
- else:
- raise ChargenTestError('%s: Got no input. Close any other connections'
- ' to this port, and try it again.' %
- self.serial.port)
-
- self.logger.info('Detected as %s UART', self.dev_prof['device_type'])
- # Log displays the UART type (AP|EC) instead of device filename.
- self.logger = logging.getLogger(type(self).__name__ + '| ' +
- self.dev_prof['device_type'])
-
- # Either login to AP or run some commands to prepare the device
- # for test
- self.eol = self.dev_prof['end_of_input']
- self.run_command(self.dev_prof['prepare_cmd'], delay=2)
- self.cleanup_cli += self.dev_prof['cleanup_cmd']
-
- # 'chargen' of AP does not have option for USB output.
- # Force it work on UART.
- if self.dev_prof['device_type'] == 'AP':
- self.usb_output = False
-
- # Check whether the command 'chargen' is available in the device.
- # 'chargen 1 4' is supposed to print '0000'
- self.get_output() # drain data
-
- chargen_cmd = 'chargen 1 4'
- if self.usb_output:
- chargen_cmd += ' usb'
- self.run_command([chargen_cmd])
- tmp_txt = self.get_output()
-
- # Check whether chargen command is available.
- if '0000' not in tmp_txt:
- raise ChargenTestError('%s: Chargen got an unexpected result: %s' %
- (self.dev_prof['device_type'], tmp_txt))
-
- self.num_ch_exp = int(self.serial.baudrate * self.duration / 10)
- chargen_cmd = 'chargen ' + str(CHARGEN_TXT_LEN) + ' ' + \
- str(self.num_ch_exp)
- if self.usb_output:
- chargen_cmd += ' usb'
- self.test_cli = [chargen_cmd]
-
- self.logger.info('Ready to test')
- finally:
- self.serial.close()
-
- def stress_test_thread(self):
- """Test thread
-
- Raises:
- ChargenTestError: if broken character is found.
- """
- try:
- self.serial.open()
- self.serial.flushInput()
- self.serial.flushOutput()
-
- # Run TPM command in background to burden cr50.
- if self.dev_prof['device_type'] == 'AP' and self.cr50_workload:
- self.run_command([CR50_LOAD_GEN_CMD])
- self.logger.debug('run TPM job while %s exists', FLAG_FILENAME)
-
- # Run the command 'chargen', one time
- self.run_command(['']) # Give a line feed
- self.get_output() # Drain the output
- self.run_command(self.test_cli)
- self.serial.readline() # Drain the echoed command line.
-
- err_msg = '%s: Expected %r but got %s after %d char received'
-
- # Keep capturing the output until the test timer is expired.
- self.num_ch_cap = 0
- self.char_loss_occurrences = 0
- data_starve_count = 0
-
- total_num_ch = self.num_ch_exp # Expected number of characters in total
- ch_exp = CHARGEN_TXT[0]
- ch_cap = 'z' # any character value is ok for loop initial condition.
- while self.num_ch_cap < total_num_ch:
- captured = self.get_output()
-
- if captured:
- # There is some output data. Reset the data starvation count.
- data_starve_count = 0
+ def __init__(self, ports, duration, cr50_workload=False, usb_output=False):
+ """Initialize UART stress tester
+
+ Args:
+ ports: List of UART ports to test.
+ duration: Time to keep testing in seconds.
+ cr50_workload: True if a workload should be generated on cr50
+ usb_output: True if a workload should be generated to USB channel
+
+ Raises:
+ ChargenTestError: if any of ports is not a valid character device.
+ """
+
+ # Save the arguments
+ for port in ports:
+ try:
+ mode = os.stat(port).st_mode
+ except OSError as e:
+ raise ChargenTestError(e)
+ if not stat.S_ISCHR(mode):
+ raise ChargenTestError("%s is not a character device." % port)
+
+ if duration <= 0:
+ raise ChargenTestError("Input error: duration is not positive.")
+
+ # Initialize logging object
+ self.logger = logging.getLogger(type(self).__name__)
+
+ # Create an UartSerial object per UART port
+ self.serials = {} # UartSerial objects
+ for port in ports:
+ self.serials[port] = UartSerial(
+ port=port,
+ duration=duration,
+ cr50_workload=cr50_workload,
+ usb_output=usb_output,
+ )
+
+ def prepare(self):
+ """Prepare the test for each UART port"""
+ self.logger.info("Prepare ports for test")
+ for _, ser in self.serials.items():
+ ser.prepare()
+ self.logger.info("Ports are ready to test")
+
+ def print_result(self):
+ """Display the test result for each UART port
+
+ Returns:
+ char_lost: Total number of characters lost
+ """
+ char_lost = 0
+ for _, ser in self.serials.items():
+ (tmp_lost, _, _) = ser.get_result()
+ char_lost += tmp_lost
+
+ # If any characters are lost, then test fails.
+ msg = "lost %d character(s) from the test" % char_lost
+ if char_lost > 0:
+ self.logger.error("FAIL: %s", msg)
else:
- data_starve_count += 1
- if data_starve_count > 1:
- # If nothing was captured more than once, then terminate the test.
- self.logger.debug('No more output')
- break
-
- for ch_cap in captured:
- if ch_cap not in CHARGEN_TXT:
- # If it is not alpha-numeric, terminate the test.
- if ch_cap not in CRLF:
- # If it is neither a CR nor LF, then it is an error case.
- self.logger.error('Whole captured characters: %r', captured)
- raise ChargenTestError(err_msg % ('Broken char captured', ch_exp,
- hex(ord(ch_cap)),
- self.num_ch_cap))
-
- # Set the loop termination condition true.
- total_num_ch = self.num_ch_cap
-
- if self.num_ch_cap >= total_num_ch:
- break
-
- if ch_exp != ch_cap:
- # If it is alpha-numeric but not continuous, then some characters
- # are lost.
- self.logger.error(err_msg, 'Char loss detected',
- ch_exp, repr(ch_cap), self.num_ch_cap)
- self.char_loss_occurrences += 1
-
- # Recalculate the expected number of characters to adjust
- # termination condition. The loss might be bigger than this
- # adjustment, but it is okay since it will terminates by either
- # CR/LF detection or by data starvation.
- idx_ch_exp = CHARGEN_TXT.find(ch_exp)
- idx_ch_cap = CHARGEN_TXT.find(ch_cap)
- if idx_ch_cap < idx_ch_exp:
- idx_ch_cap += len(CHARGEN_TXT)
- total_num_ch -= (idx_ch_cap - idx_ch_exp)
-
- self.num_ch_cap += 1
-
- # Determine What character is expected next?
- ch_exp = CHARGEN_TXT[(CHARGEN_TXT.find(ch_cap) + 1) % CHARGEN_TXT_LEN]
-
- finally:
- self.serial.close()
-
- def start_test(self):
- """Start the test thread"""
- self.logger.info('Test thread starts')
- self.test_thread.start()
-
- def wait_test_done(self):
- """Wait until the test thread get done and join"""
- self.test_thread.join()
- self.logger.info('Test thread is done')
-
- def get_result(self):
- """Display the result
+ self.logger.info("PASS: %s", msg)
- Returns:
- Integer = the number of lost character
+ return char_lost
- Raises:
- ChargenTestError: if the capture is corrupted.
- """
- # If more characters than expected are captured, it means some messages
- # from other than chargen are mixed. Stop processing further.
- if self.num_ch_exp < self.num_ch_cap:
- raise ChargenTestError('%s: UART output is corrupted.' %
- self.dev_prof['device_type'])
+ def run(self):
+ """Run the stress test on UART port(s)
- # Get the count difference between the expected to the captured
- # as the number of lost character.
- char_lost = self.num_ch_exp - self.num_ch_cap
- self.logger.info('%8d char lost / %10d (%.1f %%)',
- char_lost, self.num_ch_exp,
- char_lost * 100.0 / self.num_ch_exp)
+ Raises:
+ ChargenTestError: If any characters are lost.
+ """
- return char_lost, self.num_ch_exp, self.char_loss_occurrences
+ # Detect UART source type, and decide which command to test.
+ self.prepare()
+ # Run the test on each UART port in thread.
+ self.logger.info("Test starts")
+ for _, ser in self.serials.items():
+ ser.start_test()
-class ChargenTest(object):
- """UART stress tester
+ # Wait all tests to finish.
+ for _, ser in self.serials.items():
+ ser.wait_test_done()
- Attributes:
- logger: logging object
- serials: Dictionary where key is filename of UART device, and the value is
- UartSerial object
- """
+ # Print the result.
+ char_lost = self.print_result()
+ if char_lost:
+ raise ChargenTestError("Test failed: lost %d character(s)" % char_lost)
- def __init__(self, ports, duration, cr50_workload=False,
- usb_output=False):
- """Initialize UART stress tester
+ self.logger.info("Test is done")
- Args:
- ports: List of UART ports to test.
- duration: Time to keep testing in seconds.
- cr50_workload: True if a workload should be generated on cr50
- usb_output: True if a workload should be generated to USB channel
- Raises:
- ChargenTestError: if any of ports is not a valid character device.
- """
+def parse_args(cmdline):
+ """Parse command line arguments.
- # Save the arguments
- for port in ports:
- try:
- mode = os.stat(port).st_mode
- except OSError as e:
- raise ChargenTestError(e)
- if not stat.S_ISCHR(mode):
- raise ChargenTestError('%s is not a character device.' % port)
-
- if duration <= 0:
- raise ChargenTestError('Input error: duration is not positive.')
-
- # Initialize logging object
- self.logger = logging.getLogger(type(self).__name__)
-
- # Create an UartSerial object per UART port
- self.serials = {} # UartSerial objects
- for port in ports:
- self.serials[port] = UartSerial(port=port, duration=duration,
- cr50_workload=cr50_workload,
- usb_output=usb_output)
-
- def prepare(self):
- """Prepare the test for each UART port"""
- self.logger.info('Prepare ports for test')
- for _, ser in self.serials.items():
- ser.prepare()
- self.logger.info('Ports are ready to test')
-
- def print_result(self):
- """Display the test result for each UART port
+ Args:
+ cmdline: list to be parsed
Returns:
- char_lost: Total number of characters lost
- """
- char_lost = 0
- for _, ser in self.serials.items():
- (tmp_lost, _, _) = ser.get_result()
- char_lost += tmp_lost
-
- # If any characters are lost, then test fails.
- msg = 'lost %d character(s) from the test' % char_lost
- if char_lost > 0:
- self.logger.error('FAIL: %s', msg)
- else:
- self.logger.info('PASS: %s', msg)
-
- return char_lost
-
- def run(self):
- """Run the stress test on UART port(s)
-
- Raises:
- ChargenTestError: If any characters are lost.
+ tuple (options, args) where args is a list of cmdline arguments that the
+ parser was unable to match i.e. they're servod controls, not options.
"""
-
- # Detect UART source type, and decide which command to test.
- self.prepare()
-
- # Run the test on each UART port in thread.
- self.logger.info('Test starts')
- for _, ser in self.serials.items():
- ser.start_test()
-
- # Wait all tests to finish.
- for _, ser in self.serials.items():
- ser.wait_test_done()
-
- # Print the result.
- char_lost = self.print_result()
- if char_lost:
- raise ChargenTestError('Test failed: lost %d character(s)' %
- char_lost)
-
- self.logger.info('Test is done')
-
-def parse_args(cmdline):
- """Parse command line arguments.
-
- Args:
- cmdline: list to be parsed
-
- Returns:
- tuple (options, args) where args is a list of cmdline arguments that the
- parser was unable to match i.e. they're servod controls, not options.
- """
- description = """%(prog)s repeats sending a uart console command
+ description = """%(prog)s repeats sending a uart console command
to each UART device for a given time, and check if output
has any missing characters.
@@ -511,52 +538,70 @@ Examples:
%(prog)s /dev/ttyUSB1 /dev/ttyUSB2 --cr50
"""
- parser = argparse.ArgumentParser(description=description,
- formatter_class=argparse.RawTextHelpFormatter
- )
- parser.add_argument('port', type=str, nargs='*',
- help='UART device path to test')
- parser.add_argument('-c', '--cr50', action='store_true', default=False,
- help='generate TPM workload on cr50')
- parser.add_argument('-d', '--debug', action='store_true', default=False,
- help='enable debug messages')
- parser.add_argument('-t', '--time', type=int,
- help='Test duration in second', default=300)
- parser.add_argument('-u', '--usb', action='store_true', default=False,
- help='Generate output to USB channel instead')
- return parser.parse_known_args(cmdline)
+ parser = argparse.ArgumentParser(
+ description=description, formatter_class=argparse.RawTextHelpFormatter
+ )
+ parser.add_argument("port", type=str, nargs="*", help="UART device path to test")
+ parser.add_argument(
+ "-c",
+ "--cr50",
+ action="store_true",
+ default=False,
+ help="generate TPM workload on cr50",
+ )
+ parser.add_argument(
+ "-d",
+ "--debug",
+ action="store_true",
+ default=False,
+ help="enable debug messages",
+ )
+ parser.add_argument(
+ "-t", "--time", type=int, help="Test duration in second", default=300
+ )
+ parser.add_argument(
+ "-u",
+ "--usb",
+ action="store_true",
+ default=False,
+ help="Generate output to USB channel instead",
+ )
+ return parser.parse_known_args(cmdline)
def main():
- """Main function wrapper"""
- try:
- (options, _) = parse_args(sys.argv[1:])
-
- # Set Log format
- log_format = '%(asctime)s %(levelname)-6s | %(name)-25s'
- date_format = '%Y-%m-%d %H:%M:%S'
- if options.debug:
- log_format += ' | %(filename)s:%(lineno)4d:%(funcName)-18s'
- loglevel = logging.DEBUG
- else:
- loglevel = logging.INFO
- log_format += ' | %(message)s'
-
- logging.basicConfig(level=loglevel, format=log_format,
- datefmt=date_format)
-
- # Create a ChargenTest object
- utest = ChargenTest(options.port, options.time,
- cr50_workload=options.cr50,
- usb_output=options.usb)
- utest.run() # Run
-
- except KeyboardInterrupt:
- sys.exit(0)
-
- except ChargenTestError as e:
- logging.error(str(e))
- sys.exit(1)
-
-if __name__ == '__main__':
- main()
+ """Main function wrapper"""
+ try:
+ (options, _) = parse_args(sys.argv[1:])
+
+ # Set Log format
+ log_format = "%(asctime)s %(levelname)-6s | %(name)-25s"
+ date_format = "%Y-%m-%d %H:%M:%S"
+ if options.debug:
+ log_format += " | %(filename)s:%(lineno)4d:%(funcName)-18s"
+ loglevel = logging.DEBUG
+ else:
+ loglevel = logging.INFO
+ log_format += " | %(message)s"
+
+ logging.basicConfig(level=loglevel, format=log_format, datefmt=date_format)
+
+ # Create a ChargenTest object
+ utest = ChargenTest(
+ options.port,
+ options.time,
+ cr50_workload=options.cr50,
+ usb_output=options.usb,
+ )
+ utest.run() # Run
+
+ except KeyboardInterrupt:
+ sys.exit(0)
+
+ except ChargenTestError as e:
+ logging.error(str(e))
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/util/unpack_ftb.py b/util/unpack_ftb.py
index 03127a7089..a68662d82b 100755
--- a/util/unpack_ftb.py
+++ b/util/unpack_ftb.py
@@ -10,26 +10,28 @@
# Note: This is a py2/3 compatible file.
from __future__ import print_function
+
import argparse
import ctypes
import os
class Header(ctypes.Structure):
- _pack_ = 1
- _fields_ = [
- ('signature', ctypes.c_uint32),
- ('ftb_ver', ctypes.c_uint32),
- ('chip_id', ctypes.c_uint32),
- ('svn_ver', ctypes.c_uint32),
- ('fw_ver', ctypes.c_uint32),
- ('config_id', ctypes.c_uint32),
- ('config_ver', ctypes.c_uint32),
- ('reserved', ctypes.c_uint8 * 8),
- ('release_info', ctypes.c_ulonglong),
- ('sec_size', ctypes.c_uint32 * 4),
- ('crc', ctypes.c_uint32),
- ]
+ _pack_ = 1
+ _fields_ = [
+ ("signature", ctypes.c_uint32),
+ ("ftb_ver", ctypes.c_uint32),
+ ("chip_id", ctypes.c_uint32),
+ ("svn_ver", ctypes.c_uint32),
+ ("fw_ver", ctypes.c_uint32),
+ ("config_id", ctypes.c_uint32),
+ ("config_ver", ctypes.c_uint32),
+ ("reserved", ctypes.c_uint8 * 8),
+ ("release_info", ctypes.c_ulonglong),
+ ("sec_size", ctypes.c_uint32 * 4),
+ ("crc", ctypes.c_uint32),
+ ]
+
FW_HEADER_SIZE = 64
FW_HEADER_SIGNATURE = 0xAA55AA55
@@ -44,7 +46,7 @@ FLASH_SEC_ADDR = [
0x0000 * 4, # CODE
0x7C00 * 4, # CONFIG
0x7000 * 4, # CX
- None # This section shouldn't exist
+ None, # This section shouldn't exist
]
UPDATE_PDU_SIZE = 4096
@@ -59,64 +61,66 @@ OUTPUT_FILE_SIZE = UPDATE_PDU_SIZE + 128 * 1024
def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--input', '-i', required=True)
- parser.add_argument('--output', '-o', required=True)
- args = parser.parse_args()
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", "-i", required=True)
+ parser.add_argument("--output", "-o", required=True)
+ args = parser.parse_args()
- with open(args.input, 'rb') as f:
- bs = f.read()
+ with open(args.input, "rb") as f:
+ bs = f.read()
- size = len(bs)
- if size < FW_HEADER_SIZE + FW_BYTES_ALIGN:
- raise Exception('FW size too small')
+ size = len(bs)
+ if size < FW_HEADER_SIZE + FW_BYTES_ALIGN:
+ raise Exception("FW size too small")
- print('FTB file size:', size)
+ print("FTB file size:", size)
- header = Header()
- assert ctypes.sizeof(header) == FW_HEADER_SIZE
+ header = Header()
+ assert ctypes.sizeof(header) == FW_HEADER_SIZE
- ctypes.memmove(ctypes.addressof(header), bs, ctypes.sizeof(header))
- if (header.signature != FW_HEADER_SIGNATURE or
- header.ftb_ver != FW_FTB_VER or
- header.chip_id != FW_CHIP_ID):
- raise Exception('Invalid header')
+ ctypes.memmove(ctypes.addressof(header), bs, ctypes.sizeof(header))
+ if (
+ header.signature != FW_HEADER_SIGNATURE
+ or header.ftb_ver != FW_FTB_VER
+ or header.chip_id != FW_CHIP_ID
+ ):
+ raise Exception("Invalid header")
- for key, _ in header._fields_:
- v = getattr(header, key)
- if isinstance(v, ctypes.Array):
- print(key, list(map(hex, v)))
- else:
- print(key, hex(v))
+ for key, _ in header._fields_:
+ v = getattr(header, key)
+ if isinstance(v, ctypes.Array):
+ print(key, list(map(hex, v)))
+ else:
+ print(key, hex(v))
- dimension = sum(header.sec_size)
+ dimension = sum(header.sec_size)
- assert dimension + FW_HEADER_SIZE + FW_BYTES_ALIGN == size
- data = bs[FW_HEADER_SIZE:FW_HEADER_SIZE + dimension]
+ assert dimension + FW_HEADER_SIZE + FW_BYTES_ALIGN == size
+ data = bs[FW_HEADER_SIZE : FW_HEADER_SIZE + dimension]
- with open(args.output, 'wb') as f:
- # ensure the file size
- f.seek(OUTPUT_FILE_SIZE - 1, os.SEEK_SET)
- f.write(b'\x00')
+ with open(args.output, "wb") as f:
+ # ensure the file size
+ f.seek(OUTPUT_FILE_SIZE - 1, os.SEEK_SET)
+ f.write(b"\x00")
- f.seek(0, os.SEEK_SET)
- f.write(bs[0 : ctypes.sizeof(header)])
+ f.seek(0, os.SEEK_SET)
+ f.write(bs[0 : ctypes.sizeof(header)])
- offset = 0
- # write each sections
- for i, addr in enumerate(FLASH_SEC_ADDR):
- size = header.sec_size[i]
- assert addr is not None or size == 0
+ offset = 0
+ # write each sections
+ for i, addr in enumerate(FLASH_SEC_ADDR):
+ size = header.sec_size[i]
+ assert addr is not None or size == 0
- if size == 0:
- continue
+ if size == 0:
+ continue
- f.seek(UPDATE_PDU_SIZE + addr, os.SEEK_SET)
- f.write(data[offset : offset + size])
- offset += size
+ f.seek(UPDATE_PDU_SIZE + addr, os.SEEK_SET)
+ f.write(data[offset : offset + size])
+ offset += size
- f.flush()
+ f.flush()
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/util/update_release_branch.py b/util/update_release_branch.py
index b9063d4970..4d9c89df4a 100755
--- a/util/update_release_branch.py
+++ b/util/update_release_branch.py
@@ -19,8 +19,7 @@ import subprocess
import sys
import textwrap
-
-BUG_NONE_PATTERN = re.compile('none', flags=re.IGNORECASE)
+BUG_NONE_PATTERN = re.compile("none", flags=re.IGNORECASE)
def git_commit_msg(branch, head, merge_head, rel_paths, cmd):
@@ -42,18 +41,17 @@ def git_commit_msg(branch, head, merge_head, rel_paths, cmd):
A String containing the git commit message with the exception of the
Signed-Off-By field and Change-ID field.
"""
- relevant_commits_cmd, relevant_commits = get_relevant_commits(head,
- merge_head,
- '--oneline',
- rel_paths)
+ relevant_commits_cmd, relevant_commits = get_relevant_commits(
+ head, merge_head, "--oneline", rel_paths
+ )
- _, relevant_bugs = get_relevant_commits(head, merge_head, '', rel_paths)
- relevant_bugs = set(re.findall('BUG=(.*)', relevant_bugs))
+ _, relevant_bugs = get_relevant_commits(head, merge_head, "", rel_paths)
+ relevant_bugs = set(re.findall("BUG=(.*)", relevant_bugs))
# Filter out "none" from set of bugs
filtered = []
for bug_line in relevant_bugs:
- bug_line = bug_line.replace(',', ' ')
- bugs = bug_line.split(' ')
+ bug_line = bug_line.replace(",", " ")
+ bugs = bug_line.split(" ")
for bug in bugs:
if bug and not BUG_NONE_PATTERN.match(bug):
filtered.append(bug)
@@ -82,18 +80,20 @@ Cq-Include-Trybots: chromeos/cq:cq-orchestrator
# 72 cols.
relevant_commits_cmd = textwrap.fill(relevant_commits_cmd, width=72)
# Wrap at 68 cols to save room for 'BUG='
- bugs = textwrap.wrap(' '.join(relevant_bugs), width=68)
- bug_field = ''
+ bugs = textwrap.wrap(" ".join(relevant_bugs), width=68)
+ bug_field = ""
for line in bugs:
- bug_field += 'BUG=' + line + '\n'
+ bug_field += "BUG=" + line + "\n"
# Remove the final newline since the template adds it for us.
bug_field = bug_field[:-1]
- return COMMIT_MSG_TEMPLATE.format(BRANCH=branch,
- RELEVANT_COMMITS_CMD=relevant_commits_cmd,
- RELEVANT_COMMITS=relevant_commits,
- BUG_FIELD=bug_field,
- COMMAND_LINE=cmd)
+ return COMMIT_MSG_TEMPLATE.format(
+ BRANCH=branch,
+ RELEVANT_COMMITS_CMD=relevant_commits_cmd,
+ RELEVANT_COMMITS=relevant_commits,
+ BUG_FIELD=bug_field,
+ COMMAND_LINE=cmd,
+ )
def get_relevant_boards(baseboard):
@@ -105,15 +105,16 @@ def get_relevant_boards(baseboard):
Returns:
A list of strings containing the boards based off of the baseboard.
"""
- proc = subprocess.run(['git', 'grep', 'BASEBOARD:=' + baseboard, '--',
- 'board/'],
- stdout=subprocess.PIPE,
- encoding='utf-8',
- check=True)
+ proc = subprocess.run(
+ ["git", "grep", "BASEBOARD:=" + baseboard, "--", "board/"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ check=True,
+ )
boards = []
res = proc.stdout.splitlines()
for line in res:
- boards.append(line.split('/')[1])
+ boards.append(line.split("/")[1])
return boards
@@ -135,21 +136,18 @@ def get_relevant_commits(head, merge_head, fmt, relevant_paths):
stdout.
"""
if fmt:
- cmd = ['git', 'log', fmt, head + '..' + merge_head, '--',
- relevant_paths]
+ cmd = ["git", "log", fmt, head + ".." + merge_head, "--", relevant_paths]
else:
- cmd = ['git', 'log', head + '..' + merge_head, '--', relevant_paths]
+ cmd = ["git", "log", head + ".." + merge_head, "--", relevant_paths]
# Pass cmd as a string to subprocess.run() since we need to run with shell
# equal to True. The reason we are using shell equal to True is to take
# advantage of the glob expansion for the relevant paths.
- cmd = ' '.join(cmd)
- proc = subprocess.run(cmd,
- stdout=subprocess.PIPE,
- encoding='utf-8',
- check=True,
- shell=True)
- return ''.join(proc.args), proc.stdout
+ cmd = " ".join(cmd)
+ proc = subprocess.run(
+ cmd, stdout=subprocess.PIPE, encoding="utf-8", check=True, shell=True
+ )
+ return "".join(proc.args), proc.stdout
def main(argv):
@@ -165,46 +163,61 @@ def main(argv):
argv: A list of the command line arguments passed to this script.
"""
# Set up argument parser.
- parser = argparse.ArgumentParser(description=('A script that generates a '
- 'merge commit from cros/main'
- ' to a desired release '
- 'branch. By default, the '
- '"recursive" merge strategy '
- 'with the "theirs" strategy '
- 'option is used.'))
- parser.add_argument('--baseboard')
- parser.add_argument('--board')
- parser.add_argument('release_branch', help=('The name of the target release'
- ' branch'))
- parser.add_argument('--relevant_paths_file',
- help=('A path to a text file which includes other '
- 'relevant paths of interest for this board '
- 'or baseboard'))
- parser.add_argument('--merge_strategy', '-s', default='recursive',
- help='The merge strategy to pass to `git merge -s`')
- parser.add_argument('--strategy_option', '-X',
- help=('The strategy option for the chosen merge '
- 'strategy'))
+ parser = argparse.ArgumentParser(
+ description=(
+ "A script that generates a "
+ "merge commit from cros/main"
+ " to a desired release "
+ "branch. By default, the "
+ '"recursive" merge strategy '
+ 'with the "theirs" strategy '
+ "option is used."
+ )
+ )
+ parser.add_argument("--baseboard")
+ parser.add_argument("--board")
+ parser.add_argument(
+ "release_branch", help=("The name of the target release" " branch")
+ )
+ parser.add_argument(
+ "--relevant_paths_file",
+ help=(
+ "A path to a text file which includes other "
+ "relevant paths of interest for this board "
+ "or baseboard"
+ ),
+ )
+ parser.add_argument(
+ "--merge_strategy",
+ "-s",
+ default="recursive",
+ help="The merge strategy to pass to `git merge -s`",
+ )
+ parser.add_argument(
+ "--strategy_option",
+ "-X",
+ help=("The strategy option for the chosen merge " "strategy"),
+ )
opts = parser.parse_args(argv[1:])
- baseboard_dir = ''
- board_dir = ''
+ baseboard_dir = ""
+ board_dir = ""
if opts.baseboard:
# Dereference symlinks so "git log" works as expected.
- baseboard_dir = os.path.relpath('baseboard/' + opts.baseboard)
+ baseboard_dir = os.path.relpath("baseboard/" + opts.baseboard)
baseboard_dir = os.path.relpath(os.path.realpath(baseboard_dir))
boards = get_relevant_boards(opts.baseboard)
elif opts.board:
- board_dir = os.path.relpath('board/' + opts.board)
+ board_dir = os.path.relpath("board/" + opts.board)
board_dir = os.path.relpath(os.path.realpath(board_dir))
boards = [opts.board]
else:
- parser.error('You must specify a board OR a baseboard')
+ parser.error("You must specify a board OR a baseboard")
- print('Gathering relevant paths...')
+ print("Gathering relevant paths...")
relevant_paths = []
if opts.baseboard:
relevant_paths.append(baseboard_dir)
@@ -212,65 +225,91 @@ def main(argv):
relevant_paths.append(board_dir)
for board in boards:
- relevant_paths.append('board/' + board)
+ relevant_paths.append("board/" + board)
# Check for the existence of a file that has other paths of interest.
if opts.relevant_paths_file and os.path.exists(opts.relevant_paths_file):
- with open(opts.relevant_paths_file, 'r') as relevant_paths_file:
+ with open(opts.relevant_paths_file, "r") as relevant_paths_file:
for line in relevant_paths_file:
- if not line.startswith('#'):
+ if not line.startswith("#"):
relevant_paths.append(line.rstrip())
- relevant_paths.append('util/getversion.sh')
- relevant_paths = ' '.join(relevant_paths)
+ relevant_paths.append("util/getversion.sh")
+ relevant_paths = " ".join(relevant_paths)
# Check if we are already in merge process
- result = subprocess.run(['git', 'rev-parse', '--quiet', '--verify',
- 'MERGE_HEAD'], stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL, check=False)
+ result = subprocess.run(
+ ["git", "rev-parse", "--quiet", "--verify", "MERGE_HEAD"],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ check=False,
+ )
if result.returncode:
# Let's perform the merge
- print('Updating remote...')
- subprocess.run(['git', 'remote', 'update'], check=True)
- subprocess.run(['git', 'checkout', '-B', opts.release_branch, 'cros/' +
- opts.release_branch], check=True)
- print('Attempting git merge...')
- if opts.merge_strategy == 'recursive' and not opts.strategy_option:
- opts.strategy_option = 'theirs'
- print('Using "%s" merge strategy' % opts.merge_strategy,
- ("with strategy option '%s'" % opts.strategy_option
- if opts.strategy_option else ''))
- arglist = ['git', 'merge', '--no-ff', '--no-commit', 'cros/main', '-s',
- opts.merge_strategy]
+ print("Updating remote...")
+ subprocess.run(["git", "remote", "update"], check=True)
+ subprocess.run(
+ [
+ "git",
+ "checkout",
+ "-B",
+ opts.release_branch,
+ "cros/" + opts.release_branch,
+ ],
+ check=True,
+ )
+ print("Attempting git merge...")
+ if opts.merge_strategy == "recursive" and not opts.strategy_option:
+ opts.strategy_option = "theirs"
+ print(
+ 'Using "%s" merge strategy' % opts.merge_strategy,
+ (
+ "with strategy option '%s'" % opts.strategy_option
+ if opts.strategy_option
+ else ""
+ ),
+ )
+ arglist = [
+ "git",
+ "merge",
+ "--no-ff",
+ "--no-commit",
+ "cros/main",
+ "-s",
+ opts.merge_strategy,
+ ]
if opts.strategy_option:
- arglist.append('-X' + opts.strategy_option)
+ arglist.append("-X" + opts.strategy_option)
subprocess.run(arglist, check=True)
else:
- print('We have already started merge process.',
- 'Attempt to generate commit.')
-
- print('Generating commit message...')
- branch = subprocess.run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
- stdout=subprocess.PIPE,
- encoding='utf-8',
- check=True).stdout.rstrip()
- head = subprocess.run(['git', 'rev-parse', '--short', 'HEAD'],
- stdout=subprocess.PIPE,
- encoding='utf-8',
- check=True).stdout.rstrip()
- merge_head = subprocess.run(['git', 'rev-parse', '--short',
- 'MERGE_HEAD'],
- stdout=subprocess.PIPE,
- encoding='utf-8',
- check=True).stdout.rstrip()
-
- cmd = ' '.join(argv)
- print('Typing as fast as I can...')
+ print("We have already started merge process.", "Attempt to generate commit.")
+
+ print("Generating commit message...")
+ branch = subprocess.run(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ check=True,
+ ).stdout.rstrip()
+ head = subprocess.run(
+ ["git", "rev-parse", "--short", "HEAD"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ check=True,
+ ).stdout.rstrip()
+ merge_head = subprocess.run(
+ ["git", "rev-parse", "--short", "MERGE_HEAD"],
+ stdout=subprocess.PIPE,
+ encoding="utf-8",
+ check=True,
+ ).stdout.rstrip()
+
+ cmd = " ".join(argv)
+ print("Typing as fast as I can...")
commit_msg = git_commit_msg(branch, head, merge_head, relevant_paths, cmd)
- subprocess.run(['git', 'commit', '--signoff', '-m', commit_msg], check=True)
- subprocess.run(['git', 'commit', '--amend'], check=True)
- print(("Finished! **Please review the commit to see if it's to your "
- 'liking.**'))
+ subprocess.run(["git", "commit", "--signoff", "-m", commit_msg], check=True)
+ subprocess.run(["git", "commit", "--amend"], check=True)
+ print(("Finished! **Please review the commit to see if it's to your " "liking.**"))
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv)
diff --git a/zephyr/firmware_builder.py b/zephyr/firmware_builder.py
index f77e51d6c4..c0963a84db 100755
--- a/zephyr/firmware_builder.py
+++ b/zephyr/firmware_builder.py
@@ -15,14 +15,12 @@ import re
import subprocess
import sys
-from google.protobuf import json_format # pylint: disable=import-error
import zmake.project
-
from chromite.api.gen_sdk.chromite.api import firmware_pb2
+from google.protobuf import json_format # pylint: disable=import-error
-
-DEFAULT_BUNDLE_DIRECTORY = '/tmp/artifact_bundles'
-DEFAULT_BUNDLE_METADATA_FILE = '/tmp/artifact_bundle_metadata'
+DEFAULT_BUNDLE_DIRECTORY = "/tmp/artifact_bundles"
+DEFAULT_BUNDLE_METADATA_FILE = "/tmp/artifact_bundle_metadata"
def build(opts):
@@ -33,54 +31,52 @@ def build(opts):
platform_ec = zephyr_dir.resolve().parent
subprocess.run([platform_ec / "util" / "check_clang_format.py"], check=True)
- cmd = ['zmake', '-D', 'build', '-a']
+ cmd = ["zmake", "-D", "build", "-a"]
if opts.code_coverage:
- cmd.append('--coverage')
+ cmd.append("--coverage")
subprocess.run(cmd, cwd=pathlib.Path(__file__).parent, check=True)
if not opts.code_coverage:
for project in zmake.project.find_projects(zephyr_dir).values():
if project.config.is_test:
continue
- build_dir = (
- platform_ec / 'build' / 'zephyr' / project.config.project_name
- )
+ build_dir = platform_ec / "build" / "zephyr" / project.config.project_name
metric = metric_list.value.add()
metric.target_name = project.config.project_name
metric.platform_name = project.config.zephyr_board
for (variant, _) in project.iter_builds():
- build_log = build_dir / f'build-{variant}' / 'build.log'
+ build_log = build_dir / f"build-{variant}" / "build.log"
parse_buildlog(build_log, metric, variant.upper())
- with open(opts.metrics, 'w') as file:
+ with open(opts.metrics, "w") as file:
file.write(json_format.MessageToJson(metric_list))
return 0
UNITS = {
- 'B': 1,
- 'KB': 1024,
- 'MB': 1024 * 1024,
- 'GB': 1024 * 1024 * 1024,
+ "B": 1,
+ "KB": 1024,
+ "MB": 1024 * 1024,
+ "GB": 1024 * 1024 * 1024,
}
def parse_buildlog(filename, metric, variant):
"""Parse the build.log generated by zmake to find the size of the image."""
- with open(filename, 'r') as infile:
+ with open(filename, "r") as infile:
# Skip over all lines until the memory report is found
while True:
line = infile.readline()
if not line:
return
- if line.startswith('Memory region'):
+ if line.startswith("Memory region"):
break
for line in infile.readlines():
# Skip any lines that are not part of the report
- if not line.startswith(' '):
+ if not line.startswith(" "):
continue
parts = line.split()
fw_section = metric.fw_section.add()
- fw_section.region = variant + '_' + parts[0][:-1]
+ fw_section.region = variant + "_" + parts[0][:-1]
fw_section.used = int(parts[1]) * UNITS[parts[2]]
fw_section.total = int(parts[3]) * UNITS[parts[4]]
fw_section.track_on_gerrit = False
@@ -114,7 +110,7 @@ def write_metadata(opts, info):
bundle_metadata_file = (
opts.metadata if opts.metadata else DEFAULT_BUNDLE_METADATA_FILE
)
- with open(bundle_metadata_file, 'w') as file:
+ with open(bundle_metadata_file, "w") as file:
file.write(json_format.MessageToJson(info))
@@ -125,10 +121,10 @@ def bundle_coverage(opts):
bundle_dir = get_bundle_dir(opts)
zephyr_dir = pathlib.Path(__file__).parent
platform_ec = zephyr_dir.resolve().parent
- build_dir = platform_ec / 'build' / 'zephyr'
- tarball_name = 'coverage.tbz2'
+ build_dir = platform_ec / "build" / "zephyr"
+ tarball_name = "coverage.tbz2"
tarball_path = bundle_dir / tarball_name
- cmd = ['tar', 'cvfj', tarball_path, 'lcov.info']
+ cmd = ["tar", "cvfj", tarball_path, "lcov.info"]
subprocess.run(cmd, cwd=build_dir, check=True)
meta = info.objects.add()
meta.file_name = tarball_name
@@ -149,13 +145,11 @@ def bundle_firmware(opts):
for project in zmake.project.find_projects(zephyr_dir).values():
if project.config.is_test:
continue
- build_dir = (
- platform_ec / 'build' / 'zephyr' / project.config.project_name
- )
- artifacts_dir = build_dir / 'output'
- tarball_name = f'{project.config.project_name}.firmware.tbz2'
+ build_dir = platform_ec / "build" / "zephyr" / project.config.project_name
+ artifacts_dir = build_dir / "output"
+ tarball_name = f"{project.config.project_name}.firmware.tbz2"
tarball_path = bundle_dir.joinpath(tarball_name)
- cmd = ['tar', 'cvfj', tarball_path, '.']
+ cmd = ["tar", "cvfj", tarball_path, "."]
subprocess.run(cmd, cwd=artifacts_dir, check=True)
meta = info.objects.add()
meta.file_name = tarball_name
@@ -176,76 +170,92 @@ def test(opts):
# Run zmake tests to ensure we have a fully working zmake before
# proceeding.
- subprocess.run([zephyr_dir / 'zmake' / 'run_tests.sh'], check=True)
+ subprocess.run([zephyr_dir / "zmake" / "run_tests.sh"], check=True)
# Run formatting checks on all BUILD.py files.
- config_files = zephyr_dir.rglob('**/BUILD.py')
- subprocess.run(['black', '--diff', '--check', *config_files], check=True)
+ config_files = zephyr_dir.rglob("**/BUILD.py")
+ subprocess.run(["black", "--diff", "--check", *config_files], check=True)
- cmd = ['zmake', '-D', 'test', '-a', '--no-rebuild']
+ cmd = ["zmake", "-D", "test", "-a", "--no-rebuild"]
if opts.code_coverage:
- cmd.append('--coverage')
+ cmd.append("--coverage")
ret = subprocess.run(cmd, check=True).returncode
if ret:
return ret
if opts.code_coverage:
platform_ec = zephyr_dir.parent
- build_dir = platform_ec / 'build' / 'zephyr'
+ build_dir = platform_ec / "build" / "zephyr"
# Merge lcov files here because bundle failures are "infra" failures.
cmd = [
- '/usr/bin/lcov',
- '-o',
- build_dir / 'zephyr_merged.info',
- '--rc',
- 'lcov_branch_coverage=1',
- '-a',
- build_dir / 'all_tests.info',
- '-a',
- build_dir / 'all_builds.info',
+ "/usr/bin/lcov",
+ "-o",
+ build_dir / "zephyr_merged.info",
+ "--rc",
+ "lcov_branch_coverage=1",
+ "-a",
+ build_dir / "all_tests.info",
+ "-a",
+ build_dir / "all_builds.info",
]
output = subprocess.run(
- cmd, cwd=pathlib.Path(__file__).parent, check=True,
- stdout=subprocess.PIPE, universal_newlines=True).stdout
- _extract_lcov_summary('EC_ZEPHYR_MERGED', metrics, output)
+ cmd,
+ cwd=pathlib.Path(__file__).parent,
+ check=True,
+ stdout=subprocess.PIPE,
+ universal_newlines=True,
+ ).stdout
+ _extract_lcov_summary("EC_ZEPHYR_MERGED", metrics, output)
output = subprocess.run(
- ['/usr/bin/lcov', '--summary', build_dir / 'all_tests.info'],
- cwd=pathlib.Path(__file__).parent, check=True,
- stdout=subprocess.PIPE, universal_newlines=True).stdout
- _extract_lcov_summary('EC_ZEPHYR_TESTS', metrics, output)
-
- cmd = ['make', 'coverage', f'-j{opts.cpus}']
+ ["/usr/bin/lcov", "--summary", build_dir / "all_tests.info"],
+ cwd=pathlib.Path(__file__).parent,
+ check=True,
+ stdout=subprocess.PIPE,
+ universal_newlines=True,
+ ).stdout
+ _extract_lcov_summary("EC_ZEPHYR_TESTS", metrics, output)
+
+ cmd = ["make", "coverage", f"-j{opts.cpus}"]
print(f"# Running {' '.join(cmd)}.")
subprocess.run(cmd, cwd=platform_ec, check=True)
output = subprocess.run(
- ['/usr/bin/lcov', '--summary', platform_ec / 'build/coverage/lcov.info'],
- cwd=pathlib.Path(__file__).parent, check=True,
- stdout=subprocess.PIPE, universal_newlines=True).stdout
- _extract_lcov_summary('EC_LEGACY_MERGED', metrics, output)
+ ["/usr/bin/lcov", "--summary", platform_ec / "build/coverage/lcov.info"],
+ cwd=pathlib.Path(__file__).parent,
+ check=True,
+ stdout=subprocess.PIPE,
+ universal_newlines=True,
+ ).stdout
+ _extract_lcov_summary("EC_LEGACY_MERGED", metrics, output)
cmd = [
- '/usr/bin/lcov',
- '-o',
- build_dir / 'lcov.info',
- '--rc',
- 'lcov_branch_coverage=1',
- '-a',
- build_dir / 'zephyr_merged.info',
- '-a',
- platform_ec / 'build/coverage/lcov.info',
+ "/usr/bin/lcov",
+ "-o",
+ build_dir / "lcov.info",
+ "--rc",
+ "lcov_branch_coverage=1",
+ "-a",
+ build_dir / "zephyr_merged.info",
+ "-a",
+ platform_ec / "build/coverage/lcov.info",
]
output = subprocess.run(
- cmd, cwd=pathlib.Path(__file__).parent, check=True,
- stdout=subprocess.PIPE, universal_newlines=True).stdout
- _extract_lcov_summary('ALL_MERGED', metrics, output)
-
- with open(opts.metrics, 'w') as file:
+ cmd,
+ cwd=pathlib.Path(__file__).parent,
+ check=True,
+ stdout=subprocess.PIPE,
+ universal_newlines=True,
+ ).stdout
+ _extract_lcov_summary("ALL_MERGED", metrics, output)
+
+ with open(opts.metrics, "w") as file:
file.write(json_format.MessageToJson(metrics))
return 0
-COVERAGE_RE = re.compile(r'lines\.*: *([0-9\.]+)% \(([0-9]+) of ([0-9]+) lines\)')
+COVERAGE_RE = re.compile(r"lines\.*: *([0-9\.]+)% \(([0-9]+) of ([0-9]+) lines\)")
+
+
def _extract_lcov_summary(name, metrics, output):
re_match = COVERAGE_RE.search(output)
if re_match:
@@ -255,12 +265,13 @@ def _extract_lcov_summary(name, metrics, output):
metric.covered_lines = int(re_match.group(2))
metric.total_lines = int(re_match.group(3))
+
def main(args):
"""Builds and tests all of the Zephyr targets and reports build metrics"""
opts = parse_args(args)
- if not hasattr(opts, 'func'):
- print('Must select a valid sub command!')
+ if not hasattr(opts, "func"):
+ print("Must select a valid sub command!")
return -1
# Run selected sub command function
@@ -272,70 +283,66 @@ def parse_args(args):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
- '--cpus',
+ "--cpus",
default=multiprocessing.cpu_count(),
- help='The number of cores to use.',
+ help="The number of cores to use.",
)
parser.add_argument(
- '--metrics',
- dest='metrics',
+ "--metrics",
+ dest="metrics",
required=True,
- help='File to write the json-encoded MetricsList proto message.',
+ help="File to write the json-encoded MetricsList proto message.",
)
parser.add_argument(
- '--metadata',
+ "--metadata",
required=False,
help=(
- 'Full pathname for the file in which to write build artifact '
- 'metadata.'
+ "Full pathname for the file in which to write build artifact " "metadata."
),
)
parser.add_argument(
- '--output-dir',
+ "--output-dir",
required=False,
- help=(
- 'Full pathname for the directory in which to bundle build '
- 'artifacts.'
- ),
+ help=("Full pathname for the directory in which to bundle build " "artifacts."),
)
parser.add_argument(
- '--code-coverage',
+ "--code-coverage",
required=False,
- action='store_true',
- help='Build host-based unit tests for code coverage.',
+ action="store_true",
+ help="Build host-based unit tests for code coverage.",
)
parser.add_argument(
- '--bcs-version',
- dest='bcs_version',
- default='',
+ "--bcs-version",
+ dest="bcs_version",
+ default="",
required=False,
# TODO(b/180008931): make this required=True.
- help='BCS version to include in metadata.',
+ help="BCS version to include in metadata.",
)
# Would make this required=True, but not available until 3.7
sub_cmds = parser.add_subparsers()
- build_cmd = sub_cmds.add_parser('build', help='Builds all firmware targets')
+ build_cmd = sub_cmds.add_parser("build", help="Builds all firmware targets")
build_cmd.set_defaults(func=build)
build_cmd = sub_cmds.add_parser(
- 'bundle',
- help='Creates a tarball containing build '
- 'artifacts from all firmware targets',
+ "bundle",
+ help="Creates a tarball containing build "
+ "artifacts from all firmware targets",
)
build_cmd.set_defaults(func=bundle)
- test_cmd = sub_cmds.add_parser('test', help='Runs all firmware unit tests')
+ test_cmd = sub_cmds.add_parser("test", help="Runs all firmware unit tests")
test_cmd.set_defaults(func=test)
return parser.parse_args(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/zephyr/zmake/tests/conftest.py b/zephyr/zmake/tests/conftest.py
index 38e34bef56..be1de01401 100644
--- a/zephyr/zmake/tests/conftest.py
+++ b/zephyr/zmake/tests/conftest.py
@@ -9,7 +9,6 @@ import pathlib
import hypothesis
import pytest
-
import zmake.zmake as zm
hypothesis.settings.register_profile(
diff --git a/zephyr/zmake/tests/test_build_config.py b/zephyr/zmake/tests/test_build_config.py
index 76cc0a2028..f79ed1f8a0 100644
--- a/zephyr/zmake/tests/test_build_config.py
+++ b/zephyr/zmake/tests/test_build_config.py
@@ -13,7 +13,6 @@ import tempfile
import hypothesis
import hypothesis.strategies as st
import pytest
-
import zmake.jobserver
import zmake.util as util
from zmake.build_config import BuildConfig
diff --git a/zephyr/zmake/tests/test_generate_readme.py b/zephyr/zmake/tests/test_generate_readme.py
index cb4bcf6cc1..2149b3fc6e 100644
--- a/zephyr/zmake/tests/test_generate_readme.py
+++ b/zephyr/zmake/tests/test_generate_readme.py
@@ -7,7 +7,6 @@ Tests for the generate_readme.py file.
"""
import pytest
-
import zmake.generate_readme as gen_readme
diff --git a/zephyr/zmake/tests/test_modules.py b/zephyr/zmake/tests/test_modules.py
index 600544d2e7..9446e54f1c 100644
--- a/zephyr/zmake/tests/test_modules.py
+++ b/zephyr/zmake/tests/test_modules.py
@@ -9,7 +9,6 @@ import tempfile
import hypothesis
import hypothesis.strategies as st
-
import zmake.modules
module_lists = st.lists(
diff --git a/zephyr/zmake/tests/test_packers.py b/zephyr/zmake/tests/test_packers.py
index 43b63a908f..402cee690e 100644
--- a/zephyr/zmake/tests/test_packers.py
+++ b/zephyr/zmake/tests/test_packers.py
@@ -10,7 +10,6 @@ import tempfile
import hypothesis
import hypothesis.strategies as st
import pytest
-
import zmake.output_packers as packers
# Strategies for use with hypothesis
diff --git a/zephyr/zmake/tests/test_project.py b/zephyr/zmake/tests/test_project.py
index b5facbc331..5b5ca12583 100644
--- a/zephyr/zmake/tests/test_project.py
+++ b/zephyr/zmake/tests/test_project.py
@@ -11,7 +11,6 @@ import tempfile
import hypothesis
import hypothesis.strategies as st
import pytest
-
import zmake.modules
import zmake.output_packers
import zmake.project
diff --git a/zephyr/zmake/tests/test_reexec.py b/zephyr/zmake/tests/test_reexec.py
index 5d7905cd8f..08943909b2 100644
--- a/zephyr/zmake/tests/test_reexec.py
+++ b/zephyr/zmake/tests/test_reexec.py
@@ -8,7 +8,6 @@ import sys
import unittest.mock as mock
import pytest
-
import zmake.__main__ as main
diff --git a/zephyr/zmake/tests/test_toolchains.py b/zephyr/zmake/tests/test_toolchains.py
index 910a5faa78..f210bb7511 100644
--- a/zephyr/zmake/tests/test_toolchains.py
+++ b/zephyr/zmake/tests/test_toolchains.py
@@ -8,7 +8,6 @@ import os
import pathlib
import pytest
-
import zmake.output_packers
import zmake.project as project
import zmake.toolchains as toolchains
diff --git a/zephyr/zmake/tests/test_util.py b/zephyr/zmake/tests/test_util.py
index 1ec0076162..4a6c39f904 100644
--- a/zephyr/zmake/tests/test_util.py
+++ b/zephyr/zmake/tests/test_util.py
@@ -10,7 +10,6 @@ import tempfile
import hypothesis
import hypothesis.strategies as st
import pytest
-
import zmake.util as util
# Strategies for use with hypothesis
diff --git a/zephyr/zmake/tests/test_zmake.py b/zephyr/zmake/tests/test_zmake.py
index 4ca1d7f077..1c892ca2e4 100644
--- a/zephyr/zmake/tests/test_zmake.py
+++ b/zephyr/zmake/tests/test_zmake.py
@@ -11,14 +11,13 @@ import re
import unittest.mock
import pytest
-from testfixtures import LogCapture
-
import zmake.build_config
import zmake.jobserver
import zmake.multiproc as multiproc
import zmake.output_packers
import zmake.project
import zmake.toolchains
+from testfixtures import LogCapture
OUR_PATH = os.path.dirname(os.path.realpath(__file__))