diff options
author | Jeremy Bettis <jbettis@google.com> | 2022-07-08 10:58:19 -0600 |
---|---|---|
committer | Chromeos LUCI <chromeos-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2022-07-12 19:13:33 +0000 |
commit | 7540e7b47b55447475bb8191fb3520dd67cf7998 (patch) | |
tree | 13309dbcf1db48e60fa2c2e5aed79f63bce00b5e | |
parent | 7c114b8e1a3bb29991da70b9de394ac5d4f6c909 (diff) | |
download | chrome-ec-7540e7b47b55447475bb8191fb3520dd67cf7998.tar.gz |
ec: Format all python files with black and isort
find . \( -path ./private -prune \) -o -name '*.py' -print | xargs black
find . \( -path ./private -prune \) -o -name '*.py' -print | xargs ~/chromiumos/chromite/scripts/isort --settings-file=.isort.cfg
BRANCH=None
BUG=b:238434058
TEST=None
Signed-off-by: Jeremy Bettis <jbettis@google.com>
Change-Id: I63462d6f15d1eaf3db84eb20d1404ee976be8382
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/ec/+/3749242
Commit-Queue: Jeremy Bettis <jbettis@chromium.org>
Reviewed-by: Tom Hughes <tomhughes@chromium.org>
Tested-by: Jeremy Bettis <jbettis@chromium.org>
Commit-Queue: Jack Rosenthal <jrosenth@chromium.org>
Auto-Submit: Jeremy Bettis <jbettis@chromium.org>
Reviewed-by: Jack Rosenthal <jrosenth@chromium.org>
62 files changed, 16388 insertions, 14541 deletions
diff --git a/chip/ish/util/pack_ec.py b/chip/ish/util/pack_ec.py index bd9b823cab..8dde6ab6a9 100755 --- a/chip/ish/util/pack_ec.py +++ b/chip/ish/util/pack_ec.py @@ -28,85 +28,100 @@ MANIFEST_ENTRY_SIZE = 0x80 HEADER_SIZE = 0x1000 PAGE_SIZE = 0x1000 + def parseargs(): - parser = argparse.ArgumentParser() - parser.add_argument("-k", "--kernel", - help="EC kernel binary to pack, \ + parser = argparse.ArgumentParser() + parser.add_argument( + "-k", + "--kernel", + help="EC kernel binary to pack, \ usually ec.RW.bin or ec.RW.flat.", - required=True) - parser.add_argument("--kernel-size", type=int, - help="Size of EC kernel image", - required=True) - parser.add_argument("-a", "--aon", - help="EC aontask binary to pack, \ + required=True, + ) + parser.add_argument( + "--kernel-size", type=int, help="Size of EC kernel image", required=True + ) + parser.add_argument( + "-a", + "--aon", + help="EC aontask binary to pack, \ usually ish_aontask.bin.", - required=False) - parser.add_argument("--aon-size", type=int, - help="Size of EC aontask image", - required=False) - parser.add_argument("-o", "--output", - help="Output flash binary file") + required=False, + ) + parser.add_argument( + "--aon-size", type=int, help="Size of EC aontask image", required=False + ) + parser.add_argument("-o", "--output", help="Output flash binary file") + + return parser.parse_args() - return parser.parse_args() def gen_manifest(ext_id, comp_app_name, code_offset, module_size): - """Returns a binary blob that represents a manifest entry""" - m = bytearray(MANIFEST_ENTRY_SIZE) + """Returns a binary blob that represents a manifest entry""" + m = bytearray(MANIFEST_ENTRY_SIZE) - # 4 bytes of ASCII encode ID (little endian) - struct.pack_into('<4s', m, 0, ext_id) - # 8 bytes of ASCII encode ID (little endian) - struct.pack_into('<8s', m, 32, comp_app_name) - # 4 bytes of code offset (little endian) - struct.pack_into('<I', m, 96, code_offset) - # 2 bytes of module in page size increments (little endian) - struct.pack_into('<H', m, 100, module_size) + # 4 bytes of ASCII encode ID (little endian) + struct.pack_into("<4s", m, 0, ext_id) + # 8 bytes of ASCII encode ID (little endian) + struct.pack_into("<8s", m, 32, comp_app_name) + # 4 bytes of code offset (little endian) + struct.pack_into("<I", m, 96, code_offset) + # 2 bytes of module in page size increments (little endian) + struct.pack_into("<H", m, 100, module_size) + + return m - return m def roundup_page(size): - """Returns roundup-ed page size from size of bytes""" - return int(size / PAGE_SIZE) + (size % PAGE_SIZE > 0) + """Returns roundup-ed page size from size of bytes""" + return int(size / PAGE_SIZE) + (size % PAGE_SIZE > 0) + def main(): - args = parseargs() - print(" Packing EC image file for ISH") - - with open(args.output, 'wb') as f: - print(" kernel binary size:", args.kernel_size) - kern_rdup_pg_size = roundup_page(args.kernel_size) - # Add manifest for main ISH binary - f.write(gen_manifest(b'ISHM', b'ISH_KERN', HEADER_SIZE, kern_rdup_pg_size)) - - if args.aon is not None: - print(" AON binary size: ", args.aon_size) - aon_rdup_pg_size = roundup_page(args.aon_size) - # Add manifest for aontask binary - f.write(gen_manifest(b'ISHM', b'AON_TASK', - (HEADER_SIZE + kern_rdup_pg_size * PAGE_SIZE - - MANIFEST_ENTRY_SIZE), aon_rdup_pg_size)) - - # Add manifest that signals end of manifests - f.write(gen_manifest(b'ISHE', b'', 0, 0)) - - # Pad the remaining HEADER with 0s - if args.aon is not None: - f.write(b'\x00' * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 3))) - else: - f.write(b'\x00' * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 2))) - - # Append original kernel image - with open(args.kernel, 'rb') as in_file: - f.write(in_file.read()) - # Filling padings due to size round up as pages - f.write(b'\x00' * (kern_rdup_pg_size * PAGE_SIZE - args.kernel_size)) - - if args.aon is not None: - # Append original aon image - with open(args.aon, 'rb') as in_file: - f.write(in_file.read()) - # Filling padings due to size round up as pages - f.write(b'\x00' * (aon_rdup_pg_size * PAGE_SIZE - args.aon_size)) - -if __name__ == '__main__': - main() + args = parseargs() + print(" Packing EC image file for ISH") + + with open(args.output, "wb") as f: + print(" kernel binary size:", args.kernel_size) + kern_rdup_pg_size = roundup_page(args.kernel_size) + # Add manifest for main ISH binary + f.write(gen_manifest(b"ISHM", b"ISH_KERN", HEADER_SIZE, kern_rdup_pg_size)) + + if args.aon is not None: + print(" AON binary size: ", args.aon_size) + aon_rdup_pg_size = roundup_page(args.aon_size) + # Add manifest for aontask binary + f.write( + gen_manifest( + b"ISHM", + b"AON_TASK", + (HEADER_SIZE + kern_rdup_pg_size * PAGE_SIZE - MANIFEST_ENTRY_SIZE), + aon_rdup_pg_size, + ) + ) + + # Add manifest that signals end of manifests + f.write(gen_manifest(b"ISHE", b"", 0, 0)) + + # Pad the remaining HEADER with 0s + if args.aon is not None: + f.write(b"\x00" * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 3))) + else: + f.write(b"\x00" * (HEADER_SIZE - (MANIFEST_ENTRY_SIZE * 2))) + + # Append original kernel image + with open(args.kernel, "rb") as in_file: + f.write(in_file.read()) + # Filling padings due to size round up as pages + f.write(b"\x00" * (kern_rdup_pg_size * PAGE_SIZE - args.kernel_size)) + + if args.aon is not None: + # Append original aon image + with open(args.aon, "rb") as in_file: + f.write(in_file.read()) + # Filling padings due to size round up as pages + f.write(b"\x00" * (aon_rdup_pg_size * PAGE_SIZE - args.aon_size)) + + +if __name__ == "__main__": + main() diff --git a/chip/mchp/util/pack_ec.py b/chip/mchp/util/pack_ec.py index 7908b0bf37..15be16c0d4 100755 --- a/chip/mchp/util/pack_ec.py +++ b/chip/mchp/util/pack_ec.py @@ -16,7 +16,7 @@ import os import struct import subprocess import tempfile -import zlib # CRC32 +import zlib # CRC32 # MEC1701 has 256KB SRAM from 0xE0000 - 0x120000 # SRAM is divided into contiguous CODE & DATA @@ -30,165 +30,199 @@ LOAD_ADDR = 0x0E0000 LOAD_ADDR_RW = 0xE1000 HEADER_SIZE = 0x40 SPI_CLOCK_LIST = [48, 24, 16, 12] -SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b, 0x6b] +SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B, 0x6B] + +CRC_TABLE = [ + 0x00, + 0x07, + 0x0E, + 0x09, + 0x1C, + 0x1B, + 0x12, + 0x15, + 0x38, + 0x3F, + 0x36, + 0x31, + 0x24, + 0x23, + 0x2A, + 0x2D, +] -CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15, - 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d] def mock_print(*args, **kwargs): - pass + pass + debug_print = mock_print + def Crc8(crc, data): - """Update CRC8 value.""" - for v in data: - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]); - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]); - return crc ^ 0x55 + """Update CRC8 value.""" + for v in data: + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]) + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)]) + return crc ^ 0x55 + def GetEntryPoint(payload_file): - """Read entry point from payload EC image.""" - with open(payload_file, 'rb') as f: - f.seek(4) - s = f.read(4) - return struct.unpack('<I', s)[0] + """Read entry point from payload EC image.""" + with open(payload_file, "rb") as f: + f.seek(4) + s = f.read(4) + return struct.unpack("<I", s)[0] + def GetPayloadFromOffset(payload_file, offset): - """Read payload and pad it to 64-byte aligned.""" - with open(payload_file, 'rb') as f: - f.seek(offset) - payload = bytearray(f.read()) - rem_len = len(payload) % 64 - if rem_len: - payload += b'\0' * (64 - rem_len) - return payload + """Read payload and pad it to 64-byte aligned.""" + with open(payload_file, "rb") as f: + f.seek(offset) + payload = bytearray(f.read()) + rem_len = len(payload) % 64 + if rem_len: + payload += b"\0" * (64 - rem_len) + return payload + def GetPayload(payload_file): - """Read payload and pad it to 64-byte aligned.""" - return GetPayloadFromOffset(payload_file, 0) + """Read payload and pad it to 64-byte aligned.""" + return GetPayloadFromOffset(payload_file, 0) + def GetPublicKey(pem_file): - """Extract public exponent and modulus from PEM file.""" - result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text', - '-noout'], stdout=subprocess.PIPE, encoding='utf-8') - modulus_raw = [] - in_modulus = False - for line in result.stdout.splitlines(): - if line.startswith('modulus'): - in_modulus = True - elif not line.startswith(' '): - in_modulus = False - elif in_modulus: - modulus_raw.extend(line.strip().strip(':').split(':')) - if line.startswith('publicExponent'): - exp = int(line.split(' ')[1], 10) - modulus_raw.reverse() - modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) - return struct.pack('<Q', exp), modulus + """Extract public exponent and modulus from PEM file.""" + result = subprocess.run( + ["openssl", "rsa", "-in", pem_file, "-text", "-noout"], + stdout=subprocess.PIPE, + encoding="utf-8", + ) + modulus_raw = [] + in_modulus = False + for line in result.stdout.splitlines(): + if line.startswith("modulus"): + in_modulus = True + elif not line.startswith(" "): + in_modulus = False + elif in_modulus: + modulus_raw.extend(line.strip().strip(":").split(":")) + if line.startswith("publicExponent"): + exp = int(line.split(" ")[1], 10) + modulus_raw.reverse() + modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) + return struct.pack("<Q", exp), modulus + def GetSpiClockParameter(args): - assert args.spi_clock in SPI_CLOCK_LIST, \ - "Unsupported SPI clock speed %d MHz" % args.spi_clock - return SPI_CLOCK_LIST.index(args.spi_clock) + assert args.spi_clock in SPI_CLOCK_LIST, ( + "Unsupported SPI clock speed %d MHz" % args.spi_clock + ) + return SPI_CLOCK_LIST.index(args.spi_clock) + def GetSpiReadCmdParameter(args): - assert args.spi_read_cmd in SPI_READ_CMD_LIST, \ - "Unsupported SPI read command 0x%x" % args.spi_read_cmd - return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + assert args.spi_read_cmd in SPI_READ_CMD_LIST, ( + "Unsupported SPI read command 0x%x" % args.spi_read_cmd + ) + return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + def PadZeroTo(data, size): - data.extend(b'\0' * (size - len(data))) + data.extend(b"\0" * (size - len(data))) + def BuildHeader(args, payload_len, load_addr, rorofile): - # Identifier and header version - header = bytearray(b'PHCM\0') + # Identifier and header version + header = bytearray(b"PHCM\0") - # byte[5] - b = GetSpiClockParameter(args) - b |= (1 << 2) - header.append(b) + # byte[5] + b = GetSpiClockParameter(args) + b |= 1 << 2 + header.append(b) - # byte[6] - b = 0 - header.append(b) + # byte[6] + b = 0 + header.append(b) - # byte[7] - header.append(GetSpiReadCmdParameter(args)) + # byte[7] + header.append(GetSpiReadCmdParameter(args)) - # bytes 0x08 - 0x0b - header.extend(struct.pack('<I', load_addr)) - # bytes 0x0c - 0x0f - header.extend(struct.pack('<I', GetEntryPoint(rorofile))) - # bytes 0x10 - 0x13 - header.append((payload_len >> 6) & 0xff) - header.append((payload_len >> 14) & 0xff) - PadZeroTo(header, 0x14) - # bytes 0x14 - 0x17 - header.extend(struct.pack('<I', args.payload_offset)) + # bytes 0x08 - 0x0b + header.extend(struct.pack("<I", load_addr)) + # bytes 0x0c - 0x0f + header.extend(struct.pack("<I", GetEntryPoint(rorofile))) + # bytes 0x10 - 0x13 + header.append((payload_len >> 6) & 0xFF) + header.append((payload_len >> 14) & 0xFF) + PadZeroTo(header, 0x14) + # bytes 0x14 - 0x17 + header.extend(struct.pack("<I", args.payload_offset)) - # bytes 0x14 - 0x3F all 0 - PadZeroTo(header, 0x40) + # bytes 0x14 - 0x3F all 0 + PadZeroTo(header, 0x40) - # header signature is appended by the caller + # header signature is appended by the caller - return header + return header def BuildHeader2(args, payload_len, load_addr, payload_entry): - # Identifier and header version - header = bytearray(b'PHCM\0') + # Identifier and header version + header = bytearray(b"PHCM\0") - # byte[5] - b = GetSpiClockParameter(args) - b |= (1 << 2) - header.append(b) + # byte[5] + b = GetSpiClockParameter(args) + b |= 1 << 2 + header.append(b) - # byte[6] - b = 0 - header.append(b) + # byte[6] + b = 0 + header.append(b) - # byte[7] - header.append(GetSpiReadCmdParameter(args)) + # byte[7] + header.append(GetSpiReadCmdParameter(args)) - # bytes 0x08 - 0x0b - header.extend(struct.pack('<I', load_addr)) - # bytes 0x0c - 0x0f - header.extend(struct.pack('<I', payload_entry)) - # bytes 0x10 - 0x13 - header.append((payload_len >> 6) & 0xff) - header.append((payload_len >> 14) & 0xff) - PadZeroTo(header, 0x14) - # bytes 0x14 - 0x17 - header.extend(struct.pack('<I', args.payload_offset)) + # bytes 0x08 - 0x0b + header.extend(struct.pack("<I", load_addr)) + # bytes 0x0c - 0x0f + header.extend(struct.pack("<I", payload_entry)) + # bytes 0x10 - 0x13 + header.append((payload_len >> 6) & 0xFF) + header.append((payload_len >> 14) & 0xFF) + PadZeroTo(header, 0x14) + # bytes 0x14 - 0x17 + header.extend(struct.pack("<I", args.payload_offset)) - # bytes 0x14 - 0x3F all 0 - PadZeroTo(header, 0x40) + # bytes 0x14 - 0x3F all 0 + PadZeroTo(header, 0x40) - # header signature is appended by the caller + # header signature is appended by the caller + + return header - return header # # Compute SHA-256 of data and return digest # as a bytearray # def HashByteArray(data): - hasher = hashlib.sha256() - hasher.update(data) - h = hasher.digest() - bah = bytearray(h) - return bah + hasher = hashlib.sha256() + hasher.update(data) + h = hasher.digest() + bah = bytearray(h) + return bah + # # Return 64-byte signature of byte array data. # Signature is SHA256 of data with 32 0 bytes appended # def SignByteArray(data): - debug_print("Signature is SHA-256 of data") - sigb = HashByteArray(data) - sigb.extend(b'\0' * 32) - return sigb + debug_print("Signature is SHA-256 of data") + sigb = HashByteArray(data) + sigb.extend(b"\0" * 32) + return sigb # MEC1701H supports two 32-bit Tags located at offsets 0x0 and 0x4 @@ -201,16 +235,21 @@ def SignByteArray(data): # to the same flash part. # def BuildTag(args): - tag = bytearray([(args.header_loc >> 8) & 0xff, - (args.header_loc >> 16) & 0xff, - (args.header_loc >> 24) & 0xff]) - tag.append(Crc8(0, tag)) - return tag + tag = bytearray( + [ + (args.header_loc >> 8) & 0xFF, + (args.header_loc >> 16) & 0xFF, + (args.header_loc >> 24) & 0xFF, + ] + ) + tag.append(Crc8(0, tag)) + return tag + def BuildTagFromHdrAddr(header_loc): - tag = bytearray([(header_loc >> 8) & 0xff, - (header_loc >> 16) & 0xff, - (header_loc >> 24) & 0xff]) + tag = bytearray( + [(header_loc >> 8) & 0xFF, (header_loc >> 16) & 0xFF, (header_loc >> 24) & 0xFF] + ) tag.append(Crc8(0, tag)) return tag @@ -224,20 +263,21 @@ def BuildTagFromHdrAddr(header_loc): # Returns temporary file name # def PacklfwRoImage(rorw_file, loader_file, image_size): - """Create a temp file with the - first image_size bytes from the loader file and append bytes - from the rorw file. - return the filename""" - fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around - with open(loader_file,'rb') as fin1: # read 4KB loader file - pro = fin1.read() - fo.write(pro) # write 4KB loader data to temp file - with open(rorw_file, 'rb') as fin: - ro = fin.read(image_size) - - fo.write(ro) - fo.close() - return fo.name + """Create a temp file with the + first image_size bytes from the loader file and append bytes + from the rorw file. + return the filename""" + fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around + with open(loader_file, "rb") as fin1: # read 4KB loader file + pro = fin1.read() + fo.write(pro) # write 4KB loader data to temp file + with open(rorw_file, "rb") as fin: + ro = fin.read(image_size) + + fo.write(ro) + fo.close() + return fo.name + # # Generate a test EC_RW image of same size @@ -248,105 +288,145 @@ def PacklfwRoImage(rorw_file, loader_file, image_size): # process hash generation. # def gen_test_ecrw(pldrw): - debug_print("gen_test_ecrw: pldrw type =", type(pldrw)) - debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw))) - cookie1_pos = pldrw.find(b'\x99\x88\x77\xce') - cookie2_pos = pldrw.find(b'\xdd\xbb\xaa\xce', cookie1_pos+4) - t = struct.unpack("<L", pldrw[cookie1_pos+0x24:cookie1_pos+0x28]) - size = t[0] - debug_print("EC_RW size =", size, " = ", hex(size)) - - debug_print("Found cookie1 at ", hex(cookie1_pos)) - debug_print("Found cookie2 at ", hex(cookie2_pos)) - - if cookie1_pos > 0 and cookie2_pos > cookie1_pos: - for i in range(0, cookie1_pos): - pldrw[i] = 0xA5 - for i in range(cookie2_pos+4, len(pldrw)): - pldrw[i] = 0xA5 - - with open("ec_RW_test.bin", "wb") as fecrw: - fecrw.write(pldrw[:size]) + debug_print("gen_test_ecrw: pldrw type =", type(pldrw)) + debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw))) + cookie1_pos = pldrw.find(b"\x99\x88\x77\xce") + cookie2_pos = pldrw.find(b"\xdd\xbb\xaa\xce", cookie1_pos + 4) + t = struct.unpack("<L", pldrw[cookie1_pos + 0x24 : cookie1_pos + 0x28]) + size = t[0] + debug_print("EC_RW size =", size, " = ", hex(size)) + + debug_print("Found cookie1 at ", hex(cookie1_pos)) + debug_print("Found cookie2 at ", hex(cookie2_pos)) + + if cookie1_pos > 0 and cookie2_pos > cookie1_pos: + for i in range(0, cookie1_pos): + pldrw[i] = 0xA5 + for i in range(cookie2_pos + 4, len(pldrw)): + pldrw[i] = 0xA5 + + with open("ec_RW_test.bin", "wb") as fecrw: + fecrw.write(pldrw[:size]) + def parseargs(): - rpath = os.path.dirname(os.path.relpath(__file__)) - - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", - help="EC binary to pack, usually ec.bin or ec.RO.flat.", - metavar="EC_BIN", default="ec.bin") - parser.add_argument("-o", "--output", - help="Output flash binary file", - metavar="EC_SPI_FLASH", default="ec.packed.bin") - parser.add_argument("--loader_file", - help="EC loader binary", - default="ecloader.bin") - parser.add_argument("-s", "--spi_size", type=int, - help="Size of the SPI flash in KB", - default=512) - parser.add_argument("-l", "--header_loc", type=int, - help="Location of header in SPI flash", - default=0x1000) - parser.add_argument("-p", "--payload_offset", type=int, - help="The offset of payload from the start of header", - default=0x80) - parser.add_argument("-r", "--rw_loc", type=int, - help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size", - default=-1) - parser.add_argument("--spi_clock", type=int, - help="SPI clock speed. 8, 12, 24, or 48 MHz.", - default=24) - parser.add_argument("--spi_read_cmd", type=int, - help="SPI read command. 0x3, 0xB, or 0x3B.", - default=0xb) - parser.add_argument("--image_size", type=int, - help="Size of a single image. Default 220KB", - default=(220 * 1024)) - parser.add_argument("--test_spi", action='store_true', - help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries", - default=False) - parser.add_argument("--test_ecrw", action='store_true', - help="Use fixed pattern for EC_RW but preserve image_data", - default=False) - parser.add_argument("--verbose", action='store_true', - help="Enable verbose output", - default=False) - - return parser.parse_args() + rpath = os.path.dirname(os.path.relpath(__file__)) + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input", + help="EC binary to pack, usually ec.bin or ec.RO.flat.", + metavar="EC_BIN", + default="ec.bin", + ) + parser.add_argument( + "-o", + "--output", + help="Output flash binary file", + metavar="EC_SPI_FLASH", + default="ec.packed.bin", + ) + parser.add_argument( + "--loader_file", help="EC loader binary", default="ecloader.bin" + ) + parser.add_argument( + "-s", "--spi_size", type=int, help="Size of the SPI flash in KB", default=512 + ) + parser.add_argument( + "-l", + "--header_loc", + type=int, + help="Location of header in SPI flash", + default=0x1000, + ) + parser.add_argument( + "-p", + "--payload_offset", + type=int, + help="The offset of payload from the start of header", + default=0x80, + ) + parser.add_argument( + "-r", + "--rw_loc", + type=int, + help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size", + default=-1, + ) + parser.add_argument( + "--spi_clock", + type=int, + help="SPI clock speed. 8, 12, 24, or 48 MHz.", + default=24, + ) + parser.add_argument( + "--spi_read_cmd", + type=int, + help="SPI read command. 0x3, 0xB, or 0x3B.", + default=0xB, + ) + parser.add_argument( + "--image_size", + type=int, + help="Size of a single image. Default 220KB", + default=(220 * 1024), + ) + parser.add_argument( + "--test_spi", + action="store_true", + help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries", + default=False, + ) + parser.add_argument( + "--test_ecrw", + action="store_true", + help="Use fixed pattern for EC_RW but preserve image_data", + default=False, + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose output", default=False + ) + + return parser.parse_args() + # Debug helper routine def dumpsects(spi_list): - debug_print("spi_list has {0} entries".format(len(spi_list))) - for s in spi_list: - debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0],len(s[1]),s[2])) + debug_print("spi_list has {0} entries".format(len(spi_list))) + for s in spi_list: + debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0], len(s[1]), s[2])) + def printByteArrayAsHex(ba, title): - debug_print(title,"= ") - count = 0 - for b in ba: - count = count + 1 - debug_print("0x{0:02x}, ".format(b),end="") - if (count % 8) == 0: - debug_print("") - debug_print("\n") + debug_print(title, "= ") + count = 0 + for b in ba: + count = count + 1 + debug_print("0x{0:02x}, ".format(b), end="") + if (count % 8) == 0: + debug_print("") + debug_print("\n") + def print_args(args): - debug_print("parsed arguments:") - debug_print(".input = ", args.input) - debug_print(".output = ", args.output) - debug_print(".loader_file = ", args.loader_file) - debug_print(".spi_size (KB) = ", hex(args.spi_size)) - debug_print(".image_size = ", hex(args.image_size)) - debug_print(".header_loc = ", hex(args.header_loc)) - debug_print(".payload_offset = ", hex(args.payload_offset)) - if args.rw_loc < 0: - debug_print(".rw_loc = ", args.rw_loc) - else: - debug_print(".rw_loc = ", hex(args.rw_loc)) - debug_print(".spi_clock = ", args.spi_clock) - debug_print(".spi_read_cmd = ", args.spi_read_cmd) - debug_print(".test_spi = ", args.test_spi) - debug_print(".verbose = ", args.verbose) + debug_print("parsed arguments:") + debug_print(".input = ", args.input) + debug_print(".output = ", args.output) + debug_print(".loader_file = ", args.loader_file) + debug_print(".spi_size (KB) = ", hex(args.spi_size)) + debug_print(".image_size = ", hex(args.image_size)) + debug_print(".header_loc = ", hex(args.header_loc)) + debug_print(".payload_offset = ", hex(args.payload_offset)) + if args.rw_loc < 0: + debug_print(".rw_loc = ", args.rw_loc) + else: + debug_print(".rw_loc = ", hex(args.rw_loc)) + debug_print(".spi_clock = ", args.spi_clock) + debug_print(".spi_read_cmd = ", args.spi_read_cmd) + debug_print(".test_spi = ", args.test_spi) + debug_print(".verbose = ", args.verbose) + # # Handle quiet mode build from Makefile @@ -354,183 +434,188 @@ def print_args(args): # Verbose mode when V=1 # def main(): - global debug_print - - args = parseargs() - - if args.verbose: - debug_print = print - - debug_print("Begin MEC17xx pack_ec.py script") - - - # MEC17xx maximum 192KB each for RO & RW - # mec1701 chip Makefile sets args.spi_size = 512 - # Tags at offset 0 - # - print_args(args) - - spi_size = args.spi_size * 1024 - debug_print("SPI Flash image size in bytes =", hex(spi_size)) - - # !!! IMPORTANT !!! - # These values MUST match chip/mec1701/config_flash_layout.h - # defines. - # MEC17xx Boot-ROM TAGs are at offset 0 and 4. - # lfw + EC_RO starts at beginning of second 4KB sector - # EC_RW starts at offset 0x40000 (256KB) - - spi_list = [] - - debug_print("args.input = ",args.input) - debug_print("args.loader_file = ",args.loader_file) - debug_print("args.image_size = ",hex(args.image_size)) - - rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size) - - payload = GetPayload(rorofile) - payload_len = len(payload) - # debug - debug_print("EC_LFW + EC_RO length = ",hex(payload_len)) - - # SPI image integrity test - # compute CRC32 of EC_RO except for last 4 bytes - # skip over 4KB LFW - # Store CRC32 in last 4 bytes - if args.test_spi == True: - crc = zlib.crc32(bytes(payload[LFW_SIZE:(payload_len - 4)])) - crc_ofs = payload_len - 4 - debug_print("EC_RO CRC32 = 0x{0:08x} @ 0x{1:08x}".format(crc, crc_ofs)) - for i in range(4): - payload[crc_ofs + i] = crc & 0xff - crc = crc >> 8 - - # Chromebooks are not using MEC BootROM ECDSA. - # We implemented the ECDSA disabled case where - # the 64-byte signature contains a SHA-256 of the binary plus - # 32 zeros bytes. - payload_signature = SignByteArray(payload) - # debug - printByteArrayAsHex(payload_signature, "LFW + EC_RO payload_signature") - - # MEC17xx Header is 0x80 bytes with an 64 byte signature - # (32 byte SHA256 + 32 zero bytes) - header = BuildHeader(args, payload_len, LOAD_ADDR, rorofile) - # debug - printByteArrayAsHex(header, "Header LFW + EC_RO") - - # MEC17xx payload ECDSA not used, 64 byte signature is - # SHA256 + 32 zero bytes - header_signature = SignByteArray(header) - # debug - printByteArrayAsHex(header_signature, "header_signature") - - tag = BuildTag(args) - # MEC17xx truncate RW length to args.image_size to not overwrite LFW - # offset may be different due to Header size and other changes - # MCHP we want to append a SHA-256 to the end of the actual payload - # to test SPI read routines. - debug_print("Call to GetPayloadFromOffset") - debug_print("args.input = ", args.input) - debug_print("args.image_size = ", hex(args.image_size)) - - payload_rw = GetPayloadFromOffset(args.input, args.image_size) - debug_print("type(payload_rw) is ", type(payload_rw)) - debug_print("len(payload_rw) is ", hex(len(payload_rw))) - - # truncate to args.image_size - rw_len = args.image_size - payload_rw = payload_rw[:rw_len] - payload_rw_len = len(payload_rw) - debug_print("Truncated size of EC_RW = ", hex(payload_rw_len)) - - payload_entry_tuple = struct.unpack_from('<I', payload_rw, 4) - debug_print("payload_entry_tuple = ", payload_entry_tuple) - - payload_entry = payload_entry_tuple[0] - debug_print("payload_entry = ", hex(payload_entry)) - - # Note: payload_rw is a bytearray therefore is mutable - if args.test_ecrw: - gen_test_ecrw(payload_rw) - - # SPI image integrity test - # compute CRC32 of EC_RW except for last 4 bytes - # Store CRC32 in last 4 bytes - if args.test_spi == True: - crc = zlib.crc32(bytes(payload_rw[:(payload_rw_len - 32)])) - crc_ofs = payload_rw_len - 4 - debug_print("EC_RW CRC32 = 0x{0:08x} at offset 0x{1:08x}".format(crc, crc_ofs)) - for i in range(4): - payload_rw[crc_ofs + i] = crc & 0xff - crc = crc >> 8 - - payload_rw_sig = SignByteArray(payload_rw) - # debug - printByteArrayAsHex(payload_rw_sig, "payload_rw_sig") - - os.remove(rorofile) # clean up the temp file - - # MEC170x Boot-ROM Tags are located at SPI offset 0 - spi_list.append((0, tag, "tag")) - - spi_list.append((args.header_loc, header, "header(lwf + ro)")) - spi_list.append((args.header_loc + HEADER_SIZE, header_signature, - "header(lwf + ro) signature")) - spi_list.append((args.header_loc + args.payload_offset, payload, - "payload(lfw + ro)")) - - offset = args.header_loc + args.payload_offset + payload_len - - # No SPI Header for EC_RW as its not loaded by BootROM - spi_list.append((offset, payload_signature, - "payload(lfw_ro) signature")) - - # EC_RW location - rw_offset = int(spi_size // 2) - if args.rw_loc >= 0: - rw_offset = args.rw_loc - - debug_print("rw_offset = 0x{0:08x}".format(rw_offset)) - - if rw_offset < offset + len(payload_signature): - print("ERROR: EC_RW overlaps EC_RO") - - spi_list.append((rw_offset, payload_rw, "payload(rw)")) - - # don't add to EC_RW. We don't know if Google will process - # EC SPI flash binary with other tools during build of - # coreboot and OS. - #offset = rw_offset + payload_rw_len - #spi_list.append((offset, payload_rw_sig, "payload(rw) signature")) - - spi_list = sorted(spi_list) - - dumpsects(spi_list) - - # - # MEC17xx Boot-ROM locates TAG at SPI offset 0 instead of end of SPI. - # - with open(args.output, 'wb') as f: - debug_print("Write spi list to file", args.output) - addr = 0 - for s in spi_list: - if addr < s[0]: - debug_print("Offset ",hex(addr)," Length", hex(s[0]-addr), - "fill with 0xff") - f.write(b'\xff' * (s[0] - addr)) - addr = s[0] - debug_print("Offset ",hex(addr), " Length", hex(len(s[1])), "write data") - - f.write(s[1]) - addr += len(s[1]) - - if addr < spi_size: - debug_print("Offset ",hex(addr), " Length", hex(spi_size - addr), - "fill with 0xff") - f.write(b'\xff' * (spi_size - addr)) - - f.flush() - -if __name__ == '__main__': - main() + global debug_print + + args = parseargs() + + if args.verbose: + debug_print = print + + debug_print("Begin MEC17xx pack_ec.py script") + + # MEC17xx maximum 192KB each for RO & RW + # mec1701 chip Makefile sets args.spi_size = 512 + # Tags at offset 0 + # + print_args(args) + + spi_size = args.spi_size * 1024 + debug_print("SPI Flash image size in bytes =", hex(spi_size)) + + # !!! IMPORTANT !!! + # These values MUST match chip/mec1701/config_flash_layout.h + # defines. + # MEC17xx Boot-ROM TAGs are at offset 0 and 4. + # lfw + EC_RO starts at beginning of second 4KB sector + # EC_RW starts at offset 0x40000 (256KB) + + spi_list = [] + + debug_print("args.input = ", args.input) + debug_print("args.loader_file = ", args.loader_file) + debug_print("args.image_size = ", hex(args.image_size)) + + rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size) + + payload = GetPayload(rorofile) + payload_len = len(payload) + # debug + debug_print("EC_LFW + EC_RO length = ", hex(payload_len)) + + # SPI image integrity test + # compute CRC32 of EC_RO except for last 4 bytes + # skip over 4KB LFW + # Store CRC32 in last 4 bytes + if args.test_spi == True: + crc = zlib.crc32(bytes(payload[LFW_SIZE : (payload_len - 4)])) + crc_ofs = payload_len - 4 + debug_print("EC_RO CRC32 = 0x{0:08x} @ 0x{1:08x}".format(crc, crc_ofs)) + for i in range(4): + payload[crc_ofs + i] = crc & 0xFF + crc = crc >> 8 + + # Chromebooks are not using MEC BootROM ECDSA. + # We implemented the ECDSA disabled case where + # the 64-byte signature contains a SHA-256 of the binary plus + # 32 zeros bytes. + payload_signature = SignByteArray(payload) + # debug + printByteArrayAsHex(payload_signature, "LFW + EC_RO payload_signature") + + # MEC17xx Header is 0x80 bytes with an 64 byte signature + # (32 byte SHA256 + 32 zero bytes) + header = BuildHeader(args, payload_len, LOAD_ADDR, rorofile) + # debug + printByteArrayAsHex(header, "Header LFW + EC_RO") + + # MEC17xx payload ECDSA not used, 64 byte signature is + # SHA256 + 32 zero bytes + header_signature = SignByteArray(header) + # debug + printByteArrayAsHex(header_signature, "header_signature") + + tag = BuildTag(args) + # MEC17xx truncate RW length to args.image_size to not overwrite LFW + # offset may be different due to Header size and other changes + # MCHP we want to append a SHA-256 to the end of the actual payload + # to test SPI read routines. + debug_print("Call to GetPayloadFromOffset") + debug_print("args.input = ", args.input) + debug_print("args.image_size = ", hex(args.image_size)) + + payload_rw = GetPayloadFromOffset(args.input, args.image_size) + debug_print("type(payload_rw) is ", type(payload_rw)) + debug_print("len(payload_rw) is ", hex(len(payload_rw))) + + # truncate to args.image_size + rw_len = args.image_size + payload_rw = payload_rw[:rw_len] + payload_rw_len = len(payload_rw) + debug_print("Truncated size of EC_RW = ", hex(payload_rw_len)) + + payload_entry_tuple = struct.unpack_from("<I", payload_rw, 4) + debug_print("payload_entry_tuple = ", payload_entry_tuple) + + payload_entry = payload_entry_tuple[0] + debug_print("payload_entry = ", hex(payload_entry)) + + # Note: payload_rw is a bytearray therefore is mutable + if args.test_ecrw: + gen_test_ecrw(payload_rw) + + # SPI image integrity test + # compute CRC32 of EC_RW except for last 4 bytes + # Store CRC32 in last 4 bytes + if args.test_spi == True: + crc = zlib.crc32(bytes(payload_rw[: (payload_rw_len - 32)])) + crc_ofs = payload_rw_len - 4 + debug_print("EC_RW CRC32 = 0x{0:08x} at offset 0x{1:08x}".format(crc, crc_ofs)) + for i in range(4): + payload_rw[crc_ofs + i] = crc & 0xFF + crc = crc >> 8 + + payload_rw_sig = SignByteArray(payload_rw) + # debug + printByteArrayAsHex(payload_rw_sig, "payload_rw_sig") + + os.remove(rorofile) # clean up the temp file + + # MEC170x Boot-ROM Tags are located at SPI offset 0 + spi_list.append((0, tag, "tag")) + + spi_list.append((args.header_loc, header, "header(lwf + ro)")) + spi_list.append( + (args.header_loc + HEADER_SIZE, header_signature, "header(lwf + ro) signature") + ) + spi_list.append( + (args.header_loc + args.payload_offset, payload, "payload(lfw + ro)") + ) + + offset = args.header_loc + args.payload_offset + payload_len + + # No SPI Header for EC_RW as its not loaded by BootROM + spi_list.append((offset, payload_signature, "payload(lfw_ro) signature")) + + # EC_RW location + rw_offset = int(spi_size // 2) + if args.rw_loc >= 0: + rw_offset = args.rw_loc + + debug_print("rw_offset = 0x{0:08x}".format(rw_offset)) + + if rw_offset < offset + len(payload_signature): + print("ERROR: EC_RW overlaps EC_RO") + + spi_list.append((rw_offset, payload_rw, "payload(rw)")) + + # don't add to EC_RW. We don't know if Google will process + # EC SPI flash binary with other tools during build of + # coreboot and OS. + # offset = rw_offset + payload_rw_len + # spi_list.append((offset, payload_rw_sig, "payload(rw) signature")) + + spi_list = sorted(spi_list) + + dumpsects(spi_list) + + # + # MEC17xx Boot-ROM locates TAG at SPI offset 0 instead of end of SPI. + # + with open(args.output, "wb") as f: + debug_print("Write spi list to file", args.output) + addr = 0 + for s in spi_list: + if addr < s[0]: + debug_print( + "Offset ", hex(addr), " Length", hex(s[0] - addr), "fill with 0xff" + ) + f.write(b"\xff" * (s[0] - addr)) + addr = s[0] + debug_print( + "Offset ", hex(addr), " Length", hex(len(s[1])), "write data" + ) + + f.write(s[1]) + addr += len(s[1]) + + if addr < spi_size: + debug_print( + "Offset ", hex(addr), " Length", hex(spi_size - addr), "fill with 0xff" + ) + f.write(b"\xff" * (spi_size - addr)) + + f.flush() + + +if __name__ == "__main__": + main() diff --git a/chip/mchp/util/pack_ec_mec152x.py b/chip/mchp/util/pack_ec_mec152x.py index 34846cd6ba..89f90f5394 100755 --- a/chip/mchp/util/pack_ec_mec152x.py +++ b/chip/mchp/util/pack_ec_mec152x.py @@ -16,7 +16,7 @@ import os import struct import subprocess import tempfile -import zlib # CRC32 +import zlib # CRC32 # MEC152xH has 256KB SRAM from 0xE0000 - 0x120000 # SRAM is divided into contiguous CODE & DATA @@ -29,119 +29,157 @@ LOAD_ADDR = 0x0E0000 LOAD_ADDR_RW = 0xE1000 MEC152X_HEADER_SIZE = 0x140 MEC152X_HEADER_VERSION = 0x02 -PAYLOAD_PAD_BYTE = b'\xff' +PAYLOAD_PAD_BYTE = b"\xff" SPI_ERASE_BLOCK_SIZE = 0x1000 SPI_CLOCK_LIST = [48, 24, 16, 12] -SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b, 0x6b] -SPI_DRIVE_STR_DICT = {2:0, 4:1, 8:2, 12:3} +SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B, 0x6B] +SPI_DRIVE_STR_DICT = {2: 0, 4: 1, 8: 2, 12: 3} CHIP_MAX_CODE_SRAM_KB = 224 MEC152X_DICT = { - "HEADER_SIZE":0x140, - "HEADER_VER":0x02, - "PAYLOAD_OFFSET":0x140, - "PAYLOAD_GRANULARITY":128, - "EC_INFO_BLK_SZ":128, - "ENCR_KEY_HDR_SZ":128, - "COSIG_SZ":96, - "TRAILER_SZ":160, - "TAILER_PAD_BYTE":b'\xff', - "PAD_SIZE":128 - } - -CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15, - 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d] + "HEADER_SIZE": 0x140, + "HEADER_VER": 0x02, + "PAYLOAD_OFFSET": 0x140, + "PAYLOAD_GRANULARITY": 128, + "EC_INFO_BLK_SZ": 128, + "ENCR_KEY_HDR_SZ": 128, + "COSIG_SZ": 96, + "TRAILER_SZ": 160, + "TAILER_PAD_BYTE": b"\xff", + "PAD_SIZE": 128, +} + +CRC_TABLE = [ + 0x00, + 0x07, + 0x0E, + 0x09, + 0x1C, + 0x1B, + 0x12, + 0x15, + 0x38, + 0x3F, + 0x36, + 0x31, + 0x24, + 0x23, + 0x2A, + 0x2D, +] + def mock_print(*args, **kwargs): - pass + pass + debug_print = mock_print # Debug helper routine def dumpsects(spi_list): - debug_print("spi_list has {0} entries".format(len(spi_list))) - for s in spi_list: - debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0],len(s[1]),s[2])) + debug_print("spi_list has {0} entries".format(len(spi_list))) + for s in spi_list: + debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0], len(s[1]), s[2])) + def printByteArrayAsHex(ba, title): - debug_print(title,"= ") - if ba == None: - debug_print("None") - return - - count = 0 - for b in ba: - count = count + 1 - debug_print("0x{0:02x}, ".format(b),end="") - if (count % 8) == 0: - debug_print("") - debug_print("") + debug_print(title, "= ") + if ba == None: + debug_print("None") + return + + count = 0 + for b in ba: + count = count + 1 + debug_print("0x{0:02x}, ".format(b), end="") + if (count % 8) == 0: + debug_print("") + debug_print("") + def Crc8(crc, data): - """Update CRC8 value.""" - for v in data: - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]); - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]); - return crc ^ 0x55 + """Update CRC8 value.""" + for v in data: + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]) + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)]) + return crc ^ 0x55 + def GetEntryPoint(payload_file): - """Read entry point from payload EC image.""" - with open(payload_file, 'rb') as f: - f.seek(4) - s = f.read(4) - return int.from_bytes(s, byteorder='little') + """Read entry point from payload EC image.""" + with open(payload_file, "rb") as f: + f.seek(4) + s = f.read(4) + return int.from_bytes(s, byteorder="little") + def GetPayloadFromOffset(payload_file, offset, padsize): - """Read payload and pad it to padsize.""" - with open(payload_file, 'rb') as f: - f.seek(offset) - payload = bytearray(f.read()) - rem_len = len(payload) % padsize - debug_print("GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format(padsize,len(payload),rem_len)) + """Read payload and pad it to padsize.""" + with open(payload_file, "rb") as f: + f.seek(offset) + payload = bytearray(f.read()) + rem_len = len(payload) % padsize + debug_print( + "GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format( + padsize, len(payload), rem_len + ) + ) + + if rem_len: + payload += PAYLOAD_PAD_BYTE * (padsize - rem_len) + debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len)) - if rem_len: - payload += PAYLOAD_PAD_BYTE * (padsize - rem_len) - debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len)) + return payload - return payload def GetPayload(payload_file, padsize): - """Read payload and pad it to padsize""" - return GetPayloadFromOffset(payload_file, 0, padsize) + """Read payload and pad it to padsize""" + return GetPayloadFromOffset(payload_file, 0, padsize) + def GetPublicKey(pem_file): - """Extract public exponent and modulus from PEM file.""" - result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text', - '-noout'], stdout=subprocess.PIPE, encoding='utf-8') - modulus_raw = [] - in_modulus = False - for line in result.stdout.splitlines(): - if line.startswith('modulus'): - in_modulus = True - elif not line.startswith(' '): - in_modulus = False - elif in_modulus: - modulus_raw.extend(line.strip().strip(':').split(':')) - if line.startswith('publicExponent'): - exp = int(line.split(' ')[1], 10) - modulus_raw.reverse() - modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) - return struct.pack('<Q', exp), modulus + """Extract public exponent and modulus from PEM file.""" + result = subprocess.run( + ["openssl", "rsa", "-in", pem_file, "-text", "-noout"], + stdout=subprocess.PIPE, + encoding="utf-8", + ) + modulus_raw = [] + in_modulus = False + for line in result.stdout.splitlines(): + if line.startswith("modulus"): + in_modulus = True + elif not line.startswith(" "): + in_modulus = False + elif in_modulus: + modulus_raw.extend(line.strip().strip(":").split(":")) + if line.startswith("publicExponent"): + exp = int(line.split(" ")[1], 10) + modulus_raw.reverse() + modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) + return struct.pack("<Q", exp), modulus + def GetSpiClockParameter(args): - assert args.spi_clock in SPI_CLOCK_LIST, \ - "Unsupported SPI clock speed %d MHz" % args.spi_clock - return SPI_CLOCK_LIST.index(args.spi_clock) + assert args.spi_clock in SPI_CLOCK_LIST, ( + "Unsupported SPI clock speed %d MHz" % args.spi_clock + ) + return SPI_CLOCK_LIST.index(args.spi_clock) + def GetSpiReadCmdParameter(args): - assert args.spi_read_cmd in SPI_READ_CMD_LIST, \ - "Unsupported SPI read command 0x%x" % args.spi_read_cmd - return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + assert args.spi_read_cmd in SPI_READ_CMD_LIST, ( + "Unsupported SPI read command 0x%x" % args.spi_read_cmd + ) + return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + def GetEncodedSpiDriveStrength(args): - assert args.spi_drive_str in SPI_DRIVE_STR_DICT, \ - "Unsupported SPI drive strength %d mA" % args.spi_drive_str - return SPI_DRIVE_STR_DICT.get(args.spi_drive_str) + assert args.spi_drive_str in SPI_DRIVE_STR_DICT, ( + "Unsupported SPI drive strength %d mA" % args.spi_drive_str + ) + return SPI_DRIVE_STR_DICT.get(args.spi_drive_str) + # Return 0=Slow slew rate or 1=Fast slew rate def GetSpiSlewRate(args): @@ -149,12 +187,14 @@ def GetSpiSlewRate(args): return 1 return 0 + # Return SPI CPOL = 0 or 1 def GetSpiCpol(args): if args.spi_cpol == 0: return 0 return 1 + # Return SPI CPHA_MOSI # 0 = SPI Master drives data is stable on inactive to clock edge # 1 = SPI Master drives data is stable on active to inactive clock edge @@ -163,6 +203,7 @@ def GetSpiCphaMosi(args): return 0 return 1 + # Return SPI CPHA_MISO 0 or 1 # 0 = SPI Master samples data on inactive to active clock edge # 1 = SPI Master samples data on active to inactive clock edge @@ -171,8 +212,10 @@ def GetSpiCphaMiso(args): return 0 return 1 + def PadZeroTo(data, size): - data.extend(b'\0' * (size - len(data))) + data.extend(b"\0" * (size - len(data))) + # # Boot-ROM SPI image encryption not used with Chromebooks @@ -180,6 +223,7 @@ def PadZeroTo(data, size): def EncryptPayload(args, chip_dict, payload): return None + # # Build SPI image header for MEC152x # MEC152x image header size = 320(0x140) bytes @@ -237,67 +281,69 @@ def EncryptPayload(args, chip_dict, payload): # header[0x110:0x140] = Header ECDSA-384 signature y-coor. = 0 Auth. disabled # def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry): - header_size = MEC152X_HEADER_SIZE - - # allocate zero filled header - header = bytearray(b'\x00' * header_size) - debug_print("len(header) = ", len(header)) - - # Identifier and header version - header[0:4] = b'PHCM' - header[4] = MEC152X_HEADER_VERSION - - # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips - spiFreqMHz = GetSpiClockParameter(args) - header[5] = (int(spiFreqMHz // 48) - 1) & 0x03 - header[5] |= ((GetEncodedSpiDriveStrength(args) & 0x03) << 2) - header[5] |= ((GetSpiSlewRate(args) & 0x01) << 4) - header[5] |= ((GetSpiCpol(args) & 0x01) << 5) - header[5] |= ((GetSpiCphaMosi(args) & 0x01) << 6) - header[5] |= ((GetSpiCphaMiso(args) & 0x01) << 7) - - # b[0]=0 VTR1 must be 3.3V - # b[1]=0(VTR2 3.3V), 1(VTR2 1.8V) - # b[2]=0(VTR3 3.3V), 1(VTR3 1.8V) - # b[5:3]=111b - # b[6]=0 No ECDSA - # b[7]=0 No encrypted FW image - header[6] = 0x7 << 3 - if args.vtr2_V18 == True: - header[6] |= 0x02 - if args.vtr3_V18 == True: - header[6] |= 0x04 - - # SPI read command set same for both chips - header[7] = GetSpiReadCmdParameter(args) & 0xFF - - # bytes 0x08 - 0x0b - header[0x08:0x0C] = load_addr.to_bytes(4, byteorder='little') - # bytes 0x0c - 0x0f - header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder='little') - # bytes 0x10 - 0x11 payload length in units of 128 bytes - - payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"]) - assert payload_units < 0x10000, \ - print("Payload too large: len={0} units={1}".format(payload_len, payload_units)) - - header[0x10:0x12] = payload_units.to_bytes(2, 'little') - - # bytes 0x14 - 0x17 - header[0x14:0x18] = chip_dict["PAYLOAD_OFFSET"].to_bytes(4, 'little') - - # MEC152x: Disable ECDSA and encryption - header[0x18] = 0 - - # header[0xB0:0xE0] = SHA384(header[0:0xB0]) - header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest() - # When ECDSA authentication is disabled MCHP SPI image generator - # is filling the last 48 bytes of the Header with 0xff - header[-48:] = b'\xff' * 48 - - debug_print("After hash: len(header) = ", len(header)) - - return header + header_size = MEC152X_HEADER_SIZE + + # allocate zero filled header + header = bytearray(b"\x00" * header_size) + debug_print("len(header) = ", len(header)) + + # Identifier and header version + header[0:4] = b"PHCM" + header[4] = MEC152X_HEADER_VERSION + + # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips + spiFreqMHz = GetSpiClockParameter(args) + header[5] = (int(spiFreqMHz // 48) - 1) & 0x03 + header[5] |= (GetEncodedSpiDriveStrength(args) & 0x03) << 2 + header[5] |= (GetSpiSlewRate(args) & 0x01) << 4 + header[5] |= (GetSpiCpol(args) & 0x01) << 5 + header[5] |= (GetSpiCphaMosi(args) & 0x01) << 6 + header[5] |= (GetSpiCphaMiso(args) & 0x01) << 7 + + # b[0]=0 VTR1 must be 3.3V + # b[1]=0(VTR2 3.3V), 1(VTR2 1.8V) + # b[2]=0(VTR3 3.3V), 1(VTR3 1.8V) + # b[5:3]=111b + # b[6]=0 No ECDSA + # b[7]=0 No encrypted FW image + header[6] = 0x7 << 3 + if args.vtr2_V18 == True: + header[6] |= 0x02 + if args.vtr3_V18 == True: + header[6] |= 0x04 + + # SPI read command set same for both chips + header[7] = GetSpiReadCmdParameter(args) & 0xFF + + # bytes 0x08 - 0x0b + header[0x08:0x0C] = load_addr.to_bytes(4, byteorder="little") + # bytes 0x0c - 0x0f + header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder="little") + # bytes 0x10 - 0x11 payload length in units of 128 bytes + + payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"]) + assert payload_units < 0x10000, print( + "Payload too large: len={0} units={1}".format(payload_len, payload_units) + ) + + header[0x10:0x12] = payload_units.to_bytes(2, "little") + + # bytes 0x14 - 0x17 + header[0x14:0x18] = chip_dict["PAYLOAD_OFFSET"].to_bytes(4, "little") + + # MEC152x: Disable ECDSA and encryption + header[0x18] = 0 + + # header[0xB0:0xE0] = SHA384(header[0:0xB0]) + header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest() + # When ECDSA authentication is disabled MCHP SPI image generator + # is filling the last 48 bytes of the Header with 0xff + header[-48:] = b"\xff" * 48 + + debug_print("After hash: len(header) = ", len(header)) + + return header + # # MEC152x 128-byte EC Info Block appended to @@ -311,8 +357,9 @@ def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry): # byte 127 = customer current image revision # def GenEcInfoBlock(args, chip_dict): - ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"]) - return ecinfo + ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"]) + return ecinfo + # # Generate SPI FW image co-signature. @@ -325,7 +372,8 @@ def GenEcInfoBlock(args, chip_dict): # signature. # def GenCoSignature(args, chip_dict, payload): - return bytearray(b'\xff' * chip_dict["COSIG_SZ"]) + return bytearray(b"\xff" * chip_dict["COSIG_SZ"]) + # # Generate SPI FW Image trailer. @@ -336,22 +384,24 @@ def GenCoSignature(args, chip_dict, payload): # trailer[144:160] = 0xFF. Boot-ROM spec. says these bytes should be random. # Authentication & encryption are not used therefore random data # is not necessary. -def GenTrailer(args, chip_dict, payload, encryption_key_header, - ec_info_block, cosignature): +def GenTrailer( + args, chip_dict, payload, encryption_key_header, ec_info_block, cosignature +): trailer = bytearray(chip_dict["TAILER_PAD_BYTE"] * chip_dict["TRAILER_SZ"]) hasher = hashlib.sha384() hasher.update(payload) if ec_info_block != None: - hasher.update(ec_info_block) + hasher.update(ec_info_block) if encryption_key_header != None: - hasher.update(encryption_key_header) + hasher.update(encryption_key_header) if cosignature != None: - hasher.update(cosignature) + hasher.update(cosignature) trailer[0:48] = hasher.digest() - trailer[-16:] = 16 * b'\xff' + trailer[-16:] = 16 * b"\xff" return trailer + # MEC152xH supports two 32-bit Tags located at offsets 0x0 and 0x4 # in the SPI flash. # Tag format: @@ -362,16 +412,21 @@ def GenTrailer(args, chip_dict, payload, encryption_key_header, # to the same flash part. # def BuildTag(args): - tag = bytearray([(args.header_loc >> 8) & 0xff, - (args.header_loc >> 16) & 0xff, - (args.header_loc >> 24) & 0xff]) - tag.append(Crc8(0, tag)) - return tag + tag = bytearray( + [ + (args.header_loc >> 8) & 0xFF, + (args.header_loc >> 16) & 0xFF, + (args.header_loc >> 24) & 0xFF, + ] + ) + tag.append(Crc8(0, tag)) + return tag + def BuildTagFromHdrAddr(header_loc): - tag = bytearray([(header_loc >> 8) & 0xff, - (header_loc >> 16) & 0xff, - (header_loc >> 24) & 0xff]) + tag = bytearray( + [(header_loc >> 8) & 0xFF, (header_loc >> 16) & 0xFF, (header_loc >> 24) & 0xFF] + ) tag.append(Crc8(0, tag)) return tag @@ -388,12 +443,13 @@ def BuildTagFromHdrAddr(header_loc): # Output: # bytearray of length 4 def BuildFlashMap(secondSpiFlashBaseAddr): - flashmap = bytearray(4) - flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xff - flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xff - flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xff - flashmap[3] = Crc8(0, flashmap) - return flashmap + flashmap = bytearray(4) + flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xFF + flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xFF + flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xFF + flashmap[3] = Crc8(0, flashmap) + return flashmap + # # Creates temporary file for read/write @@ -404,21 +460,22 @@ def BuildFlashMap(secondSpiFlashBaseAddr): # Returns temporary file name # def PacklfwRoImage(rorw_file, loader_file, image_size): - """Create a temp file with the - first image_size bytes from the loader file and append bytes - from the rorw file. - return the filename""" - fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around - with open(loader_file,'rb') as fin1: # read 4KB loader file - pro = fin1.read() - fo.write(pro) # write 4KB loader data to temp file - with open(rorw_file, 'rb') as fin: - ro = fin.read(image_size) - - fo.write(ro) - fo.close() - - return fo.name + """Create a temp file with the + first image_size bytes from the loader file and append bytes + from the rorw file. + return the filename""" + fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around + with open(loader_file, "rb") as fin1: # read 4KB loader file + pro = fin1.read() + fo.write(pro) # write 4KB loader data to temp file + with open(rorw_file, "rb") as fin: + ro = fin.read(image_size) + + fo.write(ro) + fo.close() + + return fo.name + # # Generate a test EC_RW image of same size @@ -429,129 +486,184 @@ def PacklfwRoImage(rorw_file, loader_file, image_size): # process hash generation. # def gen_test_ecrw(pldrw): - debug_print("gen_test_ecrw: pldrw type =", type(pldrw)) - debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw))) - cookie1_pos = pldrw.find(b'\x99\x88\x77\xce') - cookie2_pos = pldrw.find(b'\xdd\xbb\xaa\xce', cookie1_pos+4) - t = struct.unpack("<L", pldrw[cookie1_pos+0x24:cookie1_pos+0x28]) - size = t[0] - debug_print("EC_RW size =", size, " = ", hex(size)) - - debug_print("Found cookie1 at ", hex(cookie1_pos)) - debug_print("Found cookie2 at ", hex(cookie2_pos)) - - if cookie1_pos > 0 and cookie2_pos > cookie1_pos: - for i in range(0, cookie1_pos): - pldrw[i] = 0xA5 - for i in range(cookie2_pos+4, len(pldrw)): - pldrw[i] = 0xA5 - - with open("ec_RW_test.bin", "wb") as fecrw: - fecrw.write(pldrw[:size]) + debug_print("gen_test_ecrw: pldrw type =", type(pldrw)) + debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw))) + cookie1_pos = pldrw.find(b"\x99\x88\x77\xce") + cookie2_pos = pldrw.find(b"\xdd\xbb\xaa\xce", cookie1_pos + 4) + t = struct.unpack("<L", pldrw[cookie1_pos + 0x24 : cookie1_pos + 0x28]) + size = t[0] + debug_print("EC_RW size =", size, " = ", hex(size)) + + debug_print("Found cookie1 at ", hex(cookie1_pos)) + debug_print("Found cookie2 at ", hex(cookie2_pos)) + + if cookie1_pos > 0 and cookie2_pos > cookie1_pos: + for i in range(0, cookie1_pos): + pldrw[i] = 0xA5 + for i in range(cookie2_pos + 4, len(pldrw)): + pldrw[i] = 0xA5 + + with open("ec_RW_test.bin", "wb") as fecrw: + fecrw.write(pldrw[:size]) + def parseargs(): - #TODO I commented this out. Why? - rpath = os.path.dirname(os.path.relpath(__file__)) - - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", - help="EC binary to pack, usually ec.bin or ec.RO.flat.", - metavar="EC_BIN", default="ec.bin") - parser.add_argument("-o", "--output", - help="Output flash binary file", - metavar="EC_SPI_FLASH", default="ec.packed.bin") - parser.add_argument("--loader_file", - help="EC loader binary", - default="ecloader.bin") - parser.add_argument("-s", "--spi_size", type=int, - help="Size of the SPI flash in KB", - default=512) - parser.add_argument("-l", "--header_loc", type=int, - help="Location of header in SPI flash", - default=0x1000) - parser.add_argument("-r", "--rw_loc", type=int, - help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size", - default=-1) - parser.add_argument("--spi_clock", type=int, - help="SPI clock speed. 8, 12, 24, or 48 MHz.", - default=24) - parser.add_argument("--spi_read_cmd", type=int, - help="SPI read command. 0x3, 0xB, or 0x3B.", - default=0xb) - parser.add_argument("--image_size", type=int, - help="Size of a single image. Default 220KB", - default=(220 * 1024)) - parser.add_argument("--test_spi", action='store_true', - help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries", - default=False) - parser.add_argument("--test_ecrw", action='store_true', - help="Use fixed pattern for EC_RW but preserve image_data", - default=False) - parser.add_argument("--verbose", action='store_true', - help="Enable verbose output", - default=False) - parser.add_argument("--tag0_loc", type=int, - help="MEC152X TAG0 SPI offset", - default=0) - parser.add_argument("--tag1_loc", type=int, - help="MEC152X TAG1 SPI offset", - default=4) - parser.add_argument("--spi_drive_str", type=int, - help="Chip SPI drive strength in mA: 2, 4, 8, or 12", - default=4) - parser.add_argument("--spi_slew_fast", action='store_true', - help="SPI use fast slew rate. Default is False", - default=False) - parser.add_argument("--spi_cpol", type=int, - help="SPI clock polarity when idle. Defealt is 0(low)", - default=0) - parser.add_argument("--spi_cpha_mosi", type=int, - help="""SPI clock phase master drives data. + # TODO I commented this out. Why? + rpath = os.path.dirname(os.path.relpath(__file__)) + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input", + help="EC binary to pack, usually ec.bin or ec.RO.flat.", + metavar="EC_BIN", + default="ec.bin", + ) + parser.add_argument( + "-o", + "--output", + help="Output flash binary file", + metavar="EC_SPI_FLASH", + default="ec.packed.bin", + ) + parser.add_argument( + "--loader_file", help="EC loader binary", default="ecloader.bin" + ) + parser.add_argument( + "-s", "--spi_size", type=int, help="Size of the SPI flash in KB", default=512 + ) + parser.add_argument( + "-l", + "--header_loc", + type=int, + help="Location of header in SPI flash", + default=0x1000, + ) + parser.add_argument( + "-r", + "--rw_loc", + type=int, + help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size", + default=-1, + ) + parser.add_argument( + "--spi_clock", + type=int, + help="SPI clock speed. 8, 12, 24, or 48 MHz.", + default=24, + ) + parser.add_argument( + "--spi_read_cmd", + type=int, + help="SPI read command. 0x3, 0xB, or 0x3B.", + default=0xB, + ) + parser.add_argument( + "--image_size", + type=int, + help="Size of a single image. Default 220KB", + default=(220 * 1024), + ) + parser.add_argument( + "--test_spi", + action="store_true", + help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries", + default=False, + ) + parser.add_argument( + "--test_ecrw", + action="store_true", + help="Use fixed pattern for EC_RW but preserve image_data", + default=False, + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose output", default=False + ) + parser.add_argument( + "--tag0_loc", type=int, help="MEC152X TAG0 SPI offset", default=0 + ) + parser.add_argument( + "--tag1_loc", type=int, help="MEC152X TAG1 SPI offset", default=4 + ) + parser.add_argument( + "--spi_drive_str", + type=int, + help="Chip SPI drive strength in mA: 2, 4, 8, or 12", + default=4, + ) + parser.add_argument( + "--spi_slew_fast", + action="store_true", + help="SPI use fast slew rate. Default is False", + default=False, + ) + parser.add_argument( + "--spi_cpol", + type=int, + help="SPI clock polarity when idle. Defealt is 0(low)", + default=0, + ) + parser.add_argument( + "--spi_cpha_mosi", + type=int, + help="""SPI clock phase master drives data. 0=Data driven on active to inactive clock edge, 1=Data driven on inactive to active clock edge""", - default=0) - parser.add_argument("--spi_cpha_miso", type=int, - help="""SPI clock phase master samples data. + default=0, + ) + parser.add_argument( + "--spi_cpha_miso", + type=int, + help="""SPI clock phase master samples data. 0=Data sampled on inactive to active clock edge, 1=Data sampled on active to inactive clock edge""", - default=0) + default=0, + ) + + parser.add_argument( + "--vtr2_V18", + action="store_true", + help="Chip VTR2 rail is 1.8V. Default is False(3.3V)", + default=False, + ) - parser.add_argument("--vtr2_V18", action='store_true', - help="Chip VTR2 rail is 1.8V. Default is False(3.3V)", - default=False) + parser.add_argument( + "--vtr3_V18", + action="store_true", + help="Chip VTR3 rail is 1.8V. Default is False(3.3V)", + default=False, + ) - parser.add_argument("--vtr3_V18", action='store_true', - help="Chip VTR3 rail is 1.8V. Default is False(3.3V)", - default=False) + return parser.parse_args() - return parser.parse_args() def print_args(args): - debug_print("parsed arguments:") - debug_print(".input = ", args.input) - debug_print(".output = ", args.output) - debug_print(".loader_file = ", args.loader_file) - debug_print(".spi_size (KB) = ", hex(args.spi_size)) - debug_print(".image_size = ", hex(args.image_size)) - debug_print(".tag0_loc = ", hex(args.tag0_loc)) - debug_print(".tag1_loc = ", hex(args.tag1_loc)) - debug_print(".header_loc = ", hex(args.header_loc)) - if args.rw_loc < 0: - debug_print(".rw_loc = ", args.rw_loc) - else: - debug_print(".rw_loc = ", hex(args.rw_loc)) - debug_print(".spi_clock (MHz) = ", args.spi_clock) - debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd)) - debug_print(".test_spi = ", args.test_spi) - debug_print(".test_ecrw = ", args.test_ecrw) - debug_print(".verbose = ", args.verbose) - debug_print(".spi_drive_str = ", args.spi_drive_str) - debug_print(".spi_slew_fast = ", args.spi_slew_fast) - debug_print(".spi_cpol = ", args.spi_cpol) - debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi) - debug_print(".spi_cpha_miso = ", args.spi_cpha_miso) - debug_print(".vtr2_V18 = ", args.vtr2_V18) - debug_print(".vtr3_V18 = ", args.vtr3_V18) + debug_print("parsed arguments:") + debug_print(".input = ", args.input) + debug_print(".output = ", args.output) + debug_print(".loader_file = ", args.loader_file) + debug_print(".spi_size (KB) = ", hex(args.spi_size)) + debug_print(".image_size = ", hex(args.image_size)) + debug_print(".tag0_loc = ", hex(args.tag0_loc)) + debug_print(".tag1_loc = ", hex(args.tag1_loc)) + debug_print(".header_loc = ", hex(args.header_loc)) + if args.rw_loc < 0: + debug_print(".rw_loc = ", args.rw_loc) + else: + debug_print(".rw_loc = ", hex(args.rw_loc)) + debug_print(".spi_clock (MHz) = ", args.spi_clock) + debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd)) + debug_print(".test_spi = ", args.test_spi) + debug_print(".test_ecrw = ", args.test_ecrw) + debug_print(".verbose = ", args.verbose) + debug_print(".spi_drive_str = ", args.spi_drive_str) + debug_print(".spi_slew_fast = ", args.spi_slew_fast) + debug_print(".spi_cpol = ", args.spi_cpol) + debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi) + debug_print(".spi_cpha_miso = ", args.spi_cpha_miso) + debug_print(".vtr2_V18 = ", args.vtr2_V18) + debug_print(".vtr3_V18 = ", args.vtr3_V18) + # # Handle quiet mode build from Makefile @@ -589,215 +701,229 @@ def print_args(args): # || 48 * [0] # def main(): - global debug_print + global debug_print + + args = parseargs() + + if args.verbose: + debug_print = print - args = parseargs() - - if args.verbose: - debug_print = print + debug_print("Begin pack_ec_mec152x.py script") + + print_args(args) + + chip_dict = MEC152X_DICT + + # Boot-ROM requires header location aligned >= 256 bytes. + # CrOS EC flash image update code requires EC_RO/RW location to be aligned + # on a flash erase size boundary and EC_RO/RW size to be a multiple of + # the smallest flash erase block size. + # + assert (args.header_loc % SPI_ERASE_BLOCK_SIZE) == 0, ( + "Header location %d is not on a flash erase block boundary boundary" + % args.header_loc + ) + + max_image_size = CHIP_MAX_CODE_SRAM_KB - LFW_SIZE + if args.test_spi: + max_image_size -= 32 # SHA256 digest + + assert args.image_size > max_image_size, ( + "Image size exceeds maximum" % args.image_size + ) + + spi_size = args.spi_size * 1024 + debug_print("SPI Flash image size in bytes =", hex(spi_size)) + + # !!! IMPORTANT !!! + # These values MUST match chip/mchp/config_flash_layout.h + # defines. + # MEC152x Boot-ROM TAGs are at offset 0 and 4. + # lfw + EC_RO starts at beginning of second 4KB sector + # EC_RW starts at (flash size / 2) i.e. 0x40000 for a 512KB flash. + + spi_list = [] + + debug_print("args.input = ", args.input) + debug_print("args.loader_file = ", args.loader_file) + debug_print("args.image_size = ", hex(args.image_size)) + + rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size) + debug_print("Temporary file containing LFW + EC_RO is ", rorofile) + + lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"]) + lfw_ecro_len = len(lfw_ecro) + debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len)) + + # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes + if args.test_spi: + crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4])) + crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder="little") + lfw_ecro[-4:] = crc32_ecro_bytes + debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE)) + debug_print("CRC32(ecro-4) = ", hex(crc32_ecro)) + + # Reads entry point from offset 4 of file. + # This assumes binary has Cortex-M4 vector table at offset 0. + # 32-bit word at offset 0x0 initial stack pointer value + # 32-bit word at offset 0x4 address of reset handler + # NOTE: reset address will have bit[0]=1 to ensure thumb mode. + lfw_ecro_entry = GetEntryPoint(rorofile) + + # Chromebooks are not using MEC BootROM SPI header/payload authentication + # or payload encryption. In this case the header authentication signature + # is filled with the hash digest of the respective entity. + # BuildHeader2 computes the hash digest and stores it in the correct + # header location. + header = BuildHeader2(args, chip_dict, lfw_ecro_len, LOAD_ADDR, lfw_ecro_entry) + printByteArrayAsHex(header, "Header(lfw_ecro)") + + # If payload encryption used then encrypt payload and + # generate Payload Key Header. If encryption not used + # payload is not modified and the method returns None + encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro) + printByteArrayAsHex(encryption_key_header, "LFW + EC_RO encryption_key_header") + + ec_info_block = GenEcInfoBlock(args, chip_dict) + printByteArrayAsHex(ec_info_block, "EC Info Block") + + cosignature = GenCoSignature(args, chip_dict, lfw_ecro) + printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature") + + trailer = GenTrailer( + args, chip_dict, lfw_ecro, encryption_key_header, ec_info_block, cosignature + ) + + printByteArrayAsHex(trailer, "LFW + EC_RO trailer") + + # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only. + tag0 = BuildTag(args) + tag1 = tag0 + + debug_print("Call to GetPayloadFromOffset") + debug_print("args.input = ", args.input) + debug_print("args.image_size = ", hex(args.image_size)) + + ecrw = GetPayloadFromOffset(args.input, args.image_size, chip_dict["PAD_SIZE"]) + debug_print("type(ecrw) is ", type(ecrw)) + debug_print("len(ecrw) is ", hex(len(ecrw))) + + # truncate to args.image_size + ecrw_len = len(ecrw) + if ecrw_len > args.image_size: + debug_print( + "Truncate EC_RW len={0:0x} to image_size={1:0x}".format( + ecrw_len, args.image_size + ) + ) + ecrw = ecrw[: args.image_size] + ecrw_len = len(ecrw) + + debug_print("len(EC_RW) = ", hex(ecrw_len)) + + # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes + if args.test_spi: + crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4])) + crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder="little") + ecrw[-4:] = crc32_ecrw_bytes + debug_print("ecrw len = ", hex(len(ecrw))) + debug_print("CRC32(ecrw) = ", hex(crc32_ecrw)) + + # Assume FW layout is standard Cortex-M style with vector + # table at start of binary. + # 32-bit word at offset 0x0 = Initial stack pointer + # 32-bit word at offset 0x4 = Address of reset handler + ecrw_entry_tuple = struct.unpack_from("<I", ecrw, 4) + debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0])) + + ecrw_entry = ecrw_entry_tuple[0] + debug_print("ecrw_entry = ", hex(ecrw_entry)) + + # Note: payload_rw is a bytearray therefore is mutable + if args.test_ecrw: + gen_test_ecrw(ecrw) + + os.remove(rorofile) # clean up the temp file + + # MEC152X Add TAG's + spi_list.append((args.tag0_loc, tag0, "tag0")) + spi_list.append((args.tag1_loc, tag1, "tag1")) + + # flashmap is non-zero only for systems with two external + # SPI flash chips. + flashmap = BuildFlashMap(0) + spi_list.append((8, flashmap, "flashmap")) + + # Boot-ROM SPI image header for LFW+EC-RO + spi_list.append((args.header_loc, header, "header(lfw + ro)")) + spi_list.append( + (args.header_loc + chip_dict["PAYLOAD_OFFSET"], lfw_ecro, "lfw_ecro") + ) + + offset = args.header_loc + chip_dict["PAYLOAD_OFFSET"] + lfw_ecro_len - debug_print("Begin pack_ec_mec152x.py script") - - print_args(args) - - chip_dict = MEC152X_DICT - - # Boot-ROM requires header location aligned >= 256 bytes. - # CrOS EC flash image update code requires EC_RO/RW location to be aligned - # on a flash erase size boundary and EC_RO/RW size to be a multiple of - # the smallest flash erase block size. - # - assert (args.header_loc % SPI_ERASE_BLOCK_SIZE) == 0, \ - "Header location %d is not on a flash erase block boundary boundary" % args.header_loc - - max_image_size = CHIP_MAX_CODE_SRAM_KB - LFW_SIZE - if args.test_spi: - max_image_size -= 32 # SHA256 digest - - assert args.image_size > max_image_size, \ - "Image size exceeds maximum" % args.image_size - - spi_size = args.spi_size * 1024 - debug_print("SPI Flash image size in bytes =", hex(spi_size)) - - # !!! IMPORTANT !!! - # These values MUST match chip/mchp/config_flash_layout.h - # defines. - # MEC152x Boot-ROM TAGs are at offset 0 and 4. - # lfw + EC_RO starts at beginning of second 4KB sector - # EC_RW starts at (flash size / 2) i.e. 0x40000 for a 512KB flash. - - spi_list = [] - - debug_print("args.input = ",args.input) - debug_print("args.loader_file = ",args.loader_file) - debug_print("args.image_size = ",hex(args.image_size)) - - rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size) - debug_print("Temporary file containing LFW + EC_RO is ", rorofile) - - lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"]) - lfw_ecro_len = len(lfw_ecro) - debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len)) - - # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes - if args.test_spi: - crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4])) - crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder='little') - lfw_ecro[-4:] = crc32_ecro_bytes - debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE)) - debug_print("CRC32(ecro-4) = ", hex(crc32_ecro)) - - # Reads entry point from offset 4 of file. - # This assumes binary has Cortex-M4 vector table at offset 0. - # 32-bit word at offset 0x0 initial stack pointer value - # 32-bit word at offset 0x4 address of reset handler - # NOTE: reset address will have bit[0]=1 to ensure thumb mode. - lfw_ecro_entry = GetEntryPoint(rorofile) - - # Chromebooks are not using MEC BootROM SPI header/payload authentication - # or payload encryption. In this case the header authentication signature - # is filled with the hash digest of the respective entity. - # BuildHeader2 computes the hash digest and stores it in the correct - # header location. - header = BuildHeader2(args, chip_dict, lfw_ecro_len, - LOAD_ADDR, lfw_ecro_entry) - printByteArrayAsHex(header, "Header(lfw_ecro)") - - # If payload encryption used then encrypt payload and - # generate Payload Key Header. If encryption not used - # payload is not modified and the method returns None - encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro) - printByteArrayAsHex(encryption_key_header, - "LFW + EC_RO encryption_key_header") - - ec_info_block = GenEcInfoBlock(args, chip_dict) - printByteArrayAsHex(ec_info_block, "EC Info Block") - - cosignature = GenCoSignature(args, chip_dict, lfw_ecro) - printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature") - - trailer = GenTrailer(args, chip_dict, lfw_ecro, encryption_key_header, - ec_info_block, cosignature) - - printByteArrayAsHex(trailer, "LFW + EC_RO trailer") - - # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only. - tag0 = BuildTag(args) - tag1 = tag0 - - debug_print("Call to GetPayloadFromOffset") - debug_print("args.input = ", args.input) - debug_print("args.image_size = ", hex(args.image_size)) - - ecrw = GetPayloadFromOffset(args.input, args.image_size, - chip_dict["PAD_SIZE"]) - debug_print("type(ecrw) is ", type(ecrw)) - debug_print("len(ecrw) is ", hex(len(ecrw))) - - # truncate to args.image_size - ecrw_len = len(ecrw) - if ecrw_len > args.image_size: - debug_print("Truncate EC_RW len={0:0x} to image_size={1:0x}".format(ecrw_len,args.image_size)) - ecrw = ecrw[:args.image_size] - ecrw_len = len(ecrw) - - debug_print("len(EC_RW) = ", hex(ecrw_len)) - - # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes - if args.test_spi: - crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4])) - crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder='little') - ecrw[-4:] = crc32_ecrw_bytes - debug_print("ecrw len = ", hex(len(ecrw))) - debug_print("CRC32(ecrw) = ", hex(crc32_ecrw)) - - # Assume FW layout is standard Cortex-M style with vector - # table at start of binary. - # 32-bit word at offset 0x0 = Initial stack pointer - # 32-bit word at offset 0x4 = Address of reset handler - ecrw_entry_tuple = struct.unpack_from('<I', ecrw, 4) - debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0])) - - ecrw_entry = ecrw_entry_tuple[0] - debug_print("ecrw_entry = ", hex(ecrw_entry)) - - # Note: payload_rw is a bytearray therefore is mutable - if args.test_ecrw: - gen_test_ecrw(ecrw) - - os.remove(rorofile) # clean up the temp file - - # MEC152X Add TAG's - spi_list.append((args.tag0_loc, tag0, "tag0")) - spi_list.append((args.tag1_loc, tag1, "tag1")) - - # flashmap is non-zero only for systems with two external - # SPI flash chips. - flashmap = BuildFlashMap(0) - spi_list.append((8, flashmap, "flashmap")) - - # Boot-ROM SPI image header for LFW+EC-RO - spi_list.append((args.header_loc, header, "header(lfw + ro)")) - spi_list.append((args.header_loc + chip_dict["PAYLOAD_OFFSET"], lfw_ecro, - "lfw_ecro")) - - offset = args.header_loc + chip_dict["PAYLOAD_OFFSET"] + lfw_ecro_len - - if ec_info_block != None: - spi_list.append((offset, ec_info_block, "EC Info Block")) - offset += len(ec_info_block) - - if cosignature != None: - spi_list.append((offset, cosignature, "ECRO Cosignature")) - offset += len(cosignature) - - if trailer != None: - spi_list.append((offset, trailer, "ECRO Trailer")) - offset += len(trailer) - - # EC_RW location - rw_offset = int(spi_size // 2) - if args.rw_loc >= 0: - rw_offset = args.rw_loc - - debug_print("rw_offset = 0x{0:08x}".format(rw_offset)) - - assert rw_offset >= offset, \ - print("""Offset of EC_RW at {0:08x} overlaps end - of EC_RO at {0:08x}""".format(rw_offset, offset)) - - spi_list.append((rw_offset, ecrw, "ecrw")) - offset = rw_offset + len(ecrw) - - spi_list = sorted(spi_list) - - dumpsects(spi_list) - - # - # MEC152X Boot-ROM locates TAG0/1 at SPI offset 0 - # instead of end of SPI. - # - with open(args.output, 'wb') as f: - debug_print("Write spi list to file", args.output) - addr = 0 - for s in spi_list: - if addr < s[0]: - debug_print("Offset ",hex(addr)," Length", hex(s[0]-addr), - "fill with 0xff") - f.write(b'\xff' * (s[0] - addr)) - addr = s[0] - debug_print("Offset ",hex(addr), " Length", hex(len(s[1])), "write data") - - f.write(s[1]) - addr += len(s[1]) - - if addr < spi_size: - debug_print("Offset ",hex(addr), " Length", hex(spi_size - addr), - "fill with 0xff") - f.write(b'\xff' * (spi_size - addr)) - - f.flush() + if ec_info_block != None: + spi_list.append((offset, ec_info_block, "EC Info Block")) + offset += len(ec_info_block) -if __name__ == '__main__': - main() + if cosignature != None: + spi_list.append((offset, cosignature, "ECRO Cosignature")) + offset += len(cosignature) + + if trailer != None: + spi_list.append((offset, trailer, "ECRO Trailer")) + offset += len(trailer) + + # EC_RW location + rw_offset = int(spi_size // 2) + if args.rw_loc >= 0: + rw_offset = args.rw_loc + + debug_print("rw_offset = 0x{0:08x}".format(rw_offset)) + + assert rw_offset >= offset, print( + """Offset of EC_RW at {0:08x} overlaps end + of EC_RO at {0:08x}""".format( + rw_offset, offset + ) + ) + + spi_list.append((rw_offset, ecrw, "ecrw")) + offset = rw_offset + len(ecrw) + + spi_list = sorted(spi_list) + + dumpsects(spi_list) + + # + # MEC152X Boot-ROM locates TAG0/1 at SPI offset 0 + # instead of end of SPI. + # + with open(args.output, "wb") as f: + debug_print("Write spi list to file", args.output) + addr = 0 + for s in spi_list: + if addr < s[0]: + debug_print( + "Offset ", hex(addr), " Length", hex(s[0] - addr), "fill with 0xff" + ) + f.write(b"\xff" * (s[0] - addr)) + addr = s[0] + debug_print( + "Offset ", hex(addr), " Length", hex(len(s[1])), "write data" + ) + + f.write(s[1]) + addr += len(s[1]) + + if addr < spi_size: + debug_print( + "Offset ", hex(addr), " Length", hex(spi_size - addr), "fill with 0xff" + ) + f.write(b"\xff" * (spi_size - addr)) + + f.flush() + + +if __name__ == "__main__": + main() diff --git a/chip/mchp/util/pack_ec_mec172x.py b/chip/mchp/util/pack_ec_mec172x.py index 32747d3d9a..bd5ff6edba 100755 --- a/chip/mchp/util/pack_ec_mec172x.py +++ b/chip/mchp/util/pack_ec_mec172x.py @@ -16,7 +16,7 @@ import os import struct import subprocess import tempfile -import zlib # CRC32 +import zlib # CRC32 # MEC172x has 416KB SRAM from 0xC0000 - 0x127FFF # SRAM is divided into contiguous CODE & DATA @@ -28,8 +28,8 @@ import zlib # CRC32 # SPI_ERASE_BLOCK_SIZE = 0x1000 SPI_CLOCK_LIST = [48, 24, 16, 12, 96] -SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b, 0x6b] -SPI_DRIVE_STR_DICT = {2:0, 4:1, 8:2, 12:3} +SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B, 0x6B] +SPI_DRIVE_STR_DICT = {2: 0, 4: 1, 8: 2, 12: 3} # Maximum EC_RO/EC_RW code size is based upon SPI flash erase # sector size, MEC172x Boot-ROM TAG, Header, Footer. # SPI Offset Description @@ -38,118 +38,156 @@ SPI_DRIVE_STR_DICT = {2:0, 4:1, 8:2, 12:3} # 0x1140 - 0x213F 4KB LFW # 0x2040 - 0x3EFFF # 0x3F000 - 0x3FFFF BootROM EC_INFO_BLK || COSIG || ENCR_KEY_HDR(optional) || TRAILER -CHIP_MAX_CODE_SRAM_KB = (256 - 12) +CHIP_MAX_CODE_SRAM_KB = 256 - 12 MEC172X_DICT = { - "LFW_SIZE": 0x1000, - "LOAD_ADDR": 0xC0000, - "TAG_SIZE": 4, - "KEY_BLOB_SIZE": 1584, - "HEADER_SIZE":0x140, - "HEADER_VER":0x03, - "PAYLOAD_GRANULARITY":128, - "PAYLOAD_PAD_BYTE":b'\xff', - "EC_INFO_BLK_SZ":128, - "ENCR_KEY_HDR_SZ":128, - "COSIG_SZ":96, - "TRAILER_SZ":160, - "TAILER_PAD_BYTE":b'\xff', - "PAD_SIZE":128 - } - -CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15, - 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d] + "LFW_SIZE": 0x1000, + "LOAD_ADDR": 0xC0000, + "TAG_SIZE": 4, + "KEY_BLOB_SIZE": 1584, + "HEADER_SIZE": 0x140, + "HEADER_VER": 0x03, + "PAYLOAD_GRANULARITY": 128, + "PAYLOAD_PAD_BYTE": b"\xff", + "EC_INFO_BLK_SZ": 128, + "ENCR_KEY_HDR_SZ": 128, + "COSIG_SZ": 96, + "TRAILER_SZ": 160, + "TAILER_PAD_BYTE": b"\xff", + "PAD_SIZE": 128, +} + +CRC_TABLE = [ + 0x00, + 0x07, + 0x0E, + 0x09, + 0x1C, + 0x1B, + 0x12, + 0x15, + 0x38, + 0x3F, + 0x36, + 0x31, + 0x24, + 0x23, + 0x2A, + 0x2D, +] + def mock_print(*args, **kwargs): - pass + pass + debug_print = mock_print # Debug helper routine def dumpsects(spi_list): - debug_print("spi_list has {0} entries".format(len(spi_list))) - for s in spi_list: - debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0],len(s[1]),s[2])) + debug_print("spi_list has {0} entries".format(len(spi_list))) + for s in spi_list: + debug_print("0x{0:x} 0x{1:x} {2:s}".format(s[0], len(s[1]), s[2])) + def printByteArrayAsHex(ba, title): - debug_print(title,"= ") - if ba == None: - debug_print("None") - return - - count = 0 - for b in ba: - count = count + 1 - debug_print("0x{0:02x}, ".format(b),end="") - if (count % 8) == 0: - debug_print("") - debug_print("") + debug_print(title, "= ") + if ba == None: + debug_print("None") + return + + count = 0 + for b in ba: + count = count + 1 + debug_print("0x{0:02x}, ".format(b), end="") + if (count % 8) == 0: + debug_print("") + debug_print("") + def Crc8(crc, data): - """Update CRC8 value.""" - for v in data: - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]); - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]); - return crc ^ 0x55 + """Update CRC8 value.""" + for v in data: + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]) + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)]) + return crc ^ 0x55 + def GetEntryPoint(payload_file): - """Read entry point from payload EC image.""" - with open(payload_file, 'rb') as f: - f.seek(4) - s = f.read(4) - return int.from_bytes(s, byteorder='little') + """Read entry point from payload EC image.""" + with open(payload_file, "rb") as f: + f.seek(4) + s = f.read(4) + return int.from_bytes(s, byteorder="little") + def GetPayloadFromOffset(payload_file, offset, padsize): - """Read payload and pad it to padsize.""" - with open(payload_file, 'rb') as f: - f.seek(offset) - payload = bytearray(f.read()) - rem_len = len(payload) % padsize - debug_print("GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format(padsize,len(payload),rem_len)) + """Read payload and pad it to padsize.""" + with open(payload_file, "rb") as f: + f.seek(offset) + payload = bytearray(f.read()) + rem_len = len(payload) % padsize + debug_print( + "GetPayload: padsize={0:0x} len(payload)={1:0x} rem={2:0x}".format( + padsize, len(payload), rem_len + ) + ) - if rem_len: - payload += PAYLOAD_PAD_BYTE * (padsize - rem_len) - debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len)) + if rem_len: + payload += PAYLOAD_PAD_BYTE * (padsize - rem_len) + debug_print("GetPayload: Added {0} padding bytes".format(padsize - rem_len)) + + return payload - return payload def GetPayload(payload_file, padsize): - """Read payload and pad it to padsize""" - return GetPayloadFromOffset(payload_file, 0, padsize) + """Read payload and pad it to padsize""" + return GetPayloadFromOffset(payload_file, 0, padsize) + def GetPublicKey(pem_file): - """Extract public exponent and modulus from PEM file.""" - result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text', - '-noout'], stdout=subprocess.PIPE, encoding='utf-8') - modulus_raw = [] - in_modulus = False - for line in result.stdout.splitlines(): - if line.startswith('modulus'): - in_modulus = True - elif not line.startswith(' '): - in_modulus = False - elif in_modulus: - modulus_raw.extend(line.strip().strip(':').split(':')) - if line.startswith('publicExponent'): - exp = int(line.split(' ')[1], 10) - modulus_raw.reverse() - modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) - return struct.pack('<Q', exp), modulus + """Extract public exponent and modulus from PEM file.""" + result = subprocess.run( + ["openssl", "rsa", "-in", pem_file, "-text", "-noout"], + stdout=subprocess.PIPE, + encoding="utf-8", + ) + modulus_raw = [] + in_modulus = False + for line in result.stdout.splitlines(): + if line.startswith("modulus"): + in_modulus = True + elif not line.startswith(" "): + in_modulus = False + elif in_modulus: + modulus_raw.extend(line.strip().strip(":").split(":")) + if line.startswith("publicExponent"): + exp = int(line.split(" ")[1], 10) + modulus_raw.reverse() + modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) + return struct.pack("<Q", exp), modulus + def GetSpiClockParameter(args): - assert args.spi_clock in SPI_CLOCK_LIST, \ - "Unsupported SPI clock speed %d MHz" % args.spi_clock - return SPI_CLOCK_LIST.index(args.spi_clock) + assert args.spi_clock in SPI_CLOCK_LIST, ( + "Unsupported SPI clock speed %d MHz" % args.spi_clock + ) + return SPI_CLOCK_LIST.index(args.spi_clock) + def GetSpiReadCmdParameter(args): - assert args.spi_read_cmd in SPI_READ_CMD_LIST, \ - "Unsupported SPI read command 0x%x" % args.spi_read_cmd - return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + assert args.spi_read_cmd in SPI_READ_CMD_LIST, ( + "Unsupported SPI read command 0x%x" % args.spi_read_cmd + ) + return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + def GetEncodedSpiDriveStrength(args): - assert args.spi_drive_str in SPI_DRIVE_STR_DICT, \ - "Unsupported SPI drive strength %d mA" % args.spi_drive_str - return SPI_DRIVE_STR_DICT.get(args.spi_drive_str) + assert args.spi_drive_str in SPI_DRIVE_STR_DICT, ( + "Unsupported SPI drive strength %d mA" % args.spi_drive_str + ) + return SPI_DRIVE_STR_DICT.get(args.spi_drive_str) + # Return 0=Slow slew rate or 1=Fast slew rate def GetSpiSlewRate(args): @@ -157,12 +195,14 @@ def GetSpiSlewRate(args): return 1 return 0 + # Return SPI CPOL = 0 or 1 def GetSpiCpol(args): if args.spi_cpol == 0: return 0 return 1 + # Return SPI CPHA_MOSI # 0 = SPI Master drives data is stable on inactive to clock edge # 1 = SPI Master drives data is stable on active to inactive clock edge @@ -171,6 +211,7 @@ def GetSpiCphaMosi(args): return 0 return 1 + # Return SPI CPHA_MISO 0 or 1 # 0 = SPI Master samples data on inactive to active clock edge # 1 = SPI Master samples data on active to inactive clock edge @@ -179,8 +220,10 @@ def GetSpiCphaMiso(args): return 0 return 1 + def PadZeroTo(data, size): - data.extend(b'\0' * (size - len(data))) + data.extend(b"\0" * (size - len(data))) + # # Boot-ROM SPI image encryption not used with Chromebooks @@ -188,6 +231,7 @@ def PadZeroTo(data, size): def EncryptPayload(args, chip_dict, payload): return None + # # Build SPI image header for MEC172x # MEC172x image header size = 320(0x140) bytes @@ -262,89 +306,94 @@ def EncryptPayload(args, chip_dict, payload): # header[0x110:0x140] = Header ECDSA-384 signature y-coor. = 0 Auth. disabled # def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry): - header_size = chip_dict["HEADER_SIZE"] - - # allocate zero filled header - header = bytearray(b'\x00' * header_size) - debug_print("len(header) = ", len(header)) - - # Identifier and header version - header[0:4] = b'PHCM' - header[4] = chip_dict["HEADER_VER"] - - # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips - spiFreqIndex = GetSpiClockParameter(args) - if spiFreqIndex > 3: - header[6] |= 0x01 - else: - header[5] = spiFreqIndex - - header[5] |= ((GetEncodedSpiDriveStrength(args) & 0x03) << 2) - header[5] |= ((GetSpiSlewRate(args) & 0x01) << 4) - header[5] |= ((GetSpiCpol(args) & 0x01) << 5) - header[5] |= ((GetSpiCphaMosi(args) & 0x01) << 6) - header[5] |= ((GetSpiCphaMiso(args) & 0x01) << 7) - - # header[6] - # b[0] value set above - # b[2:1] = 00b, b[5:3]=111b - # b[7]=0 No encryption of FW payload - header[6] |= 0x7 << 3 - - # SPI read command set same for both chips - header[7] = GetSpiReadCmdParameter(args) & 0xFF - - # bytes 0x08 - 0x0b - header[0x08:0x0C] = load_addr.to_bytes(4, byteorder='little') - # bytes 0x0c - 0x0f - header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder='little') - - # bytes 0x10 - 0x11 payload length in units of 128 bytes - assert payload_len % chip_dict["PAYLOAD_GRANULARITY"] == 0, \ - print("Payload size not a multiple of {0}".format(chip_dict["PAYLOAD_GRANULARITY"])) - - payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"]) - assert payload_units < 0x10000, \ - print("Payload too large: len={0} units={1}".format(payload_len, payload_units)) - - header[0x10:0x12] = payload_units.to_bytes(2, 'little') - - # bytes 0x14 - 0x17 TODO offset from start of payload to FW payload to be - # loaded by Boot-ROM. We ask Boot-ROM to load (LFW || EC_RO). - # LFW location provided on the command line. - assert (args.lfw_loc % 4096 == 0), \ - print("LFW location not on a 4KB boundary! 0x{0:0x}".format(args.lfw_loc)) - - assert args.lfw_loc >= (args.header_loc + chip_dict["HEADER_SIZE"]), \ - print("LFW location not greater than header location + header size") - - lfw_ofs = args.lfw_loc - args.header_loc - header[0x14:0x18] = lfw_ofs.to_bytes(4, 'little') - - # MEC172x: authentication key select. Authentication not used, set to 0. - header[0x18] = 0 - - # header[0x19], header[0x20:0x28] - # header[0x1A:0x20] reserved 0 - # MEC172x: supports SPI flash devices with drive strength settings - # TODO leave these fields at 0 for now. We must add 6 command line - # arguments. - - # header[0x28:0x48] reserve can be any value - # header[0x48:0x50] Customer use. TODO - # authentication disabled, leave these 0. - # header[0x50:0x80] ECDSA P384 Authentication Public key Rx - # header[0x80:0xB0] ECDSA P384 Authentication Public key Ry - - # header[0xB0:0xE0] = SHA384(header[0:0xB0]) - header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest() - # When ECDSA authentication is disabled MCHP SPI image generator - # is filling the last 48 bytes of the Header with 0xff - header[-48:] = b'\xff' * 48 - - debug_print("After hash: len(header) = ", len(header)) - - return header + header_size = chip_dict["HEADER_SIZE"] + + # allocate zero filled header + header = bytearray(b"\x00" * header_size) + debug_print("len(header) = ", len(header)) + + # Identifier and header version + header[0:4] = b"PHCM" + header[4] = chip_dict["HEADER_VER"] + + # SPI frequency, drive strength, CPOL/CPHA encoding same for both chips + spiFreqIndex = GetSpiClockParameter(args) + if spiFreqIndex > 3: + header[6] |= 0x01 + else: + header[5] = spiFreqIndex + + header[5] |= (GetEncodedSpiDriveStrength(args) & 0x03) << 2 + header[5] |= (GetSpiSlewRate(args) & 0x01) << 4 + header[5] |= (GetSpiCpol(args) & 0x01) << 5 + header[5] |= (GetSpiCphaMosi(args) & 0x01) << 6 + header[5] |= (GetSpiCphaMiso(args) & 0x01) << 7 + + # header[6] + # b[0] value set above + # b[2:1] = 00b, b[5:3]=111b + # b[7]=0 No encryption of FW payload + header[6] |= 0x7 << 3 + + # SPI read command set same for both chips + header[7] = GetSpiReadCmdParameter(args) & 0xFF + + # bytes 0x08 - 0x0b + header[0x08:0x0C] = load_addr.to_bytes(4, byteorder="little") + # bytes 0x0c - 0x0f + header[0x0C:0x10] = payload_entry.to_bytes(4, byteorder="little") + + # bytes 0x10 - 0x11 payload length in units of 128 bytes + assert payload_len % chip_dict["PAYLOAD_GRANULARITY"] == 0, print( + "Payload size not a multiple of {0}".format(chip_dict["PAYLOAD_GRANULARITY"]) + ) + + payload_units = int(payload_len // chip_dict["PAYLOAD_GRANULARITY"]) + assert payload_units < 0x10000, print( + "Payload too large: len={0} units={1}".format(payload_len, payload_units) + ) + + header[0x10:0x12] = payload_units.to_bytes(2, "little") + + # bytes 0x14 - 0x17 TODO offset from start of payload to FW payload to be + # loaded by Boot-ROM. We ask Boot-ROM to load (LFW || EC_RO). + # LFW location provided on the command line. + assert args.lfw_loc % 4096 == 0, print( + "LFW location not on a 4KB boundary! 0x{0:0x}".format(args.lfw_loc) + ) + + assert args.lfw_loc >= (args.header_loc + chip_dict["HEADER_SIZE"]), print( + "LFW location not greater than header location + header size" + ) + + lfw_ofs = args.lfw_loc - args.header_loc + header[0x14:0x18] = lfw_ofs.to_bytes(4, "little") + + # MEC172x: authentication key select. Authentication not used, set to 0. + header[0x18] = 0 + + # header[0x19], header[0x20:0x28] + # header[0x1A:0x20] reserved 0 + # MEC172x: supports SPI flash devices with drive strength settings + # TODO leave these fields at 0 for now. We must add 6 command line + # arguments. + + # header[0x28:0x48] reserve can be any value + # header[0x48:0x50] Customer use. TODO + # authentication disabled, leave these 0. + # header[0x50:0x80] ECDSA P384 Authentication Public key Rx + # header[0x80:0xB0] ECDSA P384 Authentication Public key Ry + + # header[0xB0:0xE0] = SHA384(header[0:0xB0]) + header[0xB0:0xE0] = hashlib.sha384(header[0:0xB0]).digest() + # When ECDSA authentication is disabled MCHP SPI image generator + # is filling the last 48 bytes of the Header with 0xff + header[-48:] = b"\xff" * 48 + + debug_print("After hash: len(header) = ", len(header)) + + return header + # # MEC172x 128-byte EC Info Block appended to end of padded FW binary. @@ -361,9 +410,10 @@ def BuildHeader2(args, chip_dict, payload_len, load_addr, payload_entry): # byte[0x7f] = current imeage revision # def GenEcInfoBlock(args, chip_dict): - # ecinfo = bytearray([0xff] * chip_dict["EC_INFO_BLK_SZ"]) - ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"]) - return ecinfo + # ecinfo = bytearray([0xff] * chip_dict["EC_INFO_BLK_SZ"]) + ecinfo = bytearray(chip_dict["EC_INFO_BLK_SZ"]) + return ecinfo + # # Generate SPI FW image co-signature. @@ -376,7 +426,8 @@ def GenEcInfoBlock(args, chip_dict): # signature. # def GenCoSignature(args, chip_dict, payload): - return bytearray(b'\xff' * chip_dict["COSIG_SZ"]) + return bytearray(b"\xff" * chip_dict["COSIG_SZ"]) + # # Generate SPI FW Image trailer. @@ -387,27 +438,33 @@ def GenCoSignature(args, chip_dict, payload): # trailer[144:160] = 0xFF. Boot-ROM spec. says these bytes should be random. # Authentication & encryption are not used therefore random data # is not necessary. -def GenTrailer(args, chip_dict, payload, encryption_key_header, - ec_info_block, cosignature): +def GenTrailer( + args, chip_dict, payload, encryption_key_header, ec_info_block, cosignature +): debug_print("GenTrailer SHA384 computation") trailer = bytearray(chip_dict["TAILER_PAD_BYTE"] * chip_dict["TRAILER_SZ"]) hasher = hashlib.sha384() hasher.update(payload) debug_print(" Update: payload len=0x{0:0x}".format(len(payload))) if ec_info_block != None: - hasher.update(ec_info_block) - debug_print(" Update: ec_info_block len=0x{0:0x}".format(len(ec_info_block))) + hasher.update(ec_info_block) + debug_print(" Update: ec_info_block len=0x{0:0x}".format(len(ec_info_block))) if encryption_key_header != None: - hasher.update(encryption_key_header) - debug_print(" Update: encryption_key_header len=0x{0:0x}".format(len(encryption_key_header))) + hasher.update(encryption_key_header) + debug_print( + " Update: encryption_key_header len=0x{0:0x}".format( + len(encryption_key_header) + ) + ) if cosignature != None: - hasher.update(cosignature) - debug_print(" Update: cosignature len=0x{0:0x}".format(len(cosignature))) + hasher.update(cosignature) + debug_print(" Update: cosignature len=0x{0:0x}".format(len(cosignature))) trailer[0:48] = hasher.digest() - trailer[-16:] = 16 * b'\xff' + trailer[-16:] = 16 * b"\xff" return trailer + # MEC172x supports two 32-bit Tags located at offsets 0x0 and 0x4 # in the SPI flash. # Tag format: @@ -418,16 +475,21 @@ def GenTrailer(args, chip_dict, payload, encryption_key_header, # to the same flash part. # def BuildTag(args): - tag = bytearray([(args.header_loc >> 8) & 0xff, - (args.header_loc >> 16) & 0xff, - (args.header_loc >> 24) & 0xff]) - tag.append(Crc8(0, tag)) - return tag + tag = bytearray( + [ + (args.header_loc >> 8) & 0xFF, + (args.header_loc >> 16) & 0xFF, + (args.header_loc >> 24) & 0xFF, + ] + ) + tag.append(Crc8(0, tag)) + return tag + def BuildTagFromHdrAddr(header_loc): - tag = bytearray([(header_loc >> 8) & 0xff, - (header_loc >> 16) & 0xff, - (header_loc >> 24) & 0xff]) + tag = bytearray( + [(header_loc >> 8) & 0xFF, (header_loc >> 16) & 0xFF, (header_loc >> 24) & 0xFF] + ) tag.append(Crc8(0, tag)) return tag @@ -444,12 +506,13 @@ def BuildTagFromHdrAddr(header_loc): # Output: # bytearray of length 4 def BuildFlashMap(secondSpiFlashBaseAddr): - flashmap = bytearray(4) - flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xff - flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xff - flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xff - flashmap[3] = Crc8(0, flashmap) - return flashmap + flashmap = bytearray(4) + flashmap[0] = (secondSpiFlashBaseAddr >> 12) & 0xFF + flashmap[1] = (secondSpiFlashBaseAddr >> 20) & 0xFF + flashmap[2] = (secondSpiFlashBaseAddr >> 28) & 0xFF + flashmap[3] = Crc8(0, flashmap) + return flashmap + # # Creates temporary file for read/write @@ -460,21 +523,22 @@ def BuildFlashMap(secondSpiFlashBaseAddr): # Returns temporary file name # def PacklfwRoImage(rorw_file, loader_file, image_size): - """Create a temp file with the - first image_size bytes from the loader file and append bytes - from the rorw file. - return the filename""" - fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around - with open(loader_file,'rb') as fin1: # read 4KB loader file - pro = fin1.read() - fo.write(pro) # write 4KB loader data to temp file - with open(rorw_file, 'rb') as fin: - ro = fin.read(image_size) - - fo.write(ro) - fo.close() - - return fo.name + """Create a temp file with the + first image_size bytes from the loader file and append bytes + from the rorw file. + return the filename""" + fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around + with open(loader_file, "rb") as fin1: # read 4KB loader file + pro = fin1.read() + fo.write(pro) # write 4KB loader data to temp file + with open(rorw_file, "rb") as fin: + ro = fin.read(image_size) + + fo.write(ro) + fo.close() + + return fo.name + # # Generate a test EC_RW image of same size @@ -485,136 +549,193 @@ def PacklfwRoImage(rorw_file, loader_file, image_size): # process hash generation. # def gen_test_ecrw(pldrw): - debug_print("gen_test_ecrw: pldrw type =", type(pldrw)) - debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw))) - cookie1_pos = pldrw.find(b'\x99\x88\x77\xce') - cookie2_pos = pldrw.find(b'\xdd\xbb\xaa\xce', cookie1_pos+4) - t = struct.unpack("<L", pldrw[cookie1_pos+0x24:cookie1_pos+0x28]) - size = t[0] - debug_print("EC_RW size =", size, " = ", hex(size)) - - debug_print("Found cookie1 at ", hex(cookie1_pos)) - debug_print("Found cookie2 at ", hex(cookie2_pos)) - - if cookie1_pos > 0 and cookie2_pos > cookie1_pos: - for i in range(0, cookie1_pos): - pldrw[i] = 0xA5 - for i in range(cookie2_pos+4, len(pldrw)): - pldrw[i] = 0xA5 - - with open("ec_RW_test.bin", "wb") as fecrw: - fecrw.write(pldrw[:size]) + debug_print("gen_test_ecrw: pldrw type =", type(pldrw)) + debug_print("len pldrw =", len(pldrw), " = ", hex(len(pldrw))) + cookie1_pos = pldrw.find(b"\x99\x88\x77\xce") + cookie2_pos = pldrw.find(b"\xdd\xbb\xaa\xce", cookie1_pos + 4) + t = struct.unpack("<L", pldrw[cookie1_pos + 0x24 : cookie1_pos + 0x28]) + size = t[0] + debug_print("EC_RW size =", size, " = ", hex(size)) + + debug_print("Found cookie1 at ", hex(cookie1_pos)) + debug_print("Found cookie2 at ", hex(cookie2_pos)) + + if cookie1_pos > 0 and cookie2_pos > cookie1_pos: + for i in range(0, cookie1_pos): + pldrw[i] = 0xA5 + for i in range(cookie2_pos + 4, len(pldrw)): + pldrw[i] = 0xA5 + + with open("ec_RW_test.bin", "wb") as fecrw: + fecrw.write(pldrw[:size]) + def parseargs(): - rpath = os.path.dirname(os.path.relpath(__file__)) - - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", - help="EC binary to pack, usually ec.bin or ec.RO.flat.", - metavar="EC_BIN", default="ec.bin") - parser.add_argument("-o", "--output", - help="Output flash binary file", - metavar="EC_SPI_FLASH", default="ec.packed.bin") - parser.add_argument("--loader_file", - help="EC loader binary", - default="ecloader.bin") - parser.add_argument("--load_addr", type=int, - help="EC SRAM load address", - default=0xC0000) - parser.add_argument("-s", "--spi_size", type=int, - help="Size of the SPI flash in KB", - default=512) - parser.add_argument("-l", "--header_loc", type=int, - help="Location of header in SPI flash. Must be on a 256 byte boundary", - default=0x0100) - parser.add_argument("--lfw_loc", type=int, - help="Location of LFW in SPI flash. Must be on a 4KB boundary", - default=0x1000) - parser.add_argument("--lfw_size", type=int, - help="LFW size in bytes", - default=0x1000) - parser.add_argument("-r", "--rw_loc", type=int, - help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size", - default=-1) - parser.add_argument("--spi_clock", type=int, - help="SPI clock speed. 8, 12, 24, or 48 MHz.", - default=24) - parser.add_argument("--spi_read_cmd", type=int, - help="SPI read command. 0x3, 0xB, 0x3B, or 0x6B.", - default=0xb) - parser.add_argument("--image_size", type=int, - help="Size of a single image. Default 244KB", - default=(244 * 1024)) - parser.add_argument("--test_spi", action='store_true', - help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries", - default=False) - parser.add_argument("--test_ecrw", action='store_true', - help="Use fixed pattern for EC_RW but preserve image_data", - default=False) - parser.add_argument("--verbose", action='store_true', - help="Enable verbose output", - default=False) - parser.add_argument("--tag0_loc", type=int, - help="MEC172x TAG0 SPI offset", - default=0) - parser.add_argument("--tag1_loc", type=int, - help="MEC172x TAG1 SPI offset", - default=4) - parser.add_argument("--spi_drive_str", type=int, - help="Chip SPI drive strength in mA: 2, 4, 8, or 12", - default=4) - parser.add_argument("--spi_slew_fast", action='store_true', - help="SPI use fast slew rate. Default is False", - default=False) - parser.add_argument("--spi_cpol", type=int, - help="SPI clock polarity when idle. Defealt is 0(low)", - default=0) - parser.add_argument("--spi_cpha_mosi", type=int, - help="""SPI clock phase controller drives data. + rpath = os.path.dirname(os.path.relpath(__file__)) + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input", + help="EC binary to pack, usually ec.bin or ec.RO.flat.", + metavar="EC_BIN", + default="ec.bin", + ) + parser.add_argument( + "-o", + "--output", + help="Output flash binary file", + metavar="EC_SPI_FLASH", + default="ec.packed.bin", + ) + parser.add_argument( + "--loader_file", help="EC loader binary", default="ecloader.bin" + ) + parser.add_argument( + "--load_addr", type=int, help="EC SRAM load address", default=0xC0000 + ) + parser.add_argument( + "-s", "--spi_size", type=int, help="Size of the SPI flash in KB", default=512 + ) + parser.add_argument( + "-l", + "--header_loc", + type=int, + help="Location of header in SPI flash. Must be on a 256 byte boundary", + default=0x0100, + ) + parser.add_argument( + "--lfw_loc", + type=int, + help="Location of LFW in SPI flash. Must be on a 4KB boundary", + default=0x1000, + ) + parser.add_argument( + "--lfw_size", type=int, help="LFW size in bytes", default=0x1000 + ) + parser.add_argument( + "-r", + "--rw_loc", + type=int, + help="Start offset of EC_RW. Default is -1 meaning 1/2 flash size", + default=-1, + ) + parser.add_argument( + "--spi_clock", + type=int, + help="SPI clock speed. 8, 12, 24, or 48 MHz.", + default=24, + ) + parser.add_argument( + "--spi_read_cmd", + type=int, + help="SPI read command. 0x3, 0xB, 0x3B, or 0x6B.", + default=0xB, + ) + parser.add_argument( + "--image_size", + type=int, + help="Size of a single image. Default 244KB", + default=(244 * 1024), + ) + parser.add_argument( + "--test_spi", + action="store_true", + help="Test SPI data integrity by adding CRC32 in last 4-bytes of RO/RW binaries", + default=False, + ) + parser.add_argument( + "--test_ecrw", + action="store_true", + help="Use fixed pattern for EC_RW but preserve image_data", + default=False, + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose output", default=False + ) + parser.add_argument( + "--tag0_loc", type=int, help="MEC172x TAG0 SPI offset", default=0 + ) + parser.add_argument( + "--tag1_loc", type=int, help="MEC172x TAG1 SPI offset", default=4 + ) + parser.add_argument( + "--spi_drive_str", + type=int, + help="Chip SPI drive strength in mA: 2, 4, 8, or 12", + default=4, + ) + parser.add_argument( + "--spi_slew_fast", + action="store_true", + help="SPI use fast slew rate. Default is False", + default=False, + ) + parser.add_argument( + "--spi_cpol", + type=int, + help="SPI clock polarity when idle. Defealt is 0(low)", + default=0, + ) + parser.add_argument( + "--spi_cpha_mosi", + type=int, + help="""SPI clock phase controller drives data. 0=Data driven on active to inactive clock edge, 1=Data driven on inactive to active clock edge""", - default=0) - parser.add_argument("--spi_cpha_miso", type=int, - help="""SPI clock phase controller samples data. + default=0, + ) + parser.add_argument( + "--spi_cpha_miso", + type=int, + help="""SPI clock phase controller samples data. 0=Data sampled on inactive to active clock edge, 1=Data sampled on active to inactive clock edge""", - default=0) + default=0, + ) + + return parser.parse_args() - return parser.parse_args() def print_args(args): - debug_print("parsed arguments:") - debug_print(".input = ", args.input) - debug_print(".output = ", args.output) - debug_print(".loader_file = ", args.loader_file) - debug_print(".spi_size (KB) = ", hex(args.spi_size)) - debug_print(".image_size = ", hex(args.image_size)) - debug_print(".load_addr", hex(args.load_addr)) - debug_print(".tag0_loc = ", hex(args.tag0_loc)) - debug_print(".tag1_loc = ", hex(args.tag1_loc)) - debug_print(".header_loc = ", hex(args.header_loc)) - debug_print(".lfw_loc = ", hex(args.lfw_loc)) - debug_print(".lfw_size = ", hex(args.lfw_size)) - if args.rw_loc < 0: - debug_print(".rw_loc = ", args.rw_loc) - else: - debug_print(".rw_loc = ", hex(args.rw_loc)) - debug_print(".spi_clock (MHz) = ", args.spi_clock) - debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd)) - debug_print(".test_spi = ", args.test_spi) - debug_print(".test_ecrw = ", args.test_ecrw) - debug_print(".verbose = ", args.verbose) - debug_print(".spi_drive_str = ", args.spi_drive_str) - debug_print(".spi_slew_fast = ", args.spi_slew_fast) - debug_print(".spi_cpol = ", args.spi_cpol) - debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi) - debug_print(".spi_cpha_miso = ", args.spi_cpha_miso) + debug_print("parsed arguments:") + debug_print(".input = ", args.input) + debug_print(".output = ", args.output) + debug_print(".loader_file = ", args.loader_file) + debug_print(".spi_size (KB) = ", hex(args.spi_size)) + debug_print(".image_size = ", hex(args.image_size)) + debug_print(".load_addr", hex(args.load_addr)) + debug_print(".tag0_loc = ", hex(args.tag0_loc)) + debug_print(".tag1_loc = ", hex(args.tag1_loc)) + debug_print(".header_loc = ", hex(args.header_loc)) + debug_print(".lfw_loc = ", hex(args.lfw_loc)) + debug_print(".lfw_size = ", hex(args.lfw_size)) + if args.rw_loc < 0: + debug_print(".rw_loc = ", args.rw_loc) + else: + debug_print(".rw_loc = ", hex(args.rw_loc)) + debug_print(".spi_clock (MHz) = ", args.spi_clock) + debug_print(".spi_read_cmd = ", hex(args.spi_read_cmd)) + debug_print(".test_spi = ", args.test_spi) + debug_print(".test_ecrw = ", args.test_ecrw) + debug_print(".verbose = ", args.verbose) + debug_print(".spi_drive_str = ", args.spi_drive_str) + debug_print(".spi_slew_fast = ", args.spi_slew_fast) + debug_print(".spi_cpol = ", args.spi_cpol) + debug_print(".spi_cpha_mosi = ", args.spi_cpha_mosi) + debug_print(".spi_cpha_miso = ", args.spi_cpha_miso) + def spi_list_append(mylist, loc, data, description): - """Append SPI data block tuple to list""" - t = (loc, data, description) - mylist.append(t) - debug_print("Add SPI entry: offset=0x{0:08x} len=0x{1:0x} descr={2}".format(loc, len(data), description)) + """Append SPI data block tuple to list""" + t = (loc, data, description) + mylist.append(t) + debug_print( + "Add SPI entry: offset=0x{0:08x} len=0x{1:0x} descr={2}".format( + loc, len(data), description + ) + ) + # # Handle quiet mode build from Makefile @@ -652,200 +773,207 @@ def spi_list_append(mylist, loc, data, description): # || 48 * [0] # def main(): - global debug_print - - args = parseargs() - - if args.verbose: - debug_print = print - - debug_print("Begin pack_ec_mec172x.py script") - - print_args(args) - - chip_dict = MEC172X_DICT - - # Boot-ROM requires header location aligned >= 256 bytes. - # CrOS EC flash image update code requires EC_RO/RW location to be aligned - # on a flash erase size boundary and EC_RO/RW size to be a multiple of - # the smallest flash erase block size. - - spi_size = args.spi_size * 1024 - spi_image_size = spi_size // 2 - - rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size) - debug_print("Temporary file containing LFW + EC_RO is ", rorofile) - - lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"]) - lfw_ecro_len = len(lfw_ecro) - debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len)) - - # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes - if args.test_spi: - crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4])) - crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder='little') - lfw_ecro[-4:] = crc32_ecro_bytes - debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE)) - debug_print("CRC32(ecro-4) = ", hex(crc32_ecro)) - - # Reads entry point from offset 4 of file. - # This assumes binary has Cortex-M4 vector table at offset 0. - # 32-bit word at offset 0x0 initial stack pointer value - # 32-bit word at offset 0x4 address of reset handler - # NOTE: reset address will have bit[0]=1 to ensure thumb mode. - lfw_ecro_entry = GetEntryPoint(rorofile) - debug_print("LFW Entry point from GetEntryPoint = 0x{0:08x}".format(lfw_ecro_entry)) - - # Chromebooks are not using MEC BootROM SPI header/payload authentication - # or payload encryption. In this case the header authentication signature - # is filled with the hash digest of the respective entity. - # BuildHeader2 computes the hash digest and stores it in the correct - # header location. - header = BuildHeader2(args, chip_dict, lfw_ecro_len, - args.load_addr, lfw_ecro_entry) - printByteArrayAsHex(header, "Header(lfw_ecro)") - - # If payload encryption used then encrypt payload and - # generate Payload Key Header. If encryption not used - # payload is not modified and the method returns None - encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro) - printByteArrayAsHex(encryption_key_header, - "LFW + EC_RO encryption_key_header") - - ec_info_block = GenEcInfoBlock(args, chip_dict) - printByteArrayAsHex(ec_info_block, "EC Info Block") - - cosignature = GenCoSignature(args, chip_dict, lfw_ecro) - printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature") - - trailer = GenTrailer(args, chip_dict, lfw_ecro, encryption_key_header, - ec_info_block, cosignature) - - printByteArrayAsHex(trailer, "LFW + EC_RO trailer") - - # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only. - tag0 = BuildTag(args) - tag1 = tag0 - - debug_print("Call to GetPayloadFromOffset") - debug_print("args.input = ", args.input) - debug_print("args.image_size = ", hex(args.image_size)) - - ecrw = GetPayloadFromOffset(args.input, args.image_size, - chip_dict["PAD_SIZE"]) - debug_print("type(ecrw) is ", type(ecrw)) - debug_print("len(ecrw) is ", hex(len(ecrw))) - - # truncate to args.image_size - ecrw_len = len(ecrw) - if ecrw_len > args.image_size: - debug_print("Truncate EC_RW len={0:0x} to image_size={1:0x}".format(ecrw_len,args.image_size)) - ecrw = ecrw[:args.image_size] - ecrw_len = len(ecrw) + global debug_print + + args = parseargs() + + if args.verbose: + debug_print = print + + debug_print("Begin pack_ec_mec172x.py script") + + print_args(args) + + chip_dict = MEC172X_DICT + + # Boot-ROM requires header location aligned >= 256 bytes. + # CrOS EC flash image update code requires EC_RO/RW location to be aligned + # on a flash erase size boundary and EC_RO/RW size to be a multiple of + # the smallest flash erase block size. + + spi_size = args.spi_size * 1024 + spi_image_size = spi_size // 2 + + rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size) + debug_print("Temporary file containing LFW + EC_RO is ", rorofile) + + lfw_ecro = GetPayload(rorofile, chip_dict["PAD_SIZE"]) + lfw_ecro_len = len(lfw_ecro) + debug_print("Padded LFW + EC_RO length = ", hex(lfw_ecro_len)) + + # SPI test mode compute CRC32 of EC_RO and store in last 4 bytes + if args.test_spi: + crc32_ecro = zlib.crc32(bytes(lfw_ecro[LFW_SIZE:-4])) + crc32_ecro_bytes = crc32_ecro.to_bytes(4, byteorder="little") + lfw_ecro[-4:] = crc32_ecro_bytes + debug_print("ecro len = ", hex(len(lfw_ecro) - LFW_SIZE)) + debug_print("CRC32(ecro-4) = ", hex(crc32_ecro)) + + # Reads entry point from offset 4 of file. + # This assumes binary has Cortex-M4 vector table at offset 0. + # 32-bit word at offset 0x0 initial stack pointer value + # 32-bit word at offset 0x4 address of reset handler + # NOTE: reset address will have bit[0]=1 to ensure thumb mode. + lfw_ecro_entry = GetEntryPoint(rorofile) + debug_print("LFW Entry point from GetEntryPoint = 0x{0:08x}".format(lfw_ecro_entry)) + + # Chromebooks are not using MEC BootROM SPI header/payload authentication + # or payload encryption. In this case the header authentication signature + # is filled with the hash digest of the respective entity. + # BuildHeader2 computes the hash digest and stores it in the correct + # header location. + header = BuildHeader2(args, chip_dict, lfw_ecro_len, args.load_addr, lfw_ecro_entry) + printByteArrayAsHex(header, "Header(lfw_ecro)") + + # If payload encryption used then encrypt payload and + # generate Payload Key Header. If encryption not used + # payload is not modified and the method returns None + encryption_key_header = EncryptPayload(args, chip_dict, lfw_ecro) + printByteArrayAsHex(encryption_key_header, "LFW + EC_RO encryption_key_header") + + ec_info_block = GenEcInfoBlock(args, chip_dict) + printByteArrayAsHex(ec_info_block, "EC Info Block") + + cosignature = GenCoSignature(args, chip_dict, lfw_ecro) + printByteArrayAsHex(cosignature, "LFW + EC_RO cosignature") + + trailer = GenTrailer( + args, chip_dict, lfw_ecro, encryption_key_header, ec_info_block, cosignature + ) + + printByteArrayAsHex(trailer, "LFW + EC_RO trailer") + + # Build TAG0. Set TAG1=TAG0 Boot-ROM is allowed to load EC-RO only. + tag0 = BuildTag(args) + tag1 = tag0 + + debug_print("Call to GetPayloadFromOffset") + debug_print("args.input = ", args.input) + debug_print("args.image_size = ", hex(args.image_size)) + + ecrw = GetPayloadFromOffset(args.input, args.image_size, chip_dict["PAD_SIZE"]) + debug_print("type(ecrw) is ", type(ecrw)) + debug_print("len(ecrw) is ", hex(len(ecrw))) + + # truncate to args.image_size + ecrw_len = len(ecrw) + if ecrw_len > args.image_size: + debug_print( + "Truncate EC_RW len={0:0x} to image_size={1:0x}".format( + ecrw_len, args.image_size + ) + ) + ecrw = ecrw[: args.image_size] + ecrw_len = len(ecrw) + + debug_print("len(EC_RW) = ", hex(ecrw_len)) + + # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes + if args.test_spi: + crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4])) + crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder="little") + ecrw[-4:] = crc32_ecrw_bytes + debug_print("ecrw len = ", hex(len(ecrw))) + debug_print("CRC32(ecrw) = ", hex(crc32_ecrw)) + + # Assume FW layout is standard Cortex-M style with vector + # table at start of binary. + # 32-bit word at offset 0x0 = Initial stack pointer + # 32-bit word at offset 0x4 = Address of reset handler + ecrw_entry_tuple = struct.unpack_from("<I", ecrw, 4) + debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0])) + + ecrw_entry = ecrw_entry_tuple[0] + debug_print("ecrw_entry = ", hex(ecrw_entry)) + + # Note: payload_rw is a bytearray therefore is mutable + if args.test_ecrw: + gen_test_ecrw(ecrw) + + os.remove(rorofile) # clean up the temp file + + spi_list = [] + + # MEC172x Add TAG's + # spi_list.append((args.tag0_loc, tag0, "tag0")) + # spi_list.append((args.tag1_loc, tag1, "tag1")) + spi_list_append(spi_list, args.tag0_loc, tag0, "TAG0") + spi_list_append(spi_list, args.tag1_loc, tag1, "TAG1") + + # Boot-ROM SPI image header for LFW+EC-RO + # spi_list.append((args.header_loc, header, "header(lfw + ro)")) + spi_list_append(spi_list, args.header_loc, header, "LFW-EC_RO Header") + + spi_list_append(spi_list, args.lfw_loc, lfw_ecro, "LFW-EC_RO FW") + + offset = args.lfw_loc + len(lfw_ecro) + debug_print("SPI offset after LFW_ECRO = 0x{0:08x}".format(offset)) - debug_print("len(EC_RW) = ", hex(ecrw_len)) - - # SPI test mode compute CRC32 of EC_RW and store in last 4 bytes - if args.test_spi: - crc32_ecrw = zlib.crc32(bytes(ecrw[0:-4])) - crc32_ecrw_bytes = crc32_ecrw.to_bytes(4, byteorder='little') - ecrw[-4:] = crc32_ecrw_bytes - debug_print("ecrw len = ", hex(len(ecrw))) - debug_print("CRC32(ecrw) = ", hex(crc32_ecrw)) - - # Assume FW layout is standard Cortex-M style with vector - # table at start of binary. - # 32-bit word at offset 0x0 = Initial stack pointer - # 32-bit word at offset 0x4 = Address of reset handler - ecrw_entry_tuple = struct.unpack_from('<I', ecrw, 4) - debug_print("ecrw_entry_tuple[0] = ", hex(ecrw_entry_tuple[0])) - - ecrw_entry = ecrw_entry_tuple[0] - debug_print("ecrw_entry = ", hex(ecrw_entry)) - - # Note: payload_rw is a bytearray therefore is mutable - if args.test_ecrw: - gen_test_ecrw(ecrw) - - os.remove(rorofile) # clean up the temp file - - spi_list = [] - - # MEC172x Add TAG's - #spi_list.append((args.tag0_loc, tag0, "tag0")) - #spi_list.append((args.tag1_loc, tag1, "tag1")) - spi_list_append(spi_list, args.tag0_loc, tag0, "TAG0") - spi_list_append(spi_list, args.tag1_loc, tag1, "TAG1") - - # Boot-ROM SPI image header for LFW+EC-RO - #spi_list.append((args.header_loc, header, "header(lfw + ro)")) - spi_list_append(spi_list, args.header_loc, header, "LFW-EC_RO Header") - - spi_list_append(spi_list, args.lfw_loc, lfw_ecro, "LFW-EC_RO FW") - - offset = args.lfw_loc + len(lfw_ecro) - debug_print("SPI offset after LFW_ECRO = 0x{0:08x}".format(offset)) - - if ec_info_block != None: - spi_list_append(spi_list, offset, ec_info_block, "LFW-EC_RO Info Block") - offset += len(ec_info_block) - - debug_print("SPI offset after ec_info_block = 0x{0:08x}".format(offset)) - - if cosignature != None: - #spi_list.append((offset, co-signature, "ECRO Co-signature")) - spi_list_append(spi_list, offset, cosignature, "LFW-EC_RO Co-signature") - offset += len(cosignature) - - debug_print("SPI offset after co-signature = 0x{0:08x}".format(offset)) - - if trailer != None: - #spi_list.append((offset, trailer, "ECRO Trailer")) - spi_list_append(spi_list, offset, trailer, "LFW-EC_RO trailer") - offset += len(trailer) - - debug_print("SPI offset after trailer = 0x{0:08x}".format(offset)) - - # EC_RW location - rw_offset = int(spi_size // 2) - if args.rw_loc >= 0: - rw_offset = args.rw_loc - - debug_print("rw_offset = 0x{0:08x}".format(rw_offset)) - - #spi_list.append((rw_offset, ecrw, "ecrw")) - spi_list_append(spi_list, rw_offset, ecrw, "EC_RW") - offset = rw_offset + len(ecrw) - - spi_list = sorted(spi_list) - - debug_print("Display spi_list:") - dumpsects(spi_list) - - # - # MEC172x Boot-ROM locates TAG0/1 at SPI offset 0 - # instead of end of SPI. - # - with open(args.output, 'wb') as f: - debug_print("Write spi list to file", args.output) - addr = 0 - for s in spi_list: - if addr < s[0]: - debug_print("Offset ",hex(addr)," Length", hex(s[0]-addr), - "fill with 0xff") - f.write(b'\xff' * (s[0] - addr)) - addr = s[0] - debug_print("Offset ",hex(addr), " Length", hex(len(s[1])), "write data") - - f.write(s[1]) - addr += len(s[1]) - - if addr < spi_size: - debug_print("Offset ",hex(addr), " Length", hex(spi_size - addr), - "fill with 0xff") - f.write(b'\xff' * (spi_size - addr)) + if ec_info_block != None: + spi_list_append(spi_list, offset, ec_info_block, "LFW-EC_RO Info Block") + offset += len(ec_info_block) - f.flush() + debug_print("SPI offset after ec_info_block = 0x{0:08x}".format(offset)) -if __name__ == '__main__': - main() + if cosignature != None: + # spi_list.append((offset, co-signature, "ECRO Co-signature")) + spi_list_append(spi_list, offset, cosignature, "LFW-EC_RO Co-signature") + offset += len(cosignature) + + debug_print("SPI offset after co-signature = 0x{0:08x}".format(offset)) + + if trailer != None: + # spi_list.append((offset, trailer, "ECRO Trailer")) + spi_list_append(spi_list, offset, trailer, "LFW-EC_RO trailer") + offset += len(trailer) + + debug_print("SPI offset after trailer = 0x{0:08x}".format(offset)) + + # EC_RW location + rw_offset = int(spi_size // 2) + if args.rw_loc >= 0: + rw_offset = args.rw_loc + + debug_print("rw_offset = 0x{0:08x}".format(rw_offset)) + + # spi_list.append((rw_offset, ecrw, "ecrw")) + spi_list_append(spi_list, rw_offset, ecrw, "EC_RW") + offset = rw_offset + len(ecrw) + + spi_list = sorted(spi_list) + + debug_print("Display spi_list:") + dumpsects(spi_list) + + # + # MEC172x Boot-ROM locates TAG0/1 at SPI offset 0 + # instead of end of SPI. + # + with open(args.output, "wb") as f: + debug_print("Write spi list to file", args.output) + addr = 0 + for s in spi_list: + if addr < s[0]: + debug_print( + "Offset ", hex(addr), " Length", hex(s[0] - addr), "fill with 0xff" + ) + f.write(b"\xff" * (s[0] - addr)) + addr = s[0] + debug_print( + "Offset ", hex(addr), " Length", hex(len(s[1])), "write data" + ) + + f.write(s[1]) + addr += len(s[1]) + + if addr < spi_size: + debug_print( + "Offset ", hex(addr), " Length", hex(spi_size - addr), "fill with 0xff" + ) + f.write(b"\xff" * (spi_size - addr)) + + f.flush() + + +if __name__ == "__main__": + main() diff --git a/chip/mec1322/util/pack_ec.py b/chip/mec1322/util/pack_ec.py index 9783ffb2d5..736f9efcac 100755 --- a/chip/mec1322/util/pack_ec.py +++ b/chip/mec1322/util/pack_ec.py @@ -21,228 +21,325 @@ import tempfile LOAD_ADDR = 0x100000 HEADER_SIZE = 0x140 SPI_CLOCK_LIST = [48, 24, 12, 8] -SPI_READ_CMD_LIST = [0x3, 0xb, 0x3b] +SPI_READ_CMD_LIST = [0x3, 0xB, 0x3B] + +CRC_TABLE = [ + 0x00, + 0x07, + 0x0E, + 0x09, + 0x1C, + 0x1B, + 0x12, + 0x15, + 0x38, + 0x3F, + 0x36, + 0x31, + 0x24, + 0x23, + 0x2A, + 0x2D, +] -CRC_TABLE = [0x00, 0x07, 0x0e, 0x09, 0x1c, 0x1b, 0x12, 0x15, - 0x38, 0x3f, 0x36, 0x31, 0x24, 0x23, 0x2a, 0x2d] def Crc8(crc, data): - """Update CRC8 value.""" - for v in data: - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]); - crc = ((crc << 4) & 0xff) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xf)]); - return crc ^ 0x55 + """Update CRC8 value.""" + for v in data: + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v >> 4)]) + crc = ((crc << 4) & 0xFF) ^ (CRC_TABLE[(crc >> 4) ^ (v & 0xF)]) + return crc ^ 0x55 + def GetEntryPoint(payload_file): - """Read entry point from payload EC image.""" - with open(payload_file, 'rb') as f: - f.seek(4) - s = f.read(4) - return struct.unpack('<I', s)[0] - -def GetPayloadFromOffset(payload_file,offset): - """Read payload and pad it to 64-byte aligned.""" - with open(payload_file, 'rb') as f: - f.seek(offset) - payload = bytearray(f.read()) - rem_len = len(payload) % 64 - if rem_len: - payload += b'\0' * (64 - rem_len) - return payload + """Read entry point from payload EC image.""" + with open(payload_file, "rb") as f: + f.seek(4) + s = f.read(4) + return struct.unpack("<I", s)[0] + + +def GetPayloadFromOffset(payload_file, offset): + """Read payload and pad it to 64-byte aligned.""" + with open(payload_file, "rb") as f: + f.seek(offset) + payload = bytearray(f.read()) + rem_len = len(payload) % 64 + if rem_len: + payload += b"\0" * (64 - rem_len) + return payload + def GetPayload(payload_file): - """Read payload and pad it to 64-byte aligned.""" - return GetPayloadFromOffset(payload_file, 0) + """Read payload and pad it to 64-byte aligned.""" + return GetPayloadFromOffset(payload_file, 0) + def GetPublicKey(pem_file): - """Extract public exponent and modulus from PEM file.""" - result = subprocess.run(['openssl', 'rsa', '-in', pem_file, '-text', - '-noout'], stdout=subprocess.PIPE, encoding='utf-8') - modulus_raw = [] - in_modulus = False - for line in result.stdout.splitlines(): - if line.startswith('modulus'): - in_modulus = True - elif not line.startswith(' '): - in_modulus = False - elif in_modulus: - modulus_raw.extend(line.strip().strip(':').split(':')) - if line.startswith('publicExponent'): - exp = int(line.split(' ')[1], 10) - modulus_raw.reverse() - modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) - return struct.pack('<Q', exp), modulus + """Extract public exponent and modulus from PEM file.""" + result = subprocess.run( + ["openssl", "rsa", "-in", pem_file, "-text", "-noout"], + stdout=subprocess.PIPE, + encoding="utf-8", + ) + modulus_raw = [] + in_modulus = False + for line in result.stdout.splitlines(): + if line.startswith("modulus"): + in_modulus = True + elif not line.startswith(" "): + in_modulus = False + elif in_modulus: + modulus_raw.extend(line.strip().strip(":").split(":")) + if line.startswith("publicExponent"): + exp = int(line.split(" ")[1], 10) + modulus_raw.reverse() + modulus = bytearray((int(x, 16) for x in modulus_raw[:256])) + return struct.pack("<Q", exp), modulus + def GetSpiClockParameter(args): - assert args.spi_clock in SPI_CLOCK_LIST, \ - "Unsupported SPI clock speed %d MHz" % args.spi_clock - return SPI_CLOCK_LIST.index(args.spi_clock) + assert args.spi_clock in SPI_CLOCK_LIST, ( + "Unsupported SPI clock speed %d MHz" % args.spi_clock + ) + return SPI_CLOCK_LIST.index(args.spi_clock) + def GetSpiReadCmdParameter(args): - assert args.spi_read_cmd in SPI_READ_CMD_LIST, \ - "Unsupported SPI read command 0x%x" % args.spi_read_cmd - return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + assert args.spi_read_cmd in SPI_READ_CMD_LIST, ( + "Unsupported SPI read command 0x%x" % args.spi_read_cmd + ) + return SPI_READ_CMD_LIST.index(args.spi_read_cmd) + def PadZeroTo(data, size): - data.extend(b'\0' * (size - len(data))) + data.extend(b"\0" * (size - len(data))) + def BuildHeader(args, payload_len, rorofile): - # Identifier and header version - header = bytearray(b'CSMS\0') + # Identifier and header version + header = bytearray(b"CSMS\0") + + PadZeroTo(header, 0x6) + header.append(GetSpiClockParameter(args)) + header.append(GetSpiReadCmdParameter(args)) - PadZeroTo(header, 0x6) - header.append(GetSpiClockParameter(args)) - header.append(GetSpiReadCmdParameter(args)) + header.extend(struct.pack("<I", LOAD_ADDR)) + header.extend(struct.pack("<I", GetEntryPoint(rorofile))) + header.append((payload_len >> 6) & 0xFF) + header.append((payload_len >> 14) & 0xFF) + PadZeroTo(header, 0x14) + header.extend(struct.pack("<I", args.payload_offset)) - header.extend(struct.pack('<I', LOAD_ADDR)) - header.extend(struct.pack('<I', GetEntryPoint(rorofile))) - header.append((payload_len >> 6) & 0xff) - header.append((payload_len >> 14) & 0xff) - PadZeroTo(header, 0x14) - header.extend(struct.pack('<I', args.payload_offset)) + exp, modulus = GetPublicKey(args.payload_key) + PadZeroTo(header, 0x20) + header.extend(exp) + PadZeroTo(header, 0x30) + header.extend(modulus) + PadZeroTo(header, HEADER_SIZE) - exp, modulus = GetPublicKey(args.payload_key) - PadZeroTo(header, 0x20) - header.extend(exp) - PadZeroTo(header, 0x30) - header.extend(modulus) - PadZeroTo(header, HEADER_SIZE) + return header - return header def SignByteArray(data, pem_file): - hash_file = tempfile.mkstemp(prefix='pack_ec.')[1] - sign_file = tempfile.mkstemp(prefix='pack_ec.')[1] - try: - with open(hash_file, 'wb') as f: - hasher = hashlib.sha256() - hasher.update(data) - f.write(hasher.digest()) - subprocess.run(['openssl', 'rsautl', '-sign', '-inkey', pem_file, - '-keyform', 'PEM', '-in', hash_file, '-out', sign_file], - check=True) - with open(sign_file, 'rb') as f: - signed = f.read() - return bytearray(reversed(signed)) - finally: - os.remove(hash_file) - os.remove(sign_file) + hash_file = tempfile.mkstemp(prefix="pack_ec.")[1] + sign_file = tempfile.mkstemp(prefix="pack_ec.")[1] + try: + with open(hash_file, "wb") as f: + hasher = hashlib.sha256() + hasher.update(data) + f.write(hasher.digest()) + subprocess.run( + [ + "openssl", + "rsautl", + "-sign", + "-inkey", + pem_file, + "-keyform", + "PEM", + "-in", + hash_file, + "-out", + sign_file, + ], + check=True, + ) + with open(sign_file, "rb") as f: + signed = f.read() + return bytearray(reversed(signed)) + finally: + os.remove(hash_file) + os.remove(sign_file) + def BuildTag(args): - tag = bytearray([(args.header_loc >> 8) & 0xff, - (args.header_loc >> 16) & 0xff, - (args.header_loc >> 24) & 0xff]) - if args.chip_select != 0: - tag[2] |= 0x80 - tag.append(Crc8(0, tag)) - return tag + tag = bytearray( + [ + (args.header_loc >> 8) & 0xFF, + (args.header_loc >> 16) & 0xFF, + (args.header_loc >> 24) & 0xFF, + ] + ) + if args.chip_select != 0: + tag[2] |= 0x80 + tag.append(Crc8(0, tag)) + return tag + def PacklfwRoImage(rorw_file, loader_file, image_size): - """TODO:Clean up to get rid of Temp file and just use memory - to save data""" - """Create a temp file with the + """TODO:Clean up to get rid of Temp file and just use memory + to save data""" + """Create a temp file with the first image_size bytes from the rorw file and the bytes from the loader_file appended return the filename""" - fo=tempfile.NamedTemporaryFile(delete=False) # Need to keep file around - with open(loader_file,'rb') as fin1: - pro = fin1.read() - fo.write(pro) - with open(rorw_file, 'rb') as fin: - ro = fin.read(image_size) - fo.write(ro) - fo.close() - return fo.name + fo = tempfile.NamedTemporaryFile(delete=False) # Need to keep file around + with open(loader_file, "rb") as fin1: + pro = fin1.read() + fo.write(pro) + with open(rorw_file, "rb") as fin: + ro = fin.read(image_size) + fo.write(ro) + fo.close() + return fo.name + def parseargs(): - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", - help="EC binary to pack, usually ec.bin or ec.RO.flat.", - metavar="EC_BIN", default="ec.bin") - parser.add_argument("-o", "--output", - help="Output flash binary file", - metavar="EC_SPI_FLASH", default="ec.packed.bin") - parser.add_argument("--header_key", - help="PEM key file for signing header", - default="rsakey_sign_header.pem") - parser.add_argument("--payload_key", - help="PEM key file for signing payload", - default="rsakey_sign_payload.pem") - parser.add_argument("--loader_file", - help="EC loader binary", - default="ecloader.bin") - parser.add_argument("-s", "--spi_size", type=int, - help="Size of the SPI flash in MB", - default=4) - parser.add_argument("-l", "--header_loc", type=int, - help="Location of header in SPI flash", - default=0x170000) - parser.add_argument("-p", "--payload_offset", type=int, - help="The offset of payload from the header", - default=0x240) - parser.add_argument("-r", "--rwpayload_loc", type=int, - help="The offset of payload from the header", - default=0x190000) - parser.add_argument("-z", "--romstart", type=int, - help="The first location to output of the rom", - default=0) - parser.add_argument("-c", "--chip_select", type=int, - help="Chip select signal to use, either 0 or 1.", - default=0) - parser.add_argument("--spi_clock", type=int, - help="SPI clock speed. 8, 12, 24, or 48 MHz.", - default=24) - parser.add_argument("--spi_read_cmd", type=int, - help="SPI read command. 0x3, 0xB, or 0x3B.", - default=0xb) - parser.add_argument("--image_size", type=int, - help="Size of a single image.", - default=(96 * 1024)) - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input", + help="EC binary to pack, usually ec.bin or ec.RO.flat.", + metavar="EC_BIN", + default="ec.bin", + ) + parser.add_argument( + "-o", + "--output", + help="Output flash binary file", + metavar="EC_SPI_FLASH", + default="ec.packed.bin", + ) + parser.add_argument( + "--header_key", + help="PEM key file for signing header", + default="rsakey_sign_header.pem", + ) + parser.add_argument( + "--payload_key", + help="PEM key file for signing payload", + default="rsakey_sign_payload.pem", + ) + parser.add_argument( + "--loader_file", help="EC loader binary", default="ecloader.bin" + ) + parser.add_argument( + "-s", "--spi_size", type=int, help="Size of the SPI flash in MB", default=4 + ) + parser.add_argument( + "-l", + "--header_loc", + type=int, + help="Location of header in SPI flash", + default=0x170000, + ) + parser.add_argument( + "-p", + "--payload_offset", + type=int, + help="The offset of payload from the header", + default=0x240, + ) + parser.add_argument( + "-r", + "--rwpayload_loc", + type=int, + help="The offset of payload from the header", + default=0x190000, + ) + parser.add_argument( + "-z", + "--romstart", + type=int, + help="The first location to output of the rom", + default=0, + ) + parser.add_argument( + "-c", + "--chip_select", + type=int, + help="Chip select signal to use, either 0 or 1.", + default=0, + ) + parser.add_argument( + "--spi_clock", + type=int, + help="SPI clock speed. 8, 12, 24, or 48 MHz.", + default=24, + ) + parser.add_argument( + "--spi_read_cmd", + type=int, + help="SPI read command. 0x3, 0xB, or 0x3B.", + default=0xB, + ) + parser.add_argument( + "--image_size", type=int, help="Size of a single image.", default=(96 * 1024) + ) + return parser.parse_args() + def main(): - args = parseargs() - - spi_size = args.spi_size * 1024 - args.header_loc = spi_size - (128 * 1024) - args.rwpayload_loc = spi_size - (256 * 1024) - args.romstart = spi_size - (256 * 1024) - - spi_list = [] - - rorofile=PacklfwRoImage(args.input, args.loader_file, args.image_size) - payload = GetPayload(rorofile) - payload_len = len(payload) - payload_signature = SignByteArray(payload, args.payload_key) - header = BuildHeader(args, payload_len, rorofile) - header_signature = SignByteArray(header, args.header_key) - tag = BuildTag(args) - # truncate the RW to 128k - payloadrw = GetPayloadFromOffset(args.input,args.image_size)[:128*1024] - os.remove(rorofile) # clean up the temp file - - spi_list.append((args.header_loc, header, "header")) - spi_list.append((args.header_loc + HEADER_SIZE, header_signature, "header_signature")) - spi_list.append((args.header_loc + args.payload_offset, payload, "payload")) - spi_list.append((args.header_loc + args.payload_offset + payload_len, - payload_signature, "payload_signature")) - spi_list.append((spi_size - 256, tag, "tag")) - spi_list.append((args.rwpayload_loc, payloadrw, "payloadrw")) - - - spi_list = sorted(spi_list) - - with open(args.output, 'wb') as f: - addr = args.romstart - for s in spi_list: - assert addr <= s[0] - if addr < s[0]: - f.write(b'\xff' * (s[0] - addr)) - addr = s[0] - f.write(s[1]) - addr += len(s[1]) - if addr < spi_size: - f.write(b'\xff' * (spi_size - addr)) - -if __name__ == '__main__': - main() + args = parseargs() + + spi_size = args.spi_size * 1024 + args.header_loc = spi_size - (128 * 1024) + args.rwpayload_loc = spi_size - (256 * 1024) + args.romstart = spi_size - (256 * 1024) + + spi_list = [] + + rorofile = PacklfwRoImage(args.input, args.loader_file, args.image_size) + payload = GetPayload(rorofile) + payload_len = len(payload) + payload_signature = SignByteArray(payload, args.payload_key) + header = BuildHeader(args, payload_len, rorofile) + header_signature = SignByteArray(header, args.header_key) + tag = BuildTag(args) + # truncate the RW to 128k + payloadrw = GetPayloadFromOffset(args.input, args.image_size)[: 128 * 1024] + os.remove(rorofile) # clean up the temp file + + spi_list.append((args.header_loc, header, "header")) + spi_list.append( + (args.header_loc + HEADER_SIZE, header_signature, "header_signature") + ) + spi_list.append((args.header_loc + args.payload_offset, payload, "payload")) + spi_list.append( + ( + args.header_loc + args.payload_offset + payload_len, + payload_signature, + "payload_signature", + ) + ) + spi_list.append((spi_size - 256, tag, "tag")) + spi_list.append((args.rwpayload_loc, payloadrw, "payloadrw")) + + spi_list = sorted(spi_list) + + with open(args.output, "wb") as f: + addr = args.romstart + for s in spi_list: + assert addr <= s[0] + if addr < s[0]: + f.write(b"\xff" * (s[0] - addr)) + addr = s[0] + f.write(s[1]) + addr += len(s[1]) + if addr < spi_size: + f.write(b"\xff" * (spi_size - addr)) + + +if __name__ == "__main__": + main() diff --git a/cts/common/board.py b/cts/common/board.py index d2c8e02b04..f62d7bdfc5 100644 --- a/cts/common/board.py +++ b/cts/common/board.py @@ -10,379 +10,401 @@ from __future__ import print_function -from abc import ABCMeta -from abc import abstractmethod import os import shutil import subprocess as sp -import serial +from abc import ABCMeta, abstractmethod +import serial import six - -OCD_SCRIPT_DIR = '/usr/share/openocd/scripts' +OCD_SCRIPT_DIR = "/usr/share/openocd/scripts" OPENOCD_CONFIGS = { - 'stm32l476g-eval': 'board/stm32l4discovery.cfg', - 'nucleo-f072rb': 'board/st_nucleo_f0.cfg', - 'nucleo-f411re': 'board/st_nucleo_f4.cfg', + "stm32l476g-eval": "board/stm32l4discovery.cfg", + "nucleo-f072rb": "board/st_nucleo_f0.cfg", + "nucleo-f411re": "board/st_nucleo_f4.cfg", } FLASH_OFFSETS = { - 'stm32l476g-eval': '0x08000000', - 'nucleo-f072rb': '0x08000000', - 'nucleo-f411re': '0x08000000', + "stm32l476g-eval": "0x08000000", + "nucleo-f072rb": "0x08000000", + "nucleo-f411re": "0x08000000", } -REBOOT_MARKER = 'UART initialized after reboot' +REBOOT_MARKER = "UART initialized after reboot" def get_subprocess_args(): - if six.PY3: - return {'encoding': 'utf-8'} - return {} + if six.PY3: + return {"encoding": "utf-8"} + return {} class Board(six.with_metaclass(ABCMeta, object)): - """Class representing a single board connected to a host machine. - - Attributes: - board: String containing actual type of board, i.e. nucleo-f072rb - config: Directory of board config file relative to openocd's - scripts directory - hla_serial: String containing board's hla_serial number (if board - is an stm32 board) - tty_port: String that is the path to the tty port which board's - UART outputs to - tty: String of file descriptor for tty_port - """ - - def __init__(self, board, module, hla_serial=None): - """Initializes a board object with given attributes. - - Args: - board: String containing board name - module: String of the test module you are building, - i.e. gpio, timer, etc. - hla_serial: Serial number if board's adaptor is an HLA - - Raises: - RuntimeError: Board is not supported - """ - if board not in OPENOCD_CONFIGS: - msg = 'OpenOcd configuration not found for ' + board - raise RuntimeError(msg) - if board not in FLASH_OFFSETS: - msg = 'Flash offset not found for ' + board - raise RuntimeError(msg) - self.board = board - self.flash_offset = FLASH_OFFSETS[self.board] - self.openocd_config = OPENOCD_CONFIGS[self.board] - self.module = module - self.hla_serial = hla_serial - self.tty_port = None - self.tty = None - - def reset_log_dir(self): - """Reset log directory.""" - if os.path.isdir(self.log_dir): - shutil.rmtree(self.log_dir) - os.makedirs(self.log_dir) - - @staticmethod - def get_stlink_serials(): - """Gets serial numbers of all st-link v2.1 board attached to host. - - Returns: - List of serials + """Class representing a single board connected to a host machine. + + Attributes: + board: String containing actual type of board, i.e. nucleo-f072rb + config: Directory of board config file relative to openocd's + scripts directory + hla_serial: String containing board's hla_serial number (if board + is an stm32 board) + tty_port: String that is the path to the tty port which board's + UART outputs to + tty: String of file descriptor for tty_port """ - usb_args = ['sudo', 'lsusb', '-v', '-d', '0x0483:0x374b'] - st_link_info = sp.check_output(usb_args, **get_subprocess_args()) - st_serials = [] - for line in st_link_info.split('\n'): - if 'iSerial' not in line: - continue - words = line.split() - if len(words) <= 2: - continue - st_serials.append(words[2].strip()) - return st_serials - - @abstractmethod - def get_serial(self): - """Subclass should implement this.""" - pass - - def send_openocd_commands(self, commands): - """Send a command to the board via openocd. - - Args: - commands: A list of commands to send - - Returns: - True if execution is successful or False otherwise. - """ - args = ['sudo', 'openocd', '-s', OCD_SCRIPT_DIR, - '-f', self.openocd_config, '-c', 'hla_serial ' + self.hla_serial] - - for cmd in commands: - args += ['-c', cmd] - args += ['-c', 'shutdown'] - - rv = 1 - with open(self.openocd_log, 'a') as output: - rv = sp.call(args, stdout=output, stderr=sp.STDOUT) - if rv != 0: - self.dump_openocd_log() - - return rv == 0 - - def dump_openocd_log(self): - with open(self.openocd_log) as log: - print(log.read()) + def __init__(self, board, module, hla_serial=None): + """Initializes a board object with given attributes. + + Args: + board: String containing board name + module: String of the test module you are building, + i.e. gpio, timer, etc. + hla_serial: Serial number if board's adaptor is an HLA + + Raises: + RuntimeError: Board is not supported + """ + if board not in OPENOCD_CONFIGS: + msg = "OpenOcd configuration not found for " + board + raise RuntimeError(msg) + if board not in FLASH_OFFSETS: + msg = "Flash offset not found for " + board + raise RuntimeError(msg) + self.board = board + self.flash_offset = FLASH_OFFSETS[self.board] + self.openocd_config = OPENOCD_CONFIGS[self.board] + self.module = module + self.hla_serial = hla_serial + self.tty_port = None + self.tty = None + + def reset_log_dir(self): + """Reset log directory.""" + if os.path.isdir(self.log_dir): + shutil.rmtree(self.log_dir) + os.makedirs(self.log_dir) + + @staticmethod + def get_stlink_serials(): + """Gets serial numbers of all st-link v2.1 board attached to host. + + Returns: + List of serials + """ + usb_args = ["sudo", "lsusb", "-v", "-d", "0x0483:0x374b"] + st_link_info = sp.check_output(usb_args, **get_subprocess_args()) + st_serials = [] + for line in st_link_info.split("\n"): + if "iSerial" not in line: + continue + words = line.split() + if len(words) <= 2: + continue + st_serials.append(words[2].strip()) + return st_serials + + @abstractmethod + def get_serial(self): + """Subclass should implement this.""" + pass + + def send_openocd_commands(self, commands): + """Send a command to the board via openocd. + + Args: + commands: A list of commands to send + + Returns: + True if execution is successful or False otherwise. + """ + args = [ + "sudo", + "openocd", + "-s", + OCD_SCRIPT_DIR, + "-f", + self.openocd_config, + "-c", + "hla_serial " + self.hla_serial, + ] + + for cmd in commands: + args += ["-c", cmd] + args += ["-c", "shutdown"] + + rv = 1 + with open(self.openocd_log, "a") as output: + rv = sp.call(args, stdout=output, stderr=sp.STDOUT) + + if rv != 0: + self.dump_openocd_log() + + return rv == 0 + + def dump_openocd_log(self): + with open(self.openocd_log) as log: + print(log.read()) + + def build(self, ec_dir): + """Builds test suite module for board. + + Args: + ec_dir: String of the ec directory path + + Returns: + True if build is successful or False otherwise. + """ + cmds = [ + "make", + "--directory=" + ec_dir, + "BOARD=" + self.board, + "CTS_MODULE=" + self.module, + "-j", + ] + + rv = 1 + with open(self.build_log, "a") as output: + rv = sp.call(cmds, stdout=output, stderr=sp.STDOUT) + + if rv != 0: + self.dump_build_log() + + return rv == 0 + + def dump_build_log(self): + with open(self.build_log) as log: + print(log.read()) + + def flash(self, image_path): + """Flashes board with most recent build ec.bin.""" + cmd = [ + "reset_config connect_assert_srst", + "init", + "reset init", + "flash write_image erase %s %s" % (image_path, self.flash_offset), + ] + return self.send_openocd_commands(cmd) + + def to_string(self): + s = ( + "Type: Board\n" + "board: " + self.board + "\n" + "hla_serial: " + self.hla_serial + "\n" + "openocd_config: " + self.openocd_config + "\n" + "tty_port: " + self.tty_port + "\n" + "tty: " + str(self.tty) + "\n" + ) + return s + + def reset_halt(self): + """Reset then halt board.""" + return self.send_openocd_commands(["init", "reset halt"]) + + def resume(self): + """Resume halting board.""" + return self.send_openocd_commands(["init", "resume"]) + + def setup_tty(self): + """Call this before calling read_tty for the first time. + + This is not in the initialization because caller only should call + this function after serial numbers are setup + """ + self.get_serial() + self.reset_halt() + self.identify_tty_port() + + tty = None + try: + tty = serial.Serial(self.tty_port, 115200, timeout=1) + except serial.SerialException: + raise ValueError( + "Failed to open " + + self.tty_port + + " of " + + self.board + + ". Please make sure the port is available and you have" + + " permission to read it. Create dialout group and run:" + + " sudo usermod -a -G dialout <username>." + ) + self.tty = tty + + def read_tty(self, max_boot_count=1): + """Read info from a serial port described by a file descriptor. + + Args: + max_boot_count: Stop reading if boot count exceeds this number + + Returns: + result: characters read from tty + boot: boot counts + """ + buf = [] + line = [] + boot = 0 + while True: + c = self.tty.read().decode("utf-8") + if not c: + break + line.append(c) + if c == "\n": + l = "".join(line) + buf.append(l) + if REBOOT_MARKER in l: + boot += 1 + line = [] + if boot > max_boot_count: + break + + l = "".join(line) + buf.append(l) + result = "".join(buf) - def build(self, ec_dir): - """Builds test suite module for board. + return result, boot - Args: - ec_dir: String of the ec directory path + def identify_tty_port(self): + """Saves this board's serial port.""" + dev_dir = "/dev" + id_prefix = "ID_SERIAL_SHORT=" + com_devices = [f for f in os.listdir(dev_dir) if f.startswith("ttyACM")] - Returns: - True if build is successful or False otherwise. - """ - cmds = ['make', - '--directory=' + ec_dir, - 'BOARD=' + self.board, - 'CTS_MODULE=' + self.module, - '-j'] - - rv = 1 - with open(self.build_log, 'a') as output: - rv = sp.call(cmds, stdout=output, stderr=sp.STDOUT) - - if rv != 0: - self.dump_build_log() - - return rv == 0 - - def dump_build_log(self): - with open(self.build_log) as log: - print(log.read()) - - def flash(self, image_path): - """Flashes board with most recent build ec.bin.""" - cmd = ['reset_config connect_assert_srst', - 'init', - 'reset init', - 'flash write_image erase %s %s' % (image_path, self.flash_offset)] - return self.send_openocd_commands(cmd) - - def to_string(self): - s = ('Type: Board\n' - 'board: ' + self.board + '\n' - 'hla_serial: ' + self.hla_serial + '\n' - 'openocd_config: ' + self.openocd_config + '\n' - 'tty_port: ' + self.tty_port + '\n' - 'tty: ' + str(self.tty) + '\n') - return s - - def reset_halt(self): - """Reset then halt board.""" - return self.send_openocd_commands(['init', 'reset halt']) - - def resume(self): - """Resume halting board.""" - return self.send_openocd_commands(['init', 'resume']) - - def setup_tty(self): - """Call this before calling read_tty for the first time. - - This is not in the initialization because caller only should call - this function after serial numbers are setup - """ - self.get_serial() - self.reset_halt() - self.identify_tty_port() - - tty = None - try: - tty = serial.Serial(self.tty_port, 115200, timeout=1) - except serial.SerialException: - raise ValueError('Failed to open ' + self.tty_port + ' of ' + self.board + - '. Please make sure the port is available and you have' + - ' permission to read it. Create dialout group and run:' + - ' sudo usermod -a -G dialout <username>.') - self.tty = tty - - def read_tty(self, max_boot_count=1): - """Read info from a serial port described by a file descriptor. - - Args: - max_boot_count: Stop reading if boot count exceeds this number - - Returns: - result: characters read from tty - boot: boot counts - """ - buf = [] - line = [] - boot = 0 - while True: - c = self.tty.read().decode('utf-8') - if not c: - break - line.append(c) - if c == '\n': - l = ''.join(line) - buf.append(l) - if REBOOT_MARKER in l: - boot += 1 - line = [] - if boot > max_boot_count: - break - - l = ''.join(line) - buf.append(l) - result = ''.join(buf) - - return result, boot - - def identify_tty_port(self): - """Saves this board's serial port.""" - dev_dir = '/dev' - id_prefix = 'ID_SERIAL_SHORT=' - com_devices = [f for f in os.listdir(dev_dir) if f.startswith('ttyACM')] - - for device in com_devices: - self.tty_port = os.path.join(dev_dir, device) - properties = sp.check_output( - ['udevadm', 'info', '-a', '-n', self.tty_port, '--query=property'], - **get_subprocess_args()) - for line in [l.strip() for l in properties.split('\n')]: - if line.startswith(id_prefix): - if self.hla_serial == line[len(id_prefix):]: - return + for device in com_devices: + self.tty_port = os.path.join(dev_dir, device) + properties = sp.check_output( + ["udevadm", "info", "-a", "-n", self.tty_port, "--query=property"], + **get_subprocess_args() + ) + for line in [l.strip() for l in properties.split("\n")]: + if line.startswith(id_prefix): + if self.hla_serial == line[len(id_prefix) :]: + return - # If we get here without returning, something is wrong - raise RuntimeError('The device dev path could not be found') + # If we get here without returning, something is wrong + raise RuntimeError("The device dev path could not be found") - def close_tty(self): - """Close tty.""" - self.tty.close() + def close_tty(self): + """Close tty.""" + self.tty.close() class TestHarness(Board): - """Subclass of Board representing a Test Harness. + """Subclass of Board representing a Test Harness. - Attributes: - serial_path: Path to file containing serial number - """ - - def __init__(self, board, module, log_dir, serial_path): - """Initializes a board object with given attributes. - - Args: - board: board name - module: module name - log_dir: Directory where log file is stored + Attributes: serial_path: Path to file containing serial number """ - Board.__init__(self, board, module) - self.log_dir = log_dir - self.openocd_log = os.path.join(log_dir, 'openocd_th.log') - self.build_log = os.path.join(log_dir, 'build_th.log') - self.serial_path = serial_path - self.reset_log_dir() - - def get_serial(self): - """Loads serial number from saved location.""" - if self.hla_serial: - return # serial was already loaded - try: - with open(self.serial_path, mode='r') as f: - s = f.read() - self.hla_serial = s.strip() + + def __init__(self, board, module, log_dir, serial_path): + """Initializes a board object with given attributes. + + Args: + board: board name + module: module name + log_dir: Directory where log file is stored + serial_path: Path to file containing serial number + """ + Board.__init__(self, board, module) + self.log_dir = log_dir + self.openocd_log = os.path.join(log_dir, "openocd_th.log") + self.build_log = os.path.join(log_dir, "build_th.log") + self.serial_path = serial_path + self.reset_log_dir() + + def get_serial(self): + """Loads serial number from saved location.""" + if self.hla_serial: + return # serial was already loaded + try: + with open(self.serial_path, mode="r") as f: + s = f.read() + self.hla_serial = s.strip() + return + except IOError: + msg = ( + "Your TH board has not been identified.\n" + "Connect only TH and run the script --setup, then try again." + ) + raise RuntimeError(msg) + + def save_serial(self): + """Saves the TH serial number to a file.""" + serials = Board.get_stlink_serials() + if len(serials) > 1: + msg = ( + "There are more than one test board connected to the host." + "\nConnect only the test harness and remove other boards." + ) + raise RuntimeError(msg) + if len(serials) < 1: + msg = "No test boards were found.\n" "Check boards are connected." + raise RuntimeError(msg) + + s = serials[0] + serial_dir = os.path.dirname(self.serial_path) + if not os.path.exists(serial_dir): + os.makedirs(serial_dir) + with open(self.serial_path, mode="w") as f: + f.write(s) + self.hla_serial = s + + print("Your TH serial", s, "has been saved as", self.serial_path) return - except IOError: - msg = ('Your TH board has not been identified.\n' - 'Connect only TH and run the script --setup, then try again.') - raise RuntimeError(msg) - - def save_serial(self): - """Saves the TH serial number to a file.""" - serials = Board.get_stlink_serials() - if len(serials) > 1: - msg = ('There are more than one test board connected to the host.' - '\nConnect only the test harness and remove other boards.') - raise RuntimeError(msg) - if len(serials) < 1: - msg = ('No test boards were found.\n' - 'Check boards are connected.') - raise RuntimeError(msg) - - s = serials[0] - serial_dir = os.path.dirname(self.serial_path) - if not os.path.exists(serial_dir): - os.makedirs(serial_dir) - with open(self.serial_path, mode='w') as f: - f.write(s) - self.hla_serial = s - - print('Your TH serial', s, 'has been saved as', self.serial_path) - return class DeviceUnderTest(Board): - """Subclass of Board representing a DUT board. + """Subclass of Board representing a DUT board. - Attributes: - th: Reference to test harness board to which this DUT is attached - """ - - def __init__(self, board, th, module, log_dir, hla_ser=None): - """Initializes a DUT object. - - Args: - board: String containing board name + Attributes: th: Reference to test harness board to which this DUT is attached - module: module name - log_dir: Directory where log file is stored - hla_ser: Serial number if board uses an HLA adaptor """ - Board.__init__(self, board, module, hla_serial=hla_ser) - self.th = th - self.log_dir = log_dir - self.openocd_log = os.path.join(log_dir, 'openocd_dut.log') - self.build_log = os.path.join(log_dir, 'build_dut.log') - self.reset_log_dir() - - def get_serial(self): - """Get serial number. - Precondition: The DUT and TH must both be connected, and th.hla_serial - must hold the correct value (the th's serial #) + def __init__(self, board, th, module, log_dir, hla_ser=None): + """Initializes a DUT object. + + Args: + board: String containing board name + th: Reference to test harness board to which this DUT is attached + module: module name + log_dir: Directory where log file is stored + hla_ser: Serial number if board uses an HLA adaptor + """ + Board.__init__(self, board, module, hla_serial=hla_ser) + self.th = th + self.log_dir = log_dir + self.openocd_log = os.path.join(log_dir, "openocd_dut.log") + self.build_log = os.path.join(log_dir, "build_dut.log") + self.reset_log_dir() + + def get_serial(self): + """Get serial number. + + Precondition: The DUT and TH must both be connected, and th.hla_serial + must hold the correct value (the th's serial #) + + Raises: + RuntimeError: DUT isn't found or multiple DUTs are found. + """ + if self.hla_serial is not None: + # serial was already set ('' is a valid serial) + return - Raises: - RuntimeError: DUT isn't found or multiple DUTs are found. - """ - if self.hla_serial is not None: - # serial was already set ('' is a valid serial) - return - - serials = Board.get_stlink_serials() - dut = [s for s in serials if self.th.hla_serial != s] - - # If len(dut) is 0 then your dut doesn't use an st-link device, so we - # don't have to worry about its serial number - if not dut: - msg = ('Failed to find serial for DUT.\n' - 'Is ' + self.board + ' connected?') - raise RuntimeError(msg) - if len(dut) > 1: - msg = ('Found multiple DUTs.\n' - 'You can connect only one DUT at a time. This may be caused by\n' - 'an incorrect TH serial. Check if ' + self.th.serial_path + '\n' - 'contains a correct serial.') - raise RuntimeError(msg) - - # Found your other st-link device serial! - self.hla_serial = dut[0] - return + serials = Board.get_stlink_serials() + dut = [s for s in serials if self.th.hla_serial != s] + + # If len(dut) is 0 then your dut doesn't use an st-link device, so we + # don't have to worry about its serial number + if not dut: + msg = "Failed to find serial for DUT.\n" "Is " + self.board + " connected?" + raise RuntimeError(msg) + if len(dut) > 1: + msg = ( + "Found multiple DUTs.\n" + "You can connect only one DUT at a time. This may be caused by\n" + "an incorrect TH serial. Check if " + self.th.serial_path + "\n" + "contains a correct serial." + ) + raise RuntimeError(msg) + + # Found your other st-link device serial! + self.hla_serial = dut[0] + return diff --git a/cts/cts.py b/cts/cts.py index c3e0335cab..ebc526c701 100755 --- a/cts/cts.py +++ b/cts/cts.py @@ -28,416 +28,424 @@ import argparse import os import shutil import time -import common.board as board +import common.board as board -CTS_RC_PREFIX = 'CTS_RC_' -DEFAULT_TH = 'stm32l476g-eval' -DEFAULT_DUT = 'nucleo-f072rb' +CTS_RC_PREFIX = "CTS_RC_" +DEFAULT_TH = "stm32l476g-eval" +DEFAULT_DUT = "nucleo-f072rb" MAX_SUITE_TIME_SEC = 5 -CTS_TEST_RESULT_DIR = '/tmp/ects' +CTS_TEST_RESULT_DIR = "/tmp/ects" # Host only return codes. Make sure they match values in cts.rc -CTS_RC_DID_NOT_START = -1 # test did not run. -CTS_RC_DID_NOT_END = -2 # test did not run. -CTS_RC_DUPLICATE_RUN = -3 # test was run multiple times. -CTS_RC_INVALID_RETURN_CODE = -4 # failed to parse return code +CTS_RC_DID_NOT_START = -1 # test did not run. +CTS_RC_DID_NOT_END = -2 # test did not run. +CTS_RC_DUPLICATE_RUN = -3 # test was run multiple times. +CTS_RC_INVALID_RETURN_CODE = -4 # failed to parse return code class Cts(object): - """Class that represents a eCTS run. - - Attributes: - dut: DeviceUnderTest object representing DUT - th: TestHarness object representing a test harness - module: Name of module to build/run tests for - testlist: List of strings of test names contained in given module - return_codes: Dict of strings of return codes, with a code's integer - value being the index for the corresponding string representation - """ - - def __init__(self, ec_dir, th, dut, module): - """Initializes cts class object with given arguments. - - Args: - ec_dir: Path to ec directory - th: Name of the test harness board - dut: Name of the device under test board - module: Name of module to build/run tests for (e.g. gpio, interrupt) - """ - self.results_dir = os.path.join(CTS_TEST_RESULT_DIR, dut, module) - if os.path.isdir(self.results_dir): - shutil.rmtree(self.results_dir) - else: - os.makedirs(self.results_dir) - self.ec_dir = ec_dir - self.module = module - serial_path = os.path.join(CTS_TEST_RESULT_DIR, 'th_serial') - self.th = board.TestHarness(th, module, self.results_dir, serial_path) - self.dut = board.DeviceUnderTest(dut, self.th, module, self.results_dir) - cts_dir = os.path.join(self.ec_dir, 'cts') - testlist_path = os.path.join(cts_dir, self.module, 'cts.testlist') - return_codes_path = os.path.join(cts_dir, 'common', 'cts.rc') - self.get_return_codes(return_codes_path) - self.testlist = self.get_macro_args(testlist_path, 'CTS_TEST') - - def build(self): - """Build images for DUT and TH.""" - print('Building DUT image...') - if not self.dut.build(self.ec_dir): - raise RuntimeError('Building module %s for DUT failed' % (self.module)) - print('Building TH image...') - if not self.th.build(self.ec_dir): - raise RuntimeError('Building module %s for TH failed' % (self.module)) - - def flash_boards(self): - """Flashes TH and DUT with their most recently built ec.bin.""" - cts_module = 'cts_' + self.module - image_path = os.path.join('build', self.th.board, cts_module, 'ec.bin') - self.identify_boards() - print('Flashing TH with', image_path) - if not self.th.flash(image_path): - raise RuntimeError('Flashing TH failed') - image_path = os.path.join('build', self.dut.board, cts_module, 'ec.bin') - print('Flashing DUT with', image_path) - if not self.dut.flash(image_path): - raise RuntimeError('Flashing DUT failed') - - def setup(self): - """Setup boards.""" - self.th.save_serial() - - def identify_boards(self): - """Updates serials of TH and DUT in that order (order matters).""" - self.th.get_serial() - self.dut.get_serial() - - def get_macro_args(self, filepath, macro): - """Get list of args of a macro in a file when macro. - - Args: - filepath: String containing absolute path to the file - macro: String containing text of macro to get args of - - Returns: - List of dictionaries where each entry is: - 'name': Test name, - 'th_string': Expected string from TH, - 'dut_string': Expected string from DUT, - """ - tests = [] - with open(filepath, 'r') as f: - lines = f.readlines() - joined = ''.join(lines).replace('\\\n', '').splitlines() - for l in joined: - if not l.strip().startswith(macro): - continue - d = {} - l = l.strip()[len(macro):] - l = l.strip('()').split(',') - d['name'] = l[0].strip() - d['th_rc'] = self.get_return_code_value(l[1].strip().strip('"')) - d['th_string'] = l[2].strip().strip('"') - d['dut_rc'] = self.get_return_code_value(l[3].strip().strip('"')) - d['dut_string'] = l[4].strip().strip('"') - tests.append(d) - return tests - - def get_return_codes(self, filepath): - """Read return code names from the return code definition file.""" - self.return_codes = {} - val = 0 - with open(filepath, 'r') as f: - for line in f: - line = line.strip() - if not line.startswith(CTS_RC_PREFIX): - continue - line = line.split(',')[0] - if '=' in line: - tokens = line.split('=') - line = tokens[0].strip() - val = int(tokens[1].strip()) - self.return_codes[line] = val - val += 1 - - def parse_output(self, output): - """Parse console output from DUT or TH. - - Args: - output: String containing consoule output - - Returns: - List of dictionaries where each key and value are: - name = 'ects_test_x', - started = True/False, - ended = True/False, - rc = CTS_RC_*, - output = All text between 'ects_test_x start' and 'ects_test_x end' - """ - results = [] - i = 0 - for test in self.testlist: - results.append({}) - results[i]['name'] = test['name'] - results[i]['started'] = False - results[i]['rc'] = CTS_RC_DID_NOT_START - results[i]['string'] = False - results[i]['output'] = [] - i += 1 - - i = 0 - for ln in [ln.strip() for ln in output.split('\n')]: - if i + 1 > len(results): - break - tokens = ln.split() - if len(tokens) >= 2: - if tokens[0].strip() == results[i]['name']: - if tokens[1].strip() == 'start': - # start line found - if results[i]['started']: # Already started - results[i]['rc'] = CTS_RC_DUPLICATE_RUN - else: - results[i]['rc'] = CTS_RC_DID_NOT_END - results[i]['started'] = True - continue - elif results[i]['started'] and tokens[1].strip() == 'end': - # end line found - results[i]['rc'] = CTS_RC_INVALID_RETURN_CODE - if len(tokens) == 3: - try: - results[i]['rc'] = int(tokens[2].strip()) - except ValueError: - pass - # Since index is incremented when 'end' is encountered, we don't - # need to check duplicate 'end'. - i += 1 - continue - if results[i]['started']: - results[i]['output'].append(ln) - - return results - - def get_return_code_name(self, code, strip_prefix=False): - name = '' - for k, v in self.return_codes.items(): - if v == code: - if strip_prefix: - name = k[len(CTS_RC_PREFIX):] - else: - name = k - return name - - def get_return_code_value(self, name): - if name: - return self.return_codes[name] - return 0 - - def evaluate_run(self, dut_output, th_output): - """Parse outputs to derive test results. - - Args: - dut_output: String output of DUT - th_output: String output of TH - - Returns: - th_results: list of test results for TH - dut_results: list of test results for DUT + """Class that represents a eCTS run. + + Attributes: + dut: DeviceUnderTest object representing DUT + th: TestHarness object representing a test harness + module: Name of module to build/run tests for + testlist: List of strings of test names contained in given module + return_codes: Dict of strings of return codes, with a code's integer + value being the index for the corresponding string representation """ - th_results = self.parse_output(th_output) - dut_results = self.parse_output(dut_output) - # Search for expected string in each output - for i, v in enumerate(self.testlist): - if v['th_string'] in th_results[i]['output'] or not v['th_string']: - th_results[i]['string'] = True - if v['dut_string'] in dut_results[i]['output'] or not v['dut_string']: - dut_results[i]['string'] = True + def __init__(self, ec_dir, th, dut, module): + """Initializes cts class object with given arguments. + + Args: + ec_dir: Path to ec directory + th: Name of the test harness board + dut: Name of the device under test board + module: Name of module to build/run tests for (e.g. gpio, interrupt) + """ + self.results_dir = os.path.join(CTS_TEST_RESULT_DIR, dut, module) + if os.path.isdir(self.results_dir): + shutil.rmtree(self.results_dir) + else: + os.makedirs(self.results_dir) + self.ec_dir = ec_dir + self.module = module + serial_path = os.path.join(CTS_TEST_RESULT_DIR, "th_serial") + self.th = board.TestHarness(th, module, self.results_dir, serial_path) + self.dut = board.DeviceUnderTest(dut, self.th, module, self.results_dir) + cts_dir = os.path.join(self.ec_dir, "cts") + testlist_path = os.path.join(cts_dir, self.module, "cts.testlist") + return_codes_path = os.path.join(cts_dir, "common", "cts.rc") + self.get_return_codes(return_codes_path) + self.testlist = self.get_macro_args(testlist_path, "CTS_TEST") + + def build(self): + """Build images for DUT and TH.""" + print("Building DUT image...") + if not self.dut.build(self.ec_dir): + raise RuntimeError("Building module %s for DUT failed" % (self.module)) + print("Building TH image...") + if not self.th.build(self.ec_dir): + raise RuntimeError("Building module %s for TH failed" % (self.module)) + + def flash_boards(self): + """Flashes TH and DUT with their most recently built ec.bin.""" + cts_module = "cts_" + self.module + image_path = os.path.join("build", self.th.board, cts_module, "ec.bin") + self.identify_boards() + print("Flashing TH with", image_path) + if not self.th.flash(image_path): + raise RuntimeError("Flashing TH failed") + image_path = os.path.join("build", self.dut.board, cts_module, "ec.bin") + print("Flashing DUT with", image_path) + if not self.dut.flash(image_path): + raise RuntimeError("Flashing DUT failed") + + def setup(self): + """Setup boards.""" + self.th.save_serial() + + def identify_boards(self): + """Updates serials of TH and DUT in that order (order matters).""" + self.th.get_serial() + self.dut.get_serial() + + def get_macro_args(self, filepath, macro): + """Get list of args of a macro in a file when macro. + + Args: + filepath: String containing absolute path to the file + macro: String containing text of macro to get args of + + Returns: + List of dictionaries where each entry is: + 'name': Test name, + 'th_string': Expected string from TH, + 'dut_string': Expected string from DUT, + """ + tests = [] + with open(filepath, "r") as f: + lines = f.readlines() + joined = "".join(lines).replace("\\\n", "").splitlines() + for l in joined: + if not l.strip().startswith(macro): + continue + d = {} + l = l.strip()[len(macro) :] + l = l.strip("()").split(",") + d["name"] = l[0].strip() + d["th_rc"] = self.get_return_code_value(l[1].strip().strip('"')) + d["th_string"] = l[2].strip().strip('"') + d["dut_rc"] = self.get_return_code_value(l[3].strip().strip('"')) + d["dut_string"] = l[4].strip().strip('"') + tests.append(d) + return tests + + def get_return_codes(self, filepath): + """Read return code names from the return code definition file.""" + self.return_codes = {} + val = 0 + with open(filepath, "r") as f: + for line in f: + line = line.strip() + if not line.startswith(CTS_RC_PREFIX): + continue + line = line.split(",")[0] + if "=" in line: + tokens = line.split("=") + line = tokens[0].strip() + val = int(tokens[1].strip()) + self.return_codes[line] = val + val += 1 + + def parse_output(self, output): + """Parse console output from DUT or TH. + + Args: + output: String containing consoule output + + Returns: + List of dictionaries where each key and value are: + name = 'ects_test_x', + started = True/False, + ended = True/False, + rc = CTS_RC_*, + output = All text between 'ects_test_x start' and 'ects_test_x end' + """ + results = [] + i = 0 + for test in self.testlist: + results.append({}) + results[i]["name"] = test["name"] + results[i]["started"] = False + results[i]["rc"] = CTS_RC_DID_NOT_START + results[i]["string"] = False + results[i]["output"] = [] + i += 1 - return th_results, dut_results + i = 0 + for ln in [ln.strip() for ln in output.split("\n")]: + if i + 1 > len(results): + break + tokens = ln.split() + if len(tokens) >= 2: + if tokens[0].strip() == results[i]["name"]: + if tokens[1].strip() == "start": + # start line found + if results[i]["started"]: # Already started + results[i]["rc"] = CTS_RC_DUPLICATE_RUN + else: + results[i]["rc"] = CTS_RC_DID_NOT_END + results[i]["started"] = True + continue + elif results[i]["started"] and tokens[1].strip() == "end": + # end line found + results[i]["rc"] = CTS_RC_INVALID_RETURN_CODE + if len(tokens) == 3: + try: + results[i]["rc"] = int(tokens[2].strip()) + except ValueError: + pass + # Since index is incremented when 'end' is encountered, we don't + # need to check duplicate 'end'. + i += 1 + continue + if results[i]["started"]: + results[i]["output"].append(ln) + + return results + + def get_return_code_name(self, code, strip_prefix=False): + name = "" + for k, v in self.return_codes.items(): + if v == code: + if strip_prefix: + name = k[len(CTS_RC_PREFIX) :] + else: + name = k + return name + + def get_return_code_value(self, name): + if name: + return self.return_codes[name] + return 0 + + def evaluate_run(self, dut_output, th_output): + """Parse outputs to derive test results. + + Args: + dut_output: String output of DUT + th_output: String output of TH + + Returns: + th_results: list of test results for TH + dut_results: list of test results for DUT + """ + th_results = self.parse_output(th_output) + dut_results = self.parse_output(dut_output) + + # Search for expected string in each output + for i, v in enumerate(self.testlist): + if v["th_string"] in th_results[i]["output"] or not v["th_string"]: + th_results[i]["string"] = True + if v["dut_string"] in dut_results[i]["output"] or not v["dut_string"]: + dut_results[i]["string"] = True + + return th_results, dut_results + + def print_result(self, th_results, dut_results): + """Print results to the screen. + + Args: + th_results: list of test results for TH + dut_results: list of test results for DUT + """ + len_test_name = max(len(s["name"]) for s in self.testlist) + len_code_name = max( + len(self.get_return_code_name(v, True)) for v in self.return_codes.values() + ) + + head = "{:^" + str(len_test_name) + "} " + head += "{:^" + str(len_code_name) + "} " + head += "{:^" + str(len_code_name) + "}" + head += "{:^" + str(len(" TH_STR")) + "}" + head += "{:^" + str(len(" DUT_STR")) + "}" + head += "{:^" + str(len(" RESULT")) + "}\n" + fmt = "{:" + str(len_test_name) + "} " + fmt += "{:>" + str(len_code_name) + "} " + fmt += "{:>" + str(len_code_name) + "}" + fmt += "{:>" + str(len(" TH_STR")) + "}" + fmt += "{:>" + str(len(" DUT_STR")) + "}" + fmt += "{:>" + str(len(" RESULT")) + "}\n" + + self.formatted_results = head.format( + "TEST NAME", "TH_RC", "DUT_RC", " TH_STR", " DUT_STR", " RESULT" + ) + for i, d in enumerate(dut_results): + th_cn = self.get_return_code_name(th_results[i]["rc"], True) + dut_cn = self.get_return_code_name(dut_results[i]["rc"], True) + th_res = self.evaluate_result( + th_results[i], self.testlist[i]["th_rc"], self.testlist[i]["th_string"] + ) + dut_res = self.evaluate_result( + dut_results[i], + self.testlist[i]["dut_rc"], + self.testlist[i]["dut_string"], + ) + self.formatted_results += fmt.format( + d["name"], + th_cn, + dut_cn, + "YES" if th_results[i]["string"] else "NO", + "YES" if dut_results[i]["string"] else "NO", + "PASS" if th_res and dut_res else "FAIL", + ) + + def evaluate_result(self, result, expected_rc, expected_string): + if result["rc"] != expected_rc: + return False + if expected_string and expected_string not in result["output"]: + return False + return True + + def run(self): + """Resets boards, records test results in results dir.""" + print("Reading serials...") + self.identify_boards() + print("Opening DUT tty...") + self.dut.setup_tty() + print("Opening TH tty...") + self.th.setup_tty() + + # Boards might be still writing to tty. Wait a few seconds before flashing. + time.sleep(3) + + # clear buffers + print("Clearing DUT tty...") + self.dut.read_tty() + print("Clearing TH tty...") + self.th.read_tty() + + # Resets the boards and allows them to run tests + # Due to current (7/27/16) version of sync function, + # both boards must be rest and halted, with the th + # resuming first, in order for the test suite to run in sync + print("Halting TH...") + if not self.th.reset_halt(): + raise RuntimeError("Failed to halt TH") + print("Halting DUT...") + if not self.dut.reset_halt(): + raise RuntimeError("Failed to halt DUT") + print("Resuming TH...") + if not self.th.resume(): + raise RuntimeError("Failed to resume TH") + print("Resuming DUT...") + if not self.dut.resume(): + raise RuntimeError("Failed to resume DUT") + + time.sleep(MAX_SUITE_TIME_SEC) + + print("Reading DUT tty...") + dut_output, _ = self.dut.read_tty() + self.dut.close_tty() + print("Reading TH tty...") + th_output, _ = self.th.read_tty() + self.th.close_tty() + + print("Halting TH...") + if not self.th.reset_halt(): + raise RuntimeError("Failed to halt TH") + print("Halting DUT...") + if not self.dut.reset_halt(): + raise RuntimeError("Failed to halt DUT") + + if not dut_output or not th_output: + raise ValueError( + "Output missing from boards. If you have a process " + "reading ttyACMx, please kill that process and try " + "again." + ) + + print("Pursing results...") + th_results, dut_results = self.evaluate_run(dut_output, th_output) + + # Print out results + self.print_result(th_results, dut_results) + + # Write results + dest = os.path.join(self.results_dir, "results.log") + with open(dest, "w") as fl: + fl.write(self.formatted_results) + + # Write UART outputs + dest = os.path.join(self.results_dir, "uart_th.log") + with open(dest, "w") as fl: + fl.write(th_output) + dest = os.path.join(self.results_dir, "uart_dut.log") + with open(dest, "w") as fl: + fl.write(dut_output) + + print(self.formatted_results) + + # TODO(chromium:735652): Should set exit code for the shell - def print_result(self, th_results, dut_results): - """Print results to the screen. - Args: - th_results: list of test results for TH - dut_results: list of test results for DUT - """ - len_test_name = max(len(s['name']) for s in self.testlist) - len_code_name = max(len(self.get_return_code_name(v, True)) - for v in self.return_codes.values()) - - head = '{:^' + str(len_test_name) + '} ' - head += '{:^' + str(len_code_name) + '} ' - head += '{:^' + str(len_code_name) + '}' - head += '{:^' + str(len(' TH_STR')) + '}' - head += '{:^' + str(len(' DUT_STR')) + '}' - head += '{:^' + str(len(' RESULT')) + '}\n' - fmt = '{:' + str(len_test_name) + '} ' - fmt += '{:>' + str(len_code_name) + '} ' - fmt += '{:>' + str(len_code_name) + '}' - fmt += '{:>' + str(len(' TH_STR')) + '}' - fmt += '{:>' + str(len(' DUT_STR')) + '}' - fmt += '{:>' + str(len(' RESULT')) + '}\n' - - self.formatted_results = head.format( - 'TEST NAME', 'TH_RC', 'DUT_RC', - ' TH_STR', ' DUT_STR', ' RESULT') - for i, d in enumerate(dut_results): - th_cn = self.get_return_code_name(th_results[i]['rc'], True) - dut_cn = self.get_return_code_name(dut_results[i]['rc'], True) - th_res = self.evaluate_result(th_results[i], - self.testlist[i]['th_rc'], - self.testlist[i]['th_string']) - dut_res = self.evaluate_result(dut_results[i], - self.testlist[i]['dut_rc'], - self.testlist[i]['dut_string']) - self.formatted_results += fmt.format( - d['name'], th_cn, dut_cn, - 'YES' if th_results[i]['string'] else 'NO', - 'YES' if dut_results[i]['string'] else 'NO', - 'PASS' if th_res and dut_res else 'FAIL') - - def evaluate_result(self, result, expected_rc, expected_string): - if result['rc'] != expected_rc: - return False - if expected_string and expected_string not in result['output']: - return False - return True - - def run(self): - """Resets boards, records test results in results dir.""" - print('Reading serials...') - self.identify_boards() - print('Opening DUT tty...') - self.dut.setup_tty() - print('Opening TH tty...') - self.th.setup_tty() - - # Boards might be still writing to tty. Wait a few seconds before flashing. - time.sleep(3) - - # clear buffers - print('Clearing DUT tty...') - self.dut.read_tty() - print('Clearing TH tty...') - self.th.read_tty() - - # Resets the boards and allows them to run tests - # Due to current (7/27/16) version of sync function, - # both boards must be rest and halted, with the th - # resuming first, in order for the test suite to run in sync - print('Halting TH...') - if not self.th.reset_halt(): - raise RuntimeError('Failed to halt TH') - print('Halting DUT...') - if not self.dut.reset_halt(): - raise RuntimeError('Failed to halt DUT') - print('Resuming TH...') - if not self.th.resume(): - raise RuntimeError('Failed to resume TH') - print('Resuming DUT...') - if not self.dut.resume(): - raise RuntimeError('Failed to resume DUT') - - time.sleep(MAX_SUITE_TIME_SEC) - - print('Reading DUT tty...') - dut_output, _ = self.dut.read_tty() - self.dut.close_tty() - print('Reading TH tty...') - th_output, _ = self.th.read_tty() - self.th.close_tty() - - print('Halting TH...') - if not self.th.reset_halt(): - raise RuntimeError('Failed to halt TH') - print('Halting DUT...') - if not self.dut.reset_halt(): - raise RuntimeError('Failed to halt DUT') - - if not dut_output or not th_output: - raise ValueError('Output missing from boards. If you have a process ' - 'reading ttyACMx, please kill that process and try ' - 'again.') - - print('Pursing results...') - th_results, dut_results = self.evaluate_run(dut_output, th_output) - - # Print out results - self.print_result(th_results, dut_results) - - # Write results - dest = os.path.join(self.results_dir, 'results.log') - with open(dest, 'w') as fl: - fl.write(self.formatted_results) - - # Write UART outputs - dest = os.path.join(self.results_dir, 'uart_th.log') - with open(dest, 'w') as fl: - fl.write(th_output) - dest = os.path.join(self.results_dir, 'uart_dut.log') - with open(dest, 'w') as fl: - fl.write(dut_output) - - print(self.formatted_results) - - # TODO(chromium:735652): Should set exit code for the shell +def main(): + ec_dir = os.path.realpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") + ) + os.chdir(ec_dir) + + dut = DEFAULT_DUT + module = "meta" + + parser = argparse.ArgumentParser(description="Used to build/flash boards") + parser.add_argument("-d", "--dut", help="Specify DUT you want to build/flash") + parser.add_argument("-m", "--module", help="Specify module you want to build/flash") + parser.add_argument( + "-s", + "--setup", + action="store_true", + help="Connect only the TH to save its serial", + ) + parser.add_argument( + "-b", "--build", action="store_true", help="Build test suite (no flashing)" + ) + parser.add_argument( + "-f", + "--flash", + action="store_true", + help="Flash boards with most recent images", + ) + parser.add_argument( + "-r", "--run", action="store_true", help="Run tests without flashing" + ) + + args = parser.parse_args() + + if args.module: + module = args.module + + if args.dut: + dut = args.dut + + cts = Cts(ec_dir, DEFAULT_TH, dut=dut, module=module) + + if args.setup: + cts.setup() + elif args.build: + cts.build() + elif args.flash: + cts.flash_boards() + elif args.run: + cts.run() + else: + cts.build() + cts.flash_boards() + cts.run() -def main(): - ec_dir = os.path.realpath(os.path.join( - os.path.dirname(os.path.abspath(__file__)), '..')) - os.chdir(ec_dir) - - dut = DEFAULT_DUT - module = 'meta' - - parser = argparse.ArgumentParser(description='Used to build/flash boards') - parser.add_argument('-d', - '--dut', - help='Specify DUT you want to build/flash') - parser.add_argument('-m', - '--module', - help='Specify module you want to build/flash') - parser.add_argument('-s', - '--setup', - action='store_true', - help='Connect only the TH to save its serial') - parser.add_argument('-b', - '--build', - action='store_true', - help='Build test suite (no flashing)') - parser.add_argument('-f', - '--flash', - action='store_true', - help='Flash boards with most recent images') - parser.add_argument('-r', - '--run', - action='store_true', - help='Run tests without flashing') - - args = parser.parse_args() - - if args.module: - module = args.module - - if args.dut: - dut = args.dut - - cts = Cts(ec_dir, DEFAULT_TH, dut=dut, module=module) - - if args.setup: - cts.setup() - elif args.build: - cts.build() - elif args.flash: - cts.flash_boards() - elif args.run: - cts.run() - else: - cts.build() - cts.flash_boards() - cts.run() - -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/cr50_rma_open/cr50_rma_open.py b/extra/cr50_rma_open/cr50_rma_open.py index 42ddbbac2d..b77b8f3dbb 100755 --- a/extra/cr50_rma_open/cr50_rma_open.py +++ b/extra/cr50_rma_open/cr50_rma_open.py @@ -57,16 +57,17 @@ CCD_IS_UNRESTRICTED = 1 << 0 WP_IS_DISABLED = 1 << 1 TESTLAB_IS_ENABLED = 1 << 2 RMA_OPENED = CCD_IS_UNRESTRICTED | WP_IS_DISABLED -URL = ('https://www.google.com/chromeos/partner/console/cr50reset?' - 'challenge=%s&hwid=%s') -RMA_SUPPORT_PROD = '0.3.3' -RMA_SUPPORT_PREPVT = '0.4.5' -DEV_MODE_OPEN_PROD = '0.3.9' -DEV_MODE_OPEN_PREPVT = '0.4.7' -TESTLAB_PROD = '0.3.10' -CR50_USB = '18d1:5014' -CR50_LSUSB_CMD = ['lsusb', '-vd', CR50_USB] -ERASED_BID = 'ffffffff' +URL = ( + "https://www.google.com/chromeos/partner/console/cr50reset?" "challenge=%s&hwid=%s" +) +RMA_SUPPORT_PROD = "0.3.3" +RMA_SUPPORT_PREPVT = "0.4.5" +DEV_MODE_OPEN_PROD = "0.3.9" +DEV_MODE_OPEN_PREPVT = "0.4.7" +TESTLAB_PROD = "0.3.10" +CR50_USB = "18d1:5014" +CR50_LSUSB_CMD = ["lsusb", "-vd", CR50_USB] +ERASED_BID = "ffffffff" DEBUG_MISSING_USB = """ Unable to find Cr50 Device 18d1:5014 @@ -128,13 +129,14 @@ DEBUG_DUT_CONTROL_OSERROR = """ Run from chroot if you are trying to use a /dev/pts ccd servo console """ + class RMAOpen(object): """Used to find the cr50 console and run RMA open""" - ENABLE_TESTLAB_CMD = 'ccd testlab enabled\n' + ENABLE_TESTLAB_CMD = "ccd testlab enabled\n" def __init__(self, device=None, usb_serial=None, servo_port=None, ip=None): - self.servo_port = servo_port if servo_port else '9999' + self.servo_port = servo_port if servo_port else "9999" self.ip = ip if device: self.set_cr50_device(device) @@ -142,18 +144,18 @@ class RMAOpen(object): self.find_cr50_servo_uart() else: self.find_cr50_device(usb_serial) - logging.info('DEVICE: %s', self.device) + logging.info("DEVICE: %s", self.device) self.check_version() self.print_platform_info() - logging.info('Cr50 setup ok') + logging.info("Cr50 setup ok") self.update_ccd_state() self.using_ccd = self.device_is_running_with_servo_ccd() def _dut_control(self, control): """Run dut-control and return the response""" try: - cmd = ['dut-control', '-p', self.servo_port, control] - return subprocess.check_output(cmd, encoding='utf-8').strip() + cmd = ["dut-control", "-p", self.servo_port, control] + return subprocess.check_output(cmd, encoding="utf-8").strip() except OSError: logging.warning(DEBUG_DUT_CONTROL_OSERROR) raise @@ -163,8 +165,8 @@ class RMAOpen(object): Find the console and configure it, so it can be used with this script. """ - self._dut_control('cr50_uart_timestamp:off') - self.device = self._dut_control('cr50_uart_pty').split(':')[-1] + self._dut_control("cr50_uart_timestamp:off") + self.device = self._dut_control("cr50_uart_pty").split(":")[-1] def set_cr50_device(self, device): """Save the device used for the console""" @@ -183,38 +185,38 @@ class RMAOpen(object): try: ser = serial.Serial(self.device, timeout=1) except OSError: - logging.warning('Permission denied %s', self.device) - logging.warning('Try running cr50_rma_open with sudo') + logging.warning("Permission denied %s", self.device) + logging.warning("Try running cr50_rma_open with sudo") raise - write_cmd = cmd + '\n\n' - ser.write(write_cmd.encode('utf-8')) + write_cmd = cmd + "\n\n" + ser.write(write_cmd.encode("utf-8")) if nbytes: output = ser.read(nbytes) else: output = ser.readall() ser.close() - output = output.decode('utf-8').strip() if output else '' + output = output.decode("utf-8").strip() if output else "" # Return only the command output - split_cmd = cmd + '\r' + split_cmd = cmd + "\r" if cmd and split_cmd in output: - return ''.join(output.rpartition(split_cmd)[1::]).split('>')[0] + return "".join(output.rpartition(split_cmd)[1::]).split(">")[0] return output def device_is_running_with_servo_ccd(self): """Return True if the device is a servod ccd console""" # servod uses /dev/pts consoles. Non-servod uses /dev/ttyUSBX - if '/dev/pts' not in self.device: + if "/dev/pts" not in self.device: return False # If cr50 doesn't show rdd is connected, cr50 the device must not be # a ccd device - if 'Rdd: connected' not in self.send_cmd_get_output('ccdstate'): + if "Rdd: connected" not in self.send_cmd_get_output("ccdstate"): return False # Check if the servod is running with ccd. This requires the script # is run in the chroot, so run it last. - if 'ccd_cr50' not in self._dut_control('servo_type'): + if "ccd_cr50" not in self._dut_control("servo_type"): return False - logging.info('running through servod ccd') + logging.info("running through servod ccd") return True def get_rma_challenge(self): @@ -239,14 +241,14 @@ class RMAOpen(object): Returns: The RMA challenge with all whitespace removed. """ - output = self.send_cmd_get_output('rma_auth').strip() - logging.info('rma_auth output:\n%s', output) + output = self.send_cmd_get_output("rma_auth").strip() + logging.info("rma_auth output:\n%s", output) # Extract the challenge from the console output - if 'generated challenge:' in output: - return output.split('generated challenge:')[-1].strip() - challenge = ''.join(re.findall(r' \S{5}' * 4, output)) + if "generated challenge:" in output: + return output.split("generated challenge:")[-1].strip() + challenge = "".join(re.findall(r" \S{5}" * 4, output)) # Remove all whitespace - return re.sub(r'\s', '', challenge) + return re.sub(r"\s", "", challenge) def generate_challenge_url(self, hwid): """Get the rma_auth challenge @@ -257,12 +259,14 @@ class RMAOpen(object): challenge = self.get_rma_challenge() self.print_platform_info() - logging.info('CHALLENGE: %s', challenge) - logging.info('HWID: %s', hwid) + logging.info("CHALLENGE: %s", challenge) + logging.info("HWID: %s", hwid) url = URL % (challenge, hwid) - logging.info('GOTO:\n %s', url) - logging.info('If the server fails to debug the challenge make sure the ' - 'RLZ is allowlisted') + logging.info("GOTO:\n %s", url) + logging.info( + "If the server fails to debug the challenge make sure the " + "RLZ is allowlisted" + ) def try_authcode(self, authcode): """Try opening cr50 with the authcode @@ -272,48 +276,48 @@ class RMAOpen(object): """ # rma_auth may cause the system to reboot. Don't wait to read all that # output. Read the first 300 bytes and call it a day. - output = self.send_cmd_get_output('rma_auth ' + authcode, nbytes=300) - logging.info('CR50 RESPONSE: %s', output) - logging.info('waiting for cr50 reboot') + output = self.send_cmd_get_output("rma_auth " + authcode, nbytes=300) + logging.info("CR50 RESPONSE: %s", output) + logging.info("waiting for cr50 reboot") # Cr50 may be rebooting. Wait a bit time.sleep(5) if self.using_ccd: # After reboot, reset the ccd endpoints - self._dut_control('power_state:ccd_reset') + self._dut_control("power_state:ccd_reset") # Update the ccd state after the authcode attempt self.update_ccd_state() - authcode_match = 'process_response: success!' in output + authcode_match = "process_response: success!" in output if not self.check(CCD_IS_UNRESTRICTED): if not authcode_match: logging.warning(DEBUG_AUTHCODE_MISMATCH) - message = 'Authcode mismatch. Check args and url' + message = "Authcode mismatch. Check args and url" else: - message = 'Could not set all capability privileges to Always' + message = "Could not set all capability privileges to Always" raise ValueError(message) def wp_is_force_disabled(self): """Returns True if write protect is forced disabled""" - output = self.send_cmd_get_output('wp') - wp_state = output.split('Flash WP:', 1)[-1].split('\n', 1)[0].strip() - logging.info('wp: %s', wp_state) - return wp_state == 'forced disabled' + output = self.send_cmd_get_output("wp") + wp_state = output.split("Flash WP:", 1)[-1].split("\n", 1)[0].strip() + logging.info("wp: %s", wp_state) + return wp_state == "forced disabled" def testlab_is_enabled(self): """Returns True if testlab mode is enabled""" - output = self.send_cmd_get_output('ccd testlab') - testlab_state = output.split('mode')[-1].strip().lower() - logging.info('testlab: %s', testlab_state) - return testlab_state == 'enabled' + output = self.send_cmd_get_output("ccd testlab") + testlab_state = output.split("mode")[-1].strip().lower() + logging.info("testlab: %s", testlab_state) + return testlab_state == "enabled" def ccd_is_restricted(self): """Returns True if any of the capabilities are still restricted""" - output = self.send_cmd_get_output('ccd') - if 'Capabilities' not in output: - raise ValueError('Could not get ccd output') - logging.debug('CURRENT CCD SETTINGS:\n%s', output) - restricted = 'IfOpened' in output or 'IfUnlocked' in output - logging.info('ccd: %srestricted', '' if restricted else 'Un') + output = self.send_cmd_get_output("ccd") + if "Capabilities" not in output: + raise ValueError("Could not get ccd output") + logging.debug("CURRENT CCD SETTINGS:\n%s", output) + restricted = "IfOpened" in output or "IfUnlocked" in output + logging.info("ccd: %srestricted", "" if restricted else "Un") return restricted def update_ccd_state(self): @@ -339,9 +343,10 @@ class RMAOpen(object): def _capabilities_allow_open_from_console(self): """Return True if ccd open is Always allowed from usb""" - output = self.send_cmd_get_output('ccd') - return (re.search('OpenNoDevMode.*Always', output) and - re.search('OpenFromUSB.*Always', output)) + output = self.send_cmd_get_output("ccd") + return re.search("OpenNoDevMode.*Always", output) and re.search( + "OpenFromUSB.*Always", output + ) def _requires_dev_mode_open(self): """Return True if the image requires dev mode to open""" @@ -354,78 +359,81 @@ class RMAOpen(object): def _run_on_dut(self, command): """Run the command on the DUT.""" - return subprocess.check_output(['ssh', self.ip, command], - encoding='utf-8') + return subprocess.check_output(["ssh", self.ip, command], encoding="utf-8") def _open_in_dev_mode(self): """Open Cr50 when it's in dev mode""" - output = self.send_cmd_get_output('ccd') + output = self.send_cmd_get_output("ccd") # If the device is already open, nothing needs to be done. - if 'State: Open' not in output: + if "State: Open" not in output: # Verify the device is in devmode before trying to run open. - if 'dev_mode' not in output: - logging.warning('Enter dev mode to open ccd or update to %s', - TESTLAB_PROD) - raise ValueError('DUT not in dev mode') + if "dev_mode" not in output: + logging.warning( + "Enter dev mode to open ccd or update to %s", TESTLAB_PROD + ) + raise ValueError("DUT not in dev mode") if not self.ip: - logging.warning("If your DUT doesn't have ssh support, run " - "'gsctool -a -o' from the AP") - raise ValueError('Cannot run ccd open without dut ip') - self._run_on_dut('gsctool -a -o') + logging.warning( + "If your DUT doesn't have ssh support, run " + "'gsctool -a -o' from the AP" + ) + raise ValueError("Cannot run ccd open without dut ip") + self._run_on_dut("gsctool -a -o") # Wait >1 second for cr50 to update ccd state time.sleep(3) - output = self.send_cmd_get_output('ccd') - if 'State: Open' not in output: - raise ValueError('Could not open cr50') - logging.info('ccd is open') + output = self.send_cmd_get_output("ccd") + if "State: Open" not in output: + raise ValueError("Could not open cr50") + logging.info("ccd is open") def enable_testlab(self): """Disable write protect""" if not self._has_testlab_support(): - logging.warning('Testlab mode is not supported in prod iamges') + logging.warning("Testlab mode is not supported in prod iamges") return # Some cr50 images need to be in dev mode before they can be opened. if self._requires_dev_mode_open(): self._open_in_dev_mode() else: - self.send_cmd_get_output('ccd open') - logging.info('Enabling testlab mode reqires pressing the power button.') - logging.info('Once the process starts keep tapping the power button ' - 'for 10 seconds.') + self.send_cmd_get_output("ccd open") + logging.info("Enabling testlab mode reqires pressing the power button.") + logging.info( + "Once the process starts keep tapping the power button " "for 10 seconds." + ) input("Press Enter when you're ready to start...") end_time = time.time() + 15 ser = serial.Serial(self.device, timeout=1) - printed_lines = '' - output = '' + printed_lines = "" + output = "" # start ccd testlab enable - ser.write(self.ENABLE_TESTLAB_CMD.encode('utf-8')) - logging.info('start pressing the power button\n\n') + ser.write(self.ENABLE_TESTLAB_CMD.encode("utf-8")) + logging.info("start pressing the power button\n\n") # Print all of the cr50 output as we get it, so the user will have more # information about pressing the power button. Tapping the power button # a couple of times should do it, but this will give us more confidence # the process is still running/worked. try: while time.time() < end_time: - output += ser.read(100).decode('utf-8') - full_lines = output.rsplit('\n', 1)[0] + output += ser.read(100).decode("utf-8") + full_lines = output.rsplit("\n", 1)[0] new_lines = full_lines if printed_lines: new_lines = full_lines.split(printed_lines, 1)[-1].strip() - logging.info('\n%s', new_lines) + logging.info("\n%s", new_lines) printed_lines = full_lines # Make sure the process hasn't ended. If it has, print the last # of the output and exit. new_lines = output.split(printed_lines, 1)[-1] - if 'CCD test lab mode enabled' in output: + if "CCD test lab mode enabled" in output: # print the last of the ou logging.info(new_lines) break - elif 'Physical presence check timeout' in output: + elif "Physical presence check timeout" in output: logging.info(new_lines) - logging.warning('Did not detect power button press in time') - raise ValueError('Could not enable testlab mode try again') + logging.warning("Did not detect power button press in time") + raise ValueError("Could not enable testlab mode try again") finally: ser.close() # Wait for the ccd hook to update things @@ -433,44 +441,48 @@ class RMAOpen(object): # Update the state after attempting to disable write protect self.update_ccd_state() if not self.check(TESTLAB_IS_ENABLED): - raise ValueError('Could not enable testlab mode try again') + raise ValueError("Could not enable testlab mode try again") def wp_disable(self): """Disable write protect""" - logging.info('Disabling write protect') - self.send_cmd_get_output('wp disable') + logging.info("Disabling write protect") + self.send_cmd_get_output("wp disable") # Update the state after attempting to disable write protect self.update_ccd_state() if not self.check(WP_IS_DISABLED): - raise ValueError('Could not disable write protect') + raise ValueError("Could not disable write protect") def check_version(self): """Make sure cr50 is running a version that supports RMA Open""" - output = self.send_cmd_get_output('version') + output = self.send_cmd_get_output("version") if not output.strip(): logging.warning(DEBUG_DEVICE, self.device) - raise ValueError('Could not communicate with %s' % self.device) + raise ValueError("Could not communicate with %s" % self.device) - version = re.search(r'RW.*\* ([\d\.]+)/', output).group(1) - logging.info('Running Cr50 Version: %s', version) - self.running_ver_fields = [int(field) for field in version.split('.')] + version = re.search(r"RW.*\* ([\d\.]+)/", output).group(1) + logging.info("Running Cr50 Version: %s", version) + self.running_ver_fields = [int(field) for field in version.split(".")] # prePVT images have even major versions. Prod have odd self.is_prepvt = self.running_ver_fields[1] % 2 == 0 rma_support = RMA_SUPPORT_PREPVT if self.is_prepvt else RMA_SUPPORT_PROD - logging.info('%s RMA support added in: %s', - 'prePVT' if self.is_prepvt else 'prod', rma_support) + logging.info( + "%s RMA support added in: %s", + "prePVT" if self.is_prepvt else "prod", + rma_support, + ) if not self.is_prepvt and self._running_version_is_older(TESTLAB_PROD): - raise ValueError('Update cr50. No testlab support in old prod ' - 'images.') + raise ValueError("Update cr50. No testlab support in old prod " "images.") if self._running_version_is_older(rma_support): - raise ValueError('%s does not have RMA support. Update to at ' - 'least %s' % (version, rma_support)) + raise ValueError( + "%s does not have RMA support. Update to at " + "least %s" % (version, rma_support) + ) def _running_version_is_older(self, target_ver): """Returns True if running version is older than target_ver.""" - target_ver_fields = [int(field) for field in target_ver.split('.')] + target_ver_fields = [int(field) for field in target_ver.split(".")] for i, field in enumerate(self.running_ver_fields): if field > int(target_ver_fields[i]): return False @@ -486,11 +498,11 @@ class RMAOpen(object): is no output or sysinfo doesn't contain the devid. """ self.set_cr50_device(device) - sysinfo = self.send_cmd_get_output('sysinfo') + sysinfo = self.send_cmd_get_output("sysinfo") # Make sure there is some output, and it shows it's from Cr50 - if not sysinfo or 'cr50' not in sysinfo: + if not sysinfo or "cr50" not in sysinfo: return False - logging.debug('Sysinfo output: %s', sysinfo) + logging.debug("Sysinfo output: %s", sysinfo) # The cr50 device id should be in the sysinfo output, if we found # the right console. Make sure it is return devid in sysinfo @@ -508,104 +520,137 @@ class RMAOpen(object): ValueError if the console can't be found with the given serialname """ usb_serial = self.find_cr50_usb(usb_serial) - logging.info('SERIALNAME: %s', usb_serial) - devid = '0x' + ' 0x'.join(usb_serial.lower().split('-')) - logging.info('DEVID: %s', devid) + logging.info("SERIALNAME: %s", usb_serial) + devid = "0x" + " 0x".join(usb_serial.lower().split("-")) + logging.info("DEVID: %s", devid) # Get all the usb devices - devices = glob.glob('/dev/ttyUSB*') + devices = glob.glob("/dev/ttyUSB*") # Typically Cr50 has the lowest number. Sort the devices, so we're more # likely to try the cr50 console first. devices.sort() # Find the one that is the cr50 console for device in devices: - logging.info('testing %s', device) + logging.info("testing %s", device) if self.device_matches_devid(devid, device): - logging.info('found device: %s', device) + logging.info("found device: %s", device) return logging.warning(DEBUG_CONNECTION) - raise ValueError('Found USB device, but could not communicate with ' - 'cr50 console') + raise ValueError( + "Found USB device, but could not communicate with " "cr50 console" + ) def print_platform_info(self): """Print the cr50 BID RLZ code""" - bid_output = self.send_cmd_get_output('bid') - bid = re.search(r'Board ID: (\S+?)[:,]', bid_output).group(1) + bid_output = self.send_cmd_get_output("bid") + bid = re.search(r"Board ID: (\S+?)[:,]", bid_output).group(1) if bid == ERASED_BID: logging.warning(DEBUG_ERASED_BOARD_ID) - raise ValueError('Cannot run RMA Open when board id is erased') + raise ValueError("Cannot run RMA Open when board id is erased") bid = int(bid, 16) - chrs = [chr((bid >> (8 * i)) & 0xff) for i in range(4)] - logging.info('RLZ: %s', ''.join(chrs[::-1])) + chrs = [chr((bid >> (8 * i)) & 0xFF) for i in range(4)] + logging.info("RLZ: %s", "".join(chrs[::-1])) @staticmethod def find_cr50_usb(usb_serial): """Make sure the Cr50 USB device exists""" try: - output = subprocess.check_output(CR50_LSUSB_CMD, encoding='utf-8') + output = subprocess.check_output(CR50_LSUSB_CMD, encoding="utf-8") except: logging.warning(DEBUG_MISSING_USB) - raise ValueError('Could not find Cr50 USB device') - serialnames = re.findall(r'iSerial +\d+ (\S+)\s', output) + raise ValueError("Could not find Cr50 USB device") + serialnames = re.findall(r"iSerial +\d+ (\S+)\s", output) if usb_serial: if usb_serial not in serialnames: logging.warning(DEBUG_SERIALNAME) raise ValueError('Could not find usb device "%s"' % usb_serial) return usb_serial if len(serialnames) > 1: - logging.info('Found Cr50 device serialnames %s', - ', '.join(serialnames)) + logging.info("Found Cr50 device serialnames %s", ", ".join(serialnames)) logging.warning(DEBUG_TOO_MANY_USB_DEVICES) - raise ValueError('Too many cr50 usb devices') + raise ValueError("Too many cr50 usb devices") return serialnames[0] def print_dut_state(self): """Print CCD RMA and testlab mode state.""" if not self.check(CCD_IS_UNRESTRICTED): - logging.info('CCD is still restricted.') - logging.info('Run cr50_rma_open.py -g -i $HWID to generate a url') - logging.info('Run cr50_rma_open.py -a $AUTHCODE to open cr50 with ' - 'an authcode') + logging.info("CCD is still restricted.") + logging.info("Run cr50_rma_open.py -g -i $HWID to generate a url") + logging.info( + "Run cr50_rma_open.py -a $AUTHCODE to open cr50 with " "an authcode" + ) elif not self.check(WP_IS_DISABLED): - logging.info('WP is still enabled.') - logging.info('Run cr50_rma_open.py -w to disable write protect') + logging.info("WP is still enabled.") + logging.info("Run cr50_rma_open.py -w to disable write protect") if self.check(RMA_OPENED): - logging.info('RMA Open complete') + logging.info("RMA Open complete") if not self.check(TESTLAB_IS_ENABLED) and self.is_prepvt: - logging.info('testlab mode is disabled.') - logging.info('If you are prepping a device for the testlab, you ' - 'should enable testlab mode.') - logging.info('Run cr50_rma_open.py -t to enable testlab mode') + logging.info("testlab mode is disabled.") + logging.info( + "If you are prepping a device for the testlab, you " + "should enable testlab mode." + ) + logging.info("Run cr50_rma_open.py -t to enable testlab mode") def parse_args(argv): """Get cr50_rma_open args.""" parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('-g', '--generate_challenge', action='store_true', - help='Generate Cr50 challenge. Must be used with -i') - parser.add_argument('-t', '--enable_testlab', action='store_true', - help='enable testlab mode') - parser.add_argument('-w', '--wp_disable', action='store_true', - help='Disable write protect') - parser.add_argument('-c', '--check_connection', action='store_true', - help='Check cr50 console connection works') - parser.add_argument('-s', '--serialname', type=str, default='', - help='The cr50 usb serialname') - parser.add_argument('-D', '--debug', action='store_true', - help='print debug messages') - parser.add_argument('-d', '--device', type=str, default='', - help='cr50 console device ex /dev/ttyUSB0') - parser.add_argument('-i', '--hwid', type=str, default='', - help='The board hwid. Needed to generate a challenge') - parser.add_argument('-a', '--authcode', type=str, default='', - help='The authcode string from the challenge url') - parser.add_argument('-P', '--servo_port', type=str, default='', - help='the servo port') - parser.add_argument('-I', '--ip', type=str, default='', - help='The DUT IP. Necessary to do ccd open') + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "-g", + "--generate_challenge", + action="store_true", + help="Generate Cr50 challenge. Must be used with -i", + ) + parser.add_argument( + "-t", "--enable_testlab", action="store_true", help="enable testlab mode" + ) + parser.add_argument( + "-w", "--wp_disable", action="store_true", help="Disable write protect" + ) + parser.add_argument( + "-c", + "--check_connection", + action="store_true", + help="Check cr50 console connection works", + ) + parser.add_argument( + "-s", "--serialname", type=str, default="", help="The cr50 usb serialname" + ) + parser.add_argument( + "-D", "--debug", action="store_true", help="print debug messages" + ) + parser.add_argument( + "-d", + "--device", + type=str, + default="", + help="cr50 console device ex /dev/ttyUSB0", + ) + parser.add_argument( + "-i", + "--hwid", + type=str, + default="", + help="The board hwid. Needed to generate a challenge", + ) + parser.add_argument( + "-a", + "--authcode", + type=str, + default="", + help="The authcode string from the challenge url", + ) + parser.add_argument( + "-P", "--servo_port", type=str, default="", help="the servo port" + ) + parser.add_argument( + "-I", "--ip", type=str, default="", help="The DUT IP. Necessary to do ccd open" + ) return parser.parse_args(argv) @@ -614,52 +659,53 @@ def main(argv): opts = parse_args(argv) loglevel = logging.INFO - log_format = '%(levelname)7s' + log_format = "%(levelname)7s" if opts.debug: loglevel = logging.DEBUG - log_format += ' - %(lineno)3d:%(funcName)-15s' - log_format += ' - %(message)s' + log_format += " - %(lineno)3d:%(funcName)-15s" + log_format += " - %(message)s" logging.basicConfig(level=loglevel, format=log_format) tried_authcode = False - logging.info('Running cr50_rma_open version %s', SCRIPT_VERSION) + logging.info("Running cr50_rma_open version %s", SCRIPT_VERSION) - cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, - opts.ip) + cr50_rma_open = RMAOpen(opts.device, opts.serialname, opts.servo_port, opts.ip) if opts.check_connection: sys.exit(0) if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): if opts.generate_challenge: if not opts.hwid: - logging.warning('--hwid necessary to generate challenge url') + logging.warning("--hwid necessary to generate challenge url") sys.exit(0) cr50_rma_open.generate_challenge_url(opts.hwid) sys.exit(0) elif opts.authcode: - logging.info('Using authcode: %s', opts.authcode) + logging.info("Using authcode: %s", opts.authcode) cr50_rma_open.try_authcode(opts.authcode) tried_authcode = True - if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or - opts.wp_disable): + if not cr50_rma_open.check(WP_IS_DISABLED) and (tried_authcode or opts.wp_disable): if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): - raise ValueError("Can't disable write protect unless ccd is " - "open. Run through the rma open process first") + raise ValueError( + "Can't disable write protect unless ccd is " + "open. Run through the rma open process first" + ) if tried_authcode: - logging.warning('RMA Open did not disable write protect. File a ' - 'bug') - logging.warning('Trying to disable it manually') + logging.warning("RMA Open did not disable write protect. File a " "bug") + logging.warning("Trying to disable it manually") cr50_rma_open.wp_disable() if not cr50_rma_open.check(TESTLAB_IS_ENABLED) and opts.enable_testlab: if not cr50_rma_open.check(CCD_IS_UNRESTRICTED): - raise ValueError("Can't enable testlab mode unless ccd is open." - "Run through the rma open process first") + raise ValueError( + "Can't enable testlab mode unless ccd is open." + "Run through the rma open process first" + ) cr50_rma_open.enable_testlab() cr50_rma_open.print_dut_state() -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/extra/stack_analyzer/stack_analyzer.py b/extra/stack_analyzer/stack_analyzer.py index 77d16d5450..17b2651972 100755 --- a/extra/stack_analyzer/stack_analyzer.py +++ b/extra/stack_analyzer/stack_analyzer.py @@ -25,1848 +25,1992 @@ import ctypes import os import re import subprocess -import yaml +import yaml -SECTION_RO = 'RO' -SECTION_RW = 'RW' +SECTION_RO = "RO" +SECTION_RW = "RW" # Default size of extra stack frame needed by exception context switch. # This value is for cortex-m with FPU enabled. DEFAULT_EXCEPTION_FRAME_SIZE = 224 class StackAnalyzerError(Exception): - """Exception class for stack analyzer utility.""" + """Exception class for stack analyzer utility.""" class TaskInfo(ctypes.Structure): - """Taskinfo ctypes structure. + """Taskinfo ctypes structure. - The structure definition is corresponding to the "struct taskinfo" - in "util/export_taskinfo.so.c". - """ - _fields_ = [('name', ctypes.c_char_p), - ('routine', ctypes.c_char_p), - ('stack_size', ctypes.c_uint32)] + The structure definition is corresponding to the "struct taskinfo" + in "util/export_taskinfo.so.c". + """ + _fields_ = [ + ("name", ctypes.c_char_p), + ("routine", ctypes.c_char_p), + ("stack_size", ctypes.c_uint32), + ] -class Task(object): - """Task information. - Attributes: - name: Task name. - routine_name: Routine function name. - stack_max_size: Max stack size. - routine_address: Resolved routine address. None if it hasn't been resolved. - """ - - def __init__(self, name, routine_name, stack_max_size, routine_address=None): - """Constructor. +class Task(object): + """Task information. - Args: + Attributes: name: Task name. routine_name: Routine function name. stack_max_size: Max stack size. - routine_address: Resolved routine address. + routine_address: Resolved routine address. None if it hasn't been resolved. """ - self.name = name - self.routine_name = routine_name - self.stack_max_size = stack_max_size - self.routine_address = routine_address - def __eq__(self, other): - """Task equality. + def __init__(self, name, routine_name, stack_max_size, routine_address=None): + """Constructor. - Args: - other: The compared object. + Args: + name: Task name. + routine_name: Routine function name. + stack_max_size: Max stack size. + routine_address: Resolved routine address. + """ + self.name = name + self.routine_name = routine_name + self.stack_max_size = stack_max_size + self.routine_address = routine_address - Returns: - True if equal, False if not. - """ - if not isinstance(other, Task): - return False + def __eq__(self, other): + """Task equality. - return (self.name == other.name and - self.routine_name == other.routine_name and - self.stack_max_size == other.stack_max_size and - self.routine_address == other.routine_address) + Args: + other: The compared object. + Returns: + True if equal, False if not. + """ + if not isinstance(other, Task): + return False -class Symbol(object): - """Symbol information. + return ( + self.name == other.name + and self.routine_name == other.routine_name + and self.stack_max_size == other.stack_max_size + and self.routine_address == other.routine_address + ) - Attributes: - address: Symbol address. - symtype: Symbol type, 'O' (data, object) or 'F' (function). - size: Symbol size. - name: Symbol name. - """ - def __init__(self, address, symtype, size, name): - """Constructor. +class Symbol(object): + """Symbol information. - Args: + Attributes: address: Symbol address. - symtype: Symbol type. + symtype: Symbol type, 'O' (data, object) or 'F' (function). size: Symbol size. name: Symbol name. """ - assert symtype in ['O', 'F'] - self.address = address - self.symtype = symtype - self.size = size - self.name = name - def __eq__(self, other): - """Symbol equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Symbol): - return False - - return (self.address == other.address and - self.symtype == other.symtype and - self.size == other.size and - self.name == other.name) + def __init__(self, address, symtype, size, name): + """Constructor. + + Args: + address: Symbol address. + symtype: Symbol type. + size: Symbol size. + name: Symbol name. + """ + assert symtype in ["O", "F"] + self.address = address + self.symtype = symtype + self.size = size + self.name = name + + def __eq__(self, other): + """Symbol equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Symbol): + return False + + return ( + self.address == other.address + and self.symtype == other.symtype + and self.size == other.size + and self.name == other.name + ) class Callsite(object): - """Function callsite. - - Attributes: - address: Address of callsite location. None if it is unknown. - target: Callee address. None if it is unknown. - is_tail: A bool indicates that it is a tailing call. - callee: Resolved callee function. None if it hasn't been resolved. - """ - - def __init__(self, address, target, is_tail, callee=None): - """Constructor. + """Function callsite. - Args: + Attributes: address: Address of callsite location. None if it is unknown. target: Callee address. None if it is unknown. - is_tail: A bool indicates that it is a tailing call. (function jump to - another function without restoring the stack frame) - callee: Resolved callee function. + is_tail: A bool indicates that it is a tailing call. + callee: Resolved callee function. None if it hasn't been resolved. """ - # It makes no sense that both address and target are unknown. - assert not (address is None and target is None) - self.address = address - self.target = target - self.is_tail = is_tail - self.callee = callee - - def __eq__(self, other): - """Callsite equality. - - Args: - other: The compared object. - Returns: - True if equal, False if not. - """ - if not isinstance(other, Callsite): - return False - - if not (self.address == other.address and - self.target == other.target and - self.is_tail == other.is_tail): - return False - - if self.callee is None: - return other.callee is None - elif other.callee is None: - return False - - # Assume the addresses of functions are unique. - return self.callee.address == other.callee.address + def __init__(self, address, target, is_tail, callee=None): + """Constructor. + + Args: + address: Address of callsite location. None if it is unknown. + target: Callee address. None if it is unknown. + is_tail: A bool indicates that it is a tailing call. (function jump to + another function without restoring the stack frame) + callee: Resolved callee function. + """ + # It makes no sense that both address and target are unknown. + assert not (address is None and target is None) + self.address = address + self.target = target + self.is_tail = is_tail + self.callee = callee + + def __eq__(self, other): + """Callsite equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Callsite): + return False + + if not ( + self.address == other.address + and self.target == other.target + and self.is_tail == other.is_tail + ): + return False + + if self.callee is None: + return other.callee is None + elif other.callee is None: + return False + + # Assume the addresses of functions are unique. + return self.callee.address == other.callee.address class Function(object): - """Function. - - Attributes: - address: Address of function. - name: Name of function from its symbol. - stack_frame: Size of stack frame. - callsites: Callsite list. - stack_max_usage: Max stack usage. None if it hasn't been analyzed. - stack_max_path: Max stack usage path. None if it hasn't been analyzed. - """ + """Function. - def __init__(self, address, name, stack_frame, callsites): - """Constructor. - - Args: + Attributes: address: Address of function. name: Name of function from its symbol. stack_frame: Size of stack frame. callsites: Callsite list. + stack_max_usage: Max stack usage. None if it hasn't been analyzed. + stack_max_path: Max stack usage path. None if it hasn't been analyzed. """ - self.address = address - self.name = name - self.stack_frame = stack_frame - self.callsites = callsites - self.stack_max_usage = None - self.stack_max_path = None - - def __eq__(self, other): - """Function equality. - - Args: - other: The compared object. - - Returns: - True if equal, False if not. - """ - if not isinstance(other, Function): - return False - - if not (self.address == other.address and - self.name == other.name and - self.stack_frame == other.stack_frame and - self.callsites == other.callsites and - self.stack_max_usage == other.stack_max_usage): - return False - if self.stack_max_path is None: - return other.stack_max_path is None - elif other.stack_max_path is None: - return False + def __init__(self, address, name, stack_frame, callsites): + """Constructor. + + Args: + address: Address of function. + name: Name of function from its symbol. + stack_frame: Size of stack frame. + callsites: Callsite list. + """ + self.address = address + self.name = name + self.stack_frame = stack_frame + self.callsites = callsites + self.stack_max_usage = None + self.stack_max_path = None + + def __eq__(self, other): + """Function equality. + + Args: + other: The compared object. + + Returns: + True if equal, False if not. + """ + if not isinstance(other, Function): + return False + + if not ( + self.address == other.address + and self.name == other.name + and self.stack_frame == other.stack_frame + and self.callsites == other.callsites + and self.stack_max_usage == other.stack_max_usage + ): + return False + + if self.stack_max_path is None: + return other.stack_max_path is None + elif other.stack_max_path is None: + return False + + if len(self.stack_max_path) != len(other.stack_max_path): + return False + + for self_func, other_func in zip(self.stack_max_path, other.stack_max_path): + # Assume the addresses of functions are unique. + if self_func.address != other_func.address: + return False + + return True + + def __hash__(self): + return id(self) - if len(self.stack_max_path) != len(other.stack_max_path): - return False - - for self_func, other_func in zip(self.stack_max_path, other.stack_max_path): - # Assume the addresses of functions are unique. - if self_func.address != other_func.address: - return False - - return True - - def __hash__(self): - return id(self) class AndesAnalyzer(object): - """Disassembly analyzer for Andes architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - GENERAL_PURPOSE_REGISTER_SIZE = 4 - - # Possible condition code suffixes. - CONDITION_CODES = [ 'eq', 'eqz', 'gez', 'gtz', 'lez', 'ltz', 'ne', 'nez', - 'eqc', 'nec', 'nezs', 'nes', 'eqs'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - - IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>' - # Branch instructions. - JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$' \ - .format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile \ - (r'^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$') - CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE)) - # Ignore lp register because it's for return. - INDIRECT_CALL_OPERAND_RE = re.compile \ - (r'^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$') - # TODO: Handle other kinds of store instructions. - PUSH_OPCODE_RE = re.compile(r'^push(\d{1,})$') - PUSH_OPERAND_RE = re.compile(r'^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}') - SMW_OPCODE_RE = re.compile(r'^smw(\.\w\w|\.\w\w\w)$') - SMW_OPERAND_RE = re.compile(r'^(\$r\d{1,}|\$\wp), \[\$\wp\], ' - r'(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}') - OPERANDGROUP_RE = re.compile(r'^\$r\d{1,}\~\$r\d{1,}') - - LWI_OPCODE_RE = re.compile(r'^lwi(\.\w\w)$') - LWI_PC_OPERAND_RE = re.compile(r'^\$pc, \[([^\]]+)\]') - # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec" - # Assume there is always a "\t" after the hex data. - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+' - r'(?P<words>[0-9A-Fa-f ]+)' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + """Disassembly analyzer for Andes architecture. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - - Returns: - (address, words, opcode, operand_text): The instruction address, words, - opcode, and the text of operands. - None if it isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - words = result.group('words') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, words, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - - stack_frame = 0 - callsites = [] - for address, words, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - - if is_jump_opcode or is_call_opcode: - is_tail = is_jump_opcode - - result = self.CALL_OPERAND_RE.match(operand_text) + GENERAL_PURPOSE_REGISTER_SIZE = 4 + + # Possible condition code suffixes. + CONDITION_CODES = [ + "eq", + "eqz", + "gez", + "gtz", + "lez", + "ltz", + "ne", + "nez", + "eqc", + "nec", + "nezs", + "nes", + "eqs", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + + IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>" + # Branch instructions. + JUMP_OPCODE_RE = re.compile( + r"^(b{0}|j|jr|jr.|jrnez)(\d?|\d\d)$".format(CONDITION_CODES_RE) + ) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(jal|jral|jral.|jralnez|beqzal|bltzal|bgezal)(\d)?$") + CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE)) + # Ignore lp register because it's for return. + INDIRECT_CALL_OPERAND_RE = re.compile(r"^\$r\d{1,}$|\$fp$|\$gp$|\$ta$|\$sp$|\$pc$") + # TODO: Handle other kinds of store instructions. + PUSH_OPCODE_RE = re.compile(r"^push(\d{1,})$") + PUSH_OPERAND_RE = re.compile(r"^\$r\d{1,}, \#\d{1,} \! \{([^\]]+)\}") + SMW_OPCODE_RE = re.compile(r"^smw(\.\w\w|\.\w\w\w)$") + SMW_OPERAND_RE = re.compile( + r"^(\$r\d{1,}|\$\wp), \[\$\wp\], " + r"(\$r\d{1,}|\$\wp), \#\d\w\d \! \{([^\]]+)\}" + ) + OPERANDGROUP_RE = re.compile(r"^\$r\d{1,}\~\$r\d{1,}") + + LWI_OPCODE_RE = re.compile(r"^lwi(\.\w\w)$") + LWI_PC_OPERAND_RE = re.compile(r"^\$pc, \[([^\]]+)\]") + # Example: "34280: 3f c8 0f ec addi.gp $fp, #0xfec" + # Assume there is always a "\t" after the hex data. + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+" + r"(?P<words>[0-9A-Fa-f ]+)" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, words, opcode, operand_text): The instruction address, words, + opcode, and the text of operands. + None if it isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) - + return None + + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None + + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + words = result.group("words") + if operand_text is None: + operand_text = "" else: - target_address = int(result.group(1), 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.LWI_OPCODE_RE.match(opcode) is not None: - result = self.LWI_PC_OPERAND_RE.match(operand_text) - if result is not None: - # Ignore "lwi $pc, [$sp], xx" because it's usually a return. - if result.group(1) != '$sp': - # Found an indirect call. - callsites.append(Callsite(address, None, True)) - - elif self.PUSH_OPCODE_RE.match(opcode) is not None: - # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp} - if self.PUSH_OPERAND_RE.match(operand_text) is not None: - # capture fc 20 - imm5u = int(words.split(' ')[1], 16) - # sp = sp - (imm5u << 3) - imm8u = (imm5u<<3) & 0xff - stack_frame += imm8u - - result = self.PUSH_OPERAND_RE.match(operand_text) - operandgroup_text = result.group(1) - # capture $rx~$ry - if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: - # capture number & transfer string to integer - oprandgrouphead = operandgroup_text.split(',')[0] - rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0]))) - ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1]))) - - stack_frame += ((len(operandgroup_text.split(','))+ry-rx) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - else: - stack_frame += (len(operandgroup_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - - elif self.SMW_OPCODE_RE.match(opcode) is not None: - # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp} - if self.SMW_OPERAND_RE.match(operand_text) is not None: - result = self.SMW_OPERAND_RE.match(operand_text) - operandgroup_text = result.group(3) - # capture $rx~$ry - if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: - # capture number & transfer string to integer - oprandgrouphead = operandgroup_text.split(',')[0] - rx=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[0]))) - ry=int(''.join(filter(str.isdigit, oprandgrouphead.split('~')[1]))) - - stack_frame += ((len(operandgroup_text.split(','))+ry-rx) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - else: - stack_frame += (len(operandgroup_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - - return (stack_frame, callsites) - -class ArmAnalyzer(object): - """Disassembly analyzer for ARM architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - GENERAL_PURPOSE_REGISTER_SIZE = 4 - - # Possible condition code suffixes. - CONDITION_CODES = ['', 'eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl', 'vs', - 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - # Assume there is no function name containing ">". - IMM_ADDRESS_RE = r'([0-9A-Fa-f]+)\s+<([^>]+)>' - - # Fuzzy regular expressions for instruction and operand parsing. - # Branch instructions. - JUMP_OPCODE_RE = re.compile( - r'^(b{0}|bx{0})(\.\w)?$'.format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile( - r'^(bl{0}|blx{0})(\.\w)?$'.format(CONDITION_CODES_RE)) - CALL_OPERAND_RE = re.compile(r'^{}$'.format(IMM_ADDRESS_RE)) - CBZ_CBNZ_OPCODE_RE = re.compile(r'^(cbz|cbnz)(\.\w)?$') - # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>" - CBZ_CBNZ_OPERAND_RE = re.compile(r'^[^,]+,\s+{}$'.format(IMM_ADDRESS_RE)) - # Ignore lr register because it's for return. - INDIRECT_CALL_OPERAND_RE = re.compile(r'^r\d+|sb|sl|fp|ip|sp|pc$') - # TODO(cheyuw): Handle conditional versions of following - # instructions. - # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc). - LDR_OPCODE_RE = re.compile(r'^ldr(\.\w)?$') - # Example: "pc, [sp], #4" - LDR_PC_OPERAND_RE = re.compile(r'^pc, \[([^\]]+)\]') - # TODO(cheyuw): Handle other kinds of stm instructions. - PUSH_OPCODE_RE = re.compile(r'^push$') - STM_OPCODE_RE = re.compile(r'^stmdb$') - # Stack subtraction instructions. - SUB_OPCODE_RE = re.compile(r'^sub(s|w)?(\.\w)?$') - SUB_OPERAND_RE = re.compile(r'^sp[^#]+#(\d+)') - # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68" - # Assume there is always a "\t" after the hex data. - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + operand_text = operand_text.strip() + + return (address, words, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + + stack_frame = 0 + callsites = [] + for address, words, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + + if is_jump_opcode or is_call_opcode: + is_tail = is_jump_opcode + + result = self.CALL_OPERAND_RE.match(operand_text) + + if result is None: + if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None: + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + target_address = int(result.group(1), 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append(Callsite(address, target_address, is_tail)) + + elif self.LWI_OPCODE_RE.match(opcode) is not None: + result = self.LWI_PC_OPERAND_RE.match(operand_text) + if result is not None: + # Ignore "lwi $pc, [$sp], xx" because it's usually a return. + if result.group(1) != "$sp": + # Found an indirect call. + callsites.append(Callsite(address, None, True)) + + elif self.PUSH_OPCODE_RE.match(opcode) is not None: + # Example: fc 20 push25 $r8, #0 ! {$r6~$r8, $fp, $gp, $lp} + if self.PUSH_OPERAND_RE.match(operand_text) is not None: + # capture fc 20 + imm5u = int(words.split(" ")[1], 16) + # sp = sp - (imm5u << 3) + imm8u = (imm5u << 3) & 0xFF + stack_frame += imm8u + + result = self.PUSH_OPERAND_RE.match(operand_text) + operandgroup_text = result.group(1) + # capture $rx~$ry + if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: + # capture number & transfer string to integer + oprandgrouphead = operandgroup_text.split(",")[0] + rx = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[0])) + ) + ry = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[1])) + ) + + stack_frame += ( + len(operandgroup_text.split(",")) + ry - rx + ) * self.GENERAL_PURPOSE_REGISTER_SIZE + else: + stack_frame += ( + len(operandgroup_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + elif self.SMW_OPCODE_RE.match(opcode) is not None: + # Example: smw.adm $r6, [$sp], $r10, #0x2 ! {$r6~$r10, $lp} + if self.SMW_OPERAND_RE.match(operand_text) is not None: + result = self.SMW_OPERAND_RE.match(operand_text) + operandgroup_text = result.group(3) + # capture $rx~$ry + if self.OPERANDGROUP_RE.match(operandgroup_text) is not None: + # capture number & transfer string to integer + oprandgrouphead = operandgroup_text.split(",")[0] + rx = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[0])) + ) + ry = int( + "".join(filter(str.isdigit, oprandgrouphead.split("~")[1])) + ) + + stack_frame += ( + len(operandgroup_text.split(",")) + ry - rx + ) * self.GENERAL_PURPOSE_REGISTER_SIZE + else: + stack_frame += ( + len(operandgroup_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + return (stack_frame, callsites) - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - Returns: - (address, opcode, operand_text): The instruction address, opcode, - and the text of operands. None if it - isn't an instruction line. - """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - """Analyze function, resolve the size of stack frame and callsites. - - Args: - function_symbol: Function symbol. - instructions: Instruction list. +class ArmAnalyzer(object): + """Disassembly analyzer for ARM architecture. - Returns: - (stack_frame, callsites): Size of stack frame, callsite list. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - stack_frame = 0 - callsites = [] - for address, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None - if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode: - is_tail = is_jump_opcode or is_cbz_cbnz_opcode - - if is_cbz_cbnz_opcode: - result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text) - else: - result = self.CALL_OPERAND_RE.match(operand_text) + GENERAL_PURPOSE_REGISTER_SIZE = 4 + + # Possible condition code suffixes. + CONDITION_CODES = [ + "", + "eq", + "ne", + "cs", + "hs", + "cc", + "lo", + "mi", + "pl", + "vs", + "vc", + "hi", + "ls", + "ge", + "lt", + "gt", + "le", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + # Assume there is no function name containing ">". + IMM_ADDRESS_RE = r"([0-9A-Fa-f]+)\s+<([^>]+)>" + + # Fuzzy regular expressions for instruction and operand parsing. + # Branch instructions. + JUMP_OPCODE_RE = re.compile(r"^(b{0}|bx{0})(\.\w)?$".format(CONDITION_CODES_RE)) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(bl{0}|blx{0})(\.\w)?$".format(CONDITION_CODES_RE)) + CALL_OPERAND_RE = re.compile(r"^{}$".format(IMM_ADDRESS_RE)) + CBZ_CBNZ_OPCODE_RE = re.compile(r"^(cbz|cbnz)(\.\w)?$") + # Example: "r0, 1009bcbe <host_cmd_motion_sense+0x1d2>" + CBZ_CBNZ_OPERAND_RE = re.compile(r"^[^,]+,\s+{}$".format(IMM_ADDRESS_RE)) + # Ignore lr register because it's for return. + INDIRECT_CALL_OPERAND_RE = re.compile(r"^r\d+|sb|sl|fp|ip|sp|pc$") + # TODO(cheyuw): Handle conditional versions of following + # instructions. + # TODO(cheyuw): Handle other kinds of pc modifying instructions (e.g. mov pc). + LDR_OPCODE_RE = re.compile(r"^ldr(\.\w)?$") + # Example: "pc, [sp], #4" + LDR_PC_OPERAND_RE = re.compile(r"^pc, \[([^\]]+)\]") + # TODO(cheyuw): Handle other kinds of stm instructions. + PUSH_OPCODE_RE = re.compile(r"^push$") + STM_OPCODE_RE = re.compile(r"^stmdb$") + # Stack subtraction instructions. + SUB_OPCODE_RE = re.compile(r"^sub(s|w)?(\.\w)?$") + SUB_OPERAND_RE = re.compile(r"^sp[^#]+#(\d+)") + # Example: "44d94: f893 0068 ldrb.w r0, [r3, #104] ; 0x68" + # Assume there is always a "\t" after the hex data. + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, opcode, operand_text): The instruction address, opcode, + and the text of operands. None if it + isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - # Failed to match immediate address, maybe it is an indirect call. - # CBZ and CBNZ can't be indirect calls. - if (not is_cbz_cbnz_opcode and - self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) + return None - else: - target_address = int(result.group(1), 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.LDR_OPCODE_RE.match(opcode) is not None: - result = self.LDR_PC_OPERAND_RE.match(operand_text) - if result is not None: - # Ignore "ldr pc, [sp], xx" because it's usually a return. - if result.group(1) != 'sp': - # Found an indirect call. - callsites.append(Callsite(address, None, True)) - - elif self.PUSH_OPCODE_RE.match(opcode) is not None: - # Example: "{r4, r5, r6, r7, lr}" - stack_frame += (len(operand_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) - elif self.SUB_OPCODE_RE.match(opcode) is not None: - result = self.SUB_OPERAND_RE.match(operand_text) - if result is not None: - stack_frame += int(result.group(1)) - else: - # Unhandled stack register subtraction. - assert not operand_text.startswith('sp') + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None - elif self.STM_OPCODE_RE.match(opcode) is not None: - if operand_text.startswith('sp!'): - # Subtract and writeback to stack register. - # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}" - # Get the text of pushed register list. - unused_sp, unused_sep, parameter_text = operand_text.partition(',') - stack_frame += (len(parameter_text.split(',')) * - self.GENERAL_PURPOSE_REGISTER_SIZE) + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + if operand_text is None: + operand_text = "" + else: + operand_text = operand_text.strip() + + return (address, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + """Analyze function, resolve the size of stack frame and callsites. + + Args: + function_symbol: Function symbol. + instructions: Instruction list. + + Returns: + (stack_frame, callsites): Size of stack frame, callsite list. + """ + stack_frame = 0 + callsites = [] + for address, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + is_cbz_cbnz_opcode = self.CBZ_CBNZ_OPCODE_RE.match(opcode) is not None + if is_jump_opcode or is_call_opcode or is_cbz_cbnz_opcode: + is_tail = is_jump_opcode or is_cbz_cbnz_opcode + + if is_cbz_cbnz_opcode: + result = self.CBZ_CBNZ_OPERAND_RE.match(operand_text) + else: + result = self.CALL_OPERAND_RE.match(operand_text) + + if result is None: + # Failed to match immediate address, maybe it is an indirect call. + # CBZ and CBNZ can't be indirect calls. + if ( + not is_cbz_cbnz_opcode + and self.INDIRECT_CALL_OPERAND_RE.match(operand_text) + is not None + ): + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + target_address = int(result.group(1), 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append(Callsite(address, target_address, is_tail)) + + elif self.LDR_OPCODE_RE.match(opcode) is not None: + result = self.LDR_PC_OPERAND_RE.match(operand_text) + if result is not None: + # Ignore "ldr pc, [sp], xx" because it's usually a return. + if result.group(1) != "sp": + # Found an indirect call. + callsites.append(Callsite(address, None, True)) + + elif self.PUSH_OPCODE_RE.match(opcode) is not None: + # Example: "{r4, r5, r6, r7, lr}" + stack_frame += ( + len(operand_text.split(",")) * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + elif self.SUB_OPCODE_RE.match(opcode) is not None: + result = self.SUB_OPERAND_RE.match(operand_text) + if result is not None: + stack_frame += int(result.group(1)) + else: + # Unhandled stack register subtraction. + assert not operand_text.startswith("sp") + + elif self.STM_OPCODE_RE.match(opcode) is not None: + if operand_text.startswith("sp!"): + # Subtract and writeback to stack register. + # Example: "sp!, {r4, r5, r6, r7, r8, r9, lr}" + # Get the text of pushed register list. + unused_sp, unused_sep, parameter_text = operand_text.partition(",") + stack_frame += ( + len(parameter_text.split(",")) + * self.GENERAL_PURPOSE_REGISTER_SIZE + ) + + return (stack_frame, callsites) - return (stack_frame, callsites) class RiscvAnalyzer(object): - """Disassembly analyzer for RISC-V architecture. - - Public Methods: - AnalyzeFunction: Analyze stack frame and callsites of the function. - """ - - # Possible condition code suffixes. - CONDITION_CODES = [ 'eqz', 'nez', 'lez', 'gez', 'ltz', 'gtz', 'gt', 'le', - 'gtu', 'leu', 'eq', 'ne', 'ge', 'lt', 'ltu', 'geu'] - CONDITION_CODES_RE = '({})'.format('|'.join(CONDITION_CODES)) - # Branch instructions. - JUMP_OPCODE_RE = re.compile(r'^(b{0}|j|jr)$'.format(CONDITION_CODES_RE)) - # Call instructions. - CALL_OPCODE_RE = re.compile(r'^(jal|jalr)$') - # Example: "j 8009b318 <set_state_prl_hr>" or - # "jal ra,800a4394 <power_get_signals>" or - # "bltu t0,t1,80080300 <data_loop>" - JUMP_ADDRESS_RE = r'((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>' - CALL_OPERAND_RE = re.compile(r'^{}$'.format(JUMP_ADDRESS_RE)) - # Capture address, Example: 800a4394 - CAPTURE_ADDRESS = re.compile(r'[0-9A-Fa-f]{8}') - # Indirect jump, Example: jalr a5 - INDIRECT_CALL_OPERAND_RE = re.compile(r'^t\d+|s\d+|a\d+$') - # Example: addi - ADDI_OPCODE_RE = re.compile(r'^addi$') - # Allocate stack instructions. - ADDI_OPERAND_RE = re.compile(r'^(sp,sp,-\d+)$') - # Example: "800804b6: 1101 addi sp,sp,-32" - DISASM_REGEX_RE = re.compile(r'^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+' - r'\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?') - - def ParseInstruction(self, line, function_end): - """Parse the line of instruction. + """Disassembly analyzer for RISC-V architecture. - Args: - line: Text of disassembly. - function_end: End address of the current function. None if unknown. - - Returns: - (address, opcode, operand_text): The instruction address, opcode, - and the text of operands. None if it - isn't an instruction line. + Public Methods: + AnalyzeFunction: Analyze stack frame and callsites of the function. """ - result = self.DISASM_REGEX_RE.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - # Check if it's out of bound. - if function_end is not None and address >= function_end: - return None - - opcode = result.group('opcode').strip() - operand_text = result.group('operand') - if operand_text is None: - operand_text = '' - else: - operand_text = operand_text.strip() - - return (address, opcode, operand_text) - - def AnalyzeFunction(self, function_symbol, instructions): - - stack_frame = 0 - callsites = [] - for address, opcode, operand_text in instructions: - is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None - is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None - if is_jump_opcode or is_call_opcode: - is_tail = is_jump_opcode - - result = self.CALL_OPERAND_RE.match(operand_text) + # Possible condition code suffixes. + CONDITION_CODES = [ + "eqz", + "nez", + "lez", + "gez", + "ltz", + "gtz", + "gt", + "le", + "gtu", + "leu", + "eq", + "ne", + "ge", + "lt", + "ltu", + "geu", + ] + CONDITION_CODES_RE = "({})".format("|".join(CONDITION_CODES)) + # Branch instructions. + JUMP_OPCODE_RE = re.compile(r"^(b{0}|j|jr)$".format(CONDITION_CODES_RE)) + # Call instructions. + CALL_OPCODE_RE = re.compile(r"^(jal|jalr)$") + # Example: "j 8009b318 <set_state_prl_hr>" or + # "jal ra,800a4394 <power_get_signals>" or + # "bltu t0,t1,80080300 <data_loop>" + JUMP_ADDRESS_RE = r"((\w(\w|\d\d),){0,2})([0-9A-Fa-f]+)\s+<([^>]+)>" + CALL_OPERAND_RE = re.compile(r"^{}$".format(JUMP_ADDRESS_RE)) + # Capture address, Example: 800a4394 + CAPTURE_ADDRESS = re.compile(r"[0-9A-Fa-f]{8}") + # Indirect jump, Example: jalr a5 + INDIRECT_CALL_OPERAND_RE = re.compile(r"^t\d+|s\d+|a\d+$") + # Example: addi + ADDI_OPCODE_RE = re.compile(r"^addi$") + # Allocate stack instructions. + ADDI_OPERAND_RE = re.compile(r"^(sp,sp,-\d+)$") + # Example: "800804b6: 1101 addi sp,sp,-32" + DISASM_REGEX_RE = re.compile( + r"^(?P<address>[0-9A-Fa-f]+):\s+[0-9A-Fa-f ]+" + r"\t\s*(?P<opcode>\S+)(\s+(?P<operand>[^;]*))?" + ) + + def ParseInstruction(self, line, function_end): + """Parse the line of instruction. + + Args: + line: Text of disassembly. + function_end: End address of the current function. None if unknown. + + Returns: + (address, opcode, operand_text): The instruction address, opcode, + and the text of operands. None if it + isn't an instruction line. + """ + result = self.DISASM_REGEX_RE.match(line) if result is None: - if (self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None): - # Found an indirect call. - callsites.append(Callsite(address, None, is_tail)) - - else: - # Capture address form operand_text and then convert to string - address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text)) - # String to integer - target_address = int(address_str, 16) - # Filter out the in-function target (branches and in-function calls, - # which are actually branches). - if not (function_symbol.size > 0 and - function_symbol.address < target_address < - (function_symbol.address + function_symbol.size)): - # Maybe it is a callsite. - callsites.append(Callsite(address, target_address, is_tail)) - - elif self.ADDI_OPCODE_RE.match(opcode) is not None: - # Example: sp,sp,-32 - if self.ADDI_OPERAND_RE.match(operand_text) is not None: - stack_frame += abs(int(operand_text.split(",")[2])) - - return (stack_frame, callsites) - -class StackAnalyzer(object): - """Class to analyze stack usage. - - Public Methods: - Analyze: Run the stack analysis. - """ - - C_FUNCTION_NAME = r'_A-Za-z0-9' - - # Assume there is no ":" in the path. - # Example: "driver/accel_kionix.c:321 (discriminator 3)" - ADDRTOLINE_RE = re.compile( - r'^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$') - # To eliminate the suffix appended by compilers, try to extract the - # C function name from the prefix of symbol name. - # Example: "SHA256_transform.constprop.28" - FUNCTION_PREFIX_NAME_RE = re.compile( - r'^(?P<name>[{0}]+)([^{0}].*)?$'.format(C_FUNCTION_NAME)) - - # Errors of annotation resolving. - ANNOTATION_ERROR_INVALID = 'invalid signature' - ANNOTATION_ERROR_NOTFOUND = 'function is not found' - ANNOTATION_ERROR_AMBIGUOUS = 'signature is ambiguous' - - def __init__(self, options, symbols, rodata, tasklist, annotation): - """Constructor. - - Args: - options: Namespace from argparse.parse_args(). - symbols: Symbol list. - rodata: Content of .rodata section (offset, data) - tasklist: Task list. - annotation: Annotation config. - """ - self.options = options - self.symbols = symbols - self.rodata_offset = rodata[0] - self.rodata = rodata[1] - self.tasklist = tasklist - self.annotation = annotation - self.address_to_line_cache = {} - - def AddressToLine(self, address, resolve_inline=False): - """Convert address to line. - - Args: - address: Target address. - resolve_inline: Output the stack of inlining. - - Returns: - lines: List of the corresponding lines. - - Raises: - StackAnalyzerError: If addr2line is failed. - """ - cache_key = (address, resolve_inline) - if cache_key in self.address_to_line_cache: - return self.address_to_line_cache[cache_key] - - try: - args = [self.options.addr2line, - '-f', - '-e', - self.options.elf_path, - '{:x}'.format(address)] - if resolve_inline: - args.append('-i') - - line_text = subprocess.check_output(args, encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('addr2line failed to resolve lines.') - except OSError: - raise StackAnalyzerError('Failed to run addr2line.') - - lines = [line.strip() for line in line_text.splitlines()] - # Assume the output has at least one pair like "function\nlocation\n", and - # they always show up in pairs. - # Example: "handle_request\n - # common/usb_pd_protocol.c:1191\n" - assert len(lines) >= 2 and len(lines) % 2 == 0 - - line_infos = [] - for index in range(0, len(lines), 2): - (function_name, line_text) = lines[index:index + 2] - if line_text in ['??:0', ':?']: - line_infos.append(None) - else: - result = self.ADDRTOLINE_RE.match(line_text) - # Assume the output is always well-formed. - assert result is not None - line_infos.append((function_name.strip(), - os.path.realpath(result.group('path').strip()), - int(result.group('linenum')))) - - self.address_to_line_cache[cache_key] = line_infos - return line_infos - - def AnalyzeDisassembly(self, disasm_text): - """Parse the disassembly text, analyze, and build a map of all functions. - - Args: - disasm_text: Disassembly text. - - Returns: - function_map: Dict of functions. - """ - disasm_lines = [line.strip() for line in disasm_text.splitlines()] - - if 'nds' in disasm_lines[1]: - analyzer = AndesAnalyzer() - elif 'arm' in disasm_lines[1]: - analyzer = ArmAnalyzer() - elif 'riscv' in disasm_lines[1]: - analyzer = RiscvAnalyzer() - else: - raise StackAnalyzerError('Unsupported architecture.') - - # Example: "08028c8c <motion_lid_calc>:" - function_signature_regex = re.compile( - r'^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$') - - def DetectFunctionHead(line): - """Check if the line is a function head. - - Args: - line: Text of disassembly. - - Returns: - symbol: Function symbol. None if it isn't a function head. - """ - result = function_signature_regex.match(line) - if result is None: - return None - - address = int(result.group('address'), 16) - symbol = symbol_map.get(address) - - # Check if the function exists and matches. - if symbol is None or symbol.symtype != 'F': - return None - - return symbol - - # Build symbol map, indexed by symbol address. - symbol_map = {} - for symbol in self.symbols: - # If there are multiple symbols with same address, keeping any of them is - # good enough. - symbol_map[symbol.address] = symbol - - # Parse the disassembly text. We update the variable "line" to next line - # when needed. There are two steps of parser: - # - # Step 1: Searching for the function head. Once reach the function head, - # move to the next line, which is the first line of function body. - # - # Step 2: Parsing each instruction line of function body. Once reach a - # non-instruction line, stop parsing and analyze the parsed instructions. - # - # Finally turn back to the step 1 without updating the line, because the - # current non-instruction line can be another function head. - function_map = {} - # The following three variables are the states of the parsing processing. - # They will be initialized properly during the state changes. - function_symbol = None - function_end = None - instructions = [] - - # Remove heading and tailing spaces for each line. - line_index = 0 - while line_index < len(disasm_lines): - # Get the current line. - line = disasm_lines[line_index] - - if function_symbol is None: - # Step 1: Search for the function head. - - function_symbol = DetectFunctionHead(line) - if function_symbol is not None: - # Assume there is no empty function. If the function head is followed - # by EOF, it is an empty function. - assert line_index + 1 < len(disasm_lines) - - # Found the function head, initialize and turn to the step 2. - instructions = [] - # If symbol size exists, use it as a hint of function size. - if function_symbol.size > 0: - function_end = function_symbol.address + function_symbol.size - else: - function_end = None - - else: - # Step 2: Parse the function body. - - instruction = analyzer.ParseInstruction(line, function_end) - if instruction is not None: - instructions.append(instruction) - - if instruction is None or line_index + 1 == len(disasm_lines): - # Either the invalid instruction or EOF indicates the end of the - # function, finalize the function analysis. - - # Assume there is no empty function. - assert len(instructions) > 0 - - (stack_frame, callsites) = analyzer.AnalyzeFunction(function_symbol, - instructions) - # Assume the function addresses are unique in the disassembly. - assert function_symbol.address not in function_map - function_map[function_symbol.address] = Function( - function_symbol.address, - function_symbol.name, - stack_frame, - callsites) - - # Initialize and turn back to the step 1. - function_symbol = None - - # If the current line isn't an instruction, it can be another function - # head, skip moving to the next line. - if instruction is None: - continue - - # Move to the next line. - line_index += 1 + return None - # Resolve callees of functions. - for function in function_map.values(): - for callsite in function.callsites: - if callsite.target is not None: - # Remain the callee as None if we can't resolve it. - callsite.callee = function_map.get(callsite.target) + address = int(result.group("address"), 16) + # Check if it's out of bound. + if function_end is not None and address >= function_end: + return None - return function_map + opcode = result.group("opcode").strip() + operand_text = result.group("operand") + if operand_text is None: + operand_text = "" + else: + operand_text = operand_text.strip() + + return (address, opcode, operand_text) + + def AnalyzeFunction(self, function_symbol, instructions): + + stack_frame = 0 + callsites = [] + for address, opcode, operand_text in instructions: + is_jump_opcode = self.JUMP_OPCODE_RE.match(opcode) is not None + is_call_opcode = self.CALL_OPCODE_RE.match(opcode) is not None + + if is_jump_opcode or is_call_opcode: + is_tail = is_jump_opcode + + result = self.CALL_OPERAND_RE.match(operand_text) + if result is None: + if self.INDIRECT_CALL_OPERAND_RE.match(operand_text) is not None: + # Found an indirect call. + callsites.append(Callsite(address, None, is_tail)) + + else: + # Capture address form operand_text and then convert to string + address_str = "".join(self.CAPTURE_ADDRESS.findall(operand_text)) + # String to integer + target_address = int(address_str, 16) + # Filter out the in-function target (branches and in-function calls, + # which are actually branches). + if not ( + function_symbol.size > 0 + and function_symbol.address + < target_address + < (function_symbol.address + function_symbol.size) + ): + # Maybe it is a callsite. + callsites.append(Callsite(address, target_address, is_tail)) + + elif self.ADDI_OPCODE_RE.match(opcode) is not None: + # Example: sp,sp,-32 + if self.ADDI_OPERAND_RE.match(operand_text) is not None: + stack_frame += abs(int(operand_text.split(",")[2])) + + return (stack_frame, callsites) - def MapAnnotation(self, function_map, signature_set): - """Map annotation signatures to functions. - Args: - function_map: Function map. - signature_set: Set of annotation signatures. +class StackAnalyzer(object): + """Class to analyze stack usage. - Returns: - Map of signatures to functions, map of signatures which can't be resolved. + Public Methods: + Analyze: Run the stack analysis. """ - # Build the symbol map indexed by symbol name. If there are multiple symbols - # with the same name, add them into a set. (e.g. symbols of static function - # with the same name) - symbol_map = collections.defaultdict(set) - for symbol in self.symbols: - if symbol.symtype == 'F': - # Function symbol. - result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) - if result is not None: - function = function_map.get(symbol.address) - # Ignore the symbol not in disassembly. - if function is not None: - # If there are multiple symbol with the same name and point to the - # same function, the set will deduplicate them. - symbol_map[result.group('name').strip()].add(function) - - # Build the signature map indexed by annotation signature. - signature_map = {} - sig_error_map = {} - symbol_path_map = {} - for sig in signature_set: - (name, path, _) = sig - - functions = symbol_map.get(name) - if functions is None: - sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND - continue - - if name not in symbol_path_map: - # Lazy symbol path resolving. Since the addr2line isn't fast, only - # resolve needed symbol paths. - group_map = collections.defaultdict(list) - for function in functions: - line_info = self.AddressToLine(function.address)[0] - if line_info is None: - continue - (_, symbol_path, _) = line_info + C_FUNCTION_NAME = r"_A-Za-z0-9" - # Group the functions with the same symbol signature (symbol name + - # symbol path). Assume they are the same copies and do the same - # annotation operations of them because we don't know which copy is - # indicated by the users. - group_map[symbol_path].append(function) + # Assume there is no ":" in the path. + # Example: "driver/accel_kionix.c:321 (discriminator 3)" + ADDRTOLINE_RE = re.compile( + r"^(?P<path>[^:]+):(?P<linenum>\d+)(\s+\(discriminator\s+\d+\))?$" + ) + # To eliminate the suffix appended by compilers, try to extract the + # C function name from the prefix of symbol name. + # Example: "SHA256_transform.constprop.28" + FUNCTION_PREFIX_NAME_RE = re.compile( + r"^(?P<name>[{0}]+)([^{0}].*)?$".format(C_FUNCTION_NAME) + ) + + # Errors of annotation resolving. + ANNOTATION_ERROR_INVALID = "invalid signature" + ANNOTATION_ERROR_NOTFOUND = "function is not found" + ANNOTATION_ERROR_AMBIGUOUS = "signature is ambiguous" + + def __init__(self, options, symbols, rodata, tasklist, annotation): + """Constructor. + + Args: + options: Namespace from argparse.parse_args(). + symbols: Symbol list. + rodata: Content of .rodata section (offset, data) + tasklist: Task list. + annotation: Annotation config. + """ + self.options = options + self.symbols = symbols + self.rodata_offset = rodata[0] + self.rodata = rodata[1] + self.tasklist = tasklist + self.annotation = annotation + self.address_to_line_cache = {} + + def AddressToLine(self, address, resolve_inline=False): + """Convert address to line. + + Args: + address: Target address. + resolve_inline: Output the stack of inlining. + + Returns: + lines: List of the corresponding lines. + + Raises: + StackAnalyzerError: If addr2line is failed. + """ + cache_key = (address, resolve_inline) + if cache_key in self.address_to_line_cache: + return self.address_to_line_cache[cache_key] + + try: + args = [ + self.options.addr2line, + "-f", + "-e", + self.options.elf_path, + "{:x}".format(address), + ] + if resolve_inline: + args.append("-i") + + line_text = subprocess.check_output(args, encoding="utf-8") + except subprocess.CalledProcessError: + raise StackAnalyzerError("addr2line failed to resolve lines.") + except OSError: + raise StackAnalyzerError("Failed to run addr2line.") + + lines = [line.strip() for line in line_text.splitlines()] + # Assume the output has at least one pair like "function\nlocation\n", and + # they always show up in pairs. + # Example: "handle_request\n + # common/usb_pd_protocol.c:1191\n" + assert len(lines) >= 2 and len(lines) % 2 == 0 + + line_infos = [] + for index in range(0, len(lines), 2): + (function_name, line_text) = lines[index : index + 2] + if line_text in ["??:0", ":?"]: + line_infos.append(None) + else: + result = self.ADDRTOLINE_RE.match(line_text) + # Assume the output is always well-formed. + assert result is not None + line_infos.append( + ( + function_name.strip(), + os.path.realpath(result.group("path").strip()), + int(result.group("linenum")), + ) + ) + + self.address_to_line_cache[cache_key] = line_infos + return line_infos + + def AnalyzeDisassembly(self, disasm_text): + """Parse the disassembly text, analyze, and build a map of all functions. + + Args: + disasm_text: Disassembly text. + + Returns: + function_map: Dict of functions. + """ + disasm_lines = [line.strip() for line in disasm_text.splitlines()] + + if "nds" in disasm_lines[1]: + analyzer = AndesAnalyzer() + elif "arm" in disasm_lines[1]: + analyzer = ArmAnalyzer() + elif "riscv" in disasm_lines[1]: + analyzer = RiscvAnalyzer() + else: + raise StackAnalyzerError("Unsupported architecture.") - symbol_path_map[name] = group_map + # Example: "08028c8c <motion_lid_calc>:" + function_signature_regex = re.compile( + r"^(?P<address>[0-9A-Fa-f]+)\s+<(?P<name>[^>]+)>:$" + ) - # Symbol matching. - function_group = None - group_map = symbol_path_map[name] - if len(group_map) > 0: - if path is None: - if len(group_map) > 1: - # There is ambiguity but the path isn't specified. - sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS - continue + def DetectFunctionHead(line): + """Check if the line is a function head. - # No path signature but all symbol signatures of functions are same. - # Assume they are the same functions, so there is no ambiguity. - (function_group,) = group_map.values() - else: - function_group = group_map.get(path) + Args: + line: Text of disassembly. - if function_group is None: - sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND - continue + Returns: + symbol: Function symbol. None if it isn't a function head. + """ + result = function_signature_regex.match(line) + if result is None: + return None - # The function_group is a list of all the same functions (according to - # our assumption) which should be annotated together. - signature_map[sig] = function_group + address = int(result.group("address"), 16) + symbol = symbol_map.get(address) - return (signature_map, sig_error_map) + # Check if the function exists and matches. + if symbol is None or symbol.symtype != "F": + return None - def LoadAnnotation(self): - """Load annotation rules. + return symbol - Returns: - Map of add rules, set of remove rules, set of text signatures which can't - be parsed. - """ - # Assume there is no ":" in the path. - # Example: "get_range.lto.2501[driver/accel_kionix.c:327]" - annotation_signature_regex = re.compile( - r'^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$') - - def NormalizeSignature(signature_text): - """Parse and normalize the annotation signature. - - Args: - signature_text: Text of the annotation signature. - - Returns: - (function name, path, line number) of the signature. The path and line - number can be None if not exist. None if failed to parse. - """ - result = annotation_signature_regex.match(signature_text.strip()) - if result is None: - return None - - name_result = self.FUNCTION_PREFIX_NAME_RE.match( - result.group('name').strip()) - if name_result is None: - return None - - path = result.group('path') - if path is not None: - path = os.path.realpath(path.strip()) - - linenum = result.group('linenum') - if linenum is not None: - linenum = int(linenum.strip()) - - return (name_result.group('name').strip(), path, linenum) - - def ExpandArray(dic): - """Parse and expand a symbol array - - Args: - dic: Dictionary for the array annotation - - Returns: - array of (symbol name, None, None). - """ - # TODO(drinkcat): This function is quite inefficient, as it goes through - # the symbol table multiple times. - - begin_name = dic['name'] - end_name = dic['name'] + "_end" - offset = dic['offset'] if 'offset' in dic else 0 - stride = dic['stride'] - - begin_address = None - end_address = None - - for symbol in self.symbols: - if (symbol.name == begin_name): - begin_address = symbol.address - if (symbol.name == end_name): - end_address = symbol.address - - if (not begin_address or not end_address): - return None - - output = [] - # TODO(drinkcat): This is inefficient as we go from address to symbol - # object then to symbol name, and later on we'll go back from symbol name - # to symbol object. - for addr in range(begin_address+offset, end_address, stride): - # TODO(drinkcat): Not all architectures need to drop the first bit. - val = self.rodata[(addr-self.rodata_offset) // 4] & 0xfffffffe - name = None + # Build symbol map, indexed by symbol address. + symbol_map = {} for symbol in self.symbols: - if (symbol.address == val): - result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) - name = result.group('name') - break - - if not name: - raise StackAnalyzerError('Cannot find function for address %s.', - hex(val)) - - output.append((name, None, None)) - - return output - - add_rules = collections.defaultdict(set) - remove_rules = list() - invalid_sigtxts = set() - - if 'add' in self.annotation and self.annotation['add'] is not None: - for src_sigtxt, dst_sigtxts in self.annotation['add'].items(): - src_sig = NormalizeSignature(src_sigtxt) - if src_sig is None: - invalid_sigtxts.add(src_sigtxt) - continue - - for dst_sigtxt in dst_sigtxts: - if isinstance(dst_sigtxt, dict): - dst_sig = ExpandArray(dst_sigtxt) - if dst_sig is None: - invalid_sigtxts.add(str(dst_sigtxt)) + # If there are multiple symbols with same address, keeping any of them is + # good enough. + symbol_map[symbol.address] = symbol + + # Parse the disassembly text. We update the variable "line" to next line + # when needed. There are two steps of parser: + # + # Step 1: Searching for the function head. Once reach the function head, + # move to the next line, which is the first line of function body. + # + # Step 2: Parsing each instruction line of function body. Once reach a + # non-instruction line, stop parsing and analyze the parsed instructions. + # + # Finally turn back to the step 1 without updating the line, because the + # current non-instruction line can be another function head. + function_map = {} + # The following three variables are the states of the parsing processing. + # They will be initialized properly during the state changes. + function_symbol = None + function_end = None + instructions = [] + + # Remove heading and tailing spaces for each line. + line_index = 0 + while line_index < len(disasm_lines): + # Get the current line. + line = disasm_lines[line_index] + + if function_symbol is None: + # Step 1: Search for the function head. + + function_symbol = DetectFunctionHead(line) + if function_symbol is not None: + # Assume there is no empty function. If the function head is followed + # by EOF, it is an empty function. + assert line_index + 1 < len(disasm_lines) + + # Found the function head, initialize and turn to the step 2. + instructions = [] + # If symbol size exists, use it as a hint of function size. + if function_symbol.size > 0: + function_end = function_symbol.address + function_symbol.size + else: + function_end = None + else: - add_rules[src_sig].update(dst_sig) - else: - dst_sig = NormalizeSignature(dst_sigtxt) - if dst_sig is None: - invalid_sigtxts.add(dst_sigtxt) + # Step 2: Parse the function body. + + instruction = analyzer.ParseInstruction(line, function_end) + if instruction is not None: + instructions.append(instruction) + + if instruction is None or line_index + 1 == len(disasm_lines): + # Either the invalid instruction or EOF indicates the end of the + # function, finalize the function analysis. + + # Assume there is no empty function. + assert len(instructions) > 0 + + (stack_frame, callsites) = analyzer.AnalyzeFunction( + function_symbol, instructions + ) + # Assume the function addresses are unique in the disassembly. + assert function_symbol.address not in function_map + function_map[function_symbol.address] = Function( + function_symbol.address, + function_symbol.name, + stack_frame, + callsites, + ) + + # Initialize and turn back to the step 1. + function_symbol = None + + # If the current line isn't an instruction, it can be another function + # head, skip moving to the next line. + if instruction is None: + continue + + # Move to the next line. + line_index += 1 + + # Resolve callees of functions. + for function in function_map.values(): + for callsite in function.callsites: + if callsite.target is not None: + # Remain the callee as None if we can't resolve it. + callsite.callee = function_map.get(callsite.target) + + return function_map + + def MapAnnotation(self, function_map, signature_set): + """Map annotation signatures to functions. + + Args: + function_map: Function map. + signature_set: Set of annotation signatures. + + Returns: + Map of signatures to functions, map of signatures which can't be resolved. + """ + # Build the symbol map indexed by symbol name. If there are multiple symbols + # with the same name, add them into a set. (e.g. symbols of static function + # with the same name) + symbol_map = collections.defaultdict(set) + for symbol in self.symbols: + if symbol.symtype == "F": + # Function symbol. + result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) + if result is not None: + function = function_map.get(symbol.address) + # Ignore the symbol not in disassembly. + if function is not None: + # If there are multiple symbol with the same name and point to the + # same function, the set will deduplicate them. + symbol_map[result.group("name").strip()].add(function) + + # Build the signature map indexed by annotation signature. + signature_map = {} + sig_error_map = {} + symbol_path_map = {} + for sig in signature_set: + (name, path, _) = sig + + functions = symbol_map.get(name) + if functions is None: + sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND + continue + + if name not in symbol_path_map: + # Lazy symbol path resolving. Since the addr2line isn't fast, only + # resolve needed symbol paths. + group_map = collections.defaultdict(list) + for function in functions: + line_info = self.AddressToLine(function.address)[0] + if line_info is None: + continue + + (_, symbol_path, _) = line_info + + # Group the functions with the same symbol signature (symbol name + + # symbol path). Assume they are the same copies and do the same + # annotation operations of them because we don't know which copy is + # indicated by the users. + group_map[symbol_path].append(function) + + symbol_path_map[name] = group_map + + # Symbol matching. + function_group = None + group_map = symbol_path_map[name] + if len(group_map) > 0: + if path is None: + if len(group_map) > 1: + # There is ambiguity but the path isn't specified. + sig_error_map[sig] = self.ANNOTATION_ERROR_AMBIGUOUS + continue + + # No path signature but all symbol signatures of functions are same. + # Assume they are the same functions, so there is no ambiguity. + (function_group,) = group_map.values() + else: + function_group = group_map.get(path) + + if function_group is None: + sig_error_map[sig] = self.ANNOTATION_ERROR_NOTFOUND + continue + + # The function_group is a list of all the same functions (according to + # our assumption) which should be annotated together. + signature_map[sig] = function_group + + return (signature_map, sig_error_map) + + def LoadAnnotation(self): + """Load annotation rules. + + Returns: + Map of add rules, set of remove rules, set of text signatures which can't + be parsed. + """ + # Assume there is no ":" in the path. + # Example: "get_range.lto.2501[driver/accel_kionix.c:327]" + annotation_signature_regex = re.compile( + r"^(?P<name>[^\[]+)(\[(?P<path>[^:]+)(:(?P<linenum>\d+))?\])?$" + ) + + def NormalizeSignature(signature_text): + """Parse and normalize the annotation signature. + + Args: + signature_text: Text of the annotation signature. + + Returns: + (function name, path, line number) of the signature. The path and line + number can be None if not exist. None if failed to parse. + """ + result = annotation_signature_regex.match(signature_text.strip()) + if result is None: + return None + + name_result = self.FUNCTION_PREFIX_NAME_RE.match( + result.group("name").strip() + ) + if name_result is None: + return None + + path = result.group("path") + if path is not None: + path = os.path.realpath(path.strip()) + + linenum = result.group("linenum") + if linenum is not None: + linenum = int(linenum.strip()) + + return (name_result.group("name").strip(), path, linenum) + + def ExpandArray(dic): + """Parse and expand a symbol array + + Args: + dic: Dictionary for the array annotation + + Returns: + array of (symbol name, None, None). + """ + # TODO(drinkcat): This function is quite inefficient, as it goes through + # the symbol table multiple times. + + begin_name = dic["name"] + end_name = dic["name"] + "_end" + offset = dic["offset"] if "offset" in dic else 0 + stride = dic["stride"] + + begin_address = None + end_address = None + + for symbol in self.symbols: + if symbol.name == begin_name: + begin_address = symbol.address + if symbol.name == end_name: + end_address = symbol.address + + if not begin_address or not end_address: + return None + + output = [] + # TODO(drinkcat): This is inefficient as we go from address to symbol + # object then to symbol name, and later on we'll go back from symbol name + # to symbol object. + for addr in range(begin_address + offset, end_address, stride): + # TODO(drinkcat): Not all architectures need to drop the first bit. + val = self.rodata[(addr - self.rodata_offset) // 4] & 0xFFFFFFFE + name = None + for symbol in self.symbols: + if symbol.address == val: + result = self.FUNCTION_PREFIX_NAME_RE.match(symbol.name) + name = result.group("name") + break + + if not name: + raise StackAnalyzerError( + "Cannot find function for address %s.", hex(val) + ) + + output.append((name, None, None)) + + return output + + add_rules = collections.defaultdict(set) + remove_rules = list() + invalid_sigtxts = set() + + if "add" in self.annotation and self.annotation["add"] is not None: + for src_sigtxt, dst_sigtxts in self.annotation["add"].items(): + src_sig = NormalizeSignature(src_sigtxt) + if src_sig is None: + invalid_sigtxts.add(src_sigtxt) + continue + + for dst_sigtxt in dst_sigtxts: + if isinstance(dst_sigtxt, dict): + dst_sig = ExpandArray(dst_sigtxt) + if dst_sig is None: + invalid_sigtxts.add(str(dst_sigtxt)) + else: + add_rules[src_sig].update(dst_sig) + else: + dst_sig = NormalizeSignature(dst_sigtxt) + if dst_sig is None: + invalid_sigtxts.add(dst_sigtxt) + else: + add_rules[src_sig].add(dst_sig) + + if "remove" in self.annotation and self.annotation["remove"] is not None: + for sigtxt_path in self.annotation["remove"]: + if isinstance(sigtxt_path, str): + # The path has only one vertex. + sigtxt_path = [sigtxt_path] + + if len(sigtxt_path) == 0: + continue + + # Generate multiple remove paths from all the combinations of the + # signatures of each vertex. + sig_paths = [[]] + broken_flag = False + for sigtxt_node in sigtxt_path: + if isinstance(sigtxt_node, str): + # The vertex has only one signature. + sigtxt_set = {sigtxt_node} + elif isinstance(sigtxt_node, list): + # The vertex has multiple signatures. + sigtxt_set = set(sigtxt_node) + else: + # Assume the format of annotation is verified. There should be no + # invalid case. + assert False + + sig_set = set() + for sigtxt in sigtxt_set: + sig = NormalizeSignature(sigtxt) + if sig is None: + invalid_sigtxts.add(sigtxt) + broken_flag = True + elif not broken_flag: + sig_set.add(sig) + + if broken_flag: + continue + + # Append each signature of the current node to the all previous + # remove paths. + sig_paths = [path + [sig] for path in sig_paths for sig in sig_set] + + if not broken_flag: + # All signatures are normalized. The remove path has no error. + remove_rules.extend(sig_paths) + + return (add_rules, remove_rules, invalid_sigtxts) + + def ResolveAnnotation(self, function_map): + """Resolve annotation. + + Args: + function_map: Function map. + + Returns: + Set of added call edges, list of remove paths, set of eliminated + callsite addresses, set of annotation signatures which can't be resolved. + """ + + def StringifySignature(signature): + """Stringify the tupled signature. + + Args: + signature: Tupled signature. + + Returns: + Signature string. + """ + (name, path, linenum) = signature + bracket_text = "" + if path is not None: + path = os.path.relpath(path) + if linenum is None: + bracket_text = "[{}]".format(path) + else: + bracket_text = "[{}:{}]".format(path, linenum) + + return name + bracket_text + + (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation() + + signature_set = set() + for src_sig, dst_sigs in add_rules.items(): + signature_set.add(src_sig) + signature_set.update(dst_sigs) + + for remove_sigs in remove_rules: + signature_set.update(remove_sigs) + + # Map signatures to functions. + (signature_map, sig_error_map) = self.MapAnnotation(function_map, signature_set) + + # Build the indirect callsite map indexed by callsite signature. + indirect_map = collections.defaultdict(set) + for function in function_map.values(): + for callsite in function.callsites: + if callsite.target is not None: + continue + + # Found an indirect callsite. + line_info = self.AddressToLine(callsite.address)[0] + if line_info is None: + continue + + (name, path, linenum) = line_info + result = self.FUNCTION_PREFIX_NAME_RE.match(name) + if result is None: + continue + + indirect_map[(result.group("name").strip(), path, linenum)].add( + (function, callsite.address) + ) + + # Generate the annotation sets. + add_set = set() + remove_list = list() + eliminated_addrs = set() + + for src_sig, dst_sigs in add_rules.items(): + src_funcs = set(signature_map.get(src_sig, [])) + # Try to match the source signature to the indirect callsites. Even if it + # can't be found in disassembly. + indirect_calls = indirect_map.get(src_sig) + if indirect_calls is not None: + for function, callsite_address in indirect_calls: + # Add the caller of the indirect callsite to the source functions. + src_funcs.add(function) + # Assume each callsite can be represented by a unique address. + eliminated_addrs.add(callsite_address) + + if src_sig in sig_error_map: + # Assume the error is always the not found error. Since the signature + # found in indirect callsite map must be a full signature, it can't + # happen the ambiguous error. + assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND + # Found in inline stack, remove the not found error. + del sig_error_map[src_sig] + + for dst_sig in dst_sigs: + dst_funcs = signature_map.get(dst_sig) + if dst_funcs is None: + continue + + # Duplicate the call edge for all the same source and destination + # functions. + for src_func in src_funcs: + for dst_func in dst_funcs: + add_set.add((src_func, dst_func)) + + for remove_sigs in remove_rules: + # Since each signature can be mapped to multiple functions, generate + # multiple remove paths from all the combinations of these functions. + remove_paths = [[]] + skip_flag = False + for remove_sig in remove_sigs: + # Transform each signature to the corresponding functions. + remove_funcs = signature_map.get(remove_sig) + if remove_funcs is None: + # There is an unresolved signature in the remove path. Ignore the + # whole broken remove path. + skip_flag = True + break + else: + # Append each function of the current signature to the all previous + # remove paths. + remove_paths = [p + [f] for p in remove_paths for f in remove_funcs] + + if skip_flag: + # Ignore the broken remove path. + continue + + for remove_path in remove_paths: + # Deduplicate the remove paths. + if remove_path not in remove_list: + remove_list.append(remove_path) + + # Format the error messages. + failed_sigtxts = set() + for sigtxt in invalid_sigtxts: + failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID)) + + for sig, error in sig_error_map.items(): + failed_sigtxts.add((StringifySignature(sig), error)) + + return (add_set, remove_list, eliminated_addrs, failed_sigtxts) + + def PreprocessAnnotation( + self, function_map, add_set, remove_list, eliminated_addrs + ): + """Preprocess the annotation and callgraph. + + Add the missing call edges, and delete simple remove paths (the paths have + one or two vertices) from the function_map. + + Eliminate the annotated indirect callsites. + + Return the remaining remove list. + + Args: + function_map: Function map. + add_set: Set of missing call edges. + remove_list: List of remove paths. + eliminated_addrs: Set of eliminated callsite addresses. + + Returns: + List of remaining remove paths. + """ + + def CheckEdge(path): + """Check if all edges of the path are on the callgraph. + + Args: + path: Path. + + Returns: + True or False. + """ + for index in range(len(path) - 1): + if (path[index], path[index + 1]) not in edge_set: + return False + + return True + + for src_func, dst_func in add_set: + # TODO(cheyuw): Support tailing call annotation. + src_func.callsites.append(Callsite(None, dst_func.address, False, dst_func)) + + # Delete simple remove paths. + remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2) + edge_set = set() + for function in function_map.values(): + cleaned_callsites = [] + for callsite in function.callsites: + if (callsite.callee,) in remove_simple or ( + function, + callsite.callee, + ) in remove_simple: + continue + + if callsite.target is None and callsite.address in eliminated_addrs: + continue + + cleaned_callsites.append(callsite) + if callsite.callee is not None: + edge_set.add((function, callsite.callee)) + + function.callsites = cleaned_callsites + + return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)] + + def AnalyzeCallGraph(self, function_map, remove_list): + """Analyze callgraph. + + It will update the max stack size and path for each function. + + Args: + function_map: Function map. + remove_list: List of remove paths. + + Returns: + List of function cycles. + """ + + def Traverse(curr_state): + """Traverse the callgraph and calculate the max stack usages of functions. + + Args: + curr_state: Current state. + + Returns: + SCC lowest link. + """ + scc_index = scc_index_counter[0] + scc_index_counter[0] += 1 + scc_index_map[curr_state] = scc_index + scc_lowlink = scc_index + scc_stack.append(curr_state) + # Push the current state in the stack. We can use a set to maintain this + # because the stacked states are unique; otherwise we will find a cycle + # first. + stacked_states.add(curr_state) + + (curr_address, curr_positions) = curr_state + curr_func = function_map[curr_address] + + invalid_flag = False + new_positions = list(curr_positions) + for index, position in enumerate(curr_positions): + remove_path = remove_list[index] + + # The position of each remove path in the state is the length of the + # longest matching path between the prefix of the remove path and the + # suffix of the current traversing path. We maintain this length when + # appending the next callee to the traversing path. And it can be used + # to check if the remove path appears in the traversing path. + + # TODO(cheyuw): Implement KMP algorithm to match remove paths + # efficiently. + if remove_path[position] is curr_func: + # Matches the current function, extend the length. + new_positions[index] = position + 1 + if new_positions[index] == len(remove_path): + # The length of the longest matching path is equal to the length of + # the remove path, which means the suffix of the current traversing + # path matches the remove path. + invalid_flag = True + break + + else: + # We can't get the new longest matching path by extending the previous + # one directly. Fallback to search the new longest matching path. + + # If we can't find any matching path in the following search, reset + # the matching length to 0. + new_positions[index] = 0 + + # We want to find the new longest matching prefix of remove path with + # the suffix of the current traversing path. Because the new longest + # matching path won't be longer than the prevous one now, and part of + # the suffix matches the prefix of remove path, we can get the needed + # suffix from the previous matching prefix of the invalid path. + suffix = remove_path[:position] + [curr_func] + for offset in range(1, len(suffix)): + length = position - offset + if remove_path[:length] == suffix[offset:]: + new_positions[index] = length + break + + new_positions = tuple(new_positions) + + # If the current suffix is invalid, set the max stack usage to 0. + max_stack_usage = 0 + max_callee_state = None + self_loop = False + + if not invalid_flag: + # Max stack usage is at least equal to the stack frame. + max_stack_usage = curr_func.stack_frame + for callsite in curr_func.callsites: + callee = callsite.callee + if callee is None: + continue + + callee_state = (callee.address, new_positions) + if callee_state not in scc_index_map: + # Unvisited state. + scc_lowlink = min(scc_lowlink, Traverse(callee_state)) + elif callee_state in stacked_states: + # The state is shown in the stack. There is a cycle. + sub_stack_usage = 0 + scc_lowlink = min(scc_lowlink, scc_index_map[callee_state]) + if callee_state == curr_state: + self_loop = True + + done_result = done_states.get(callee_state) + if done_result is not None: + # Already done this state and use its result. If the state reaches a + # cycle, reusing the result will cause inaccuracy (the stack usage + # of cycle depends on where the entrance is). But it's fine since we + # can't get accurate stack usage under this situation, and we rely + # on user-provided annotations to break the cycle, after which the + # result will be accurate again. + (sub_stack_usage, _) = done_result + + if callsite.is_tail: + # For tailing call, since the callee reuses the stack frame of the + # caller, choose the larger one directly. + stack_usage = max(curr_func.stack_frame, sub_stack_usage) + else: + stack_usage = curr_func.stack_frame + sub_stack_usage + + if stack_usage > max_stack_usage: + max_stack_usage = stack_usage + max_callee_state = callee_state + + if scc_lowlink == scc_index: + group = [] + while scc_stack[-1] != curr_state: + scc_state = scc_stack.pop() + stacked_states.remove(scc_state) + group.append(scc_state) + + scc_stack.pop() + stacked_states.remove(curr_state) + + # If the cycle is not empty, record it. + if len(group) > 0 or self_loop: + group.append(curr_state) + cycle_groups.append(group) + + # Store the done result. + done_states[curr_state] = (max_stack_usage, max_callee_state) + + if curr_positions == initial_positions: + # If the current state is initial state, we traversed the callgraph by + # using the current function as start point. Update the stack usage of + # the function. + # If the function matches a single vertex remove path, this will set its + # max stack usage to 0, which is not expected (we still calculate its + # max stack usage, but prevent any function from calling it). However, + # all the single vertex remove paths have been preprocessed and removed. + curr_func.stack_max_usage = max_stack_usage + + # Reconstruct the max stack path by traversing the state transitions. + max_stack_path = [curr_func] + callee_state = max_callee_state + while callee_state is not None: + # The first element of state tuple is function address. + max_stack_path.append(function_map[callee_state[0]]) + done_result = done_states.get(callee_state) + # All of the descendants should be done. + assert done_result is not None + (_, callee_state) = done_result + + curr_func.stack_max_path = max_stack_path + + return scc_lowlink + + # The state is the concatenation of the current function address and the + # state of matching position. + initial_positions = (0,) * len(remove_list) + done_states = {} + stacked_states = set() + scc_index_counter = [0] + scc_index_map = {} + scc_stack = [] + cycle_groups = [] + for function in function_map.values(): + if function.stack_max_usage is None: + Traverse((function.address, initial_positions)) + + cycle_functions = [] + for group in cycle_groups: + cycle = set(function_map[state[0]] for state in group) + if cycle not in cycle_functions: + cycle_functions.append(cycle) + + return cycle_functions + + def Analyze(self): + """Run the stack analysis. + + Raises: + StackAnalyzerError: If disassembly fails. + """ + + def OutputInlineStack(address, prefix=""): + """Output beautiful inline stack. + + Args: + address: Address. + prefix: Prefix of each line. + + Returns: + Key for sorting, output text + """ + line_infos = self.AddressToLine(address, True) + + if line_infos[0] is None: + order_key = (None, None) else: - add_rules[src_sig].add(dst_sig) - - if 'remove' in self.annotation and self.annotation['remove'] is not None: - for sigtxt_path in self.annotation['remove']: - if isinstance(sigtxt_path, str): - # The path has only one vertex. - sigtxt_path = [sigtxt_path] - - if len(sigtxt_path) == 0: - continue - - # Generate multiple remove paths from all the combinations of the - # signatures of each vertex. - sig_paths = [[]] - broken_flag = False - for sigtxt_node in sigtxt_path: - if isinstance(sigtxt_node, str): - # The vertex has only one signature. - sigtxt_set = {sigtxt_node} - elif isinstance(sigtxt_node, list): - # The vertex has multiple signatures. - sigtxt_set = set(sigtxt_node) - else: - # Assume the format of annotation is verified. There should be no - # invalid case. - assert False - - sig_set = set() - for sigtxt in sigtxt_set: - sig = NormalizeSignature(sigtxt) - if sig is None: - invalid_sigtxts.add(sigtxt) - broken_flag = True - elif not broken_flag: - sig_set.add(sig) - - if broken_flag: - continue - - # Append each signature of the current node to the all previous - # remove paths. - sig_paths = [path + [sig] for path in sig_paths for sig in sig_set] - - if not broken_flag: - # All signatures are normalized. The remove path has no error. - remove_rules.extend(sig_paths) - - return (add_rules, remove_rules, invalid_sigtxts) + (_, path, linenum) = line_infos[0] + order_key = (linenum, path) + + line_texts = [] + for line_info in reversed(line_infos): + if line_info is None: + (function_name, path, linenum) = ("??", "??", 0) + else: + (function_name, path, linenum) = line_info + + line_texts.append( + "{}[{}:{}]".format(function_name, os.path.relpath(path), linenum) + ) + + output = "{}-> {} {:x}\n".format(prefix, line_texts[0], address) + for depth, line_text in enumerate(line_texts[1:]): + output += "{} {}- {}\n".format(prefix, " " * depth, line_text) + + # Remove the last newline character. + return (order_key, output.rstrip("\n")) + + # Analyze disassembly. + try: + disasm_text = subprocess.check_output( + [self.options.objdump, "-d", self.options.elf_path], encoding="utf-8" + ) + except subprocess.CalledProcessError: + raise StackAnalyzerError("objdump failed to disassemble.") + except OSError: + raise StackAnalyzerError("Failed to run objdump.") + + function_map = self.AnalyzeDisassembly(disasm_text) + result = self.ResolveAnnotation(function_map) + (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result + remove_list = self.PreprocessAnnotation( + function_map, add_set, remove_list, eliminated_addrs + ) + cycle_functions = self.AnalyzeCallGraph(function_map, remove_list) + + # Print the results of task-aware stack analysis. + extra_stack_frame = self.annotation.get( + "exception_frame_size", DEFAULT_EXCEPTION_FRAME_SIZE + ) + for task in self.tasklist: + routine_func = function_map[task.routine_address] + print( + "Task: {}, Max size: {} ({} + {}), Allocated size: {}".format( + task.name, + routine_func.stack_max_usage + extra_stack_frame, + routine_func.stack_max_usage, + extra_stack_frame, + task.stack_max_size, + ) + ) + + print("Call Trace:") + max_stack_path = routine_func.stack_max_path + # Assume the routine function is resolved. + assert max_stack_path is not None + for depth, curr_func in enumerate(max_stack_path): + line_info = self.AddressToLine(curr_func.address)[0] + if line_info is None: + (path, linenum) = ("??", 0) + else: + (_, path, linenum) = line_info + + print( + " {} ({}) [{}:{}] {:x}".format( + curr_func.name, + curr_func.stack_frame, + os.path.relpath(path), + linenum, + curr_func.address, + ) + ) + + if depth + 1 < len(max_stack_path): + succ_func = max_stack_path[depth + 1] + text_list = [] + for callsite in curr_func.callsites: + if callsite.callee is succ_func: + indent_prefix = " " + if callsite.address is None: + order_text = ( + None, + "{}-> [annotation]".format(indent_prefix), + ) + else: + order_text = OutputInlineStack( + callsite.address, indent_prefix + ) + + text_list.append(order_text) + + for _, text in sorted(text_list, key=lambda item: item[0]): + print(text) + + print("Unresolved indirect callsites:") + for function in function_map.values(): + indirect_callsites = [] + for callsite in function.callsites: + if callsite.target is None: + indirect_callsites.append(callsite.address) + + if len(indirect_callsites) > 0: + print(" In function {}:".format(function.name)) + text_list = [] + for address in indirect_callsites: + text_list.append(OutputInlineStack(address, " ")) + + for _, text in sorted(text_list, key=lambda item: item[0]): + print(text) + + print("Unresolved annotation signatures:") + for sigtxt, error in failed_sigtxts: + print(" {}: {}".format(sigtxt, error)) + + if len(cycle_functions) > 0: + print("There are cycles in the following function sets:") + for functions in cycle_functions: + print("[{}]".format(", ".join(function.name for function in functions))) - def ResolveAnnotation(self, function_map): - """Resolve annotation. - Args: - function_map: Function map. +def ParseArgs(): + """Parse commandline arguments. Returns: - Set of added call edges, list of remove paths, set of eliminated - callsite addresses, set of annotation signatures which can't be resolved. + options: Namespace from argparse.parse_args(). """ - def StringifySignature(signature): - """Stringify the tupled signature. - - Args: - signature: Tupled signature. - - Returns: - Signature string. - """ - (name, path, linenum) = signature - bracket_text = '' - if path is not None: - path = os.path.relpath(path) - if linenum is None: - bracket_text = '[{}]'.format(path) - else: - bracket_text = '[{}:{}]'.format(path, linenum) - - return name + bracket_text - - (add_rules, remove_rules, invalid_sigtxts) = self.LoadAnnotation() - - signature_set = set() - for src_sig, dst_sigs in add_rules.items(): - signature_set.add(src_sig) - signature_set.update(dst_sigs) - - for remove_sigs in remove_rules: - signature_set.update(remove_sigs) - - # Map signatures to functions. - (signature_map, sig_error_map) = self.MapAnnotation(function_map, - signature_set) - - # Build the indirect callsite map indexed by callsite signature. - indirect_map = collections.defaultdict(set) - for function in function_map.values(): - for callsite in function.callsites: - if callsite.target is not None: - continue - - # Found an indirect callsite. - line_info = self.AddressToLine(callsite.address)[0] - if line_info is None: - continue - - (name, path, linenum) = line_info - result = self.FUNCTION_PREFIX_NAME_RE.match(name) - if result is None: - continue - - indirect_map[(result.group('name').strip(), path, linenum)].add( - (function, callsite.address)) - - # Generate the annotation sets. - add_set = set() - remove_list = list() - eliminated_addrs = set() - - for src_sig, dst_sigs in add_rules.items(): - src_funcs = set(signature_map.get(src_sig, [])) - # Try to match the source signature to the indirect callsites. Even if it - # can't be found in disassembly. - indirect_calls = indirect_map.get(src_sig) - if indirect_calls is not None: - for function, callsite_address in indirect_calls: - # Add the caller of the indirect callsite to the source functions. - src_funcs.add(function) - # Assume each callsite can be represented by a unique address. - eliminated_addrs.add(callsite_address) - - if src_sig in sig_error_map: - # Assume the error is always the not found error. Since the signature - # found in indirect callsite map must be a full signature, it can't - # happen the ambiguous error. - assert sig_error_map[src_sig] == self.ANNOTATION_ERROR_NOTFOUND - # Found in inline stack, remove the not found error. - del sig_error_map[src_sig] - - for dst_sig in dst_sigs: - dst_funcs = signature_map.get(dst_sig) - if dst_funcs is None: - continue - - # Duplicate the call edge for all the same source and destination - # functions. - for src_func in src_funcs: - for dst_func in dst_funcs: - add_set.add((src_func, dst_func)) - - for remove_sigs in remove_rules: - # Since each signature can be mapped to multiple functions, generate - # multiple remove paths from all the combinations of these functions. - remove_paths = [[]] - skip_flag = False - for remove_sig in remove_sigs: - # Transform each signature to the corresponding functions. - remove_funcs = signature_map.get(remove_sig) - if remove_funcs is None: - # There is an unresolved signature in the remove path. Ignore the - # whole broken remove path. - skip_flag = True - break - else: - # Append each function of the current signature to the all previous - # remove paths. - remove_paths = [p + [f] for p in remove_paths for f in remove_funcs] - - if skip_flag: - # Ignore the broken remove path. - continue - - for remove_path in remove_paths: - # Deduplicate the remove paths. - if remove_path not in remove_list: - remove_list.append(remove_path) - - # Format the error messages. - failed_sigtxts = set() - for sigtxt in invalid_sigtxts: - failed_sigtxts.add((sigtxt, self.ANNOTATION_ERROR_INVALID)) - - for sig, error in sig_error_map.items(): - failed_sigtxts.add((StringifySignature(sig), error)) - - return (add_set, remove_list, eliminated_addrs, failed_sigtxts) - - def PreprocessAnnotation(self, function_map, add_set, remove_list, - eliminated_addrs): - """Preprocess the annotation and callgraph. + parser = argparse.ArgumentParser(description="EC firmware stack analyzer.") + parser.add_argument("elf_path", help="the path of EC firmware ELF") + parser.add_argument( + "--export_taskinfo", + required=True, + help="the path of export_taskinfo.so utility", + ) + parser.add_argument( + "--section", + required=True, + help="the section.", + choices=[SECTION_RO, SECTION_RW], + ) + parser.add_argument("--objdump", default="objdump", help="the path of objdump") + parser.add_argument( + "--addr2line", default="addr2line", help="the path of addr2line" + ) + parser.add_argument( + "--annotation", default=None, help="the path of annotation file" + ) + + # TODO(cheyuw): Add an option for dumping stack usage of all functions. + + return parser.parse_args() - Add the missing call edges, and delete simple remove paths (the paths have - one or two vertices) from the function_map. - Eliminate the annotated indirect callsites. - - Return the remaining remove list. +def ParseSymbolText(symbol_text): + """Parse the content of the symbol text. Args: - function_map: Function map. - add_set: Set of missing call edges. - remove_list: List of remove paths. - eliminated_addrs: Set of eliminated callsite addresses. + symbol_text: Text of the symbols. Returns: - List of remaining remove paths. + symbols: Symbol list. """ - def CheckEdge(path): - """Check if all edges of the path are on the callgraph. - - Args: - path: Path. - - Returns: - True or False. - """ - for index in range(len(path) - 1): - if (path[index], path[index + 1]) not in edge_set: - return False - - return True - - for src_func, dst_func in add_set: - # TODO(cheyuw): Support tailing call annotation. - src_func.callsites.append( - Callsite(None, dst_func.address, False, dst_func)) - - # Delete simple remove paths. - remove_simple = set(tuple(p) for p in remove_list if len(p) <= 2) - edge_set = set() - for function in function_map.values(): - cleaned_callsites = [] - for callsite in function.callsites: - if ((callsite.callee,) in remove_simple or - (function, callsite.callee) in remove_simple): - continue - - if callsite.target is None and callsite.address in eliminated_addrs: - continue - - cleaned_callsites.append(callsite) - if callsite.callee is not None: - edge_set.add((function, callsite.callee)) + # Example: "10093064 g F .text 0000015c .hidden hook_task" + symbol_regex = re.compile( + r"^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+" + r"((?P<type>[OF])\s+)?\S+\s+" + r"(?P<size>[0-9A-Fa-f]+)\s+" + r"(\S+\s+)?(?P<name>\S+)$" + ) + + symbols = [] + for line in symbol_text.splitlines(): + line = line.strip() + result = symbol_regex.match(line) + if result is not None: + address = int(result.group("address"), 16) + symtype = result.group("type") + if symtype is None: + symtype = "O" - function.callsites = cleaned_callsites + size = int(result.group("size"), 16) + name = result.group("name") + symbols.append(Symbol(address, symtype, size, name)) - return [p for p in remove_list if len(p) >= 3 and CheckEdge(p)] + return symbols - def AnalyzeCallGraph(self, function_map, remove_list): - """Analyze callgraph. - It will update the max stack size and path for each function. +def ParseRoDataText(rodata_text): + """Parse the content of rodata Args: - function_map: Function map. - remove_list: List of remove paths. + symbol_text: Text of the rodata dump. Returns: - List of function cycles. + symbols: Symbol list. """ - def Traverse(curr_state): - """Traverse the callgraph and calculate the max stack usages of functions. - - Args: - curr_state: Current state. - - Returns: - SCC lowest link. - """ - scc_index = scc_index_counter[0] - scc_index_counter[0] += 1 - scc_index_map[curr_state] = scc_index - scc_lowlink = scc_index - scc_stack.append(curr_state) - # Push the current state in the stack. We can use a set to maintain this - # because the stacked states are unique; otherwise we will find a cycle - # first. - stacked_states.add(curr_state) - - (curr_address, curr_positions) = curr_state - curr_func = function_map[curr_address] - - invalid_flag = False - new_positions = list(curr_positions) - for index, position in enumerate(curr_positions): - remove_path = remove_list[index] - - # The position of each remove path in the state is the length of the - # longest matching path between the prefix of the remove path and the - # suffix of the current traversing path. We maintain this length when - # appending the next callee to the traversing path. And it can be used - # to check if the remove path appears in the traversing path. - - # TODO(cheyuw): Implement KMP algorithm to match remove paths - # efficiently. - if remove_path[position] is curr_func: - # Matches the current function, extend the length. - new_positions[index] = position + 1 - if new_positions[index] == len(remove_path): - # The length of the longest matching path is equal to the length of - # the remove path, which means the suffix of the current traversing - # path matches the remove path. - invalid_flag = True - break - - else: - # We can't get the new longest matching path by extending the previous - # one directly. Fallback to search the new longest matching path. - - # If we can't find any matching path in the following search, reset - # the matching length to 0. - new_positions[index] = 0 - - # We want to find the new longest matching prefix of remove path with - # the suffix of the current traversing path. Because the new longest - # matching path won't be longer than the prevous one now, and part of - # the suffix matches the prefix of remove path, we can get the needed - # suffix from the previous matching prefix of the invalid path. - suffix = remove_path[:position] + [curr_func] - for offset in range(1, len(suffix)): - length = position - offset - if remove_path[:length] == suffix[offset:]: - new_positions[index] = length - break - - new_positions = tuple(new_positions) - - # If the current suffix is invalid, set the max stack usage to 0. - max_stack_usage = 0 - max_callee_state = None - self_loop = False - - if not invalid_flag: - # Max stack usage is at least equal to the stack frame. - max_stack_usage = curr_func.stack_frame - for callsite in curr_func.callsites: - callee = callsite.callee - if callee is None: + # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K... + # 100a7294 00000000 00000000 01000000 ............ + + base_offset = None + offset = None + rodata = [] + for line in rodata_text.splitlines(): + line = line.strip() + space = line.find(" ") + if space < 0: + continue + try: + address = int(line[0:space], 16) + except ValueError: continue - callee_state = (callee.address, new_positions) - if callee_state not in scc_index_map: - # Unvisited state. - scc_lowlink = min(scc_lowlink, Traverse(callee_state)) - elif callee_state in stacked_states: - # The state is shown in the stack. There is a cycle. - sub_stack_usage = 0 - scc_lowlink = min(scc_lowlink, scc_index_map[callee_state]) - if callee_state == curr_state: - self_loop = True - - done_result = done_states.get(callee_state) - if done_result is not None: - # Already done this state and use its result. If the state reaches a - # cycle, reusing the result will cause inaccuracy (the stack usage - # of cycle depends on where the entrance is). But it's fine since we - # can't get accurate stack usage under this situation, and we rely - # on user-provided annotations to break the cycle, after which the - # result will be accurate again. - (sub_stack_usage, _) = done_result - - if callsite.is_tail: - # For tailing call, since the callee reuses the stack frame of the - # caller, choose the larger one directly. - stack_usage = max(curr_func.stack_frame, sub_stack_usage) - else: - stack_usage = curr_func.stack_frame + sub_stack_usage - - if stack_usage > max_stack_usage: - max_stack_usage = stack_usage - max_callee_state = callee_state - - if scc_lowlink == scc_index: - group = [] - while scc_stack[-1] != curr_state: - scc_state = scc_stack.pop() - stacked_states.remove(scc_state) - group.append(scc_state) - - scc_stack.pop() - stacked_states.remove(curr_state) - - # If the cycle is not empty, record it. - if len(group) > 0 or self_loop: - group.append(curr_state) - cycle_groups.append(group) - - # Store the done result. - done_states[curr_state] = (max_stack_usage, max_callee_state) - - if curr_positions == initial_positions: - # If the current state is initial state, we traversed the callgraph by - # using the current function as start point. Update the stack usage of - # the function. - # If the function matches a single vertex remove path, this will set its - # max stack usage to 0, which is not expected (we still calculate its - # max stack usage, but prevent any function from calling it). However, - # all the single vertex remove paths have been preprocessed and removed. - curr_func.stack_max_usage = max_stack_usage - - # Reconstruct the max stack path by traversing the state transitions. - max_stack_path = [curr_func] - callee_state = max_callee_state - while callee_state is not None: - # The first element of state tuple is function address. - max_stack_path.append(function_map[callee_state[0]]) - done_result = done_states.get(callee_state) - # All of the descendants should be done. - assert done_result is not None - (_, callee_state) = done_result - - curr_func.stack_max_path = max_stack_path - - return scc_lowlink - - # The state is the concatenation of the current function address and the - # state of matching position. - initial_positions = (0,) * len(remove_list) - done_states = {} - stacked_states = set() - scc_index_counter = [0] - scc_index_map = {} - scc_stack = [] - cycle_groups = [] - for function in function_map.values(): - if function.stack_max_usage is None: - Traverse((function.address, initial_positions)) - - cycle_functions = [] - for group in cycle_groups: - cycle = set(function_map[state[0]] for state in group) - if cycle not in cycle_functions: - cycle_functions.append(cycle) - - return cycle_functions - - def Analyze(self): - """Run the stack analysis. - - Raises: - StackAnalyzerError: If disassembly fails. - """ - def OutputInlineStack(address, prefix=''): - """Output beautiful inline stack. - - Args: - address: Address. - prefix: Prefix of each line. - - Returns: - Key for sorting, output text - """ - line_infos = self.AddressToLine(address, True) - - if line_infos[0] is None: - order_key = (None, None) - else: - (_, path, linenum) = line_infos[0] - order_key = (linenum, path) - - line_texts = [] - for line_info in reversed(line_infos): - if line_info is None: - (function_name, path, linenum) = ('??', '??', 0) - else: - (function_name, path, linenum) = line_info - - line_texts.append('{}[{}:{}]'.format(function_name, - os.path.relpath(path), - linenum)) - - output = '{}-> {} {:x}\n'.format(prefix, line_texts[0], address) - for depth, line_text in enumerate(line_texts[1:]): - output += '{} {}- {}\n'.format(prefix, ' ' * depth, line_text) - - # Remove the last newline character. - return (order_key, output.rstrip('\n')) - - # Analyze disassembly. - try: - disasm_text = subprocess.check_output([self.options.objdump, - '-d', - self.options.elf_path], - encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('objdump failed to disassemble.') - except OSError: - raise StackAnalyzerError('Failed to run objdump.') - - function_map = self.AnalyzeDisassembly(disasm_text) - result = self.ResolveAnnotation(function_map) - (add_set, remove_list, eliminated_addrs, failed_sigtxts) = result - remove_list = self.PreprocessAnnotation(function_map, - add_set, - remove_list, - eliminated_addrs) - cycle_functions = self.AnalyzeCallGraph(function_map, remove_list) - - # Print the results of task-aware stack analysis. - extra_stack_frame = self.annotation.get('exception_frame_size', - DEFAULT_EXCEPTION_FRAME_SIZE) - for task in self.tasklist: - routine_func = function_map[task.routine_address] - print('Task: {}, Max size: {} ({} + {}), Allocated size: {}'.format( - task.name, - routine_func.stack_max_usage + extra_stack_frame, - routine_func.stack_max_usage, - extra_stack_frame, - task.stack_max_size)) - - print('Call Trace:') - max_stack_path = routine_func.stack_max_path - # Assume the routine function is resolved. - assert max_stack_path is not None - for depth, curr_func in enumerate(max_stack_path): - line_info = self.AddressToLine(curr_func.address)[0] - if line_info is None: - (path, linenum) = ('??', 0) - else: - (_, path, linenum) = line_info - - print(' {} ({}) [{}:{}] {:x}'.format(curr_func.name, - curr_func.stack_frame, - os.path.relpath(path), - linenum, - curr_func.address)) - - if depth + 1 < len(max_stack_path): - succ_func = max_stack_path[depth + 1] - text_list = [] - for callsite in curr_func.callsites: - if callsite.callee is succ_func: - indent_prefix = ' ' - if callsite.address is None: - order_text = (None, '{}-> [annotation]'.format(indent_prefix)) - else: - order_text = OutputInlineStack(callsite.address, indent_prefix) - - text_list.append(order_text) - - for _, text in sorted(text_list, key=lambda item: item[0]): - print(text) - - print('Unresolved indirect callsites:') - for function in function_map.values(): - indirect_callsites = [] - for callsite in function.callsites: - if callsite.target is None: - indirect_callsites.append(callsite.address) - - if len(indirect_callsites) > 0: - print(' In function {}:'.format(function.name)) - text_list = [] - for address in indirect_callsites: - text_list.append(OutputInlineStack(address, ' ')) - - for _, text in sorted(text_list, key=lambda item: item[0]): - print(text) - - print('Unresolved annotation signatures:') - for sigtxt, error in failed_sigtxts: - print(' {}: {}'.format(sigtxt, error)) - - if len(cycle_functions) > 0: - print('There are cycles in the following function sets:') - for functions in cycle_functions: - print('[{}]'.format(', '.join(function.name for function in functions))) - - -def ParseArgs(): - """Parse commandline arguments. - - Returns: - options: Namespace from argparse.parse_args(). - """ - parser = argparse.ArgumentParser(description="EC firmware stack analyzer.") - parser.add_argument('elf_path', help="the path of EC firmware ELF") - parser.add_argument('--export_taskinfo', required=True, - help="the path of export_taskinfo.so utility") - parser.add_argument('--section', required=True, help='the section.', - choices=[SECTION_RO, SECTION_RW]) - parser.add_argument('--objdump', default='objdump', - help='the path of objdump') - parser.add_argument('--addr2line', default='addr2line', - help='the path of addr2line') - parser.add_argument('--annotation', default=None, - help='the path of annotation file') - - # TODO(cheyuw): Add an option for dumping stack usage of all functions. - - return parser.parse_args() - - -def ParseSymbolText(symbol_text): - """Parse the content of the symbol text. - - Args: - symbol_text: Text of the symbols. - - Returns: - symbols: Symbol list. - """ - # Example: "10093064 g F .text 0000015c .hidden hook_task" - symbol_regex = re.compile(r'^(?P<address>[0-9A-Fa-f]+)\s+[lwg]\s+' - r'((?P<type>[OF])\s+)?\S+\s+' - r'(?P<size>[0-9A-Fa-f]+)\s+' - r'(\S+\s+)?(?P<name>\S+)$') - - symbols = [] - for line in symbol_text.splitlines(): - line = line.strip() - result = symbol_regex.match(line) - if result is not None: - address = int(result.group('address'), 16) - symtype = result.group('type') - if symtype is None: - symtype = 'O' - - size = int(result.group('size'), 16) - name = result.group('name') - symbols.append(Symbol(address, symtype, size, name)) - - return symbols - - -def ParseRoDataText(rodata_text): - """Parse the content of rodata - - Args: - symbol_text: Text of the rodata dump. - - Returns: - symbols: Symbol list. - """ - # Examples: 8018ab0 00040048 00010000 10020000 4b8e0108 ...H........K... - # 100a7294 00000000 00000000 01000000 ............ - - base_offset = None - offset = None - rodata = [] - for line in rodata_text.splitlines(): - line = line.strip() - space = line.find(' ') - if space < 0: - continue - try: - address = int(line[0:space], 16) - except ValueError: - continue - - if not base_offset: - base_offset = address - offset = address - elif address != offset: - raise StackAnalyzerError('objdump of rodata not contiguous.') + if not base_offset: + base_offset = address + offset = address + elif address != offset: + raise StackAnalyzerError("objdump of rodata not contiguous.") - for i in range(0, 4): - num = line[(space + 1 + i*9):(space + 9 + i*9)] - if len(num.strip()) > 0: - val = int(num, 16) - else: - val = 0 - # TODO(drinkcat): Not all platforms are necessarily big-endian - rodata.append((val & 0x000000ff) << 24 | - (val & 0x0000ff00) << 8 | - (val & 0x00ff0000) >> 8 | - (val & 0xff000000) >> 24) + for i in range(0, 4): + num = line[(space + 1 + i * 9) : (space + 9 + i * 9)] + if len(num.strip()) > 0: + val = int(num, 16) + else: + val = 0 + # TODO(drinkcat): Not all platforms are necessarily big-endian + rodata.append( + (val & 0x000000FF) << 24 + | (val & 0x0000FF00) << 8 + | (val & 0x00FF0000) >> 8 + | (val & 0xFF000000) >> 24 + ) - offset = offset + 4*4 + offset = offset + 4 * 4 - return (base_offset, rodata) + return (base_offset, rodata) def LoadTasklist(section, export_taskinfo, symbols): - """Load the task information. + """Load the task information. - Args: - section: Section (RO | RW). - export_taskinfo: Handle of export_taskinfo.so. - symbols: Symbol list. + Args: + section: Section (RO | RW). + export_taskinfo: Handle of export_taskinfo.so. + symbols: Symbol list. - Returns: - tasklist: Task list. - """ + Returns: + tasklist: Task list. + """ - TaskInfoPointer = ctypes.POINTER(TaskInfo) - taskinfos = TaskInfoPointer() - if section == SECTION_RO: - get_taskinfos_func = export_taskinfo.get_ro_taskinfos - else: - get_taskinfos_func = export_taskinfo.get_rw_taskinfos + TaskInfoPointer = ctypes.POINTER(TaskInfo) + taskinfos = TaskInfoPointer() + if section == SECTION_RO: + get_taskinfos_func = export_taskinfo.get_ro_taskinfos + else: + get_taskinfos_func = export_taskinfo.get_rw_taskinfos - taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos)) + taskinfo_num = get_taskinfos_func(ctypes.pointer(taskinfos)) - tasklist = [] - for index in range(taskinfo_num): - taskinfo = taskinfos[index] - tasklist.append(Task(taskinfo.name.decode('utf-8'), - taskinfo.routine.decode('utf-8'), - taskinfo.stack_size)) + tasklist = [] + for index in range(taskinfo_num): + taskinfo = taskinfos[index] + tasklist.append( + Task( + taskinfo.name.decode("utf-8"), + taskinfo.routine.decode("utf-8"), + taskinfo.stack_size, + ) + ) - # Resolve routine address for each task. It's more efficient to resolve all - # routine addresses of tasks together. - routine_map = dict((task.routine_name, None) for task in tasklist) + # Resolve routine address for each task. It's more efficient to resolve all + # routine addresses of tasks together. + routine_map = dict((task.routine_name, None) for task in tasklist) - for symbol in symbols: - # Resolve task routine address. - if symbol.name in routine_map: - # Assume the symbol of routine is unique. - assert routine_map[symbol.name] is None - routine_map[symbol.name] = symbol.address + for symbol in symbols: + # Resolve task routine address. + if symbol.name in routine_map: + # Assume the symbol of routine is unique. + assert routine_map[symbol.name] is None + routine_map[symbol.name] = symbol.address - for task in tasklist: - address = routine_map[task.routine_name] - # Assume we have resolved all routine addresses. - assert address is not None - task.routine_address = address + for task in tasklist: + address = routine_map[task.routine_name] + # Assume we have resolved all routine addresses. + assert address is not None + task.routine_address = address - return tasklist + return tasklist def main(): - """Main function.""" - try: - options = ParseArgs() - - # Load annotation config. - if options.annotation is None: - annotation = {} - elif not os.path.exists(options.annotation): - print('Warning: Annotation file {} does not exist.' - .format(options.annotation)) - annotation = {} - else: - try: - with open(options.annotation, 'r') as annotation_file: - annotation = yaml.safe_load(annotation_file) - - except yaml.YAMLError: - raise StackAnalyzerError('Failed to parse annotation file {}.' - .format(options.annotation)) - except IOError: - raise StackAnalyzerError('Failed to open annotation file {}.' - .format(options.annotation)) - - # TODO(cheyuw): Do complete annotation format verification. - if not isinstance(annotation, dict): - raise StackAnalyzerError('Invalid annotation file {}.' - .format(options.annotation)) - - # Generate and parse the symbols. + """Main function.""" try: - symbol_text = subprocess.check_output([options.objdump, - '-t', - options.elf_path], - encoding='utf-8') - rodata_text = subprocess.check_output([options.objdump, - '-s', - '-j', '.rodata', - options.elf_path], - encoding='utf-8') - except subprocess.CalledProcessError: - raise StackAnalyzerError('objdump failed to dump symbol table or rodata.') - except OSError: - raise StackAnalyzerError('Failed to run objdump.') - - symbols = ParseSymbolText(symbol_text) - rodata = ParseRoDataText(rodata_text) - - # Load the tasklist. - try: - export_taskinfo = ctypes.CDLL(options.export_taskinfo) - except OSError: - raise StackAnalyzerError('Failed to load export_taskinfo.') - - tasklist = LoadTasklist(options.section, export_taskinfo, symbols) - - analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation) - analyzer.Analyze() - except StackAnalyzerError as e: - print('Error: {}'.format(e)) - - -if __name__ == '__main__': - main() + options = ParseArgs() + + # Load annotation config. + if options.annotation is None: + annotation = {} + elif not os.path.exists(options.annotation): + print( + "Warning: Annotation file {} does not exist.".format(options.annotation) + ) + annotation = {} + else: + try: + with open(options.annotation, "r") as annotation_file: + annotation = yaml.safe_load(annotation_file) + + except yaml.YAMLError: + raise StackAnalyzerError( + "Failed to parse annotation file {}.".format(options.annotation) + ) + except IOError: + raise StackAnalyzerError( + "Failed to open annotation file {}.".format(options.annotation) + ) + + # TODO(cheyuw): Do complete annotation format verification. + if not isinstance(annotation, dict): + raise StackAnalyzerError( + "Invalid annotation file {}.".format(options.annotation) + ) + + # Generate and parse the symbols. + try: + symbol_text = subprocess.check_output( + [options.objdump, "-t", options.elf_path], encoding="utf-8" + ) + rodata_text = subprocess.check_output( + [options.objdump, "-s", "-j", ".rodata", options.elf_path], + encoding="utf-8", + ) + except subprocess.CalledProcessError: + raise StackAnalyzerError("objdump failed to dump symbol table or rodata.") + except OSError: + raise StackAnalyzerError("Failed to run objdump.") + + symbols = ParseSymbolText(symbol_text) + rodata = ParseRoDataText(rodata_text) + + # Load the tasklist. + try: + export_taskinfo = ctypes.CDLL(options.export_taskinfo) + except OSError: + raise StackAnalyzerError("Failed to load export_taskinfo.") + + tasklist = LoadTasklist(options.section, export_taskinfo, symbols) + + analyzer = StackAnalyzer(options, symbols, rodata, tasklist, annotation) + analyzer.Analyze() + except StackAnalyzerError as e: + print("Error: {}".format(e)) + + +if __name__ == "__main__": + main() diff --git a/extra/stack_analyzer/stack_analyzer_unittest.py b/extra/stack_analyzer/stack_analyzer_unittest.py index c36fa9da45..ad2837a8a4 100755 --- a/extra/stack_analyzer/stack_analyzer_unittest.py +++ b/extra/stack_analyzer/stack_analyzer_unittest.py @@ -11,820 +11,924 @@ from __future__ import print_function -import mock import os import subprocess import unittest +import mock import stack_analyzer as sa class ObjectTest(unittest.TestCase): - """Tests for classes of basic objects.""" - - def testTask(self): - task_a = sa.Task('a', 'a_task', 1234) - task_b = sa.Task('b', 'b_task', 5678, 0x1000) - self.assertEqual(task_a, task_a) - self.assertNotEqual(task_a, task_b) - self.assertNotEqual(task_a, None) - - def testSymbol(self): - symbol_a = sa.Symbol(0x1234, 'F', 32, 'a') - symbol_b = sa.Symbol(0x234, 'O', 42, 'b') - self.assertEqual(symbol_a, symbol_a) - self.assertNotEqual(symbol_a, symbol_b) - self.assertNotEqual(symbol_a, None) - - def testCallsite(self): - callsite_a = sa.Callsite(0x1002, 0x3000, False) - callsite_b = sa.Callsite(0x1002, 0x3000, True) - self.assertEqual(callsite_a, callsite_a) - self.assertNotEqual(callsite_a, callsite_b) - self.assertNotEqual(callsite_a, None) - - def testFunction(self): - func_a = sa.Function(0x100, 'a', 0, []) - func_b = sa.Function(0x200, 'b', 0, []) - self.assertEqual(func_a, func_a) - self.assertNotEqual(func_a, func_b) - self.assertNotEqual(func_a, None) + """Tests for classes of basic objects.""" + + def testTask(self): + task_a = sa.Task("a", "a_task", 1234) + task_b = sa.Task("b", "b_task", 5678, 0x1000) + self.assertEqual(task_a, task_a) + self.assertNotEqual(task_a, task_b) + self.assertNotEqual(task_a, None) + + def testSymbol(self): + symbol_a = sa.Symbol(0x1234, "F", 32, "a") + symbol_b = sa.Symbol(0x234, "O", 42, "b") + self.assertEqual(symbol_a, symbol_a) + self.assertNotEqual(symbol_a, symbol_b) + self.assertNotEqual(symbol_a, None) + + def testCallsite(self): + callsite_a = sa.Callsite(0x1002, 0x3000, False) + callsite_b = sa.Callsite(0x1002, 0x3000, True) + self.assertEqual(callsite_a, callsite_a) + self.assertNotEqual(callsite_a, callsite_b) + self.assertNotEqual(callsite_a, None) + + def testFunction(self): + func_a = sa.Function(0x100, "a", 0, []) + func_b = sa.Function(0x200, "b", 0, []) + self.assertEqual(func_a, func_a) + self.assertNotEqual(func_a, func_b) + self.assertNotEqual(func_a, None) class ArmAnalyzerTest(unittest.TestCase): - """Tests for class ArmAnalyzer.""" - - def AppendConditionCode(self, opcodes): - rets = [] - for opcode in opcodes: - rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES) - - return rets - - def testInstructionMatching(self): - jump_list = self.AppendConditionCode(['b', 'bx']) - jump_list += (list(opcode + '.n' for opcode in jump_list) + - list(opcode + '.w' for opcode in jump_list)) - for opcode in jump_list: - self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('bl')) - self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match('blx')) - - cbz_list = ['cbz', 'cbnz', 'cbz.n', 'cbnz.n', 'cbz.w', 'cbnz.w'] - for opcode in cbz_list: - self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match('cbn')) - - call_list = self.AppendConditionCode(['bl', 'blx']) - call_list += list(opcode + '.n' for opcode in call_list) - for opcode in call_list: - self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode)) - - self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match('ble')) - - result = sa.ArmAnalyzer.CALL_OPERAND_RE.match('53f90 <get_time+0x18>') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '53f90') - self.assertEqual(result.group(2), 'get_time+0x18') - - result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match('r6, 53f90 <get+0x0>') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '53f90') - self.assertEqual(result.group(2), 'get+0x0') - - self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('push')) - self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match('pushal')) - self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('stmdb')) - self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match('lstm')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subw')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('sub.w')) - self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match('subs.w')) - - result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, sp, #1668 ; 0x684') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '1668') - result = sa.ArmAnalyzer.SUB_OPERAND_RE.match('sp, #1668') - self.assertIsNotNone(result) - self.assertEqual(result.group(1), '1668') - self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match('sl, #1668')) - - def testAnalyzeFunction(self): - analyzer = sa.ArmAnalyzer() - symbol = sa.Symbol(0x10, 'F', 0x100, 'foo') - instructions = [ - (0x10, 'push', '{r4, r5, r6, r7, lr}'), - (0x12, 'subw', 'sp, sp, #16 ; 0x10'), - (0x16, 'movs', 'lr, r1'), - (0x18, 'beq.n', '26 <foo+0x26>'), - (0x1a, 'bl', '30 <foo+0x30>'), - (0x1e, 'bl', 'deadbeef <bar>'), - (0x22, 'blx', '0 <woo>'), - (0x26, 'push', '{r1}'), - (0x28, 'stmdb', 'sp!, {r4, r5, r6, r7, r8, r9, lr}'), - (0x2c, 'stmdb', 'sp!, {r4}'), - (0x30, 'stmdb', 'sp, {r4}'), - (0x34, 'bx.n', '10 <foo>'), - (0x36, 'bx.n', 'r3'), - (0x38, 'ldr', 'pc, [r10]'), - ] - (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions) - self.assertEqual(size, 72) - expect_callsites = [sa.Callsite(0x1e, 0xdeadbeef, False), - sa.Callsite(0x22, 0x0, False), - sa.Callsite(0x34, 0x10, True), - sa.Callsite(0x36, None, True), - sa.Callsite(0x38, None, True)] - self.assertEqual(callsites, expect_callsites) + """Tests for class ArmAnalyzer.""" + + def AppendConditionCode(self, opcodes): + rets = [] + for opcode in opcodes: + rets.extend(opcode + cc for cc in sa.ArmAnalyzer.CONDITION_CODES) + + return rets + + def testInstructionMatching(self): + jump_list = self.AppendConditionCode(["b", "bx"]) + jump_list += list(opcode + ".n" for opcode in jump_list) + list( + opcode + ".w" for opcode in jump_list + ) + for opcode in jump_list: + self.assertIsNotNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("bl")) + self.assertIsNone(sa.ArmAnalyzer.JUMP_OPCODE_RE.match("blx")) + + cbz_list = ["cbz", "cbnz", "cbz.n", "cbnz.n", "cbz.w", "cbnz.w"] + for opcode in cbz_list: + self.assertIsNotNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.CBZ_CBNZ_OPCODE_RE.match("cbn")) + + call_list = self.AppendConditionCode(["bl", "blx"]) + call_list += list(opcode + ".n" for opcode in call_list) + for opcode in call_list: + self.assertIsNotNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match(opcode)) + + self.assertIsNone(sa.ArmAnalyzer.CALL_OPCODE_RE.match("ble")) + + result = sa.ArmAnalyzer.CALL_OPERAND_RE.match("53f90 <get_time+0x18>") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "53f90") + self.assertEqual(result.group(2), "get_time+0x18") + + result = sa.ArmAnalyzer.CBZ_CBNZ_OPERAND_RE.match("r6, 53f90 <get+0x0>") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "53f90") + self.assertEqual(result.group(2), "get+0x0") + + self.assertIsNotNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("push")) + self.assertIsNone(sa.ArmAnalyzer.PUSH_OPCODE_RE.match("pushal")) + self.assertIsNotNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("stmdb")) + self.assertIsNone(sa.ArmAnalyzer.STM_OPCODE_RE.match("lstm")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subw")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("sub.w")) + self.assertIsNotNone(sa.ArmAnalyzer.SUB_OPCODE_RE.match("subs.w")) + + result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, sp, #1668 ; 0x684") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "1668") + result = sa.ArmAnalyzer.SUB_OPERAND_RE.match("sp, #1668") + self.assertIsNotNone(result) + self.assertEqual(result.group(1), "1668") + self.assertIsNone(sa.ArmAnalyzer.SUB_OPERAND_RE.match("sl, #1668")) + + def testAnalyzeFunction(self): + analyzer = sa.ArmAnalyzer() + symbol = sa.Symbol(0x10, "F", 0x100, "foo") + instructions = [ + (0x10, "push", "{r4, r5, r6, r7, lr}"), + (0x12, "subw", "sp, sp, #16 ; 0x10"), + (0x16, "movs", "lr, r1"), + (0x18, "beq.n", "26 <foo+0x26>"), + (0x1A, "bl", "30 <foo+0x30>"), + (0x1E, "bl", "deadbeef <bar>"), + (0x22, "blx", "0 <woo>"), + (0x26, "push", "{r1}"), + (0x28, "stmdb", "sp!, {r4, r5, r6, r7, r8, r9, lr}"), + (0x2C, "stmdb", "sp!, {r4}"), + (0x30, "stmdb", "sp, {r4}"), + (0x34, "bx.n", "10 <foo>"), + (0x36, "bx.n", "r3"), + (0x38, "ldr", "pc, [r10]"), + ] + (size, callsites) = analyzer.AnalyzeFunction(symbol, instructions) + self.assertEqual(size, 72) + expect_callsites = [ + sa.Callsite(0x1E, 0xDEADBEEF, False), + sa.Callsite(0x22, 0x0, False), + sa.Callsite(0x34, 0x10, True), + sa.Callsite(0x36, None, True), + sa.Callsite(0x38, None, True), + ] + self.assertEqual(callsites, expect_callsites) class StackAnalyzerTest(unittest.TestCase): - """Tests for class StackAnalyzer.""" - - def setUp(self): - symbols = [sa.Symbol(0x1000, 'F', 0x15C, 'hook_task'), - sa.Symbol(0x2000, 'F', 0x51C, 'console_task'), - sa.Symbol(0x3200, 'O', 0x124, '__just_data'), - sa.Symbol(0x4000, 'F', 0x11C, 'touchpad_calc'), - sa.Symbol(0x5000, 'F', 0x12C, 'touchpad_calc.constprop.42'), - sa.Symbol(0x12000, 'F', 0x13C, 'trackpad_range'), - sa.Symbol(0x13000, 'F', 0x200, 'inlined_mul'), - sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul'), - sa.Symbol(0x13100, 'F', 0x200, 'inlined_mul_alias'), - sa.Symbol(0x20000, 'O', 0x0, '__array'), - sa.Symbol(0x20010, 'O', 0x0, '__array_end'), - ] - tasklist = [sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - sa.Task('CONSOLE', 'console_task', 460, 0x2000)] - # Array at 0x20000 that contains pointers to hook_task and console_task, - # with stride=8, offset=4 - rodata = (0x20000, [ 0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000 ]) - options = mock.MagicMock(elf_path='./ec.RW.elf', - export_taskinfo='fake', - section='RW', - objdump='objdump', - addr2line='addr2line', - annotation=None) - self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {}) - - def testParseSymbolText(self): - symbol_text = ( - '0 g F .text e8 Foo\n' - '0000dead w F .text 000000e8 .hidden Bar\n' - 'deadbeef l O .bss 00000004 .hidden Woooo\n' - 'deadbee g O .rodata 00000008 __Hooo_ooo\n' - 'deadbee g .rodata 00000000 __foo_doo_coo_end\n' - ) - symbols = sa.ParseSymbolText(symbol_text) - expect_symbols = [sa.Symbol(0x0, 'F', 0xe8, 'Foo'), - sa.Symbol(0xdead, 'F', 0xe8, 'Bar'), - sa.Symbol(0xdeadbeef, 'O', 0x4, 'Woooo'), - sa.Symbol(0xdeadbee, 'O', 0x8, '__Hooo_ooo'), - sa.Symbol(0xdeadbee, 'O', 0x0, '__foo_doo_coo_end')] - self.assertEqual(symbols, expect_symbols) - - def testParseRoData(self): - rodata_text = ( - '\n' - 'Contents of section .rodata:\n' - ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n' - ) - rodata = sa.ParseRoDataText(rodata_text) - expect_rodata = (0x20000, - [ 0x0010adde, 0x00001000, 0x0020adde, 0x00002000 ]) - self.assertEqual(rodata, expect_rodata) - - def testLoadTasklist(self): - def tasklist_to_taskinfos(pointer, tasklist): - taskinfos = [] - for task in tasklist: - taskinfos.append(sa.TaskInfo(name=task.name.encode('utf-8'), - routine=task.routine_name.encode('utf-8'), - stack_size=task.stack_max_size)) - - TaskInfoArray = sa.TaskInfo * len(taskinfos) - pointer.contents.contents = TaskInfoArray(*taskinfos) - return len(taskinfos) - - def ro_taskinfos(pointer): - return tasklist_to_taskinfos(pointer, expect_ro_tasklist) - - def rw_taskinfos(pointer): - return tasklist_to_taskinfos(pointer, expect_rw_tasklist) - - expect_ro_tasklist = [ - sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - ] - - expect_rw_tasklist = [ - sa.Task('HOOKS', 'hook_task', 2048, 0x1000), - sa.Task('WOOKS', 'hook_task', 4096, 0x1000), - sa.Task('CONSOLE', 'console_task', 460, 0x2000), - ] - - export_taskinfo = mock.MagicMock( - get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos), - get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos)) - - tasklist = sa.LoadTasklist('RO', export_taskinfo, self.analyzer.symbols) - self.assertEqual(tasklist, expect_ro_tasklist) - tasklist = sa.LoadTasklist('RW', export_taskinfo, self.analyzer.symbols) - self.assertEqual(tasklist, expect_rw_tasklist) - - def testResolveAnnotation(self): - self.analyzer.annotation = {} - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(remove_rules, []) - self.assertEqual(invalid_sigtxts, set()) - - self.analyzer.annotation = {'add': None, 'remove': None} - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(remove_rules, []) - self.assertEqual(invalid_sigtxts, set()) - - self.analyzer.annotation = { - 'add': None, - 'remove': [ - [['a', 'b'], ['0', '[', '2'], 'x'], - [['a', 'b[x:3]'], ['0', '1', '2'], 'x'], - ], - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, {}) - self.assertEqual(list.sort(remove_rules), list.sort([ - [('a', None, None), ('1', None, None), ('x', None, None)], - [('a', None, None), ('0', None, None), ('x', None, None)], - [('a', None, None), ('2', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('1', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('0', None, None), ('x', None, None)], - [('b', os.path.abspath('x'), 3), ('2', None, None), ('x', None, None)], - ])) - self.assertEqual(invalid_sigtxts, {'['}) - - self.analyzer.annotation = { - 'add': { - 'touchpad_calc': [ dict(name='__array', stride=8, offset=4) ], + """Tests for class StackAnalyzer.""" + + def setUp(self): + symbols = [ + sa.Symbol(0x1000, "F", 0x15C, "hook_task"), + sa.Symbol(0x2000, "F", 0x51C, "console_task"), + sa.Symbol(0x3200, "O", 0x124, "__just_data"), + sa.Symbol(0x4000, "F", 0x11C, "touchpad_calc"), + sa.Symbol(0x5000, "F", 0x12C, "touchpad_calc.constprop.42"), + sa.Symbol(0x12000, "F", 0x13C, "trackpad_range"), + sa.Symbol(0x13000, "F", 0x200, "inlined_mul"), + sa.Symbol(0x13100, "F", 0x200, "inlined_mul"), + sa.Symbol(0x13100, "F", 0x200, "inlined_mul_alias"), + sa.Symbol(0x20000, "O", 0x0, "__array"), + sa.Symbol(0x20010, "O", 0x0, "__array_end"), + ] + tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + sa.Task("CONSOLE", "console_task", 460, 0x2000), + ] + # Array at 0x20000 that contains pointers to hook_task and console_task, + # with stride=8, offset=4 + rodata = (0x20000, [0xDEAD1000, 0x00001000, 0xDEAD2000, 0x00002000]) + options = mock.MagicMock( + elf_path="./ec.RW.elf", + export_taskinfo="fake", + section="RW", + objdump="objdump", + addr2line="addr2line", + annotation=None, + ) + self.analyzer = sa.StackAnalyzer(options, symbols, rodata, tasklist, {}) + + def testParseSymbolText(self): + symbol_text = ( + "0 g F .text e8 Foo\n" + "0000dead w F .text 000000e8 .hidden Bar\n" + "deadbeef l O .bss 00000004 .hidden Woooo\n" + "deadbee g O .rodata 00000008 __Hooo_ooo\n" + "deadbee g .rodata 00000000 __foo_doo_coo_end\n" + ) + symbols = sa.ParseSymbolText(symbol_text) + expect_symbols = [ + sa.Symbol(0x0, "F", 0xE8, "Foo"), + sa.Symbol(0xDEAD, "F", 0xE8, "Bar"), + sa.Symbol(0xDEADBEEF, "O", 0x4, "Woooo"), + sa.Symbol(0xDEADBEE, "O", 0x8, "__Hooo_ooo"), + sa.Symbol(0xDEADBEE, "O", 0x0, "__foo_doo_coo_end"), + ] + self.assertEqual(symbols, expect_symbols) + + def testParseRoData(self): + rodata_text = ( + "\n" + "Contents of section .rodata:\n" + " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n" + ) + rodata = sa.ParseRoDataText(rodata_text) + expect_rodata = (0x20000, [0x0010ADDE, 0x00001000, 0x0020ADDE, 0x00002000]) + self.assertEqual(rodata, expect_rodata) + + def testLoadTasklist(self): + def tasklist_to_taskinfos(pointer, tasklist): + taskinfos = [] + for task in tasklist: + taskinfos.append( + sa.TaskInfo( + name=task.name.encode("utf-8"), + routine=task.routine_name.encode("utf-8"), + stack_size=task.stack_max_size, + ) + ) + + TaskInfoArray = sa.TaskInfo * len(taskinfos) + pointer.contents.contents = TaskInfoArray(*taskinfos) + return len(taskinfos) + + def ro_taskinfos(pointer): + return tasklist_to_taskinfos(pointer, expect_ro_tasklist) + + def rw_taskinfos(pointer): + return tasklist_to_taskinfos(pointer, expect_rw_tasklist) + + expect_ro_tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + ] + + expect_rw_tasklist = [ + sa.Task("HOOKS", "hook_task", 2048, 0x1000), + sa.Task("WOOKS", "hook_task", 4096, 0x1000), + sa.Task("CONSOLE", "console_task", 460, 0x2000), + ] + + export_taskinfo = mock.MagicMock( + get_ro_taskinfos=mock.MagicMock(side_effect=ro_taskinfos), + get_rw_taskinfos=mock.MagicMock(side_effect=rw_taskinfos), + ) + + tasklist = sa.LoadTasklist("RO", export_taskinfo, self.analyzer.symbols) + self.assertEqual(tasklist, expect_ro_tasklist) + tasklist = sa.LoadTasklist("RW", export_taskinfo, self.analyzer.symbols) + self.assertEqual(tasklist, expect_rw_tasklist) + + def testResolveAnnotation(self): + self.analyzer.annotation = {} + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual(remove_rules, []) + self.assertEqual(invalid_sigtxts, set()) + + self.analyzer.annotation = {"add": None, "remove": None} + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual(remove_rules, []) + self.assertEqual(invalid_sigtxts, set()) + + self.analyzer.annotation = { + "add": None, + "remove": [ + [["a", "b"], ["0", "[", "2"], "x"], + [["a", "b[x:3]"], ["0", "1", "2"], "x"], + ], + } + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(add_rules, {}) + self.assertEqual( + list.sort(remove_rules), + list.sort( + [ + [("a", None, None), ("1", None, None), ("x", None, None)], + [("a", None, None), ("0", None, None), ("x", None, None)], + [("a", None, None), ("2", None, None), ("x", None, None)], + [ + ("b", os.path.abspath("x"), 3), + ("1", None, None), + ("x", None, None), + ], + [ + ("b", os.path.abspath("x"), 3), + ("0", None, None), + ("x", None, None), + ], + [ + ("b", os.path.abspath("x"), 3), + ("2", None, None), + ("x", None, None), + ], + ] + ), + ) + self.assertEqual(invalid_sigtxts, {"["}) + + self.analyzer.annotation = { + "add": { + "touchpad_calc": [dict(name="__array", stride=8, offset=4)], + } + } + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual( + add_rules, + { + ("touchpad_calc", None, None): set( + [("console_task", None, None), ("hook_task", None, None)] + ) + }, + ) + + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + 0x5000: sa.Function(0x5000, "touchpad_calc.constprop.42", 0, []), + 0x13000: sa.Function(0x13000, "inlined_mul", 0, []), + 0x13100: sa.Function(0x13100, "inlined_mul", 0, []), + } + funcs[0x1000].callsites = [sa.Callsite(0x1002, None, False, None)] + # Set address_to_line_cache to fake the results of addr2line. + self.analyzer.address_to_line_cache = { + (0x1000, False): [("hook_task", os.path.abspath("a.c"), 10)], + (0x1002, False): [("toot_calc", os.path.abspath("t.c"), 1234)], + (0x2000, False): [("console_task", os.path.abspath("b.c"), 20)], + (0x4000, False): [("toudhpad_calc", os.path.abspath("a.c"), 20)], + (0x5000, False): [ + ("touchpad_calc.constprop.42", os.path.abspath("b.c"), 40) + ], + (0x12000, False): [("trackpad_range", os.path.abspath("t.c"), 10)], + (0x13000, False): [("inlined_mul", os.path.abspath("x.c"), 12)], + (0x13100, False): [("inlined_mul", os.path.abspath("x.c"), 12)], + } + self.analyzer.annotation = { + "add": { + "hook_task.lto.573": ["touchpad_calc.lto.2501[a.c]"], + "console_task": ["touchpad_calc[b.c]", "inlined_mul_alias"], + "hook_task[q.c]": ["hook_task"], + "inlined_mul[x.c]": ["inlined_mul"], + "toot_calc[t.c:1234]": ["hook_task"], + }, + "remove": [ + ["touchpad?calc["], + "touchpad_calc", + ["touchpad_calc[a.c]"], + ["task_unk[a.c]"], + ["touchpad_calc[x/a.c]"], + ["trackpad_range"], + ["inlined_mul"], + ["inlined_mul", "console_task", "touchpad_calc[a.c]"], + ["inlined_mul", "inlined_mul_alias", "console_task"], + ["inlined_mul", "inlined_mul_alias", "console_task"], + ], + } + (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() + self.assertEqual(invalid_sigtxts, {"touchpad?calc["}) + + signature_set = set() + for src_sig, dst_sigs in add_rules.items(): + signature_set.add(src_sig) + signature_set.update(dst_sigs) + + for remove_sigs in remove_rules: + signature_set.update(remove_sigs) + + (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, signature_set) + result = self.analyzer.ResolveAnnotation(funcs) + (add_set, remove_list, eliminated_addrs, failed_sigs) = result + + expect_signature_map = { + ("hook_task", None, None): {funcs[0x1000]}, + ("touchpad_calc", os.path.abspath("a.c"), None): {funcs[0x4000]}, + ("touchpad_calc", os.path.abspath("b.c"), None): {funcs[0x5000]}, + ("console_task", None, None): {funcs[0x2000]}, + ("inlined_mul_alias", None, None): {funcs[0x13100]}, + ("inlined_mul", os.path.abspath("x.c"), None): { + funcs[0x13000], + funcs[0x13100], + }, + ("inlined_mul", None, None): {funcs[0x13000], funcs[0x13100]}, + } + self.assertEqual(len(signature_map), len(expect_signature_map)) + for sig, funclist in signature_map.items(): + self.assertEqual(set(funclist), expect_signature_map[sig]) + + self.assertEqual( + add_set, + { + (funcs[0x1000], funcs[0x4000]), + (funcs[0x1000], funcs[0x1000]), + (funcs[0x2000], funcs[0x5000]), + (funcs[0x2000], funcs[0x13100]), + (funcs[0x13000], funcs[0x13000]), + (funcs[0x13000], funcs[0x13100]), + (funcs[0x13100], funcs[0x13000]), + (funcs[0x13100], funcs[0x13100]), + }, + ) + expect_remove_list = [ + [funcs[0x4000]], + [funcs[0x13000]], + [funcs[0x13100]], + [funcs[0x13000], funcs[0x2000], funcs[0x4000]], + [funcs[0x13100], funcs[0x2000], funcs[0x4000]], + [funcs[0x13000], funcs[0x13100], funcs[0x2000]], + [funcs[0x13100], funcs[0x13100], funcs[0x2000]], + ] + self.assertEqual(len(remove_list), len(expect_remove_list)) + for remove_path in remove_list: + self.assertTrue(remove_path in expect_remove_list) + + self.assertEqual(eliminated_addrs, {0x1002}) + self.assertEqual( + failed_sigs, + { + ("touchpad?calc[", sa.StackAnalyzer.ANNOTATION_ERROR_INVALID), + ("touchpad_calc", sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS), + ("hook_task[q.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("task_unk[a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("touchpad_calc[x/a.c]", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + ("trackpad_range", sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), + }, + ) + + def testPreprocessAnnotation(self): + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + funcs[0x1000].callsites = [sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])] + funcs[0x2000].callsites = [ + sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]), + sa.Callsite(0x2006, None, True, None), + ] + add_set = { + (funcs[0x2000], funcs[0x2000]), + (funcs[0x2000], funcs[0x4000]), + (funcs[0x4000], funcs[0x1000]), + (funcs[0x4000], funcs[0x2000]), } - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(add_rules, { - ('touchpad_calc', None, None): - set([('console_task', None, None), ('hook_task', None, None)])}) - - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - 0x5000: sa.Function(0x5000, 'touchpad_calc.constprop.42', 0, []), - 0x13000: sa.Function(0x13000, 'inlined_mul', 0, []), - 0x13100: sa.Function(0x13100, 'inlined_mul', 0, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, None, False, None)] - # Set address_to_line_cache to fake the results of addr2line. - self.analyzer.address_to_line_cache = { - (0x1000, False): [('hook_task', os.path.abspath('a.c'), 10)], - (0x1002, False): [('toot_calc', os.path.abspath('t.c'), 1234)], - (0x2000, False): [('console_task', os.path.abspath('b.c'), 20)], - (0x4000, False): [('toudhpad_calc', os.path.abspath('a.c'), 20)], - (0x5000, False): [ - ('touchpad_calc.constprop.42', os.path.abspath('b.c'), 40)], - (0x12000, False): [('trackpad_range', os.path.abspath('t.c'), 10)], - (0x13000, False): [('inlined_mul', os.path.abspath('x.c'), 12)], - (0x13100, False): [('inlined_mul', os.path.abspath('x.c'), 12)], - } - self.analyzer.annotation = { - 'add': { - 'hook_task.lto.573': ['touchpad_calc.lto.2501[a.c]'], - 'console_task': ['touchpad_calc[b.c]', 'inlined_mul_alias'], - 'hook_task[q.c]': ['hook_task'], - 'inlined_mul[x.c]': ['inlined_mul'], - 'toot_calc[t.c:1234]': ['hook_task'], - }, - 'remove': [ - ['touchpad?calc['], - 'touchpad_calc', - ['touchpad_calc[a.c]'], - ['task_unk[a.c]'], - ['touchpad_calc[x/a.c]'], - ['trackpad_range'], - ['inlined_mul'], - ['inlined_mul', 'console_task', 'touchpad_calc[a.c]'], - ['inlined_mul', 'inlined_mul_alias', 'console_task'], - ['inlined_mul', 'inlined_mul_alias', 'console_task'], - ], - } - (add_rules, remove_rules, invalid_sigtxts) = self.analyzer.LoadAnnotation() - self.assertEqual(invalid_sigtxts, {'touchpad?calc['}) - - signature_set = set() - for src_sig, dst_sigs in add_rules.items(): - signature_set.add(src_sig) - signature_set.update(dst_sigs) - - for remove_sigs in remove_rules: - signature_set.update(remove_sigs) - - (signature_map, failed_sigs) = self.analyzer.MapAnnotation(funcs, - signature_set) - result = self.analyzer.ResolveAnnotation(funcs) - (add_set, remove_list, eliminated_addrs, failed_sigs) = result - - expect_signature_map = { - ('hook_task', None, None): {funcs[0x1000]}, - ('touchpad_calc', os.path.abspath('a.c'), None): {funcs[0x4000]}, - ('touchpad_calc', os.path.abspath('b.c'), None): {funcs[0x5000]}, - ('console_task', None, None): {funcs[0x2000]}, - ('inlined_mul_alias', None, None): {funcs[0x13100]}, - ('inlined_mul', os.path.abspath('x.c'), None): {funcs[0x13000], - funcs[0x13100]}, - ('inlined_mul', None, None): {funcs[0x13000], funcs[0x13100]}, - } - self.assertEqual(len(signature_map), len(expect_signature_map)) - for sig, funclist in signature_map.items(): - self.assertEqual(set(funclist), expect_signature_map[sig]) - - self.assertEqual(add_set, { - (funcs[0x1000], funcs[0x4000]), - (funcs[0x1000], funcs[0x1000]), - (funcs[0x2000], funcs[0x5000]), - (funcs[0x2000], funcs[0x13100]), - (funcs[0x13000], funcs[0x13000]), - (funcs[0x13000], funcs[0x13100]), - (funcs[0x13100], funcs[0x13000]), - (funcs[0x13100], funcs[0x13100]), - }) - expect_remove_list = [ - [funcs[0x4000]], - [funcs[0x13000]], - [funcs[0x13100]], - [funcs[0x13000], funcs[0x2000], funcs[0x4000]], - [funcs[0x13100], funcs[0x2000], funcs[0x4000]], - [funcs[0x13000], funcs[0x13100], funcs[0x2000]], - [funcs[0x13100], funcs[0x13100], funcs[0x2000]], - ] - self.assertEqual(len(remove_list), len(expect_remove_list)) - for remove_path in remove_list: - self.assertTrue(remove_path in expect_remove_list) - - self.assertEqual(eliminated_addrs, {0x1002}) - self.assertEqual(failed_sigs, { - ('touchpad?calc[', sa.StackAnalyzer.ANNOTATION_ERROR_INVALID), - ('touchpad_calc', sa.StackAnalyzer.ANNOTATION_ERROR_AMBIGUOUS), - ('hook_task[q.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('task_unk[a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('touchpad_calc[x/a.c]', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - ('trackpad_range', sa.StackAnalyzer.ANNOTATION_ERROR_NOTFOUND), - }) - - def testPreprocessAnnotation(self): - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, 0x1000, False, funcs[0x1000])] - funcs[0x2000].callsites = [ - sa.Callsite(0x2002, 0x1000, False, funcs[0x1000]), - sa.Callsite(0x2006, None, True, None), - ] - add_set = { - (funcs[0x2000], funcs[0x2000]), - (funcs[0x2000], funcs[0x4000]), - (funcs[0x4000], funcs[0x1000]), - (funcs[0x4000], funcs[0x2000]), - } - remove_list = [ - [funcs[0x1000]], - [funcs[0x2000], funcs[0x2000]], - [funcs[0x4000], funcs[0x1000]], - [funcs[0x2000], funcs[0x4000], funcs[0x2000]], - [funcs[0x4000], funcs[0x1000], funcs[0x4000]], - ] - eliminated_addrs = {0x2006} - - remaining_remove_list = self.analyzer.PreprocessAnnotation(funcs, - add_set, - remove_list, - eliminated_addrs) - - expect_funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 0, []), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - expect_funcs[0x2000].callsites = [ - sa.Callsite(None, 0x4000, False, expect_funcs[0x4000])] - expect_funcs[0x4000].callsites = [ - sa.Callsite(None, 0x2000, False, expect_funcs[0x2000])] - self.assertEqual(funcs, expect_funcs) - self.assertEqual(remaining_remove_list, [ - [funcs[0x2000], funcs[0x4000], funcs[0x2000]], - ]) - - def testAndesAnalyzeDisassembly(self): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n' - ' 1004: 47 70\t\tmovi55 $r0, #1\n' - ' 1006: b1 13\tbnezs8 100929de <flash_command_write>\n' - ' 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n' - '00002000 <console_task>:\n' - ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n' - ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n' - ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n' - ' 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n' - '00004000 <touchpad_calc>:\n' - ' 4000: 47 70\t\tmovi55 $r0, #1\n' - '00010000 <look_task>:' - ) - function_map = self.analyzer.AnalyzeDisassembly(disasm_text) - func_hook_task = sa.Function(0x1000, 'hook_task', 48, [ - sa.Callsite(0x1006, 0x100929de, True, None)]) - expect_funcmap = { - 0x1000: func_hook_task, - 0x2000: sa.Function(0x2000, 'console_task', 16, - [sa.Callsite(0x2002, 0x1000, False, func_hook_task), - sa.Callsite(0x2006, 0x53968, True, None)]), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - self.assertEqual(function_map, expect_funcmap) - - def testArmAnalyzeDisassembly(self): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: dead beef\tfake\n' - ' 1004: 4770\t\tbx lr\n' - ' 1006: b113\tcbz r3, 100929de <flash_command_write>\n' - ' 1008: 00015cfc\t.word 0x00015cfc\n' - '00002000 <console_task>:\n' - ' 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n' - ' 2002: f00e fcc5\tbl 1000 <hook_task>\n' - ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n' - ' 200a: dead beef\tfake\n' - '00004000 <touchpad_calc>:\n' - ' 4000: 4770\t\tbx lr\n' - '00010000 <look_task>:' - ) - function_map = self.analyzer.AnalyzeDisassembly(disasm_text) - func_hook_task = sa.Function(0x1000, 'hook_task', 0, [ - sa.Callsite(0x1006, 0x100929de, True, None)]) - expect_funcmap = { - 0x1000: func_hook_task, - 0x2000: sa.Function(0x2000, 'console_task', 8, - [sa.Callsite(0x2002, 0x1000, False, func_hook_task), - sa.Callsite(0x2006, 0x53968, True, None)]), - 0x4000: sa.Function(0x4000, 'touchpad_calc', 0, []), - } - self.assertEqual(function_map, expect_funcmap) - - def testAnalyzeCallGraph(self): - funcs = { - 0x1000: sa.Function(0x1000, 'hook_task', 0, []), - 0x2000: sa.Function(0x2000, 'console_task', 8, []), - 0x3000: sa.Function(0x3000, 'task_a', 12, []), - 0x4000: sa.Function(0x4000, 'task_b', 96, []), - 0x5000: sa.Function(0x5000, 'task_c', 32, []), - 0x6000: sa.Function(0x6000, 'task_d', 100, []), - 0x7000: sa.Function(0x7000, 'task_e', 24, []), - 0x8000: sa.Function(0x8000, 'task_f', 20, []), - 0x9000: sa.Function(0x9000, 'task_g', 20, []), - 0x10000: sa.Function(0x10000, 'task_x', 16, []), - } - funcs[0x1000].callsites = [ - sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]), - sa.Callsite(0x1006, 0x4000, False, funcs[0x4000])] - funcs[0x2000].callsites = [ - sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]), - sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]), - sa.Callsite(0x200a, 0x10000, False, funcs[0x10000])] - funcs[0x3000].callsites = [ - sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]), - sa.Callsite(0x3006, 0x1000, False, funcs[0x1000])] - funcs[0x4000].callsites = [ - sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]), - sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]), - sa.Callsite(0x400a, 0x8000, False, funcs[0x8000])] - funcs[0x5000].callsites = [ - sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])] - funcs[0x7000].callsites = [ - sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])] - funcs[0x8000].callsites = [ - sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])] - funcs[0x9000].callsites = [ - sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])] - funcs[0x10000].callsites = [ - sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])] - - cycles = self.analyzer.AnalyzeCallGraph(funcs, [ - [funcs[0x2000]] * 2, - [funcs[0x10000], funcs[0x2000]] * 3, - [funcs[0x1000], funcs[0x3000], funcs[0x1000]] - ]) - - expect_func_stack = { - 0x1000: (268, [funcs[0x1000], - funcs[0x3000], - funcs[0x4000], - funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x2000: (208, [funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x5000], - funcs[0x4000], - funcs[0x7000]]), - 0x3000: (280, [funcs[0x3000], - funcs[0x1000], - funcs[0x3000], - funcs[0x4000], - funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x4000: (120, [funcs[0x4000], funcs[0x7000]]), - 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]), - 0x6000: (100, [funcs[0x6000]]), - 0x7000: (24, [funcs[0x7000]]), - 0x8000: (160, [funcs[0x8000], - funcs[0x9000], - funcs[0x4000], - funcs[0x7000]]), - 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]), - 0x10000: (200, [funcs[0x10000], - funcs[0x2000], - funcs[0x10000], - funcs[0x2000], - funcs[0x5000], - funcs[0x4000], - funcs[0x7000]]), - } - expect_cycles = [ - {funcs[0x4000], funcs[0x8000], funcs[0x9000]}, - {funcs[0x7000]}, - ] - for func in funcs.values(): - (stack_max_usage, stack_max_path) = expect_func_stack[func.address] - self.assertEqual(func.stack_max_usage, stack_max_usage) - self.assertEqual(func.stack_max_path, stack_max_path) - - self.assertEqual(len(cycles), len(expect_cycles)) - for cycle in cycles: - self.assertTrue(cycle in expect_cycles) - - @mock.patch('subprocess.check_output') - def testAddressToLine(self, checkoutput_mock): - checkoutput_mock.return_value = 'fake_func\n/test.c:1' - self.assertEqual(self.analyzer.AddressToLine(0x1234), - [('fake_func', '/test.c', 1)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '1234'], encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = 'fake_func\n/a.c:1\nbake_func\n/b.c:2\n' - self.assertEqual(self.analyzer.AddressToLine(0x1234, True), - [('fake_func', '/a.c', 1), ('bake_func', '/b.c', 2)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '1234', '-i'], - encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = 'fake_func\n/test.c:1 (discriminator 128)' - self.assertEqual(self.analyzer.AddressToLine(0x12345), - [('fake_func', '/test.c', 1)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '12345'], encoding='utf-8') - checkoutput_mock.reset_mock() - - checkoutput_mock.return_value = '??\n:?\nbake_func\n/b.c:2\n' - self.assertEqual(self.analyzer.AddressToLine(0x123456), - [None, ('bake_func', '/b.c', 2)]) - checkoutput_mock.assert_called_once_with( - ['addr2line', '-f', '-e', './ec.RW.elf', '123456'], encoding='utf-8') - checkoutput_mock.reset_mock() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'addr2line failed to resolve lines.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.AddressToLine(0x5678) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run addr2line.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.AddressToLine(0x9012) - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine') - def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n' - ' 1002: 47 70\t\tmovi55 $r0, #1\n' - ' 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n' - '00002000 <console_task>:\n' - ' 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n' - ' 2002: f0 0e fc c5\tjal 1000 <hook_task>\n' - ' 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n' - ' 200a: 12 34 56 78\tjral5 $r0\n' - ) - - addrtoline_mock.return_value = [('??', '??', 0)] - self.analyzer.annotation = { - 'exception_frame_size': 64, - 'remove': [['fake_func']], - } - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.return_value = disasm_text - self.analyzer.Analyze() - print_mock.assert_has_calls([ - mock.call( - 'Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048'), - mock.call('Call Trace:'), - mock.call(' hook_task (32) [??:0] 1000'), - mock.call( - 'Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460'), - mock.call('Call Trace:'), - mock.call(' console_task (16) [??:0] 2000'), - mock.call(' -> ??[??:0] 2002'), - mock.call(' hook_task (32) [??:0] 1000'), - mock.call('Unresolved indirect callsites:'), - mock.call(' In function console_task:'), - mock.call(' -> ??[??:0] 200a'), - mock.call('Unresolved annotation signatures:'), - mock.call(' fake_func: function is not found'), - ]) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run objdump.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.Analyze() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'objdump failed to disassemble.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.Analyze() - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.StackAnalyzer.AddressToLine') - def testArmAnalyze(self, addrtoline_mock, checkoutput_mock): - disasm_text = ( - '\n' - 'build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm' - '\n' - 'Disassembly of section .text:\n' - '\n' - '00000900 <wook_task>:\n' - ' ...\n' - '00001000 <hook_task>:\n' - ' 1000: b508\t\tpush {r3, lr}\n' - ' 1002: 4770\t\tbx lr\n' - ' 1006: 00015cfc\t.word 0x00015cfc\n' - '00002000 <console_task>:\n' - ' 2000: b508\t\tpush {r3, lr}\n' - ' 2002: f00e fcc5\tbl 1000 <hook_task>\n' - ' 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n' - ' 200a: 1234 5678\tb.w sl\n' - ) - - addrtoline_mock.return_value = [('??', '??', 0)] - self.analyzer.annotation = { - 'exception_frame_size': 64, - 'remove': [['fake_func']], - } - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.return_value = disasm_text - self.analyzer.Analyze() - print_mock.assert_has_calls([ - mock.call( - 'Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048'), - mock.call('Call Trace:'), - mock.call(' hook_task (8) [??:0] 1000'), - mock.call( - 'Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460'), - mock.call('Call Trace:'), - mock.call(' console_task (8) [??:0] 2000'), - mock.call(' -> ??[??:0] 2002'), - mock.call(' hook_task (8) [??:0] 1000'), - mock.call('Unresolved indirect callsites:'), - mock.call(' In function console_task:'), - mock.call(' -> ??[??:0] 200a'), - mock.call('Unresolved annotation signatures:'), - mock.call(' fake_func: function is not found'), - ]) - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'Failed to run objdump.'): - checkoutput_mock.side_effect = OSError() - self.analyzer.Analyze() - - with self.assertRaisesRegexp(sa.StackAnalyzerError, - 'objdump failed to disassemble.'): - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - self.analyzer.Analyze() - - @mock.patch('subprocess.check_output') - @mock.patch('stack_analyzer.ParseArgs') - def testMain(self, parseargs_mock, checkoutput_mock): - symbol_text = ('1000 g F .text 0000015c .hidden hook_task\n' - '2000 g F .text 0000051c .hidden console_task\n') - rodata_text = ( - '\n' - 'Contents of section .rodata:\n' - ' 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n' - ) - - args = mock.MagicMock(elf_path='./ec.RW.elf', - export_taskinfo='fake', - section='RW', - objdump='objdump', - addr2line='addr2line', - annotation='fake') - parseargs_mock.return_value = args - - with mock.patch('os.path.exists') as path_mock: - path_mock.return_value = False - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - sa.main() - print_mock.assert_any_call( - 'Warning: Annotation file fake does not exist.') - - with mock.patch('os.path.exists') as path_mock: - path_mock.return_value = True - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - open_mock.side_effect = IOError() - sa.main() - print_mock.assert_called_once_with( - 'Error: Failed to open annotation file fake.') - - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', mock.mock_open()) as open_mock: - open_mock.return_value.read.side_effect = ['{', ''] - sa.main() - open_mock.assert_called_once_with('fake', 'r') - print_mock.assert_called_once_with( - 'Error: Failed to parse annotation file fake.') - - with mock.patch('builtins.print') as print_mock: - with mock.patch('builtins.open', - mock.mock_open(read_data='')) as open_mock: - sa.main() - print_mock.assert_called_once_with( - 'Error: Invalid annotation file fake.') - - args.annotation = None - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = [symbol_text, rodata_text] - sa.main() - print_mock.assert_called_once_with( - 'Error: Failed to load export_taskinfo.') - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = subprocess.CalledProcessError(1, '') - sa.main() - print_mock.assert_called_once_with( - 'Error: objdump failed to dump symbol table or rodata.') - - with mock.patch('builtins.print') as print_mock: - checkoutput_mock.side_effect = OSError() - sa.main() - print_mock.assert_called_once_with('Error: Failed to run objdump.') - - -if __name__ == '__main__': - unittest.main() + remove_list = [ + [funcs[0x1000]], + [funcs[0x2000], funcs[0x2000]], + [funcs[0x4000], funcs[0x1000]], + [funcs[0x2000], funcs[0x4000], funcs[0x2000]], + [funcs[0x4000], funcs[0x1000], funcs[0x4000]], + ] + eliminated_addrs = {0x2006} + + remaining_remove_list = self.analyzer.PreprocessAnnotation( + funcs, add_set, remove_list, eliminated_addrs + ) + + expect_funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 0, []), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + expect_funcs[0x2000].callsites = [ + sa.Callsite(None, 0x4000, False, expect_funcs[0x4000]) + ] + expect_funcs[0x4000].callsites = [ + sa.Callsite(None, 0x2000, False, expect_funcs[0x2000]) + ] + self.assertEqual(funcs, expect_funcs) + self.assertEqual( + remaining_remove_list, + [ + [funcs[0x2000], funcs[0x4000], funcs[0x2000]], + ], + ) + + def testAndesAnalyzeDisassembly(self): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: fc 42\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n" + " 1004: 47 70\t\tmovi55 $r0, #1\n" + " 1006: b1 13\tbnezs8 100929de <flash_command_write>\n" + " 1008: 00 01 5c fc\tbne $r6, $r0, 2af6a\n" + "00002000 <console_task>:\n" + " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n" + " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n" + " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n" + " 200a: de ad be ef\tswi.gp $r0, [ + #-11036]\n" + "00004000 <touchpad_calc>:\n" + " 4000: 47 70\t\tmovi55 $r0, #1\n" + "00010000 <look_task>:" + ) + function_map = self.analyzer.AnalyzeDisassembly(disasm_text) + func_hook_task = sa.Function( + 0x1000, "hook_task", 48, [sa.Callsite(0x1006, 0x100929DE, True, None)] + ) + expect_funcmap = { + 0x1000: func_hook_task, + 0x2000: sa.Function( + 0x2000, + "console_task", + 16, + [ + sa.Callsite(0x2002, 0x1000, False, func_hook_task), + sa.Callsite(0x2006, 0x53968, True, None), + ], + ), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + self.assertEqual(function_map, expect_funcmap) + + def testArmAnalyzeDisassembly(self): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: dead beef\tfake\n" + " 1004: 4770\t\tbx lr\n" + " 1006: b113\tcbz r3, 100929de <flash_command_write>\n" + " 1008: 00015cfc\t.word 0x00015cfc\n" + "00002000 <console_task>:\n" + " 2000: b508\t\tpush {r3, lr} ; malformed comments,; r0, r1 \n" + " 2002: f00e fcc5\tbl 1000 <hook_task>\n" + " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n" + " 200a: dead beef\tfake\n" + "00004000 <touchpad_calc>:\n" + " 4000: 4770\t\tbx lr\n" + "00010000 <look_task>:" + ) + function_map = self.analyzer.AnalyzeDisassembly(disasm_text) + func_hook_task = sa.Function( + 0x1000, "hook_task", 0, [sa.Callsite(0x1006, 0x100929DE, True, None)] + ) + expect_funcmap = { + 0x1000: func_hook_task, + 0x2000: sa.Function( + 0x2000, + "console_task", + 8, + [ + sa.Callsite(0x2002, 0x1000, False, func_hook_task), + sa.Callsite(0x2006, 0x53968, True, None), + ], + ), + 0x4000: sa.Function(0x4000, "touchpad_calc", 0, []), + } + self.assertEqual(function_map, expect_funcmap) + + def testAnalyzeCallGraph(self): + funcs = { + 0x1000: sa.Function(0x1000, "hook_task", 0, []), + 0x2000: sa.Function(0x2000, "console_task", 8, []), + 0x3000: sa.Function(0x3000, "task_a", 12, []), + 0x4000: sa.Function(0x4000, "task_b", 96, []), + 0x5000: sa.Function(0x5000, "task_c", 32, []), + 0x6000: sa.Function(0x6000, "task_d", 100, []), + 0x7000: sa.Function(0x7000, "task_e", 24, []), + 0x8000: sa.Function(0x8000, "task_f", 20, []), + 0x9000: sa.Function(0x9000, "task_g", 20, []), + 0x10000: sa.Function(0x10000, "task_x", 16, []), + } + funcs[0x1000].callsites = [ + sa.Callsite(0x1002, 0x3000, False, funcs[0x3000]), + sa.Callsite(0x1006, 0x4000, False, funcs[0x4000]), + ] + funcs[0x2000].callsites = [ + sa.Callsite(0x2002, 0x5000, False, funcs[0x5000]), + sa.Callsite(0x2006, 0x2000, False, funcs[0x2000]), + sa.Callsite(0x200A, 0x10000, False, funcs[0x10000]), + ] + funcs[0x3000].callsites = [ + sa.Callsite(0x3002, 0x4000, False, funcs[0x4000]), + sa.Callsite(0x3006, 0x1000, False, funcs[0x1000]), + ] + funcs[0x4000].callsites = [ + sa.Callsite(0x4002, 0x6000, True, funcs[0x6000]), + sa.Callsite(0x4006, 0x7000, False, funcs[0x7000]), + sa.Callsite(0x400A, 0x8000, False, funcs[0x8000]), + ] + funcs[0x5000].callsites = [sa.Callsite(0x5002, 0x4000, False, funcs[0x4000])] + funcs[0x7000].callsites = [sa.Callsite(0x7002, 0x7000, False, funcs[0x7000])] + funcs[0x8000].callsites = [sa.Callsite(0x8002, 0x9000, False, funcs[0x9000])] + funcs[0x9000].callsites = [sa.Callsite(0x9002, 0x4000, False, funcs[0x4000])] + funcs[0x10000].callsites = [sa.Callsite(0x10002, 0x2000, False, funcs[0x2000])] + + cycles = self.analyzer.AnalyzeCallGraph( + funcs, + [ + [funcs[0x2000]] * 2, + [funcs[0x10000], funcs[0x2000]] * 3, + [funcs[0x1000], funcs[0x3000], funcs[0x1000]], + ], + ) + + expect_func_stack = { + 0x1000: ( + 268, + [ + funcs[0x1000], + funcs[0x3000], + funcs[0x4000], + funcs[0x8000], + funcs[0x9000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x2000: ( + 208, + [ + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x5000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x3000: ( + 280, + [ + funcs[0x3000], + funcs[0x1000], + funcs[0x3000], + funcs[0x4000], + funcs[0x8000], + funcs[0x9000], + funcs[0x4000], + funcs[0x7000], + ], + ), + 0x4000: (120, [funcs[0x4000], funcs[0x7000]]), + 0x5000: (152, [funcs[0x5000], funcs[0x4000], funcs[0x7000]]), + 0x6000: (100, [funcs[0x6000]]), + 0x7000: (24, [funcs[0x7000]]), + 0x8000: (160, [funcs[0x8000], funcs[0x9000], funcs[0x4000], funcs[0x7000]]), + 0x9000: (140, [funcs[0x9000], funcs[0x4000], funcs[0x7000]]), + 0x10000: ( + 200, + [ + funcs[0x10000], + funcs[0x2000], + funcs[0x10000], + funcs[0x2000], + funcs[0x5000], + funcs[0x4000], + funcs[0x7000], + ], + ), + } + expect_cycles = [ + {funcs[0x4000], funcs[0x8000], funcs[0x9000]}, + {funcs[0x7000]}, + ] + for func in funcs.values(): + (stack_max_usage, stack_max_path) = expect_func_stack[func.address] + self.assertEqual(func.stack_max_usage, stack_max_usage) + self.assertEqual(func.stack_max_path, stack_max_path) + + self.assertEqual(len(cycles), len(expect_cycles)) + for cycle in cycles: + self.assertTrue(cycle in expect_cycles) + + @mock.patch("subprocess.check_output") + def testAddressToLine(self, checkoutput_mock): + checkoutput_mock.return_value = "fake_func\n/test.c:1" + self.assertEqual( + self.analyzer.AddressToLine(0x1234), [("fake_func", "/test.c", 1)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "1234"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "fake_func\n/a.c:1\nbake_func\n/b.c:2\n" + self.assertEqual( + self.analyzer.AddressToLine(0x1234, True), + [("fake_func", "/a.c", 1), ("bake_func", "/b.c", 2)], + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "1234", "-i"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "fake_func\n/test.c:1 (discriminator 128)" + self.assertEqual( + self.analyzer.AddressToLine(0x12345), [("fake_func", "/test.c", 1)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "12345"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + checkoutput_mock.return_value = "??\n:?\nbake_func\n/b.c:2\n" + self.assertEqual( + self.analyzer.AddressToLine(0x123456), [None, ("bake_func", "/b.c", 2)] + ) + checkoutput_mock.assert_called_once_with( + ["addr2line", "-f", "-e", "./ec.RW.elf", "123456"], encoding="utf-8" + ) + checkoutput_mock.reset_mock() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "addr2line failed to resolve lines." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.AddressToLine(0x5678) + + with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run addr2line."): + checkoutput_mock.side_effect = OSError() + self.analyzer.AddressToLine(0x9012) + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine") + def testAndesAnalyze(self, addrtoline_mock, checkoutput_mock): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-nds32le" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: fc 00\t\tpush25 $r10, #16 ! {$r6~$r10, $fp, $gp, $lp}\n" + " 1002: 47 70\t\tmovi55 $r0, #1\n" + " 1006: 00 01 5c fc\tbne $r6, $r0, 2af6a\n" + "00002000 <console_task>:\n" + " 2000: fc 00\t\tpush25 $r6, #0 ! {$r6, $fp, $gp, $lp} \n" + " 2002: f0 0e fc c5\tjal 1000 <hook_task>\n" + " 2006: f0 0e bd 3b\tj 53968 <get_program_memory_addr>\n" + " 200a: 12 34 56 78\tjral5 $r0\n" + ) + + addrtoline_mock.return_value = [("??", "??", 0)] + self.analyzer.annotation = { + "exception_frame_size": 64, + "remove": [["fake_func"]], + } + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.return_value = disasm_text + self.analyzer.Analyze() + print_mock.assert_has_calls( + [ + mock.call( + "Task: HOOKS, Max size: 96 (32 + 64), Allocated size: 2048" + ), + mock.call("Call Trace:"), + mock.call(" hook_task (32) [??:0] 1000"), + mock.call( + "Task: CONSOLE, Max size: 112 (48 + 64), Allocated size: 460" + ), + mock.call("Call Trace:"), + mock.call(" console_task (16) [??:0] 2000"), + mock.call(" -> ??[??:0] 2002"), + mock.call(" hook_task (32) [??:0] 1000"), + mock.call("Unresolved indirect callsites:"), + mock.call(" In function console_task:"), + mock.call(" -> ??[??:0] 200a"), + mock.call("Unresolved annotation signatures:"), + mock.call(" fake_func: function is not found"), + ] + ) + + with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."): + checkoutput_mock.side_effect = OSError() + self.analyzer.Analyze() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "objdump failed to disassemble." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.Analyze() + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.StackAnalyzer.AddressToLine") + def testArmAnalyze(self, addrtoline_mock, checkoutput_mock): + disasm_text = ( + "\n" + "build/{BOARD}/RW/ec.RW.elf: file format elf32-littlearm" + "\n" + "Disassembly of section .text:\n" + "\n" + "00000900 <wook_task>:\n" + " ...\n" + "00001000 <hook_task>:\n" + " 1000: b508\t\tpush {r3, lr}\n" + " 1002: 4770\t\tbx lr\n" + " 1006: 00015cfc\t.word 0x00015cfc\n" + "00002000 <console_task>:\n" + " 2000: b508\t\tpush {r3, lr}\n" + " 2002: f00e fcc5\tbl 1000 <hook_task>\n" + " 2006: f00e bd3b\tb.w 53968 <get_program_memory_addr>\n" + " 200a: 1234 5678\tb.w sl\n" + ) + + addrtoline_mock.return_value = [("??", "??", 0)] + self.analyzer.annotation = { + "exception_frame_size": 64, + "remove": [["fake_func"]], + } + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.return_value = disasm_text + self.analyzer.Analyze() + print_mock.assert_has_calls( + [ + mock.call( + "Task: HOOKS, Max size: 72 (8 + 64), Allocated size: 2048" + ), + mock.call("Call Trace:"), + mock.call(" hook_task (8) [??:0] 1000"), + mock.call( + "Task: CONSOLE, Max size: 80 (16 + 64), Allocated size: 460" + ), + mock.call("Call Trace:"), + mock.call(" console_task (8) [??:0] 2000"), + mock.call(" -> ??[??:0] 2002"), + mock.call(" hook_task (8) [??:0] 1000"), + mock.call("Unresolved indirect callsites:"), + mock.call(" In function console_task:"), + mock.call(" -> ??[??:0] 200a"), + mock.call("Unresolved annotation signatures:"), + mock.call(" fake_func: function is not found"), + ] + ) + + with self.assertRaisesRegexp(sa.StackAnalyzerError, "Failed to run objdump."): + checkoutput_mock.side_effect = OSError() + self.analyzer.Analyze() + + with self.assertRaisesRegexp( + sa.StackAnalyzerError, "objdump failed to disassemble." + ): + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + self.analyzer.Analyze() + + @mock.patch("subprocess.check_output") + @mock.patch("stack_analyzer.ParseArgs") + def testMain(self, parseargs_mock, checkoutput_mock): + symbol_text = ( + "1000 g F .text 0000015c .hidden hook_task\n" + "2000 g F .text 0000051c .hidden console_task\n" + ) + rodata_text = ( + "\n" + "Contents of section .rodata:\n" + " 20000 dead1000 00100000 dead2000 00200000 He..f.He..s.\n" + ) + + args = mock.MagicMock( + elf_path="./ec.RW.elf", + export_taskinfo="fake", + section="RW", + objdump="objdump", + addr2line="addr2line", + annotation="fake", + ) + parseargs_mock.return_value = args + + with mock.patch("os.path.exists") as path_mock: + path_mock.return_value = False + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + sa.main() + print_mock.assert_any_call( + "Warning: Annotation file fake does not exist." + ) + + with mock.patch("os.path.exists") as path_mock: + path_mock.return_value = True + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + open_mock.side_effect = IOError() + sa.main() + print_mock.assert_called_once_with( + "Error: Failed to open annotation file fake." + ) + + with mock.patch("builtins.print") as print_mock: + with mock.patch("builtins.open", mock.mock_open()) as open_mock: + open_mock.return_value.read.side_effect = ["{", ""] + sa.main() + open_mock.assert_called_once_with("fake", "r") + print_mock.assert_called_once_with( + "Error: Failed to parse annotation file fake." + ) + + with mock.patch("builtins.print") as print_mock: + with mock.patch( + "builtins.open", mock.mock_open(read_data="") + ) as open_mock: + sa.main() + print_mock.assert_called_once_with( + "Error: Invalid annotation file fake." + ) + + args.annotation = None + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = [symbol_text, rodata_text] + sa.main() + print_mock.assert_called_once_with("Error: Failed to load export_taskinfo.") + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = subprocess.CalledProcessError(1, "") + sa.main() + print_mock.assert_called_once_with( + "Error: objdump failed to dump symbol table or rodata." + ) + + with mock.patch("builtins.print") as print_mock: + checkoutput_mock.side_effect = OSError() + sa.main() + print_mock.assert_called_once_with("Error: Failed to run objdump.") + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/tigertool/ecusb/__init__.py b/extra/tigertool/ecusb/__init__.py index fe4dbc6749..7a48fdb360 100644 --- a/extra/tigertool/ecusb/__init__.py +++ b/extra/tigertool/ecusb/__init__.py @@ -6,4 +6,4 @@ # pylint: disable=bad-indentation,docstring-section-indent # pylint: disable=docstring-trailing-quotes -__all__ = ['tiny_servo_common', 'stm32usb', 'stm32uart', 'pty_driver'] +__all__ = ["tiny_servo_common", "stm32usb", "stm32uart", "pty_driver"] diff --git a/extra/tigertool/ecusb/pty_driver.py b/extra/tigertool/ecusb/pty_driver.py index 09ef8c42e4..037ff8d529 100644 --- a/extra/tigertool/ecusb/pty_driver.py +++ b/extra/tigertool/ecusb/pty_driver.py @@ -17,8 +17,9 @@ import ast import errno import fcntl import os -import pexpect import time + +import pexpect from pexpect import fdpexpect # Expecting a result in 3 seconds is plenty even for slow platforms. @@ -27,281 +28,285 @@ FLUSH_UART_TIMEOUT = 1 class ptyError(Exception): - """Exception class for pty errors.""" + """Exception class for pty errors.""" UART_PARAMS = { - 'uart_cmd': None, - 'uart_multicmd': None, - 'uart_regexp': None, - 'uart_timeout': DEFAULT_UART_TIMEOUT, + "uart_cmd": None, + "uart_multicmd": None, + "uart_regexp": None, + "uart_timeout": DEFAULT_UART_TIMEOUT, } class ptyDriver(object): - """Automate interactive commands on a pty interface.""" - def __init__(self, interface, params, fast=False): - """Init class variables.""" - self._child = None - self._fd = None - self._interface = interface - self._pty_path = self._interface.get_pty() - self._dict = UART_PARAMS.copy() - self._fast = fast - - def __del__(self): - self.close() - - def close(self): - """Close any open files and interfaces.""" - if self._fd: - self._close() - self._interface.close() - - def _open(self): - """Connect to serial device and create pexpect interface.""" - assert self._fd is None - self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK) - # Don't allow forked processes to access. - fcntl.fcntl(self._fd, fcntl.F_SETFD, - fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC) - self._child = fdpexpect.fdspawn(self._fd) - # pexpect defaults to a 100ms delay before sending characters, to - # work around race conditions in ssh. We don't need this feature - # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up. - if self._fast: - self._child.delaybeforesend = 0.001 - - def _close(self): - """Close serial device connection.""" - os.close(self._fd) - self._fd = None - self._child = None - - def _flush(self): - """Flush device output to prevent previous messages interfering.""" - if self._child.sendline('') != 1: - raise ptyError('Failed to send newline.') - # Have a maximum timeout for the flush operation. We should have cleared - # all data from the buffer, but if data is regularly being generated, we - # can't guarantee it will ever stop. - flush_end_time = time.time() + FLUSH_UART_TIMEOUT - while time.time() <= flush_end_time: - try: - self._child.expect('.', timeout=0.01) - except (pexpect.TIMEOUT, pexpect.EOF): - break - except OSError as e: - # EAGAIN indicates no data available, maybe we didn't wait long enough. - if e.errno != errno.EAGAIN: - raise - break - - def _send(self, cmds): - """Send command to EC. - - This function always flushes serial device before sending, and is used as - a wrapper function to make sure the channel is always flushed before - sending commands. - - Args: - cmds: The commands to send to the device, either a list or a string. - - Raises: - ptyError: Raised when writing to the device fails. - """ - self._flush() - if not isinstance(cmds, list): - cmds = [cmds] - for cmd in cmds: - if self._child.sendline(cmd) != len(cmd) + 1: - raise ptyError('Failed to send command.') - - def _issue_cmd(self, cmds): - """Send command to the device and do not wait for response. - - Args: - cmds: The commands to send to the device, either a list or a string. - """ - self._issue_cmd_get_results(cmds, []) - - def _issue_cmd_get_results(self, cmds, - regex_list, timeout=DEFAULT_UART_TIMEOUT): - """Send command to the device and wait for response. - - This function waits for response message matching a regular - expressions. - - Args: - cmds: The commands issued, either a list or a string. - regex_list: List of Regular expressions used to match response message. - Note1, list must be ordered. - Note2, empty list sends and returns. - timeout: time to wait for matching results before failing. - - Returns: - List of tuples, each of which contains the entire matched string and - all the subgroups of the match. None if not matched. - For example: - response of the given command: - High temp: 37.2 - Low temp: 36.4 - regex_list: - ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)'] - returns: - [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')] - - Raises: - ptyError: If timed out waiting for a response - """ - result_list = [] - self._open() - try: - self._send(cmds) - for regex in regex_list: - self._child.expect(regex, timeout) - match = self._child.match - lastindex = match.lastindex if match and match.lastindex else 0 - # Create a tuple which contains the entire matched string and all - # the subgroups of the match. - result = match.group(*range(lastindex + 1)) if match else None - if result: - result = tuple(res.decode('utf-8') for res in result) - result_list.append(result) - except pexpect.TIMEOUT: - raise ptyError('Timeout waiting for response.') - finally: - if not regex_list: - # Must be longer than delaybeforesend - time.sleep(0.1) - self._close() - return result_list - - def _issue_cmd_get_multi_results(self, cmd, regex): - """Send command to the device and wait for multiple response. - - This function waits for arbitrary number of response message - matching a regular expression. - - Args: - cmd: The command issued. - regex: Regular expression used to match response message. - - Returns: - List of tuples, each of which contains the entire matched string and - all the subgroups of the match. None if not matched. - """ - result_list = [] - self._open() - try: - self._send(cmd) - while True: + """Automate interactive commands on a pty interface.""" + + def __init__(self, interface, params, fast=False): + """Init class variables.""" + self._child = None + self._fd = None + self._interface = interface + self._pty_path = self._interface.get_pty() + self._dict = UART_PARAMS.copy() + self._fast = fast + + def __del__(self): + self.close() + + def close(self): + """Close any open files and interfaces.""" + if self._fd: + self._close() + self._interface.close() + + def _open(self): + """Connect to serial device and create pexpect interface.""" + assert self._fd is None + self._fd = os.open(self._pty_path, os.O_RDWR | os.O_NONBLOCK) + # Don't allow forked processes to access. + fcntl.fcntl( + self._fd, + fcntl.F_SETFD, + fcntl.fcntl(self._fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC, + ) + self._child = fdpexpect.fdspawn(self._fd) + # pexpect defaults to a 100ms delay before sending characters, to + # work around race conditions in ssh. We don't need this feature + # so we'll change delaybeforesend from 0.1 to 0.001 to speed things up. + if self._fast: + self._child.delaybeforesend = 0.001 + + def _close(self): + """Close serial device connection.""" + os.close(self._fd) + self._fd = None + self._child = None + + def _flush(self): + """Flush device output to prevent previous messages interfering.""" + if self._child.sendline("") != 1: + raise ptyError("Failed to send newline.") + # Have a maximum timeout for the flush operation. We should have cleared + # all data from the buffer, but if data is regularly being generated, we + # can't guarantee it will ever stop. + flush_end_time = time.time() + FLUSH_UART_TIMEOUT + while time.time() <= flush_end_time: + try: + self._child.expect(".", timeout=0.01) + except (pexpect.TIMEOUT, pexpect.EOF): + break + except OSError as e: + # EAGAIN indicates no data available, maybe we didn't wait long enough. + if e.errno != errno.EAGAIN: + raise + break + + def _send(self, cmds): + """Send command to EC. + + This function always flushes serial device before sending, and is used as + a wrapper function to make sure the channel is always flushed before + sending commands. + + Args: + cmds: The commands to send to the device, either a list or a string. + + Raises: + ptyError: Raised when writing to the device fails. + """ + self._flush() + if not isinstance(cmds, list): + cmds = [cmds] + for cmd in cmds: + if self._child.sendline(cmd) != len(cmd) + 1: + raise ptyError("Failed to send command.") + + def _issue_cmd(self, cmds): + """Send command to the device and do not wait for response. + + Args: + cmds: The commands to send to the device, either a list or a string. + """ + self._issue_cmd_get_results(cmds, []) + + def _issue_cmd_get_results(self, cmds, regex_list, timeout=DEFAULT_UART_TIMEOUT): + """Send command to the device and wait for response. + + This function waits for response message matching a regular + expressions. + + Args: + cmds: The commands issued, either a list or a string. + regex_list: List of Regular expressions used to match response message. + Note1, list must be ordered. + Note2, empty list sends and returns. + timeout: time to wait for matching results before failing. + + Returns: + List of tuples, each of which contains the entire matched string and + all the subgroups of the match. None if not matched. + For example: + response of the given command: + High temp: 37.2 + Low temp: 36.4 + regex_list: + ['High temp: (\d+)\.(\d+)', 'Low temp: (\d+)\.(\d+)'] + returns: + [('High temp: 37.2', '37', '2'), ('Low temp: 36.4', '36', '4')] + + Raises: + ptyError: If timed out waiting for a response + """ + result_list = [] + self._open() try: - self._child.expect(regex, timeout=0.1) - match = self._child.match - lastindex = match.lastindex if match and match.lastindex else 0 - # Create a tuple which contains the entire matched string and all - # the subgroups of the match. - result = match.group(*range(lastindex + 1)) if match else None - if result: - result = tuple(res.decode('utf-8') for res in result) - result_list.append(result) + self._send(cmds) + for regex in regex_list: + self._child.expect(regex, timeout) + match = self._child.match + lastindex = match.lastindex if match and match.lastindex else 0 + # Create a tuple which contains the entire matched string and all + # the subgroups of the match. + result = match.group(*range(lastindex + 1)) if match else None + if result: + result = tuple(res.decode("utf-8") for res in result) + result_list.append(result) except pexpect.TIMEOUT: - break - finally: - self._close() - return result_list - - def _Set_uart_timeout(self, timeout): - """Set timeout value for waiting for the device response. - - Args: - timeout: Timeout value in second. - """ - self._dict['uart_timeout'] = timeout - - def _Get_uart_timeout(self): - """Get timeout value for waiting for the device response. - - Returns: - Timeout value in second. - """ - return self._dict['uart_timeout'] - - def _Set_uart_regexp(self, regexp): - """Set the list of regular expressions which matches the command response. - - Args: - regexp: A string which contains a list of regular expressions. - """ - if not isinstance(regexp, str): - raise ptyError('The argument regexp should be a string.') - self._dict['uart_regexp'] = ast.literal_eval(regexp) - - def _Get_uart_regexp(self): - """Get the list of regular expressions which matches the command response. - - Returns: - A string which contains a list of regular expressions. - """ - return str(self._dict['uart_regexp']) - - def _Set_uart_cmd(self, cmd): - """Set the UART command and send it to the device. - - If ec_uart_regexp is 'None', the command is just sent and it doesn't care - about its response. - - If ec_uart_regexp is not 'None', the command is send and its response, - which matches the regular expression of ec_uart_regexp, will be kept. - Use its getter to obtain this result. If no match after ec_uart_timeout - seconds, a timeout error will be raised. - - Args: - cmd: A string of UART command. - """ - if self._dict['uart_regexp']: - self._dict['uart_cmd'] = self._issue_cmd_get_results( - cmd, self._dict['uart_regexp'], self._dict['uart_timeout']) - else: - self._dict['uart_cmd'] = None - self._issue_cmd(cmd) - - def _Set_uart_multicmd(self, cmds): - """Set multiple UART commands and send them to the device. - - Note that ec_uart_regexp is not supported to match the results. - - Args: - cmds: A semicolon-separated string of UART commands. - """ - self._issue_cmd(cmds.split(';')) - - def _Get_uart_cmd(self): - """Get the result of the latest UART command. - - Returns: - A string which contains a list of tuples, each of which contains the - entire matched string and all the subgroups of the match. 'None' if - the ec_uart_regexp is 'None'. - """ - return str(self._dict['uart_cmd']) - - def _Set_uart_capture(self, cmd): - """Set UART capture mode (on or off). - - Once capture is enabled, UART output could be collected periodically by - invoking _Get_uart_stream() below. - - Args: - cmd: True for on, False for off - """ - self._interface.set_capture_active(cmd) - - def _Get_uart_capture(self): - """Get the UART capture mode (on or off).""" - return self._interface.get_capture_active() - - def _Get_uart_stream(self): - """Get uart stream generated since last time.""" - return self._interface.get_stream() + raise ptyError("Timeout waiting for response.") + finally: + if not regex_list: + # Must be longer than delaybeforesend + time.sleep(0.1) + self._close() + return result_list + + def _issue_cmd_get_multi_results(self, cmd, regex): + """Send command to the device and wait for multiple response. + + This function waits for arbitrary number of response message + matching a regular expression. + + Args: + cmd: The command issued. + regex: Regular expression used to match response message. + + Returns: + List of tuples, each of which contains the entire matched string and + all the subgroups of the match. None if not matched. + """ + result_list = [] + self._open() + try: + self._send(cmd) + while True: + try: + self._child.expect(regex, timeout=0.1) + match = self._child.match + lastindex = match.lastindex if match and match.lastindex else 0 + # Create a tuple which contains the entire matched string and all + # the subgroups of the match. + result = match.group(*range(lastindex + 1)) if match else None + if result: + result = tuple(res.decode("utf-8") for res in result) + result_list.append(result) + except pexpect.TIMEOUT: + break + finally: + self._close() + return result_list + + def _Set_uart_timeout(self, timeout): + """Set timeout value for waiting for the device response. + + Args: + timeout: Timeout value in second. + """ + self._dict["uart_timeout"] = timeout + + def _Get_uart_timeout(self): + """Get timeout value for waiting for the device response. + + Returns: + Timeout value in second. + """ + return self._dict["uart_timeout"] + + def _Set_uart_regexp(self, regexp): + """Set the list of regular expressions which matches the command response. + + Args: + regexp: A string which contains a list of regular expressions. + """ + if not isinstance(regexp, str): + raise ptyError("The argument regexp should be a string.") + self._dict["uart_regexp"] = ast.literal_eval(regexp) + + def _Get_uart_regexp(self): + """Get the list of regular expressions which matches the command response. + + Returns: + A string which contains a list of regular expressions. + """ + return str(self._dict["uart_regexp"]) + + def _Set_uart_cmd(self, cmd): + """Set the UART command and send it to the device. + + If ec_uart_regexp is 'None', the command is just sent and it doesn't care + about its response. + + If ec_uart_regexp is not 'None', the command is send and its response, + which matches the regular expression of ec_uart_regexp, will be kept. + Use its getter to obtain this result. If no match after ec_uart_timeout + seconds, a timeout error will be raised. + + Args: + cmd: A string of UART command. + """ + if self._dict["uart_regexp"]: + self._dict["uart_cmd"] = self._issue_cmd_get_results( + cmd, self._dict["uart_regexp"], self._dict["uart_timeout"] + ) + else: + self._dict["uart_cmd"] = None + self._issue_cmd(cmd) + + def _Set_uart_multicmd(self, cmds): + """Set multiple UART commands and send them to the device. + + Note that ec_uart_regexp is not supported to match the results. + + Args: + cmds: A semicolon-separated string of UART commands. + """ + self._issue_cmd(cmds.split(";")) + + def _Get_uart_cmd(self): + """Get the result of the latest UART command. + + Returns: + A string which contains a list of tuples, each of which contains the + entire matched string and all the subgroups of the match. 'None' if + the ec_uart_regexp is 'None'. + """ + return str(self._dict["uart_cmd"]) + + def _Set_uart_capture(self, cmd): + """Set UART capture mode (on or off). + + Once capture is enabled, UART output could be collected periodically by + invoking _Get_uart_stream() below. + + Args: + cmd: True for on, False for off + """ + self._interface.set_capture_active(cmd) + + def _Get_uart_capture(self): + """Get the UART capture mode (on or off).""" + return self._interface.get_capture_active() + + def _Get_uart_stream(self): + """Get uart stream generated since last time.""" + return self._interface.get_stream() diff --git a/extra/tigertool/ecusb/stm32uart.py b/extra/tigertool/ecusb/stm32uart.py index 95219455a9..5794bd091b 100644 --- a/extra/tigertool/ecusb/stm32uart.py +++ b/extra/tigertool/ecusb/stm32uart.py @@ -17,232 +17,244 @@ import termios import threading import time import tty + import usb from . import stm32usb class SuartError(Exception): - """Class for exceptions of Suart.""" - def __init__(self, msg, value=0): - """SuartError constructor. + """Class for exceptions of Suart.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SuartError, self).__init__(msg, value) - self.msg = msg - self.value = value + def __init__(self, msg, value=0): + """SuartError constructor. + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SuartError, self).__init__(msg, value) + self.msg = msg + self.value = value -class Suart(object): - """Provide interface to stm32 serial usb endpoint.""" - def __init__(self, vendor=0x18d1, product=0x501a, interface=0, - serialname=None, debuglog=False): - """Suart contstructor. - - Initializes stm32 USB stream interface. - - Args: - vendor: usb vendor id of stm32 device - product: usb product id of stm32 device - interface: interface number of stm32 device to use - serialname: serial name to target. Defaults to None. - debuglog: chatty output. Defaults to False. - - Raises: - SuartError: If init fails - """ - self._ptym = None - self._ptys = None - self._ptyname = None - self._rx_thread = None - self._tx_thread = None - self._debuglog = debuglog - self._susb = stm32usb.Susb(vendor=vendor, product=product, - interface=interface, serialname=serialname) - self._running = False - - def __del__(self): - """Suart destructor.""" - self.close() - - def close(self): - """Stop all running threads.""" - self._running = False - if self._rx_thread: - self._rx_thread.join(2) - self._rx_thread = None - if self._tx_thread: - self._tx_thread.join(2) - self._tx_thread = None - self._susb.close() - - def run_rx_thread(self): - """Background loop to pass data from USB to pty.""" - ep = select.epoll() - ep.register(self._ptym, select.EPOLLHUP) - try: - while self._running: - events = ep.poll(0) - # Check if the pty is connected to anything, or hungup. - if not events: - try: - r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) - if r: - if self._debuglog: - print(''.join([chr(x) for x in r]), end='') - os.write(self._ptym, r) - - # If we miss some characters on pty disconnect, that's fine. - # ep.read() also throws USBError on timeout, which we discard. - except OSError: - pass - except usb.core.USBError: - pass - else: - time.sleep(.1) - except Exception as e: - raise e - - def run_tx_thread(self): - """Background loop to pass data from pty to USB.""" - ep = select.epoll() - ep.register(self._ptym, select.EPOLLHUP) - try: - while self._running: - events = ep.poll(0) - # Check if the pty is connected to anything, or hungup. - if not events: - try: - r = os.read(self._ptym, 64) - # TODO(crosbug.com/936182): Remove when the servo v4/micro console - # issues are fixed. - time.sleep(0.001) - if r: - self._susb._write_ep.write(r, self._susb.TIMEOUT_MS) - - except OSError: - pass - except usb.core.USBError: - pass - else: - time.sleep(.1) - except Exception as e: - raise e - - def run(self): - """Creates pthreads to poll stm32 & PTY for data.""" - m, s = os.openpty() - self._ptyname = os.ttyname(s) - - self._ptym = m - self._ptys = s - - os.fchmod(s, 0o660) - - # Change the owner and group of the PTY to the user who started servod. - try: - uid = int(os.environ.get('SUDO_UID', -1)) - except TypeError: - uid = -1 - try: - gid = int(os.environ.get('SUDO_GID', -1)) - except TypeError: - gid = -1 - os.fchown(s, uid, gid) - - tty.setraw(self._ptym, termios.TCSADRAIN) - - # Generate a HUP flag on pty slave fd. - os.fdopen(s).close() - - self._running = True - - self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[]) - self._rx_thread.daemon = True - self._rx_thread.start() - - self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[]) - self._tx_thread.daemon = True - self._tx_thread.start() - - def get_uart_props(self): - """Get the uart's properties. - - Returns: - dict where: - baudrate: integer of uarts baudrate - bits: integer, number of bits of data Can be 5|6|7|8 inclusive - parity: integer, parity of 0-2 inclusive where: - 0: no parity - 1: odd parity - 2: even parity - sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: - 0: 1 stop bit - 1: 1.5 stop bits - 2: 2 stop bits - """ - return { - 'baudrate': 115200, - 'bits': 8, - 'parity': 0, - 'sbits': 1, - } - - def set_uart_props(self, line_props): - """Set the uart's properties. - - Note that Suart cannot set properties - and will fail if the properties are not the default 115200,8n1. - - Args: - line_props: dict where: - baudrate: integer of uarts baudrate - bits: integer, number of bits of data ( prior to stop bit) - parity: integer, parity of 0-2 inclusive where - 0: no parity - 1: odd parity - 2: even parity - sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: - 0: 1 stop bit - 1: 1.5 stop bits - 2: 2 stop bits - - Raises: - SuartError: If requested line properties are not the default. - """ - curr_props = self.get_uart_props() - for prop in line_props: - if line_props[prop] != curr_props[prop]: - raise SuartError('Line property %s cannot be set from %s to %s' % ( - prop, curr_props[prop], line_props[prop])) - return True - - def get_pty(self): - """Gets path to pty for communication to/from uart. - - Returns: - String path to the pty connected to the uart - """ - return self._ptyname +class Suart(object): + """Provide interface to stm32 serial usb endpoint.""" + + def __init__( + self, + vendor=0x18D1, + product=0x501A, + interface=0, + serialname=None, + debuglog=False, + ): + """Suart contstructor. + + Initializes stm32 USB stream interface. + + Args: + vendor: usb vendor id of stm32 device + product: usb product id of stm32 device + interface: interface number of stm32 device to use + serialname: serial name to target. Defaults to None. + debuglog: chatty output. Defaults to False. + + Raises: + SuartError: If init fails + """ + self._ptym = None + self._ptys = None + self._ptyname = None + self._rx_thread = None + self._tx_thread = None + self._debuglog = debuglog + self._susb = stm32usb.Susb( + vendor=vendor, product=product, interface=interface, serialname=serialname + ) + self._running = False + + def __del__(self): + """Suart destructor.""" + self.close() + + def close(self): + """Stop all running threads.""" + self._running = False + if self._rx_thread: + self._rx_thread.join(2) + self._rx_thread = None + if self._tx_thread: + self._tx_thread.join(2) + self._tx_thread = None + self._susb.close() + + def run_rx_thread(self): + """Background loop to pass data from USB to pty.""" + ep = select.epoll() + ep.register(self._ptym, select.EPOLLHUP) + try: + while self._running: + events = ep.poll(0) + # Check if the pty is connected to anything, or hungup. + if not events: + try: + r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) + if r: + if self._debuglog: + print("".join([chr(x) for x in r]), end="") + os.write(self._ptym, r) + + # If we miss some characters on pty disconnect, that's fine. + # ep.read() also throws USBError on timeout, which we discard. + except OSError: + pass + except usb.core.USBError: + pass + else: + time.sleep(0.1) + except Exception as e: + raise e + + def run_tx_thread(self): + """Background loop to pass data from pty to USB.""" + ep = select.epoll() + ep.register(self._ptym, select.EPOLLHUP) + try: + while self._running: + events = ep.poll(0) + # Check if the pty is connected to anything, or hungup. + if not events: + try: + r = os.read(self._ptym, 64) + # TODO(crosbug.com/936182): Remove when the servo v4/micro console + # issues are fixed. + time.sleep(0.001) + if r: + self._susb._write_ep.write(r, self._susb.TIMEOUT_MS) + + except OSError: + pass + except usb.core.USBError: + pass + else: + time.sleep(0.1) + except Exception as e: + raise e + + def run(self): + """Creates pthreads to poll stm32 & PTY for data.""" + m, s = os.openpty() + self._ptyname = os.ttyname(s) + + self._ptym = m + self._ptys = s + + os.fchmod(s, 0o660) + + # Change the owner and group of the PTY to the user who started servod. + try: + uid = int(os.environ.get("SUDO_UID", -1)) + except TypeError: + uid = -1 + + try: + gid = int(os.environ.get("SUDO_GID", -1)) + except TypeError: + gid = -1 + os.fchown(s, uid, gid) + + tty.setraw(self._ptym, termios.TCSADRAIN) + + # Generate a HUP flag on pty slave fd. + os.fdopen(s).close() + + self._running = True + + self._rx_thread = threading.Thread(target=self.run_rx_thread, args=[]) + self._rx_thread.daemon = True + self._rx_thread.start() + + self._tx_thread = threading.Thread(target=self.run_tx_thread, args=[]) + self._tx_thread.daemon = True + self._tx_thread.start() + + def get_uart_props(self): + """Get the uart's properties. + + Returns: + dict where: + baudrate: integer of uarts baudrate + bits: integer, number of bits of data Can be 5|6|7|8 inclusive + parity: integer, parity of 0-2 inclusive where: + 0: no parity + 1: odd parity + 2: even parity + sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: + 0: 1 stop bit + 1: 1.5 stop bits + 2: 2 stop bits + """ + return { + "baudrate": 115200, + "bits": 8, + "parity": 0, + "sbits": 1, + } + + def set_uart_props(self, line_props): + """Set the uart's properties. + + Note that Suart cannot set properties + and will fail if the properties are not the default 115200,8n1. + + Args: + line_props: dict where: + baudrate: integer of uarts baudrate + bits: integer, number of bits of data ( prior to stop bit) + parity: integer, parity of 0-2 inclusive where + 0: no parity + 1: odd parity + 2: even parity + sbits: integer, number of stop bits. Can be 0|1|2 inclusive where: + 0: 1 stop bit + 1: 1.5 stop bits + 2: 2 stop bits + + Raises: + SuartError: If requested line properties are not the default. + """ + curr_props = self.get_uart_props() + for prop in line_props: + if line_props[prop] != curr_props[prop]: + raise SuartError( + "Line property %s cannot be set from %s to %s" + % (prop, curr_props[prop], line_props[prop]) + ) + return True + + def get_pty(self): + """Gets path to pty for communication to/from uart. + + Returns: + String path to the pty connected to the uart + """ + return self._ptyname def main(): - """Run a suart test with the default parameters.""" - try: - sobj = Suart() - sobj.run() + """Run a suart test with the default parameters.""" + try: + sobj = Suart() + sobj.run() - # run() is a thread so just busy wait to mimic server. - while True: - # Ours sleeps to eleven! - time.sleep(11) - except KeyboardInterrupt: - sys.exit(0) + # run() is a thread so just busy wait to mimic server. + while True: + # Ours sleeps to eleven! + time.sleep(11) + except KeyboardInterrupt: + sys.exit(0) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/tigertool/ecusb/stm32usb.py b/extra/tigertool/ecusb/stm32usb.py index bfd5fbb1fb..4b2b23fbac 100644 --- a/extra/tigertool/ecusb/stm32usb.py +++ b/extra/tigertool/ecusb/stm32usb.py @@ -12,108 +12,115 @@ import usb class SusbError(Exception): - """Class for exceptions of Susb.""" - def __init__(self, msg, value=0): - """SusbError constructor. + """Class for exceptions of Susb.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SusbError, self).__init__(msg, value) - self.msg = msg - self.value = value + def __init__(self, msg, value=0): + """SusbError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SusbError, self).__init__(msg, value) + self.msg = msg + self.value = value class Susb(object): - """Provide stm32 USB functionality. - - Instance Variables: - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - READ_ENDPOINT = 0x81 - WRITE_ENDPOINT = 0x1 - TIMEOUT_MS = 100 - - def __init__(self, vendor=0x18d1, - product=0x5027, interface=1, serialname=None, logger=None): - """Susb constructor. - - Discovers and connects to stm32 USB endpoints. - - Args: - vendor: usb vendor id of stm32 device. - product: usb product id of stm32 device. - interface: interface number ( 1 - 4 ) of stm32 device to use. - serialname: string of device serialname. - logger: none - - Raises: - SusbError: An error accessing Susb object + """Provide stm32 USB functionality. + + Instance Variables: + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - self._vendor = vendor - self._product = product - self._interface = interface - self._serialname = serialname - self._find_device() - - def _find_device(self): - """Set up the usb endpoint""" - # Find the stm32. - dev_g = usb.core.find(idVendor=self._vendor, idProduct=self._product, - find_all=True) - dev_list = list(dev_g) - - if not dev_list: - raise SusbError('USB device not found') - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if self._serialname: - for d in dev_list: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == self._serialname: - dev = d - break - if dev is None: - raise SusbError('USB device(%s) not found' % self._serialname) - else: - try: - dev = dev_list[0] - except StopIteration: - raise SusbError('USB device %04x:%04x not found' % ( - self._vendor, self._product)) - - # If we can't set configuration, it's already been set. - try: - dev.set_configuration() - except usb.core.USBError: - pass - - self._dev = dev - - # Get an endpoint instance. - cfg = dev.get_active_configuration() - intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface) - self._intf = intf - if not intf: - raise SusbError('Interface %04x:%04x - 0x%x not found' % ( - self._vendor, self._product, self._interface)) - - # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel - # module driver that produces a ttyUSB, or direct endpoint access, but - # can't do both at the same time. - if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: - dev.detach_kernel_driver(intf.bInterfaceNumber) - - read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT - read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) - self._read_ep = read_ep - - write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT - write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) - self._write_ep = write_ep - - def close(self): - usb.util.dispose_resources(self._dev) + + READ_ENDPOINT = 0x81 + WRITE_ENDPOINT = 0x1 + TIMEOUT_MS = 100 + + def __init__( + self, vendor=0x18D1, product=0x5027, interface=1, serialname=None, logger=None + ): + """Susb constructor. + + Discovers and connects to stm32 USB endpoints. + + Args: + vendor: usb vendor id of stm32 device. + product: usb product id of stm32 device. + interface: interface number ( 1 - 4 ) of stm32 device to use. + serialname: string of device serialname. + logger: none + + Raises: + SusbError: An error accessing Susb object + """ + self._vendor = vendor + self._product = product + self._interface = interface + self._serialname = serialname + self._find_device() + + def _find_device(self): + """Set up the usb endpoint""" + # Find the stm32. + dev_g = usb.core.find( + idVendor=self._vendor, idProduct=self._product, find_all=True + ) + dev_list = list(dev_g) + + if not dev_list: + raise SusbError("USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if self._serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == self._serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % self._serialname) + else: + try: + dev = dev_list[0] + except StopIteration: + raise SusbError( + "USB device %04x:%04x not found" % (self._vendor, self._product) + ) + + # If we can't set configuration, it's already been set. + try: + dev.set_configuration() + except usb.core.USBError: + pass + + self._dev = dev + + # Get an endpoint instance. + cfg = dev.get_active_configuration() + intf = usb.util.find_descriptor(cfg, bInterfaceNumber=self._interface) + self._intf = intf + if not intf: + raise SusbError( + "Interface %04x:%04x - 0x%x not found" + % (self._vendor, self._product, self._interface) + ) + + # Detach raiden.ko if it is loaded. CCD endpoints support either a kernel + # module driver that produces a ttyUSB, or direct endpoint access, but + # can't do both at the same time. + if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: + dev.detach_kernel_driver(intf.bInterfaceNumber) + + read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT + read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) + self._read_ep = read_ep + + write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT + write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) + self._write_ep = write_ep + + def close(self): + usb.util.dispose_resources(self._dev) diff --git a/extra/tigertool/ecusb/tiny_servo_common.py b/extra/tigertool/ecusb/tiny_servo_common.py index 726e2e64b7..599bae80dc 100644 --- a/extra/tigertool/ecusb/tiny_servo_common.py +++ b/extra/tigertool/ecusb/tiny_servo_common.py @@ -17,215 +17,224 @@ import time import six import usb -from . import pty_driver -from . import stm32uart +from . import pty_driver, stm32uart def get_subprocess_args(): - if six.PY3: - return {'encoding': 'utf-8'} - return {} + if six.PY3: + return {"encoding": "utf-8"} + return {} class TinyServoError(Exception): - """Exceptions.""" + """Exceptions.""" def log(output): - """Print output to console, logfiles can be added here. + """Print output to console, logfiles can be added here. + + Args: + output: string to output. + """ + sys.stdout.write(output) + sys.stdout.write("\n") + sys.stdout.flush() - Args: - output: string to output. - """ - sys.stdout.write(output) - sys.stdout.write('\n') - sys.stdout.flush() def check_usb(vidpid, serialname=None): - """Check if |vidpid| is present on the system's USB. + """Check if |vidpid| is present on the system's USB. + + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + Returns: + True if found, False, otherwise. + """ + if get_usb_dev(vidpid, serialname): + return True - Returns: - True if found, False, otherwise. - """ - if get_usb_dev(vidpid, serialname): - return True + return False - return False def check_usb_sn(vidpid): - """Return the serial number + """Return the serial number - Return the serial number of the first USB device with VID:PID vidpid, - or None if no device is found. This will not work well with two of - the same device attached. + Return the serial number of the first USB device with VID:PID vidpid, + or None if no device is found. This will not work well with two of + the same device attached. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - Returns: - string serial number if found, None otherwise. - """ - dev = get_usb_dev(vidpid) + Returns: + string serial number if found, None otherwise. + """ + dev = get_usb_dev(vidpid) - if dev: - dev_serial = usb.util.get_string(dev, dev.iSerialNumber) + if dev: + dev_serial = usb.util.get_string(dev, dev.iSerialNumber) - return dev_serial + return dev_serial + + return None - return None def get_usb_dev(vidpid, serialname=None): - """Return the USB pyusb devie struct + """Return the USB pyusb devie struct + + Return the dev struct of the first USB device with VID:PID vidpid, + or None if no device is found. If more than one device check serial + if supplied. + + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. + + Returns: + pyusb device if found, None otherwise. + """ + vidpidst = vidpid.split(":") + vid = int(vidpidst[0], 16) + pid = int(vidpidst[1], 16) + + dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True) + dev_list = list(dev_g) + + if not dev_list: + return None + + # Check if we have multiple devices and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + return None + else: + try: + dev = dev_list[0] + except StopIteration: + return None + + return dev + - Return the dev struct of the first USB device with VID:PID vidpid, - or None if no device is found. If more than one device check serial - if supplied. +def check_usb_dev(vidpid, serialname=None): + """Return the USB dev number + + Return the dev number of the first USB device with VID:PID vidpid, + or None if no device is found. If more than one device check serial + if supplied. - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specified. - Returns: - pyusb device if found, None otherwise. - """ - vidpidst = vidpid.split(':') - vid = int(vidpidst[0], 16) - pid = int(vidpidst[1], 16) + Returns: + usb device number if found, None otherwise. + """ + dev = get_usb_dev(vidpid, serialname=serialname) - dev_g = usb.core.find(idVendor=vid, idProduct=pid, find_all=True) - dev_list = list(dev_g) + if dev: + return dev.address - if not dev_list: return None - # Check if we have multiple devices and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - return None - else: - try: - dev = dev_list[0] - except StopIteration: - return None - - return dev -def check_usb_dev(vidpid, serialname=None): - """Return the USB dev number +def wait_for_usb_remove(vidpid, serialname=None, timeout=None): + """Wait for USB device with vidpid to be removed. - Return the dev number of the first USB device with VID:PID vidpid, - or None if no device is found. If more than one device check serial - if supplied. + Wrapper for wait_for_usb below + """ + wait_for_usb(vidpid, serialname=serialname, timeout=timeout, desiredpresence=False) - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specified. - Returns: - usb device number if found, None otherwise. - """ - dev = get_usb_dev(vidpid, serialname=serialname) +def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True): + """Wait for usb device with vidpid to be present/absent. - if dev: - return dev.address + Args: + vidpid: string representation of the usb vid:pid, eg. '18d1:2001' + serialname: serialname if specificed. + timeout: timeout in seconds, None for no timeout. + desiredpresence: True for present, False for not present. - return None + Raises: + TinyServoError: on timeout. + """ + if timeout: + finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout) + while check_usb(vidpid, serialname) != desiredpresence: + time.sleep(0.1) + if timeout: + if datetime.datetime.now() > finish: + raise TinyServoError("Timeout", "Timeout waiting for USB %s" % vidpid) -def wait_for_usb_remove(vidpid, serialname=None, timeout=None): - """Wait for USB device with vidpid to be removed. +def do_serialno(serialno, pty): + """Set serialnumber 'serialno' via ec console 'pty'. - Wrapper for wait_for_usb below - """ - wait_for_usb(vidpid, serialname=serialname, - timeout=timeout, desiredpresence=False) + Commands are: + # > serialno set 1234 + # Saving serial number + # Serial number: 1234 -def wait_for_usb(vidpid, serialname=None, timeout=None, desiredpresence=True): - """Wait for usb device with vidpid to be present/absent. - - Args: - vidpid: string representation of the usb vid:pid, eg. '18d1:2001' - serialname: serialname if specificed. - timeout: timeout in seconds, None for no timeout. - desiredpresence: True for present, False for not present. - - Raises: - TinyServoError: on timeout. - """ - if timeout: - finish = datetime.datetime.now() + datetime.timedelta(seconds=timeout) - while check_usb(vidpid, serialname) != desiredpresence: - time.sleep(.1) - if timeout: - if datetime.datetime.now() > finish: - raise TinyServoError('Timeout', 'Timeout waiting for USB %s' % vidpid) + Args: + serialno: string serial number to set. + pty: tinyservo console to send commands. + + Raises: + TinyServoError: on failure to set. + ptyError: on command interface error. + """ + cmd = r"serialno set %s" % serialno + regex = r"Serial number:\s+(\S+)" + + results = pty._issue_cmd_get_results(cmd, [regex])[0] + sn = results[1].strip().strip("\n\r") + + if sn == serialno: + log("Success !") + log("Serial set to %s" % sn) + else: + log("Serial number set to %s but saved as %s." % (serialno, sn)) + raise TinyServoError( + "Serial Number", "Serial number set to %s but saved as %s." % (serialno, sn) + ) -def do_serialno(serialno, pty): - """Set serialnumber 'serialno' via ec console 'pty'. - - Commands are: - # > serialno set 1234 - # Saving serial number - # Serial number: 1234 - - Args: - serialno: string serial number to set. - pty: tinyservo console to send commands. - - Raises: - TinyServoError: on failure to set. - ptyError: on command interface error. - """ - cmd = r'serialno set %s' % serialno - regex = r'Serial number:\s+(\S+)' - - results = pty._issue_cmd_get_results(cmd, [regex])[0] - sn = results[1].strip().strip('\n\r') - - if sn == serialno: - log('Success !') - log('Serial set to %s' % sn) - else: - log('Serial number set to %s but saved as %s.' % (serialno, sn)) - raise TinyServoError( - 'Serial Number', - 'Serial number set to %s but saved as %s.' % (serialno, sn)) def setup_tinyservod(vidpid, interface, serialname=None, debuglog=False): - """Set up a pty - - Set up a pty to the ec console in order - to send commands. Returns a pty_driver object. - - Args: - vidpid: string vidpid of device to access. - interface: not used. - serialname: string serial name of device requested, optional. - debuglog: chatty printout (boolean) - - Returns: - pty object - - Raises: - UsbError, SusbError: on device not found - """ - vidstr, pidstr = vidpid.split(':') - vid = int(vidstr, 16) - pid = int(pidstr, 16) - suart = stm32uart.Suart(vendor=vid, product=pid, - interface=interface, serialname=serialname, - debuglog=debuglog) - suart.run() - pty = pty_driver.ptyDriver(suart, []) - - return pty + """Set up a pty + + Set up a pty to the ec console in order + to send commands. Returns a pty_driver object. + + Args: + vidpid: string vidpid of device to access. + interface: not used. + serialname: string serial name of device requested, optional. + debuglog: chatty printout (boolean) + + Returns: + pty object + + Raises: + UsbError, SusbError: on device not found + """ + vidstr, pidstr = vidpid.split(":") + vid = int(vidstr, 16) + pid = int(pidstr, 16) + suart = stm32uart.Suart( + vendor=vid, + product=pid, + interface=interface, + serialname=serialname, + debuglog=debuglog, + ) + suart.run() + pty = pty_driver.ptyDriver(suart, []) + + return pty diff --git a/extra/tigertool/ecusb/tiny_servod.py b/extra/tigertool/ecusb/tiny_servod.py index 632d9c3a20..ca5ca63f31 100644 --- a/extra/tigertool/ecusb/tiny_servod.py +++ b/extra/tigertool/ecusb/tiny_servod.py @@ -8,47 +8,48 @@ """Helper class to facilitate communication to servo ec console.""" -from ecusb import pty_driver -from ecusb import stm32uart +from ecusb import pty_driver, stm32uart class TinyServod(object): - """Helper class to wrap a pty_driver with interface.""" - - def __init__(self, vid, pid, interface, serialname=None, debug=False): - """Build the driver and interface. - - Args: - vid: servo device vid - pid: servo device pid - interface: which usb interface the servo console is on - serialname: the servo device serial (if available) - """ - self._vid = vid - self._pid = pid - self._interface = interface - self._serial = serialname - self._debug = debug - self._init() - - def _init(self): - self.suart = stm32uart.Suart(vendor=self._vid, - product=self._pid, - interface=self._interface, - serialname=self._serial, - debuglog=self._debug) - self.suart.run() - self.pty = pty_driver.ptyDriver(self.suart, []) - - def reinitialize(self): - """Reinitialize the connect after a reset/disconnect/etc.""" - self.close() - self._init() - - def close(self): - """Close out the connection and release resources. - - Note: if another TinyServod process or servod itself needs the same device - it's necessary to call this to ensure the usb device is available. - """ - self.suart.close() + """Helper class to wrap a pty_driver with interface.""" + + def __init__(self, vid, pid, interface, serialname=None, debug=False): + """Build the driver and interface. + + Args: + vid: servo device vid + pid: servo device pid + interface: which usb interface the servo console is on + serialname: the servo device serial (if available) + """ + self._vid = vid + self._pid = pid + self._interface = interface + self._serial = serialname + self._debug = debug + self._init() + + def _init(self): + self.suart = stm32uart.Suart( + vendor=self._vid, + product=self._pid, + interface=self._interface, + serialname=self._serial, + debuglog=self._debug, + ) + self.suart.run() + self.pty = pty_driver.ptyDriver(self.suart, []) + + def reinitialize(self): + """Reinitialize the connect after a reset/disconnect/etc.""" + self.close() + self._init() + + def close(self): + """Close out the connection and release resources. + + Note: if another TinyServod process or servod itself needs the same device + it's necessary to call this to ensure the usb device is available. + """ + self.suart.close() diff --git a/extra/tigertool/tigertest.py b/extra/tigertool/tigertest.py index 0cd31c8cce..8f8b2c7f03 100755 --- a/extra/tigertool/tigertest.py +++ b/extra/tigertool/tigertest.py @@ -13,7 +13,6 @@ import argparse import subprocess import sys - # Script to control tigertail USB-C Mux board. # # optional arguments: @@ -35,58 +34,60 @@ import sys def testCmd(cmd, expected_results): - """Run command on console, check for success. - - Args: - cmd: shell command to run. - expected_results: a list object of strings expected in the result. - - Raises: - Exception on fail. - """ - print('run: ' + cmd) - try: - p = subprocess.run(cmd, shell=True, check=False, capture_output=True) - output = p.stdout.decode('utf-8') - error = p.stderr.decode('utf-8') - assert p.returncode == 0 - for result in expected_results: - output.index(result) - except Exception as e: - print('FAIL') - print('cmd: ' + cmd) - print('error: ' + str(e)) - print('stdout:\n' + output) - print('stderr:\n' + error) - print('expected: ' + str(expected_results)) - print('RC: ' + str(p.returncode)) - raise e + """Run command on console, check for success. + + Args: + cmd: shell command to run. + expected_results: a list object of strings expected in the result. + + Raises: + Exception on fail. + """ + print("run: " + cmd) + try: + p = subprocess.run(cmd, shell=True, check=False, capture_output=True) + output = p.stdout.decode("utf-8") + error = p.stderr.decode("utf-8") + assert p.returncode == 0 + for result in expected_results: + output.index(result) + except Exception as e: + print("FAIL") + print("cmd: " + cmd) + print("error: " + str(e)) + print("stdout:\n" + output) + print("stderr:\n" + error) + print("expected: " + str(expected_results)) + print("RC: " + str(p.returncode)) + raise e + def test_sequence(): - testCmd('./tigertool.py --reboot', ['PASS']) - testCmd('./tigertool.py --setserialno test', ['PASS']) - testCmd('./tigertool.py --check_serial', ['test', 'PASS']) - testCmd('./tigertool.py -s test --check_serial', ['test', 'PASS']) - testCmd('./tigertool.py -m A', ['Mux set to A', 'PASS']) - testCmd('./tigertool.py -m B', ['Mux set to B', 'PASS']) - testCmd('./tigertool.py -m off', ['Mux set to off', 'PASS']) - testCmd('./tigertool.py -p', ['PASS']) - testCmd('./tigertool.py -r rw', ['PASS']) - testCmd('./tigertool.py -r ro', ['PASS']) - testCmd('./tigertool.py --check_version', ['RW', 'RO', 'PASS']) - - print('PASS') + testCmd("./tigertool.py --reboot", ["PASS"]) + testCmd("./tigertool.py --setserialno test", ["PASS"]) + testCmd("./tigertool.py --check_serial", ["test", "PASS"]) + testCmd("./tigertool.py -s test --check_serial", ["test", "PASS"]) + testCmd("./tigertool.py -m A", ["Mux set to A", "PASS"]) + testCmd("./tigertool.py -m B", ["Mux set to B", "PASS"]) + testCmd("./tigertool.py -m off", ["Mux set to off", "PASS"]) + testCmd("./tigertool.py -p", ["PASS"]) + testCmd("./tigertool.py -r rw", ["PASS"]) + testCmd("./tigertool.py -r ro", ["PASS"]) + testCmd("./tigertool.py --check_version", ["RW", "RO", "PASS"]) + + print("PASS") + def main(argv): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('-c', '--count', type=int, default=1, - help='loops to run') + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("-c", "--count", type=int, default=1, help="loops to run") + + opts = parser.parse_args(argv) - opts = parser.parse_args(argv) + for i in range(1, opts.count + 1): + print("Iteration: %d" % i) + test_sequence() - for i in range(1, opts.count + 1): - print('Iteration: %d' % i) - test_sequence() -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/extra/tigertool/tigertool.py b/extra/tigertool/tigertool.py index 6baae8abdf..37b7b01495 100755 --- a/extra/tigertool/tigertool.py +++ b/extra/tigertool/tigertool.py @@ -17,287 +17,308 @@ import time import ecusb.tiny_servo_common as c -STM_VIDPID = '18d1:5027' -serialno = 'Uninitialized' +STM_VIDPID = "18d1:5027" +serialno = "Uninitialized" + def do_mux(mux, pty): - """Set mux via ec console 'pty'. + """Set mux via ec console 'pty'. + + Args: + mux: mux to connect to DUT, 'A', 'B', or 'off' + pty: a pty object connected to tigertail - Args: - mux: mux to connect to DUT, 'A', 'B', or 'off' - pty: a pty object connected to tigertail + Commands are: + # > mux A + # TYPE-C mux is A + """ + validmux = ["A", "B", "off"] + if mux not in validmux: + c.log("Mux setting %s invalid, try one of %s" % (mux, validmux)) + return False - Commands are: - # > mux A - # TYPE-C mux is A - """ - validmux = ['A', 'B', 'off'] - if mux not in validmux: - c.log('Mux setting %s invalid, try one of %s' % (mux, validmux)) - return False + cmd = "mux %s" % mux + regex = "TYPE\-C mux is ([^\s\r\n]*)\r" - cmd = 'mux %s' % mux - regex = 'TYPE\-C mux is ([^\s\r\n]*)\r' + results = pty._issue_cmd_get_results(cmd, [regex])[0] + result = results[1].strip().strip("\n\r") - results = pty._issue_cmd_get_results(cmd, [regex])[0] - result = results[1].strip().strip('\n\r') + if result != mux: + c.log("Mux set to %s but saved as %s." % (mux, result)) + return False + c.log("Mux set to %s" % result) + return True - if result != mux: - c.log('Mux set to %s but saved as %s.' % (mux, result)) - return False - c.log('Mux set to %s' % result) - return True def do_version(pty): - """Check version via ec console 'pty'. - - Args: - pty: a pty object connected to tigertail - - Commands are: - # > version - # Chip: stm stm32f07x - # Board: 0 - # RO: tigertail_v1.1.6749-74d1a312e - # RW: tigertail_v1.1.6749-74d1a312e - # Build: tigertail_v1.1.6749-74d1a312e - # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com - """ - cmd = 'version' - regex = r'RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+' \ - r'(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)' - - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('Version is %s' % results[3]) - c.log('RO: %s' % results[1]) - c.log('RW: %s' % results[2]) - c.log('Date: %s' % results[4]) - c.log('Src: %s' % results[5]) - - return True + """Check version via ec console 'pty'. + + Args: + pty: a pty object connected to tigertail + + Commands are: + # > version + # Chip: stm stm32f07x + # Board: 0 + # RO: tigertail_v1.1.6749-74d1a312e + # RW: tigertail_v1.1.6749-74d1a312e + # Build: tigertail_v1.1.6749-74d1a312e + # 2017-07-25 20:08:34 nsanders@meatball.mtv.corp.google.com + """ + cmd = "version" + regex = ( + r"RO:\s+(\S+)\s+RW:\s+(\S+)\s+Build:\s+(\S+)\s+" + r"(\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) (\S+)" + ) + + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log("Version is %s" % results[3]) + c.log("RO: %s" % results[1]) + c.log("RW: %s" % results[2]) + c.log("Date: %s" % results[4]) + c.log("Src: %s" % results[5]) + + return True + def do_check_serial(pty): - """Check serial via ec console 'pty'. + """Check serial via ec console 'pty'. - Args: - pty: a pty object connected to tigertail + Args: + pty: a pty object connected to tigertail - Commands are: - # > serialno - # Serial number: number - """ - cmd = 'serialno' - regex = r'Serial number: ([^\n\r]+)' + Commands are: + # > serialno + # Serial number: number + """ + cmd = "serialno" + regex = r"Serial number: ([^\n\r]+)" - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('Serial is %s' % results[1]) + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log("Serial is %s" % results[1]) - return True + return True def do_power(count, bus, pty): - """Check power usage via ec console 'pty'. - - Args: - count: number of samples to capture - bus: rail to monitor, 'vbus', 'cc1', or 'cc2' - pty: a pty object connected to tigertail - - Commands are: - # > ina 0 - # Configuration: 4127 - # Shunt voltage: 02c4 => 1770 uV - # Bus voltage : 1008 => 5130 mV - # Power : 0019 => 625 mW - # Current : 0082 => 130 mA - # Calibration : 0155 - # Mask/Enable : 0008 - # Alert limit : 0000 - """ - if bus == 'vbus': - ina = 0 - if bus == 'cc1': - ina = 4 - if bus == 'cc2': - ina = 1 - - start = time.time() - - c.log('time,\tmV,\tmW,\tmA') - - cmd = 'ina %s' % ina - regex = r'Bus voltage : \S+ \S+ (\d+) mV\s+' \ - r'Power : \S+ \S+ (\d+) mW\s+' \ - r'Current : \S+ \S+ (\d+) mA' - - for i in range(0, count): - results = pty._issue_cmd_get_results(cmd, [regex])[0] - c.log('%.2f,\t%s,\t%s\t%s' % ( - time.time() - start, - results[1], results[2], results[3])) + """Check power usage via ec console 'pty'. + + Args: + count: number of samples to capture + bus: rail to monitor, 'vbus', 'cc1', or 'cc2' + pty: a pty object connected to tigertail + + Commands are: + # > ina 0 + # Configuration: 4127 + # Shunt voltage: 02c4 => 1770 uV + # Bus voltage : 1008 => 5130 mV + # Power : 0019 => 625 mW + # Current : 0082 => 130 mA + # Calibration : 0155 + # Mask/Enable : 0008 + # Alert limit : 0000 + """ + if bus == "vbus": + ina = 0 + if bus == "cc1": + ina = 4 + if bus == "cc2": + ina = 1 + + start = time.time() + + c.log("time,\tmV,\tmW,\tmA") + + cmd = "ina %s" % ina + regex = ( + r"Bus voltage : \S+ \S+ (\d+) mV\s+" + r"Power : \S+ \S+ (\d+) mW\s+" + r"Current : \S+ \S+ (\d+) mA" + ) + + for i in range(0, count): + results = pty._issue_cmd_get_results(cmd, [regex])[0] + c.log( + "%.2f,\t%s,\t%s\t%s" + % (time.time() - start, results[1], results[2], results[3]) + ) + + return True - return True def do_reboot(pty, serialname): - """Reboot via ec console pty - - Args: - pty: a pty object connected to tigertail - serialname: serial name, can be None. - - Command is: reboot. - """ - cmd = 'reboot' - - # Check usb dev number on current instance. - devno = c.check_usb_dev(STM_VIDPID, serialname=serialname) - if not devno: - c.log('Device not found') - return False - - try: - pty._issue_cmd(cmd) - except Exception as e: - c.log('Failed to send command: ' + str(e)) - return False - - try: - c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds - # it's not going to. This step just goes faster if it's detected. - pass - - try: - c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - c.log('Failed to return from reboot: ' + str(e)) - return False - - # Check that the device had a new device number, i.e. it's - # disconnected and reconnected. - newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname) - if newdevno == devno: - c.log("Device didn't reboot") - return False - - return True + """Reboot via ec console pty + + Args: + pty: a pty object connected to tigertail + serialname: serial name, can be None. + + Command is: reboot. + """ + cmd = "reboot" + + # Check usb dev number on current instance. + devno = c.check_usb_dev(STM_VIDPID, serialname=serialname) + if not devno: + c.log("Device not found") + return False + + try: + pty._issue_cmd(cmd) + except Exception as e: + c.log("Failed to send command: " + str(e)) + return False + + try: + c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds + # it's not going to. This step just goes faster if it's detected. + pass + + try: + c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + c.log("Failed to return from reboot: " + str(e)) + return False + + # Check that the device had a new device number, i.e. it's + # disconnected and reconnected. + newdevno = c.check_usb_dev(STM_VIDPID, serialname=serialname) + if newdevno == devno: + c.log("Device didn't reboot") + return False + + return True + def do_sysjump(region, pty, serialname): - """Set region via ec console 'pty'. - - Args: - region: ec code region to execute, 'ro' or 'rw' - pty: a pty object connected to tigertail - serialname: serial name, can be None. - - Commands are: - # > sysjump rw - """ - validregion = ['ro', 'rw'] - if region not in validregion: - c.log('Region setting %s invalid, try one of %s' % ( - region, validregion)) - return False - - cmd = 'sysjump %s' % region - try: - pty._issue_cmd(cmd) - except Exception as e: - c.log('Exception: ' + str(e)) - return False - - try: - c.wait_for_usb_remove(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds - # it's not going to. This step just goes faster if it's detected. - pass - - try: - c.wait_for_usb(STM_VIDPID, timeout=3., serialname=serialname) - except Exception as e: - c.log('Failed to return from restart: ' + str(e)) - return False - - c.log('Region requested %s' % region) - return True + """Set region via ec console 'pty'. + + Args: + region: ec code region to execute, 'ro' or 'rw' + pty: a pty object connected to tigertail + serialname: serial name, can be None. + + Commands are: + # > sysjump rw + """ + validregion = ["ro", "rw"] + if region not in validregion: + c.log("Region setting %s invalid, try one of %s" % (region, validregion)) + return False + + cmd = "sysjump %s" % region + try: + pty._issue_cmd(cmd) + except Exception as e: + c.log("Exception: " + str(e)) + return False + + try: + c.wait_for_usb_remove(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + # Polling for reboot isn't reliable but if it hasn't happened in 3 seconds + # it's not going to. This step just goes faster if it's detected. + pass + + try: + c.wait_for_usb(STM_VIDPID, timeout=3.0, serialname=serialname) + except Exception as e: + c.log("Failed to return from restart: " + str(e)) + return False + + c.log("Region requested %s" % region) + return True + def get_parser(): - parser = argparse.ArgumentParser( - description=__doc__) - parser.add_argument('-s', '--serialno', type=str, default=None, - help='serial number of board to use') - parser.add_argument('-b', '--bus', type=str, default='vbus', - help='Which rail to log: [vbus|cc1|cc2]') - group = parser.add_mutually_exclusive_group() - group.add_argument('--setserialno', type=str, default=None, - help='serial number to set on the board.') - group.add_argument('--check_serial', action='store_true', - help='check serial number set on the board.') - group.add_argument('-m', '--mux', type=str, default=None, - help='mux selection') - group.add_argument('-p', '--power', action='store_true', - help='check VBUS') - group.add_argument('-l', '--powerlog', type=int, default=None, - help='log VBUS') - group.add_argument('-r', '--sysjump', type=str, default=None, - help='region selection') - group.add_argument('--reboot', action='store_true', - help='reboot tigertail') - group.add_argument('--check_version', action='store_true', - help='check tigertail version') - return parser + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-s", "--serialno", type=str, default=None, help="serial number of board to use" + ) + parser.add_argument( + "-b", + "--bus", + type=str, + default="vbus", + help="Which rail to log: [vbus|cc1|cc2]", + ) + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--setserialno", + type=str, + default=None, + help="serial number to set on the board.", + ) + group.add_argument( + "--check_serial", + action="store_true", + help="check serial number set on the board.", + ) + group.add_argument("-m", "--mux", type=str, default=None, help="mux selection") + group.add_argument("-p", "--power", action="store_true", help="check VBUS") + group.add_argument("-l", "--powerlog", type=int, default=None, help="log VBUS") + group.add_argument( + "-r", "--sysjump", type=str, default=None, help="region selection" + ) + group.add_argument("--reboot", action="store_true", help="reboot tigertail") + group.add_argument( + "--check_version", action="store_true", help="check tigertail version" + ) + return parser + def main(argv): - parser = get_parser() - opts = parser.parse_args(argv) + parser = get_parser() + opts = parser.parse_args(argv) - result = True + result = True - # Let's make sure there's a tigertail - # If nothing found in 5 seconds, fail. - c.wait_for_usb(STM_VIDPID, timeout=5., serialname=opts.serialno) + # Let's make sure there's a tigertail + # If nothing found in 5 seconds, fail. + c.wait_for_usb(STM_VIDPID, timeout=5.0, serialname=opts.serialno) - pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno) + pty = c.setup_tinyservod(STM_VIDPID, 0, serialname=opts.serialno) - if opts.bus not in ('vbus', 'cc1', 'cc2'): - c.log('Try --bus [vbus|cc1|cc2]') - result = False + if opts.bus not in ("vbus", "cc1", "cc2"): + c.log("Try --bus [vbus|cc1|cc2]") + result = False - elif opts.setserialno: - try: - c.do_serialno(opts.setserialno, pty) - except Exception: - result = False + elif opts.setserialno: + try: + c.do_serialno(opts.setserialno, pty) + except Exception: + result = False - elif opts.mux: - result &= do_mux(opts.mux, pty) + elif opts.mux: + result &= do_mux(opts.mux, pty) - elif opts.sysjump: - result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno) + elif opts.sysjump: + result &= do_sysjump(opts.sysjump, pty, serialname=opts.serialno) - elif opts.reboot: - result &= do_reboot(pty, serialname=opts.serialno) + elif opts.reboot: + result &= do_reboot(pty, serialname=opts.serialno) - elif opts.check_version: - result &= do_version(pty) + elif opts.check_version: + result &= do_version(pty) - elif opts.check_serial: - result &= do_check_serial(pty) + elif opts.check_serial: + result &= do_check_serial(pty) - elif opts.power: - result &= do_power(1, opts.bus, pty) + elif opts.power: + result &= do_power(1, opts.bus, pty) - elif opts.powerlog: - result &= do_power(opts.powerlog, opts.bus, pty) + elif opts.powerlog: + result &= do_power(opts.powerlog, opts.bus, pty) - if result: - c.log('PASS') - else: - c.log('FAIL') - sys.exit(-1) + if result: + c.log("PASS") + else: + c.log("FAIL") + sys.exit(-1) -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/extra/usb_power/convert_power_log_board.py b/extra/usb_power/convert_power_log_board.py index 8aab77ee4c..c1b25f57db 100644 --- a/extra/usb_power/convert_power_log_board.py +++ b/extra/usb_power/convert_power_log_board.py @@ -14,6 +14,7 @@ Program to convert sweetberry config to servod config template. # Note: This is a py2/3 compatible file. from __future__ import print_function + import json import os import sys @@ -48,21 +49,25 @@ def write_to_file(file, sweetberry, inas): inas: list of inas read from board file. """ - with open(file, 'w') as pyfile: + with open(file, "w") as pyfile: - pyfile.write('inas = [\n') + pyfile.write("inas = [\n") for rec in inas: - if rec['sweetberry'] != sweetberry: + if rec["sweetberry"] != sweetberry: continue # EX : ('sweetberry', 0x40, 'SB_FW_CAM_2P8', 5.0, 1.000, 3, False), - channel, i2c_addr = Spower.CHMAP[rec['channel']] - record = (" ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" - ",\n" % (i2c_addr, rec['name'], rec['rs'], channel)) + channel, i2c_addr = Spower.CHMAP[rec["channel"]] + record = " ('sweetberry', 0x%02x, '%s', 5.0, %f, %d, 'True')" ",\n" % ( + i2c_addr, + rec["name"], + rec["rs"], + channel, + ) pyfile.write(record) - pyfile.write(']\n') + pyfile.write("]\n") def main(argv): @@ -76,16 +81,18 @@ def main(argv): inas = fetch_records(inputf) - sweetberry = set(rec['sweetberry'] for rec in inas) + sweetberry = set(rec["sweetberry"] for rec in inas) if len(sweetberry) == 2: - print("Converting %s to %s and %s" % (inputf, basename + '_a.py', - basename + '_b.py')) - write_to_file(basename + '_a.py', 'A', inas) - write_to_file(basename + '_b.py', 'B', inas) + print( + "Converting %s to %s and %s" + % (inputf, basename + "_a.py", basename + "_b.py") + ) + write_to_file(basename + "_a.py", "A", inas) + write_to_file(basename + "_b.py", "B", inas) else: - print("Converting %s to %s" % (inputf, basename + '.py')) - write_to_file(basename + '.py', sweetberry.pop(), inas) + print("Converting %s to %s" % (inputf, basename + ".py")) + write_to_file(basename + ".py", sweetberry.pop(), inas) if __name__ == "__main__": diff --git a/extra/usb_power/convert_servo_ina.py b/extra/usb_power/convert_servo_ina.py index 1c70f31aeb..6ccd474e6c 100755 --- a/extra/usb_power/convert_servo_ina.py +++ b/extra/usb_power/convert_servo_ina.py @@ -14,67 +14,71 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import os import sys def fetch_records(basename): - """Import records from servo_ina file. + """Import records from servo_ina file. - servo_ina files are python imports, and have a list of tuples with - the INA data. - (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True) + servo_ina files are python imports, and have a list of tuples with + the INA data. + (inatype, i2caddr, rail name, bus voltage, shunt ohms, mux, True) - Args: - basename: python import name (filename -.py) + Args: + basename: python import name (filename -.py) - Returns: - list of tuples as described above. - """ - ina_desc = __import__(basename) - return ina_desc.inas + Returns: + list of tuples as described above. + """ + ina_desc = __import__(basename) + return ina_desc.inas def main(argv): - if len(argv) != 2: - print("usage:") - print(" %s input.py" % argv[0]) - return + if len(argv) != 2: + print("usage:") + print(" %s input.py" % argv[0]) + return - inputf = argv[1] - basename = os.path.splitext(inputf)[0] - outputf = basename + '.board' - outputs = basename + '.scenario' + inputf = argv[1] + basename = os.path.splitext(inputf)[0] + outputf = basename + ".board" + outputs = basename + ".scenario" - print("Converting %s to %s, %s" % (inputf, outputf, outputs)) + print("Converting %s to %s, %s" % (inputf, outputf, outputs)) - inas = fetch_records(basename) + inas = fetch_records(basename) + boardfile = open(outputf, "w") + scenario = open(outputs, "w") - boardfile = open(outputf, 'w') - scenario = open(outputs, 'w') + boardfile.write("[\n") + scenario.write("[\n") + start = True - boardfile.write('[\n') - scenario.write('[\n') - start = True + for rec in inas: + if start: + start = False + else: + boardfile.write(",\n") + scenario.write(",\n") - for rec in inas: - if start: - start = False - else: - boardfile.write(',\n') - scenario.write(',\n') + record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % ( + rec[2], + rec[4], + rec[1] - 64, + ) + boardfile.write(record) + scenario.write('"%s"' % rec[2]) - record = ' {"name": "%s", "rs": %f, "sweetberry": "A", "channel": %d}' % ( - rec[2], rec[4], rec[1] - 64) - boardfile.write(record) - scenario.write('"%s"' % rec[2]) + boardfile.write("\n") + boardfile.write("]") - boardfile.write('\n') - boardfile.write(']') + scenario.write("\n") + scenario.write("]") - scenario.write('\n') - scenario.write(']') if __name__ == "__main__": - main(sys.argv) + main(sys.argv) diff --git a/extra/usb_power/powerlog.py b/extra/usb_power/powerlog.py index 82cce3daed..d893e5c6b9 100755 --- a/extra/usb_power/powerlog.py +++ b/extra/usb_power/powerlog.py @@ -14,9 +14,9 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import array -from distutils import sysconfig import json import logging import os @@ -25,884 +25,988 @@ import struct import sys import time import traceback +from distutils import sysconfig import usb - from stats_manager import StatsManager # Directory where hdctools installs configuration files into. -LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), 'servo', - 'data') +LIB_DIR = os.path.join(sysconfig.get_python_lib(standard_lib=False), "servo", "data") # Potential config file locations: current working directory, the same directory # as powerlog.py file or LIB_DIR. -CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), - LIB_DIR] - -def logoutput(msg): - print(msg) - sys.stdout.flush() - -def process_filename(filename): - """Find the file path from the filename. - - If filename is already the complete path, return that directly. If filename is - just the short name, look for the file in the current working directory, in - the directory of the current .py file, and then in the directory installed by - hdctools. If the file is found, return the complete path of the file. - - Args: - filename: complete file path or short file name. - - Returns: - a complete file path. - - Raises: - IOError if filename does not exist. - """ - # Check if filename is absolute path. - if os.path.isabs(filename) and os.path.isfile(filename): - return filename - # Check if filename is relative to a known config location. - for dirname in CONFIG_LOCATIONS: - file_at_dir = os.path.join(dirname, filename) - if os.path.isfile(file_at_dir): - return file_at_dir - raise IOError('No such file or directory: \'%s\'' % filename) - - -class Spower(object): - """Power class to access devices on the bus. - - Usage: - bus = Spower() - - Instance Variables: - _dev: pyUSB device object - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - - # INA interface type. - INA_POWER = 1 - INA_BUSV = 2 - INA_CURRENT = 3 - INA_SHUNTV = 4 - # INA_SUFFIX is used to differentiate multiple ina types for the same power - # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1 - # (power, no suffix for backward compatibility). - INA_SUFFIX = ['', '', '_busv', '_cur', '_shuntv'] - - # usb power commands - CMD_RESET = 0x0000 - CMD_STOP = 0x0001 - CMD_ADDINA = 0x0002 - CMD_START = 0x0003 - CMD_NEXT = 0x0004 - CMD_SETTIME = 0x0005 - - # Map between header channel number (0-47) - # and INA I2C bus/addr on sweetberry. - CHMAP = { - 0: (3, 0x40), - 1: (1, 0x40), - 2: (2, 0x40), - 3: (0, 0x40), - 4: (3, 0x41), - 5: (1, 0x41), - 6: (2, 0x41), - 7: (0, 0x41), - 8: (3, 0x42), - 9: (1, 0x42), - 10: (2, 0x42), - 11: (0, 0x42), - 12: (3, 0x43), - 13: (1, 0x43), - 14: (2, 0x43), - 15: (0, 0x43), - 16: (3, 0x44), - 17: (1, 0x44), - 18: (2, 0x44), - 19: (0, 0x44), - 20: (3, 0x45), - 21: (1, 0x45), - 22: (2, 0x45), - 23: (0, 0x45), - 24: (3, 0x46), - 25: (1, 0x46), - 26: (2, 0x46), - 27: (0, 0x46), - 28: (3, 0x47), - 29: (1, 0x47), - 30: (2, 0x47), - 31: (0, 0x47), - 32: (3, 0x48), - 33: (1, 0x48), - 34: (2, 0x48), - 35: (0, 0x48), - 36: (3, 0x49), - 37: (1, 0x49), - 38: (2, 0x49), - 39: (0, 0x49), - 40: (3, 0x4a), - 41: (1, 0x4a), - 42: (2, 0x4a), - 43: (0, 0x4a), - 44: (3, 0x4b), - 45: (1, 0x4b), - 46: (2, 0x4b), - 47: (0, 0x4b), - } - - def __init__(self, board, vendor=0x18d1, - product=0x5020, interface=1, serialname=None): - self._logger = logging.getLogger(__name__) - self._board = board - - # Find the stm32. - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise Exception("Power", "USB device not found") - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - dev_serial = "PyUSB dioesn't have a stable interface" - try: - dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) - except ValueError: - # Incompatible pyUsb version. - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - raise Exception("Power", "USB device(%s) not found" % serialname) - else: - try: - dev = dev_list[0] - except TypeError: - # Incompatible pyUsb version. - dev = dev_list.next() - - self._logger.debug("Found USB device: %04x:%04x", vendor, product) - self._dev = dev - - # Get an endpoint instance. - try: - dev.set_configuration() - except usb.USBError: - pass - cfg = dev.get_active_configuration() - - intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \ - i.bInterfaceClass==255 and i.bInterfaceSubClass==0x54) - - self._intf = intf - self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber) - - read_ep = usb.util.find_descriptor( - intf, - # match the first IN endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_IN - ) - - self._read_ep = read_ep - self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress) - - write_ep = usb.util.find_descriptor( - intf, - # match the first OUT endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_OUT - ) - - self._write_ep = write_ep - self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress) - - self.clear_ina_struct() - - self._logger.debug("Found power logging USB endpoint.") - - def clear_ina_struct(self): - """ Clear INA description struct.""" - self._inas = [] +CONFIG_LOCATIONS = [os.getcwd(), os.path.dirname(os.path.realpath(__file__)), LIB_DIR] - def append_ina_struct(self, name, rs, port, addr, - data=None, ina_type=INA_POWER): - """Add an INA descriptor into the list of active INAs. - Args: - name: Readable name of this channel. - rs: Sense resistor value in ohms, floating point. - port: I2C channel this INA is connected to. - addr: I2C addr of this INA. - data: Misc data for special handling, board specific. - ina_type: INA function to use, power, voltage, etc. - """ - ina = {} - ina['name'] = name - ina['rs'] = rs - ina['port'] = port - ina['addr'] = addr - ina['type'] = ina_type - # Calculate INA231 Calibration register - # (see INA231 spec p.15) - # CurrentLSB = uA per div = 80mV / (Rsh * 2^15) - # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000) - ina['uAscale'] = 80000000. / (rs * 0x8000); - ina['uWscale'] = 25. * ina['uAscale']; - ina['mVscale'] = 1.25 - ina['uVscale'] = 2.5 - ina['data'] = data - self._inas.append(ina) - - def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000): - """Write command to logger logic. - - This function writes byte command values list to stm, then reads - byte status. - - Args: - write_list: list of command byte values [0~255]. - read_count: number of status byte values to read. - - Interface: - write: [command, data ... ] - read: [status ] - - Returns: - bytes read, or None on failure. - """ - self._logger.debug("Spower.wr_command(write_list=[%s] (%d), read_count=%s)", - list(bytearray(write_list)), len(write_list), read_count) - - # Clean up args from python style to correct types. - write_length = 0 - if write_list: - write_length = len(write_list) - if not read_count: - read_count = 0 - - # Send command to stm32. - if write_list: - cmd = write_list - ret = self._write_ep.write(cmd, wtimeout) - - self._logger.debug("RET: %s ", ret) - - # Read back response if necessary. - if read_count: - bytesread = self._read_ep.read(512, rtimeout) - self._logger.debug("BYTES: [%s]", bytesread) - - if len(bytesread) != read_count: - pass - - self._logger.debug("STATUS: 0x%02x", int(bytesread[0])) - if read_count == 1: - return bytesread[0] - else: - return bytesread - - return None - - def clear(self): - """Clear pending reads on the stm32""" - try: - while True: - ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50) - self._logger.debug("Try Clear: read %s", - "success" if ret == 0 else "failure") - except: - pass - - def send_reset(self): - """Reset the power interface on the stm32""" - cmd = struct.pack("<H", self.CMD_RESET) - ret = self.wr_command(cmd, rtimeout=50, wtimeout=50) - self._logger.debug("Command RESET: %s", - "success" if ret == 0 else "failure") - - def reset(self): - """Try resetting the USB interface until success. - - Use linear back off strategy when encounter the error with 10ms increment. - - Raises: - Exception on failure. - """ - max_reset_retry = 100 - for count in range(1, max_reset_retry + 1): - self.clear() - try: - self.send_reset() - return - except Exception as e: - self.clear() - self.clear() - self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e) - time.sleep(count * 0.01) - raise Exception("Power", "Failed to reset") - - def stop(self): - """Stop any active data acquisition.""" - cmd = struct.pack("<H", self.CMD_STOP) - ret = self.wr_command(cmd) - self._logger.debug("Command STOP: %s", - "success" if ret == 0 else "failure") - - def start(self, integration_us): - """Start data acquisition. - - Args: - integration_us: int, how many us between samples, and - how often the data block must be read. +def logoutput(msg): + print(msg) + sys.stdout.flush() - Returns: - actual sampling interval in ms. - """ - cmd = struct.pack("<HI", self.CMD_START, integration_us) - read = self.wr_command(cmd, read_count=5) - actual_us = 0 - if len(read) == 5: - ret, actual_us = struct.unpack("<BI", read) - self._logger.debug("Command START: %s %dus", - "success" if ret == 0 else "failure", actual_us) - else: - self._logger.debug("Command START: FAIL") - return actual_us +def process_filename(filename): + """Find the file path from the filename. - def add_ina_name(self, name_tuple): - """Add INA from board config. + If filename is already the complete path, return that directly. If filename is + just the short name, look for the file in the current working directory, in + the directory of the current .py file, and then in the directory installed by + hdctools. If the file is found, return the complete path of the file. Args: - name_tuple: name and type of power rail in board config. + filename: complete file path or short file name. Returns: - True if INA added, False if the INA is not on this board. + a complete file path. Raises: - Exception on unexpected failure. + IOError if filename does not exist. """ - name, ina_type = name_tuple - - for datum in self._brdcfg: - if datum["name"] == name: - rs = int(float(datum["rs"]) * 1000.) - board = datum["sweetberry"] - - if board == self._board: - if 'port' in datum and 'addr' in datum: - port = datum['port'] - addr = datum['addr'] - else: - channel = int(datum["channel"]) - port, addr = self.CHMAP[channel] - self.add_ina(port, ina_type, addr, 0, rs, data=datum) - return True - else: - return False - raise Exception("Power", "Failed to find INA %s" % name) - - def set_time(self, timestamp_us): - """Set sweetberry time to match host time. - - Args: - timestamp_us: host timestmap in us. - """ - # 0x0005 , 8 byte timestamp - cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us) - ret = self.wr_command(cmd) + # Check if filename is absolute path. + if os.path.isabs(filename) and os.path.isfile(filename): + return filename + # Check if filename is relative to a known config location. + for dirname in CONFIG_LOCATIONS: + file_at_dir = os.path.join(dirname, filename) + if os.path.isfile(file_at_dir): + return file_at_dir + raise IOError("No such file or directory: '%s'" % filename) - self._logger.debug("Command SETTIME: %s", - "success" if ret == 0 else "failure") - def add_ina(self, bus, ina_type, addr, extra, resistance, data=None): - """Add an INA to the data acquisition list. +class Spower(object): + """Power class to access devices on the bus. - Args: - bus: which i2c bus the INA is on. Same ordering as Si2c. - ina_type: Ina interface: INA_POWER/BUSV/etc. - addr: 7 bit i2c addr of this INA - extra: extra data for nonstandard configs. - resistance: int, shunt resistance in mOhm - """ - # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs - cmd = struct.pack("<HBBBBI", self.CMD_ADDINA, - bus, ina_type, addr, extra, resistance) - ret = self.wr_command(cmd) - if ret == 0: - if data: - name = data['name'] - else: - name = "ina%d_%02x" % (bus, addr) - self.append_ina_struct(name, resistance, bus, addr, - data=data, ina_type=ina_type) - self._logger.debug("Command ADD_INA: %s", - "success" if ret == 0 else "failure") - - def report_header_size(self): - """Helper function to calculate power record header size.""" - result = 2 - timestamp = 8 - return result + timestamp - - def report_size(self, ina_count): - """Helper function to calculate full power record size.""" - record = 2 - - datasize = self.report_header_size() + ina_count * record - # Round to multiple of 4 bytes. - datasize = int(((datasize + 3) // 4) * 4) - - return datasize - - def read_line(self): - """Read a line of data from the setup INAs + Usage: + bus = Spower() - Returns: - list of dicts of the values read by ina/type tuple, otherwise None. - [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}] + Instance Variables: + _dev: pyUSB device object + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - try: - expected_bytes = self.report_size(len(self._inas)) - cmd = struct.pack("<H", self.CMD_NEXT) - bytesread = self.wr_command(cmd, read_count=expected_bytes) - except usb.core.USBError as e: - self._logger.error("READ LINE FAILED %s", e) - return None - - if len(bytesread) == 1: - if bytesread[0] != 0x6: - self._logger.debug("READ LINE FAILED bytes: %d ret: %02x", - len(bytesread), bytesread[0]) - return None - - if len(bytesread) % expected_bytes != 0: - self._logger.debug("READ LINE WARNING: expected %d, got %d", - expected_bytes, len(bytesread)) - - packet_count = len(bytesread) // expected_bytes - - values = [] - for i in range(0, packet_count): - start = i * expected_bytes - end = (i + 1) * expected_bytes - record = self.interpret_line(bytesread[start:end]) - values.append(record) - - return values - - def interpret_line(self, data): - """Interpret a power record from INAs - Args: - data: one single record of bytes. + # INA interface type. + INA_POWER = 1 + INA_BUSV = 2 + INA_CURRENT = 3 + INA_SHUNTV = 4 + # INA_SUFFIX is used to differentiate multiple ina types for the same power + # rail. No suffix for when ina type is 0 (non-existent) and when ina type is 1 + # (power, no suffix for backward compatibility). + INA_SUFFIX = ["", "", "_busv", "_cur", "_shuntv"] + + # usb power commands + CMD_RESET = 0x0000 + CMD_STOP = 0x0001 + CMD_ADDINA = 0x0002 + CMD_START = 0x0003 + CMD_NEXT = 0x0004 + CMD_SETTIME = 0x0005 + + # Map between header channel number (0-47) + # and INA I2C bus/addr on sweetberry. + CHMAP = { + 0: (3, 0x40), + 1: (1, 0x40), + 2: (2, 0x40), + 3: (0, 0x40), + 4: (3, 0x41), + 5: (1, 0x41), + 6: (2, 0x41), + 7: (0, 0x41), + 8: (3, 0x42), + 9: (1, 0x42), + 10: (2, 0x42), + 11: (0, 0x42), + 12: (3, 0x43), + 13: (1, 0x43), + 14: (2, 0x43), + 15: (0, 0x43), + 16: (3, 0x44), + 17: (1, 0x44), + 18: (2, 0x44), + 19: (0, 0x44), + 20: (3, 0x45), + 21: (1, 0x45), + 22: (2, 0x45), + 23: (0, 0x45), + 24: (3, 0x46), + 25: (1, 0x46), + 26: (2, 0x46), + 27: (0, 0x46), + 28: (3, 0x47), + 29: (1, 0x47), + 30: (2, 0x47), + 31: (0, 0x47), + 32: (3, 0x48), + 33: (1, 0x48), + 34: (2, 0x48), + 35: (0, 0x48), + 36: (3, 0x49), + 37: (1, 0x49), + 38: (2, 0x49), + 39: (0, 0x49), + 40: (3, 0x4A), + 41: (1, 0x4A), + 42: (2, 0x4A), + 43: (0, 0x4A), + 44: (3, 0x4B), + 45: (1, 0x4B), + 46: (2, 0x4B), + 47: (0, 0x4B), + } + + def __init__( + self, board, vendor=0x18D1, product=0x5020, interface=1, serialname=None + ): + self._logger = logging.getLogger(__name__) + self._board = board + + # Find the stm32. + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise Exception("Power", "USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + dev_serial = "PyUSB dioesn't have a stable interface" + try: + dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) + except ValueError: + # Incompatible pyUsb version. + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + raise Exception("Power", "USB device(%s) not found" % serialname) + else: + try: + dev = dev_list[0] + except TypeError: + # Incompatible pyUsb version. + dev = dev_list.next() - Output: - stdout of the record in csv format. + self._logger.debug("Found USB device: %04x:%04x", vendor, product) + self._dev = dev - Returns: - dict containing name, value of recorded data. - """ - status, size = struct.unpack("<BB", data[0:2]) - if len(data) != self.report_size(size): - self._logger.error("READ LINE FAILED st:%d size:%d expected:%d len:%d", - status, size, self.report_size(size), len(data)) - else: - pass + # Get an endpoint instance. + try: + dev.set_configuration() + except usb.USBError: + pass + cfg = dev.get_active_configuration() + + intf = usb.util.find_descriptor( + cfg, + custom_match=lambda i: i.bInterfaceClass == 255 + and i.bInterfaceSubClass == 0x54, + ) + + self._intf = intf + self._logger.debug("InterfaceNumber: %s", intf.bInterfaceNumber) + + read_ep = usb.util.find_descriptor( + intf, + # match the first IN endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_IN, + ) + + self._read_ep = read_ep + self._logger.debug("Reader endpoint: 0x%x", read_ep.bEndpointAddress) + + write_ep = usb.util.find_descriptor( + intf, + # match the first OUT endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_OUT, + ) + + self._write_ep = write_ep + self._logger.debug("Writer endpoint: 0x%x", write_ep.bEndpointAddress) + + self.clear_ina_struct() + + self._logger.debug("Found power logging USB endpoint.") + + def clear_ina_struct(self): + """Clear INA description struct.""" + self._inas = [] + + def append_ina_struct(self, name, rs, port, addr, data=None, ina_type=INA_POWER): + """Add an INA descriptor into the list of active INAs. + + Args: + name: Readable name of this channel. + rs: Sense resistor value in ohms, floating point. + port: I2C channel this INA is connected to. + addr: I2C addr of this INA. + data: Misc data for special handling, board specific. + ina_type: INA function to use, power, voltage, etc. + """ + ina = {} + ina["name"] = name + ina["rs"] = rs + ina["port"] = port + ina["addr"] = addr + ina["type"] = ina_type + # Calculate INA231 Calibration register + # (see INA231 spec p.15) + # CurrentLSB = uA per div = 80mV / (Rsh * 2^15) + # CurrentLSB uA = 80000000nV / (Rsh mOhm * 0x8000) + ina["uAscale"] = 80000000.0 / (rs * 0x8000) + ina["uWscale"] = 25.0 * ina["uAscale"] + ina["mVscale"] = 1.25 + ina["uVscale"] = 2.5 + ina["data"] = data + self._inas.append(ina) + + def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=1000): + """Write command to logger logic. + + This function writes byte command values list to stm, then reads + byte status. + + Args: + write_list: list of command byte values [0~255]. + read_count: number of status byte values to read. + + Interface: + write: [command, data ... ] + read: [status ] + + Returns: + bytes read, or None on failure. + """ + self._logger.debug( + "Spower.wr_command(write_list=[%s] (%d), read_count=%s)", + list(bytearray(write_list)), + len(write_list), + read_count, + ) + + # Clean up args from python style to correct types. + write_length = 0 + if write_list: + write_length = len(write_list) + if not read_count: + read_count = 0 + + # Send command to stm32. + if write_list: + cmd = write_list + ret = self._write_ep.write(cmd, wtimeout) + + self._logger.debug("RET: %s ", ret) + + # Read back response if necessary. + if read_count: + bytesread = self._read_ep.read(512, rtimeout) + self._logger.debug("BYTES: [%s]", bytesread) + + if len(bytesread) != read_count: + pass + + self._logger.debug("STATUS: 0x%02x", int(bytesread[0])) + if read_count == 1: + return bytesread[0] + else: + return bytesread + + return None + + def clear(self): + """Clear pending reads on the stm32""" + try: + while True: + ret = self.wr_command(b"", read_count=512, rtimeout=100, wtimeout=50) + self._logger.debug( + "Try Clear: read %s", "success" if ret == 0 else "failure" + ) + except: + pass + + def send_reset(self): + """Reset the power interface on the stm32""" + cmd = struct.pack("<H", self.CMD_RESET) + ret = self.wr_command(cmd, rtimeout=50, wtimeout=50) + self._logger.debug("Command RESET: %s", "success" if ret == 0 else "failure") + + def reset(self): + """Try resetting the USB interface until success. + + Use linear back off strategy when encounter the error with 10ms increment. + + Raises: + Exception on failure. + """ + max_reset_retry = 100 + for count in range(1, max_reset_retry + 1): + self.clear() + try: + self.send_reset() + return + except Exception as e: + self.clear() + self.clear() + self._logger.debug("TRY %d of %d: %s", count, max_reset_retry, e) + time.sleep(count * 0.01) + raise Exception("Power", "Failed to reset") + + def stop(self): + """Stop any active data acquisition.""" + cmd = struct.pack("<H", self.CMD_STOP) + ret = self.wr_command(cmd) + self._logger.debug("Command STOP: %s", "success" if ret == 0 else "failure") + + def start(self, integration_us): + """Start data acquisition. + + Args: + integration_us: int, how many us between samples, and + how often the data block must be read. + + Returns: + actual sampling interval in ms. + """ + cmd = struct.pack("<HI", self.CMD_START, integration_us) + read = self.wr_command(cmd, read_count=5) + actual_us = 0 + if len(read) == 5: + ret, actual_us = struct.unpack("<BI", read) + self._logger.debug( + "Command START: %s %dus", + "success" if ret == 0 else "failure", + actual_us, + ) + else: + self._logger.debug("Command START: FAIL") + + return actual_us + + def add_ina_name(self, name_tuple): + """Add INA from board config. + + Args: + name_tuple: name and type of power rail in board config. + + Returns: + True if INA added, False if the INA is not on this board. + + Raises: + Exception on unexpected failure. + """ + name, ina_type = name_tuple + + for datum in self._brdcfg: + if datum["name"] == name: + rs = int(float(datum["rs"]) * 1000.0) + board = datum["sweetberry"] + + if board == self._board: + if "port" in datum and "addr" in datum: + port = datum["port"] + addr = datum["addr"] + else: + channel = int(datum["channel"]) + port, addr = self.CHMAP[channel] + self.add_ina(port, ina_type, addr, 0, rs, data=datum) + return True + else: + return False + raise Exception("Power", "Failed to find INA %s" % name) + + def set_time(self, timestamp_us): + """Set sweetberry time to match host time. + + Args: + timestamp_us: host timestmap in us. + """ + # 0x0005 , 8 byte timestamp + cmd = struct.pack("<HQ", self.CMD_SETTIME, timestamp_us) + ret = self.wr_command(cmd) + + self._logger.debug("Command SETTIME: %s", "success" if ret == 0 else "failure") + + def add_ina(self, bus, ina_type, addr, extra, resistance, data=None): + """Add an INA to the data acquisition list. + + Args: + bus: which i2c bus the INA is on. Same ordering as Si2c. + ina_type: Ina interface: INA_POWER/BUSV/etc. + addr: 7 bit i2c addr of this INA + extra: extra data for nonstandard configs. + resistance: int, shunt resistance in mOhm + """ + # 0x0002, 1B: bus, 1B:INA type, 1B: INA addr, 1B: extra, 4B: Rs + cmd = struct.pack( + "<HBBBBI", self.CMD_ADDINA, bus, ina_type, addr, extra, resistance + ) + ret = self.wr_command(cmd) + if ret == 0: + if data: + name = data["name"] + else: + name = "ina%d_%02x" % (bus, addr) + self.append_ina_struct( + name, resistance, bus, addr, data=data, ina_type=ina_type + ) + self._logger.debug("Command ADD_INA: %s", "success" if ret == 0 else "failure") + + def report_header_size(self): + """Helper function to calculate power record header size.""" + result = 2 + timestamp = 8 + return result + timestamp + + def report_size(self, ina_count): + """Helper function to calculate full power record size.""" + record = 2 + + datasize = self.report_header_size() + ina_count * record + # Round to multiple of 4 bytes. + datasize = int(((datasize + 3) // 4) * 4) + + return datasize + + def read_line(self): + """Read a line of data from the setup INAs + + Returns: + list of dicts of the values read by ina/type tuple, otherwise None. + [{ts:100, (vbat, power):450}, {ts:200, (vbat, power):440}] + """ + try: + expected_bytes = self.report_size(len(self._inas)) + cmd = struct.pack("<H", self.CMD_NEXT) + bytesread = self.wr_command(cmd, read_count=expected_bytes) + except usb.core.USBError as e: + self._logger.error("READ LINE FAILED %s", e) + return None + + if len(bytesread) == 1: + if bytesread[0] != 0x6: + self._logger.debug( + "READ LINE FAILED bytes: %d ret: %02x", len(bytesread), bytesread[0] + ) + return None + + if len(bytesread) % expected_bytes != 0: + self._logger.debug( + "READ LINE WARNING: expected %d, got %d", expected_bytes, len(bytesread) + ) + + packet_count = len(bytesread) // expected_bytes + + values = [] + for i in range(0, packet_count): + start = i * expected_bytes + end = (i + 1) * expected_bytes + record = self.interpret_line(bytesread[start:end]) + values.append(record) + + return values + + def interpret_line(self, data): + """Interpret a power record from INAs + + Args: + data: one single record of bytes. + + Output: + stdout of the record in csv format. + + Returns: + dict containing name, value of recorded data. + """ + status, size = struct.unpack("<BB", data[0:2]) + if len(data) != self.report_size(size): + self._logger.error( + "READ LINE FAILED st:%d size:%d expected:%d len:%d", + status, + size, + self.report_size(size), + len(data), + ) + else: + pass - timestamp = struct.unpack("<Q", data[2:10])[0] - self._logger.debug("READ LINE: st:%d size:%d time:%dus", status, size, - timestamp) - ftimestamp = float(timestamp) / 1000000. + timestamp = struct.unpack("<Q", data[2:10])[0] + self._logger.debug( + "READ LINE: st:%d size:%d time:%dus", status, size, timestamp + ) + ftimestamp = float(timestamp) / 1000000.0 - record = {"ts": ftimestamp, "status": status, "berry":self._board} + record = {"ts": ftimestamp, "status": status, "berry": self._board} - for i in range(0, size): - idx = self.report_header_size() + 2*i - name = self._inas[i]['name'] - name_tuple = (self._inas[i]['name'], self._inas[i]['type']) + for i in range(0, size): + idx = self.report_header_size() + 2 * i + name = self._inas[i]["name"] + name_tuple = (self._inas[i]["name"], self._inas[i]["type"]) - raw_val = struct.unpack("<h", data[idx:idx+2])[0] + raw_val = struct.unpack("<h", data[idx : idx + 2])[0] - if self._inas[i]['type'] == Spower.INA_POWER: - val = raw_val * self._inas[i]['uWscale'] - elif self._inas[i]['type'] == Spower.INA_BUSV: - val = raw_val * self._inas[i]['mVscale'] - elif self._inas[i]['type'] == Spower.INA_CURRENT: - val = raw_val * self._inas[i]['uAscale'] - elif self._inas[i]['type'] == Spower.INA_SHUNTV: - val = raw_val * self._inas[i]['uVscale'] + if self._inas[i]["type"] == Spower.INA_POWER: + val = raw_val * self._inas[i]["uWscale"] + elif self._inas[i]["type"] == Spower.INA_BUSV: + val = raw_val * self._inas[i]["mVscale"] + elif self._inas[i]["type"] == Spower.INA_CURRENT: + val = raw_val * self._inas[i]["uAscale"] + elif self._inas[i]["type"] == Spower.INA_SHUNTV: + val = raw_val * self._inas[i]["uVscale"] - self._logger.debug("READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, - raw_val, val) - record[name_tuple] = val + self._logger.debug( + "READ %d %s: %fs: 0x%04x %f", i, name, ftimestamp, raw_val, val + ) + record[name_tuple] = val - return record + return record - def load_board(self, brdfile): - """Load a board config. + def load_board(self, brdfile): + """Load a board config. - Args: - brdfile: Filename of a json file decribing the INA wiring of this board. - """ - with open(process_filename(brdfile)) as data_file: - data = json.load(data_file) + Args: + brdfile: Filename of a json file decribing the INA wiring of this board. + """ + with open(process_filename(brdfile)) as data_file: + data = json.load(data_file) - #TODO: validate this. - self._brdcfg = data; - self._logger.debug(pprint.pformat(data)) + # TODO: validate this. + self._brdcfg = data + self._logger.debug(pprint.pformat(data)) class powerlog(object): - """Power class to log aggregated power. - - Usage: - obj = powerlog() + """Power class to log aggregated power. - Instance Variables: - _data: a StatsManager object that records sweetberry readings and calculates - statistics. - _pwr[]: Spower objects for individual sweetberries. - """ - - def __init__(self, brdfile, cfgfile, serial_a=None, serial_b=None, - sync_date=False, use_ms=False, use_mW=False, print_stats=False, - stats_dir=None, stats_json_dir=None, print_raw_data=True, - raw_data_dir=None): - """Init the powerlog class and set the variables. - - Args: - brdfile: string name of json file containing board layout. - cfgfile: string name of json containing list of rails to read. - serial_a: serial number of sweetberry A. - serial_b: serial number of sweetberry B. - sync_date: report timestamps synced with host datetime. - use_ms: report timestamps in ms rather than us. - use_mW: report power as milliwatts, otherwise default to microwatts. - print_stats: print statistics for sweetberry readings at the end. - stats_dir: directory to save sweetberry readings statistics; if None then - do not save the statistics. - stats_json_dir: directory to save means of sweetberry readings in json - format; if None then do not save the statistics. - print_raw_data: print sweetberry readings raw data in real time, default - is to print. - raw_data_dir: directory to save sweetberry readings raw data; if None then - do not save the raw data. - """ - self._logger = logging.getLogger(__name__) - self._data = StatsManager() - self._pwr = {} - self._use_ms = use_ms - self._use_mW = use_mW - self._print_stats = print_stats - self._stats_dir = stats_dir - self._stats_json_dir = stats_json_dir - self._print_raw_data = print_raw_data - self._raw_data_dir = raw_data_dir - - if not serial_a and not serial_b: - self._pwr['A'] = Spower('A') - if serial_a: - self._pwr['A'] = Spower('A', serialname=serial_a) - if serial_b: - self._pwr['B'] = Spower('B', serialname=serial_b) - - with open(process_filename(cfgfile)) as data_file: - names = json.load(data_file) - self._names = self.process_scenario(names) - - for key in self._pwr: - self._pwr[key].load_board(brdfile) - self._pwr[key].reset() - - # Allocate the rails to the appropriate boards. - used_boards = [] - for name in self._names: - success = False - for key in self._pwr.keys(): - if self._pwr[key].add_ina_name(name): - success = True - if key not in used_boards: - used_boards.append(key) - if not success: - raise Exception("Failed to add %s (maybe missing " - "sweetberry, or bad board file?)" % name) - - # Evict unused boards. - for key in list(self._pwr.keys()): - if key not in used_boards: - self._pwr.pop(key) - - for key in self._pwr.keys(): - if sync_date: - self._pwr[key].set_time(time.time() * 1000000) - else: - self._pwr[key].set_time(0) - - def process_scenario(self, name_list): - """Return list of tuples indicating name and type. + Usage: + obj = powerlog() - Args: - json originated list of names, or [name, type] - Returns: - list of tuples of (name, type) defaulting to type "POWER" - Raises: exception, invalid INA type. + Instance Variables: + _data: a StatsManager object that records sweetberry readings and calculates + statistics. + _pwr[]: Spower objects for individual sweetberries. """ - names = [] - for entry in name_list: - if isinstance(entry, list): - name = entry[0] - if entry[1] == "POWER": - type = Spower.INA_POWER - elif entry[1] == "BUSV": - type = Spower.INA_BUSV - elif entry[1] == "CURRENT": - type = Spower.INA_CURRENT - elif entry[1] == "SHUNTV": - type = Spower.INA_SHUNTV - else: - raise Exception("Invalid INA type", "Type of %s [%s] not recognized," - " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1])) - else: - name = entry - type = Spower.INA_POWER - - names.append((name, type)) - return names - def start(self, integration_us_request, seconds, sync_speed=.8): - """Starts sampling. + def __init__( + self, + brdfile, + cfgfile, + serial_a=None, + serial_b=None, + sync_date=False, + use_ms=False, + use_mW=False, + print_stats=False, + stats_dir=None, + stats_json_dir=None, + print_raw_data=True, + raw_data_dir=None, + ): + """Init the powerlog class and set the variables. + + Args: + brdfile: string name of json file containing board layout. + cfgfile: string name of json containing list of rails to read. + serial_a: serial number of sweetberry A. + serial_b: serial number of sweetberry B. + sync_date: report timestamps synced with host datetime. + use_ms: report timestamps in ms rather than us. + use_mW: report power as milliwatts, otherwise default to microwatts. + print_stats: print statistics for sweetberry readings at the end. + stats_dir: directory to save sweetberry readings statistics; if None then + do not save the statistics. + stats_json_dir: directory to save means of sweetberry readings in json + format; if None then do not save the statistics. + print_raw_data: print sweetberry readings raw data in real time, default + is to print. + raw_data_dir: directory to save sweetberry readings raw data; if None then + do not save the raw data. + """ + self._logger = logging.getLogger(__name__) + self._data = StatsManager() + self._pwr = {} + self._use_ms = use_ms + self._use_mW = use_mW + self._print_stats = print_stats + self._stats_dir = stats_dir + self._stats_json_dir = stats_json_dir + self._print_raw_data = print_raw_data + self._raw_data_dir = raw_data_dir + + if not serial_a and not serial_b: + self._pwr["A"] = Spower("A") + if serial_a: + self._pwr["A"] = Spower("A", serialname=serial_a) + if serial_b: + self._pwr["B"] = Spower("B", serialname=serial_b) + + with open(process_filename(cfgfile)) as data_file: + names = json.load(data_file) + self._names = self.process_scenario(names) - Args: - integration_us_request: requested interval between sample values. - seconds: time until exit, or None to run until cancel. - sync_speed: A usb request is sent every [.8] * integration_us. - """ - # We will get back the actual integration us. - # It should be the same for all devices. - integration_us = None - for key in self._pwr: - integration_us_new = self._pwr[key].start(integration_us_request) - if integration_us: - if integration_us != integration_us_new: - raise Exception("FAIL", - "Integration on A: %dus != integration on B %dus" % ( - integration_us, integration_us_new)) - integration_us = integration_us_new - - # CSV header - title = "ts:%dus" % integration_us - for name_tuple in self._names: - name, ina_type = name_tuple - - if ina_type == Spower.INA_POWER: - unit = "mW" if self._use_mW else "uW" - elif ina_type == Spower.INA_BUSV: - unit = "mV" - elif ina_type == Spower.INA_CURRENT: - unit = "uA" - elif ina_type == Spower.INA_SHUNTV: - unit = "uV" - - title += ", %s %s" % (name, unit) - name_type = name + Spower.INA_SUFFIX[ina_type] - self._data.SetUnit(name_type, unit) - title += ", status" - if self._print_raw_data: - logoutput(title) - - forever = False - if not seconds: - forever = True - end_time = time.time() + seconds - try: - pending_records = [] - while forever or end_time > time.time(): - if (integration_us > 5000): - time.sleep((integration_us / 1000000.) * sync_speed) for key in self._pwr: - records = self._pwr[key].read_line() - if not records: - continue - - for record in records: - pending_records.append(record) - - pending_records.sort(key=lambda r: r['ts']) - - aggregate_record = {"boards": set()} - for record in pending_records: - if record["berry"] not in aggregate_record["boards"]: - for rkey in record.keys(): - aggregate_record[rkey] = record[rkey] - aggregate_record["boards"].add(record["berry"]) - else: - self._logger.info("break %s, %s", record["berry"], - aggregate_record["boards"]) - break - - if aggregate_record["boards"] == set(self._pwr.keys()): - csv = "%f" % aggregate_record["ts"] - for name in self._names: - if name in aggregate_record: - multiplier = 0.001 if (self._use_mW and - name[1]==Spower.INA_POWER) else 1 - value = aggregate_record[name] * multiplier - csv += ", %.2f" % value - name_type = name[0] + Spower.INA_SUFFIX[name[1]] - self._data.AddSample(name_type, value) - else: - csv += ", " - csv += ", %d" % aggregate_record["status"] - if self._print_raw_data: - logoutput(csv) - - aggregate_record = {"boards": set()} - for r in range(0, len(self._pwr)): - pending_records.pop(0) - - except KeyboardInterrupt: - self._logger.info('\nCTRL+C caught.') - - finally: - for key in self._pwr: - self._pwr[key].stop() - self._data.CalculateStats() - if self._print_stats: - print(self._data.SummaryToString()) - save_dir = 'sweetberry%s' % time.time() - if self._stats_dir: - stats_dir = os.path.join(self._stats_dir, save_dir) - self._data.SaveSummary(stats_dir) - if self._stats_json_dir: - stats_json_dir = os.path.join(self._stats_json_dir, save_dir) - self._data.SaveSummaryJSON(stats_json_dir) - if self._raw_data_dir: - raw_data_dir = os.path.join(self._raw_data_dir, save_dir) - self._data.SaveRawData(raw_data_dir) + self._pwr[key].load_board(brdfile) + self._pwr[key].reset() + + # Allocate the rails to the appropriate boards. + used_boards = [] + for name in self._names: + success = False + for key in self._pwr.keys(): + if self._pwr[key].add_ina_name(name): + success = True + if key not in used_boards: + used_boards.append(key) + if not success: + raise Exception( + "Failed to add %s (maybe missing " + "sweetberry, or bad board file?)" % name + ) + + # Evict unused boards. + for key in list(self._pwr.keys()): + if key not in used_boards: + self._pwr.pop(key) + + for key in self._pwr.keys(): + if sync_date: + self._pwr[key].set_time(time.time() * 1000000) + else: + self._pwr[key].set_time(0) + + def process_scenario(self, name_list): + """Return list of tuples indicating name and type. + + Args: + json originated list of names, or [name, type] + Returns: + list of tuples of (name, type) defaulting to type "POWER" + Raises: exception, invalid INA type. + """ + names = [] + for entry in name_list: + if isinstance(entry, list): + name = entry[0] + if entry[1] == "POWER": + type = Spower.INA_POWER + elif entry[1] == "BUSV": + type = Spower.INA_BUSV + elif entry[1] == "CURRENT": + type = Spower.INA_CURRENT + elif entry[1] == "SHUNTV": + type = Spower.INA_SHUNTV + else: + raise Exception( + "Invalid INA type", + "Type of %s [%s] not recognized," + " try one of POWER, BUSV, CURRENT" % (entry[0], entry[1]), + ) + else: + name = entry + type = Spower.INA_POWER + + names.append((name, type)) + return names + + def start(self, integration_us_request, seconds, sync_speed=0.8): + """Starts sampling. + + Args: + integration_us_request: requested interval between sample values. + seconds: time until exit, or None to run until cancel. + sync_speed: A usb request is sent every [.8] * integration_us. + """ + # We will get back the actual integration us. + # It should be the same for all devices. + integration_us = None + for key in self._pwr: + integration_us_new = self._pwr[key].start(integration_us_request) + if integration_us: + if integration_us != integration_us_new: + raise Exception( + "FAIL", + "Integration on A: %dus != integration on B %dus" + % (integration_us, integration_us_new), + ) + integration_us = integration_us_new + + # CSV header + title = "ts:%dus" % integration_us + for name_tuple in self._names: + name, ina_type = name_tuple + + if ina_type == Spower.INA_POWER: + unit = "mW" if self._use_mW else "uW" + elif ina_type == Spower.INA_BUSV: + unit = "mV" + elif ina_type == Spower.INA_CURRENT: + unit = "uA" + elif ina_type == Spower.INA_SHUNTV: + unit = "uV" + + title += ", %s %s" % (name, unit) + name_type = name + Spower.INA_SUFFIX[ina_type] + self._data.SetUnit(name_type, unit) + title += ", status" + if self._print_raw_data: + logoutput(title) + + forever = False + if not seconds: + forever = True + end_time = time.time() + seconds + try: + pending_records = [] + while forever or end_time > time.time(): + if integration_us > 5000: + time.sleep((integration_us / 1000000.0) * sync_speed) + for key in self._pwr: + records = self._pwr[key].read_line() + if not records: + continue + + for record in records: + pending_records.append(record) + + pending_records.sort(key=lambda r: r["ts"]) + + aggregate_record = {"boards": set()} + for record in pending_records: + if record["berry"] not in aggregate_record["boards"]: + for rkey in record.keys(): + aggregate_record[rkey] = record[rkey] + aggregate_record["boards"].add(record["berry"]) + else: + self._logger.info( + "break %s, %s", record["berry"], aggregate_record["boards"] + ) + break + + if aggregate_record["boards"] == set(self._pwr.keys()): + csv = "%f" % aggregate_record["ts"] + for name in self._names: + if name in aggregate_record: + multiplier = ( + 0.001 + if (self._use_mW and name[1] == Spower.INA_POWER) + else 1 + ) + value = aggregate_record[name] * multiplier + csv += ", %.2f" % value + name_type = name[0] + Spower.INA_SUFFIX[name[1]] + self._data.AddSample(name_type, value) + else: + csv += ", " + csv += ", %d" % aggregate_record["status"] + if self._print_raw_data: + logoutput(csv) + + aggregate_record = {"boards": set()} + for r in range(0, len(self._pwr)): + pending_records.pop(0) + + except KeyboardInterrupt: + self._logger.info("\nCTRL+C caught.") + + finally: + for key in self._pwr: + self._pwr[key].stop() + self._data.CalculateStats() + if self._print_stats: + print(self._data.SummaryToString()) + save_dir = "sweetberry%s" % time.time() + if self._stats_dir: + stats_dir = os.path.join(self._stats_dir, save_dir) + self._data.SaveSummary(stats_dir) + if self._stats_json_dir: + stats_json_dir = os.path.join(self._stats_json_dir, save_dir) + self._data.SaveSummaryJSON(stats_json_dir) + if self._raw_data_dir: + raw_data_dir = os.path.join(self._raw_data_dir, save_dir) + self._data.SaveRawData(raw_data_dir) def main(argv=None): - if argv is None: - argv = sys.argv[1:] - # Command line argument description. - parser = argparse.ArgumentParser( - description="Gather CSV data from sweetberry") - parser.add_argument('-b', '--board', type=str, - help="Board configuration file, eg. my.board", default="") - parser.add_argument('-c', '--config', type=str, - help="Rail config to monitor, eg my.scenario", default="") - parser.add_argument('-A', '--serial', type=str, - help="Serial number of sweetberry A", default="") - parser.add_argument('-B', '--serial_b', type=str, - help="Serial number of sweetberry B", default="") - parser.add_argument('-t', '--integration_us', type=int, - help="Target integration time for samples", default=100000) - parser.add_argument('-s', '--seconds', type=float, - help="Seconds to run capture", default=0.) - parser.add_argument('--date', default=False, - help="Sync logged timestamp to host date", action="store_true") - parser.add_argument('--ms', default=False, - help="Print timestamp as milliseconds", action="store_true") - parser.add_argument('--mW', default=False, - help="Print power as milliwatts, otherwise default to microwatts", - action="store_true") - parser.add_argument('--slow', default=False, - help="Intentionally overflow", action="store_true") - parser.add_argument('--print_stats', default=False, action="store_true", - help="Print statistics for sweetberry readings at the end") - parser.add_argument('--save_stats', type=str, nargs='?', - dest='stats_dir', metavar='STATS_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save statistics for sweetberry readings to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "stats will be saved to where %(prog)s is located; if this flag is " - "not set, then do not save stats") - parser.add_argument('--save_stats_json', type=str, nargs='?', - dest='stats_json_dir', metavar='STATS_JSON_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save means for sweetberry readings in json to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "stats will be saved to where %(prog)s is located; if this flag is " - "not set, then do not save stats") - parser.add_argument('--no_print_raw_data', - dest='print_raw_data', default=True, action="store_false", - help="Not print raw sweetberry readings at real time, default is to " - "print") - parser.add_argument('--save_raw_data', type=str, nargs='?', - dest='raw_data_dir', metavar='RAW_DATA_DIR', - const=os.path.dirname(os.path.abspath(__file__)), default=None, - help="Save raw data for sweetberry readings to %(metavar)s if " - "%(metavar)s is specified, %(metavar)s will be created if it does " - "not exist; if %(metavar)s is not specified but the flag is set, " - "raw data will be saved to where %(prog)s is located; if this flag " - "is not set, then do not save raw data") - parser.add_argument('-v', '--verbose', default=False, - help="Very chatty printout", action="store_true") - - args = parser.parse_args(argv) - - root_logger = logging.getLogger(__name__) - if args.verbose: - root_logger.setLevel(logging.DEBUG) - else: - root_logger.setLevel(logging.INFO) - - # if powerlog is used through main, log to sys.stdout - if __name__ == "__main__": - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) - root_logger.addHandler(stdout_handler) - - integration_us_request = args.integration_us - if not args.board: - raise Exception("Power", "No board file selected, see board.README") - if not args.config: - raise Exception("Power", "No config file selected, see board.README") - - brdfile = args.board - cfgfile = args.config - seconds = args.seconds - serial_a = args.serial - serial_b = args.serial_b - sync_date = args.date - use_ms = args.ms - use_mW = args.mW - print_stats = args.print_stats - stats_dir = args.stats_dir - stats_json_dir = args.stats_json_dir - print_raw_data = args.print_raw_data - raw_data_dir = args.raw_data_dir - - boards = [] - - sync_speed = .8 - if args.slow: - sync_speed = 1.2 - - # Set up logging interface. - powerlogger = powerlog(brdfile, cfgfile, serial_a=serial_a, serial_b=serial_b, - sync_date=sync_date, use_ms=use_ms, use_mW=use_mW, - print_stats=print_stats, stats_dir=stats_dir, - stats_json_dir=stats_json_dir, - print_raw_data=print_raw_data,raw_data_dir=raw_data_dir) - - # Start logging. - powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed) + if argv is None: + argv = sys.argv[1:] + # Command line argument description. + parser = argparse.ArgumentParser(description="Gather CSV data from sweetberry") + parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration file, eg. my.board", + default="", + ) + parser.add_argument( + "-c", + "--config", + type=str, + help="Rail config to monitor, eg my.scenario", + default="", + ) + parser.add_argument( + "-A", "--serial", type=str, help="Serial number of sweetberry A", default="" + ) + parser.add_argument( + "-B", "--serial_b", type=str, help="Serial number of sweetberry B", default="" + ) + parser.add_argument( + "-t", + "--integration_us", + type=int, + help="Target integration time for samples", + default=100000, + ) + parser.add_argument( + "-s", "--seconds", type=float, help="Seconds to run capture", default=0.0 + ) + parser.add_argument( + "--date", + default=False, + help="Sync logged timestamp to host date", + action="store_true", + ) + parser.add_argument( + "--ms", + default=False, + help="Print timestamp as milliseconds", + action="store_true", + ) + parser.add_argument( + "--mW", + default=False, + help="Print power as milliwatts, otherwise default to microwatts", + action="store_true", + ) + parser.add_argument( + "--slow", default=False, help="Intentionally overflow", action="store_true" + ) + parser.add_argument( + "--print_stats", + default=False, + action="store_true", + help="Print statistics for sweetberry readings at the end", + ) + parser.add_argument( + "--save_stats", + type=str, + nargs="?", + dest="stats_dir", + metavar="STATS_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save statistics for sweetberry readings to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "stats will be saved to where %(prog)s is located; if this flag is " + "not set, then do not save stats", + ) + parser.add_argument( + "--save_stats_json", + type=str, + nargs="?", + dest="stats_json_dir", + metavar="STATS_JSON_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save means for sweetberry readings in json to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "stats will be saved to where %(prog)s is located; if this flag is " + "not set, then do not save stats", + ) + parser.add_argument( + "--no_print_raw_data", + dest="print_raw_data", + default=True, + action="store_false", + help="Not print raw sweetberry readings at real time, default is to " "print", + ) + parser.add_argument( + "--save_raw_data", + type=str, + nargs="?", + dest="raw_data_dir", + metavar="RAW_DATA_DIR", + const=os.path.dirname(os.path.abspath(__file__)), + default=None, + help="Save raw data for sweetberry readings to %(metavar)s if " + "%(metavar)s is specified, %(metavar)s will be created if it does " + "not exist; if %(metavar)s is not specified but the flag is set, " + "raw data will be saved to where %(prog)s is located; if this flag " + "is not set, then do not save raw data", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + help="Very chatty printout", + action="store_true", + ) + + args = parser.parse_args(argv) + + root_logger = logging.getLogger(__name__) + if args.verbose: + root_logger.setLevel(logging.DEBUG) + else: + root_logger.setLevel(logging.INFO) + + # if powerlog is used through main, log to sys.stdout + if __name__ == "__main__": + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + root_logger.addHandler(stdout_handler) + + integration_us_request = args.integration_us + if not args.board: + raise Exception("Power", "No board file selected, see board.README") + if not args.config: + raise Exception("Power", "No config file selected, see board.README") + + brdfile = args.board + cfgfile = args.config + seconds = args.seconds + serial_a = args.serial + serial_b = args.serial_b + sync_date = args.date + use_ms = args.ms + use_mW = args.mW + print_stats = args.print_stats + stats_dir = args.stats_dir + stats_json_dir = args.stats_json_dir + print_raw_data = args.print_raw_data + raw_data_dir = args.raw_data_dir + + boards = [] + + sync_speed = 0.8 + if args.slow: + sync_speed = 1.2 + + # Set up logging interface. + powerlogger = powerlog( + brdfile, + cfgfile, + serial_a=serial_a, + serial_b=serial_b, + sync_date=sync_date, + use_ms=use_ms, + use_mW=use_mW, + print_stats=print_stats, + stats_dir=stats_dir, + stats_json_dir=stats_json_dir, + print_raw_data=print_raw_data, + raw_data_dir=raw_data_dir, + ) + + # Start logging. + powerlogger.start(integration_us_request, seconds, sync_speed=sync_speed) if __name__ == "__main__": - main() + main() diff --git a/extra/usb_power/powerlog_unittest.py b/extra/usb_power/powerlog_unittest.py index 1d0718530e..693826c16d 100644 --- a/extra/usb_power/powerlog_unittest.py +++ b/extra/usb_power/powerlog_unittest.py @@ -15,40 +15,42 @@ import unittest import powerlog + class TestPowerlog(unittest.TestCase): - """Test to verify powerlog util methods work as expected.""" - - def setUp(self): - """Set up data and create a temporary directory to save data and stats.""" - self.tempdir = tempfile.mkdtemp() - self.filename = 'testfile' - self.filepath = os.path.join(self.tempdir, self.filename) - with open(self.filepath, 'w') as f: - f.write('') - - def tearDown(self): - """Delete the temporary directory and its content.""" - shutil.rmtree(self.tempdir) - - def test_ProcessFilenameAbsoluteFilePath(self): - """Absolute file path is returned unchanged.""" - processed_fname = powerlog.process_filename(self.filepath) - self.assertEqual(self.filepath, processed_fname) - - def test_ProcessFilenameRelativeFilePath(self): - """Finds relative file path inside a known config location.""" - original = powerlog.CONFIG_LOCATIONS - powerlog.CONFIG_LOCATIONS = [self.tempdir] - processed_fname = powerlog.process_filename(self.filename) - try: - self.assertEqual(self.filepath, processed_fname) - finally: - powerlog.CONFIG_LOCATIONS = original - - def test_ProcessFilenameInvalid(self): - """IOError is raised when file cannot be found by any of the four ways.""" - with self.assertRaises(IOError): - powerlog.process_filename(self.filename) - -if __name__ == '__main__': - unittest.main() + """Test to verify powerlog util methods work as expected.""" + + def setUp(self): + """Set up data and create a temporary directory to save data and stats.""" + self.tempdir = tempfile.mkdtemp() + self.filename = "testfile" + self.filepath = os.path.join(self.tempdir, self.filename) + with open(self.filepath, "w") as f: + f.write("") + + def tearDown(self): + """Delete the temporary directory and its content.""" + shutil.rmtree(self.tempdir) + + def test_ProcessFilenameAbsoluteFilePath(self): + """Absolute file path is returned unchanged.""" + processed_fname = powerlog.process_filename(self.filepath) + self.assertEqual(self.filepath, processed_fname) + + def test_ProcessFilenameRelativeFilePath(self): + """Finds relative file path inside a known config location.""" + original = powerlog.CONFIG_LOCATIONS + powerlog.CONFIG_LOCATIONS = [self.tempdir] + processed_fname = powerlog.process_filename(self.filename) + try: + self.assertEqual(self.filepath, processed_fname) + finally: + powerlog.CONFIG_LOCATIONS = original + + def test_ProcessFilenameInvalid(self): + """IOError is raised when file cannot be found by any of the four ways.""" + with self.assertRaises(IOError): + powerlog.process_filename(self.filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/usb_power/stats_manager.py b/extra/usb_power/stats_manager.py index 0f8c3fcb15..633312311d 100644 --- a/extra/usb_power/stats_manager.py +++ b/extra/usb_power/stats_manager.py @@ -20,382 +20,391 @@ import os import numpy -STATS_PREFIX = '@@' -NAN_TAG = '*' -NAN_DESCRIPTION = '%s domains contain NaN samples' % NAN_TAG +STATS_PREFIX = "@@" +NAN_TAG = "*" +NAN_DESCRIPTION = "%s domains contain NaN samples" % NAN_TAG LONG_UNIT = { - '': 'N/A', - 'mW': 'milliwatt', - 'uW': 'microwatt', - 'mV': 'millivolt', - 'uA': 'microamp', - 'uV': 'microvolt' + "": "N/A", + "mW": "milliwatt", + "uW": "microwatt", + "mV": "millivolt", + "uA": "microamp", + "uV": "microvolt", } class StatsManagerError(Exception): - """Errors in StatsManager class.""" - pass + """Errors in StatsManager class.""" + pass -class StatsManager(object): - """Calculates statistics for several lists of data(float). - - Example usage: - - >>> stats = StatsManager(title='Title Banner') - >>> stats.AddSample(TIME_KEY, 50.0) - >>> stats.AddSample(TIME_KEY, 25.0) - >>> stats.AddSample(TIME_KEY, 40.0) - >>> stats.AddSample(TIME_KEY, 10.0) - >>> stats.AddSample(TIME_KEY, 10.0) - >>> stats.AddSample('frobnicate', 11.5) - >>> stats.AddSample('frobnicate', 9.0) - >>> stats.AddSample('foobar', 11111.0) - >>> stats.AddSample('foobar', 22222.0) - >>> stats.CalculateStats() - >>> print(stats.SummaryToString()) - ` @@-------------------------------------------------------------- - ` @@ Title Banner - @@-------------------------------------------------------------- - @@ NAME COUNT MEAN STDDEV MAX MIN - @@ sample_msecs 4 31.25 15.16 50.00 10.00 - @@ foobar 2 16666.50 5555.50 22222.00 11111.00 - @@ frobnicate 2 10.25 1.25 11.50 9.00 - ` @@-------------------------------------------------------------- - - Attributes: - _data: dict of list of readings for each domain(key) - _unit: dict of unit for each domain(key) - _smid: id supplied to differentiate data output to other StatsManager - instances that potentially save to the same directory - if smid all output files will be named |smid|_|fname| - _title: title to add as banner to formatted summary. If no title, - no banner gets added - _order: list of formatting order for domains. Domains not listed are - displayed in sorted order - _hide_domains: collection of domains to hide when formatting summary string - _accept_nan: flag to indicate if NaN samples are acceptable - _nan_domains: set to keep track of which domains contain NaN samples - _summary: dict of stats per domain (key): min, max, count, mean, stddev - _logger = StatsManager logger - - Note: - _summary is empty until CalculateStats() is called, and is updated when - CalculateStats() is called. - """ - - # pylint: disable=W0102 - def __init__(self, smid='', title='', order=[], hide_domains=[], - accept_nan=True): - """Initialize infrastructure for data and their statistics.""" - self._title = title - self._data = collections.defaultdict(list) - self._unit = collections.defaultdict(str) - self._smid = smid - self._order = order - self._hide_domains = hide_domains - self._accept_nan = accept_nan - self._nan_domains = set() - self._summary = {} - self._logger = logging.getLogger(type(self).__name__) - - def AddSample(self, domain, sample): - """Add one sample for a domain. - - Args: - domain: the domain name for the sample. - sample: one time sample for domain, expect type float. - - Raises: - StatsManagerError: if trying to add NaN and |_accept_nan| is false - """ - try: - sample = float(sample) - except ValueError: - # if we don't accept nan this will be caught below - self._logger.debug('sample %s for domain %s is not a number. Making NaN', - sample, domain) - sample = float('NaN') - if not self._accept_nan and math.isnan(sample): - raise StatsManagerError('accept_nan is false. Cannot add NaN sample.') - self._data[domain].append(sample) - if math.isnan(sample): - self._nan_domains.add(domain) - - def SetUnit(self, domain, unit): - """Set the unit for a domain. - - There can be only one unit for each domain. Setting unit twice will - overwrite the original unit. - - Args: - domain: the domain name. - unit: unit of the domain. - """ - if domain in self._unit: - self._logger.warning('overwriting the unit of %s, old unit is %s, new ' - 'unit is %s.', domain, self._unit[domain], unit) - self._unit[domain] = unit - def CalculateStats(self): - """Calculate stats for all domain-data pairs. - - First erases all previous stats, then calculate stats for all data. - """ - self._summary = {} - for domain, data in self._data.items(): - data_np = numpy.array(data) - self._summary[domain] = { - 'mean': numpy.nanmean(data_np), - 'min': numpy.nanmin(data_np), - 'max': numpy.nanmax(data_np), - 'stddev': numpy.nanstd(data_np), - 'count': data_np.size, - } - - @property - def DomainsToDisplay(self): - """List of domains that the manager will output in summaries.""" - return set(self._summary.keys()) - set(self._hide_domains) - - @property - def NanInOutput(self): - """Return whether any of the domains to display have NaN values.""" - return bool(len(set(self._nan_domains) & self.DomainsToDisplay)) - - def _SummaryTable(self): - """Generate the matrix to output as a summary. - - Returns: - A 2d matrix of headers and their data for each domain - e.g. - [[NAME, COUNT, MEAN, STDDEV, MAX, MIN], - [pp5000_mw, 10, 50, 0, 50, 50]] - """ - headers = ('NAME', 'COUNT', 'MEAN', 'STDDEV', 'MAX', 'MIN') - table = [headers] - # determine what domains to display & and the order - domains_to_display = self.DomainsToDisplay - display_order = [key for key in self._order if key in domains_to_display] - domains_to_display -= set(display_order) - display_order.extend(sorted(domains_to_display)) - for domain in display_order: - stats = self._summary[domain] - if not domain.endswith(self._unit[domain]): - domain = '%s_%s' % (domain, self._unit[domain]) - if domain in self._nan_domains: - domain = '%s%s' % (domain, NAN_TAG) - row = [domain] - row.append(str(stats['count'])) - for entry in headers[2:]: - row.append('%.2f' % stats[entry.lower()]) - table.append(row) - return table - - def SummaryToMarkdownString(self): - """Format the summary into a b/ compatible markdown table string. - - This requires this sort of output format - - | header1 | header2 | header3 | ... - | --------- | --------- | --------- | ... - | sample1h1 | sample1h2 | sample1h3 | ... - . - . - . - - Returns: - formatted summary string. - """ - # All we need to do before processing is insert a row of '-' between - # the headers, and the data - table = self._SummaryTable() - columns = len(table[0]) - # Using '-:' to allow the numbers to be right aligned - sep_row = ['-'] + ['-:'] * (columns - 1) - table.insert(1, sep_row) - text_rows = ['|'.join(r) for r in table] - body = '\n'.join(['|%s|' % r for r in text_rows]) - if self._title: - title_section = '**%s** \n\n' % self._title - body = title_section + body - # Make sure that the body is terminated with a newline. - return body + '\n' - - def SummaryToString(self, prefix=STATS_PREFIX): - """Format summary into a string, ready for pretty print. - - See class description for format example. - - Args: - prefix: start every row in summary string with prefix, for easier reading. - - Returns: - formatted summary string. - """ - table = self._SummaryTable() - max_col_width = [] - for col_idx in range(len(table[0])): - col_item_widths = [len(row[col_idx]) for row in table] - max_col_width.append(max(col_item_widths)) - - formatted_lines = [] - for row in table: - formatted_row = prefix + ' ' - for i in range(len(row)): - formatted_row += row[i].rjust(max_col_width[i] + 2) - formatted_lines.append(formatted_row) - if self.NanInOutput: - formatted_lines.append('%s %s' % (prefix, NAN_DESCRIPTION)) - - if self._title: - line_length = len(formatted_lines[0]) - dec_length = len(prefix) - # trim title to be at most as long as the longest line without the prefix - title = self._title[:(line_length - dec_length)] - # line is a seperator line consisting of ----- - line = '%s%s' % (prefix, '-' * (line_length - dec_length)) - # prepend the prefix to the centered title - padded_title = '%s%s' % (prefix, title.center(line_length)[dec_length:]) - formatted_lines = [line, padded_title, line] + formatted_lines + [line] - formatted_output = '\n'.join(formatted_lines) - return formatted_output - - def GetSummary(self): - """Getter for summary.""" - return self._summary - - def _MakeUniqueFName(self, fname): - """prepend |_smid| to fname & rotate fname to ensure uniqueness. - - Before saving a file through the StatsManager, make sure that the filename - is unique, first by prepending the smid if any and otherwise by appending - increasing integer suffixes until the filename is unique. - - If |smid| is defined /path/to/example/file.txt becomes - /path/to/example/{smid}_file.txt. - - The rotation works by changing /path/to/example/somename.txt to - /path/to/example/somename1.txt if the first one already exists on the - system. - - Note: this is not thread-safe. While it makes sense to use StatsManager - in a threaded data-collection, the data retrieval should happen in a - single threaded environment to ensure files don't get potentially clobbered. - - Args: - fname: filename to ensure uniqueness. - - Returns: - {smid_}fname{tag}.[b].ext - the smid portion gets prepended if |smid| is defined - the tag portion gets appended if necessary to ensure unique fname - """ - fdir = os.path.dirname(fname) - base, ext = os.path.splitext(os.path.basename(fname)) - if self._smid: - base = '%s_%s' % (self._smid, base) - unique_fname = os.path.join(fdir, '%s%s' % (base, ext)) - tag = 0 - while os.path.exists(unique_fname): - old_fname = unique_fname - unique_fname = os.path.join(fdir, '%s%d%s' % (base, tag, ext)) - self._logger.warning('Attempted to store stats information at %s, but ' - 'file already exists. Attempting to store at %s ' - 'now.', old_fname, unique_fname) - tag += 1 - return unique_fname - - def SaveSummary(self, directory, fname='summary.txt', prefix=STATS_PREFIX): - """Save summary to file. - - Args: - directory: directory to save the summary in. - fname: filename to save summary under. - prefix: start every row in summary string with prefix, for easier reading. - - Returns: - full path of summary save location - """ - summary_str = self.SummaryToString(prefix=prefix) + '\n' - return self._SaveSummary(summary_str, directory, fname) - - def SaveSummaryJSON(self, directory, fname='summary.json'): - """Save summary (only MEAN) into a JSON file. - - Args: - directory: directory to save the JSON summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location - """ - data = {} - for domain in self._summary: - unit = LONG_UNIT.get(self._unit[domain], self._unit[domain]) - data_entry = {'mean': self._summary[domain]['mean'], 'unit': unit} - data[domain] = data_entry - summary_str = json.dumps(data, indent=2) - return self._SaveSummary(summary_str, directory, fname) - - def SaveSummaryMD(self, directory, fname='summary.md'): - """Save summary into a MD file to paste into b/. - - Args: - directory: directory to save the MD summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location +class StatsManager(object): + """Calculates statistics for several lists of data(float). + + Example usage: + + >>> stats = StatsManager(title='Title Banner') + >>> stats.AddSample(TIME_KEY, 50.0) + >>> stats.AddSample(TIME_KEY, 25.0) + >>> stats.AddSample(TIME_KEY, 40.0) + >>> stats.AddSample(TIME_KEY, 10.0) + >>> stats.AddSample(TIME_KEY, 10.0) + >>> stats.AddSample('frobnicate', 11.5) + >>> stats.AddSample('frobnicate', 9.0) + >>> stats.AddSample('foobar', 11111.0) + >>> stats.AddSample('foobar', 22222.0) + >>> stats.CalculateStats() + >>> print(stats.SummaryToString()) + ` @@-------------------------------------------------------------- + ` @@ Title Banner + @@-------------------------------------------------------------- + @@ NAME COUNT MEAN STDDEV MAX MIN + @@ sample_msecs 4 31.25 15.16 50.00 10.00 + @@ foobar 2 16666.50 5555.50 22222.00 11111.00 + @@ frobnicate 2 10.25 1.25 11.50 9.00 + ` @@-------------------------------------------------------------- + + Attributes: + _data: dict of list of readings for each domain(key) + _unit: dict of unit for each domain(key) + _smid: id supplied to differentiate data output to other StatsManager + instances that potentially save to the same directory + if smid all output files will be named |smid|_|fname| + _title: title to add as banner to formatted summary. If no title, + no banner gets added + _order: list of formatting order for domains. Domains not listed are + displayed in sorted order + _hide_domains: collection of domains to hide when formatting summary string + _accept_nan: flag to indicate if NaN samples are acceptable + _nan_domains: set to keep track of which domains contain NaN samples + _summary: dict of stats per domain (key): min, max, count, mean, stddev + _logger = StatsManager logger + + Note: + _summary is empty until CalculateStats() is called, and is updated when + CalculateStats() is called. """ - summary_str = self.SummaryToMarkdownString() - return self._SaveSummary(summary_str, directory, fname) - def _SaveSummary(self, output_str, directory, fname): - """Wrote |output_str| to |fname|. - - Args: - output_str: formatted output string - directory: directory to save the summary in. - fname: filename to save summary under. - - Returns: - full path of summary save location - """ - if not os.path.exists(directory): - os.makedirs(directory) - fname = self._MakeUniqueFName(os.path.join(directory, fname)) - with open(fname, 'w') as f: - f.write(output_str) - return fname - - def GetRawData(self): - """Getter for all raw_data.""" - return self._data - - def SaveRawData(self, directory, dirname='raw_data'): - """Save raw data to file. - - Args: - directory: directory to create the raw data folder in. - dirname: folder in which raw data live. - - Returns: - list of full path of each domain's raw data save location - """ - if not os.path.exists(directory): - os.makedirs(directory) - dirname = os.path.join(directory, dirname) - if not os.path.exists(dirname): - os.makedirs(dirname) - fnames = [] - for domain, data in self._data.items(): - if not domain.endswith(self._unit[domain]): - domain = '%s_%s' % (domain, self._unit[domain]) - fname = self._MakeUniqueFName(os.path.join(dirname, '%s.txt' % domain)) - with open(fname, 'w') as f: - f.write('\n'.join('%.2f' % sample for sample in data) + '\n') - fnames.append(fname) - return fnames + # pylint: disable=W0102 + def __init__(self, smid="", title="", order=[], hide_domains=[], accept_nan=True): + """Initialize infrastructure for data and their statistics.""" + self._title = title + self._data = collections.defaultdict(list) + self._unit = collections.defaultdict(str) + self._smid = smid + self._order = order + self._hide_domains = hide_domains + self._accept_nan = accept_nan + self._nan_domains = set() + self._summary = {} + self._logger = logging.getLogger(type(self).__name__) + + def AddSample(self, domain, sample): + """Add one sample for a domain. + + Args: + domain: the domain name for the sample. + sample: one time sample for domain, expect type float. + + Raises: + StatsManagerError: if trying to add NaN and |_accept_nan| is false + """ + try: + sample = float(sample) + except ValueError: + # if we don't accept nan this will be caught below + self._logger.debug( + "sample %s for domain %s is not a number. Making NaN", sample, domain + ) + sample = float("NaN") + if not self._accept_nan and math.isnan(sample): + raise StatsManagerError("accept_nan is false. Cannot add NaN sample.") + self._data[domain].append(sample) + if math.isnan(sample): + self._nan_domains.add(domain) + + def SetUnit(self, domain, unit): + """Set the unit for a domain. + + There can be only one unit for each domain. Setting unit twice will + overwrite the original unit. + + Args: + domain: the domain name. + unit: unit of the domain. + """ + if domain in self._unit: + self._logger.warning( + "overwriting the unit of %s, old unit is %s, new " "unit is %s.", + domain, + self._unit[domain], + unit, + ) + self._unit[domain] = unit + + def CalculateStats(self): + """Calculate stats for all domain-data pairs. + + First erases all previous stats, then calculate stats for all data. + """ + self._summary = {} + for domain, data in self._data.items(): + data_np = numpy.array(data) + self._summary[domain] = { + "mean": numpy.nanmean(data_np), + "min": numpy.nanmin(data_np), + "max": numpy.nanmax(data_np), + "stddev": numpy.nanstd(data_np), + "count": data_np.size, + } + + @property + def DomainsToDisplay(self): + """List of domains that the manager will output in summaries.""" + return set(self._summary.keys()) - set(self._hide_domains) + + @property + def NanInOutput(self): + """Return whether any of the domains to display have NaN values.""" + return bool(len(set(self._nan_domains) & self.DomainsToDisplay)) + + def _SummaryTable(self): + """Generate the matrix to output as a summary. + + Returns: + A 2d matrix of headers and their data for each domain + e.g. + [[NAME, COUNT, MEAN, STDDEV, MAX, MIN], + [pp5000_mw, 10, 50, 0, 50, 50]] + """ + headers = ("NAME", "COUNT", "MEAN", "STDDEV", "MAX", "MIN") + table = [headers] + # determine what domains to display & and the order + domains_to_display = self.DomainsToDisplay + display_order = [key for key in self._order if key in domains_to_display] + domains_to_display -= set(display_order) + display_order.extend(sorted(domains_to_display)) + for domain in display_order: + stats = self._summary[domain] + if not domain.endswith(self._unit[domain]): + domain = "%s_%s" % (domain, self._unit[domain]) + if domain in self._nan_domains: + domain = "%s%s" % (domain, NAN_TAG) + row = [domain] + row.append(str(stats["count"])) + for entry in headers[2:]: + row.append("%.2f" % stats[entry.lower()]) + table.append(row) + return table + + def SummaryToMarkdownString(self): + """Format the summary into a b/ compatible markdown table string. + + This requires this sort of output format + + | header1 | header2 | header3 | ... + | --------- | --------- | --------- | ... + | sample1h1 | sample1h2 | sample1h3 | ... + . + . + . + + Returns: + formatted summary string. + """ + # All we need to do before processing is insert a row of '-' between + # the headers, and the data + table = self._SummaryTable() + columns = len(table[0]) + # Using '-:' to allow the numbers to be right aligned + sep_row = ["-"] + ["-:"] * (columns - 1) + table.insert(1, sep_row) + text_rows = ["|".join(r) for r in table] + body = "\n".join(["|%s|" % r for r in text_rows]) + if self._title: + title_section = "**%s** \n\n" % self._title + body = title_section + body + # Make sure that the body is terminated with a newline. + return body + "\n" + + def SummaryToString(self, prefix=STATS_PREFIX): + """Format summary into a string, ready for pretty print. + + See class description for format example. + + Args: + prefix: start every row in summary string with prefix, for easier reading. + + Returns: + formatted summary string. + """ + table = self._SummaryTable() + max_col_width = [] + for col_idx in range(len(table[0])): + col_item_widths = [len(row[col_idx]) for row in table] + max_col_width.append(max(col_item_widths)) + + formatted_lines = [] + for row in table: + formatted_row = prefix + " " + for i in range(len(row)): + formatted_row += row[i].rjust(max_col_width[i] + 2) + formatted_lines.append(formatted_row) + if self.NanInOutput: + formatted_lines.append("%s %s" % (prefix, NAN_DESCRIPTION)) + + if self._title: + line_length = len(formatted_lines[0]) + dec_length = len(prefix) + # trim title to be at most as long as the longest line without the prefix + title = self._title[: (line_length - dec_length)] + # line is a seperator line consisting of ----- + line = "%s%s" % (prefix, "-" * (line_length - dec_length)) + # prepend the prefix to the centered title + padded_title = "%s%s" % (prefix, title.center(line_length)[dec_length:]) + formatted_lines = [line, padded_title, line] + formatted_lines + [line] + formatted_output = "\n".join(formatted_lines) + return formatted_output + + def GetSummary(self): + """Getter for summary.""" + return self._summary + + def _MakeUniqueFName(self, fname): + """prepend |_smid| to fname & rotate fname to ensure uniqueness. + + Before saving a file through the StatsManager, make sure that the filename + is unique, first by prepending the smid if any and otherwise by appending + increasing integer suffixes until the filename is unique. + + If |smid| is defined /path/to/example/file.txt becomes + /path/to/example/{smid}_file.txt. + + The rotation works by changing /path/to/example/somename.txt to + /path/to/example/somename1.txt if the first one already exists on the + system. + + Note: this is not thread-safe. While it makes sense to use StatsManager + in a threaded data-collection, the data retrieval should happen in a + single threaded environment to ensure files don't get potentially clobbered. + + Args: + fname: filename to ensure uniqueness. + + Returns: + {smid_}fname{tag}.[b].ext + the smid portion gets prepended if |smid| is defined + the tag portion gets appended if necessary to ensure unique fname + """ + fdir = os.path.dirname(fname) + base, ext = os.path.splitext(os.path.basename(fname)) + if self._smid: + base = "%s_%s" % (self._smid, base) + unique_fname = os.path.join(fdir, "%s%s" % (base, ext)) + tag = 0 + while os.path.exists(unique_fname): + old_fname = unique_fname + unique_fname = os.path.join(fdir, "%s%d%s" % (base, tag, ext)) + self._logger.warning( + "Attempted to store stats information at %s, but " + "file already exists. Attempting to store at %s " + "now.", + old_fname, + unique_fname, + ) + tag += 1 + return unique_fname + + def SaveSummary(self, directory, fname="summary.txt", prefix=STATS_PREFIX): + """Save summary to file. + + Args: + directory: directory to save the summary in. + fname: filename to save summary under. + prefix: start every row in summary string with prefix, for easier reading. + + Returns: + full path of summary save location + """ + summary_str = self.SummaryToString(prefix=prefix) + "\n" + return self._SaveSummary(summary_str, directory, fname) + + def SaveSummaryJSON(self, directory, fname="summary.json"): + """Save summary (only MEAN) into a JSON file. + + Args: + directory: directory to save the JSON summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + data = {} + for domain in self._summary: + unit = LONG_UNIT.get(self._unit[domain], self._unit[domain]) + data_entry = {"mean": self._summary[domain]["mean"], "unit": unit} + data[domain] = data_entry + summary_str = json.dumps(data, indent=2) + return self._SaveSummary(summary_str, directory, fname) + + def SaveSummaryMD(self, directory, fname="summary.md"): + """Save summary into a MD file to paste into b/. + + Args: + directory: directory to save the MD summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + summary_str = self.SummaryToMarkdownString() + return self._SaveSummary(summary_str, directory, fname) + + def _SaveSummary(self, output_str, directory, fname): + """Wrote |output_str| to |fname|. + + Args: + output_str: formatted output string + directory: directory to save the summary in. + fname: filename to save summary under. + + Returns: + full path of summary save location + """ + if not os.path.exists(directory): + os.makedirs(directory) + fname = self._MakeUniqueFName(os.path.join(directory, fname)) + with open(fname, "w") as f: + f.write(output_str) + return fname + + def GetRawData(self): + """Getter for all raw_data.""" + return self._data + + def SaveRawData(self, directory, dirname="raw_data"): + """Save raw data to file. + + Args: + directory: directory to create the raw data folder in. + dirname: folder in which raw data live. + + Returns: + list of full path of each domain's raw data save location + """ + if not os.path.exists(directory): + os.makedirs(directory) + dirname = os.path.join(directory, dirname) + if not os.path.exists(dirname): + os.makedirs(dirname) + fnames = [] + for domain, data in self._data.items(): + if not domain.endswith(self._unit[domain]): + domain = "%s_%s" % (domain, self._unit[domain]) + fname = self._MakeUniqueFName(os.path.join(dirname, "%s.txt" % domain)) + with open(fname, "w") as f: + f.write("\n".join("%.2f" % sample for sample in data) + "\n") + fnames.append(fname) + return fnames diff --git a/extra/usb_power/stats_manager_unittest.py b/extra/usb_power/stats_manager_unittest.py index beb9984b93..37a472af98 100644 --- a/extra/usb_power/stats_manager_unittest.py +++ b/extra/usb_power/stats_manager_unittest.py @@ -9,6 +9,7 @@ """Unit tests for StatsManager.""" from __future__ import print_function + import json import os import re @@ -20,296 +21,304 @@ import stats_manager class TestStatsManager(unittest.TestCase): - """Test to verify StatsManager methods work as expected. - - StatsManager should collect raw data, calculate their statistics, and save - them in expected format. - """ - - def _populate_mock_stats(self): - """Create a populated & processed StatsManager to test data retrieval.""" - self.data.AddSample('A', 99999.5) - self.data.AddSample('A', 100000.5) - self.data.SetUnit('A', 'uW') - self.data.SetUnit('A', 'mW') - self.data.AddSample('B', 1.5) - self.data.AddSample('B', 2.5) - self.data.AddSample('B', 3.5) - self.data.SetUnit('B', 'mV') - self.data.CalculateStats() - - def _populate_mock_stats_no_unit(self): - self.data.AddSample('B', 1000) - self.data.AddSample('A', 200) - self.data.SetUnit('A', 'blue') - - def setUp(self): - """Set up StatsManager and create a temporary directory for test.""" - self.tempdir = tempfile.mkdtemp() - self.data = stats_manager.StatsManager() - - def tearDown(self): - """Delete the temporary directory and its content.""" - shutil.rmtree(self.tempdir) - - def test_AddSample(self): - """Adding a sample successfully adds a sample.""" - self.data.AddSample('Test', 1000) - self.data.SetUnit('Test', 'test') - self.data.CalculateStats() - summary = self.data.GetSummary() - self.assertEqual(1, summary['Test']['count']) - - def test_AddSampleNoFloatAcceptNaN(self): - """Adding a non-number adds 'NaN' and doesn't raise an exception.""" - self.data.AddSample('Test', 10) - self.data.AddSample('Test', 20) - # adding a fake NaN: one that gets converted into NaN internally - self.data.AddSample('Test', 'fiesta') - # adding a real NaN - self.data.AddSample('Test', float('NaN')) - self.data.SetUnit('Test', 'test') - self.data.CalculateStats() - summary = self.data.GetSummary() - # assert that 'NaN' as added. - self.assertEqual(4, summary['Test']['count']) - # assert that mean, min, and max calculatings ignore the 'NaN' - self.assertEqual(10, summary['Test']['min']) - self.assertEqual(20, summary['Test']['max']) - self.assertEqual(15, summary['Test']['mean']) - - def test_AddSampleNoFloatNotAcceptNaN(self): - """Adding a non-number raises a StatsManagerError if accept_nan is False.""" - self.data = stats_manager.StatsManager(accept_nan=False) - with self.assertRaisesRegexp(stats_manager.StatsManagerError, - 'accept_nan is false. Cannot add NaN sample.'): - # adding a fake NaN: one that gets converted into NaN internally - self.data.AddSample('Test', 'fiesta') - with self.assertRaisesRegexp(stats_manager.StatsManagerError, - 'accept_nan is false. Cannot add NaN sample.'): - # adding a real NaN - self.data.AddSample('Test', float('NaN')) - - def test_AddSampleNoUnit(self): - """Not adding a unit does not cause an exception on CalculateStats().""" - self.data.AddSample('Test', 17) - self.data.CalculateStats() - summary = self.data.GetSummary() - self.assertEqual(1, summary['Test']['count']) - - def test_UnitSuffix(self): - """Unit gets appended as a suffix in the displayed summary.""" - self.data.AddSample('test', 250) - self.data.SetUnit('test', 'mw') - self.data.CalculateStats() - summary_str = self.data.SummaryToString() - self.assertIn('test_mw', summary_str) - - def test_DoubleUnitSuffix(self): - """If domain already ends in unit, verify that unit doesn't get appended.""" - self.data.AddSample('test_mw', 250) - self.data.SetUnit('test_mw', 'mw') - self.data.CalculateStats() - summary_str = self.data.SummaryToString() - self.assertIn('test_mw', summary_str) - self.assertNotIn('test_mw_mw', summary_str) - - def test_GetRawData(self): - """GetRawData returns exact same data as fed in.""" - self._populate_mock_stats() - raw_data = self.data.GetRawData() - self.assertListEqual([99999.5, 100000.5], raw_data['A']) - self.assertListEqual([1.5, 2.5, 3.5], raw_data['B']) - - def test_GetSummary(self): - """GetSummary returns expected stats about the data fed in.""" - self._populate_mock_stats() - summary = self.data.GetSummary() - self.assertEqual(2, summary['A']['count']) - self.assertAlmostEqual(100000.5, summary['A']['max']) - self.assertAlmostEqual(99999.5, summary['A']['min']) - self.assertAlmostEqual(0.5, summary['A']['stddev']) - self.assertAlmostEqual(100000.0, summary['A']['mean']) - self.assertEqual(3, summary['B']['count']) - self.assertAlmostEqual(3.5, summary['B']['max']) - self.assertAlmostEqual(1.5, summary['B']['min']) - self.assertAlmostEqual(0.81649658092773, summary['B']['stddev']) - self.assertAlmostEqual(2.5, summary['B']['mean']) - - def test_SaveRawData(self): - """SaveRawData stores same data as fed in.""" - self._populate_mock_stats() - dirname = 'unittest_raw_data' - expected_files = set(['A_mW.txt', 'B_mV.txt']) - fnames = self.data.SaveRawData(self.tempdir, dirname) - files_returned = set([os.path.basename(f) for f in fnames]) - # Assert that only the expected files got returned. - self.assertEqual(expected_files, files_returned) - # Assert that only the returned files are in the outdir. - self.assertEqual(set(os.listdir(os.path.join(self.tempdir, dirname))), - files_returned) - for fname in fnames: - with open(fname, 'r') as f: - if 'A_mW' in fname: - self.assertEqual('99999.50', f.readline().strip()) - self.assertEqual('100000.50', f.readline().strip()) - if 'B_mV' in fname: - self.assertEqual('1.50', f.readline().strip()) - self.assertEqual('2.50', f.readline().strip()) - self.assertEqual('3.50', f.readline().strip()) - - def test_SaveRawDataNoUnit(self): - """SaveRawData appends no unit suffix if the unit is not specified.""" - self._populate_mock_stats_no_unit() - self.data.CalculateStats() - outdir = 'unittest_raw_data' - files = self.data.SaveRawData(self.tempdir, outdir) - files = [os.path.basename(f) for f in files] - # Verify nothing gets appended to domain for filename if no unit exists. - self.assertIn('B.txt', files) - - def test_SaveRawDataSMID(self): - """SaveRawData uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - files = self.data.SaveRawData(self.tempdir) - for fname in files: - self.assertTrue(os.path.basename(fname).startswith(identifier)) - - def test_SummaryToStringNaNHelp(self): - """NaN containing row gets tagged with *, help banner gets added.""" - help_banner_exp = '%s %s' % (stats_manager.STATS_PREFIX, - stats_manager.NAN_DESCRIPTION) - nan_domain = 'A-domain' - nan_domain_exp = '%s%s' % (nan_domain, stats_manager.NAN_TAG) - # NaN helper banner is added when a NaN domain is found & domain gets tagged - data = stats_manager.StatsManager() - data.AddSample(nan_domain, float('NaN')) - data.AddSample(nan_domain, 17) - data.AddSample('B-domain', 17) - data.CalculateStats() - summarystr = data.SummaryToString() - self.assertIn(help_banner_exp, summarystr) - self.assertIn(nan_domain_exp, summarystr) - # NaN helper banner is not added when no NaN domain output, no tagging - data = stats_manager.StatsManager() - # nan_domain in this scenario does not contain any NaN - data.AddSample(nan_domain, 19) - data.AddSample('B-domain', 17) - data.CalculateStats() - summarystr = data.SummaryToString() - self.assertNotIn(help_banner_exp, summarystr) - self.assertNotIn(nan_domain_exp, summarystr) - - def test_SummaryToStringTitle(self): - """Title shows up in SummaryToString if title specified.""" - title = 'titulo' - data = stats_manager.StatsManager(title=title) - self._populate_mock_stats() - summary_str = data.SummaryToString() - self.assertIn(title, summary_str) - - def test_SummaryToStringHideDomains(self): - """Keys indicated in hide_domains are not printed in the summary.""" - data = stats_manager.StatsManager(hide_domains=['A-domain']) - data.AddSample('A-domain', 17) - data.AddSample('B-domain', 17) - data.CalculateStats() - summary_str = data.SummaryToString() - self.assertIn('B-domain', summary_str) - self.assertNotIn('A-domain', summary_str) - - def test_SummaryToStringOrder(self): - """Order passed into StatsManager is honoured when formatting summary.""" - # StatsManager that should print D & B first, and the subsequent elements - # are sorted. - d_b_a_c_regexp = re.compile('D-domain.*B-domain.*A-domain.*C-domain', - re.DOTALL) - data = stats_manager.StatsManager(order=['D-domain', 'B-domain']) - data.AddSample('A-domain', 17) - data.AddSample('B-domain', 17) - data.AddSample('C-domain', 17) - data.AddSample('D-domain', 17) - data.CalculateStats() - summary_str = data.SummaryToString() - self.assertRegexpMatches(summary_str, d_b_a_c_regexp) - - def test_MakeUniqueFName(self): - data = stats_manager.StatsManager() - testfile = os.path.join(self.tempdir, 'testfile.txt') - with open(testfile, 'w') as f: - f.write('') - expected_fname = os.path.join(self.tempdir, 'testfile0.txt') - self.assertEqual(expected_fname, data._MakeUniqueFName(testfile)) - - def test_SaveSummary(self): - """SaveSummary properly dumps the summary into a file.""" - self._populate_mock_stats() - fname = 'unittest_summary.txt' - expected_fname = os.path.join(self.tempdir, fname) - fname = self.data.SaveSummary(self.tempdir, fname) - # Assert the reported fname is the same as the expected fname - self.assertEqual(expected_fname, fname) - # Assert only the reported fname is output (in the tempdir) - self.assertEqual(set([os.path.basename(fname)]), - set(os.listdir(self.tempdir))) - with open(fname, 'r') as f: - self.assertEqual( - '@@ NAME COUNT MEAN STDDEV MAX MIN\n', - f.readline()) - self.assertEqual( - '@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n', - f.readline()) - self.assertEqual( - '@@ B_mV 3 2.50 0.82 3.50 1.50\n', - f.readline()) - - def test_SaveSummarySMID(self): - """SaveSummary uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - fname = os.path.basename(self.data.SaveSummary(self.tempdir)) - self.assertTrue(fname.startswith(identifier)) - - def test_SaveSummaryJSON(self): - """SaveSummaryJSON saves the added data properly in JSON format.""" - self._populate_mock_stats() - fname = 'unittest_summary.json' - expected_fname = os.path.join(self.tempdir, fname) - fname = self.data.SaveSummaryJSON(self.tempdir, fname) - # Assert the reported fname is the same as the expected fname - self.assertEqual(expected_fname, fname) - # Assert only the reported fname is output (in the tempdir) - self.assertEqual(set([os.path.basename(fname)]), - set(os.listdir(self.tempdir))) - with open(fname, 'r') as f: - summary = json.load(f) - self.assertAlmostEqual(100000.0, summary['A']['mean']) - self.assertEqual('milliwatt', summary['A']['unit']) - self.assertAlmostEqual(2.5, summary['B']['mean']) - self.assertEqual('millivolt', summary['B']['unit']) - - def test_SaveSummaryJSONSMID(self): - """SaveSummaryJSON uses the smid when creating output filename.""" - identifier = 'ec' - self.data = stats_manager.StatsManager(smid=identifier) - self._populate_mock_stats() - fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir)) - self.assertTrue(fname.startswith(identifier)) - - def test_SaveSummaryJSONNoUnit(self): - """SaveSummaryJSON marks unknown units properly as N/A.""" - self._populate_mock_stats_no_unit() - self.data.CalculateStats() - fname = 'unittest_summary.json' - fname = self.data.SaveSummaryJSON(self.tempdir, fname) - with open(fname, 'r') as f: - summary = json.load(f) - self.assertEqual('blue', summary['A']['unit']) - # if no unit is specified, JSON should save 'N/A' as the unit. - self.assertEqual('N/A', summary['B']['unit']) - -if __name__ == '__main__': - unittest.main() + """Test to verify StatsManager methods work as expected. + + StatsManager should collect raw data, calculate their statistics, and save + them in expected format. + """ + + def _populate_mock_stats(self): + """Create a populated & processed StatsManager to test data retrieval.""" + self.data.AddSample("A", 99999.5) + self.data.AddSample("A", 100000.5) + self.data.SetUnit("A", "uW") + self.data.SetUnit("A", "mW") + self.data.AddSample("B", 1.5) + self.data.AddSample("B", 2.5) + self.data.AddSample("B", 3.5) + self.data.SetUnit("B", "mV") + self.data.CalculateStats() + + def _populate_mock_stats_no_unit(self): + self.data.AddSample("B", 1000) + self.data.AddSample("A", 200) + self.data.SetUnit("A", "blue") + + def setUp(self): + """Set up StatsManager and create a temporary directory for test.""" + self.tempdir = tempfile.mkdtemp() + self.data = stats_manager.StatsManager() + + def tearDown(self): + """Delete the temporary directory and its content.""" + shutil.rmtree(self.tempdir) + + def test_AddSample(self): + """Adding a sample successfully adds a sample.""" + self.data.AddSample("Test", 1000) + self.data.SetUnit("Test", "test") + self.data.CalculateStats() + summary = self.data.GetSummary() + self.assertEqual(1, summary["Test"]["count"]) + + def test_AddSampleNoFloatAcceptNaN(self): + """Adding a non-number adds 'NaN' and doesn't raise an exception.""" + self.data.AddSample("Test", 10) + self.data.AddSample("Test", 20) + # adding a fake NaN: one that gets converted into NaN internally + self.data.AddSample("Test", "fiesta") + # adding a real NaN + self.data.AddSample("Test", float("NaN")) + self.data.SetUnit("Test", "test") + self.data.CalculateStats() + summary = self.data.GetSummary() + # assert that 'NaN' as added. + self.assertEqual(4, summary["Test"]["count"]) + # assert that mean, min, and max calculatings ignore the 'NaN' + self.assertEqual(10, summary["Test"]["min"]) + self.assertEqual(20, summary["Test"]["max"]) + self.assertEqual(15, summary["Test"]["mean"]) + + def test_AddSampleNoFloatNotAcceptNaN(self): + """Adding a non-number raises a StatsManagerError if accept_nan is False.""" + self.data = stats_manager.StatsManager(accept_nan=False) + with self.assertRaisesRegexp( + stats_manager.StatsManagerError, + "accept_nan is false. Cannot add NaN sample.", + ): + # adding a fake NaN: one that gets converted into NaN internally + self.data.AddSample("Test", "fiesta") + with self.assertRaisesRegexp( + stats_manager.StatsManagerError, + "accept_nan is false. Cannot add NaN sample.", + ): + # adding a real NaN + self.data.AddSample("Test", float("NaN")) + + def test_AddSampleNoUnit(self): + """Not adding a unit does not cause an exception on CalculateStats().""" + self.data.AddSample("Test", 17) + self.data.CalculateStats() + summary = self.data.GetSummary() + self.assertEqual(1, summary["Test"]["count"]) + + def test_UnitSuffix(self): + """Unit gets appended as a suffix in the displayed summary.""" + self.data.AddSample("test", 250) + self.data.SetUnit("test", "mw") + self.data.CalculateStats() + summary_str = self.data.SummaryToString() + self.assertIn("test_mw", summary_str) + + def test_DoubleUnitSuffix(self): + """If domain already ends in unit, verify that unit doesn't get appended.""" + self.data.AddSample("test_mw", 250) + self.data.SetUnit("test_mw", "mw") + self.data.CalculateStats() + summary_str = self.data.SummaryToString() + self.assertIn("test_mw", summary_str) + self.assertNotIn("test_mw_mw", summary_str) + + def test_GetRawData(self): + """GetRawData returns exact same data as fed in.""" + self._populate_mock_stats() + raw_data = self.data.GetRawData() + self.assertListEqual([99999.5, 100000.5], raw_data["A"]) + self.assertListEqual([1.5, 2.5, 3.5], raw_data["B"]) + + def test_GetSummary(self): + """GetSummary returns expected stats about the data fed in.""" + self._populate_mock_stats() + summary = self.data.GetSummary() + self.assertEqual(2, summary["A"]["count"]) + self.assertAlmostEqual(100000.5, summary["A"]["max"]) + self.assertAlmostEqual(99999.5, summary["A"]["min"]) + self.assertAlmostEqual(0.5, summary["A"]["stddev"]) + self.assertAlmostEqual(100000.0, summary["A"]["mean"]) + self.assertEqual(3, summary["B"]["count"]) + self.assertAlmostEqual(3.5, summary["B"]["max"]) + self.assertAlmostEqual(1.5, summary["B"]["min"]) + self.assertAlmostEqual(0.81649658092773, summary["B"]["stddev"]) + self.assertAlmostEqual(2.5, summary["B"]["mean"]) + + def test_SaveRawData(self): + """SaveRawData stores same data as fed in.""" + self._populate_mock_stats() + dirname = "unittest_raw_data" + expected_files = set(["A_mW.txt", "B_mV.txt"]) + fnames = self.data.SaveRawData(self.tempdir, dirname) + files_returned = set([os.path.basename(f) for f in fnames]) + # Assert that only the expected files got returned. + self.assertEqual(expected_files, files_returned) + # Assert that only the returned files are in the outdir. + self.assertEqual( + set(os.listdir(os.path.join(self.tempdir, dirname))), files_returned + ) + for fname in fnames: + with open(fname, "r") as f: + if "A_mW" in fname: + self.assertEqual("99999.50", f.readline().strip()) + self.assertEqual("100000.50", f.readline().strip()) + if "B_mV" in fname: + self.assertEqual("1.50", f.readline().strip()) + self.assertEqual("2.50", f.readline().strip()) + self.assertEqual("3.50", f.readline().strip()) + + def test_SaveRawDataNoUnit(self): + """SaveRawData appends no unit suffix if the unit is not specified.""" + self._populate_mock_stats_no_unit() + self.data.CalculateStats() + outdir = "unittest_raw_data" + files = self.data.SaveRawData(self.tempdir, outdir) + files = [os.path.basename(f) for f in files] + # Verify nothing gets appended to domain for filename if no unit exists. + self.assertIn("B.txt", files) + + def test_SaveRawDataSMID(self): + """SaveRawData uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + files = self.data.SaveRawData(self.tempdir) + for fname in files: + self.assertTrue(os.path.basename(fname).startswith(identifier)) + + def test_SummaryToStringNaNHelp(self): + """NaN containing row gets tagged with *, help banner gets added.""" + help_banner_exp = "%s %s" % ( + stats_manager.STATS_PREFIX, + stats_manager.NAN_DESCRIPTION, + ) + nan_domain = "A-domain" + nan_domain_exp = "%s%s" % (nan_domain, stats_manager.NAN_TAG) + # NaN helper banner is added when a NaN domain is found & domain gets tagged + data = stats_manager.StatsManager() + data.AddSample(nan_domain, float("NaN")) + data.AddSample(nan_domain, 17) + data.AddSample("B-domain", 17) + data.CalculateStats() + summarystr = data.SummaryToString() + self.assertIn(help_banner_exp, summarystr) + self.assertIn(nan_domain_exp, summarystr) + # NaN helper banner is not added when no NaN domain output, no tagging + data = stats_manager.StatsManager() + # nan_domain in this scenario does not contain any NaN + data.AddSample(nan_domain, 19) + data.AddSample("B-domain", 17) + data.CalculateStats() + summarystr = data.SummaryToString() + self.assertNotIn(help_banner_exp, summarystr) + self.assertNotIn(nan_domain_exp, summarystr) + + def test_SummaryToStringTitle(self): + """Title shows up in SummaryToString if title specified.""" + title = "titulo" + data = stats_manager.StatsManager(title=title) + self._populate_mock_stats() + summary_str = data.SummaryToString() + self.assertIn(title, summary_str) + + def test_SummaryToStringHideDomains(self): + """Keys indicated in hide_domains are not printed in the summary.""" + data = stats_manager.StatsManager(hide_domains=["A-domain"]) + data.AddSample("A-domain", 17) + data.AddSample("B-domain", 17) + data.CalculateStats() + summary_str = data.SummaryToString() + self.assertIn("B-domain", summary_str) + self.assertNotIn("A-domain", summary_str) + + def test_SummaryToStringOrder(self): + """Order passed into StatsManager is honoured when formatting summary.""" + # StatsManager that should print D & B first, and the subsequent elements + # are sorted. + d_b_a_c_regexp = re.compile("D-domain.*B-domain.*A-domain.*C-domain", re.DOTALL) + data = stats_manager.StatsManager(order=["D-domain", "B-domain"]) + data.AddSample("A-domain", 17) + data.AddSample("B-domain", 17) + data.AddSample("C-domain", 17) + data.AddSample("D-domain", 17) + data.CalculateStats() + summary_str = data.SummaryToString() + self.assertRegexpMatches(summary_str, d_b_a_c_regexp) + + def test_MakeUniqueFName(self): + data = stats_manager.StatsManager() + testfile = os.path.join(self.tempdir, "testfile.txt") + with open(testfile, "w") as f: + f.write("") + expected_fname = os.path.join(self.tempdir, "testfile0.txt") + self.assertEqual(expected_fname, data._MakeUniqueFName(testfile)) + + def test_SaveSummary(self): + """SaveSummary properly dumps the summary into a file.""" + self._populate_mock_stats() + fname = "unittest_summary.txt" + expected_fname = os.path.join(self.tempdir, fname) + fname = self.data.SaveSummary(self.tempdir, fname) + # Assert the reported fname is the same as the expected fname + self.assertEqual(expected_fname, fname) + # Assert only the reported fname is output (in the tempdir) + self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir))) + with open(fname, "r") as f: + self.assertEqual( + "@@ NAME COUNT MEAN STDDEV MAX MIN\n", + f.readline(), + ) + self.assertEqual( + "@@ A_mW 2 100000.00 0.50 100000.50 99999.50\n", + f.readline(), + ) + self.assertEqual( + "@@ B_mV 3 2.50 0.82 3.50 1.50\n", + f.readline(), + ) + + def test_SaveSummarySMID(self): + """SaveSummary uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + fname = os.path.basename(self.data.SaveSummary(self.tempdir)) + self.assertTrue(fname.startswith(identifier)) + + def test_SaveSummaryJSON(self): + """SaveSummaryJSON saves the added data properly in JSON format.""" + self._populate_mock_stats() + fname = "unittest_summary.json" + expected_fname = os.path.join(self.tempdir, fname) + fname = self.data.SaveSummaryJSON(self.tempdir, fname) + # Assert the reported fname is the same as the expected fname + self.assertEqual(expected_fname, fname) + # Assert only the reported fname is output (in the tempdir) + self.assertEqual(set([os.path.basename(fname)]), set(os.listdir(self.tempdir))) + with open(fname, "r") as f: + summary = json.load(f) + self.assertAlmostEqual(100000.0, summary["A"]["mean"]) + self.assertEqual("milliwatt", summary["A"]["unit"]) + self.assertAlmostEqual(2.5, summary["B"]["mean"]) + self.assertEqual("millivolt", summary["B"]["unit"]) + + def test_SaveSummaryJSONSMID(self): + """SaveSummaryJSON uses the smid when creating output filename.""" + identifier = "ec" + self.data = stats_manager.StatsManager(smid=identifier) + self._populate_mock_stats() + fname = os.path.basename(self.data.SaveSummaryJSON(self.tempdir)) + self.assertTrue(fname.startswith(identifier)) + + def test_SaveSummaryJSONNoUnit(self): + """SaveSummaryJSON marks unknown units properly as N/A.""" + self._populate_mock_stats_no_unit() + self.data.CalculateStats() + fname = "unittest_summary.json" + fname = self.data.SaveSummaryJSON(self.tempdir, fname) + with open(fname, "r") as f: + summary = json.load(f) + self.assertEqual("blue", summary["A"]["unit"]) + # if no unit is specified, JSON should save 'N/A' as the unit. + self.assertEqual("N/A", summary["B"]["unit"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/usb_serial/console.py b/extra/usb_serial/console.py index 7211dceff6..c08cb72092 100755 --- a/extra/usb_serial/console.py +++ b/extra/usb_serial/console.py @@ -12,6 +12,7 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import array import os @@ -20,23 +21,24 @@ import termios import threading import time import tty + try: - import usb + import usb except ModuleNotFoundError: - print('import usb failed') - print('try running these commands:') - print(' sudo apt-get install python-pip') - print(' sudo pip install --pre pyusb') - print() - sys.exit(-1) + print("import usb failed") + print("try running these commands:") + print(" sudo apt-get install python-pip") + print(" sudo pip install --pre pyusb") + print() + sys.exit(-1) import six def GetBuffer(stream): - if six.PY3: - return stream.buffer - return stream + if six.PY3: + return stream.buffer + return stream """Class Susb covers USB device discovery and initialization. @@ -45,99 +47,104 @@ def GetBuffer(stream): and interface number. """ + class SusbError(Exception): - """Class for exceptions of Susb.""" - def __init__(self, msg, value=0): - """SusbError constructor. + """Class for exceptions of Susb.""" - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SusbError, self).__init__(msg, value) - self.msg = msg - self.value = value - -class Susb(): - """Provide USB functionality. - - Instance Variables: - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - READ_ENDPOINT = 0x81 - WRITE_ENDPOINT = 0x1 - TIMEOUT_MS = 100 - - def __init__(self, vendor=0x18d1, - product=0x500f, interface=1, serialname=None): - """Susb constructor. - - Discovers and connects to USB endpoints. - - Args: - vendor: usb vendor id of device - product: usb product id of device - interface: interface number ( 1 - 8 ) of device to use - serialname: string of device serialnumber. - - Raises: - SusbError: An error accessing Susb object + def __init__(self, msg, value=0): + """SusbError constructor. + + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SusbError, self).__init__(msg, value) + self.msg = msg + self.value = value + + +class Susb: + """Provide USB functionality. + + Instance Variables: + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - # Find the device. - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise SusbError('USB device not found') - - # Check if we have multiple devices. - dev = None - if serialname: - for d in dev_list: - dev_serial = "PyUSB doesn't have a stable interface" - try: - dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) - except: - dev_serial = usb.util.get_string(d, d.iSerialNumber) - if dev_serial == serialname: - dev = d - break - if dev is None: - raise SusbError('USB device(%s) not found' % (serialname,)) - else: - try: - dev = dev_list[0] - except: - try: - dev = dev_list.next() - except: - raise SusbError('USB device %04x:%04x not found' % (vendor, product)) - # If we can't set configuration, it's already been set. - try: - dev.set_configuration() - except usb.core.USBError: - pass + READ_ENDPOINT = 0x81 + WRITE_ENDPOINT = 0x1 + TIMEOUT_MS = 100 + + def __init__(self, vendor=0x18D1, product=0x500F, interface=1, serialname=None): + """Susb constructor. + + Discovers and connects to USB endpoints. + + Args: + vendor: usb vendor id of device + product: usb product id of device + interface: interface number ( 1 - 8 ) of device to use + serialname: string of device serialnumber. + + Raises: + SusbError: An error accessing Susb object + """ + # Find the device. + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise SusbError("USB device not found") + + # Check if we have multiple devices. + dev = None + if serialname: + for d in dev_list: + dev_serial = "PyUSB doesn't have a stable interface" + try: + dev_serial = usb.util.get_string(d, 256, d.iSerialNumber) + except: + dev_serial = usb.util.get_string(d, d.iSerialNumber) + if dev_serial == serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % (serialname,)) + else: + try: + dev = dev_list[0] + except: + try: + dev = dev_list.next() + except: + raise SusbError( + "USB device %04x:%04x not found" % (vendor, product) + ) + + # If we can't set configuration, it's already been set. + try: + dev.set_configuration() + except usb.core.USBError: + pass - # Get an endpoint instance. - cfg = dev.get_active_configuration() - intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface) - self._intf = intf + # Get an endpoint instance. + cfg = dev.get_active_configuration() + intf = usb.util.find_descriptor(cfg, bInterfaceNumber=interface) + self._intf = intf - if not intf: - raise SusbError('Interface not found') + if not intf: + raise SusbError("Interface not found") - # Detach raiden.ko if it is loaded. - if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: + # Detach raiden.ko if it is loaded. + if dev.is_kernel_driver_active(intf.bInterfaceNumber) is True: dev.detach_kernel_driver(intf.bInterfaceNumber) - read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT - read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) - self._read_ep = read_ep + read_ep_number = intf.bInterfaceNumber + self.READ_ENDPOINT + read_ep = usb.util.find_descriptor(intf, bEndpointAddress=read_ep_number) + self._read_ep = read_ep - write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT - write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) - self._write_ep = write_ep + write_ep_number = intf.bInterfaceNumber + self.WRITE_ENDPOINT + write_ep = usb.util.find_descriptor(intf, bEndpointAddress=write_ep_number) + self._write_ep = write_ep """Suart class implements a stream interface, to access Google's USB class. @@ -146,89 +153,91 @@ class Susb(): and forwards them across. This particular class is hardcoded to stdin/out. """ -class SuartError(Exception): - """Class for exceptions of Suart.""" - def __init__(self, msg, value=0): - """SuartError constructor. - - Args: - msg: string, message describing error in detail - value: integer, value of error when non-zero status returned. Default=0 - """ - super(SuartError, self).__init__(msg, value) - self.msg = msg - self.value = value - - -class Suart(): - """Provide interface to serial usb endpoint.""" - def __init__(self, vendor=0x18d1, product=0x501c, interface=0, - serialname=None): - """Suart contstructor. +class SuartError(Exception): + """Class for exceptions of Suart.""" - Initializes USB stream interface. + def __init__(self, msg, value=0): + """SuartError constructor. - Args: - vendor: usb vendor id of device - product: usb product id of device - interface: interface number of device to use - serialname: Defaults to None. + Args: + msg: string, message describing error in detail + value: integer, value of error when non-zero status returned. Default=0 + """ + super(SuartError, self).__init__(msg, value) + self.msg = msg + self.value = value - Raises: - SuartError: If init fails - """ - self._done = threading.Event() - self._susb = Susb(vendor=vendor, product=product, - interface=interface, serialname=serialname) - def wait_until_done(self, timeout=None): - return self._done.wait(timeout=timeout) +class Suart: + """Provide interface to serial usb endpoint.""" - def run_rx_thread(self): - try: - while True: - try: - r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) - if r: - GetBuffer(sys.stdout).write(r.tobytes()) - GetBuffer(sys.stdout).flush() - - except Exception as e: - # If we miss some characters on pty disconnect, that's fine. - # ep.read() also throws USBError on timeout, which we discard. - if not isinstance(e, (OSError, usb.core.USBError)): - print('rx %s' % e) - finally: - self._done.set() + def __init__(self, vendor=0x18D1, product=0x501C, interface=0, serialname=None): + """Suart contstructor. - def run_tx_thread(self): - try: - while True: - try: - r = GetBuffer(sys.stdin).read(1) - if not r or r == b'\x03': - break - if r: - self._susb._write_ep.write(array.array('B', r), - self._susb.TIMEOUT_MS) - except Exception as e: - print('tx %s' % e) - finally: - self._done.set() + Initializes USB stream interface. - def run(self): - """Creates pthreads to poll USB & PTY for data.""" - self._exit = False + Args: + vendor: usb vendor id of device + product: usb product id of device + interface: interface number of device to use + serialname: Defaults to None. - self._rx_thread = threading.Thread(target=self.run_rx_thread) - self._rx_thread.daemon = True - self._rx_thread.start() + Raises: + SuartError: If init fails + """ + self._done = threading.Event() + self._susb = Susb( + vendor=vendor, product=product, interface=interface, serialname=serialname + ) - self._tx_thread = threading.Thread(target=self.run_tx_thread) - self._tx_thread.daemon = True - self._tx_thread.start() + def wait_until_done(self, timeout=None): + return self._done.wait(timeout=timeout) + def run_rx_thread(self): + try: + while True: + try: + r = self._susb._read_ep.read(64, self._susb.TIMEOUT_MS) + if r: + GetBuffer(sys.stdout).write(r.tobytes()) + GetBuffer(sys.stdout).flush() + + except Exception as e: + # If we miss some characters on pty disconnect, that's fine. + # ep.read() also throws USBError on timeout, which we discard. + if not isinstance(e, (OSError, usb.core.USBError)): + print("rx %s" % e) + finally: + self._done.set() + + def run_tx_thread(self): + try: + while True: + try: + r = GetBuffer(sys.stdin).read(1) + if not r or r == b"\x03": + break + if r: + self._susb._write_ep.write( + array.array("B", r), self._susb.TIMEOUT_MS + ) + except Exception as e: + print("tx %s" % e) + finally: + self._done.set() + + def run(self): + """Creates pthreads to poll USB & PTY for data.""" + self._exit = False + + self._rx_thread = threading.Thread(target=self.run_rx_thread) + self._rx_thread.daemon = True + self._rx_thread.start() + + self._tx_thread = threading.Thread(target=self.run_tx_thread) + self._tx_thread.daemon = True + self._tx_thread.start() """Command line functionality @@ -237,59 +246,67 @@ class Suart(): Ctrl-C exits. """ -parser = argparse.ArgumentParser(description='Open a console to a USB device') -parser.add_argument('-d', '--device', type=str, - help='vid:pid of target device', default='18d1:501c') -parser.add_argument('-i', '--interface', type=int, - help='interface number of console', default=0) -parser.add_argument('-s', '--serialno', type=str, - help='serial number of device', default='') -parser.add_argument('-S', '--notty-exit-sleep', type=float, default=0.2, - help='When stdin is *not* a TTY, wait this many seconds ' - 'after EOF from stdin before exiting, to give time for ' - 'receiving a reply from the USB device.') +parser = argparse.ArgumentParser(description="Open a console to a USB device") +parser.add_argument( + "-d", "--device", type=str, help="vid:pid of target device", default="18d1:501c" +) +parser.add_argument( + "-i", "--interface", type=int, help="interface number of console", default=0 +) +parser.add_argument( + "-s", "--serialno", type=str, help="serial number of device", default="" +) +parser.add_argument( + "-S", + "--notty-exit-sleep", + type=float, + default=0.2, + help="When stdin is *not* a TTY, wait this many seconds " + "after EOF from stdin before exiting, to give time for " + "receiving a reply from the USB device.", +) + def runconsole(): - """Run the usb console code + """Run the usb console code - Starts the pty thread, and idles until a ^C is caught. - """ - args = parser.parse_args() + Starts the pty thread, and idles until a ^C is caught. + """ + args = parser.parse_args() - vidstr, pidstr = args.device.split(':') - vid = int(vidstr, 16) - pid = int(pidstr, 16) + vidstr, pidstr = args.device.split(":") + vid = int(vidstr, 16) + pid = int(pidstr, 16) - serialno = args.serialno - interface = args.interface + serialno = args.serialno + interface = args.interface - sobj = Suart(vendor=vid, product=pid, interface=interface, - serialname=serialno) - if sys.stdin.isatty(): - tty.setraw(sys.stdin.fileno()) - sobj.run() - sobj.wait_until_done() - if not sys.stdin.isatty() and args.notty_exit_sleep > 0: - time.sleep(args.notty_exit_sleep) + sobj = Suart(vendor=vid, product=pid, interface=interface, serialname=serialno) + if sys.stdin.isatty(): + tty.setraw(sys.stdin.fileno()) + sobj.run() + sobj.wait_until_done() + if not sys.stdin.isatty() and args.notty_exit_sleep > 0: + time.sleep(args.notty_exit_sleep) def main(): - stdin_isatty = sys.stdin.isatty() - if stdin_isatty: - fd = sys.stdin.fileno() - os.system('stty -echo') - old_settings = termios.tcgetattr(fd) - - try: - runconsole() - finally: + stdin_isatty = sys.stdin.isatty() if stdin_isatty: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - os.system('stty echo') - # Avoid having the user's shell prompt start mid-line after the final output - # from this program. - print() + fd = sys.stdin.fileno() + os.system("stty -echo") + old_settings = termios.tcgetattr(fd) + + try: + runconsole() + finally: + if stdin_isatty: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + os.system("stty echo") + # Avoid having the user's shell prompt start mid-line after the final output + # from this program. + print() -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/extra/usb_updater/fw_update.py b/extra/usb_updater/fw_update.py index 0d7a570fc3..f05797bfb6 100755 --- a/extra/usb_updater/fw_update.py +++ b/extra/usb_updater/fw_update.py @@ -20,407 +20,425 @@ import struct import sys import time from pprint import pprint -import usb +import usb debug = False -def debuglog(msg): - if debug: - print(msg) - -def log(msg): - print(msg) - sys.stdout.flush() - - -"""Sends firmware update to CROS EC usb endpoint.""" - -class Supdate(object): - """Class to access firmware update endpoints. - - Usage: - d = Supdate() - Instance Variables: - _dev: pyUSB device object - _read_ep: pyUSB read endpoint for this interface - _write_ep: pyUSB write endpoint for this interface - """ - USB_SUBCLASS_GOOGLE_UPDATE = 0x53 - USB_CLASS_VENDOR = 0xFF - def __init__(self): - pass - - - def connect_usb(self, serialname=None): - """Initial discovery and connection to USB endpoint. - - This searches for a USB device matching the VID:PID specified - in the config file, optionally matching a specified serialname. - - Args: - serialname: Find the device with this serial, in case multiple - devices are attached. - - Returns: - True on success. - Raises: - Exception on error. - """ - # Find the stm32. - vendor = self._brdcfg['vid'] - product = self._brdcfg['pid'] - - dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) - dev_list = list(dev_g) - if dev_list is None: - raise Exception("Update", "USB device not found") - - # Check if we have multiple stm32s and we've specified the serial. - dev = None - if serialname: - for d in dev_list: - if usb.util.get_string(d, d.iSerialNumber) == serialname: - dev = d - break - if dev is None: - raise SusbError("USB device(%s) not found" % serialname) - else: - try: - dev = dev_list[0] - except: - dev = dev_list.next() - - debuglog("Found stm32: %04x:%04x" % (vendor, product)) - self._dev = dev - - # Get an endpoint instance. - try: - dev.set_configuration() - except: - pass - cfg = dev.get_active_configuration() - - intf = usb.util.find_descriptor(cfg, custom_match=lambda i: \ - i.bInterfaceClass==self.USB_CLASS_VENDOR and \ - i.bInterfaceSubClass==self.USB_SUBCLASS_GOOGLE_UPDATE) - - self._intf = intf - debuglog("Interface: %s" % intf) - debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber) - - read_ep = usb.util.find_descriptor( - intf, - # match the first IN endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_IN - ) - - self._read_ep = read_ep - debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress) - - write_ep = usb.util.find_descriptor( - intf, - # match the first OUT endpoint - custom_match = \ - lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == \ - usb.util.ENDPOINT_OUT - ) - - self._write_ep = write_ep - debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress) - - return True - - - def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000): - """Write command to logger logic.. - - This function writes byte command values list to stm, then reads - byte status. - - Args: - write_list: list of command byte values [0~255]. - read_count: number of status byte values to read. - wtimeout: mS to wait for write success - rtimeout: mS to wait for read success - - Returns: - status byte, if one byte is read, - byte list, if multiple bytes are read, - None, if no bytes are read. - - Interface: - write: [command, data ... ] - read: [status ] - """ - debuglog("wr_command(write_list=[%s] (%d), read_count=%s)" % ( - list(bytearray(write_list)), len(write_list), read_count)) - - # Clean up args from python style to correct types. - write_length = 0 - if write_list: - write_length = len(write_list) - if not read_count: - read_count = 0 - - # Send command to stm32. - if write_list: - cmd = write_list - ret = self._write_ep.write(cmd, wtimeout) - debuglog("RET: %s " % ret) - - # Read back response if necessary. - if read_count: - bytesread = self._read_ep.read(512, rtimeout) - debuglog("BYTES: [%s]" % bytesread) - - if len(bytesread) != read_count: - debuglog("Unexpected bytes read: %d, expected: %d" % (len(bytesread), read_count)) - pass - - debuglog("STATUS: 0x%02x" % int(bytesread[0])) - if read_count == 1: - return bytesread[0] - else: - return bytesread - - return None +def debuglog(msg): + if debug: + print(msg) - def stop(self): - """Finalize system flash and exit.""" - cmd = struct.pack(">I", 0xB007AB1E) - read = self.wr_command(cmd, read_count=4) - if len(read) == 4: - log("Finished flashing") - return +def log(msg): + print(msg) + sys.stdout.flush() - raise Exception("Update", "Stop failed [%s]" % read) +"""Sends firmware update to CROS EC usb endpoint.""" - def write_file(self): - """Write the update region packet by packet to USB - This sends write packets of size 128B out, in 32B chunks. - Overall, this will write all data in the inactive code region. +class Supdate(object): + """Class to access firmware update endpoints. - Raises: - Exception if write failed or address out of bounds. - """ - region = self._region - flash_base = self._brdcfg["flash"] - offset = self._base - flash_base - if offset != self._brdcfg['regions'][region][0]: - raise Exception("Update", "Region %s offset 0x%x != available offset 0x%x" % ( - region, self._brdcfg['regions'][region][0], offset)) - - length = self._brdcfg['regions'][region][1] - log("Sending") - - # Go to the correct region in the ec.bin file. - self._binfile.seek(offset) - - # Send 32 bytes at a time. Must be less than the endpoint's max packet size. - maxpacket = 32 - - # While data is left, create update packets. - while length > 0: - # Update packets are 128B. We can use any number - # but the micro must malloc this memory. - pagesize = min(length, 128) - - # Packet is: - # packet size: page bytes transferred plus 3 x 32b values header. - # cmd: n/a - # base: flash address to write this packet. - # data: 128B of data to write into flash_base - cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base) - read = self.wr_command(cmd, read_count=0) - - # Push 'todo' bytes out the pipe. - todo = pagesize - while todo > 0: - packetsize = min(maxpacket, todo) - data = self._binfile.read(packetsize) - if len(data) != packetsize: - raise Exception("Update", "No more data from file") - for i in range(0, 10): - try: - self.wr_command(data, read_count=0) - break - except: - log("Timeout fail") - todo -= packetsize - # Done with this packet, move to the next one. - length -= pagesize - offset += pagesize - - # Validate that the micro thinks it successfully wrote the data. - read = self.wr_command(''.encode(), read_count=4) - result = struct.unpack("<I", read) - result = result[0] - if result != 0: - raise Exception("Update", "Upload failed with rc: 0x%x" % result) - - - def start(self): - """Start a transaction and erase currently inactive region. - - This function sends a start command, and receives the base of the - preferred inactive region. This could be RW, RW_B, - or RO (if there's no RW_B) - - Note that the region is erased here, so you'd better program the RO if - you just erased it. TODO(nsanders): Modify the protocol to allow active - region select or query before erase. - """ + Usage: + d = Supdate() - # Size is 3 uint32 fields - # packet: [packetsize, cmd, base] - size = 4 + 4 + 4 - # Return value is [status, base_addr] - expected = 4 + 4 - - cmd = struct.pack("<III", size, 0, 0) - read = self.wr_command(cmd, read_count=expected) - - if len(read) == 4: - raise Exception("Update", "Protocol version 0 not supported") - elif len(read) == expected: - base, version = struct.unpack(">II", read) - log("Update protocol v. %d" % version) - log("Available flash region base: %x" % base) - else: - raise Exception("Update", "Start command returned %d bytes" % len(read)) - - if base < 256: - raise Exception("Update", "Start returned error code 0x%x" % base) - - self._base = base - flash_base = self._brdcfg["flash"] - self._offset = self._base - flash_base - - # Find our active region. - for region in self._brdcfg['regions']: - if (self._offset >= self._brdcfg['regions'][region][0]) and \ - (self._offset < (self._brdcfg['regions'][region][0] + \ - self._brdcfg['regions'][region][1])): - log("Active region: %s" % region) - self._region = region - - - def load_board(self, brdfile): - """Load firmware layout file. - - example as follows: - { - "board": "servo micro", - "vid": 6353, - "pid": 20506, - "flash": 134217728, - "regions": { - "RW": [65536, 65536], - "PSTATE": [61440, 4096], - "RO": [0, 61440] - } - } - - Args: - brdfile: path to board description file. + Instance Variables: + _dev: pyUSB device object + _read_ep: pyUSB read endpoint for this interface + _write_ep: pyUSB write endpoint for this interface """ - with open(brdfile) as data_file: - data = json.load(data_file) - - # TODO(nsanders): validate this data before moving on. - self._brdcfg = data; - if debug: - pprint(data) - - log("Board is %s" % self._brdcfg['board']) - # Cast hex strings to int. - self._brdcfg['flash'] = int(self._brdcfg['flash'], 0) - self._brdcfg['vid'] = int(self._brdcfg['vid'], 0) - self._brdcfg['pid'] = int(self._brdcfg['pid'], 0) - log("Flash Base is %x" % self._brdcfg['flash']) - self._flashsize = 0 - for region in self._brdcfg['regions']: - base = int(self._brdcfg['regions'][region][0], 0) - length = int(self._brdcfg['regions'][region][1], 0) - log("region %s\tbase:0x%08x size:0x%08x" % ( - region, base, length)) - self._flashsize += length + USB_SUBCLASS_GOOGLE_UPDATE = 0x53 + USB_CLASS_VENDOR = 0xFF - # Convert these to int because json doesn't support hex. - self._brdcfg['regions'][region][0] = base - self._brdcfg['regions'][region][1] = length - - log("Flash Size: 0x%x" % self._flashsize) - - def load_file(self, binfile): - """Open and verify size of the target ec.bin file. - - Args: - binfile: path to ec.bin - - Raises: - Exception on file not found or filesize not matching. - """ - self._filesize = os.path.getsize(binfile) - self._binfile = open(binfile, 'rb') - - if self._filesize != self._flashsize: - raise Exception("Update", "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize)) + def __init__(self): + pass + def connect_usb(self, serialname=None): + """Initial discovery and connection to USB endpoint. + + This searches for a USB device matching the VID:PID specified + in the config file, optionally matching a specified serialname. + + Args: + serialname: Find the device with this serial, in case multiple + devices are attached. + + Returns: + True on success. + Raises: + Exception on error. + """ + # Find the stm32. + vendor = self._brdcfg["vid"] + product = self._brdcfg["pid"] + + dev_g = usb.core.find(idVendor=vendor, idProduct=product, find_all=True) + dev_list = list(dev_g) + if dev_list is None: + raise Exception("Update", "USB device not found") + + # Check if we have multiple stm32s and we've specified the serial. + dev = None + if serialname: + for d in dev_list: + if usb.util.get_string(d, d.iSerialNumber) == serialname: + dev = d + break + if dev is None: + raise SusbError("USB device(%s) not found" % serialname) + else: + try: + dev = dev_list[0] + except: + dev = dev_list.next() + + debuglog("Found stm32: %04x:%04x" % (vendor, product)) + self._dev = dev + + # Get an endpoint instance. + try: + dev.set_configuration() + except: + pass + cfg = dev.get_active_configuration() + + intf = usb.util.find_descriptor( + cfg, + custom_match=lambda i: i.bInterfaceClass == self.USB_CLASS_VENDOR + and i.bInterfaceSubClass == self.USB_SUBCLASS_GOOGLE_UPDATE, + ) + + self._intf = intf + debuglog("Interface: %s" % intf) + debuglog("InterfaceNumber: %s" % intf.bInterfaceNumber) + + read_ep = usb.util.find_descriptor( + intf, + # match the first IN endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_IN, + ) + + self._read_ep = read_ep + debuglog("Reader endpoint: 0x%x" % read_ep.bEndpointAddress) + + write_ep = usb.util.find_descriptor( + intf, + # match the first OUT endpoint + custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) + == usb.util.ENDPOINT_OUT, + ) + + self._write_ep = write_ep + debuglog("Writer endpoint: 0x%x" % write_ep.bEndpointAddress) + + return True + + def wr_command(self, write_list, read_count=1, wtimeout=100, rtimeout=2000): + """Write command to logger logic.. + + This function writes byte command values list to stm, then reads + byte status. + + Args: + write_list: list of command byte values [0~255]. + read_count: number of status byte values to read. + wtimeout: mS to wait for write success + rtimeout: mS to wait for read success + + Returns: + status byte, if one byte is read, + byte list, if multiple bytes are read, + None, if no bytes are read. + + Interface: + write: [command, data ... ] + read: [status ] + """ + debuglog( + "wr_command(write_list=[%s] (%d), read_count=%s)" + % (list(bytearray(write_list)), len(write_list), read_count) + ) + + # Clean up args from python style to correct types. + write_length = 0 + if write_list: + write_length = len(write_list) + if not read_count: + read_count = 0 + + # Send command to stm32. + if write_list: + cmd = write_list + ret = self._write_ep.write(cmd, wtimeout) + debuglog("RET: %s " % ret) + + # Read back response if necessary. + if read_count: + bytesread = self._read_ep.read(512, rtimeout) + debuglog("BYTES: [%s]" % bytesread) + + if len(bytesread) != read_count: + debuglog( + "Unexpected bytes read: %d, expected: %d" + % (len(bytesread), read_count) + ) + pass + + debuglog("STATUS: 0x%02x" % int(bytesread[0])) + if read_count == 1: + return bytesread[0] + else: + return bytesread + + return None + + def stop(self): + """Finalize system flash and exit.""" + cmd = struct.pack(">I", 0xB007AB1E) + read = self.wr_command(cmd, read_count=4) + + if len(read) == 4: + log("Finished flashing") + return + + raise Exception("Update", "Stop failed [%s]" % read) + + def write_file(self): + """Write the update region packet by packet to USB + + This sends write packets of size 128B out, in 32B chunks. + Overall, this will write all data in the inactive code region. + + Raises: + Exception if write failed or address out of bounds. + """ + region = self._region + flash_base = self._brdcfg["flash"] + offset = self._base - flash_base + if offset != self._brdcfg["regions"][region][0]: + raise Exception( + "Update", + "Region %s offset 0x%x != available offset 0x%x" + % (region, self._brdcfg["regions"][region][0], offset), + ) + + length = self._brdcfg["regions"][region][1] + log("Sending") + + # Go to the correct region in the ec.bin file. + self._binfile.seek(offset) + + # Send 32 bytes at a time. Must be less than the endpoint's max packet size. + maxpacket = 32 + + # While data is left, create update packets. + while length > 0: + # Update packets are 128B. We can use any number + # but the micro must malloc this memory. + pagesize = min(length, 128) + + # Packet is: + # packet size: page bytes transferred plus 3 x 32b values header. + # cmd: n/a + # base: flash address to write this packet. + # data: 128B of data to write into flash_base + cmd = struct.pack(">III", pagesize + 12, 0, offset + flash_base) + read = self.wr_command(cmd, read_count=0) + + # Push 'todo' bytes out the pipe. + todo = pagesize + while todo > 0: + packetsize = min(maxpacket, todo) + data = self._binfile.read(packetsize) + if len(data) != packetsize: + raise Exception("Update", "No more data from file") + for i in range(0, 10): + try: + self.wr_command(data, read_count=0) + break + except: + log("Timeout fail") + todo -= packetsize + # Done with this packet, move to the next one. + length -= pagesize + offset += pagesize + + # Validate that the micro thinks it successfully wrote the data. + read = self.wr_command("".encode(), read_count=4) + result = struct.unpack("<I", read) + result = result[0] + if result != 0: + raise Exception("Update", "Upload failed with rc: 0x%x" % result) + + def start(self): + """Start a transaction and erase currently inactive region. + + This function sends a start command, and receives the base of the + preferred inactive region. This could be RW, RW_B, + or RO (if there's no RW_B) + + Note that the region is erased here, so you'd better program the RO if + you just erased it. TODO(nsanders): Modify the protocol to allow active + region select or query before erase. + """ + + # Size is 3 uint32 fields + # packet: [packetsize, cmd, base] + size = 4 + 4 + 4 + # Return value is [status, base_addr] + expected = 4 + 4 + + cmd = struct.pack("<III", size, 0, 0) + read = self.wr_command(cmd, read_count=expected) + + if len(read) == 4: + raise Exception("Update", "Protocol version 0 not supported") + elif len(read) == expected: + base, version = struct.unpack(">II", read) + log("Update protocol v. %d" % version) + log("Available flash region base: %x" % base) + else: + raise Exception("Update", "Start command returned %d bytes" % len(read)) + + if base < 256: + raise Exception("Update", "Start returned error code 0x%x" % base) + + self._base = base + flash_base = self._brdcfg["flash"] + self._offset = self._base - flash_base + + # Find our active region. + for region in self._brdcfg["regions"]: + if (self._offset >= self._brdcfg["regions"][region][0]) and ( + self._offset + < ( + self._brdcfg["regions"][region][0] + + self._brdcfg["regions"][region][1] + ) + ): + log("Active region: %s" % region) + self._region = region + + def load_board(self, brdfile): + """Load firmware layout file. + + example as follows: + { + "board": "servo micro", + "vid": 6353, + "pid": 20506, + "flash": 134217728, + "regions": { + "RW": [65536, 65536], + "PSTATE": [61440, 4096], + "RO": [0, 61440] + } + } + + Args: + brdfile: path to board description file. + """ + with open(brdfile) as data_file: + data = json.load(data_file) + + # TODO(nsanders): validate this data before moving on. + self._brdcfg = data + if debug: + pprint(data) + + log("Board is %s" % self._brdcfg["board"]) + # Cast hex strings to int. + self._brdcfg["flash"] = int(self._brdcfg["flash"], 0) + self._brdcfg["vid"] = int(self._brdcfg["vid"], 0) + self._brdcfg["pid"] = int(self._brdcfg["pid"], 0) + + log("Flash Base is %x" % self._brdcfg["flash"]) + self._flashsize = 0 + for region in self._brdcfg["regions"]: + base = int(self._brdcfg["regions"][region][0], 0) + length = int(self._brdcfg["regions"][region][1], 0) + log("region %s\tbase:0x%08x size:0x%08x" % (region, base, length)) + self._flashsize += length + + # Convert these to int because json doesn't support hex. + self._brdcfg["regions"][region][0] = base + self._brdcfg["regions"][region][1] = length + + log("Flash Size: 0x%x" % self._flashsize) + + def load_file(self, binfile): + """Open and verify size of the target ec.bin file. + + Args: + binfile: path to ec.bin + + Raises: + Exception on file not found or filesize not matching. + """ + self._filesize = os.path.getsize(binfile) + self._binfile = open(binfile, "rb") + + if self._filesize != self._flashsize: + raise Exception( + "Update", + "Flash size 0x%x != file size 0x%x" % (self._flashsize, self._filesize), + ) # Generate command line arguments parser = argparse.ArgumentParser(description="Update firmware over usb") -parser.add_argument('-b', '--board', type=str, help="Board configuration json file", default="board.json") -parser.add_argument('-f', '--file', type=str, help="Complete ec.bin file", default="ec.bin") -parser.add_argument('-s', '--serial', type=str, help="Serial number", default="") -parser.add_argument('-l', '--list', action="store_true", help="List regions") -parser.add_argument('-v', '--verbose', action="store_true", help="Chatty output") +parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration json file", + default="board.json", +) +parser.add_argument( + "-f", "--file", type=str, help="Complete ec.bin file", default="ec.bin" +) +parser.add_argument("-s", "--serial", type=str, help="Serial number", default="") +parser.add_argument("-l", "--list", action="store_true", help="List regions") +parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output") + def main(): - global debug - args = parser.parse_args() + global debug + args = parser.parse_args() + brdfile = args.board + serial = args.serial + binfile = args.file + if args.verbose: + debug = True - brdfile = args.board - serial = args.serial - binfile = args.file - if args.verbose: - debug = True + with open(brdfile) as data_file: + names = json.load(data_file) - with open(brdfile) as data_file: - names = json.load(data_file) + p = Supdate() + p.load_board(brdfile) + p.connect_usb(serialname=serial) + p.load_file(binfile) - p = Supdate() - p.load_board(brdfile) - p.connect_usb(serialname=serial) - p.load_file(binfile) + # List solely prints the config. + if args.list: + return - # List solely prints the config. - if (args.list): - return + # Start transfer and erase. + p.start() + # Upload the bin file + log("Uploading %s" % binfile) + p.write_file() - # Start transfer and erase. - p.start() - # Upload the bin file - log("Uploading %s" % binfile) - p.write_file() + # Finalize + log("Done. Finalizing.") + p.stop() - # Finalize - log("Done. Finalizing.") - p.stop() if __name__ == "__main__": - main() - - + main() diff --git a/extra/usb_updater/servo_updater.py b/extra/usb_updater/servo_updater.py index 432ee120de..5402af70aa 100755 --- a/extra/usb_updater/servo_updater.py +++ b/extra/usb_updater/servo_updater.py @@ -14,40 +14,46 @@ from __future__ import print_function import argparse +import json import os import re import subprocess import time -import json - -import fw_update import ecusb.tiny_servo_common as c +import fw_update from ecusb import tiny_servod + class ServoUpdaterException(Exception): - """Raised on exceptions generated by servo_updater.""" + """Raised on exceptions generated by servo_updater.""" -BOARD_C2D2 = 'c2d2' -BOARD_SERVO_MICRO = 'servo_micro' -BOARD_SERVO_V4 = 'servo_v4' -BOARD_SERVO_V4P1 = 'servo_v4p1' -BOARD_SWEETBERRY = 'sweetberry' + +BOARD_C2D2 = "c2d2" +BOARD_SERVO_MICRO = "servo_micro" +BOARD_SERVO_V4 = "servo_v4" +BOARD_SERVO_V4P1 = "servo_v4p1" +BOARD_SWEETBERRY = "sweetberry" DEFAULT_BOARD = BOARD_SERVO_V4 # These lists are to facilitate exposing choices in the command-line tool # below. -BOARDS = [BOARD_C2D2, BOARD_SERVO_MICRO, BOARD_SERVO_V4, BOARD_SERVO_V4P1, - BOARD_SWEETBERRY] +BOARDS = [ + BOARD_C2D2, + BOARD_SERVO_MICRO, + BOARD_SERVO_V4, + BOARD_SERVO_V4P1, + BOARD_SWEETBERRY, +] # Servo firmware bundles four channels of firmware. We need to make sure the # user does not request a non-existing channel, so keep the lists around to # guard on command-line usage. -DEFAULT_CHANNEL = STABLE_CHANNEL = 'stable' +DEFAULT_CHANNEL = STABLE_CHANNEL = "stable" -PREV_CHANNEL = 'prev' +PREV_CHANNEL = "prev" # The ordering here matters. From left to right it's the channel that the user # is most likely to be running. This is used to inform and warn the user if @@ -55,403 +61,436 @@ PREV_CHANNEL = 'prev' # user know they are running the 'stable' version before letting them know they # are running 'dev' or even 'alpah' which (while true) might cause confusion. -CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, 'dev', 'alpha'] +CHANNELS = [DEFAULT_CHANNEL, PREV_CHANNEL, "dev", "alpha"] -DEFAULT_BASE_PATH = '/usr/' -TEST_IMAGE_BASE_PATH = '/usr/local/' +DEFAULT_BASE_PATH = "/usr/" +TEST_IMAGE_BASE_PATH = "/usr/local/" -COMMON_PATH = 'share/servo_updater' +COMMON_PATH = "share/servo_updater" -FIRMWARE_DIR = 'firmware/' -CONFIGS_DIR = 'configs/' +FIRMWARE_DIR = "firmware/" +CONFIGS_DIR = "configs/" RETRIES_COUNT = 10 RETRIES_DELAY = 1 + def do_with_retries(func, *args): - """Try a function several times - - Call function passed as argument and check if no error happened. - If exception was raised by function, - it will be retried up to RETRIES_COUNT times. - - Args: - func: function that will be called - args: arguments passed to 'func' - - Returns: - If call to function was successful, its result will be returned. - If retries count was exceeded, exception will be raised. - """ - - retry = 0 - while retry < RETRIES_COUNT: - try: - return func(*args) - except Exception as e: - print('Retrying function %s: %s' % (func.__name__, e)) - retry = retry + 1 - time.sleep(RETRIES_DELAY) - continue - - raise Exception("'{}' failed after {} retries".format( - func.__name__, RETRIES_COUNT)) + """Try a function several times + + Call function passed as argument and check if no error happened. + If exception was raised by function, + it will be retried up to RETRIES_COUNT times. + + Args: + func: function that will be called + args: arguments passed to 'func' + + Returns: + If call to function was successful, its result will be returned. + If retries count was exceeded, exception will be raised. + """ + + retry = 0 + while retry < RETRIES_COUNT: + try: + return func(*args) + except Exception as e: + print("Retrying function %s: %s" % (func.__name__, e)) + retry = retry + 1 + time.sleep(RETRIES_DELAY) + continue + + raise Exception("'{}' failed after {} retries".format(func.__name__, RETRIES_COUNT)) + def flash(brdfile, serialno, binfile): - """Call fw_update to upload to updater USB endpoint. + """Call fw_update to upload to updater USB endpoint. + + Args: + brdfile: path to board configuration file + serialno: device serial number + binfile: firmware file + """ - Args: - brdfile: path to board configuration file - serialno: device serial number - binfile: firmware file - """ + p = fw_update.Supdate() + p.load_board(brdfile) + p.connect_usb(serialname=serialno) + p.load_file(binfile) - p = fw_update.Supdate() - p.load_board(brdfile) - p.connect_usb(serialname=serialno) - p.load_file(binfile) + # Start transfer and erase. + p.start() + # Upload the bin file + print("Uploading %s" % binfile) + p.write_file() - # Start transfer and erase. - p.start() - # Upload the bin file - print('Uploading %s' % binfile) - p.write_file() + # Finalize + print("Done. Finalizing.") + p.stop() - # Finalize - print('Done. Finalizing.') - p.stop() def flash2(vidpid, serialno, binfile): - """Call fw update via usb_updater2 commandline. - - Args: - vidpid: vendor id and product id of device - serialno: device serial number (optional) - binfile: firmware file - """ - - tool = 'usb_updater2' - cmd = '%s -d %s' % (tool, vidpid) - if serialno: - cmd += ' -S %s' % serialno - cmd += ' -n' - cmd += ' %s' % binfile - - print(cmd) - help_cmd = '%s --help' % tool - with open('/dev/null') as devnull: - valid_check = subprocess.call(help_cmd.split(), stdout=devnull, - stderr=devnull) - if valid_check: - raise ServoUpdaterException('%s exit with res = %d. Make sure the tool ' - 'is available on the device.' % (help_cmd, - valid_check)) - res = subprocess.call(cmd.split()) - - if res in (0, 1, 2): - return res - else: - raise ServoUpdaterException('%s exit with res = %d' % (cmd, res)) + """Call fw update via usb_updater2 commandline. + + Args: + vidpid: vendor id and product id of device + serialno: device serial number (optional) + binfile: firmware file + """ + + tool = "usb_updater2" + cmd = "%s -d %s" % (tool, vidpid) + if serialno: + cmd += " -S %s" % serialno + cmd += " -n" + cmd += " %s" % binfile + + print(cmd) + help_cmd = "%s --help" % tool + with open("/dev/null") as devnull: + valid_check = subprocess.call(help_cmd.split(), stdout=devnull, stderr=devnull) + if valid_check: + raise ServoUpdaterException( + "%s exit with res = %d. Make sure the tool " + "is available on the device." % (help_cmd, valid_check) + ) + res = subprocess.call(cmd.split()) + + if res in (0, 1, 2): + return res + else: + raise ServoUpdaterException("%s exit with res = %d" % (cmd, res)) + def select(tinys, region): - """Jump to specified boot region + """Jump to specified boot region - Ensure the servo is in the expected ro/rw region. - This function jumps to the required region and verify if jump was - successful by executing 'sysinfo' command and reading current region. - If response was not received or region is invalid, exception is raised. + Ensure the servo is in the expected ro/rw region. + This function jumps to the required region and verify if jump was + successful by executing 'sysinfo' command and reading current region. + If response was not received or region is invalid, exception is raised. - Args: - tinys: TinyServod object - region: region to jump to, only "rw" and "ro" is allowed - """ + Args: + tinys: TinyServod object + region: region to jump to, only "rw" and "ro" is allowed + """ - if region not in ['rw', 'ro']: - raise Exception('Region must be ro or rw') + if region not in ["rw", "ro"]: + raise Exception("Region must be ro or rw") - if region == 'ro': - cmd = 'reboot' - else: - cmd = 'sysjump %s' % region + if region == "ro": + cmd = "reboot" + else: + cmd = "sysjump %s" % region - tinys.pty._issue_cmd(cmd) + tinys.pty._issue_cmd(cmd) - tinys.close() - time.sleep(2) - tinys.reinitialize() + tinys.close() + time.sleep(2) + tinys.reinitialize() + + res = tinys.pty._issue_cmd_get_results("sysinfo", [r"Copy:[\s]+(RO|RW)"]) + current_region = res[0][1].lower() + if current_region != region: + raise Exception("Invalid region: %s/%s" % (current_region, region)) - res = tinys.pty._issue_cmd_get_results('sysinfo', [r'Copy:[\s]+(RO|RW)']) - current_region = res[0][1].lower() - if current_region != region: - raise Exception('Invalid region: %s/%s' % (current_region, region)) def do_version(tinys): - """Check version via ec console 'pty'. + """Check version via ec console 'pty'. + + Args: + tinys: TinyServod object - Args: - tinys: TinyServod object + Returns: + detected version number - Returns: - detected version number + Commands are: + # > version + # ... + # Build: tigertail_v1.1.6749-74d1a312e + """ + cmd = "version" + regex = r"Build:\s+(\S+)[\r\n]+" - Commands are: - # > version - # ... - # Build: tigertail_v1.1.6749-74d1a312e - """ - cmd = 'version' - regex = r'Build:\s+(\S+)[\r\n]+' + results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0] - results = tinys.pty._issue_cmd_get_results(cmd, [regex])[0] + return results[1].strip(" \t\r\n\0") - return results[1].strip(' \t\r\n\0') def do_updater_version(tinys): - """Check whether this uses python updater or c++ updater - - Args: - tinys: TinyServod object - - Returns: - updater version number. 2 or 6. - """ - vers = do_version(tinys) - - # Servo versions below 58 are from servo-9040.B. Versions starting with _v2 - # are newer than anything _v1, no need to check the exact number. Updater - # version is not directly queryable. - if re.search(r'_v[2-9]\.\d', vers): - return 6 - m = re.search(r'_v1\.1\.(\d\d\d\d)', vers) - if m: - version_number = int(m.group(1)) - if version_number < 5800: - return 2 - else: - return 6 - raise ServoUpdaterException( - "Can't determine updater target from vers: [%s]" % vers) + """Check whether this uses python updater or c++ updater -def _extract_version(boardname, binfile): - """Find the version string from |binfile|. + Args: + tinys: TinyServod object - Args: - boardname: the name of the board, eg. "servo_micro" - binfile: path to the binary to search + Returns: + updater version number. 2 or 6. + """ + vers = do_version(tinys) - Returns: - the version string. - """ - if boardname is None: - # cannot extract the version if the name is None - return None - rawstrings = subprocess.check_output( - ['cbfstool', binfile, 'read', '-r', 'RO_FRID', '-f', '/dev/stdout'], - **c.get_subprocess_args()) - m = re.match(r'%s_v\S+' % boardname, rawstrings) - if m: - newvers = m.group(0).strip(' \t\r\n\0') - else: - raise ServoUpdaterException("Can't find version from file: %s." % binfile) + # Servo versions below 58 are from servo-9040.B. Versions starting with _v2 + # are newer than anything _v1, no need to check the exact number. Updater + # version is not directly queryable. + if re.search(r"_v[2-9]\.\d", vers): + return 6 + m = re.search(r"_v1\.1\.(\d\d\d\d)", vers) + if m: + version_number = int(m.group(1)) + if version_number < 5800: + return 2 + else: + return 6 + raise ServoUpdaterException("Can't determine updater target from vers: [%s]" % vers) + + +def _extract_version(boardname, binfile): + """Find the version string from |binfile|. + + Args: + boardname: the name of the board, eg. "servo_micro" + binfile: path to the binary to search + + Returns: + the version string. + """ + if boardname is None: + # cannot extract the version if the name is None + return None + rawstrings = subprocess.check_output( + ["cbfstool", binfile, "read", "-r", "RO_FRID", "-f", "/dev/stdout"], + **c.get_subprocess_args() + ) + m = re.match(r"%s_v\S+" % boardname, rawstrings) + if m: + newvers = m.group(0).strip(" \t\r\n\0") + else: + raise ServoUpdaterException("Can't find version from file: %s." % binfile) + + return newvers - return newvers def get_firmware_channel(bname, version): - """Find out which channel |version| for |bname| came from. - - Args: - bname: board name - version: current version string - - Returns: - one of the channel names if |version| came from one of those, or None - """ - for channel in CHANNELS: - # Pass |bname| as cname to find the board specific file, and pass None as - # fname to ensure the default directory is searched - _, _, vers = get_files_and_version(bname, None, channel=channel) - if version == vers: - return channel - # None of the channels matched. This firmware is currently unknown. - return None + """Find out which channel |version| for |bname| came from. + + Args: + bname: board name + version: current version string + + Returns: + one of the channel names if |version| came from one of those, or None + """ + for channel in CHANNELS: + # Pass |bname| as cname to find the board specific file, and pass None as + # fname to ensure the default directory is searched + _, _, vers = get_files_and_version(bname, None, channel=channel) + if version == vers: + return channel + # None of the channels matched. This firmware is currently unknown. + return None + def get_files_and_version(cname, fname=None, channel=DEFAULT_CHANNEL): - """Select config and firmware binary files. - - This checks default file names and paths. - In: /usr/share/servo_updater/[firmware|configs] - check for board.json, board.bin - - Args: - cname: board name, or config name. eg. "servo_v4" or "servo_v4.json" - fname: firmware binary name. Can be None to try default. - channel: the channel requested for servo firmware. See |CHANNELS| above. - - Returns: - cname, fname, version: validated filenames selected from the path. - """ - for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH): - updater_path = os.path.join(p, COMMON_PATH) - if os.path.exists(updater_path): - break - else: - raise ServoUpdaterException('servo_updater/ dir not found in known spots.') - - firmware_path = os.path.join(updater_path, FIRMWARE_DIR) - configs_path = os.path.join(updater_path, CONFIGS_DIR) - - for p in (firmware_path, configs_path): - if not os.path.exists(p): - raise ServoUpdaterException('Could not find required path %r' % p) - - if not os.path.isfile(cname): - # If not an existing file, try checking on the default path. - newname = os.path.join(configs_path, cname) - if os.path.isfile(newname): - cname = newname + """Select config and firmware binary files. + + This checks default file names and paths. + In: /usr/share/servo_updater/[firmware|configs] + check for board.json, board.bin + + Args: + cname: board name, or config name. eg. "servo_v4" or "servo_v4.json" + fname: firmware binary name. Can be None to try default. + channel: the channel requested for servo firmware. See |CHANNELS| above. + + Returns: + cname, fname, version: validated filenames selected from the path. + """ + for p in (DEFAULT_BASE_PATH, TEST_IMAGE_BASE_PATH): + updater_path = os.path.join(p, COMMON_PATH) + if os.path.exists(updater_path): + break else: - # Try appending ".json" to convert board name to config file. - cname = newname + '.json' + raise ServoUpdaterException("servo_updater/ dir not found in known spots.") + + firmware_path = os.path.join(updater_path, FIRMWARE_DIR) + configs_path = os.path.join(updater_path, CONFIGS_DIR) + + for p in (firmware_path, configs_path): + if not os.path.exists(p): + raise ServoUpdaterException("Could not find required path %r" % p) + if not os.path.isfile(cname): - raise ServoUpdaterException("Can't find config file: %s." % cname) - - # Always retrieve the boardname - with open(cname) as data_file: - data = json.load(data_file) - boardname = data['board'] - - if not fname: - # If no |fname| supplied, look for the default locations with the board - # and channel requested. - binary_file = '%s.%s.bin' % (boardname, channel) - newname = os.path.join(firmware_path, binary_file) - if os.path.isfile(newname): - fname = newname - else: - raise ServoUpdaterException("Can't find firmware binary: %s." % - binary_file) - elif not os.path.isfile(fname): - # If a name is specified but not found, try the default path. - newname = os.path.join(firmware_path, fname) - if os.path.isfile(newname): - fname = newname + # If not an existing file, try checking on the default path. + newname = os.path.join(configs_path, cname) + if os.path.isfile(newname): + cname = newname + else: + # Try appending ".json" to convert board name to config file. + cname = newname + ".json" + if not os.path.isfile(cname): + raise ServoUpdaterException("Can't find config file: %s." % cname) + + # Always retrieve the boardname + with open(cname) as data_file: + data = json.load(data_file) + boardname = data["board"] + + if not fname: + # If no |fname| supplied, look for the default locations with the board + # and channel requested. + binary_file = "%s.%s.bin" % (boardname, channel) + newname = os.path.join(firmware_path, binary_file) + if os.path.isfile(newname): + fname = newname + else: + raise ServoUpdaterException("Can't find firmware binary: %s." % binary_file) + elif not os.path.isfile(fname): + # If a name is specified but not found, try the default path. + newname = os.path.join(firmware_path, fname) + if os.path.isfile(newname): + fname = newname + else: + raise ServoUpdaterException("Can't find file: %s." % fname) + + # Lastly, retrieve the version as well for decision making, debug, and + # informational purposes. + binvers = _extract_version(boardname, fname) + + return cname, fname, binvers + + +def main(): + parser = argparse.ArgumentParser(description="Image a servo device") + parser.add_argument( + "-p", + "--print", + dest="print_only", + action="store_true", + default=False, + help="only print available firmware for board/channel", + ) + parser.add_argument( + "-s", "--serialno", type=str, help="serial number to program", default=None + ) + parser.add_argument( + "-b", + "--board", + type=str, + help="Board configuration json file", + default=DEFAULT_BOARD, + choices=BOARDS, + ) + parser.add_argument( + "-c", + "--channel", + type=str, + help="Firmware channel to use", + default=DEFAULT_CHANNEL, + choices=CHANNELS, + ) + parser.add_argument( + "-f", "--file", type=str, help="Complete ec.bin file", default=None + ) + parser.add_argument( + "--force", + action="store_true", + help="Update even if version match", + default=False, + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Chatty output") + parser.add_argument( + "-r", "--reboot", action="store_true", help="Always reboot, even after probe." + ) + + args = parser.parse_args() + + brdfile, binfile, newvers = get_files_and_version( + args.board, args.file, args.channel + ) + + # If the user only cares about the information then just print it here, + # and exit. + if args.print_only: + output = ("board: %s\n" "channel: %s\n" "firmware: %s") % ( + args.board, + args.channel, + newvers, + ) + print(output) + return + + serialno = args.serialno + + with open(brdfile) as data_file: + data = json.load(data_file) + vid, pid = int(data["vid"], 0), int(data["pid"], 0) + vidpid = "%04x:%04x" % (vid, pid) + iface = int(data["console"], 0) + boardname = data["board"] + + # Make sure device is up. + print("===== Waiting for USB device =====") + c.wait_for_usb(vidpid, serialname=serialno) + # We need a tiny_servod to query some information. Set it up first. + tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose) + + if not args.force: + vers = do_version(tinys) + print("Current %s version is %s" % (boardname, vers)) + print("Available %s version is %s" % (boardname, newvers)) + + if newvers == vers: + print("No version update needed") + if args.reboot: + select(tinys, "ro") + return + else: + print("Updating to recommended version.") + + # Make sure the servo MCU is in RO + print("===== Jumping to RO =====") + do_with_retries(select, tinys, "ro") + + print("===== Flashing RW =====") + vers = do_with_retries(do_updater_version, tinys) + # To make sure that the tiny_servod here does not interfere with other + # processes, close it out. + tinys.close() + + if vers == 2: + flash(brdfile, serialno, binfile) + elif vers == 6: + do_with_retries(flash2, vidpid, serialno, binfile) else: - raise ServoUpdaterException("Can't find file: %s." % fname) + raise ServoUpdaterException("Can't detect updater version") - # Lastly, retrieve the version as well for decision making, debug, and - # informational purposes. - binvers = _extract_version(boardname, fname) + # Make sure device is up. + c.wait_for_usb(vidpid, serialname=serialno) + # After we have made sure that it's back/available, reconnect the tiny servod. + tinys.reinitialize() - return cname, fname, binvers + # Make sure the servo MCU is in RW + print("===== Jumping to RW =====") + do_with_retries(select, tinys, "rw") -def main(): - parser = argparse.ArgumentParser(description='Image a servo device') - parser.add_argument('-p', '--print', dest='print_only', action='store_true', - default=False, - help='only print available firmware for board/channel') - parser.add_argument('-s', '--serialno', type=str, - help='serial number to program', default=None) - parser.add_argument('-b', '--board', type=str, - help='Board configuration json file', - default=DEFAULT_BOARD, choices=BOARDS) - parser.add_argument('-c', '--channel', type=str, - help='Firmware channel to use', - default=DEFAULT_CHANNEL, choices=CHANNELS) - parser.add_argument('-f', '--file', type=str, - help='Complete ec.bin file', default=None) - parser.add_argument('--force', action='store_true', - help='Update even if version match', default=False) - parser.add_argument('-v', '--verbose', action='store_true', - help='Chatty output') - parser.add_argument('-r', '--reboot', action='store_true', - help='Always reboot, even after probe.') - - args = parser.parse_args() - - brdfile, binfile, newvers = get_files_and_version(args.board, args.file, - args.channel) - - # If the user only cares about the information then just print it here, - # and exit. - if args.print_only: - output = ('board: %s\n' - 'channel: %s\n' - 'firmware: %s') % (args.board, args.channel, newvers) - print(output) - return - - serialno = args.serialno - - with open(brdfile) as data_file: - data = json.load(data_file) - vid, pid = int(data['vid'], 0), int(data['pid'], 0) - vidpid = '%04x:%04x' % (vid, pid) - iface = int(data['console'], 0) - boardname = data['board'] - - # Make sure device is up. - print('===== Waiting for USB device =====') - c.wait_for_usb(vidpid, serialname=serialno) - # We need a tiny_servod to query some information. Set it up first. - tinys = tiny_servod.TinyServod(vid, pid, iface, serialno, args.verbose) - - if not args.force: - vers = do_version(tinys) - print('Current %s version is %s' % (boardname, vers)) - print('Available %s version is %s' % (boardname, newvers)) - - if newvers == vers: - print('No version update needed') - if args.reboot: - select(tinys, 'ro') - return + print("===== Flashing RO =====") + vers = do_with_retries(do_updater_version, tinys) + + if vers == 2: + flash(brdfile, serialno, binfile) + elif vers == 6: + do_with_retries(flash2, vidpid, serialno, binfile) else: - print('Updating to recommended version.') - - # Make sure the servo MCU is in RO - print('===== Jumping to RO =====') - do_with_retries(select, tinys, 'ro') - - print('===== Flashing RW =====') - vers = do_with_retries(do_updater_version, tinys) - # To make sure that the tiny_servod here does not interfere with other - # processes, close it out. - tinys.close() - - if vers == 2: - flash(brdfile, serialno, binfile) - elif vers == 6: - do_with_retries(flash2, vidpid, serialno, binfile) - else: - raise ServoUpdaterException("Can't detect updater version") - - # Make sure device is up. - c.wait_for_usb(vidpid, serialname=serialno) - # After we have made sure that it's back/available, reconnect the tiny servod. - tinys.reinitialize() - - # Make sure the servo MCU is in RW - print('===== Jumping to RW =====') - do_with_retries(select, tinys, 'rw') - - print('===== Flashing RO =====') - vers = do_with_retries(do_updater_version, tinys) - - if vers == 2: - flash(brdfile, serialno, binfile) - elif vers == 6: - do_with_retries(flash2, vidpid, serialno, binfile) - else: - raise ServoUpdaterException("Can't detect updater version") - - # Make sure the servo MCU is in RO - print('===== Rebooting =====') - do_with_retries(select, tinys, 'ro') - # Perform additional reboot to free USB/UART resources, taken by tiny servod. - # See https://issuetracker.google.com/196021317 for background. - tinys.pty._issue_cmd('reboot') - - print('===== Finished =====') - -if __name__ == '__main__': - main() + raise ServoUpdaterException("Can't detect updater version") + + # Make sure the servo MCU is in RO + print("===== Rebooting =====") + do_with_retries(select, tinys, "ro") + # Perform additional reboot to free USB/UART resources, taken by tiny servod. + # See https://issuetracker.google.com/196021317 for background. + tinys.pty._issue_cmd("reboot") + + print("===== Finished =====") + + +if __name__ == "__main__": + main() diff --git a/firmware_builder.py b/firmware_builder.py index 78b7614190..d2548c3f5f 100755 --- a/firmware_builder.py +++ b/firmware_builder.py @@ -16,21 +16,20 @@ import pathlib import subprocess import sys -# pylint: disable=import-error -from google.protobuf import json_format - from chromite.api.gen_sdk.chromite.api import firmware_pb2 +# pylint: disable=import-error +from google.protobuf import json_format -DEFAULT_BUNDLE_DIRECTORY = '/tmp/artifact_bundles' -DEFAULT_BUNDLE_METADATA_FILE = '/tmp/artifact_bundle_metadata' +DEFAULT_BUNDLE_DIRECTORY = "/tmp/artifact_bundles" +DEFAULT_BUNDLE_METADATA_FILE = "/tmp/artifact_bundle_metadata" # The the list of boards whose on-device unit tests we will verify compilation. # TODO(b/172501728) On-device unit tests should build for all boards, but # they've bit rotted, so we only build the ones that compile. BOARDS_UNIT_TEST = [ - 'bloonchipper', - 'dartmonkey', + "bloonchipper", + "dartmonkey", ] @@ -51,59 +50,57 @@ def build(opts): "When --code-coverage is selected, 'build' is a no-op. " "Run 'test' with --code-coverage instead." ) - with open(opts.metrics, 'w') as f: + with open(opts.metrics, "w") as f: f.write(json_format.MessageToJson(metric_list)) return ec_dir = pathlib.Path(__file__).parent subprocess.run([ec_dir / "util" / "check_clang_format.py"], check=True) - cmd = ['make', 'buildall_only', f'-j{opts.cpus}'] + cmd = ["make", "buildall_only", f"-j{opts.cpus}"] print(f"# Running {' '.join(cmd)}.") subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True) ec_dir = os.path.dirname(__file__) - build_dir = os.path.join(ec_dir, 'build') + build_dir = os.path.join(ec_dir, "build") for build_target in sorted(os.listdir(build_dir)): metric = metric_list.value.add() metric.target_name = build_target - metric.platform_name = 'ec' - for variant in ['RO', 'RW']: + metric.platform_name = "ec" + for variant in ["RO", "RW"]: memsize_file = ( pathlib.Path(build_dir) / build_target / variant - / f'ec.{variant}.elf.memsize.txt' + / f"ec.{variant}.elf.memsize.txt" ) if memsize_file.exists(): parse_memsize(memsize_file, metric, variant) - with open(opts.metrics, 'w') as f: + with open(opts.metrics, "w") as f: f.write(json_format.MessageToJson(metric_list)) # Ensure that there are no regressions for boards that build successfully # with clang: b/172020503. - cmd = ['./util/build_with_clang.py'] + cmd = ["./util/build_with_clang.py"] print(f'# Running {" ".join(cmd)}.') - subprocess.run(cmd, - cwd=os.path.dirname(__file__), - check=True) + subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True) UNITS = { - 'B': 1, - 'KB': 1024, - 'MB': 1024 * 1024, - 'GB': 1024 * 1024 * 1024, + "B": 1, + "KB": 1024, + "MB": 1024 * 1024, + "GB": 1024 * 1024 * 1024, } def parse_memsize(filename, metric, variant): - with open(filename, 'r') as infile: + with open(filename, "r") as infile: # Skip header line infile.readline() for line in infile.readlines(): parts = line.split() fw_section = metric.fw_section.add() - fw_section.region = variant + '_' + parts[0][:-1] + fw_section.region = variant + "_" + parts[0][:-1] fw_section.used = int(parts[1]) * UNITS[parts[2]] fw_section.total = int(parts[3]) * UNITS[parts[4]] fw_section.track_on_gerrit = False @@ -135,7 +132,7 @@ def write_metadata(opts, info): bundle_metadata_file = ( opts.metadata if opts.metadata else DEFAULT_BUNDLE_METADATA_FILE ) - with open(bundle_metadata_file, 'w') as f: + with open(bundle_metadata_file, "w") as f: f.write(json_format.MessageToJson(info)) @@ -145,10 +142,10 @@ def bundle_coverage(opts): info.bcs_version_info.version_string = opts.bcs_version bundle_dir = get_bundle_dir(opts) ec_dir = os.path.dirname(__file__) - tarball_name = 'coverage.tbz2' + tarball_name = "coverage.tbz2" tarball_path = os.path.join(bundle_dir, tarball_name) - cmd = ['tar', 'cvfj', tarball_path, 'lcov.info'] - subprocess.run(cmd, cwd=os.path.join(ec_dir, 'build/coverage'), check=True) + cmd = ["tar", "cvfj", tarball_path, "lcov.info"] + subprocess.run(cmd, cwd=os.path.join(ec_dir, "build/coverage"), check=True) meta = info.objects.add() meta.file_name = tarball_name meta.lcov_info.type = ( @@ -164,16 +161,20 @@ def bundle_firmware(opts): info.bcs_version_info.version_string = opts.bcs_version bundle_dir = get_bundle_dir(opts) ec_dir = os.path.dirname(__file__) - for build_target in sorted(os.listdir(os.path.join(ec_dir, 'build'))): - tarball_name = ''.join([build_target, '.firmware.tbz2']) + for build_target in sorted(os.listdir(os.path.join(ec_dir, "build"))): + tarball_name = "".join([build_target, ".firmware.tbz2"]) tarball_path = os.path.join(bundle_dir, tarball_name) cmd = [ - 'tar', 'cvfj', tarball_path, - '--exclude=*.o.d', '--exclude=*.o', '.', + "tar", + "cvfj", + tarball_path, + "--exclude=*.o.d", + "--exclude=*.o", + ".", ] subprocess.run( cmd, - cwd=os.path.join(ec_dir, 'build', build_target), + cwd=os.path.join(ec_dir, "build", build_target), check=True, ) meta = info.objects.add() @@ -191,7 +192,7 @@ def test(opts): """Runs all of the unit tests for EC firmware""" # TODO(b/169178847): Add appropriate metric information metrics = firmware_pb2.FwTestMetricList() - with open(opts.metrics, 'w') as f: + with open(opts.metrics, "w") as f: f.write(json_format.MessageToJson(metrics)) # If building for code coverage, build the 'coverage' target, which @@ -200,8 +201,8 @@ def test(opts): # # Otherwise, build the 'runtests' target, which verifies all # posix-based unit tests build and pass. - target = 'coverage' if opts.code_coverage else 'runtests' - cmd = ['make', target, f'-j{opts.cpus}'] + target = "coverage" if opts.code_coverage else "runtests" + cmd = ["make", target, f"-j{opts.cpus}"] print(f"# Running {' '.join(cmd)}.") subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True) @@ -209,13 +210,13 @@ def test(opts): # Verify compilation of the on-device unit test binaries. # TODO(b/172501728) These should build for all boards, but they've bit # rotted, so we only build the ones that compile. - cmd = ['make', f'-j{opts.cpus}'] - cmd.extend(['tests-' + b for b in BOARDS_UNIT_TEST]) + cmd = ["make", f"-j{opts.cpus}"] + cmd.extend(["tests-" + b for b in BOARDS_UNIT_TEST]) print(f"# Running {' '.join(cmd)}.") subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True) # Verify the tests pass with ASan also - cmd = ['make', 'TEST_ASAN=y', target, f'-j{opts.cpus}'] + cmd = ["make", "TEST_ASAN=y", target, f"-j{opts.cpus}"] print(f"# Running {' '.join(cmd)}.") subprocess.run(cmd, cwd=os.path.dirname(__file__), check=True) @@ -227,8 +228,8 @@ def main(args): """ opts = parse_args(args) - if not hasattr(opts, 'func'): - print('Must select a valid sub command!') + if not hasattr(opts, "func"): + print("Must select a valid sub command!") return -1 # Run selected sub command function @@ -244,66 +245,64 @@ def parse_args(args): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--cpus', + "--cpus", default=multiprocessing.cpu_count(), - help='The number of cores to use.', + help="The number of cores to use.", ) parser.add_argument( - '--metrics', - dest='metrics', + "--metrics", + dest="metrics", required=True, - help='File to write the json-encoded MetricsList proto message.', + help="File to write the json-encoded MetricsList proto message.", ) parser.add_argument( - '--metadata', + "--metadata", required=False, - help='Full pathname for the file in which to write build artifact ' - 'metadata.', + help="Full pathname for the file in which to write build artifact " "metadata.", ) parser.add_argument( - '--output-dir', + "--output-dir", required=False, - help='Full pathanme for the directory in which to bundle build ' - 'artifacts.', + help="Full pathanme for the directory in which to bundle build " "artifacts.", ) parser.add_argument( - '--code-coverage', + "--code-coverage", required=False, - action='store_true', - help='Build host-based unit tests for code coverage.', + action="store_true", + help="Build host-based unit tests for code coverage.", ) parser.add_argument( - '--bcs-version', - dest='bcs_version', - default='', + "--bcs-version", + dest="bcs_version", + default="", required=False, # TODO(b/180008931): make this required=True. - help='BCS version to include in metadata.', + help="BCS version to include in metadata.", ) # Would make this required=True, but not available until 3.7 sub_cmds = parser.add_subparsers() - build_cmd = sub_cmds.add_parser('build', help='Builds all firmware targets') + build_cmd = sub_cmds.add_parser("build", help="Builds all firmware targets") build_cmd.set_defaults(func=build) build_cmd = sub_cmds.add_parser( - 'bundle', - help='Creates a tarball containing build ' - 'artifacts from all firmware targets', + "bundle", + help="Creates a tarball containing build " + "artifacts from all firmware targets", ) build_cmd.set_defaults(func=bundle) - test_cmd = sub_cmds.add_parser('test', help='Runs all firmware unit tests') + test_cmd = sub_cmds.add_parser("test", help="Runs all firmware unit tests") test_cmd.set_defaults(func=test) return parser.parse_args(args) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) @@ -10,7 +10,7 @@ setup( author="Aseda Aboagye", author_email="aaboagye@chromium.org", url="https://www.chromium.org/chromium-os/ec-development", - package_dir={"" : "util"}, + package_dir={"": "util"}, packages=["ec3po"], py_modules=["ec3po.console", "ec3po.interpreter"], description="EC console interpreter.", @@ -22,7 +22,7 @@ setup( author="Nick Sanders", author_email="nsanders@chromium.org", url="https://www.chromium.org/chromium-os/ec-development", - package_dir={"" : "extra/tigertool"}, + package_dir={"": "extra/tigertool"}, packages=["ecusb"], description="Tiny implementation of servod.", ) @@ -33,17 +33,23 @@ setup( author="Nick Sanders", author_email="nsanders@chromium.org", url="https://www.chromium.org/chromium-os/ec-development", - package_dir={"" : "extra/usb_updater"}, + package_dir={"": "extra/usb_updater"}, py_modules=["servo_updater", "fw_update"], - entry_points = { + entry_points={ "console_scripts": ["servo_updater=servo_updater:main"], }, - data_files=[("share/servo_updater/configs", - ["extra/usb_updater/c2d2.json", - "extra/usb_updater/servo_v4.json", - "extra/usb_updater/servo_v4p1.json", - "extra/usb_updater/servo_micro.json", - "extra/usb_updater/sweetberry.json"])], + data_files=[ + ( + "share/servo_updater/configs", + [ + "extra/usb_updater/c2d2.json", + "extra/usb_updater/servo_v4.json", + "extra/usb_updater/servo_v4p1.json", + "extra/usb_updater/servo_micro.json", + "extra/usb_updater/sweetberry.json", + ], + ) + ], description="Servo usb updater.", ) @@ -53,9 +59,9 @@ setup( author="Nick Sanders", author_email="nsanders@chromium.org", url="https://www.chromium.org/chromium-os/ec-development", - package_dir={"" : "extra/usb_power"}, + package_dir={"": "extra/usb_power"}, py_modules=["powerlog", "stats_manager"], - entry_points = { + entry_points={ "console_scripts": ["powerlog=powerlog:main"], }, description="Sweetberry power logger.", @@ -67,9 +73,9 @@ setup( author="Nick Sanders", author_email="nsanders@chromium.org", url="https://www.chromium.org/chromium-os/ec-development", - package_dir={"" : "extra/usb_serial"}, + package_dir={"": "extra/usb_serial"}, py_modules=["console"], - entry_points = { + entry_points={ "console_scripts": ["usb_console=console:main"], }, description="Tool to open the usb console on servo, cr50.", @@ -81,11 +87,10 @@ setup( author="Wei-Han Chen", author_email="stimim@chromium.org", url="https://www.chromium.org/chromium-os/ec-development", - package_dir={"" : "util"}, + package_dir={"": "util"}, py_modules=["unpack_ftb"], - entry_points = { + entry_points={ "console_scripts": ["unpack_ftb=unpack_ftb:main"], }, description="Tool to convert ST touchpad .ftb file to .bin", ) - diff --git a/test/run_device_tests.py b/test/run_device_tests.py index 8de2fa417c..09f255f43e 100755 --- a/test/run_device_tests.py +++ b/test/run_device_tests.py @@ -51,63 +51,74 @@ import time from concurrent.futures.thread import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Optional, BinaryIO, List +from typing import BinaryIO, List, Optional # pylint: disable=import-error import colorama # type: ignore[import] -from contextlib2 import ExitStack import fmap +from contextlib2 import ExitStack + # pylint: enable=import-error EC_DIR = Path(os.path.dirname(os.path.realpath(__file__))).parent -JTRACE_FLASH_SCRIPT = os.path.join(EC_DIR, 'util/flash_jlink.py') -SERVO_MICRO_FLASH_SCRIPT = os.path.join(EC_DIR, 'util/flash_ec') +JTRACE_FLASH_SCRIPT = os.path.join(EC_DIR, "util/flash_jlink.py") +SERVO_MICRO_FLASH_SCRIPT = os.path.join(EC_DIR, "util/flash_ec") -ALL_TESTS_PASSED_REGEX = re.compile(r'Pass!\r\n') -ALL_TESTS_FAILED_REGEX = re.compile(r'Fail! \(\d+ tests\)\r\n') +ALL_TESTS_PASSED_REGEX = re.compile(r"Pass!\r\n") +ALL_TESTS_FAILED_REGEX = re.compile(r"Fail! \(\d+ tests\)\r\n") -SINGLE_CHECK_PASSED_REGEX = re.compile(r'Pass: .*') -SINGLE_CHECK_FAILED_REGEX = re.compile(r'.*failed:.*') +SINGLE_CHECK_PASSED_REGEX = re.compile(r"Pass: .*") +SINGLE_CHECK_FAILED_REGEX = re.compile(r".*failed:.*") -ASSERTION_FAILURE_REGEX = re.compile(r'ASSERTION FAILURE.*') +ASSERTION_FAILURE_REGEX = re.compile(r"ASSERTION FAILURE.*") DATA_ACCESS_VIOLATION_8020000_REGEX = re.compile( - r'Data access violation, mfar = 8020000\r\n') + r"Data access violation, mfar = 8020000\r\n" +) DATA_ACCESS_VIOLATION_8040000_REGEX = re.compile( - r'Data access violation, mfar = 8040000\r\n') + r"Data access violation, mfar = 8040000\r\n" +) DATA_ACCESS_VIOLATION_80C0000_REGEX = re.compile( - r'Data access violation, mfar = 80c0000\r\n') + r"Data access violation, mfar = 80c0000\r\n" +) DATA_ACCESS_VIOLATION_80E0000_REGEX = re.compile( - r'Data access violation, mfar = 80e0000\r\n') + r"Data access violation, mfar = 80e0000\r\n" +) DATA_ACCESS_VIOLATION_20000000_REGEX = re.compile( - r'Data access violation, mfar = 20000000\r\n') + r"Data access violation, mfar = 20000000\r\n" +) DATA_ACCESS_VIOLATION_24000000_REGEX = re.compile( - r'Data access violation, mfar = 24000000\r\n') + r"Data access violation, mfar = 24000000\r\n" +) -BLOONCHIPPER = 'bloonchipper' -DARTMONKEY = 'dartmonkey' +BLOONCHIPPER = "bloonchipper" +DARTMONKEY = "dartmonkey" -JTRACE = 'jtrace' -SERVO_MICRO = 'servo_micro' +JTRACE = "jtrace" +SERVO_MICRO = "servo_micro" -GCC = 'gcc' -CLANG = 'clang' +GCC = "gcc" +CLANG = "clang" -TEST_ASSETS_BUCKET = 'gs://chromiumos-test-assets-public/fpmcu/RO' +TEST_ASSETS_BUCKET = "gs://chromiumos-test-assets-public/fpmcu/RO" DARTMONKEY_IMAGE_PATH = os.path.join( - TEST_ASSETS_BUCKET, 'dartmonkey_v2.0.2887-311310808.bin') + TEST_ASSETS_BUCKET, "dartmonkey_v2.0.2887-311310808.bin" +) NOCTURNE_FP_IMAGE_PATH = os.path.join( - TEST_ASSETS_BUCKET, 'nocturne_fp_v2.2.64-58cf5974e.bin') -NAMI_FP_IMAGE_PATH = os.path.join( - TEST_ASSETS_BUCKET, 'nami_fp_v2.2.144-7a08e07eb.bin') + TEST_ASSETS_BUCKET, "nocturne_fp_v2.2.64-58cf5974e.bin" +) +NAMI_FP_IMAGE_PATH = os.path.join(TEST_ASSETS_BUCKET, "nami_fp_v2.2.144-7a08e07eb.bin") BLOONCHIPPER_V4277_IMAGE_PATH = os.path.join( - TEST_ASSETS_BUCKET, 'bloonchipper_v2.0.4277-9f652bb3.bin') + TEST_ASSETS_BUCKET, "bloonchipper_v2.0.4277-9f652bb3.bin" +) BLOONCHIPPER_V5938_IMAGE_PATH = os.path.join( - TEST_ASSETS_BUCKET, 'bloonchipper_v2.0.5938-197506c1.bin') + TEST_ASSETS_BUCKET, "bloonchipper_v2.0.5938-197506c1.bin" +) class ImageType(Enum): """EC Image type to use for the test.""" + RO = 1 RW = 2 @@ -115,9 +126,16 @@ class ImageType(Enum): class BoardConfig: """Board-specific configuration.""" - def __init__(self, name, servo_uart_name, servo_power_enable, - rollback_region0_regex, rollback_region1_regex, mpu_regex, - variants): + def __init__( + self, + name, + servo_uart_name, + servo_power_enable, + rollback_region0_regex, + rollback_region1_regex, + mpu_regex, + variants, + ): self.name = name self.servo_uart_name = servo_uart_name self.servo_power_enable = servo_power_enable @@ -130,18 +148,31 @@ class BoardConfig: class TestConfig: """Configuration for a given test.""" - def __init__(self, test_name, image_to_use=ImageType.RW, - finish_regexes=None, fail_regexes=None, toggle_power=False, - test_args=None, num_flash_attempts=2, timeout_secs=10, - enable_hw_write_protect=False, ro_image=None, build_board=None, - config_name=None): + def __init__( + self, + test_name, + image_to_use=ImageType.RW, + finish_regexes=None, + fail_regexes=None, + toggle_power=False, + test_args=None, + num_flash_attempts=2, + timeout_secs=10, + enable_hw_write_protect=False, + ro_image=None, + build_board=None, + config_name=None, + ): if test_args is None: test_args = [] if finish_regexes is None: finish_regexes = [ALL_TESTS_PASSED_REGEX, ALL_TESTS_FAILED_REGEX] if fail_regexes is None: - fail_regexes = [SINGLE_CHECK_FAILED_REGEX, ALL_TESTS_FAILED_REGEX, - ASSERTION_FAILURE_REGEX] + fail_regexes = [ + SINGLE_CHECK_FAILED_REGEX, + ALL_TESTS_FAILED_REGEX, + ASSERTION_FAILURE_REGEX, + ] if config_name is None: config_name = test_name @@ -177,68 +208,104 @@ class AllTests: @staticmethod def get_public_tests(board_config: BoardConfig) -> List[TestConfig]: tests = [ - TestConfig(test_name='aes'), - TestConfig(test_name='cec'), - TestConfig(test_name='cortexm_fpu'), - TestConfig(test_name='crc'), - TestConfig(test_name='flash_physical', image_to_use=ImageType.RO, - toggle_power=True), - TestConfig(test_name='flash_write_protect', - image_to_use=ImageType.RO, - toggle_power=True, enable_hw_write_protect=True), - TestConfig(test_name='fpsensor_hw'), - TestConfig(config_name='fpsensor_spi_ro', test_name='fpsensor', - image_to_use=ImageType.RO, test_args=['spi']), - TestConfig(config_name='fpsensor_spi_rw', test_name='fpsensor', - test_args=['spi']), - TestConfig(config_name='fpsensor_uart_ro', test_name='fpsensor', - image_to_use=ImageType.RO, test_args=['uart']), - TestConfig(config_name='fpsensor_uart_rw', test_name='fpsensor', - test_args=['uart']), - TestConfig(config_name='mpu_ro', test_name='mpu', - image_to_use=ImageType.RO, - finish_regexes=[board_config.mpu_regex]), - TestConfig(config_name='mpu_rw', test_name='mpu', - finish_regexes=[board_config.mpu_regex]), - TestConfig(test_name='mutex'), - TestConfig(test_name='pingpong'), - TestConfig(test_name='printf'), - TestConfig(test_name='queue'), - TestConfig(config_name='rollback_region0', test_name='rollback', - finish_regexes=[board_config.rollback_region0_regex], - test_args=['region0']), - TestConfig(config_name='rollback_region1', test_name='rollback', - finish_regexes=[board_config.rollback_region1_regex], - test_args=['region1']), - TestConfig(test_name='rollback_entropy', image_to_use=ImageType.RO), - TestConfig(test_name='rtc'), - TestConfig(test_name='sha256'), - TestConfig(test_name='sha256_unrolled'), - TestConfig(test_name='static_if'), - TestConfig(test_name='stdlib'), - TestConfig(config_name='system_is_locked_wp_on', - test_name='system_is_locked', test_args=['wp_on'], - toggle_power=True, enable_hw_write_protect=True), - TestConfig(config_name='system_is_locked_wp_off', - test_name='system_is_locked', test_args=['wp_off'], - toggle_power=True, enable_hw_write_protect=False), - TestConfig(test_name='timer_dos'), - TestConfig(test_name='utils', timeout_secs=20), - TestConfig(test_name='utils_str'), + TestConfig(test_name="aes"), + TestConfig(test_name="cec"), + TestConfig(test_name="cortexm_fpu"), + TestConfig(test_name="crc"), + TestConfig( + test_name="flash_physical", image_to_use=ImageType.RO, toggle_power=True + ), + TestConfig( + test_name="flash_write_protect", + image_to_use=ImageType.RO, + toggle_power=True, + enable_hw_write_protect=True, + ), + TestConfig(test_name="fpsensor_hw"), + TestConfig( + config_name="fpsensor_spi_ro", + test_name="fpsensor", + image_to_use=ImageType.RO, + test_args=["spi"], + ), + TestConfig( + config_name="fpsensor_spi_rw", test_name="fpsensor", test_args=["spi"] + ), + TestConfig( + config_name="fpsensor_uart_ro", + test_name="fpsensor", + image_to_use=ImageType.RO, + test_args=["uart"], + ), + TestConfig( + config_name="fpsensor_uart_rw", test_name="fpsensor", test_args=["uart"] + ), + TestConfig( + config_name="mpu_ro", + test_name="mpu", + image_to_use=ImageType.RO, + finish_regexes=[board_config.mpu_regex], + ), + TestConfig( + config_name="mpu_rw", + test_name="mpu", + finish_regexes=[board_config.mpu_regex], + ), + TestConfig(test_name="mutex"), + TestConfig(test_name="pingpong"), + TestConfig(test_name="printf"), + TestConfig(test_name="queue"), + TestConfig( + config_name="rollback_region0", + test_name="rollback", + finish_regexes=[board_config.rollback_region0_regex], + test_args=["region0"], + ), + TestConfig( + config_name="rollback_region1", + test_name="rollback", + finish_regexes=[board_config.rollback_region1_regex], + test_args=["region1"], + ), + TestConfig(test_name="rollback_entropy", image_to_use=ImageType.RO), + TestConfig(test_name="rtc"), + TestConfig(test_name="sha256"), + TestConfig(test_name="sha256_unrolled"), + TestConfig(test_name="static_if"), + TestConfig(test_name="stdlib"), + TestConfig( + config_name="system_is_locked_wp_on", + test_name="system_is_locked", + test_args=["wp_on"], + toggle_power=True, + enable_hw_write_protect=True, + ), + TestConfig( + config_name="system_is_locked_wp_off", + test_name="system_is_locked", + test_args=["wp_off"], + toggle_power=True, + enable_hw_write_protect=False, + ), + TestConfig(test_name="timer_dos"), + TestConfig(test_name="utils", timeout_secs=20), + TestConfig(test_name="utils_str"), ] if board_config.name == BLOONCHIPPER: - tests.append(TestConfig(test_name='stm32f_rtc')) + tests.append(TestConfig(test_name="stm32f_rtc")) # Run panic data tests for all boards and RO versions. for variant_name, variant_info in board_config.variants.items(): tests.append( - TestConfig(config_name='panic_data_' + variant_name, - test_name='panic_data', - fail_regexes=[SINGLE_CHECK_FAILED_REGEX, - ALL_TESTS_FAILED_REGEX], - ro_image=variant_info.get('ro_image_path'), - build_board=variant_info.get('build_board'))) + TestConfig( + config_name="panic_data_" + variant_name, + test_name="panic_data", + fail_regexes=[SINGLE_CHECK_FAILED_REGEX, ALL_TESTS_FAILED_REGEX], + ro_image=variant_info.get("ro_image_path"), + build_board=variant_info.get("build_board"), + ) + ) return tests @@ -248,75 +315,72 @@ class AllTests: tests = [] try: current_dir = os.path.dirname(__file__) - private_dir = os.path.join(current_dir, os.pardir, 'private/test') + private_dir = os.path.join(current_dir, os.pardir, "private/test") have_private = os.path.isdir(private_dir) if not have_private: return [] sys.path.append(private_dir) import private_tests # pylint: disable=import-error + for test_args in private_tests.tests: tests.append(TestConfig(**test_args)) # Catch all exceptions to avoid disruptions in public repo except BaseException as e: - logging.debug('Failed to get list of private tests: %s', str(e)) - logging.debug('Ignore error and continue.') + logging.debug("Failed to get list of private tests: %s", str(e)) + logging.debug("Ignore error and continue.") return [] return tests BLOONCHIPPER_CONFIG = BoardConfig( name=BLOONCHIPPER, - servo_uart_name='raw_fpmcu_console_uart_pty', - servo_power_enable='fpmcu_pp3300', + servo_uart_name="raw_fpmcu_console_uart_pty", + servo_power_enable="fpmcu_pp3300", rollback_region0_regex=DATA_ACCESS_VIOLATION_8020000_REGEX, rollback_region1_regex=DATA_ACCESS_VIOLATION_8040000_REGEX, mpu_regex=DATA_ACCESS_VIOLATION_20000000_REGEX, variants={ - 'bloonchipper_v2.0.4277': { - 'ro_image_path': BLOONCHIPPER_V4277_IMAGE_PATH - }, - 'bloonchipper_v2.0.5938': { - 'ro_image_path': BLOONCHIPPER_V5938_IMAGE_PATH - } - } + "bloonchipper_v2.0.4277": {"ro_image_path": BLOONCHIPPER_V4277_IMAGE_PATH}, + "bloonchipper_v2.0.5938": {"ro_image_path": BLOONCHIPPER_V5938_IMAGE_PATH}, + }, ) DARTMONKEY_CONFIG = BoardConfig( name=DARTMONKEY, - servo_uart_name='raw_fpmcu_console_uart_pty', - servo_power_enable='fpmcu_pp3300', + servo_uart_name="raw_fpmcu_console_uart_pty", + servo_power_enable="fpmcu_pp3300", rollback_region0_regex=DATA_ACCESS_VIOLATION_80C0000_REGEX, rollback_region1_regex=DATA_ACCESS_VIOLATION_80E0000_REGEX, mpu_regex=DATA_ACCESS_VIOLATION_24000000_REGEX, # For dartmonkey board, run panic data test also on nocturne_fp and # nami_fp boards with appropriate RO image. variants={ - 'dartmonkey_v2.0.2887': { - 'ro_image_path': DARTMONKEY_IMAGE_PATH + "dartmonkey_v2.0.2887": {"ro_image_path": DARTMONKEY_IMAGE_PATH}, + "nocturne_fp_v2.2.64": { + "ro_image_path": NOCTURNE_FP_IMAGE_PATH, + "build_board": "nocturne_fp", }, - 'nocturne_fp_v2.2.64': { - 'ro_image_path': NOCTURNE_FP_IMAGE_PATH, - 'build_board': 'nocturne_fp' + "nami_fp_v2.2.144": { + "ro_image_path": NAMI_FP_IMAGE_PATH, + "build_board": "nami_fp", }, - 'nami_fp_v2.2.144': { - 'ro_image_path': NAMI_FP_IMAGE_PATH, - 'build_board': 'nami_fp' - } - } + }, ) BOARD_CONFIGS = { - 'bloonchipper': BLOONCHIPPER_CONFIG, - 'dartmonkey': DARTMONKEY_CONFIG, + "bloonchipper": BLOONCHIPPER_CONFIG, + "dartmonkey": DARTMONKEY_CONFIG, } def read_file_gsutil(path: str) -> bytes: """Get data from bucket, using gsutil tool""" - cmd = ['gsutil', 'cat', path] + cmd = ["gsutil", "cat", path] - logging.debug('Running command: "%s"', ' '.join(cmd)) - gsutil = subprocess.run(cmd, stdout=subprocess.PIPE) # pylint: disable=subprocess-run-check + logging.debug('Running command: "%s"', " ".join(cmd)) + gsutil = subprocess.run( + cmd, stdout=subprocess.PIPE + ) # pylint: disable=subprocess-run-check gsutil.check_returncode() return gsutil.stdout @@ -324,9 +388,9 @@ def read_file_gsutil(path: str) -> bytes: def find_section_offset_size(section: str, image: bytes) -> (int, int): """Get offset and size of the section in image""" - areas = fmap.fmap_decode(image)['areas'] - area = next(area for area in areas if area['name'] == section) - return area['offset'], area['size'] + areas = fmap.fmap_decode(image)["areas"] + area = next(area for area in areas if area["name"] == section) + return area["offset"], area["size"] def read_section(src: bytes, section: str) -> bytes: @@ -341,10 +405,10 @@ def write_section(data: bytes, image: bytearray, section: str): (section_start, section_size) = find_section_offset_size(section, image) if section_size < len(data): - raise ValueError(section + ' section size is not enough to store data') + raise ValueError(section + " section size is not enough to store data") section_end = section_start + section_size - filling = bytes([0xff for _ in range(section_size - len(data))]) + filling = bytes([0xFF for _ in range(section_size - len(data))]) image[section_start:section_end] = data + filling @@ -355,12 +419,14 @@ def copy_section(src: bytes, dst: bytearray, section: str): (dst_start, dst_size) = find_section_offset_size(section, dst) if dst_size < src_size: - raise ValueError('Section ' + section + ' from source image has ' - 'greater size than the section in destination image') + raise ValueError( + "Section " + section + " from source image has " + "greater size than the section in destination image" + ) src_end = src_start + src_size dst_end = dst_start + dst_size - filling = bytes([0xff for _ in range(dst_size - src_size)]) + filling = bytes([0xFF for _ in range(dst_size - src_size)]) dst[dst_start:dst_end] = src[src_start:src_end] + filling @@ -368,28 +434,28 @@ def copy_section(src: bytes, dst: bytearray, section: str): def replace_ro(image: bytearray, ro: bytes): """Replace RO in image with provided one""" # Backup RO public key since its private part was used to sign RW. - ro_pubkey = read_section(image, 'KEY_RO') + ro_pubkey = read_section(image, "KEY_RO") # Copy RO part of the firmware to the image. Please note that RO public key # is copied too since EC_RO area includes KEY_RO area. - copy_section(ro, image, 'EC_RO') + copy_section(ro, image, "EC_RO") # Restore RO public key. - write_section(ro_pubkey, image, 'KEY_RO') + write_section(ro_pubkey, image, "KEY_RO") def get_console(board_config: BoardConfig) -> Optional[str]: """Get the name of the console for a given board.""" cmd = [ - 'dut-control', + "dut-control", board_config.servo_uart_name, ] - logging.debug('Running command: "%s"', ' '.join(cmd)) + logging.debug('Running command: "%s"', " ".join(cmd)) with subprocess.Popen(cmd, stdout=subprocess.PIPE) as proc: for line in io.TextIOWrapper(proc.stdout): # type: ignore[arg-type] logging.debug(line) - pty = line.split(':') + pty = line.split(":") if len(pty) == 2 and pty[0] == board_config.servo_uart_name: return pty[1].strip() @@ -399,77 +465,82 @@ def get_console(board_config: BoardConfig) -> Optional[str]: def power(board_config: BoardConfig, on: bool) -> None: """Turn power to board on/off.""" if on: - state = 'pp3300' + state = "pp3300" else: - state = 'off' + state = "off" cmd = [ - 'dut-control', - board_config.servo_power_enable + ':' + state, + "dut-control", + board_config.servo_power_enable + ":" + state, ] - logging.debug('Running command: "%s"', ' '.join(cmd)) + logging.debug('Running command: "%s"', " ".join(cmd)) subprocess.run(cmd).check_returncode() # pylint: disable=subprocess-run-check def hw_write_protect(enable: bool) -> None: """Enable/disable hardware write protect.""" if enable: - state = 'force_on' + state = "force_on" else: - state = 'force_off' + state = "force_off" cmd = [ - 'dut-control', - 'fw_wp_state:' + state, - ] - logging.debug('Running command: "%s"', ' '.join(cmd)) + "dut-control", + "fw_wp_state:" + state, + ] + logging.debug('Running command: "%s"', " ".join(cmd)) subprocess.run(cmd).check_returncode() # pylint: disable=subprocess-run-check def build(test_name: str, board_name: str, compiler: str) -> None: """Build specified test for specified board.""" - cmd = ['make'] + cmd = ["make"] if compiler == CLANG: - cmd = cmd + ['CC=arm-none-eabi-clang'] + cmd = cmd + ["CC=arm-none-eabi-clang"] cmd = cmd + [ - 'BOARD=' + board_name, - 'test-' + test_name, - '-j', + "BOARD=" + board_name, + "test-" + test_name, + "-j", ] - logging.debug('Running command: "%s"', ' '.join(cmd)) + logging.debug('Running command: "%s"', " ".join(cmd)) subprocess.run(cmd).check_returncode() # pylint: disable=subprocess-run-check -def flash(image_path: str, board: str, flasher: str, remote_ip: str, - remote_port: int) -> bool: +def flash( + image_path: str, board: str, flasher: str, remote_ip: str, remote_port: int +) -> bool: """Flash specified test to specified board.""" - logging.info('Flashing test') + logging.info("Flashing test") cmd = [] if flasher == JTRACE: cmd.append(JTRACE_FLASH_SCRIPT) if remote_ip: - cmd.extend(['--remote', remote_ip + ':' + str(remote_port)]) + cmd.extend(["--remote", remote_ip + ":" + str(remote_port)]) elif flasher == SERVO_MICRO: cmd.append(SERVO_MICRO_FLASH_SCRIPT) else: logging.error('Unknown flasher: "%s"', flasher) return False - cmd.extend([ - '--board', board, - '--image', image_path, - ]) - logging.debug('Running command: "%s"', ' '.join(cmd)) + cmd.extend( + [ + "--board", + board, + "--image", + image_path, + ] + ) + logging.debug('Running command: "%s"', " ".join(cmd)) completed_process = subprocess.run(cmd) # pylint: disable=subprocess-run-check return completed_process.returncode == 0 def patch_image(test: TestConfig, image_path: str): """Replace RO part of the firmware with provided one.""" - with open(image_path, 'rb+') as f: + with open(image_path, "rb+") as f: image = bytearray(f.read()) ro = read_file_gsutil(test.ro_image) replace_ro(image, ro) @@ -478,8 +549,9 @@ def patch_image(test: TestConfig, image_path: str): f.truncate() -def readline(executor: ThreadPoolExecutor, f: BinaryIO, timeout_secs: int) -> \ - Optional[bytes]: +def readline( + executor: ThreadPoolExecutor, f: BinaryIO, timeout_secs: int +) -> Optional[bytes]: """Read a line with timeout.""" a = executor.submit(f.readline) try: @@ -488,8 +560,7 @@ def readline(executor: ThreadPoolExecutor, f: BinaryIO, timeout_secs: int) -> \ return None -def readlines_until_timeout(executor, f: BinaryIO, timeout_secs: int) -> \ - List[bytes]: +def readlines_until_timeout(executor, f: BinaryIO, timeout_secs: int) -> List[bytes]: """Continuously read lines for timeout_secs.""" lines: List[bytes] = [] while True: @@ -519,19 +590,20 @@ def process_console_output_line(line: bytes, test: TestConfig): return None -def run_test(test: TestConfig, console: io.FileIO, - executor: ThreadPoolExecutor) -> bool: +def run_test( + test: TestConfig, console: io.FileIO, executor: ThreadPoolExecutor +) -> bool: """Run specified test.""" start = time.time() # Wait for boot to finish time.sleep(1) - console.write('\n'.encode()) + console.write("\n".encode()) if test.image_to_use == ImageType.RO: - console.write('reboot ro\n'.encode()) + console.write("reboot ro\n".encode()) time.sleep(1) - test_cmd = 'runtest ' + ' '.join(test.test_args) + '\n' + test_cmd = "runtest " + " ".join(test.test_args) + "\n" console.write(test_cmd.encode()) while True: @@ -540,7 +612,7 @@ def run_test(test: TestConfig, console: io.FileIO, if not line: now = time.time() if now - start > test.timeout_secs: - logging.debug('Test timed out') + logging.debug("Test timed out") return False continue @@ -569,15 +641,18 @@ def run_test(test: TestConfig, console: io.FileIO, def get_test_list(config: BoardConfig, test_args) -> List[TestConfig]: """Get a list of tests to run.""" - if test_args == 'all': + if test_args == "all": return AllTests.get(config) test_list = [] for t in test_args: - logging.debug('test: %s', t) + logging.debug("test: %s", t) test_regex = re.compile(t) - tests = [test for test in AllTests.get(config) - if test_regex.fullmatch(test.config_name)] + tests = [ + test + for test in AllTests.get(config) + if test_regex.fullmatch(test.config_name) + ] if not tests: logging.error('Unable to find test config for "%s"', t) sys.exit(1) @@ -588,7 +663,7 @@ def get_test_list(config: BoardConfig, test_args) -> List[TestConfig]: def parse_remote_arg(remote: str) -> str: if not remote: - return '' + return "" try: ip = socket.gethostbyname(remote) @@ -601,67 +676,69 @@ def parse_remote_arg(remote: str) -> str: def main(): parser = argparse.ArgumentParser() - default_board = 'bloonchipper' - parser.add_argument( - '--board', '-b', - help='Board (default: ' + default_board + ')', - default=default_board) - - default_tests = 'all' + default_board = "bloonchipper" parser.add_argument( - '--tests', '-t', - nargs='+', - help='Tests (default: ' + default_tests + ')', - default=default_tests) + "--board", + "-b", + help="Board (default: " + default_board + ")", + default=default_board, + ) - log_level_choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + default_tests = "all" parser.add_argument( - '--log_level', '-l', - choices=log_level_choices, - default='DEBUG' + "--tests", + "-t", + nargs="+", + help="Tests (default: " + default_tests + ")", + default=default_tests, ) + log_level_choices = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + parser.add_argument("--log_level", "-l", choices=log_level_choices, default="DEBUG") + flasher_choices = [SERVO_MICRO, JTRACE] - parser.add_argument( - '--flasher', '-f', - choices=flasher_choices, - default=JTRACE - ) + parser.add_argument("--flasher", "-f", choices=flasher_choices, default=JTRACE) compiler_options = [GCC, CLANG] - parser.add_argument('--compiler', '-c', - choices=compiler_options, - default=GCC) + parser.add_argument("--compiler", "-c", choices=compiler_options, default=GCC) # This might be expanded to serve as a "remote" for flash_ec also, so # we will leave it generic. parser.add_argument( - '--remote', '-n', - help='The remote host connected to one or both of: J-Link and Servo.', + "--remote", + "-n", + help="The remote host connected to one or both of: J-Link and Servo.", ) - parser.add_argument('--jlink_port', '-j', - type=int, - help='The port to use when connecting to JLink.') - parser.add_argument('--console_port', '-p', - type=int, - help='The port connected to the FPMCU console.') + parser.add_argument( + "--jlink_port", "-j", type=int, help="The port to use when connecting to JLink." + ) + parser.add_argument( + "--console_port", + "-p", + type=int, + help="The port connected to the FPMCU console.", + ) args = parser.parse_args() logging.basicConfig(level=args.log_level) if args.jlink_port and not args.flasher == JTRACE: - logging.error('jlink_port specified, but flasher is not set to J-Link.') + logging.error("jlink_port specified, but flasher is not set to J-Link.") sys.exit(1) if args.remote and not (args.jlink_port or args.console_port): - logging.error('jlink_port or console_port must be specified when using ' - 'the remote option.') + logging.error( + "jlink_port or console_port must be specified when using " + "the remote option." + ) sys.exit(1) if (args.jlink_port or args.console_port) and not args.remote: - logging.error('The remote option must be specified when using the ' - 'jlink_port or console_port options.') + logging.error( + "The remote option must be specified when using the " + "jlink_port or console_port options." + ) sys.exit(1) if args.board not in BOARD_CONFIGS: @@ -675,9 +752,7 @@ def main(): e = ThreadPoolExecutor(max_workers=1) test_list = get_test_list(board_config, args.tests) - logging.debug( - 'Running tests: %s', [ - test.config_name for test in test_list]) + logging.debug("Running tests: %s", [test.config_name for test in test_list]) for test in test_list: build_board = args.board @@ -689,15 +764,17 @@ def main(): # build test binary build(test.test_name, build_board, args.compiler) - image_path = os.path.join(EC_DIR, 'build', build_board, test.test_name, - test.test_name + '.bin') + image_path = os.path.join( + EC_DIR, "build", build_board, test.test_name, test.test_name + ".bin" + ) if test.ro_image is not None: try: patch_image(test, image_path) except Exception as exception: - logging.warning('An exception occurred while patching ' - 'image: %s', exception) + logging.warning( + "An exception occurred while patching " "image: %s", exception + ) test.passed = False continue @@ -706,16 +783,16 @@ def main(): # flash_write_protect test is run; works after second attempt. flash_succeeded = False for i in range(0, test.num_flash_attempts): - logging.debug('Flash attempt %d', i + 1) - if flash(image_path, args.board, args.flasher, remote_ip, - args.jlink_port): + logging.debug("Flash attempt %d", i + 1) + if flash(image_path, args.board, args.flasher, remote_ip, args.jlink_port): flash_succeeded = True break time.sleep(1) if not flash_succeeded: - logging.debug('Flashing failed after max attempts: %d', - test.num_flash_attempts) + logging.debug( + "Flashing failed after max attempts: %d", test.num_flash_attempts + ) test.passed = False continue @@ -733,11 +810,11 @@ def main(): if remote_ip and args.console_port: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((remote_ip, args.console_port)) - console = stack.enter_context( - s.makefile(mode='rwb', buffering=0)) + console = stack.enter_context(s.makefile(mode="rwb", buffering=0)) else: console = stack.enter_context( - open(get_console(board_config), 'wb+', buffering=0)) + open(get_console(board_config), "wb+", buffering=0) + ) test.passed = run_test(test, console, executor=e) @@ -745,11 +822,11 @@ def main(): exit_code = 0 for test in test_list: # print results - print('Test "' + test.config_name + '": ', end='') + print('Test "' + test.config_name + '": ', end="") if test.passed: - print(colorama.Fore.GREEN + 'PASSED') + print(colorama.Fore.GREEN + "PASSED") else: - print(colorama.Fore.RED + 'FAILED') + print(colorama.Fore.RED + "FAILED") exit_code = 1 print(colorama.Style.RESET_ALL) @@ -758,5 +835,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/test/timer_calib.py b/test/timer_calib.py index 2a625d80c7..c7ac9fc23b 100644 --- a/test/timer_calib.py +++ b/test/timer_calib.py @@ -7,48 +7,52 @@ import time -def one_pass(helper): - helper.wait_output("=== Timer calibration ===") - res = helper.wait_output("back-to-back get_time : (?P<lat>[0-9]+) us", - use_re=True)["lat"] - minlat = int(res) - helper.trace("get_time latency %d us\n" % minlat) - - helper.wait_output("sleep 1s") - t0 = time.time() - second = helper.wait_output("done. delay = (?P<second>[0-9]+) us", - use_re=True)["second"] - t1 = time.time() - secondreal = t1 - t0 - secondlat = int(second) - 1000000 - helper.trace("1s timer latency %d us / real time %f s\n" % (secondlat, - secondreal)) - - us = {} - for pow2 in range(7): - delay = 1 << (7-pow2) - us[delay] = helper.wait_output("%d us => (?P<us>[0-9]+) us" % delay, - use_re=True)["us"] - helper.wait_output("Done.") - - return minlat, secondlat, secondreal +def one_pass(helper): + helper.wait_output("=== Timer calibration ===") + res = helper.wait_output("back-to-back get_time : (?P<lat>[0-9]+) us", use_re=True)[ + "lat" + ] + minlat = int(res) + helper.trace("get_time latency %d us\n" % minlat) + + helper.wait_output("sleep 1s") + t0 = time.time() + second = helper.wait_output("done. delay = (?P<second>[0-9]+) us", use_re=True)[ + "second" + ] + t1 = time.time() + secondreal = t1 - t0 + secondlat = int(second) - 1000000 + helper.trace("1s timer latency %d us / real time %f s\n" % (secondlat, secondreal)) + + us = {} + for pow2 in range(7): + delay = 1 << (7 - pow2) + us[delay] = helper.wait_output( + "%d us => (?P<us>[0-9]+) us" % delay, use_re=True + )["us"] + helper.wait_output("Done.") + + return minlat, secondlat, secondreal def test(helper): - one_pass(helper) + one_pass(helper) - helper.ec_command("reboot") - helper.wait_output("--- UART initialized") + helper.ec_command("reboot") + helper.wait_output("--- UART initialized") - # get the timing results on the second pass - # to avoid binary translation overhead - minlat, secondlat, secondreal = one_pass(helper) + # get the timing results on the second pass + # to avoid binary translation overhead + minlat, secondlat, secondreal = one_pass(helper) - # check that the timings somewhat make sense - if minlat > 220 or secondlat > 500 or abs(secondreal-1.0) > 0.200: - helper.fail("imprecise timings " + - "(get_time %d us sleep %d us / real time %.3f s)" % - (minlat, secondlat, secondreal)) + # check that the timings somewhat make sense + if minlat > 220 or secondlat > 500 or abs(secondreal - 1.0) > 0.200: + helper.fail( + "imprecise timings " + + "(get_time %d us sleep %d us / real time %.3f s)" + % (minlat, secondlat, secondreal) + ) - return True # PASS ! + return True # PASS ! diff --git a/test/timer_jump.py b/test/timer_jump.py index f506a69fcf..2801c3b3fa 100644 --- a/test/timer_jump.py +++ b/test/timer_jump.py @@ -10,22 +10,25 @@ import time DELAY = 5 ERROR_MARGIN = 0.5 + def test(helper): - helper.wait_output("idle task started") - helper.ec_command("sysinfo") - copy = helper.wait_output("Copy:\s+(?P<c>\S+)", use_re=True)["c"] - if copy != "RO": - helper.ec_command("sysjump ro") - helper.wait_output("idle task started") - helper.ec_command("gettime") - ec_start_time = helper.wait_output("Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s", - use_re=True)["t"] - time.sleep(DELAY) - helper.ec_command("sysjump a") - helper.wait_output("idle task started") - helper.ec_command("gettime") - ec_end_time = helper.wait_output("Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s", - use_re=True)["t"] + helper.wait_output("idle task started") + helper.ec_command("sysinfo") + copy = helper.wait_output("Copy:\s+(?P<c>\S+)", use_re=True)["c"] + if copy != "RO": + helper.ec_command("sysjump ro") + helper.wait_output("idle task started") + helper.ec_command("gettime") + ec_start_time = helper.wait_output( + "Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s", use_re=True + )["t"] + time.sleep(DELAY) + helper.ec_command("sysjump a") + helper.wait_output("idle task started") + helper.ec_command("gettime") + ec_end_time = helper.wait_output( + "Time: 0x[0-9a-f]* = (?P<t>[\d\.]+) s", use_re=True + )["t"] - time_diff = float(ec_end_time) - float(ec_start_time) - return time_diff >= DELAY and time_diff <= DELAY + ERROR_MARGIN + time_diff = float(ec_end_time) - float(ec_start_time) + return time_diff >= DELAY and time_diff <= DELAY + ERROR_MARGIN diff --git a/util/build_with_clang.py b/util/build_with_clang.py index a38ade2cb8..98da942152 100755 --- a/util/build_with_clang.py +++ b/util/build_with_clang.py @@ -12,15 +12,14 @@ import multiprocessing import os import subprocess import sys - from concurrent.futures import ThreadPoolExecutor # Add to this list as compilation errors are fixed for boards. BOARDS_THAT_COMPILE_SUCCESSFULLY_WITH_CLANG = [ - 'dartmonkey', - 'bloonchipper', - 'nucleo-f412zg', - 'nucleo-h743zi', + "dartmonkey", + "bloonchipper", + "nucleo-f412zg", + "nucleo-h743zi", ] @@ -29,35 +28,29 @@ def build(board_name: str) -> None: logging.debug('Building board: "%s"', board_name) cmd = [ - 'make', - 'BOARD=' + board_name, - '-j', + "make", + "BOARD=" + board_name, + "-j", ] - logging.debug('Running command: "%s"', ' '.join(cmd)) - subprocess.run(cmd, env=dict(os.environ, CC='clang'), check=True) + logging.debug('Running command: "%s"', " ".join(cmd)) + subprocess.run(cmd, env=dict(os.environ, CC="clang"), check=True) def main() -> int: parser = argparse.ArgumentParser() - log_level_choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] - parser.add_argument( - '--log_level', '-l', - choices=log_level_choices, - default='DEBUG' - ) + log_level_choices = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + parser.add_argument("--log_level", "-l", choices=log_level_choices, default="DEBUG") parser.add_argument( - '--num_threads', '-j', - type=int, - default=multiprocessing.cpu_count() + "--num_threads", "-j", type=int, default=multiprocessing.cpu_count() ) args = parser.parse_args() logging.basicConfig(level=args.log_level) - logging.debug('Building with %d threads', args.num_threads) + logging.debug("Building with %d threads", args.num_threads) failed_boards = [] with ThreadPoolExecutor(max_workers=args.num_threads) as executor: @@ -73,13 +66,14 @@ def main() -> int: failed_boards.append(board) if len(failed_boards) > 0: - logging.error('The following boards failed to compile:\n%s', - '\n'.join(failed_boards)) + logging.error( + "The following boards failed to compile:\n%s", "\n".join(failed_boards) + ) return 1 - logging.info('All boards compiled successfully!') + logging.info("All boards compiled successfully!") return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/util/chargen b/util/chargen index 9ba14d3d6a..a1f7947a14 100644 --- a/util/chargen +++ b/util/chargen @@ -5,6 +5,7 @@ import sys + def chargen(modulo, max_chars): """Generate a stream of characters on the console. @@ -18,7 +19,7 @@ def chargen(modulo, max_chars): zero, if zero - print indefinitely """ - base = '0' + base = "0" c = base counter = 0 while True: @@ -26,25 +27,25 @@ def chargen(modulo, max_chars): counter = counter + 1 if (max_chars != 0) and (counter == max_chars): - sys.stdout.write('\n') + sys.stdout.write("\n") return if modulo and ((counter % modulo) == 0): c = base continue - if c == 'z': + if c == "z": c = base - elif c == 'Z': - c = 'a' - elif c == '9': - c = 'A' + elif c == "Z": + c = "a" + elif c == "9": + c = "A" else: - c = '%c' % (ord(c) + 1) + c = "%c" % (ord(c) + 1) def main(args): - '''Process command line arguments and invoke chargen if args are valid''' + """Process command line arguments and invoke chargen if args are valid""" modulo = 0 max_chars = 0 @@ -55,8 +56,7 @@ def main(args): if len(args) > 1: max_chars = int(args[1]) except ValueError: - sys.stderr.write('usage %s:' - "['seq_length' ['max_chars']]\n") + sys.stderr.write("usage %s:" "['seq_length' ['max_chars']]\n") sys.exit(1) try: @@ -64,6 +64,7 @@ def main(args): except KeyboardInterrupt: print() -if __name__ == '__main__': + +if __name__ == "__main__": main(sys.argv[1:]) sys.exit(0) diff --git a/util/config_option_check.py b/util/config_option_check.py index 8bd8ecb1f0..0b05b64091 100755 --- a/util/config_option_check.py +++ b/util/config_option_check.py @@ -13,6 +13,7 @@ Script to ensure that all configuration options for the Chrome EC are defined in config.h. """ from __future__ import print_function + import enum import os import re @@ -21,368 +22,392 @@ import sys class Line(object): - """Class for each changed line in diff output. + """Class for each changed line in diff output. - Attributes: - line_num: The integer line number that this line appears in the file. - string: The literal string of this line. - line_type: '+' or '-' indicating if this line was an addition or - deletion. - """ + Attributes: + line_num: The integer line number that this line appears in the file. + string: The literal string of this line. + line_type: '+' or '-' indicating if this line was an addition or + deletion. + """ - def __init__(self, line_num, string, line_type): - """Inits Line with the line number and the actual string.""" - self.line_num = line_num - self.string = string - self.line_type = line_type + def __init__(self, line_num, string, line_type): + """Inits Line with the line number and the actual string.""" + self.line_num = line_num + self.string = string + self.line_type = line_type class Hunk(object): - """Class for a git diff hunk. + """Class for a git diff hunk. - Attributes: - filename: The name of the file that this hunk belongs to. - lines: A list of Line objects that are a part of this hunk. - """ + Attributes: + filename: The name of the file that this hunk belongs to. + lines: A list of Line objects that are a part of this hunk. + """ - def __init__(self, filename, lines): - """Inits Hunk with the filename and the list of lines of the hunk.""" - self.filename = filename - self.lines = lines + def __init__(self, filename, lines): + """Inits Hunk with the filename and the list of lines of the hunk.""" + self.filename = filename + self.lines = lines # Master file which is supposed to include all CONFIG_xxxx descriptions. -CONFIG_FILE = 'include/config.h' +CONFIG_FILE = "include/config.h" # Specific files which the checker should ignore. -ALLOWLIST = [CONFIG_FILE, 'util/config_option_check.py'] +ALLOWLIST = [CONFIG_FILE, "util/config_option_check.py"] # Specific directories which the checker should ignore. -ALLOW_PATTERN = re.compile('zephyr/.*') +ALLOW_PATTERN = re.compile("zephyr/.*") # Specific CONFIG_* flags which the checker should ignore. -ALLOWLIST_CONFIGS = ['CONFIG_ZTEST'] +ALLOWLIST_CONFIGS = ["CONFIG_ZTEST"] + def obtain_current_config_options(): - """Obtains current config options from include/config.h. - - Scans through the main config file defined in CONFIG_FILE for all CONFIG_* - options. - - Returns: - config_options: A list of all the config options in the main CONFIG_FILE. - """ - - config_options = [] - config_option_re = re.compile(r'^#(define|undef)\s+(CONFIG_[A-Z0-9_]+)') - with open(CONFIG_FILE, 'r') as config_file: - for line in config_file: - result = config_option_re.search(line) - if not result: - continue - word = result.groups()[1] - if word not in config_options: - config_options.append(word) - return config_options + """Obtains current config options from include/config.h. + + Scans through the main config file defined in CONFIG_FILE for all CONFIG_* + options. + + Returns: + config_options: A list of all the config options in the main CONFIG_FILE. + """ + + config_options = [] + config_option_re = re.compile(r"^#(define|undef)\s+(CONFIG_[A-Z0-9_]+)") + with open(CONFIG_FILE, "r") as config_file: + for line in config_file: + result = config_option_re.search(line) + if not result: + continue + word = result.groups()[1] + if word not in config_options: + config_options.append(word) + return config_options + def obtain_config_options_in_use(): - """Obtains all the config options in use in the repo. - - Scans through the entire repo looking for all CONFIG_* options actively used. - - Returns: - options_in_use: A set of all the config options in use in the repo. - """ - file_list = [] - cwd = os.getcwd() - config_option_re = re.compile(r'\b(CONFIG_[a-zA-Z0-9_]+)') - config_debug_option_re = re.compile(r'\b(CONFIG_DEBUG_[a-zA-Z0-9_]+)') - options_in_use = set() - for (dirpath, dirnames, filenames) in os.walk(cwd, topdown=True): - # Ignore the build and private directories (taken from .gitignore) - if 'build' in dirnames: - dirnames.remove('build') - if 'private' in dirnames: - dirnames.remove('private') - for f in filenames: - # Ignore hidden files. - if f.startswith('.'): - continue - # Only consider C source, assembler, and Make-style files. - if (os.path.splitext(f)[1] in ('.c', '.h', '.inc', '.S', '.mk') or - 'Makefile' in f): - file_list.append(os.path.join(dirpath, f)) - - # Search through each file and build a set of the CONFIG_* options being - # used. - - for f in file_list: - if CONFIG_FILE in f: - continue - with open(f, 'r') as cur_file: - for line in cur_file: - match = config_option_re.findall(line) - if match: - for option in match: - if not in_comment(f, line, option): - if option not in options_in_use: - options_in_use.add(option) - - # Since debug options can be turned on at any time, assume that they are - # always in use in case any aren't being used. - - with open(CONFIG_FILE, 'r') as config_file: - for line in config_file: - match = config_debug_option_re.findall(line) - if match: - for option in match: - if not in_comment(CONFIG_FILE, line, option): - if option not in options_in_use: - options_in_use.add(option) - - return options_in_use + """Obtains all the config options in use in the repo. + + Scans through the entire repo looking for all CONFIG_* options actively used. + + Returns: + options_in_use: A set of all the config options in use in the repo. + """ + file_list = [] + cwd = os.getcwd() + config_option_re = re.compile(r"\b(CONFIG_[a-zA-Z0-9_]+)") + config_debug_option_re = re.compile(r"\b(CONFIG_DEBUG_[a-zA-Z0-9_]+)") + options_in_use = set() + for (dirpath, dirnames, filenames) in os.walk(cwd, topdown=True): + # Ignore the build and private directories (taken from .gitignore) + if "build" in dirnames: + dirnames.remove("build") + if "private" in dirnames: + dirnames.remove("private") + for f in filenames: + # Ignore hidden files. + if f.startswith("."): + continue + # Only consider C source, assembler, and Make-style files. + if ( + os.path.splitext(f)[1] in (".c", ".h", ".inc", ".S", ".mk") + or "Makefile" in f + ): + file_list.append(os.path.join(dirpath, f)) + + # Search through each file and build a set of the CONFIG_* options being + # used. + + for f in file_list: + if CONFIG_FILE in f: + continue + with open(f, "r") as cur_file: + for line in cur_file: + match = config_option_re.findall(line) + if match: + for option in match: + if not in_comment(f, line, option): + if option not in options_in_use: + options_in_use.add(option) + + # Since debug options can be turned on at any time, assume that they are + # always in use in case any aren't being used. + + with open(CONFIG_FILE, "r") as config_file: + for line in config_file: + match = config_debug_option_re.findall(line) + if match: + for option in match: + if not in_comment(CONFIG_FILE, line, option): + if option not in options_in_use: + options_in_use.add(option) + + return options_in_use + def print_missing_config_options(hunks, config_options): - """Searches thru all the changes in hunks for missing options and prints them. - - Args: - hunks: A list of Hunk objects which represent the hunks from the git - diff output. - config_options: A list of all the config options in the main CONFIG_FILE. - - Returns: - missing_config_option: A boolean indicating if any CONFIG_* options - are missing from the main CONFIG_FILE in this commit or if any CONFIG_* - options removed are no longer being used in the repo. - """ - missing_config_option = False - print_banner = True - deprecated_options = set() - # Determine longest CONFIG_* length to be used for formatting. - max_option_length = max(len(option) for option in config_options) - config_option_re = re.compile(r'\b(CONFIG_[a-zA-Z0-9_]+)') - - # Search for all CONFIG_* options in use in the repo. - options_in_use = obtain_config_options_in_use() - - # Check each hunk's line for a missing config option. - for h in hunks: - for l in h.lines: - # Check for the existence of a CONFIG_* in the line. - match = filter(lambda opt: opt in ALLOWLIST_CONFIGS, - config_option_re.findall(l.string)) - if not match: - continue - - # At this point, an option was found in the line. However, we need to - # verify that it is not within a comment. - violations = set() - - for option in match: - if not in_comment(h.filename, l.string, option): - # Since the CONFIG_* option is not within a comment, we've found a - # violation. We now need to determine if this line is a deletion or - # not. For deletions, we will need to verify if this CONFIG_* option - # is no longer being used in the entire repo. - - if l.line_type == '-': - if option not in options_in_use and option in config_options: - deprecated_options.add(option) - else: - violations.add(option) - - # Check to see if the CONFIG_* option is in the config file and print the - # violations. - for option in match: - if option not in config_options and option in violations: - # Print the banner once. - if print_banner: - print('The following config options were found to be missing ' - 'from %s.\n' - 'Please add new config options there along with ' - 'descriptions.\n\n' % CONFIG_FILE) - print_banner = False + """Searches thru all the changes in hunks for missing options and prints them. + + Args: + hunks: A list of Hunk objects which represent the hunks from the git + diff output. + config_options: A list of all the config options in the main CONFIG_FILE. + + Returns: + missing_config_option: A boolean indicating if any CONFIG_* options + are missing from the main CONFIG_FILE in this commit or if any CONFIG_* + options removed are no longer being used in the repo. + """ + missing_config_option = False + print_banner = True + deprecated_options = set() + # Determine longest CONFIG_* length to be used for formatting. + max_option_length = max(len(option) for option in config_options) + config_option_re = re.compile(r"\b(CONFIG_[a-zA-Z0-9_]+)") + + # Search for all CONFIG_* options in use in the repo. + options_in_use = obtain_config_options_in_use() + + # Check each hunk's line for a missing config option. + for h in hunks: + for l in h.lines: + # Check for the existence of a CONFIG_* in the line. + match = filter( + lambda opt: opt in ALLOWLIST_CONFIGS, config_option_re.findall(l.string) + ) + if not match: + continue + + # At this point, an option was found in the line. However, we need to + # verify that it is not within a comment. + violations = set() + + for option in match: + if not in_comment(h.filename, l.string, option): + # Since the CONFIG_* option is not within a comment, we've found a + # violation. We now need to determine if this line is a deletion or + # not. For deletions, we will need to verify if this CONFIG_* option + # is no longer being used in the entire repo. + + if l.line_type == "-": + if option not in options_in_use and option in config_options: + deprecated_options.add(option) + else: + violations.add(option) + + # Check to see if the CONFIG_* option is in the config file and print the + # violations. + for option in match: + if option not in config_options and option in violations: + # Print the banner once. + if print_banner: + print( + "The following config options were found to be missing " + "from %s.\n" + "Please add new config options there along with " + "descriptions.\n\n" % CONFIG_FILE + ) + print_banner = False + missing_config_option = True + # Print the misssing config option. + print( + "> %-*s %s:%s" + % (max_option_length, option, h.filename, l.line_num) + ) + + if deprecated_options: + print( + "\n\nThe following config options are being removed and also appear" + " to be the last uses\nof that option. Please remove these " + "options from %s.\n\n" % CONFIG_FILE + ) + for option in deprecated_options: + print("> %s" % option) missing_config_option = True - # Print the misssing config option. - print('> %-*s %s:%s' % (max_option_length, option, - h.filename, - l.line_num)) - if deprecated_options: - print('\n\nThe following config options are being removed and also appear' - ' to be the last uses\nof that option. Please remove these ' - 'options from %s.\n\n' % CONFIG_FILE) - for option in deprecated_options: - print('> %s' % option) - missing_config_option = True + return missing_config_option - return missing_config_option def in_comment(filename, line, substr): - """Checks if given substring appears in a comment. - - Args: - filename: The filename where this line is from. This is used to determine - what kind of comments to look for. - line: String of line to search in. - substr: Substring to search for in the line. - - Returns: - is_in_comment: Boolean indicating if substr was in a comment. - """ - - c_style_ext = ('.c', '.h', '.inc', '.S') - make_style_ext = ('.mk') - is_in_comment = False - - extension = os.path.splitext(filename)[1] - substr_idx = line.find(substr) - - # Different files have different comment syntax; Handle appropriately. - if extension in c_style_ext: - beg_comment_idx = line.find('/*') - end_comment_idx = line.find('*/') - if end_comment_idx == -1: - end_comment_idx = len(line) - - if beg_comment_idx == -1: - # Check to see if this line is from a multi-line comment. - if line.lstrip().startswith('* '): - # It _seems_ like it is. - is_in_comment = True - else: - # Check to see if its actually inside the comment. - if beg_comment_idx < substr_idx < end_comment_idx: - is_in_comment = True - elif extension in make_style_ext or 'Makefile' in filename: - beg_comment_idx = line.find('#') - # Ignore everything to the right of the hash. - if beg_comment_idx < substr_idx and beg_comment_idx != -1: - is_in_comment = True - return is_in_comment + """Checks if given substring appears in a comment. + + Args: + filename: The filename where this line is from. This is used to determine + what kind of comments to look for. + line: String of line to search in. + substr: Substring to search for in the line. + + Returns: + is_in_comment: Boolean indicating if substr was in a comment. + """ + + c_style_ext = (".c", ".h", ".inc", ".S") + make_style_ext = ".mk" + is_in_comment = False + + extension = os.path.splitext(filename)[1] + substr_idx = line.find(substr) + + # Different files have different comment syntax; Handle appropriately. + if extension in c_style_ext: + beg_comment_idx = line.find("/*") + end_comment_idx = line.find("*/") + if end_comment_idx == -1: + end_comment_idx = len(line) + + if beg_comment_idx == -1: + # Check to see if this line is from a multi-line comment. + if line.lstrip().startswith("* "): + # It _seems_ like it is. + is_in_comment = True + else: + # Check to see if its actually inside the comment. + if beg_comment_idx < substr_idx < end_comment_idx: + is_in_comment = True + elif extension in make_style_ext or "Makefile" in filename: + beg_comment_idx = line.find("#") + # Ignore everything to the right of the hash. + if beg_comment_idx < substr_idx and beg_comment_idx != -1: + is_in_comment = True + return is_in_comment + def get_hunks(): - """Gets the hunks of the most recent commit. - - States: - new_file: Searching for a new file in the git diff. - filename_search: Searching for the filename of this hunk. - hunk: Searching for the beginning of a new hunk. - lines: Counting line numbers and searching for changes. - - Returns: - hunks: A list of Hunk objects which represent the hunks in the git diff - output. - """ - - diff = [] - hunks = [] - hunk_lines = [] - line = '' - filename = '' - i = 0 - line_num = 0 - - # Regex patterns - new_file_re = re.compile(r'^diff --git') - filename_re = re.compile(r'^[+]{3} (.*)') - hunk_line_num_re = re.compile(r'^@@ -[0-9]+,[0-9]+ \+([0-9]+),[0-9]+ @@.*') - line_re = re.compile(r'^([+| |-])(.*)') - - # Get the diff output. - proc = subprocess.run(['git', 'diff', '--cached', '-GCONFIG_*', '--no-prefix', - '--no-ext-diff', 'HEAD~1'], - stdout=subprocess.PIPE, - encoding='utf-8', - check=True) - diff = proc.stdout.splitlines() - if not diff: - return [] - line = diff[0] - - state = enum.Enum('state', 'NEW_FILE FILENAME_SEARCH HUNK LINES') - current_state = state.NEW_FILE - - while True: - # Search for the beginning of a new file. - if current_state is state.NEW_FILE: - match = new_file_re.search(line) - if match: - current_state = state.FILENAME_SEARCH - - # Search the diff output for a file name. - elif current_state is state.FILENAME_SEARCH: - # Search for a file name. - match = filename_re.search(line) - if match: - filename = match.groups(1)[0] - if filename in ALLOWLIST or ALLOW_PATTERN.match(filename): - # Skip the file if it's allowlisted. - current_state = state.NEW_FILE - else: - current_state = state.HUNK - - # Search for a hunk. Each hunk starts with a line describing the line - # numbers in the file. - elif current_state is state.HUNK: - hunk_lines = [] - match = hunk_line_num_re.search(line) - if match: - # Extract the line number offset. - line_num = int(match.groups(1)[0]) - current_state = state.LINES - - # Start looking for changes. - elif current_state is state.LINES: - # Check if state needs updating. - new_hunk = hunk_line_num_re.search(line) - new_file = new_file_re.search(line) - if new_hunk: - current_state = state.HUNK - hunks.append(Hunk(filename, hunk_lines)) - continue - elif new_file: - current_state = state.NEW_FILE - hunks.append(Hunk(filename, hunk_lines)) - continue - - match = line_re.search(line) - if match: - line_type = match.groups(1)[0] - # We only care about modifications. - if line_type != ' ': - hunk_lines.append(Line(line_num, match.groups(2)[1], line_type)) - # Deletions don't count towards the line numbers. - if line_type != '-': - line_num += 1 - - # Advance to the next line - try: - i += 1 - line = diff[i] - except IndexError: - # We've reached the end of the diff. Return what we have. - if hunk_lines: - hunks.append(Hunk(filename, hunk_lines)) - return hunks + """Gets the hunks of the most recent commit. + + States: + new_file: Searching for a new file in the git diff. + filename_search: Searching for the filename of this hunk. + hunk: Searching for the beginning of a new hunk. + lines: Counting line numbers and searching for changes. + + Returns: + hunks: A list of Hunk objects which represent the hunks in the git diff + output. + """ + + diff = [] + hunks = [] + hunk_lines = [] + line = "" + filename = "" + i = 0 + line_num = 0 + + # Regex patterns + new_file_re = re.compile(r"^diff --git") + filename_re = re.compile(r"^[+]{3} (.*)") + hunk_line_num_re = re.compile(r"^@@ -[0-9]+,[0-9]+ \+([0-9]+),[0-9]+ @@.*") + line_re = re.compile(r"^([+| |-])(.*)") + + # Get the diff output. + proc = subprocess.run( + [ + "git", + "diff", + "--cached", + "-GCONFIG_*", + "--no-prefix", + "--no-ext-diff", + "HEAD~1", + ], + stdout=subprocess.PIPE, + encoding="utf-8", + check=True, + ) + diff = proc.stdout.splitlines() + if not diff: + return [] + line = diff[0] + + state = enum.Enum("state", "NEW_FILE FILENAME_SEARCH HUNK LINES") + current_state = state.NEW_FILE + + while True: + # Search for the beginning of a new file. + if current_state is state.NEW_FILE: + match = new_file_re.search(line) + if match: + current_state = state.FILENAME_SEARCH + + # Search the diff output for a file name. + elif current_state is state.FILENAME_SEARCH: + # Search for a file name. + match = filename_re.search(line) + if match: + filename = match.groups(1)[0] + if filename in ALLOWLIST or ALLOW_PATTERN.match(filename): + # Skip the file if it's allowlisted. + current_state = state.NEW_FILE + else: + current_state = state.HUNK + + # Search for a hunk. Each hunk starts with a line describing the line + # numbers in the file. + elif current_state is state.HUNK: + hunk_lines = [] + match = hunk_line_num_re.search(line) + if match: + # Extract the line number offset. + line_num = int(match.groups(1)[0]) + current_state = state.LINES + + # Start looking for changes. + elif current_state is state.LINES: + # Check if state needs updating. + new_hunk = hunk_line_num_re.search(line) + new_file = new_file_re.search(line) + if new_hunk: + current_state = state.HUNK + hunks.append(Hunk(filename, hunk_lines)) + continue + elif new_file: + current_state = state.NEW_FILE + hunks.append(Hunk(filename, hunk_lines)) + continue + + match = line_re.search(line) + if match: + line_type = match.groups(1)[0] + # We only care about modifications. + if line_type != " ": + hunk_lines.append(Line(line_num, match.groups(2)[1], line_type)) + # Deletions don't count towards the line numbers. + if line_type != "-": + line_num += 1 + + # Advance to the next line + try: + i += 1 + line = diff[i] + except IndexError: + # We've reached the end of the diff. Return what we have. + if hunk_lines: + hunks.append(Hunk(filename, hunk_lines)) + return hunks + def main(): - """Searches through committed changes for missing config options. - - Checks through committed changes for CONFIG_* options. Then checks to make - sure that all CONFIG_* options used are defined in include/config.h. Finally, - reports any missing config options. - """ - # Obtain the hunks of the commit to search through. - hunks = get_hunks() - # Obtain config options from include/config.h. - config_options = obtain_current_config_options() - # Find any missing config options from the hunks and print them. - missing_opts = print_missing_config_options(hunks, config_options) - - if missing_opts: - print('\nIt may also be possible that you have a typo.') - sys.exit(1) - -if __name__ == '__main__': - main() + """Searches through committed changes for missing config options. + + Checks through committed changes for CONFIG_* options. Then checks to make + sure that all CONFIG_* options used are defined in include/config.h. Finally, + reports any missing config options. + """ + # Obtain the hunks of the commit to search through. + hunks = get_hunks() + # Obtain config options from include/config.h. + config_options = obtain_current_config_options() + # Find any missing config options from the hunks and print them. + missing_opts = print_missing_config_options(hunks, config_options) + + if missing_opts: + print("\nIt may also be possible that you have a typo.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/util/ec3po/console.py b/util/ec3po/console.py index e71216e3f2..33fa5a6775 100755 --- a/util/ec3po/console.py +++ b/util/ec3po/console.py @@ -17,7 +17,6 @@ from __future__ import print_function import argparse import binascii import ctypes -from datetime import datetime import logging import os import pty @@ -26,26 +25,25 @@ import select import stat import sys import traceback +from datetime import datetime import six +from ec3po import interpreter, threadproc_shim -from ec3po import interpreter -from ec3po import threadproc_shim - - -PROMPT = b'> ' +PROMPT = b"> " CONSOLE_INPUT_LINE_SIZE = 80 # Taken from the CONFIG_* with the same name. CONSOLE_MAX_READ = 100 # Max bytes to read at a time from the user. LOOK_BUFFER_SIZE = 256 # Size of search window when looking for the enhanced EC - # image string. +# image string. # In console_init(), the EC will print a string saying that the EC console is # enabled. Enhanced images will print a slightly different string. These # regular expressions are used to determine at reboot whether the EC image is # enhanced or not. -ENHANCED_IMAGE_RE = re.compile(br'Enhanced Console is enabled ' - br'\(v([0-9]+\.[0-9]+\.[0-9]+)\)') -NON_ENHANCED_IMAGE_RE = re.compile(br'Console is enabled; ') +ENHANCED_IMAGE_RE = re.compile( + rb"Enhanced Console is enabled " rb"\(v([0-9]+\.[0-9]+\.[0-9]+)\)" +) +NON_ENHANCED_IMAGE_RE = re.compile(rb"Console is enabled; ") # The timeouts are really only useful for enhanced EC images, but otherwise just # serve as a delay for non-enhanced EC images. Therefore, we can keep this @@ -54,1118 +52,1173 @@ NON_ENHANCED_IMAGE_RE = re.compile(br'Console is enabled; ') # EC image, we can increase the timeout for stability just in case it takes a # bit longer to receive an ACK for some reason. NON_ENHANCED_EC_INTERROGATION_TIMEOUT = 0.3 # Maximum number of seconds to wait - # for a response to an - # interrogation of a non-enhanced - # EC image. +# for a response to an +# interrogation of a non-enhanced +# EC image. ENHANCED_EC_INTERROGATION_TIMEOUT = 1.0 # Maximum number of seconds to wait for - # a response to an interrogation of an - # enhanced EC image. +# a response to an interrogation of an +# enhanced EC image. # List of modes which control when interrogations are performed with the EC. -INTERROGATION_MODES = [b'never', b'always', b'auto'] +INTERROGATION_MODES = [b"never", b"always", b"auto"] # Format for printing host timestamp -HOST_STRFTIME="%y-%m-%d %H:%M:%S.%f" +HOST_STRFTIME = "%y-%m-%d %H:%M:%S.%f" class EscState(object): - """Class which contains an enumeration for states of ESC sequences.""" - ESC_START = 1 - ESC_BRACKET = 2 - ESC_BRACKET_1 = 3 - ESC_BRACKET_3 = 4 - ESC_BRACKET_8 = 5 - + """Class which contains an enumeration for states of ESC sequences.""" -class ControlKey(object): - """Class which contains codes for various control keys.""" - BACKSPACE = 0x08 - CTRL_A = 0x01 - CTRL_B = 0x02 - CTRL_D = 0x04 - CTRL_E = 0x05 - CTRL_F = 0x06 - CTRL_K = 0x0b - CTRL_N = 0xe - CTRL_P = 0x10 - CARRIAGE_RETURN = 0x0d - ESC = 0x1b + ESC_START = 1 + ESC_BRACKET = 2 + ESC_BRACKET_1 = 3 + ESC_BRACKET_3 = 4 + ESC_BRACKET_8 = 5 -class Console(object): - """Class which provides the console interface between the EC and the user. - - This class essentially represents the console interface between the user and - the EC. It handles all of the console editing behaviour - - Attributes: - logger: A logger for this module. - controller_pty: File descriptor to the controller side of the PTY. Used for - driving output to the user and receiving user input. - user_pty: A string representing the PTY name of the served console. - cmd_pipe: A socket.socket or multiprocessing.Connection object which - represents the console side of the command pipe. This must be a - bidirectional pipe. Console commands and responses utilize this pipe. - dbg_pipe: A socket.socket or multiprocessing.Connection object which - represents the console's read-only side of the debug pipe. This must be a - unidirectional pipe attached to the intepreter. EC debug messages use - this pipe. - oobm_queue: A queue.Queue or multiprocessing.Queue which is used for out of - band management for the interactive console. - input_buffer: A string representing the current input command. - input_buffer_pos: An integer representing the current position in the buffer - to insert a char. - partial_cmd: A string representing the command entered on a line before - pressing the up arrow keys. - esc_state: An integer represeting the current state within an escape - sequence. - line_limit: An integer representing the maximum number of characters on a - line. - history: A list of strings containing the past entered console commands. - history_pos: An integer representing the current history buffer position. - This index is used to show previous commands. - prompt: A string representing the console prompt displayed to the user. - enhanced_ec: A boolean indicating if the EC image that we are currently - communicating with is enhanced or not. Enhanced EC images will support - packed commands and host commands over the UART. This defaults to False - until we perform some handshaking. - interrogation_timeout: A float representing the current maximum seconds to - wait for a response to an interrogation. - receiving_oobm_cmd: A boolean indicating whether or not the console is in - the middle of receiving an out of band command. - pending_oobm_cmd: A string containing the pending OOBM command. - interrogation_mode: A string containing the current mode of whether - interrogations are performed with the EC or not and how often. - raw_debug: Flag to indicate whether per interrupt data should be logged to - debug - output_line_log_buffer: buffer for lines coming from the EC to log to debug - """ - - def __init__(self, controller_pty, user_pty, interface_pty, cmd_pipe, dbg_pipe, - name=None): - """Initalises a Console object with the provided arguments. +class ControlKey(object): + """Class which contains codes for various control keys.""" - Args: - controller_pty: File descriptor to the controller side of the PTY. Used for - driving output to the user and receiving user input. - user_pty: A string representing the PTY name of the served console. - interface_pty: A string representing the PTY name of the served command - interface. - cmd_pipe: A socket.socket or multiprocessing.Connection object which - represents the console side of the command pipe. This must be a - bidirectional pipe. Console commands and responses utilize this pipe. - dbg_pipe: A socket.socket or multiprocessing.Connection object which - represents the console's read-only side of the debug pipe. This must be a - unidirectional pipe attached to the intepreter. EC debug messages use - this pipe. - name: the console source name - """ - # Create a unique logger based on the console name - console_prefix = ('%s - ' % name) if name else '' - logger = logging.getLogger('%sEC3PO.Console' % console_prefix) - self.logger = interpreter.LoggerAdapter(logger, {'pty': user_pty}) - self.controller_pty = controller_pty - self.user_pty = user_pty - self.interface_pty = interface_pty - self.cmd_pipe = cmd_pipe - self.dbg_pipe = dbg_pipe - self.oobm_queue = threadproc_shim.Queue() - self.input_buffer = b'' - self.input_buffer_pos = 0 - self.partial_cmd = b'' - self.esc_state = 0 - self.line_limit = CONSOLE_INPUT_LINE_SIZE - self.history = [] - self.history_pos = 0 - self.prompt = PROMPT - self.enhanced_ec = False - self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT - self.receiving_oobm_cmd = False - self.pending_oobm_cmd = b'' - self.interrogation_mode = b'auto' - self.timestamp_enabled = True - self.look_buffer = b'' - self.raw_debug = False - self.output_line_log_buffer = [] - - def __str__(self): - """Show internal state of Console object as a string.""" - string = [] - string.append('controller_pty: %s' % self.controller_pty) - string.append('user_pty: %s' % self.user_pty) - string.append('interface_pty: %s' % self.interface_pty) - string.append('cmd_pipe: %s' % self.cmd_pipe) - string.append('dbg_pipe: %s' % self.dbg_pipe) - string.append('oobm_queue: %s' % self.oobm_queue) - string.append('input_buffer: %s' % self.input_buffer) - string.append('input_buffer_pos: %d' % self.input_buffer_pos) - string.append('esc_state: %d' % self.esc_state) - string.append('line_limit: %d' % self.line_limit) - string.append('history: %r' % self.history) - string.append('history_pos: %d' % self.history_pos) - string.append('prompt: %r' % self.prompt) - string.append('partial_cmd: %r'% self.partial_cmd) - string.append('interrogation_mode: %r' % self.interrogation_mode) - string.append('look_buffer: %r' % self.look_buffer) - return '\n'.join(string) - - def LogConsoleOutput(self, data): - """Log to debug user MCU output to controller_pty when line is filled. - - The logging also suppresses the Cr50 spinner lines by removing characters - when it sees backspaces. + BACKSPACE = 0x08 + CTRL_A = 0x01 + CTRL_B = 0x02 + CTRL_D = 0x04 + CTRL_E = 0x05 + CTRL_F = 0x06 + CTRL_K = 0x0B + CTRL_N = 0xE + CTRL_P = 0x10 + CARRIAGE_RETURN = 0x0D + ESC = 0x1B - Args: - data: bytes - string received from MCU - """ - data = list(data) - # For compatibility with python2 and python3, standardize on the data - # being a list of integers. This requires one more transformation in py2 - if not isinstance(data[0], int): - data = [ord(c) for c in data] - - # This is a list of already filtered characters (or placeholders). - line = self.output_line_log_buffer - - # TODO(b/177480273): use raw strings here - symbols = { - ord(b'\n'): u'\\n', - ord(b'\r'): u'\\r', - ord(b'\t'): u'\\t' - } - # self.logger.debug(u'%s + %r', u''.join(line), ''.join(data)) - while data: - # Recall, data is a list of integers, namely the byte values sent by - # the MCU. - byte = data.pop(0) - # This means that |byte| is an int. - if byte == ord('\n'): - line.append(symbols[byte]) - if line: - self.logger.debug(u'%s', ''.join(line)) - line = [] - elif byte == ord('\b'): - # Backspace: trim the last character off the buffer - if line: - line.pop(-1) - elif byte in symbols: - line.append(symbols[byte]) - elif byte < ord(' ') or byte > ord('~'): - # Turn any character that isn't printable ASCII into escaped hex. - # ' ' is chr(20), and 0-19 are unprintable control characters. - # '~' is chr(126), and 127 is DELETE. 128-255 are control and Latin-1. - line.append(u'\\x%02x' % byte) - else: - # byte is printable. Thus it is safe to use chr() to get the printable - # character out of it again. - line.append(u'%s' % chr(byte)) - self.output_line_log_buffer = line - - def PrintHistory(self): - """Print the history of entered commands.""" - fd = self.controller_pty - # Make it pretty by figuring out how wide to pad the numbers. - wide = (len(self.history) // 10) + 1 - for i in range(len(self.history)): - line = b' %*d %s\r\n' % (wide, i, self.history[i]) - os.write(fd, line) - - def ShowPreviousCommand(self): - """Shows the previous command from the history list.""" - # There's nothing to do if there's no history at all. - if not self.history: - self.logger.debug('No history to print.') - return - - # Don't do anything if there's no more history to show. - if self.history_pos == 0: - self.logger.debug('No more history to show.') - return - - self.logger.debug('current history position: %d.', self.history_pos) - - # Decrement the history buffer position. - self.history_pos -= 1 - self.logger.debug('new history position.: %d', self.history_pos) - - # Save the text entered on the console if any. - if self.history_pos == len(self.history)-1: - self.logger.debug('saving partial_cmd: %r', self.input_buffer) - self.partial_cmd = self.input_buffer - - # Backspace the line. - for _ in range(self.input_buffer_pos): - self.SendBackspace() - - # Print the last entry in the history buffer. - self.logger.debug('printing previous entry %d - %s', self.history_pos, - self.history[self.history_pos]) - fd = self.controller_pty - prev_cmd = self.history[self.history_pos] - os.write(fd, prev_cmd) - # Update the input buffer. - self.input_buffer = prev_cmd - self.input_buffer_pos = len(prev_cmd) - - def ShowNextCommand(self): - """Shows the next command from the history list.""" - # Don't do anything if there's no history at all. - if not self.history: - self.logger.debug('History buffer is empty.') - return - - fd = self.controller_pty - - self.logger.debug('current history position: %d', self.history_pos) - # Increment the history position. - self.history_pos += 1 - - # Restore the partial cmd. - if self.history_pos == len(self.history): - self.logger.debug('Restoring partial command of %r', self.partial_cmd) - # Backspace the line. - for _ in range(self.input_buffer_pos): - self.SendBackspace() - # Print the partially entered command if any. - os.write(fd, self.partial_cmd) - self.input_buffer = self.partial_cmd - self.input_buffer_pos = len(self.input_buffer) - # Now that we've printed it, clear the partial cmd storage. - self.partial_cmd = b'' - # Reset history position. - self.history_pos = len(self.history) - return - - self.logger.debug('new history position: %d', self.history_pos) - if self.history_pos > len(self.history)-1: - self.logger.debug('No more history to show.') - self.history_pos -= 1 - self.logger.debug('Reset history position to %d', self.history_pos) - return - - # Backspace the line. - for _ in range(self.input_buffer_pos): - self.SendBackspace() - - # Print the newer entry from the history buffer. - self.logger.debug('printing next entry %d - %s', self.history_pos, - self.history[self.history_pos]) - next_cmd = self.history[self.history_pos] - os.write(fd, next_cmd) - # Update the input buffer. - self.input_buffer = next_cmd - self.input_buffer_pos = len(next_cmd) - self.logger.debug('new history position: %d.', self.history_pos) - - def SliceOutChar(self): - """Remove a char from the line and shift everything over 1 column.""" - fd = self.controller_pty - # Remove the character at the input_buffer_pos by slicing it out. - self.input_buffer = self.input_buffer[0:self.input_buffer_pos] + \ - self.input_buffer[self.input_buffer_pos+1:] - # Write the rest of the line - moved_col = os.write(fd, self.input_buffer[self.input_buffer_pos:]) - # Write a space to clear out the last char - moved_col += os.write(fd, b' ') - # Update the input buffer position. - self.input_buffer_pos += moved_col - # Reset the cursor - self.MoveCursor('left', moved_col) - - def HandleEsc(self, byte): - """HandleEsc processes escape sequences. - Args: - byte: An integer representing the current byte in the sequence. +class Console(object): + """Class which provides the console interface between the EC and the user. + + This class essentially represents the console interface between the user and + the EC. It handles all of the console editing behaviour + + Attributes: + logger: A logger for this module. + controller_pty: File descriptor to the controller side of the PTY. Used for + driving output to the user and receiving user input. + user_pty: A string representing the PTY name of the served console. + cmd_pipe: A socket.socket or multiprocessing.Connection object which + represents the console side of the command pipe. This must be a + bidirectional pipe. Console commands and responses utilize this pipe. + dbg_pipe: A socket.socket or multiprocessing.Connection object which + represents the console's read-only side of the debug pipe. This must be a + unidirectional pipe attached to the intepreter. EC debug messages use + this pipe. + oobm_queue: A queue.Queue or multiprocessing.Queue which is used for out of + band management for the interactive console. + input_buffer: A string representing the current input command. + input_buffer_pos: An integer representing the current position in the buffer + to insert a char. + partial_cmd: A string representing the command entered on a line before + pressing the up arrow keys. + esc_state: An integer represeting the current state within an escape + sequence. + line_limit: An integer representing the maximum number of characters on a + line. + history: A list of strings containing the past entered console commands. + history_pos: An integer representing the current history buffer position. + This index is used to show previous commands. + prompt: A string representing the console prompt displayed to the user. + enhanced_ec: A boolean indicating if the EC image that we are currently + communicating with is enhanced or not. Enhanced EC images will support + packed commands and host commands over the UART. This defaults to False + until we perform some handshaking. + interrogation_timeout: A float representing the current maximum seconds to + wait for a response to an interrogation. + receiving_oobm_cmd: A boolean indicating whether or not the console is in + the middle of receiving an out of band command. + pending_oobm_cmd: A string containing the pending OOBM command. + interrogation_mode: A string containing the current mode of whether + interrogations are performed with the EC or not and how often. + raw_debug: Flag to indicate whether per interrupt data should be logged to + debug + output_line_log_buffer: buffer for lines coming from the EC to log to debug """ - # We shouldn't be handling an escape sequence if we haven't seen one. - assert self.esc_state != 0 - - if self.esc_state is EscState.ESC_START: - self.logger.debug('ESC_START') - if byte == ord('['): - self.esc_state = EscState.ESC_BRACKET - return - else: - self.logger.error('Unexpected sequence. %c', byte) - self.esc_state = 0 - - elif self.esc_state is EscState.ESC_BRACKET: - self.logger.debug('ESC_BRACKET') - # Left Arrow key was pressed. - if byte == ord('D'): - self.logger.debug('Left arrow key pressed.') - self.MoveCursor('left', 1) - self.esc_state = 0 # Reset the state. - return - - # Right Arrow key. - elif byte == ord('C'): - self.logger.debug('Right arrow key pressed.') - self.MoveCursor('right', 1) - self.esc_state = 0 # Reset the state. - return - - # Up Arrow key. - elif byte == ord('A'): - self.logger.debug('Up arrow key pressed.') - self.ShowPreviousCommand() - # Reset the state. - self.esc_state = 0 # Reset the state. - return - - # Down Arrow key. - elif byte == ord('B'): - self.logger.debug('Down arrow key pressed.') - self.ShowNextCommand() - # Reset the state. - self.esc_state = 0 # Reset the state. - return - - # For some reason, minicom sends a 1 instead of 7. /shrug - # TODO(aaboagye): Figure out why this happens. - elif byte == ord('1') or byte == ord('7'): - self.esc_state = EscState.ESC_BRACKET_1 - - elif byte == ord('3'): - self.esc_state = EscState.ESC_BRACKET_3 - - elif byte == ord('8'): - self.esc_state = EscState.ESC_BRACKET_8 - - else: - self.logger.error(r'Bad or unhandled escape sequence. got ^[%c\(%d)', - chr(byte), byte) - self.esc_state = 0 - return - - elif self.esc_state is EscState.ESC_BRACKET_1: - self.logger.debug('ESC_BRACKET_1') - # HOME key. - if byte == ord('~'): - self.logger.debug('Home key pressed.') - self.MoveCursor('left', self.input_buffer_pos) - self.esc_state = 0 # Reset the state. - self.logger.debug('ESC sequence complete.') - return - - elif self.esc_state is EscState.ESC_BRACKET_3: - self.logger.debug('ESC_BRACKET_3') - # DEL key. - if byte == ord('~'): - self.logger.debug('Delete key pressed.') - if self.input_buffer_pos != len(self.input_buffer): - self.SliceOutChar() - self.esc_state = 0 # Reset the state. - - elif self.esc_state is EscState.ESC_BRACKET_8: - self.logger.debug('ESC_BRACKET_8') - # END key. - if byte == ord('~'): - self.logger.debug('End key pressed.') - self.MoveCursor('right', - len(self.input_buffer) - self.input_buffer_pos) - self.esc_state = 0 # Reset the state. - self.logger.debug('ESC sequence complete.') - return - - else: - self.logger.error('Unexpected sequence. %c', byte) + def __init__( + self, controller_pty, user_pty, interface_pty, cmd_pipe, dbg_pipe, name=None + ): + """Initalises a Console object with the provided arguments. + + Args: + controller_pty: File descriptor to the controller side of the PTY. Used for + driving output to the user and receiving user input. + user_pty: A string representing the PTY name of the served console. + interface_pty: A string representing the PTY name of the served command + interface. + cmd_pipe: A socket.socket or multiprocessing.Connection object which + represents the console side of the command pipe. This must be a + bidirectional pipe. Console commands and responses utilize this pipe. + dbg_pipe: A socket.socket or multiprocessing.Connection object which + represents the console's read-only side of the debug pipe. This must be a + unidirectional pipe attached to the intepreter. EC debug messages use + this pipe. + name: the console source name + """ + # Create a unique logger based on the console name + console_prefix = ("%s - " % name) if name else "" + logger = logging.getLogger("%sEC3PO.Console" % console_prefix) + self.logger = interpreter.LoggerAdapter(logger, {"pty": user_pty}) + self.controller_pty = controller_pty + self.user_pty = user_pty + self.interface_pty = interface_pty + self.cmd_pipe = cmd_pipe + self.dbg_pipe = dbg_pipe + self.oobm_queue = threadproc_shim.Queue() + self.input_buffer = b"" + self.input_buffer_pos = 0 + self.partial_cmd = b"" self.esc_state = 0 + self.line_limit = CONSOLE_INPUT_LINE_SIZE + self.history = [] + self.history_pos = 0 + self.prompt = PROMPT + self.enhanced_ec = False + self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT + self.receiving_oobm_cmd = False + self.pending_oobm_cmd = b"" + self.interrogation_mode = b"auto" + self.timestamp_enabled = True + self.look_buffer = b"" + self.raw_debug = False + self.output_line_log_buffer = [] + + def __str__(self): + """Show internal state of Console object as a string.""" + string = [] + string.append("controller_pty: %s" % self.controller_pty) + string.append("user_pty: %s" % self.user_pty) + string.append("interface_pty: %s" % self.interface_pty) + string.append("cmd_pipe: %s" % self.cmd_pipe) + string.append("dbg_pipe: %s" % self.dbg_pipe) + string.append("oobm_queue: %s" % self.oobm_queue) + string.append("input_buffer: %s" % self.input_buffer) + string.append("input_buffer_pos: %d" % self.input_buffer_pos) + string.append("esc_state: %d" % self.esc_state) + string.append("line_limit: %d" % self.line_limit) + string.append("history: %r" % self.history) + string.append("history_pos: %d" % self.history_pos) + string.append("prompt: %r" % self.prompt) + string.append("partial_cmd: %r" % self.partial_cmd) + string.append("interrogation_mode: %r" % self.interrogation_mode) + string.append("look_buffer: %r" % self.look_buffer) + return "\n".join(string) + + def LogConsoleOutput(self, data): + """Log to debug user MCU output to controller_pty when line is filled. + + The logging also suppresses the Cr50 spinner lines by removing characters + when it sees backspaces. + + Args: + data: bytes - string received from MCU + """ + data = list(data) + # For compatibility with python2 and python3, standardize on the data + # being a list of integers. This requires one more transformation in py2 + if not isinstance(data[0], int): + data = [ord(c) for c in data] + + # This is a list of already filtered characters (or placeholders). + line = self.output_line_log_buffer + + # TODO(b/177480273): use raw strings here + symbols = {ord(b"\n"): "\\n", ord(b"\r"): "\\r", ord(b"\t"): "\\t"} + # self.logger.debug(u'%s + %r', u''.join(line), ''.join(data)) + while data: + # Recall, data is a list of integers, namely the byte values sent by + # the MCU. + byte = data.pop(0) + # This means that |byte| is an int. + if byte == ord("\n"): + line.append(symbols[byte]) + if line: + self.logger.debug("%s", "".join(line)) + line = [] + elif byte == ord("\b"): + # Backspace: trim the last character off the buffer + if line: + line.pop(-1) + elif byte in symbols: + line.append(symbols[byte]) + elif byte < ord(" ") or byte > ord("~"): + # Turn any character that isn't printable ASCII into escaped hex. + # ' ' is chr(20), and 0-19 are unprintable control characters. + # '~' is chr(126), and 127 is DELETE. 128-255 are control and Latin-1. + line.append("\\x%02x" % byte) + else: + # byte is printable. Thus it is safe to use chr() to get the printable + # character out of it again. + line.append("%s" % chr(byte)) + self.output_line_log_buffer = line + + def PrintHistory(self): + """Print the history of entered commands.""" + fd = self.controller_pty + # Make it pretty by figuring out how wide to pad the numbers. + wide = (len(self.history) // 10) + 1 + for i in range(len(self.history)): + line = b" %*d %s\r\n" % (wide, i, self.history[i]) + os.write(fd, line) + + def ShowPreviousCommand(self): + """Shows the previous command from the history list.""" + # There's nothing to do if there's no history at all. + if not self.history: + self.logger.debug("No history to print.") + return + + # Don't do anything if there's no more history to show. + if self.history_pos == 0: + self.logger.debug("No more history to show.") + return + + self.logger.debug("current history position: %d.", self.history_pos) + + # Decrement the history buffer position. + self.history_pos -= 1 + self.logger.debug("new history position.: %d", self.history_pos) + + # Save the text entered on the console if any. + if self.history_pos == len(self.history) - 1: + self.logger.debug("saving partial_cmd: %r", self.input_buffer) + self.partial_cmd = self.input_buffer + + # Backspace the line. + for _ in range(self.input_buffer_pos): + self.SendBackspace() + + # Print the last entry in the history buffer. + self.logger.debug( + "printing previous entry %d - %s", + self.history_pos, + self.history[self.history_pos], + ) + fd = self.controller_pty + prev_cmd = self.history[self.history_pos] + os.write(fd, prev_cmd) + # Update the input buffer. + self.input_buffer = prev_cmd + self.input_buffer_pos = len(prev_cmd) + + def ShowNextCommand(self): + """Shows the next command from the history list.""" + # Don't do anything if there's no history at all. + if not self.history: + self.logger.debug("History buffer is empty.") + return + + fd = self.controller_pty + + self.logger.debug("current history position: %d", self.history_pos) + # Increment the history position. + self.history_pos += 1 + + # Restore the partial cmd. + if self.history_pos == len(self.history): + self.logger.debug("Restoring partial command of %r", self.partial_cmd) + # Backspace the line. + for _ in range(self.input_buffer_pos): + self.SendBackspace() + # Print the partially entered command if any. + os.write(fd, self.partial_cmd) + self.input_buffer = self.partial_cmd + self.input_buffer_pos = len(self.input_buffer) + # Now that we've printed it, clear the partial cmd storage. + self.partial_cmd = b"" + # Reset history position. + self.history_pos = len(self.history) + return + + self.logger.debug("new history position: %d", self.history_pos) + if self.history_pos > len(self.history) - 1: + self.logger.debug("No more history to show.") + self.history_pos -= 1 + self.logger.debug("Reset history position to %d", self.history_pos) + return + + # Backspace the line. + for _ in range(self.input_buffer_pos): + self.SendBackspace() + + # Print the newer entry from the history buffer. + self.logger.debug( + "printing next entry %d - %s", + self.history_pos, + self.history[self.history_pos], + ) + next_cmd = self.history[self.history_pos] + os.write(fd, next_cmd) + # Update the input buffer. + self.input_buffer = next_cmd + self.input_buffer_pos = len(next_cmd) + self.logger.debug("new history position: %d.", self.history_pos) + + def SliceOutChar(self): + """Remove a char from the line and shift everything over 1 column.""" + fd = self.controller_pty + # Remove the character at the input_buffer_pos by slicing it out. + self.input_buffer = ( + self.input_buffer[0 : self.input_buffer_pos] + + self.input_buffer[self.input_buffer_pos + 1 :] + ) + # Write the rest of the line + moved_col = os.write(fd, self.input_buffer[self.input_buffer_pos :]) + # Write a space to clear out the last char + moved_col += os.write(fd, b" ") + # Update the input buffer position. + self.input_buffer_pos += moved_col + # Reset the cursor + self.MoveCursor("left", moved_col) + + def HandleEsc(self, byte): + """HandleEsc processes escape sequences. + + Args: + byte: An integer representing the current byte in the sequence. + """ + # We shouldn't be handling an escape sequence if we haven't seen one. + assert self.esc_state != 0 + + if self.esc_state is EscState.ESC_START: + self.logger.debug("ESC_START") + if byte == ord("["): + self.esc_state = EscState.ESC_BRACKET + return + + else: + self.logger.error("Unexpected sequence. %c", byte) + self.esc_state = 0 + + elif self.esc_state is EscState.ESC_BRACKET: + self.logger.debug("ESC_BRACKET") + # Left Arrow key was pressed. + if byte == ord("D"): + self.logger.debug("Left arrow key pressed.") + self.MoveCursor("left", 1) + self.esc_state = 0 # Reset the state. + return + + # Right Arrow key. + elif byte == ord("C"): + self.logger.debug("Right arrow key pressed.") + self.MoveCursor("right", 1) + self.esc_state = 0 # Reset the state. + return + + # Up Arrow key. + elif byte == ord("A"): + self.logger.debug("Up arrow key pressed.") + self.ShowPreviousCommand() + # Reset the state. + self.esc_state = 0 # Reset the state. + return + + # Down Arrow key. + elif byte == ord("B"): + self.logger.debug("Down arrow key pressed.") + self.ShowNextCommand() + # Reset the state. + self.esc_state = 0 # Reset the state. + return + + # For some reason, minicom sends a 1 instead of 7. /shrug + # TODO(aaboagye): Figure out why this happens. + elif byte == ord("1") or byte == ord("7"): + self.esc_state = EscState.ESC_BRACKET_1 + + elif byte == ord("3"): + self.esc_state = EscState.ESC_BRACKET_3 + + elif byte == ord("8"): + self.esc_state = EscState.ESC_BRACKET_8 + + else: + self.logger.error( + r"Bad or unhandled escape sequence. got ^[%c\(%d)", chr(byte), byte + ) + self.esc_state = 0 + return + + elif self.esc_state is EscState.ESC_BRACKET_1: + self.logger.debug("ESC_BRACKET_1") + # HOME key. + if byte == ord("~"): + self.logger.debug("Home key pressed.") + self.MoveCursor("left", self.input_buffer_pos) + self.esc_state = 0 # Reset the state. + self.logger.debug("ESC sequence complete.") + return + + elif self.esc_state is EscState.ESC_BRACKET_3: + self.logger.debug("ESC_BRACKET_3") + # DEL key. + if byte == ord("~"): + self.logger.debug("Delete key pressed.") + if self.input_buffer_pos != len(self.input_buffer): + self.SliceOutChar() + self.esc_state = 0 # Reset the state. + + elif self.esc_state is EscState.ESC_BRACKET_8: + self.logger.debug("ESC_BRACKET_8") + # END key. + if byte == ord("~"): + self.logger.debug("End key pressed.") + self.MoveCursor("right", len(self.input_buffer) - self.input_buffer_pos) + self.esc_state = 0 # Reset the state. + self.logger.debug("ESC sequence complete.") + return + + else: + self.logger.error("Unexpected sequence. %c", byte) + self.esc_state = 0 + + else: + self.logger.error("Unexpected sequence. %c", byte) + self.esc_state = 0 + + def ProcessInput(self): + """Captures the input determines what actions to take.""" + # There's nothing to do if the input buffer is empty. + if len(self.input_buffer) == 0: + return + + # Don't store 2 consecutive identical commands in the history. + if self.history and self.history[-1] != self.input_buffer or not self.history: + self.history.append(self.input_buffer) + + # Split the command up by spaces. + line = self.input_buffer.split(b" ") + self.logger.debug("cmd: %s", self.input_buffer) + cmd = line[0].lower() + + # The 'history' command is a special case that we handle locally. + if cmd == "history": + self.PrintHistory() + return + + # Send the command to the interpreter. + self.logger.debug("Sending command to interpreter.") + self.cmd_pipe.send(self.input_buffer) + + def CheckForEnhancedECImage(self): + """Performs an interrogation of the EC image. + + Send a SYN and expect an ACK. If no ACK or the response is incorrect, then + assume that the current EC image that we are talking to is not enhanced. + + Returns: + is_enhanced: A boolean indicating whether the EC responded to the + interrogation correctly. + + Raises: + EOFError: Allowed to propagate through from self.dbg_pipe.recv(). + """ + # Send interrogation byte and wait for the response. + self.logger.debug("Performing interrogation.") + self.cmd_pipe.send(interpreter.EC_SYN) + + response = "" + if self.dbg_pipe.poll(self.interrogation_timeout): + response = self.dbg_pipe.recv() + self.logger.debug("response: %r", binascii.hexlify(response)) + else: + self.logger.debug("Timed out waiting for EC_ACK") + + # Verify the acknowledgment. + is_enhanced = response == interpreter.EC_ACK + + if is_enhanced: + # Increase the interrogation timeout for stability purposes. + self.interrogation_timeout = ENHANCED_EC_INTERROGATION_TIMEOUT + self.logger.debug( + "Increasing interrogation timeout to %rs.", self.interrogation_timeout + ) + else: + # Reduce the timeout in order to reduce the perceivable delay. + self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT + self.logger.debug( + "Reducing interrogation timeout to %rs.", self.interrogation_timeout + ) + + return is_enhanced + + def HandleChar(self, byte): + """HandleChar does a certain action when it receives a character. + + Args: + byte: An integer representing the character received from the user. + + Raises: + EOFError: Allowed to propagate through from self.CheckForEnhancedECImage() + i.e. from self.dbg_pipe.recv(). + """ + fd = self.controller_pty + + # Enter the OOBM prompt mode if the user presses '%'. + if byte == ord("%"): + self.logger.debug("Begin OOBM command.") + self.receiving_oobm_cmd = True + # Print a "prompt". + os.write(self.controller_pty, b"\r\n% ") + return + + # Add chars to the pending OOBM command if we're currently receiving one. + if self.receiving_oobm_cmd and byte != ControlKey.CARRIAGE_RETURN: + tmp_bytes = six.int2byte(byte) + self.pending_oobm_cmd += tmp_bytes + self.logger.debug("%s", tmp_bytes) + os.write(self.controller_pty, tmp_bytes) + return + + if byte == ControlKey.CARRIAGE_RETURN: + if self.receiving_oobm_cmd: + # Terminate the command and place it in the OOBM queue. + self.logger.debug("End OOBM command.") + if self.pending_oobm_cmd: + self.oobm_queue.put(self.pending_oobm_cmd) + self.logger.debug( + "Placed %r into OOBM command queue.", self.pending_oobm_cmd + ) + + # Reset the state. + os.write(self.controller_pty, b"\r\n" + self.prompt) + self.input_buffer = b"" + self.input_buffer_pos = 0 + self.receiving_oobm_cmd = False + self.pending_oobm_cmd = b"" + return + + if self.interrogation_mode == b"never": + self.logger.debug( + "Skipping interrogation because interrogation mode" + " is set to never." + ) + elif self.interrogation_mode == b"always": + # Only interrogate the EC if the interrogation mode is set to 'always'. + self.enhanced_ec = self.CheckForEnhancedECImage() + self.logger.debug("Enhanced EC image? %r", self.enhanced_ec) + + if not self.enhanced_ec: + # Send everything straight to the EC to handle. + self.cmd_pipe.send(six.int2byte(byte)) + # Reset the input buffer. + self.input_buffer = b"" + self.input_buffer_pos = 0 + self.logger.log(1, "Reset input buffer.") + return + + # Keep handling the ESC sequence if we're in the middle of it. + if self.esc_state != 0: + self.HandleEsc(byte) + return + + # When we're at the end of the line, we should only allow going backwards, + # backspace, carriage return, up, or down. The arrow keys are escape + # sequences, so we let the escape...escape. + if self.input_buffer_pos >= self.line_limit and byte not in [ + ControlKey.CTRL_B, + ControlKey.ESC, + ControlKey.BACKSPACE, + ControlKey.CTRL_A, + ControlKey.CARRIAGE_RETURN, + ControlKey.CTRL_P, + ControlKey.CTRL_N, + ]: + return + + # If the input buffer is full we can't accept new chars. + buffer_full = len(self.input_buffer) >= self.line_limit + + # Carriage_Return/Enter + if byte == ControlKey.CARRIAGE_RETURN: + self.logger.debug("Enter key pressed.") + # Put a carriage return/newline and the print the prompt. + os.write(fd, b"\r\n") + + # TODO(aaboagye): When we control the printing of all output, print the + # prompt AFTER printing all the output. We can't do it yet because we + # don't know how much is coming from the EC. + + # Print the prompt. + os.write(fd, self.prompt) + # Process the input. + self.ProcessInput() + # Now, clear the buffer. + self.input_buffer = b"" + self.input_buffer_pos = 0 + # Reset history buffer pos. + self.history_pos = len(self.history) + # Clear partial command. + self.partial_cmd = b"" + + # Backspace + elif byte == ControlKey.BACKSPACE: + self.logger.debug("Backspace pressed.") + if self.input_buffer_pos > 0: + # Move left 1 column. + self.MoveCursor("left", 1) + # Remove the character at the input_buffer_pos by slicing it out. + self.SliceOutChar() + + self.logger.debug("input_buffer_pos: %d", self.input_buffer_pos) + + # Ctrl+A. Move cursor to beginning of the line + elif byte == ControlKey.CTRL_A: + self.logger.debug("Control+A pressed.") + self.MoveCursor("left", self.input_buffer_pos) + + # Ctrl+B. Move cursor left 1 column. + elif byte == ControlKey.CTRL_B: + self.logger.debug("Control+B pressed.") + self.MoveCursor("left", 1) + + # Ctrl+D. Delete a character. + elif byte == ControlKey.CTRL_D: + self.logger.debug("Control+D pressed.") + if self.input_buffer_pos != len(self.input_buffer): + # Remove the character by slicing it out. + self.SliceOutChar() + + # Ctrl+E. Move cursor to end of the line. + elif byte == ControlKey.CTRL_E: + self.logger.debug("Control+E pressed.") + self.MoveCursor("right", len(self.input_buffer) - self.input_buffer_pos) + + # Ctrl+F. Move cursor right 1 column. + elif byte == ControlKey.CTRL_F: + self.logger.debug("Control+F pressed.") + self.MoveCursor("right", 1) + + # Ctrl+K. Kill line. + elif byte == ControlKey.CTRL_K: + self.logger.debug("Control+K pressed.") + self.KillLine() + + # Ctrl+N. Next line. + elif byte == ControlKey.CTRL_N: + self.logger.debug("Control+N pressed.") + self.ShowNextCommand() + + # Ctrl+P. Previous line. + elif byte == ControlKey.CTRL_P: + self.logger.debug("Control+P pressed.") + self.ShowPreviousCommand() + + # ESC sequence + elif byte == ControlKey.ESC: + # Starting an ESC sequence + self.esc_state = EscState.ESC_START + + # Only print printable chars. + elif IsPrintable(byte): + # Drop the character if we're full. + if buffer_full: + self.logger.debug("Dropped char: %c(%d)", byte, byte) + return + # Print the character. + os.write(fd, six.int2byte(byte)) + # Print the rest of the line (if any). + extra_bytes_written = os.write( + fd, self.input_buffer[self.input_buffer_pos :] + ) + + # Recreate the input buffer. + self.input_buffer = ( + self.input_buffer[0 : self.input_buffer_pos] + + six.int2byte(byte) + + self.input_buffer[self.input_buffer_pos :] + ) + # Update the input buffer position. + self.input_buffer_pos += 1 + extra_bytes_written + + # Reset the cursor if we wrote any extra bytes. + if extra_bytes_written: + self.MoveCursor("left", extra_bytes_written) + + self.logger.debug("input_buffer_pos: %d", self.input_buffer_pos) + + def MoveCursor(self, direction, count): + """MoveCursor moves the cursor left or right by count columns. + + Args: + direction: A string that should be either 'left' or 'right' representing + the direction to move the cursor on the console. + count: An integer representing how many columns the cursor should be + moved. + + Raises: + AssertionError: If the direction is not equal to 'left' or 'right'. + """ + # If there's nothing to move, we're done. + if not count: + return + fd = self.controller_pty + seq = b"\033[" + str(count).encode("ascii") + if direction == "left": + # Bind the movement. + if count > self.input_buffer_pos: + count = self.input_buffer_pos + seq += b"D" + self.logger.debug("move cursor left %d", count) + self.input_buffer_pos -= count + + elif direction == "right": + # Bind the movement. + if (count + self.input_buffer_pos) > len(self.input_buffer): + count = 0 + seq += b"C" + self.logger.debug("move cursor right %d", count) + self.input_buffer_pos += count + + else: + raise AssertionError( + ("The only valid directions are 'left' and " "'right'") + ) + + self.logger.debug("input_buffer_pos: %d", self.input_buffer_pos) + # Move the cursor. + if count != 0: + os.write(fd, seq) + + def KillLine(self): + """Kill the rest of the line based on the input buffer position.""" + # Killing the line is killing all the text to the right. + diff = len(self.input_buffer) - self.input_buffer_pos + self.logger.debug("diff: %d", diff) + # Diff shouldn't be negative, but if it is for some reason, let's try to + # correct the cursor. + if diff < 0: + self.logger.warning( + "Resetting input buffer position to %d...", len(self.input_buffer) + ) + self.MoveCursor("left", -diff) + return + if diff: + self.MoveCursor("right", diff) + for _ in range(diff): + self.SendBackspace() + self.input_buffer_pos -= diff + self.input_buffer = self.input_buffer[0 : self.input_buffer_pos] + + def SendBackspace(self): + """Backspace a character on the console.""" + os.write(self.controller_pty, b"\033[1D \033[1D") + + def ProcessOOBMQueue(self): + """Retrieve an item from the OOBM queue and process it.""" + item = self.oobm_queue.get() + self.logger.debug("OOBM cmd: %r", item) + cmd = item.split(b" ") + + if cmd[0] == b"loglevel": + # An integer is required in order to set the log level. + if len(cmd) < 2: + self.logger.debug("Insufficient args") + self.PrintOOBMHelp() + return + try: + self.logger.debug("Log level change request.") + new_log_level = int(cmd[1]) + self.logger.logger.setLevel(new_log_level) + self.logger.info("Log level changed to %d.", new_log_level) + + # Forward the request to the interpreter as well. + self.cmd_pipe.send(item) + except ValueError: + # Ignoring the request if an integer was not provided. + self.PrintOOBMHelp() + + elif cmd[0] == b"timestamp": + mode = cmd[1].lower() + self.timestamp_enabled = mode == b"on" + self.logger.info( + "%sabling uart timestamps.", "En" if self.timestamp_enabled else "Dis" + ) + + elif cmd[0] == b"rawdebug": + mode = cmd[1].lower() + self.raw_debug = mode == b"on" + self.logger.info( + "%sabling per interrupt debug logs.", "En" if self.raw_debug else "Dis" + ) + + elif cmd[0] == b"interrogate" and len(cmd) >= 2: + enhanced = False + mode = cmd[1] + if len(cmd) >= 3 and cmd[2] == b"enhanced": + enhanced = True + + # Set the mode if correct. + if mode in INTERROGATION_MODES: + self.interrogation_mode = mode + self.logger.debug("Updated interrogation mode to %s.", mode) + + # Update the assumptions of the EC image. + self.enhanced_ec = enhanced + self.logger.debug("Enhanced EC image is now %r", self.enhanced_ec) + + # Send command to interpreter as well. + self.cmd_pipe.send(b"enhanced " + str(self.enhanced_ec).encode("ascii")) + else: + self.PrintOOBMHelp() + + else: + self.PrintOOBMHelp() + + def PrintOOBMHelp(self): + """Prints out the OOBM help.""" + # Print help syntax. + os.write(self.controller_pty, b"\r\n" + b"Known OOBM commands:\r\n") + os.write( + self.controller_pty, + b" interrogate <never | always | auto> " b"[enhanced]\r\n", + ) + os.write(self.controller_pty, b" loglevel <int>\r\n") + + def CheckBufferForEnhancedImage(self, data): + """Adds data to a look buffer and checks to see for enhanced EC image. + + The EC's console task prints a string upon initialization which says that + "Console is enabled; type HELP for help.". The enhanced EC images print a + different string as a part of their init. This function searches through a + "look" buffer, scanning for the presence of either of those strings and + updating the enhanced_ec state accordingly. + + Args: + data: A string containing the data sent from the interpreter. + """ + self.look_buffer += data + + # Search the buffer for any of the EC image strings. + enhanced_match = re.search(ENHANCED_IMAGE_RE, self.look_buffer) + non_enhanced_match = re.search(NON_ENHANCED_IMAGE_RE, self.look_buffer) + + # Update the state if any matches were found. + if enhanced_match or non_enhanced_match: + if enhanced_match: + self.enhanced_ec = True + elif non_enhanced_match: + self.enhanced_ec = False + + # Inform the interpreter of the result. + self.cmd_pipe.send(b"enhanced " + str(self.enhanced_ec).encode("ascii")) + self.logger.debug("Enhanced EC image? %r", self.enhanced_ec) + + # Clear look buffer since a match was found. + self.look_buffer = b"" + + # Move the sliding window. + self.look_buffer = self.look_buffer[-LOOK_BUFFER_SIZE:] - else: - self.logger.error('Unexpected sequence. %c', byte) - self.esc_state = 0 - - def ProcessInput(self): - """Captures the input determines what actions to take.""" - # There's nothing to do if the input buffer is empty. - if len(self.input_buffer) == 0: - return - - # Don't store 2 consecutive identical commands in the history. - if (self.history and self.history[-1] != self.input_buffer - or not self.history): - self.history.append(self.input_buffer) - - # Split the command up by spaces. - line = self.input_buffer.split(b' ') - self.logger.debug('cmd: %s', self.input_buffer) - cmd = line[0].lower() - - # The 'history' command is a special case that we handle locally. - if cmd == 'history': - self.PrintHistory() - return - - # Send the command to the interpreter. - self.logger.debug('Sending command to interpreter.') - self.cmd_pipe.send(self.input_buffer) - def CheckForEnhancedECImage(self): - """Performs an interrogation of the EC image. +def CanonicalizeTimeString(timestr): + """Canonicalize the timestamp string. - Send a SYN and expect an ACK. If no ACK or the response is incorrect, then - assume that the current EC image that we are talking to is not enhanced. + Args: + timestr: A timestamp string ended with 6 digits msec. Returns: - is_enhanced: A boolean indicating whether the EC responded to the - interrogation correctly. - - Raises: - EOFError: Allowed to propagate through from self.dbg_pipe.recv(). + A string with 3 digits msec and an extra space. """ - # Send interrogation byte and wait for the response. - self.logger.debug('Performing interrogation.') - self.cmd_pipe.send(interpreter.EC_SYN) - - response = '' - if self.dbg_pipe.poll(self.interrogation_timeout): - response = self.dbg_pipe.recv() - self.logger.debug('response: %r', binascii.hexlify(response)) - else: - self.logger.debug('Timed out waiting for EC_ACK') + return timestr[:-3].encode("ascii") + b" " - # Verify the acknowledgment. - is_enhanced = response == interpreter.EC_ACK - - if is_enhanced: - # Increase the interrogation timeout for stability purposes. - self.interrogation_timeout = ENHANCED_EC_INTERROGATION_TIMEOUT - self.logger.debug('Increasing interrogation timeout to %rs.', - self.interrogation_timeout) - else: - # Reduce the timeout in order to reduce the perceivable delay. - self.interrogation_timeout = NON_ENHANCED_EC_INTERROGATION_TIMEOUT - self.logger.debug('Reducing interrogation timeout to %rs.', - self.interrogation_timeout) - - return is_enhanced - - def HandleChar(self, byte): - """HandleChar does a certain action when it receives a character. - - Args: - byte: An integer representing the character received from the user. - Raises: - EOFError: Allowed to propagate through from self.CheckForEnhancedECImage() - i.e. from self.dbg_pipe.recv(). - """ - fd = self.controller_pty - - # Enter the OOBM prompt mode if the user presses '%'. - if byte == ord('%'): - self.logger.debug('Begin OOBM command.') - self.receiving_oobm_cmd = True - # Print a "prompt". - os.write(self.controller_pty, b'\r\n% ') - return - - # Add chars to the pending OOBM command if we're currently receiving one. - if self.receiving_oobm_cmd and byte != ControlKey.CARRIAGE_RETURN: - tmp_bytes = six.int2byte(byte) - self.pending_oobm_cmd += tmp_bytes - self.logger.debug('%s', tmp_bytes) - os.write(self.controller_pty, tmp_bytes) - return - - if byte == ControlKey.CARRIAGE_RETURN: - if self.receiving_oobm_cmd: - # Terminate the command and place it in the OOBM queue. - self.logger.debug('End OOBM command.') - if self.pending_oobm_cmd: - self.oobm_queue.put(self.pending_oobm_cmd) - self.logger.debug('Placed %r into OOBM command queue.', - self.pending_oobm_cmd) - - # Reset the state. - os.write(self.controller_pty, b'\r\n' + self.prompt) - self.input_buffer = b'' - self.input_buffer_pos = 0 - self.receiving_oobm_cmd = False - self.pending_oobm_cmd = b'' - return - - if self.interrogation_mode == b'never': - self.logger.debug('Skipping interrogation because interrogation mode' - ' is set to never.') - elif self.interrogation_mode == b'always': - # Only interrogate the EC if the interrogation mode is set to 'always'. - self.enhanced_ec = self.CheckForEnhancedECImage() - self.logger.debug('Enhanced EC image? %r', self.enhanced_ec) - - if not self.enhanced_ec: - # Send everything straight to the EC to handle. - self.cmd_pipe.send(six.int2byte(byte)) - # Reset the input buffer. - self.input_buffer = b'' - self.input_buffer_pos = 0 - self.logger.log(1, 'Reset input buffer.') - return - - # Keep handling the ESC sequence if we're in the middle of it. - if self.esc_state != 0: - self.HandleEsc(byte) - return - - # When we're at the end of the line, we should only allow going backwards, - # backspace, carriage return, up, or down. The arrow keys are escape - # sequences, so we let the escape...escape. - if (self.input_buffer_pos >= self.line_limit and - byte not in [ControlKey.CTRL_B, ControlKey.ESC, ControlKey.BACKSPACE, - ControlKey.CTRL_A, ControlKey.CARRIAGE_RETURN, - ControlKey.CTRL_P, ControlKey.CTRL_N]): - return - - # If the input buffer is full we can't accept new chars. - buffer_full = len(self.input_buffer) >= self.line_limit - - - # Carriage_Return/Enter - if byte == ControlKey.CARRIAGE_RETURN: - self.logger.debug('Enter key pressed.') - # Put a carriage return/newline and the print the prompt. - os.write(fd, b'\r\n') - - # TODO(aaboagye): When we control the printing of all output, print the - # prompt AFTER printing all the output. We can't do it yet because we - # don't know how much is coming from the EC. - - # Print the prompt. - os.write(fd, self.prompt) - # Process the input. - self.ProcessInput() - # Now, clear the buffer. - self.input_buffer = b'' - self.input_buffer_pos = 0 - # Reset history buffer pos. - self.history_pos = len(self.history) - # Clear partial command. - self.partial_cmd = b'' - - # Backspace - elif byte == ControlKey.BACKSPACE: - self.logger.debug('Backspace pressed.') - if self.input_buffer_pos > 0: - # Move left 1 column. - self.MoveCursor('left', 1) - # Remove the character at the input_buffer_pos by slicing it out. - self.SliceOutChar() - - self.logger.debug('input_buffer_pos: %d', self.input_buffer_pos) - - # Ctrl+A. Move cursor to beginning of the line - elif byte == ControlKey.CTRL_A: - self.logger.debug('Control+A pressed.') - self.MoveCursor('left', self.input_buffer_pos) - - # Ctrl+B. Move cursor left 1 column. - elif byte == ControlKey.CTRL_B: - self.logger.debug('Control+B pressed.') - self.MoveCursor('left', 1) - - # Ctrl+D. Delete a character. - elif byte == ControlKey.CTRL_D: - self.logger.debug('Control+D pressed.') - if self.input_buffer_pos != len(self.input_buffer): - # Remove the character by slicing it out. - self.SliceOutChar() - - # Ctrl+E. Move cursor to end of the line. - elif byte == ControlKey.CTRL_E: - self.logger.debug('Control+E pressed.') - self.MoveCursor('right', - len(self.input_buffer) - self.input_buffer_pos) - - # Ctrl+F. Move cursor right 1 column. - elif byte == ControlKey.CTRL_F: - self.logger.debug('Control+F pressed.') - self.MoveCursor('right', 1) - - # Ctrl+K. Kill line. - elif byte == ControlKey.CTRL_K: - self.logger.debug('Control+K pressed.') - self.KillLine() - - # Ctrl+N. Next line. - elif byte == ControlKey.CTRL_N: - self.logger.debug('Control+N pressed.') - self.ShowNextCommand() - - # Ctrl+P. Previous line. - elif byte == ControlKey.CTRL_P: - self.logger.debug('Control+P pressed.') - self.ShowPreviousCommand() - - # ESC sequence - elif byte == ControlKey.ESC: - # Starting an ESC sequence - self.esc_state = EscState.ESC_START - - # Only print printable chars. - elif IsPrintable(byte): - # Drop the character if we're full. - if buffer_full: - self.logger.debug('Dropped char: %c(%d)', byte, byte) - return - # Print the character. - os.write(fd, six.int2byte(byte)) - # Print the rest of the line (if any). - extra_bytes_written = os.write(fd, - self.input_buffer[self.input_buffer_pos:]) - - # Recreate the input buffer. - self.input_buffer = (self.input_buffer[0:self.input_buffer_pos] + - six.int2byte(byte) + - self.input_buffer[self.input_buffer_pos:]) - # Update the input buffer position. - self.input_buffer_pos += 1 + extra_bytes_written - - # Reset the cursor if we wrote any extra bytes. - if extra_bytes_written: - self.MoveCursor('left', extra_bytes_written) - - self.logger.debug('input_buffer_pos: %d', self.input_buffer_pos) - - def MoveCursor(self, direction, count): - """MoveCursor moves the cursor left or right by count columns. +def IsPrintable(byte): + """Determines if a byte is printable. Args: - direction: A string that should be either 'left' or 'right' representing - the direction to move the cursor on the console. - count: An integer representing how many columns the cursor should be - moved. + byte: An integer potentially representing a printable character. - Raises: - AssertionError: If the direction is not equal to 'left' or 'right'. + Returns: + A boolean indicating whether the byte is a printable character. """ - # If there's nothing to move, we're done. - if not count: - return - fd = self.controller_pty - seq = b'\033[' + str(count).encode('ascii') - if direction == 'left': - # Bind the movement. - if count > self.input_buffer_pos: - count = self.input_buffer_pos - seq += b'D' - self.logger.debug('move cursor left %d', count) - self.input_buffer_pos -= count - - elif direction == 'right': - # Bind the movement. - if (count + self.input_buffer_pos) > len(self.input_buffer): - count = 0 - seq += b'C' - self.logger.debug('move cursor right %d', count) - self.input_buffer_pos += count - - else: - raise AssertionError(('The only valid directions are \'left\' and ' - '\'right\'')) - - self.logger.debug('input_buffer_pos: %d', self.input_buffer_pos) - # Move the cursor. - if count != 0: - os.write(fd, seq) - - def KillLine(self): - """Kill the rest of the line based on the input buffer position.""" - # Killing the line is killing all the text to the right. - diff = len(self.input_buffer) - self.input_buffer_pos - self.logger.debug('diff: %d', diff) - # Diff shouldn't be negative, but if it is for some reason, let's try to - # correct the cursor. - if diff < 0: - self.logger.warning('Resetting input buffer position to %d...', - len(self.input_buffer)) - self.MoveCursor('left', -diff) - return - if diff: - self.MoveCursor('right', diff) - for _ in range(diff): - self.SendBackspace() - self.input_buffer_pos -= diff - self.input_buffer = self.input_buffer[0:self.input_buffer_pos] - - def SendBackspace(self): - """Backspace a character on the console.""" - os.write(self.controller_pty, b'\033[1D \033[1D') - - def ProcessOOBMQueue(self): - """Retrieve an item from the OOBM queue and process it.""" - item = self.oobm_queue.get() - self.logger.debug('OOBM cmd: %r', item) - cmd = item.split(b' ') - - if cmd[0] == b'loglevel': - # An integer is required in order to set the log level. - if len(cmd) < 2: - self.logger.debug('Insufficient args') - self.PrintOOBMHelp() - return - try: - self.logger.debug('Log level change request.') - new_log_level = int(cmd[1]) - self.logger.logger.setLevel(new_log_level) - self.logger.info('Log level changed to %d.', new_log_level) - - # Forward the request to the interpreter as well. - self.cmd_pipe.send(item) - except ValueError: - # Ignoring the request if an integer was not provided. - self.PrintOOBMHelp() - - elif cmd[0] == b'timestamp': - mode = cmd[1].lower() - self.timestamp_enabled = (mode == b'on') - self.logger.info('%sabling uart timestamps.', - 'En' if self.timestamp_enabled else 'Dis') - - elif cmd[0] == b'rawdebug': - mode = cmd[1].lower() - self.raw_debug = (mode == b'on') - self.logger.info('%sabling per interrupt debug logs.', - 'En' if self.raw_debug else 'Dis') - - elif cmd[0] == b'interrogate' and len(cmd) >= 2: - enhanced = False - mode = cmd[1] - if len(cmd) >= 3 and cmd[2] == b'enhanced': - enhanced = True - - # Set the mode if correct. - if mode in INTERROGATION_MODES: - self.interrogation_mode = mode - self.logger.debug('Updated interrogation mode to %s.', mode) - - # Update the assumptions of the EC image. - self.enhanced_ec = enhanced - self.logger.debug('Enhanced EC image is now %r', self.enhanced_ec) - - # Send command to interpreter as well. - self.cmd_pipe.send(b'enhanced ' + str(self.enhanced_ec).encode('ascii')) - else: - self.PrintOOBMHelp() - - else: - self.PrintOOBMHelp() + return byte >= ord(" ") and byte <= ord("~") - def PrintOOBMHelp(self): - """Prints out the OOBM help.""" - # Print help syntax. - os.write(self.controller_pty, b'\r\n' + b'Known OOBM commands:\r\n') - os.write(self.controller_pty, b' interrogate <never | always | auto> ' - b'[enhanced]\r\n') - os.write(self.controller_pty, b' loglevel <int>\r\n') - def CheckBufferForEnhancedImage(self, data): - """Adds data to a look buffer and checks to see for enhanced EC image. - - The EC's console task prints a string upon initialization which says that - "Console is enabled; type HELP for help.". The enhanced EC images print a - different string as a part of their init. This function searches through a - "look" buffer, scanning for the presence of either of those strings and - updating the enhanced_ec state accordingly. +def StartLoop(console, command_active, shutdown_pipe=None): + """Starts the infinite loop of console processing. Args: - data: A string containing the data sent from the interpreter. + console: A Console object that has been properly initialzed. + command_active: ctypes data object or multiprocessing.Value indicating if + servod owns the console, or user owns the console. This prevents input + collisions. + shutdown_pipe: A file object for a pipe or equivalent that becomes readable + (not blocked) to indicate that the loop should exit. Can be None to never + exit the loop. """ - self.look_buffer += data - - # Search the buffer for any of the EC image strings. - enhanced_match = re.search(ENHANCED_IMAGE_RE, self.look_buffer) - non_enhanced_match = re.search(NON_ENHANCED_IMAGE_RE, self.look_buffer) - - # Update the state if any matches were found. - if enhanced_match or non_enhanced_match: - if enhanced_match: - self.enhanced_ec = True - elif non_enhanced_match: - self.enhanced_ec = False - - # Inform the interpreter of the result. - self.cmd_pipe.send(b'enhanced ' + str(self.enhanced_ec).encode('ascii')) - self.logger.debug('Enhanced EC image? %r', self.enhanced_ec) - - # Clear look buffer since a match was found. - self.look_buffer = b'' - - # Move the sliding window. - self.look_buffer = self.look_buffer[-LOOK_BUFFER_SIZE:] - - -def CanonicalizeTimeString(timestr): - """Canonicalize the timestamp string. - - Args: - timestr: A timestamp string ended with 6 digits msec. - - Returns: - A string with 3 digits msec and an extra space. - """ - return timestr[:-3].encode('ascii') + b' ' - - -def IsPrintable(byte): - """Determines if a byte is printable. - - Args: - byte: An integer potentially representing a printable character. - - Returns: - A boolean indicating whether the byte is a printable character. - """ - return byte >= ord(' ') and byte <= ord('~') - - -def StartLoop(console, command_active, shutdown_pipe=None): - """Starts the infinite loop of console processing. - - Args: - console: A Console object that has been properly initialzed. - command_active: ctypes data object or multiprocessing.Value indicating if - servod owns the console, or user owns the console. This prevents input - collisions. - shutdown_pipe: A file object for a pipe or equivalent that becomes readable - (not blocked) to indicate that the loop should exit. Can be None to never - exit the loop. - """ - try: - console.logger.debug('Console is being served on %s.', console.user_pty) - console.logger.debug('Console controller is on %s.', console.controller_pty) - console.logger.debug('Command interface is being served on %s.', - console.interface_pty) - console.logger.debug(console) - - # This checks for HUP to indicate if the user has connected to the pty. - ep = select.epoll() - ep.register(console.controller_pty, select.EPOLLHUP) - - # This is used instead of "break" to avoid exiting the loop in the middle of - # an iteration. - continue_looping = True - - # Used for determining when to print host timestamps - tm_req = True - - while continue_looping: - # Check to see if pts is connected to anything - events = ep.poll(0) - controller_connected = not events - - # Check to see if pipes or the console are ready for reading. - read_list = [console.interface_pty, - console.cmd_pipe, console.dbg_pipe] - if controller_connected: - read_list.append(console.controller_pty) - if shutdown_pipe is not None: - read_list.append(shutdown_pipe) - - # Check if any input is ready, or wait for .1 sec and re-poll if - # a user has connected to the pts. - select_output = select.select(read_list, [], [], .1) - if not select_output: - continue - ready_for_reading = select_output[0] - - for obj in ready_for_reading: - if obj is console.controller_pty: - if not command_active.value: - # Convert to bytes so we can look for non-printable chars such as - # Ctrl+A, Ctrl+E, etc. - try: - line = bytearray(os.read(console.controller_pty, CONSOLE_MAX_READ)) - console.logger.debug('Input from user: %s, locked:%s', - str(line).strip(), command_active.value) - for i in line: - try: - # Handle each character as it arrives. - console.HandleChar(i) - except EOFError: - console.logger.debug( - 'ec3po console received EOF from dbg_pipe in HandleChar()' - ' while reading console.controller_pty') - continue_looping = False - break - except OSError: - console.logger.debug('Ptm read failed, probably user disconnect.') - - elif obj is console.interface_pty: - if command_active.value: - # Convert to bytes so we can look for non-printable chars such as - # Ctrl+A, Ctrl+E, etc. - line = bytearray(os.read(console.interface_pty, CONSOLE_MAX_READ)) - console.logger.debug('Input from interface: %s, locked:%s', - str(line).strip(), command_active.value) - for i in line: - try: - # Handle each character as it arrives. - console.HandleChar(i) - except EOFError: - console.logger.debug( - 'ec3po console received EOF from dbg_pipe in HandleChar()' - ' while reading console.interface_pty') - continue_looping = False - break - - elif obj is console.cmd_pipe: - try: - data = console.cmd_pipe.recv() - except EOFError: - console.logger.debug('ec3po console received EOF from cmd_pipe') - continue_looping = False - else: - # Write it to the user console. - if console.raw_debug: - console.logger.debug('|CMD|-%s->%r', - ('u' if controller_connected else '') + - ('i' if command_active.value else ''), - data.strip()) + try: + console.logger.debug("Console is being served on %s.", console.user_pty) + console.logger.debug("Console controller is on %s.", console.controller_pty) + console.logger.debug( + "Command interface is being served on %s.", console.interface_pty + ) + console.logger.debug(console) + + # This checks for HUP to indicate if the user has connected to the pty. + ep = select.epoll() + ep.register(console.controller_pty, select.EPOLLHUP) + + # This is used instead of "break" to avoid exiting the loop in the middle of + # an iteration. + continue_looping = True + + # Used for determining when to print host timestamps + tm_req = True + + while continue_looping: + # Check to see if pts is connected to anything + events = ep.poll(0) + controller_connected = not events + + # Check to see if pipes or the console are ready for reading. + read_list = [console.interface_pty, console.cmd_pipe, console.dbg_pipe] if controller_connected: - os.write(console.controller_pty, data) - if command_active.value: - os.write(console.interface_pty, data) - - elif obj is console.dbg_pipe: - try: - data = console.dbg_pipe.recv() - except EOFError: - console.logger.debug('ec3po console received EOF from dbg_pipe') - continue_looping = False - else: - if console.interrogation_mode == b'auto': - # Search look buffer for enhanced EC image string. - console.CheckBufferForEnhancedImage(data) - # Write it to the user console. - if len(data) > 1 and console.raw_debug: - console.logger.debug('|DBG|-%s->%r', - ('u' if controller_connected else '') + - ('i' if command_active.value else ''), - data.strip()) - console.LogConsoleOutput(data) - if controller_connected: - end = len(data) - 1 - if console.timestamp_enabled: - # A timestamp is required at the beginning of this line - if tm_req is True: - now = datetime.now() - tm = CanonicalizeTimeString(now.strftime(HOST_STRFTIME)) - os.write(console.controller_pty, tm) - tm_req = False - - # Insert timestamps into the middle where appropriate - # except if the last character is a newline - nls_found = data.count(b'\n', 0, end) - now = datetime.now() - tm = CanonicalizeTimeString(now.strftime('\n' + HOST_STRFTIME)) - data_tm = data.replace(b'\n', tm, nls_found) - else: - data_tm = data - - # timestamp required on next input - if data[end] == b'\n'[0]: - tm_req = True - os.write(console.controller_pty, data_tm) - if command_active.value: - os.write(console.interface_pty, data) - - elif obj is shutdown_pipe: - console.logger.debug( - 'ec3po console received shutdown pipe unblocked notification') - continue_looping = False - - while not console.oobm_queue.empty(): - console.logger.debug('OOBM queue ready for reading.') - console.ProcessOOBMQueue() - - except KeyboardInterrupt: - pass - - finally: - ep.unregister(console.controller_pty) - console.dbg_pipe.close() - console.cmd_pipe.close() - os.close(console.controller_pty) - os.close(console.interface_pty) - if shutdown_pipe is not None: - shutdown_pipe.close() - console.logger.debug('Exit ec3po console loop for %s', console.user_pty) + read_list.append(console.controller_pty) + if shutdown_pipe is not None: + read_list.append(shutdown_pipe) + + # Check if any input is ready, or wait for .1 sec and re-poll if + # a user has connected to the pts. + select_output = select.select(read_list, [], [], 0.1) + if not select_output: + continue + ready_for_reading = select_output[0] + + for obj in ready_for_reading: + if obj is console.controller_pty: + if not command_active.value: + # Convert to bytes so we can look for non-printable chars such as + # Ctrl+A, Ctrl+E, etc. + try: + line = bytearray( + os.read(console.controller_pty, CONSOLE_MAX_READ) + ) + console.logger.debug( + "Input from user: %s, locked:%s", + str(line).strip(), + command_active.value, + ) + for i in line: + try: + # Handle each character as it arrives. + console.HandleChar(i) + except EOFError: + console.logger.debug( + "ec3po console received EOF from dbg_pipe in HandleChar()" + " while reading console.controller_pty" + ) + continue_looping = False + break + except OSError: + console.logger.debug( + "Ptm read failed, probably user disconnect." + ) + + elif obj is console.interface_pty: + if command_active.value: + # Convert to bytes so we can look for non-printable chars such as + # Ctrl+A, Ctrl+E, etc. + line = bytearray( + os.read(console.interface_pty, CONSOLE_MAX_READ) + ) + console.logger.debug( + "Input from interface: %s, locked:%s", + str(line).strip(), + command_active.value, + ) + for i in line: + try: + # Handle each character as it arrives. + console.HandleChar(i) + except EOFError: + console.logger.debug( + "ec3po console received EOF from dbg_pipe in HandleChar()" + " while reading console.interface_pty" + ) + continue_looping = False + break + + elif obj is console.cmd_pipe: + try: + data = console.cmd_pipe.recv() + except EOFError: + console.logger.debug("ec3po console received EOF from cmd_pipe") + continue_looping = False + else: + # Write it to the user console. + if console.raw_debug: + console.logger.debug( + "|CMD|-%s->%r", + ("u" if controller_connected else "") + + ("i" if command_active.value else ""), + data.strip(), + ) + if controller_connected: + os.write(console.controller_pty, data) + if command_active.value: + os.write(console.interface_pty, data) + + elif obj is console.dbg_pipe: + try: + data = console.dbg_pipe.recv() + except EOFError: + console.logger.debug("ec3po console received EOF from dbg_pipe") + continue_looping = False + else: + if console.interrogation_mode == b"auto": + # Search look buffer for enhanced EC image string. + console.CheckBufferForEnhancedImage(data) + # Write it to the user console. + if len(data) > 1 and console.raw_debug: + console.logger.debug( + "|DBG|-%s->%r", + ("u" if controller_connected else "") + + ("i" if command_active.value else ""), + data.strip(), + ) + console.LogConsoleOutput(data) + if controller_connected: + end = len(data) - 1 + if console.timestamp_enabled: + # A timestamp is required at the beginning of this line + if tm_req is True: + now = datetime.now() + tm = CanonicalizeTimeString( + now.strftime(HOST_STRFTIME) + ) + os.write(console.controller_pty, tm) + tm_req = False + + # Insert timestamps into the middle where appropriate + # except if the last character is a newline + nls_found = data.count(b"\n", 0, end) + now = datetime.now() + tm = CanonicalizeTimeString( + now.strftime("\n" + HOST_STRFTIME) + ) + data_tm = data.replace(b"\n", tm, nls_found) + else: + data_tm = data + + # timestamp required on next input + if data[end] == b"\n"[0]: + tm_req = True + os.write(console.controller_pty, data_tm) + if command_active.value: + os.write(console.interface_pty, data) + + elif obj is shutdown_pipe: + console.logger.debug( + "ec3po console received shutdown pipe unblocked notification" + ) + continue_looping = False + + while not console.oobm_queue.empty(): + console.logger.debug("OOBM queue ready for reading.") + console.ProcessOOBMQueue() + + except KeyboardInterrupt: + pass + + finally: + ep.unregister(console.controller_pty) + console.dbg_pipe.close() + console.cmd_pipe.close() + os.close(console.controller_pty) + os.close(console.interface_pty) + if shutdown_pipe is not None: + shutdown_pipe.close() + console.logger.debug("Exit ec3po console loop for %s", console.user_pty) def main(argv): - """Kicks off the EC-3PO interactive console interface and interpreter. - - We create some pipes to communicate with an interpreter, instantiate an - interpreter, create a PTY pair, and begin serving the console interface. - - Args: - argv: A list of strings containing the arguments this module was called - with. - """ - # Set up argument parser. - parser = argparse.ArgumentParser(description=('Start interactive EC console ' - 'and interpreter.')) - parser.add_argument('ec_uart_pty', - help=('The full PTY name that the EC UART' - ' is present on. eg: /dev/pts/12')) - parser.add_argument('--log-level', - default='info', - help='info, debug, warning, error, or critical') - - # Parse arguments. - opts = parser.parse_args(argv) - - # Set logging level. - opts.log_level = opts.log_level.lower() - if opts.log_level == 'info': - log_level = logging.INFO - elif opts.log_level == 'debug': - log_level = logging.DEBUG - elif opts.log_level == 'warning': - log_level = logging.WARNING - elif opts.log_level == 'error': - log_level = logging.ERROR - elif opts.log_level == 'critical': - log_level = logging.CRITICAL - else: - parser.error('Invalid log level. (info, debug, warning, error, critical)') - - # Start logging with a timestamp, module, and log level shown in each log - # entry. - logging.basicConfig(level=log_level, format=('%(asctime)s - %(module)s -' - ' %(levelname)s - %(message)s')) - - # Create some pipes to communicate between the interpreter and the console. - # The command pipe is bidirectional. - cmd_pipe_interactive, cmd_pipe_interp = threadproc_shim.Pipe() - # The debug pipe is unidirectional from interpreter to console only. - dbg_pipe_interactive, dbg_pipe_interp = threadproc_shim.Pipe(duplex=False) - - # Create an interpreter instance. - itpr = interpreter.Interpreter(opts.ec_uart_pty, cmd_pipe_interp, - dbg_pipe_interp, log_level) - - # Spawn an interpreter process. - itpr_process = threadproc_shim.ThreadOrProcess( - target=interpreter.StartLoop, args=(itpr,)) - # Make sure to kill the interpreter when we terminate. - itpr_process.daemon = True - # Start the interpreter. - itpr_process.start() - - # Open a new pseudo-terminal pair - (controller_pty, user_pty) = pty.openpty() - # Set the permissions to 660. - os.chmod(os.ttyname(user_pty), (stat.S_IRGRP | stat.S_IWGRP | - stat.S_IRUSR | stat.S_IWUSR)) - # Create a console. - console = Console(controller_pty, os.ttyname(user_pty), cmd_pipe_interactive, - dbg_pipe_interactive) - # Start serving the console. - v = threadproc_shim.Value(ctypes.c_bool, False) - StartLoop(console, v) - - -if __name__ == '__main__': - main(sys.argv[1:]) + """Kicks off the EC-3PO interactive console interface and interpreter. + + We create some pipes to communicate with an interpreter, instantiate an + interpreter, create a PTY pair, and begin serving the console interface. + + Args: + argv: A list of strings containing the arguments this module was called + with. + """ + # Set up argument parser. + parser = argparse.ArgumentParser( + description=("Start interactive EC console " "and interpreter.") + ) + parser.add_argument( + "ec_uart_pty", + help=("The full PTY name that the EC UART" " is present on. eg: /dev/pts/12"), + ) + parser.add_argument( + "--log-level", default="info", help="info, debug, warning, error, or critical" + ) + + # Parse arguments. + opts = parser.parse_args(argv) + + # Set logging level. + opts.log_level = opts.log_level.lower() + if opts.log_level == "info": + log_level = logging.INFO + elif opts.log_level == "debug": + log_level = logging.DEBUG + elif opts.log_level == "warning": + log_level = logging.WARNING + elif opts.log_level == "error": + log_level = logging.ERROR + elif opts.log_level == "critical": + log_level = logging.CRITICAL + else: + parser.error("Invalid log level. (info, debug, warning, error, critical)") + + # Start logging with a timestamp, module, and log level shown in each log + # entry. + logging.basicConfig( + level=log_level, + format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"), + ) + + # Create some pipes to communicate between the interpreter and the console. + # The command pipe is bidirectional. + cmd_pipe_interactive, cmd_pipe_interp = threadproc_shim.Pipe() + # The debug pipe is unidirectional from interpreter to console only. + dbg_pipe_interactive, dbg_pipe_interp = threadproc_shim.Pipe(duplex=False) + + # Create an interpreter instance. + itpr = interpreter.Interpreter( + opts.ec_uart_pty, cmd_pipe_interp, dbg_pipe_interp, log_level + ) + + # Spawn an interpreter process. + itpr_process = threadproc_shim.ThreadOrProcess( + target=interpreter.StartLoop, args=(itpr,) + ) + # Make sure to kill the interpreter when we terminate. + itpr_process.daemon = True + # Start the interpreter. + itpr_process.start() + + # Open a new pseudo-terminal pair + (controller_pty, user_pty) = pty.openpty() + # Set the permissions to 660. + os.chmod( + os.ttyname(user_pty), + (stat.S_IRGRP | stat.S_IWGRP | stat.S_IRUSR | stat.S_IWUSR), + ) + # Create a console. + console = Console( + controller_pty, os.ttyname(user_pty), cmd_pipe_interactive, dbg_pipe_interactive + ) + # Start serving the console. + v = threadproc_shim.Value(ctypes.c_bool, False) + StartLoop(console, v) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/util/ec3po/console_unittest.py b/util/ec3po/console_unittest.py index 7e341e7e8d..41ae324ef4 100755 --- a/util/ec3po/console_unittest.py +++ b/util/ec3po/console_unittest.py @@ -11,1262 +11,1310 @@ from __future__ import print_function import binascii import logging -import mock import tempfile import unittest +import mock import six - -from ec3po import console -from ec3po import interpreter -from ec3po import threadproc_shim +from ec3po import console, interpreter, threadproc_shim ESC_STRING = six.int2byte(console.ControlKey.ESC) + class Keys(object): - """A class that contains the escape sequences for special keys.""" - LEFT_ARROW = [console.ControlKey.ESC, ord('['), ord('D')] - RIGHT_ARROW = [console.ControlKey.ESC, ord('['), ord('C')] - UP_ARROW = [console.ControlKey.ESC, ord('['), ord('A')] - DOWN_ARROW = [console.ControlKey.ESC, ord('['), ord('B')] - HOME = [console.ControlKey.ESC, ord('['), ord('1'), ord('~')] - END = [console.ControlKey.ESC, ord('['), ord('8'), ord('~')] - DEL = [console.ControlKey.ESC, ord('['), ord('3'), ord('~')] + """A class that contains the escape sequences for special keys.""" + + LEFT_ARROW = [console.ControlKey.ESC, ord("["), ord("D")] + RIGHT_ARROW = [console.ControlKey.ESC, ord("["), ord("C")] + UP_ARROW = [console.ControlKey.ESC, ord("["), ord("A")] + DOWN_ARROW = [console.ControlKey.ESC, ord("["), ord("B")] + HOME = [console.ControlKey.ESC, ord("["), ord("1"), ord("~")] + END = [console.ControlKey.ESC, ord("["), ord("8"), ord("~")] + DEL = [console.ControlKey.ESC, ord("["), ord("3"), ord("~")] + class OutputStream(object): - """A class that has methods which return common console output.""" + """A class that has methods which return common console output.""" - @staticmethod - def MoveCursorLeft(count): - """Produces what would be printed to the console if the cursor moved left. + @staticmethod + def MoveCursorLeft(count): + """Produces what would be printed to the console if the cursor moved left. - Args: - count: An integer representing how many columns to move left. + Args: + count: An integer representing how many columns to move left. - Returns: - string: A string which contains what would be printed to the console if - the cursor moved left. - """ - string = ESC_STRING - string += b'[' + str(count).encode('ascii') + b'D' - return string + Returns: + string: A string which contains what would be printed to the console if + the cursor moved left. + """ + string = ESC_STRING + string += b"[" + str(count).encode("ascii") + b"D" + return string - @staticmethod - def MoveCursorRight(count): - """Produces what would be printed to the console if the cursor moved right. + @staticmethod + def MoveCursorRight(count): + """Produces what would be printed to the console if the cursor moved right. - Args: - count: An integer representing how many columns to move right. + Args: + count: An integer representing how many columns to move right. - Returns: - string: A string which contains what would be printed to the console if - the cursor moved right. - """ - string = ESC_STRING - string += b'[' + str(count).encode('ascii') + b'C' - return string + Returns: + string: A string which contains what would be printed to the console if + the cursor moved right. + """ + string = ESC_STRING + string += b"[" + str(count).encode("ascii") + b"C" + return string -BACKSPACE_STRING = b'' + +BACKSPACE_STRING = b"" # Move cursor left 1 column. BACKSPACE_STRING += OutputStream.MoveCursorLeft(1) # Write a space. -BACKSPACE_STRING += b' ' +BACKSPACE_STRING += b" " # Move cursor left 1 column. BACKSPACE_STRING += OutputStream.MoveCursorLeft(1) + def BytesToByteList(string): - """Converts a bytes string to list of bytes. + """Converts a bytes string to list of bytes. + + Args: + string: A literal bytes to turn into a list of bytes. - Args: - string: A literal bytes to turn into a list of bytes. + Returns: + A list of integers representing the byte value of each character in the + string. + """ + if six.PY3: + return [c for c in string] + return [ord(c) for c in string] - Returns: - A list of integers representing the byte value of each character in the - string. - """ - if six.PY3: - return [c for c in string] - return [ord(c) for c in string] def CheckConsoleOutput(test_case, exp_console_out): - """Verify what was sent out the console matches what we expect. + """Verify what was sent out the console matches what we expect. - Args: - test_case: A unittest.TestCase object representing the current unit test. - exp_console_out: A string representing the console output stream. - """ - # Read what was sent out the console. - test_case.tempfile.seek(0) - console_out = test_case.tempfile.read() + Args: + test_case: A unittest.TestCase object representing the current unit test. + exp_console_out: A string representing the console output stream. + """ + # Read what was sent out the console. + test_case.tempfile.seek(0) + console_out = test_case.tempfile.read() - test_case.assertEqual(exp_console_out, console_out) + test_case.assertEqual(exp_console_out, console_out) -def CheckInputBuffer(test_case, exp_input_buffer): - """Verify that the input buffer contains what we expect. - - Args: - test_case: A unittest.TestCase object representing the current unit test. - exp_input_buffer: A string containing the contents of the current input - buffer. - """ - test_case.assertEqual(exp_input_buffer, test_case.console.input_buffer, - (b'input buffer does not match expected.\n' - b'expected: |' + exp_input_buffer + b'|\n' - b'got: |' + test_case.console.input_buffer + - b'|\n' + str(test_case.console).encode('ascii'))) -def CheckInputBufferPosition(test_case, exp_pos): - """Verify the input buffer position. +def CheckInputBuffer(test_case, exp_input_buffer): + """Verify that the input buffer contains what we expect. - Args: - test_case: A unittest.TestCase object representing the current unit test. - exp_pos: An integer representing the expected input buffer position. - """ - test_case.assertEqual(exp_pos, test_case.console.input_buffer_pos, - 'input buffer position is incorrect.\ngot: ' + - str(test_case.console.input_buffer_pos) + '\nexp: ' + - str(exp_pos) + '\n' + str(test_case.console)) + Args: + test_case: A unittest.TestCase object representing the current unit test. + exp_input_buffer: A string containing the contents of the current input + buffer. + """ + test_case.assertEqual( + exp_input_buffer, + test_case.console.input_buffer, + ( + b"input buffer does not match expected.\n" + b"expected: |" + exp_input_buffer + b"|\n" + b"got: |" + + test_case.console.input_buffer + + b"|\n" + + str(test_case.console).encode("ascii") + ), + ) -def CheckHistoryBuffer(test_case, exp_history): - """Verify that the items in the history buffer are what we expect. - - Args: - test_case: A unittest.TestCase object representing the current unit test. - exp_history: A list of strings representing the expected contents of the - history buffer. - """ - # First, check to see if the length is what we expect. - test_case.assertEqual(len(exp_history), len(test_case.console.history), - ('The number of items in the history is unexpected.\n' - 'exp: ' + str(len(exp_history)) + '\n' - 'got: ' + str(len(test_case.console.history)) + '\n' - 'internal state:\n' + str(test_case.console))) - - # Next, check the actual contents of the history buffer. - for i in range(len(exp_history)): - test_case.assertEqual(exp_history[i], test_case.console.history[i], - (b'history buffer contents are incorrect.\n' - b'exp: ' + exp_history[i] + b'\n' - b'got: ' + test_case.console.history[i] + b'\n' - b'internal state:\n' + - str(test_case.console).encode('ascii'))) +def CheckInputBufferPosition(test_case, exp_pos): + """Verify the input buffer position. -class TestConsoleEditingMethods(unittest.TestCase): - """Test case to verify all console editing methods.""" - - def setUp(self): - """Setup the test harness.""" - # Setup logging with a timestamp, the module, and the log level. - logging.basicConfig(level=logging.DEBUG, - format=('%(asctime)s - %(module)s -' - ' %(levelname)s - %(message)s')) - - # Create a temp file and set both the controller and peripheral PTYs to the - # file to create a loopback. - self.tempfile = tempfile.TemporaryFile() - - # Create some mock pipes. These won't be used since we'll mock out sends - # to the interpreter. - mock_pipe_end_0, mock_pipe_end_1 = threadproc_shim.Pipe() - self.console = console.Console(self.tempfile.fileno(), self.tempfile, - tempfile.TemporaryFile(), - mock_pipe_end_0, mock_pipe_end_1, "EC") - - # Console editing methods are only valid for enhanced EC images, therefore - # we have to assume that the "EC" we're talking to is enhanced. By default, - # the console believes that the EC it's communicating with is NOT enhanced - # which is why we have to override it here. - self.console.enhanced_ec = True - self.console.CheckForEnhancedECImage = mock.MagicMock(return_value=True) - - def test_EnteringChars(self): - """Verify that characters are echoed onto the console.""" - test_str = b'abc' - input_stream = BytesToByteList(test_str) - - # Send the characters in. - for byte in input_stream: - self.console.HandleChar(byte) - - # Check the input position. - exp_pos = len(test_str) - CheckInputBufferPosition(self, exp_pos) - - # Verify that the input buffer is correct. - expected_buffer = test_str - CheckInputBuffer(self, expected_buffer) - - # Check console output - exp_console_out = test_str - CheckConsoleOutput(self, exp_console_out) - - def test_EnteringDeletingMoreCharsThanEntered(self): - """Verify that we can press backspace more than we have entered chars.""" - test_str = b'spamspam' - input_stream = BytesToByteList(test_str) - - # Send the characters in. - for byte in input_stream: - self.console.HandleChar(byte) - - # Now backspace 1 more than what we sent. - input_stream = [] - for _ in range(len(test_str) + 1): - input_stream.append(console.ControlKey.BACKSPACE) - - # Send that sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # First, verify that input buffer position is 0. - CheckInputBufferPosition(self, 0) - - # Next, examine the output stream for the correct sequence. - exp_console_out = test_str - for _ in range(len(test_str)): - exp_console_out += BACKSPACE_STRING - - # Now, verify that we got what we expected. - CheckConsoleOutput(self, exp_console_out) - - def test_EnteringMoreThanCharLimit(self): - """Verify that we drop characters when the line is too long.""" - test_str = self.console.line_limit * b'o' # All allowed. - test_str += 5 * b'x' # All should be dropped. - input_stream = BytesToByteList(test_str) - - # Send the characters in. - for byte in input_stream: - self.console.HandleChar(byte) - - # First, we expect that input buffer position should be equal to the line - # limit. - exp_pos = self.console.line_limit - CheckInputBufferPosition(self, exp_pos) - - # The input buffer should only hold until the line limit. - exp_buffer = test_str[0:self.console.line_limit] - CheckInputBuffer(self, exp_buffer) - - # Lastly, check that the extra characters are not printed. - exp_console_out = exp_buffer - CheckConsoleOutput(self, exp_console_out) - - def test_ValidKeysOnLongLine(self): - """Verify that we can still press valid keys if the line is too long.""" - # Fill the line. - test_str = self.console.line_limit * b'o' - exp_console_out = test_str - # Try to fill it even more; these should all be dropped. - test_str += 5 * b'x' - input_stream = BytesToByteList(test_str) - - # We should be able to press the following keys: - # - Backspace - # - Arrow Keys/CTRL+B/CTRL+F/CTRL+P/CTRL+N - # - Delete - # - Home/CTRL+A - # - End/CTRL+E - # - Carriage Return - - # Backspace 1 character - input_stream.append(console.ControlKey.BACKSPACE) - exp_console_out += BACKSPACE_STRING - # Refill the line. - input_stream.extend(BytesToByteList(b'o')) - exp_console_out += b'o' - - # Left arrow key. - input_stream.extend(Keys.LEFT_ARROW) - exp_console_out += OutputStream.MoveCursorLeft(1) - - # Right arrow key. - input_stream.extend(Keys.RIGHT_ARROW) - exp_console_out += OutputStream.MoveCursorRight(1) - - # CTRL+B - input_stream.append(console.ControlKey.CTRL_B) - exp_console_out += OutputStream.MoveCursorLeft(1) - - # CTRL+F - input_stream.append(console.ControlKey.CTRL_F) - exp_console_out += OutputStream.MoveCursorRight(1) - - # Let's press enter now so we can test up and down. - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - exp_console_out += b'\r\n' + self.console.prompt - - # Up arrow key. - input_stream.extend(Keys.UP_ARROW) - exp_console_out += test_str[:self.console.line_limit] - - # Down arrow key. - input_stream.extend(Keys.DOWN_ARROW) - # Since the line was blank, we have to backspace the entire line. - exp_console_out += self.console.line_limit * BACKSPACE_STRING - - # CTRL+P - input_stream.append(console.ControlKey.CTRL_P) - exp_console_out += test_str[:self.console.line_limit] - - # CTRL+N - input_stream.append(console.ControlKey.CTRL_N) - # Since the line was blank, we have to backspace the entire line. - exp_console_out += self.console.line_limit * BACKSPACE_STRING - - # Press the Up arrow key to reprint the long line. - input_stream.extend(Keys.UP_ARROW) - exp_console_out += test_str[:self.console.line_limit] - - # Press the Home key to jump to the beginning of the line. - input_stream.extend(Keys.HOME) - exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit) - - # Press the End key to jump to the end of the line. - input_stream.extend(Keys.END) - exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit) - - # Press CTRL+A to jump to the beginning of the line. - input_stream.append(console.ControlKey.CTRL_A) - exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit) - - # Press CTRL+E to jump to the end of the line. - input_stream.extend(Keys.END) - exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit) - - # Move left one column so we can delete a character. - input_stream.extend(Keys.LEFT_ARROW) - exp_console_out += OutputStream.MoveCursorLeft(1) - - # Press the delete key. - input_stream.extend(Keys.DEL) - # This should look like a space, and then move cursor left 1 column since - # we're at the end of line. - exp_console_out += b' ' + OutputStream.MoveCursorLeft(1) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify everything happened correctly. - CheckConsoleOutput(self, exp_console_out) - - def test_BackspaceOnEmptyLine(self): - """Verify that we can backspace on an empty line with no bad effects.""" - # Send a single backspace. - test_str = [console.ControlKey.BACKSPACE] - - # Send the characters in. - for byte in test_str: - self.console.HandleChar(byte) - - # Check the input position. - exp_pos = 0 - CheckInputBufferPosition(self, exp_pos) - - # Check that buffer is empty. - exp_input_buffer = b'' - CheckInputBuffer(self, exp_input_buffer) - - # Check that the console output is empty. - exp_console_out = b'' - CheckConsoleOutput(self, exp_console_out) - - def test_BackspaceWithinLine(self): - """Verify that we shift the chars over when backspacing within a line.""" - # Misspell 'help' - test_str = b'heelp' - input_stream = BytesToByteList(test_str) - # Use the arrow key to go back to fix it. - # Move cursor left 1 column. - input_stream.extend(2*Keys.LEFT_ARROW) - # Backspace once to remove the extra 'e'. - input_stream.append(console.ControlKey.BACKSPACE) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify the input buffer - exp_input_buffer = b'help' - CheckInputBuffer(self, exp_input_buffer) - - # Verify the input buffer position. It should be at 2 (cursor over the 'l') - CheckInputBufferPosition(self, 2) - - # We expect the console output to be the test string, with two moves to the - # left, another move left, and then the rest of the line followed by a - # space. - exp_console_out = test_str - exp_console_out += 2 * OutputStream.MoveCursorLeft(1) - - # Move cursor left 1 column. - exp_console_out += OutputStream.MoveCursorLeft(1) - # Rest of the line and a space. (test_str in this case) - exp_console_out += b'lp ' - # Reset the cursor 2 + 1 to the left. - exp_console_out += OutputStream.MoveCursorLeft(3) - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - def test_JumpToBeginningOfLineViaCtrlA(self): - """Verify that we can jump to the beginning of a line with Ctrl+A.""" - # Enter some chars and press CTRL+A - test_str = b'abc' - input_stream = BytesToByteList(test_str) + [console.ControlKey.CTRL_A] - - # Send the characters in. - for byte in input_stream: - self.console.HandleChar(byte) - - # We expect to see our test string followed by a move cursor left. - exp_console_out = test_str - exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) - - # Check to see what whas printed on the console. - CheckConsoleOutput(self, exp_console_out) - - # Check that the input buffer position is now 0. - CheckInputBufferPosition(self, 0) - - # Check input buffer still contains our test string. - CheckInputBuffer(self, test_str) - - def test_JumpToBeginningOfLineViaHomeKey(self): - """Jump to beginning of line via HOME key.""" - test_str = b'version' - input_stream = BytesToByteList(test_str) - input_stream.extend(Keys.HOME) - - # Send out the stream. - for byte in input_stream: - self.console.HandleChar(byte) - - # First, verify that input buffer position is now 0. - CheckInputBufferPosition(self, 0) - - # Next, verify that the input buffer did not change. - CheckInputBuffer(self, test_str) - - # Lastly, check that the cursor moved correctly. - exp_console_out = test_str - exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) - CheckConsoleOutput(self, exp_console_out) - - def test_JumpToEndOfLineViaEndKey(self): - """Jump to the end of the line using the END key.""" - test_str = b'version' - input_stream = BytesToByteList(test_str) - input_stream += [console.ControlKey.CTRL_A] - # Now, jump to the end of the line. - input_stream.extend(Keys.END) - - # Send out the stream. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the input buffer position is correct. This should be at the - # end of the test string. - CheckInputBufferPosition(self, len(test_str)) - - # The expected output should be the test string, followed by a jump to the - # beginning of the line, and lastly a jump to the end of the line. - exp_console_out = test_str - exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) - # Now the jump back to the end of the line. - exp_console_out += OutputStream.MoveCursorRight(len(test_str)) - - # Verify console output stream. - CheckConsoleOutput(self, exp_console_out) - - def test_JumpToEndOfLineViaCtrlE(self): - """Enter some chars and then try to jump to the end. (Should be a no-op)""" - test_str = b'sysinfo' - input_stream = BytesToByteList(test_str) - input_stream.append(console.ControlKey.CTRL_E) - - # Send out the stream - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the input buffer position isn't any further than we expect. - # At this point, the position should be at the end of the test string. - CheckInputBufferPosition(self, len(test_str)) - - # Now, let's try to jump to the beginning and then jump back to the end. - input_stream = [console.ControlKey.CTRL_A, console.ControlKey.CTRL_E] - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Perform the same verification. - CheckInputBufferPosition(self, len(test_str)) - - # Lastly try to jump again, beyond the end. - input_stream = [console.ControlKey.CTRL_E] - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Perform the same verification. - CheckInputBufferPosition(self, len(test_str)) - - # We expect to see the test string, a jump to the beginning of the line, and - # one jump to the end of the line. - exp_console_out = test_str - # Jump to beginning. - exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) - # Jump back to end. - exp_console_out += OutputStream.MoveCursorRight(len(test_str)) - - # Verify the console output. - CheckConsoleOutput(self, exp_console_out) - - def test_MoveLeftWithArrowKey(self): - """Move cursor left one column with arrow key.""" - test_str = b'tastyspam' - input_stream = BytesToByteList(test_str) - input_stream.extend(Keys.LEFT_ARROW) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the input buffer position is 1 less than the length. - CheckInputBufferPosition(self, len(test_str) - 1) - - # Also, verify that the input buffer is not modified. - CheckInputBuffer(self, test_str) - - # We expect the test string, followed by a one column move left. - exp_console_out = test_str + OutputStream.MoveCursorLeft(1) - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - def test_MoveLeftWithCtrlB(self): - """Move cursor back one column with Ctrl+B.""" - test_str = b'tastyspam' - input_stream = BytesToByteList(test_str) - input_stream.append(console.ControlKey.CTRL_B) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the input buffer position is 1 less than the length. - CheckInputBufferPosition(self, len(test_str) - 1) + Args: + test_case: A unittest.TestCase object representing the current unit test. + exp_pos: An integer representing the expected input buffer position. + """ + test_case.assertEqual( + exp_pos, + test_case.console.input_buffer_pos, + "input buffer position is incorrect.\ngot: " + + str(test_case.console.input_buffer_pos) + + "\nexp: " + + str(exp_pos) + + "\n" + + str(test_case.console), + ) - # Also, verify that the input buffer is not modified. - CheckInputBuffer(self, test_str) - # We expect the test string, followed by a one column move left. - exp_console_out = test_str + OutputStream.MoveCursorLeft(1) +def CheckHistoryBuffer(test_case, exp_history): + """Verify that the items in the history buffer are what we expect. - # Verify console output. - CheckConsoleOutput(self, exp_console_out) + Args: + test_case: A unittest.TestCase object representing the current unit test. + exp_history: A list of strings representing the expected contents of the + history buffer. + """ + # First, check to see if the length is what we expect. + test_case.assertEqual( + len(exp_history), + len(test_case.console.history), + ( + "The number of items in the history is unexpected.\n" + "exp: " + str(len(exp_history)) + "\n" + "got: " + str(len(test_case.console.history)) + "\n" + "internal state:\n" + str(test_case.console) + ), + ) + + # Next, check the actual contents of the history buffer. + for i in range(len(exp_history)): + test_case.assertEqual( + exp_history[i], + test_case.console.history[i], + ( + b"history buffer contents are incorrect.\n" + b"exp: " + exp_history[i] + b"\n" + b"got: " + test_case.console.history[i] + b"\n" + b"internal state:\n" + str(test_case.console).encode("ascii") + ), + ) - def test_MoveRightWithArrowKey(self): - """Move cursor one column to the right with the arrow key.""" - test_str = b'version' - input_stream = BytesToByteList(test_str) - # Jump to beginning of line. - input_stream.append(console.ControlKey.CTRL_A) - # Press right arrow key. - input_stream.extend(Keys.RIGHT_ARROW) - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) +class TestConsoleEditingMethods(unittest.TestCase): + """Test case to verify all console editing methods.""" + + def setUp(self): + """Setup the test harness.""" + # Setup logging with a timestamp, the module, and the log level. + logging.basicConfig( + level=logging.DEBUG, + format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"), + ) + + # Create a temp file and set both the controller and peripheral PTYs to the + # file to create a loopback. + self.tempfile = tempfile.TemporaryFile() + + # Create some mock pipes. These won't be used since we'll mock out sends + # to the interpreter. + mock_pipe_end_0, mock_pipe_end_1 = threadproc_shim.Pipe() + self.console = console.Console( + self.tempfile.fileno(), + self.tempfile, + tempfile.TemporaryFile(), + mock_pipe_end_0, + mock_pipe_end_1, + "EC", + ) + + # Console editing methods are only valid for enhanced EC images, therefore + # we have to assume that the "EC" we're talking to is enhanced. By default, + # the console believes that the EC it's communicating with is NOT enhanced + # which is why we have to override it here. + self.console.enhanced_ec = True + self.console.CheckForEnhancedECImage = mock.MagicMock(return_value=True) + + def test_EnteringChars(self): + """Verify that characters are echoed onto the console.""" + test_str = b"abc" + input_stream = BytesToByteList(test_str) + + # Send the characters in. + for byte in input_stream: + self.console.HandleChar(byte) + + # Check the input position. + exp_pos = len(test_str) + CheckInputBufferPosition(self, exp_pos) + + # Verify that the input buffer is correct. + expected_buffer = test_str + CheckInputBuffer(self, expected_buffer) + + # Check console output + exp_console_out = test_str + CheckConsoleOutput(self, exp_console_out) + + def test_EnteringDeletingMoreCharsThanEntered(self): + """Verify that we can press backspace more than we have entered chars.""" + test_str = b"spamspam" + input_stream = BytesToByteList(test_str) + + # Send the characters in. + for byte in input_stream: + self.console.HandleChar(byte) + + # Now backspace 1 more than what we sent. + input_stream = [] + for _ in range(len(test_str) + 1): + input_stream.append(console.ControlKey.BACKSPACE) + + # Send that sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # First, verify that input buffer position is 0. + CheckInputBufferPosition(self, 0) + + # Next, examine the output stream for the correct sequence. + exp_console_out = test_str + for _ in range(len(test_str)): + exp_console_out += BACKSPACE_STRING + + # Now, verify that we got what we expected. + CheckConsoleOutput(self, exp_console_out) + + def test_EnteringMoreThanCharLimit(self): + """Verify that we drop characters when the line is too long.""" + test_str = self.console.line_limit * b"o" # All allowed. + test_str += 5 * b"x" # All should be dropped. + input_stream = BytesToByteList(test_str) + + # Send the characters in. + for byte in input_stream: + self.console.HandleChar(byte) + + # First, we expect that input buffer position should be equal to the line + # limit. + exp_pos = self.console.line_limit + CheckInputBufferPosition(self, exp_pos) + + # The input buffer should only hold until the line limit. + exp_buffer = test_str[0 : self.console.line_limit] + CheckInputBuffer(self, exp_buffer) + + # Lastly, check that the extra characters are not printed. + exp_console_out = exp_buffer + CheckConsoleOutput(self, exp_console_out) + + def test_ValidKeysOnLongLine(self): + """Verify that we can still press valid keys if the line is too long.""" + # Fill the line. + test_str = self.console.line_limit * b"o" + exp_console_out = test_str + # Try to fill it even more; these should all be dropped. + test_str += 5 * b"x" + input_stream = BytesToByteList(test_str) + + # We should be able to press the following keys: + # - Backspace + # - Arrow Keys/CTRL+B/CTRL+F/CTRL+P/CTRL+N + # - Delete + # - Home/CTRL+A + # - End/CTRL+E + # - Carriage Return + + # Backspace 1 character + input_stream.append(console.ControlKey.BACKSPACE) + exp_console_out += BACKSPACE_STRING + # Refill the line. + input_stream.extend(BytesToByteList(b"o")) + exp_console_out += b"o" + + # Left arrow key. + input_stream.extend(Keys.LEFT_ARROW) + exp_console_out += OutputStream.MoveCursorLeft(1) + + # Right arrow key. + input_stream.extend(Keys.RIGHT_ARROW) + exp_console_out += OutputStream.MoveCursorRight(1) + + # CTRL+B + input_stream.append(console.ControlKey.CTRL_B) + exp_console_out += OutputStream.MoveCursorLeft(1) + + # CTRL+F + input_stream.append(console.ControlKey.CTRL_F) + exp_console_out += OutputStream.MoveCursorRight(1) + + # Let's press enter now so we can test up and down. + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + exp_console_out += b"\r\n" + self.console.prompt + + # Up arrow key. + input_stream.extend(Keys.UP_ARROW) + exp_console_out += test_str[: self.console.line_limit] + + # Down arrow key. + input_stream.extend(Keys.DOWN_ARROW) + # Since the line was blank, we have to backspace the entire line. + exp_console_out += self.console.line_limit * BACKSPACE_STRING + + # CTRL+P + input_stream.append(console.ControlKey.CTRL_P) + exp_console_out += test_str[: self.console.line_limit] + + # CTRL+N + input_stream.append(console.ControlKey.CTRL_N) + # Since the line was blank, we have to backspace the entire line. + exp_console_out += self.console.line_limit * BACKSPACE_STRING + + # Press the Up arrow key to reprint the long line. + input_stream.extend(Keys.UP_ARROW) + exp_console_out += test_str[: self.console.line_limit] + + # Press the Home key to jump to the beginning of the line. + input_stream.extend(Keys.HOME) + exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit) + + # Press the End key to jump to the end of the line. + input_stream.extend(Keys.END) + exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit) + + # Press CTRL+A to jump to the beginning of the line. + input_stream.append(console.ControlKey.CTRL_A) + exp_console_out += OutputStream.MoveCursorLeft(self.console.line_limit) + + # Press CTRL+E to jump to the end of the line. + input_stream.extend(Keys.END) + exp_console_out += OutputStream.MoveCursorRight(self.console.line_limit) + + # Move left one column so we can delete a character. + input_stream.extend(Keys.LEFT_ARROW) + exp_console_out += OutputStream.MoveCursorLeft(1) + + # Press the delete key. + input_stream.extend(Keys.DEL) + # This should look like a space, and then move cursor left 1 column since + # we're at the end of line. + exp_console_out += b" " + OutputStream.MoveCursorLeft(1) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify everything happened correctly. + CheckConsoleOutput(self, exp_console_out) + + def test_BackspaceOnEmptyLine(self): + """Verify that we can backspace on an empty line with no bad effects.""" + # Send a single backspace. + test_str = [console.ControlKey.BACKSPACE] + + # Send the characters in. + for byte in test_str: + self.console.HandleChar(byte) + + # Check the input position. + exp_pos = 0 + CheckInputBufferPosition(self, exp_pos) + + # Check that buffer is empty. + exp_input_buffer = b"" + CheckInputBuffer(self, exp_input_buffer) + + # Check that the console output is empty. + exp_console_out = b"" + CheckConsoleOutput(self, exp_console_out) + + def test_BackspaceWithinLine(self): + """Verify that we shift the chars over when backspacing within a line.""" + # Misspell 'help' + test_str = b"heelp" + input_stream = BytesToByteList(test_str) + # Use the arrow key to go back to fix it. + # Move cursor left 1 column. + input_stream.extend(2 * Keys.LEFT_ARROW) + # Backspace once to remove the extra 'e'. + input_stream.append(console.ControlKey.BACKSPACE) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify the input buffer + exp_input_buffer = b"help" + CheckInputBuffer(self, exp_input_buffer) + + # Verify the input buffer position. It should be at 2 (cursor over the 'l') + CheckInputBufferPosition(self, 2) + + # We expect the console output to be the test string, with two moves to the + # left, another move left, and then the rest of the line followed by a + # space. + exp_console_out = test_str + exp_console_out += 2 * OutputStream.MoveCursorLeft(1) + + # Move cursor left 1 column. + exp_console_out += OutputStream.MoveCursorLeft(1) + # Rest of the line and a space. (test_str in this case) + exp_console_out += b"lp " + # Reset the cursor 2 + 1 to the left. + exp_console_out += OutputStream.MoveCursorLeft(3) + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + def test_JumpToBeginningOfLineViaCtrlA(self): + """Verify that we can jump to the beginning of a line with Ctrl+A.""" + # Enter some chars and press CTRL+A + test_str = b"abc" + input_stream = BytesToByteList(test_str) + [console.ControlKey.CTRL_A] + + # Send the characters in. + for byte in input_stream: + self.console.HandleChar(byte) + + # We expect to see our test string followed by a move cursor left. + exp_console_out = test_str + exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) + + # Check to see what whas printed on the console. + CheckConsoleOutput(self, exp_console_out) + + # Check that the input buffer position is now 0. + CheckInputBufferPosition(self, 0) + + # Check input buffer still contains our test string. + CheckInputBuffer(self, test_str) + + def test_JumpToBeginningOfLineViaHomeKey(self): + """Jump to beginning of line via HOME key.""" + test_str = b"version" + input_stream = BytesToByteList(test_str) + input_stream.extend(Keys.HOME) + + # Send out the stream. + for byte in input_stream: + self.console.HandleChar(byte) + + # First, verify that input buffer position is now 0. + CheckInputBufferPosition(self, 0) + + # Next, verify that the input buffer did not change. + CheckInputBuffer(self, test_str) + + # Lastly, check that the cursor moved correctly. + exp_console_out = test_str + exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) + CheckConsoleOutput(self, exp_console_out) + + def test_JumpToEndOfLineViaEndKey(self): + """Jump to the end of the line using the END key.""" + test_str = b"version" + input_stream = BytesToByteList(test_str) + input_stream += [console.ControlKey.CTRL_A] + # Now, jump to the end of the line. + input_stream.extend(Keys.END) + + # Send out the stream. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the input buffer position is correct. This should be at the + # end of the test string. + CheckInputBufferPosition(self, len(test_str)) + + # The expected output should be the test string, followed by a jump to the + # beginning of the line, and lastly a jump to the end of the line. + exp_console_out = test_str + exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) + # Now the jump back to the end of the line. + exp_console_out += OutputStream.MoveCursorRight(len(test_str)) + + # Verify console output stream. + CheckConsoleOutput(self, exp_console_out) + + def test_JumpToEndOfLineViaCtrlE(self): + """Enter some chars and then try to jump to the end. (Should be a no-op)""" + test_str = b"sysinfo" + input_stream = BytesToByteList(test_str) + input_stream.append(console.ControlKey.CTRL_E) + + # Send out the stream + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the input buffer position isn't any further than we expect. + # At this point, the position should be at the end of the test string. + CheckInputBufferPosition(self, len(test_str)) + + # Now, let's try to jump to the beginning and then jump back to the end. + input_stream = [console.ControlKey.CTRL_A, console.ControlKey.CTRL_E] + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Perform the same verification. + CheckInputBufferPosition(self, len(test_str)) + + # Lastly try to jump again, beyond the end. + input_stream = [console.ControlKey.CTRL_E] + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Perform the same verification. + CheckInputBufferPosition(self, len(test_str)) + + # We expect to see the test string, a jump to the beginning of the line, and + # one jump to the end of the line. + exp_console_out = test_str + # Jump to beginning. + exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) + # Jump back to end. + exp_console_out += OutputStream.MoveCursorRight(len(test_str)) + + # Verify the console output. + CheckConsoleOutput(self, exp_console_out) + + def test_MoveLeftWithArrowKey(self): + """Move cursor left one column with arrow key.""" + test_str = b"tastyspam" + input_stream = BytesToByteList(test_str) + input_stream.extend(Keys.LEFT_ARROW) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the input buffer position is 1 less than the length. + CheckInputBufferPosition(self, len(test_str) - 1) + + # Also, verify that the input buffer is not modified. + CheckInputBuffer(self, test_str) + + # We expect the test string, followed by a one column move left. + exp_console_out = test_str + OutputStream.MoveCursorLeft(1) + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + def test_MoveLeftWithCtrlB(self): + """Move cursor back one column with Ctrl+B.""" + test_str = b"tastyspam" + input_stream = BytesToByteList(test_str) + input_stream.append(console.ControlKey.CTRL_B) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the input buffer position is 1 less than the length. + CheckInputBufferPosition(self, len(test_str) - 1) - # Verify that the input buffer position is 1. - CheckInputBufferPosition(self, 1) + # Also, verify that the input buffer is not modified. + CheckInputBuffer(self, test_str) - # Also, verify that the input buffer is not modified. - CheckInputBuffer(self, test_str) + # We expect the test string, followed by a one column move left. + exp_console_out = test_str + OutputStream.MoveCursorLeft(1) - # We expect the test string, followed by a jump to the beginning of the - # line, and finally a move right 1. - exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str))) - - # A move right 1 column. - exp_console_out += OutputStream.MoveCursorRight(1) - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - def test_MoveRightWithCtrlF(self): - """Move cursor forward one column with Ctrl+F.""" - test_str = b'panicinfo' - input_stream = BytesToByteList(test_str) - input_stream.append(console.ControlKey.CTRL_A) - # Now, move right one column. - input_stream.append(console.ControlKey.CTRL_F) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the input buffer position is 1. - CheckInputBufferPosition(self, 1) - - # Also, verify that the input buffer is not modified. - CheckInputBuffer(self, test_str) - - # We expect the test string, followed by a jump to the beginning of the - # line, and finally a move right 1. - exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str))) - - # A move right 1 column. - exp_console_out += OutputStream.MoveCursorRight(1) - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - def test_ImpossibleMoveLeftWithArrowKey(self): - """Verify that we can't move left at the beginning of the line.""" - # We shouldn't be able to move left if we're at the beginning of the line. - input_stream = Keys.LEFT_ARROW - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Nothing should have been output. - exp_console_output = b'' - CheckConsoleOutput(self, exp_console_output) - - # The input buffer position should still be 0. - CheckInputBufferPosition(self, 0) - - # The input buffer itself should be empty. - CheckInputBuffer(self, b'') - - def test_ImpossibleMoveRightWithArrowKey(self): - """Verify that we can't move right at the end of the line.""" - # We shouldn't be able to move right if we're at the end of the line. - input_stream = Keys.RIGHT_ARROW - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Nothing should have been output. - exp_console_output = b'' - CheckConsoleOutput(self, exp_console_output) - - # The input buffer position should still be 0. - CheckInputBufferPosition(self, 0) - - # The input buffer itself should be empty. - CheckInputBuffer(self, b'') - - def test_KillEntireLine(self): - """Verify that we can kill an entire line with Ctrl+K.""" - test_str = b'accelinfo on' - input_stream = BytesToByteList(test_str) - # Jump to beginning of line and then kill it with Ctrl+K. - input_stream.extend([console.ControlKey.CTRL_A, console.ControlKey.CTRL_K]) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # First, we expect that the input buffer is empty. - CheckInputBuffer(self, b'') - - # The buffer position should be 0. - CheckInputBufferPosition(self, 0) - - # What we expect to see on the console stream should be the following. The - # test string, a jump to the beginning of the line, then jump back to the - # end of the line and replace the line with spaces. - exp_console_out = test_str - # Jump to beginning of line. - exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) - # Jump to end of line. - exp_console_out += OutputStream.MoveCursorRight(len(test_str)) - # Replace line with spaces, which looks like backspaces. - for _ in range(len(test_str)): - exp_console_out += BACKSPACE_STRING - - # Verify the console output. - CheckConsoleOutput(self, exp_console_out) - - def test_KillPartialLine(self): - """Verify that we can kill a portion of a line.""" - test_str = b'accelread 0 1' - input_stream = BytesToByteList(test_str) - len_to_kill = 5 - for _ in range(len_to_kill): - # Move cursor left - input_stream.extend(Keys.LEFT_ARROW) - # Now kill - input_stream.append(console.ControlKey.CTRL_K) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # First, check that the input buffer was truncated. - exp_input_buffer = test_str[:-len_to_kill] - CheckInputBuffer(self, exp_input_buffer) - - # Verify the input buffer position. - CheckInputBufferPosition(self, len(test_str) - len_to_kill) - - # The console output stream that we expect is the test string followed by a - # move left of len_to_kill, then a jump to the end of the line and backspace - # of len_to_kill. - exp_console_out = test_str - for _ in range(len_to_kill): - # Move left 1 column. - exp_console_out += OutputStream.MoveCursorLeft(1) - # Then jump to the end of the line - exp_console_out += OutputStream.MoveCursorRight(len_to_kill) - # Backspace of len_to_kill - for _ in range(len_to_kill): - exp_console_out += BACKSPACE_STRING - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - def test_InsertingCharacters(self): - """Verify that we can insert characters within the line.""" - test_str = b'accel 0 1' # Here we forgot the 'read' part in 'accelread' - input_stream = BytesToByteList(test_str) - # We need to move over to the 'l' and add read. - insertion_point = test_str.find(b'l') + 1 - for i in range(len(test_str) - insertion_point): - # Move cursor left. - input_stream.extend(Keys.LEFT_ARROW) - # Now, add in 'read' - added_str = b'read' - input_stream.extend(BytesToByteList(added_str)) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # First, verify that the input buffer is correct. - exp_input_buffer = test_str[:insertion_point] + added_str - exp_input_buffer += test_str[insertion_point:] - CheckInputBuffer(self, exp_input_buffer) - - # Verify that the input buffer position is correct. - exp_input_buffer_pos = insertion_point + len(added_str) - CheckInputBufferPosition(self, exp_input_buffer_pos) - - # The console output stream that we expect is the test string, followed by - # move cursor left until the 'l' was found, the added test string while - # shifting characters around. - exp_console_out = test_str - for i in range(len(test_str) - insertion_point): - # Move cursor left. - exp_console_out += OutputStream.MoveCursorLeft(1) - - # Now for each character, write the rest of the line will be shifted to the - # right one column. - for i in range(len(added_str)): - # Printed character. - exp_console_out += added_str[i:i+1] - # The rest of the line - exp_console_out += test_str[insertion_point:] - # Reset the cursor back left - reset_dist = len(test_str[insertion_point:]) - exp_console_out += OutputStream.MoveCursorLeft(reset_dist) - - # Verify the console output. - CheckConsoleOutput(self, exp_console_out) - - def test_StoreCommandHistory(self): - """Verify that entered commands are stored in the history.""" - test_commands = [] - test_commands.append(b'help') - test_commands.append(b'version') - test_commands.append(b'accelread 0 1') - input_stream = [] - for c in test_commands: - input_stream.extend(BytesToByteList(c)) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # We expect to have the test commands in the history buffer. - exp_history_buf = test_commands - CheckHistoryBuffer(self, exp_history_buf) - - def test_CycleUpThruCommandHistory(self): - """Verify that the UP arrow key will print itmes in the history buffer.""" - # Enter some commands. - test_commands = [b'version', b'accelrange 0', b'battery', b'gettime'] - input_stream = [] - for command in test_commands: - input_stream.extend(BytesToByteList(command)) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Now, hit the UP arrow key to print the previous entries. - for i in range(len(test_commands)): - input_stream.extend(Keys.UP_ARROW) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The expected output should be test commands with prompts printed in - # between, followed by line kills with the previous test commands printed. - exp_console_out = b'' - for i in range(len(test_commands)): - exp_console_out += test_commands[i] + b'\r\n' + self.console.prompt - - # When we press up, the line should be cleared and print the previous buffer - # entry. - for i in range(len(test_commands)-1, 0, -1): - exp_console_out += test_commands[i] - # Backspace to the beginning. - for i in range(len(test_commands[i])): - exp_console_out += BACKSPACE_STRING + # Verify console output. + CheckConsoleOutput(self, exp_console_out) - # The last command should just be printed out with no backspacing. - exp_console_out += test_commands[0] - - # Now, verify. - CheckConsoleOutput(self, exp_console_out) - - def test_UpArrowOnEmptyHistory(self): - """Ensure nothing happens if the history is empty.""" - # Press the up arrow key twice. - input_stream = 2 * Keys.UP_ARROW - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # We expect nothing to have happened. - exp_console_out = b'' - exp_input_buffer = b'' - exp_input_buffer_pos = 0 - exp_history_buf = [] - - # Verify. - CheckConsoleOutput(self, exp_console_out) - CheckInputBufferPosition(self, exp_input_buffer_pos) - CheckInputBuffer(self, exp_input_buffer) - CheckHistoryBuffer(self, exp_history_buf) - - def test_UpArrowDoesNotGoOutOfBounds(self): - """Verify that pressing the up arrow many times won't go out of bounds.""" - # Enter one command. - test_str = b'help version' - input_stream = BytesToByteList(test_str) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - # Then press the up arrow key twice. - input_stream.extend(2 * Keys.UP_ARROW) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the history buffer is correct. - exp_history_buf = [test_str] - CheckHistoryBuffer(self, exp_history_buf) - - # We expect that the console output should only contain our entered command, - # a new prompt, and then our command aggain. - exp_console_out = test_str + b'\r\n' + self.console.prompt - # Pressing up should reprint the command we entered. - exp_console_out += test_str - - # Verify. - CheckConsoleOutput(self, exp_console_out) - - def test_CycleDownThruCommandHistory(self): - """Verify that we can select entries by hitting the down arrow.""" - # Enter at least 4 commands. - test_commands = [b'version', b'accelrange 0', b'battery', b'gettime'] - input_stream = [] - for command in test_commands: - input_stream.extend(BytesToByteList(command)) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Now, hit the UP arrow key twice to print the previous two entries. - for i in range(2): - input_stream.extend(Keys.UP_ARROW) - - # Now, hit the DOWN arrow key twice to print the newer entries. - input_stream.extend(2*Keys.DOWN_ARROW) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The expected output should be commands that we entered, followed by - # prompts, then followed by our last two commands in reverse. Then, we - # should see the last entry in the list, followed by the saved partial cmd - # of a blank line. - exp_console_out = b'' - for i in range(len(test_commands)): - exp_console_out += test_commands[i] + b'\r\n' + self.console.prompt - - # When we press up, the line should be cleared and print the previous buffer - # entry. - for i in range(len(test_commands)-1, 1, -1): - exp_console_out += test_commands[i] - # Backspace to the beginning. - for i in range(len(test_commands[i])): - exp_console_out += BACKSPACE_STRING + def test_MoveRightWithArrowKey(self): + """Move cursor one column to the right with the arrow key.""" + test_str = b"version" + input_stream = BytesToByteList(test_str) + # Jump to beginning of line. + input_stream.append(console.ControlKey.CTRL_A) + # Press right arrow key. + input_stream.extend(Keys.RIGHT_ARROW) - # When we press down, it should have cleared the last command (which we - # covered with the previous for loop), and then prints the next command. - exp_console_out += test_commands[3] - for i in range(len(test_commands[3])): - exp_console_out += BACKSPACE_STRING - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - # Verify input buffer. - exp_input_buffer = b'' # Empty because our partial command was empty. - exp_input_buffer_pos = len(exp_input_buffer) - CheckInputBuffer(self, exp_input_buffer) - CheckInputBufferPosition(self, exp_input_buffer_pos) - - def test_SavingPartialCommandWhenNavigatingHistory(self): - """Verify that partial commands are saved when navigating history.""" - # Enter a command. - test_str = b'accelinfo' - input_stream = BytesToByteList(test_str) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Enter a partial command. - partial_cmd = b'ver' - input_stream.extend(BytesToByteList(partial_cmd)) - - # Hit the UP arrow key. - input_stream.extend(Keys.UP_ARROW) - # Then, the DOWN arrow key. - input_stream.extend(Keys.DOWN_ARROW) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The expected output should be the command we entered, a prompt, the - # partial command, clearing of the partial command, the command entered, - # clearing of the command entered, and then the partial command. - exp_console_out = test_str + b'\r\n' + self.console.prompt - exp_console_out += partial_cmd - for _ in range(len(partial_cmd)): - exp_console_out += BACKSPACE_STRING - exp_console_out += test_str - for _ in range(len(test_str)): - exp_console_out += BACKSPACE_STRING - exp_console_out += partial_cmd - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - # Verify input buffer. - exp_input_buffer = partial_cmd - exp_input_buffer_pos = len(exp_input_buffer) - CheckInputBuffer(self, exp_input_buffer) - CheckInputBufferPosition(self, exp_input_buffer_pos) - - def test_DownArrowOnEmptyHistory(self): - """Ensure nothing happens if the history is empty.""" - # Then press the up down arrow twice. - input_stream = 2 * Keys.DOWN_ARROW - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # We expect nothing to have happened. - exp_console_out = b'' - exp_input_buffer = b'' - exp_input_buffer_pos = 0 - exp_history_buf = [] - - # Verify. - CheckConsoleOutput(self, exp_console_out) - CheckInputBufferPosition(self, exp_input_buffer_pos) - CheckInputBuffer(self, exp_input_buffer) - CheckHistoryBuffer(self, exp_history_buf) - - def test_DeleteCharsUsingDELKey(self): - """Verify that we can delete characters using the DEL key.""" - test_str = b'version' - input_stream = BytesToByteList(test_str) - - # Hit the left arrow key 2 times. - input_stream.extend(2 * Keys.LEFT_ARROW) - - # Press the DEL key. - input_stream.extend(Keys.DEL) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The expected output should be the command we entered, 2 individual cursor - # moves to the left, and then removing a char and shifting everything to the - # left one column. - exp_console_out = test_str - exp_console_out += 2 * OutputStream.MoveCursorLeft(1) - - # Remove the char by shifting everything to the left one, slicing out the - # remove char. - exp_console_out += test_str[-1:] + b' ' - - # Reset the cursor by moving back 2 columns because of the 'n' and space. - exp_console_out += OutputStream.MoveCursorLeft(2) - - # Verify console output. - CheckConsoleOutput(self, exp_console_out) - - # Verify input buffer. The input buffer should have the char sliced out and - # be positioned where the char was removed. - exp_input_buffer = test_str[:-2] + test_str[-1:] - exp_input_buffer_pos = len(exp_input_buffer) - 1 - CheckInputBuffer(self, exp_input_buffer) - CheckInputBufferPosition(self, exp_input_buffer_pos) - - def test_RepeatedCommandInHistory(self): - """Verify that we don't store 2 consecutive identical commands in history""" - # Enter a few commands. - test_commands = [b'version', b'accelrange 0', b'battery', b'gettime'] - # Repeat the last command. - test_commands.append(test_commands[len(test_commands)-1]) - - input_stream = [] - for command in test_commands: - input_stream.extend(BytesToByteList(command)) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Verify that the history buffer is correct. The last command, since - # it was repeated, should not have been added to the history. - exp_history_buf = test_commands[0:len(test_commands)-1] - CheckHistoryBuffer(self, exp_history_buf) + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + # Verify that the input buffer position is 1. + CheckInputBufferPosition(self, 1) -class TestConsoleCompatibility(unittest.TestCase): - """Verify that console can speak to enhanced and non-enhanced EC images.""" - def setUp(self): - """Setup the test harness.""" - # Setup logging with a timestamp, the module, and the log level. - logging.basicConfig(level=logging.DEBUG, - format=('%(asctime)s - %(module)s -' - ' %(levelname)s - %(message)s')) - # Create a temp file and set both the controller and peripheral PTYs to the - # file to create a loopback. - self.tempfile = tempfile.TemporaryFile() - - # Mock out the pipes. - mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock() - self.console = console.Console(self.tempfile.fileno(), self.tempfile, - tempfile.TemporaryFile(), - mock_pipe_end_0, mock_pipe_end_1, "EC") - - @mock.patch('ec3po.console.Console.CheckForEnhancedECImage') - def test_ActAsPassThruInNonEnhancedMode(self, mock_check): - """Verify we simply pass everything thru to non-enhanced ECs. + # Also, verify that the input buffer is not modified. + CheckInputBuffer(self, test_str) - Args: - mock_check: A MagicMock object replacing the CheckForEnhancedECImage() - method. - """ - # Set the interrogation mode to always so that we actually interrogate. - self.console.interrogation_mode = b'always' - - # Assume EC interrogations indicate that the image is non-enhanced. - mock_check.return_value = False - - # Press enter, followed by the command, and another enter. - input_stream = [] - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - test_command = b'version' - input_stream.extend(BytesToByteList(test_command)) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Expected calls to send down the pipe would be each character of the test - # command. - expected_calls = [] - expected_calls.append(mock.call( - six.int2byte(console.ControlKey.CARRIAGE_RETURN))) - for char in test_command: - if six.PY3: - expected_calls.append(mock.call(bytes([char]))) - else: - expected_calls.append(mock.call(char)) - expected_calls.append(mock.call( - six.int2byte(console.ControlKey.CARRIAGE_RETURN))) - - # Verify that the calls happened. - self.console.cmd_pipe.send.assert_has_calls(expected_calls) - - # Since we're acting as a pass-thru, the input buffer should be empty and - # input_buffer_pos is 0. - CheckInputBuffer(self, b'') - CheckInputBufferPosition(self, 0) - - @mock.patch('ec3po.console.Console.CheckForEnhancedECImage') - def test_TransitionFromNonEnhancedToEnhanced(self, mock_check): - """Verify that we transition correctly to enhanced mode. + # We expect the test string, followed by a jump to the beginning of the + # line, and finally a move right 1. + exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str))) + + # A move right 1 column. + exp_console_out += OutputStream.MoveCursorRight(1) + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + def test_MoveRightWithCtrlF(self): + """Move cursor forward one column with Ctrl+F.""" + test_str = b"panicinfo" + input_stream = BytesToByteList(test_str) + input_stream.append(console.ControlKey.CTRL_A) + # Now, move right one column. + input_stream.append(console.ControlKey.CTRL_F) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the input buffer position is 1. + CheckInputBufferPosition(self, 1) + + # Also, verify that the input buffer is not modified. + CheckInputBuffer(self, test_str) + + # We expect the test string, followed by a jump to the beginning of the + # line, and finally a move right 1. + exp_console_out = test_str + OutputStream.MoveCursorLeft(len((test_str))) + + # A move right 1 column. + exp_console_out += OutputStream.MoveCursorRight(1) + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + def test_ImpossibleMoveLeftWithArrowKey(self): + """Verify that we can't move left at the beginning of the line.""" + # We shouldn't be able to move left if we're at the beginning of the line. + input_stream = Keys.LEFT_ARROW + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Nothing should have been output. + exp_console_output = b"" + CheckConsoleOutput(self, exp_console_output) + + # The input buffer position should still be 0. + CheckInputBufferPosition(self, 0) + + # The input buffer itself should be empty. + CheckInputBuffer(self, b"") + + def test_ImpossibleMoveRightWithArrowKey(self): + """Verify that we can't move right at the end of the line.""" + # We shouldn't be able to move right if we're at the end of the line. + input_stream = Keys.RIGHT_ARROW + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Nothing should have been output. + exp_console_output = b"" + CheckConsoleOutput(self, exp_console_output) + + # The input buffer position should still be 0. + CheckInputBufferPosition(self, 0) + + # The input buffer itself should be empty. + CheckInputBuffer(self, b"") + + def test_KillEntireLine(self): + """Verify that we can kill an entire line with Ctrl+K.""" + test_str = b"accelinfo on" + input_stream = BytesToByteList(test_str) + # Jump to beginning of line and then kill it with Ctrl+K. + input_stream.extend([console.ControlKey.CTRL_A, console.ControlKey.CTRL_K]) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # First, we expect that the input buffer is empty. + CheckInputBuffer(self, b"") + + # The buffer position should be 0. + CheckInputBufferPosition(self, 0) + + # What we expect to see on the console stream should be the following. The + # test string, a jump to the beginning of the line, then jump back to the + # end of the line and replace the line with spaces. + exp_console_out = test_str + # Jump to beginning of line. + exp_console_out += OutputStream.MoveCursorLeft(len(test_str)) + # Jump to end of line. + exp_console_out += OutputStream.MoveCursorRight(len(test_str)) + # Replace line with spaces, which looks like backspaces. + for _ in range(len(test_str)): + exp_console_out += BACKSPACE_STRING + + # Verify the console output. + CheckConsoleOutput(self, exp_console_out) + + def test_KillPartialLine(self): + """Verify that we can kill a portion of a line.""" + test_str = b"accelread 0 1" + input_stream = BytesToByteList(test_str) + len_to_kill = 5 + for _ in range(len_to_kill): + # Move cursor left + input_stream.extend(Keys.LEFT_ARROW) + # Now kill + input_stream.append(console.ControlKey.CTRL_K) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # First, check that the input buffer was truncated. + exp_input_buffer = test_str[:-len_to_kill] + CheckInputBuffer(self, exp_input_buffer) + + # Verify the input buffer position. + CheckInputBufferPosition(self, len(test_str) - len_to_kill) + + # The console output stream that we expect is the test string followed by a + # move left of len_to_kill, then a jump to the end of the line and backspace + # of len_to_kill. + exp_console_out = test_str + for _ in range(len_to_kill): + # Move left 1 column. + exp_console_out += OutputStream.MoveCursorLeft(1) + # Then jump to the end of the line + exp_console_out += OutputStream.MoveCursorRight(len_to_kill) + # Backspace of len_to_kill + for _ in range(len_to_kill): + exp_console_out += BACKSPACE_STRING + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + def test_InsertingCharacters(self): + """Verify that we can insert characters within the line.""" + test_str = b"accel 0 1" # Here we forgot the 'read' part in 'accelread' + input_stream = BytesToByteList(test_str) + # We need to move over to the 'l' and add read. + insertion_point = test_str.find(b"l") + 1 + for i in range(len(test_str) - insertion_point): + # Move cursor left. + input_stream.extend(Keys.LEFT_ARROW) + # Now, add in 'read' + added_str = b"read" + input_stream.extend(BytesToByteList(added_str)) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # First, verify that the input buffer is correct. + exp_input_buffer = test_str[:insertion_point] + added_str + exp_input_buffer += test_str[insertion_point:] + CheckInputBuffer(self, exp_input_buffer) + + # Verify that the input buffer position is correct. + exp_input_buffer_pos = insertion_point + len(added_str) + CheckInputBufferPosition(self, exp_input_buffer_pos) + + # The console output stream that we expect is the test string, followed by + # move cursor left until the 'l' was found, the added test string while + # shifting characters around. + exp_console_out = test_str + for i in range(len(test_str) - insertion_point): + # Move cursor left. + exp_console_out += OutputStream.MoveCursorLeft(1) + + # Now for each character, write the rest of the line will be shifted to the + # right one column. + for i in range(len(added_str)): + # Printed character. + exp_console_out += added_str[i : i + 1] + # The rest of the line + exp_console_out += test_str[insertion_point:] + # Reset the cursor back left + reset_dist = len(test_str[insertion_point:]) + exp_console_out += OutputStream.MoveCursorLeft(reset_dist) + + # Verify the console output. + CheckConsoleOutput(self, exp_console_out) + + def test_StoreCommandHistory(self): + """Verify that entered commands are stored in the history.""" + test_commands = [] + test_commands.append(b"help") + test_commands.append(b"version") + test_commands.append(b"accelread 0 1") + input_stream = [] + for c in test_commands: + input_stream.extend(BytesToByteList(c)) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # We expect to have the test commands in the history buffer. + exp_history_buf = test_commands + CheckHistoryBuffer(self, exp_history_buf) + + def test_CycleUpThruCommandHistory(self): + """Verify that the UP arrow key will print itmes in the history buffer.""" + # Enter some commands. + test_commands = [b"version", b"accelrange 0", b"battery", b"gettime"] + input_stream = [] + for command in test_commands: + input_stream.extend(BytesToByteList(command)) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Now, hit the UP arrow key to print the previous entries. + for i in range(len(test_commands)): + input_stream.extend(Keys.UP_ARROW) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The expected output should be test commands with prompts printed in + # between, followed by line kills with the previous test commands printed. + exp_console_out = b"" + for i in range(len(test_commands)): + exp_console_out += test_commands[i] + b"\r\n" + self.console.prompt + + # When we press up, the line should be cleared and print the previous buffer + # entry. + for i in range(len(test_commands) - 1, 0, -1): + exp_console_out += test_commands[i] + # Backspace to the beginning. + for i in range(len(test_commands[i])): + exp_console_out += BACKSPACE_STRING + + # The last command should just be printed out with no backspacing. + exp_console_out += test_commands[0] + + # Now, verify. + CheckConsoleOutput(self, exp_console_out) + + def test_UpArrowOnEmptyHistory(self): + """Ensure nothing happens if the history is empty.""" + # Press the up arrow key twice. + input_stream = 2 * Keys.UP_ARROW + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # We expect nothing to have happened. + exp_console_out = b"" + exp_input_buffer = b"" + exp_input_buffer_pos = 0 + exp_history_buf = [] + + # Verify. + CheckConsoleOutput(self, exp_console_out) + CheckInputBufferPosition(self, exp_input_buffer_pos) + CheckInputBuffer(self, exp_input_buffer) + CheckHistoryBuffer(self, exp_history_buf) + + def test_UpArrowDoesNotGoOutOfBounds(self): + """Verify that pressing the up arrow many times won't go out of bounds.""" + # Enter one command. + test_str = b"help version" + input_stream = BytesToByteList(test_str) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + # Then press the up arrow key twice. + input_stream.extend(2 * Keys.UP_ARROW) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the history buffer is correct. + exp_history_buf = [test_str] + CheckHistoryBuffer(self, exp_history_buf) + + # We expect that the console output should only contain our entered command, + # a new prompt, and then our command aggain. + exp_console_out = test_str + b"\r\n" + self.console.prompt + # Pressing up should reprint the command we entered. + exp_console_out += test_str + + # Verify. + CheckConsoleOutput(self, exp_console_out) + + def test_CycleDownThruCommandHistory(self): + """Verify that we can select entries by hitting the down arrow.""" + # Enter at least 4 commands. + test_commands = [b"version", b"accelrange 0", b"battery", b"gettime"] + input_stream = [] + for command in test_commands: + input_stream.extend(BytesToByteList(command)) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Now, hit the UP arrow key twice to print the previous two entries. + for i in range(2): + input_stream.extend(Keys.UP_ARROW) + + # Now, hit the DOWN arrow key twice to print the newer entries. + input_stream.extend(2 * Keys.DOWN_ARROW) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The expected output should be commands that we entered, followed by + # prompts, then followed by our last two commands in reverse. Then, we + # should see the last entry in the list, followed by the saved partial cmd + # of a blank line. + exp_console_out = b"" + for i in range(len(test_commands)): + exp_console_out += test_commands[i] + b"\r\n" + self.console.prompt + + # When we press up, the line should be cleared and print the previous buffer + # entry. + for i in range(len(test_commands) - 1, 1, -1): + exp_console_out += test_commands[i] + # Backspace to the beginning. + for i in range(len(test_commands[i])): + exp_console_out += BACKSPACE_STRING + + # When we press down, it should have cleared the last command (which we + # covered with the previous for loop), and then prints the next command. + exp_console_out += test_commands[3] + for i in range(len(test_commands[3])): + exp_console_out += BACKSPACE_STRING + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + # Verify input buffer. + exp_input_buffer = b"" # Empty because our partial command was empty. + exp_input_buffer_pos = len(exp_input_buffer) + CheckInputBuffer(self, exp_input_buffer) + CheckInputBufferPosition(self, exp_input_buffer_pos) + + def test_SavingPartialCommandWhenNavigatingHistory(self): + """Verify that partial commands are saved when navigating history.""" + # Enter a command. + test_str = b"accelinfo" + input_stream = BytesToByteList(test_str) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Enter a partial command. + partial_cmd = b"ver" + input_stream.extend(BytesToByteList(partial_cmd)) + + # Hit the UP arrow key. + input_stream.extend(Keys.UP_ARROW) + # Then, the DOWN arrow key. + input_stream.extend(Keys.DOWN_ARROW) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The expected output should be the command we entered, a prompt, the + # partial command, clearing of the partial command, the command entered, + # clearing of the command entered, and then the partial command. + exp_console_out = test_str + b"\r\n" + self.console.prompt + exp_console_out += partial_cmd + for _ in range(len(partial_cmd)): + exp_console_out += BACKSPACE_STRING + exp_console_out += test_str + for _ in range(len(test_str)): + exp_console_out += BACKSPACE_STRING + exp_console_out += partial_cmd + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + # Verify input buffer. + exp_input_buffer = partial_cmd + exp_input_buffer_pos = len(exp_input_buffer) + CheckInputBuffer(self, exp_input_buffer) + CheckInputBufferPosition(self, exp_input_buffer_pos) + + def test_DownArrowOnEmptyHistory(self): + """Ensure nothing happens if the history is empty.""" + # Then press the up down arrow twice. + input_stream = 2 * Keys.DOWN_ARROW + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # We expect nothing to have happened. + exp_console_out = b"" + exp_input_buffer = b"" + exp_input_buffer_pos = 0 + exp_history_buf = [] + + # Verify. + CheckConsoleOutput(self, exp_console_out) + CheckInputBufferPosition(self, exp_input_buffer_pos) + CheckInputBuffer(self, exp_input_buffer) + CheckHistoryBuffer(self, exp_history_buf) + + def test_DeleteCharsUsingDELKey(self): + """Verify that we can delete characters using the DEL key.""" + test_str = b"version" + input_stream = BytesToByteList(test_str) + + # Hit the left arrow key 2 times. + input_stream.extend(2 * Keys.LEFT_ARROW) + + # Press the DEL key. + input_stream.extend(Keys.DEL) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The expected output should be the command we entered, 2 individual cursor + # moves to the left, and then removing a char and shifting everything to the + # left one column. + exp_console_out = test_str + exp_console_out += 2 * OutputStream.MoveCursorLeft(1) + + # Remove the char by shifting everything to the left one, slicing out the + # remove char. + exp_console_out += test_str[-1:] + b" " + + # Reset the cursor by moving back 2 columns because of the 'n' and space. + exp_console_out += OutputStream.MoveCursorLeft(2) + + # Verify console output. + CheckConsoleOutput(self, exp_console_out) + + # Verify input buffer. The input buffer should have the char sliced out and + # be positioned where the char was removed. + exp_input_buffer = test_str[:-2] + test_str[-1:] + exp_input_buffer_pos = len(exp_input_buffer) - 1 + CheckInputBuffer(self, exp_input_buffer) + CheckInputBufferPosition(self, exp_input_buffer_pos) + + def test_RepeatedCommandInHistory(self): + """Verify that we don't store 2 consecutive identical commands in history""" + # Enter a few commands. + test_commands = [b"version", b"accelrange 0", b"battery", b"gettime"] + # Repeat the last command. + test_commands.append(test_commands[len(test_commands) - 1]) + + input_stream = [] + for command in test_commands: + input_stream.extend(BytesToByteList(command)) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Verify that the history buffer is correct. The last command, since + # it was repeated, should not have been added to the history. + exp_history_buf = test_commands[0 : len(test_commands) - 1] + CheckHistoryBuffer(self, exp_history_buf) - Args: - mock_check: A MagicMock object replacing the CheckForEnhancedECImage() - method. - """ - # Set the interrogation mode to always so that we actually interrogate. - self.console.interrogation_mode = b'always' - - # First, assume that the EC interrogations indicate an enhanced EC image. - mock_check.return_value = True - # But our current knowledge of the EC image (which was actually the - # 'previous' EC) was a non-enhanced image. - self.console.enhanced_ec = False - - test_command = b'sysinfo' - input_stream = [] - input_stream.extend(BytesToByteList(test_command)) - - expected_calls = [] - # All keystrokes to the console should be directed straight through to the - # EC until we press the enter key. - for char in test_command: - if six.PY3: - expected_calls.append(mock.call(bytes([char]))) - else: - expected_calls.append(mock.call(char)) - - # Press the enter key. - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - # The enter key should not be sent to the pipe since we should negotiate - # to an enhanced EC image. - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # At this point, we should have negotiated to enhanced. - self.assertTrue(self.console.enhanced_ec, msg=('Did not negotiate to ' - 'enhanced EC image.')) - - # The command would have been dropped however, so verify this... - CheckInputBuffer(self, b'') - CheckInputBufferPosition(self, 0) - # ...and repeat the command. - input_stream = BytesToByteList(test_command) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Since we're enhanced now, we should have sent the entire command as one - # string with no trailing carriage return - expected_calls.append(mock.call(test_command)) - - # Verify all of the calls. - self.console.cmd_pipe.send.assert_has_calls(expected_calls) - - @mock.patch('ec3po.console.Console.CheckForEnhancedECImage') - def test_TransitionFromEnhancedToNonEnhanced(self, mock_check): - """Verify that we transition correctly to non-enhanced mode. - Args: - mock_check: A MagicMock object replacing the CheckForEnhancedECImage() - method. - """ - # Set the interrogation mode to always so that we actually interrogate. - self.console.interrogation_mode = b'always' - - # First, assume that the EC interrogations indicate an non-enhanced EC - # image. - mock_check.return_value = False - # But our current knowledge of the EC image (which was actually the - # 'previous' EC) was an enhanced image. - self.console.enhanced_ec = True - - test_command = b'sysinfo' - input_stream = [] - input_stream.extend(BytesToByteList(test_command)) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # But, we will negotiate to non-enhanced however, dropping this command. - # Verify this. - self.assertFalse(self.console.enhanced_ec, msg=('Did not negotiate to' - 'non-enhanced EC image.')) - CheckInputBuffer(self, b'') - CheckInputBufferPosition(self, 0) - - # The carriage return should have passed through though. - expected_calls = [] - expected_calls.append(mock.call( - six.int2byte(console.ControlKey.CARRIAGE_RETURN))) - - # Since the command was dropped, repeat the command. - input_stream = BytesToByteList(test_command) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # Since we're not enhanced now, we should have sent each character in the - # entire command separately and a carriage return. - for char in test_command: - if six.PY3: - expected_calls.append(mock.call(bytes([char]))) - else: - expected_calls.append(mock.call(char)) - expected_calls.append(mock.call( - six.int2byte(console.ControlKey.CARRIAGE_RETURN))) - - # Verify all of the calls. - self.console.cmd_pipe.send.assert_has_calls(expected_calls) - - def test_EnhancedCheckIfTimedOut(self): - """Verify that the check returns false if it times out.""" - # Make the debug pipe "time out". - self.console.dbg_pipe.poll.return_value = False - self.assertFalse(self.console.CheckForEnhancedECImage()) - - def test_EnhancedCheckIfACKReceived(self): - """Verify that the check returns true if the ACK is received.""" - # Make the debug pipe return EC_ACK. - self.console.dbg_pipe.poll.return_value = True - self.console.dbg_pipe.recv.return_value = interpreter.EC_ACK - self.assertTrue(self.console.CheckForEnhancedECImage()) - - def test_EnhancedCheckIfWrong(self): - """Verify that the check returns false if byte received is wrong.""" - # Make the debug pipe return the wrong byte. - self.console.dbg_pipe.poll.return_value = True - self.console.dbg_pipe.recv.return_value = b'\xff' - self.assertFalse(self.console.CheckForEnhancedECImage()) - - def test_EnhancedCheckUsingBuffer(self): - """Verify that given reboot output, enhanced EC images are detected.""" - enhanced_output_stream = b""" +class TestConsoleCompatibility(unittest.TestCase): + """Verify that console can speak to enhanced and non-enhanced EC images.""" + + def setUp(self): + """Setup the test harness.""" + # Setup logging with a timestamp, the module, and the log level. + logging.basicConfig( + level=logging.DEBUG, + format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"), + ) + # Create a temp file and set both the controller and peripheral PTYs to the + # file to create a loopback. + self.tempfile = tempfile.TemporaryFile() + + # Mock out the pipes. + mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock() + self.console = console.Console( + self.tempfile.fileno(), + self.tempfile, + tempfile.TemporaryFile(), + mock_pipe_end_0, + mock_pipe_end_1, + "EC", + ) + + @mock.patch("ec3po.console.Console.CheckForEnhancedECImage") + def test_ActAsPassThruInNonEnhancedMode(self, mock_check): + """Verify we simply pass everything thru to non-enhanced ECs. + + Args: + mock_check: A MagicMock object replacing the CheckForEnhancedECImage() + method. + """ + # Set the interrogation mode to always so that we actually interrogate. + self.console.interrogation_mode = b"always" + + # Assume EC interrogations indicate that the image is non-enhanced. + mock_check.return_value = False + + # Press enter, followed by the command, and another enter. + input_stream = [] + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + test_command = b"version" + input_stream.extend(BytesToByteList(test_command)) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Expected calls to send down the pipe would be each character of the test + # command. + expected_calls = [] + expected_calls.append( + mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN)) + ) + for char in test_command: + if six.PY3: + expected_calls.append(mock.call(bytes([char]))) + else: + expected_calls.append(mock.call(char)) + expected_calls.append( + mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN)) + ) + + # Verify that the calls happened. + self.console.cmd_pipe.send.assert_has_calls(expected_calls) + + # Since we're acting as a pass-thru, the input buffer should be empty and + # input_buffer_pos is 0. + CheckInputBuffer(self, b"") + CheckInputBufferPosition(self, 0) + + @mock.patch("ec3po.console.Console.CheckForEnhancedECImage") + def test_TransitionFromNonEnhancedToEnhanced(self, mock_check): + """Verify that we transition correctly to enhanced mode. + + Args: + mock_check: A MagicMock object replacing the CheckForEnhancedECImage() + method. + """ + # Set the interrogation mode to always so that we actually interrogate. + self.console.interrogation_mode = b"always" + + # First, assume that the EC interrogations indicate an enhanced EC image. + mock_check.return_value = True + # But our current knowledge of the EC image (which was actually the + # 'previous' EC) was a non-enhanced image. + self.console.enhanced_ec = False + + test_command = b"sysinfo" + input_stream = [] + input_stream.extend(BytesToByteList(test_command)) + + expected_calls = [] + # All keystrokes to the console should be directed straight through to the + # EC until we press the enter key. + for char in test_command: + if six.PY3: + expected_calls.append(mock.call(bytes([char]))) + else: + expected_calls.append(mock.call(char)) + + # Press the enter key. + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + # The enter key should not be sent to the pipe since we should negotiate + # to an enhanced EC image. + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # At this point, we should have negotiated to enhanced. + self.assertTrue( + self.console.enhanced_ec, msg=("Did not negotiate to " "enhanced EC image.") + ) + + # The command would have been dropped however, so verify this... + CheckInputBuffer(self, b"") + CheckInputBufferPosition(self, 0) + # ...and repeat the command. + input_stream = BytesToByteList(test_command) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Since we're enhanced now, we should have sent the entire command as one + # string with no trailing carriage return + expected_calls.append(mock.call(test_command)) + + # Verify all of the calls. + self.console.cmd_pipe.send.assert_has_calls(expected_calls) + + @mock.patch("ec3po.console.Console.CheckForEnhancedECImage") + def test_TransitionFromEnhancedToNonEnhanced(self, mock_check): + """Verify that we transition correctly to non-enhanced mode. + + Args: + mock_check: A MagicMock object replacing the CheckForEnhancedECImage() + method. + """ + # Set the interrogation mode to always so that we actually interrogate. + self.console.interrogation_mode = b"always" + + # First, assume that the EC interrogations indicate an non-enhanced EC + # image. + mock_check.return_value = False + # But our current knowledge of the EC image (which was actually the + # 'previous' EC) was an enhanced image. + self.console.enhanced_ec = True + + test_command = b"sysinfo" + input_stream = [] + input_stream.extend(BytesToByteList(test_command)) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # But, we will negotiate to non-enhanced however, dropping this command. + # Verify this. + self.assertFalse( + self.console.enhanced_ec, + msg=("Did not negotiate to" "non-enhanced EC image."), + ) + CheckInputBuffer(self, b"") + CheckInputBufferPosition(self, 0) + + # The carriage return should have passed through though. + expected_calls = [] + expected_calls.append( + mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN)) + ) + + # Since the command was dropped, repeat the command. + input_stream = BytesToByteList(test_command) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # Since we're not enhanced now, we should have sent each character in the + # entire command separately and a carriage return. + for char in test_command: + if six.PY3: + expected_calls.append(mock.call(bytes([char]))) + else: + expected_calls.append(mock.call(char)) + expected_calls.append( + mock.call(six.int2byte(console.ControlKey.CARRIAGE_RETURN)) + ) + + # Verify all of the calls. + self.console.cmd_pipe.send.assert_has_calls(expected_calls) + + def test_EnhancedCheckIfTimedOut(self): + """Verify that the check returns false if it times out.""" + # Make the debug pipe "time out". + self.console.dbg_pipe.poll.return_value = False + self.assertFalse(self.console.CheckForEnhancedECImage()) + + def test_EnhancedCheckIfACKReceived(self): + """Verify that the check returns true if the ACK is received.""" + # Make the debug pipe return EC_ACK. + self.console.dbg_pipe.poll.return_value = True + self.console.dbg_pipe.recv.return_value = interpreter.EC_ACK + self.assertTrue(self.console.CheckForEnhancedECImage()) + + def test_EnhancedCheckIfWrong(self): + """Verify that the check returns false if byte received is wrong.""" + # Make the debug pipe return the wrong byte. + self.console.dbg_pipe.poll.return_value = True + self.console.dbg_pipe.recv.return_value = b"\xff" + self.assertFalse(self.console.CheckForEnhancedECImage()) + + def test_EnhancedCheckUsingBuffer(self): + """Verify that given reboot output, enhanced EC images are detected.""" + enhanced_output_stream = b""" --- UART initialized after reboot --- [Reset cause: reset-pin soft] [Image: RO, jerry_v1.1.4363-2af8572-dirty 2016-02-23 13:26:20 aaboagye@lithium.mtv.corp.google.com] @@ -1295,19 +1343,19 @@ Enhanced Console is enabled (v1.0.0); type HELP for help. [0.224060 hash done 41dac382e3a6e3d2ea5b4d789c1bc46525cae7cc5ff6758f0de8d8369b506f57] [0.375150 POWER_GOOD seen] """ - for line in enhanced_output_stream.split(b'\n'): - self.console.CheckBufferForEnhancedImage(line) + for line in enhanced_output_stream.split(b"\n"): + self.console.CheckBufferForEnhancedImage(line) - # Since the enhanced console string was present in the output, the console - # should have caught it. - self.assertTrue(self.console.enhanced_ec) + # Since the enhanced console string was present in the output, the console + # should have caught it. + self.assertTrue(self.console.enhanced_ec) - # Also should check that the command was sent to the interpreter. - self.console.cmd_pipe.send.assert_called_once_with(b'enhanced True') + # Also should check that the command was sent to the interpreter. + self.console.cmd_pipe.send.assert_called_once_with(b"enhanced True") - # Now test the non-enhanced EC image. - self.console.cmd_pipe.reset_mock() - non_enhanced_output_stream = b""" + # Now test the non-enhanced EC image. + self.console.cmd_pipe.reset_mock() + non_enhanced_output_stream = b""" --- UART initialized after reboot --- [Reset cause: reset-pin soft] [Image: RO, jerry_v1.1.4363-2af8572-dirty 2016-02-23 13:03:15 aaboagye@lithium.mtv.corp.google.com] @@ -1331,239 +1379,253 @@ Console is enabled; type HELP for help. [0.010285 power on 2] [0.010385 power state 5 = S5->S3, in 0x0000] """ - for line in non_enhanced_output_stream.split(b'\n'): - self.console.CheckBufferForEnhancedImage(line) + for line in non_enhanced_output_stream.split(b"\n"): + self.console.CheckBufferForEnhancedImage(line) - # Since the default console string is present in the output, it should be - # determined to be non enhanced now. - self.assertFalse(self.console.enhanced_ec) + # Since the default console string is present in the output, it should be + # determined to be non enhanced now. + self.assertFalse(self.console.enhanced_ec) - # Check that command was also sent to the interpreter. - self.console.cmd_pipe.send.assert_called_once_with(b'enhanced False') + # Check that command was also sent to the interpreter. + self.console.cmd_pipe.send.assert_called_once_with(b"enhanced False") class TestOOBMConsoleCommands(unittest.TestCase): - """Verify that OOBM console commands work correctly.""" - def setUp(self): - """Setup the test harness.""" - # Setup logging with a timestamp, the module, and the log level. - logging.basicConfig(level=logging.DEBUG, - format=('%(asctime)s - %(module)s -' - ' %(levelname)s - %(message)s')) - # Create a temp file and set both the controller and peripheral PTYs to the - # file to create a loopback. - self.tempfile = tempfile.TemporaryFile() - - # Mock out the pipes. - mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock() - self.console = console.Console(self.tempfile.fileno(), self.tempfile, - tempfile.TemporaryFile(), - mock_pipe_end_0, mock_pipe_end_1, "EC") - self.console.oobm_queue = mock.MagicMock() - - @mock.patch('ec3po.console.Console.CheckForEnhancedECImage') - def test_InterrogateCommand(self, mock_check): - """Verify that 'interrogate' command works as expected. - - Args: - mock_check: A MagicMock object replacing the CheckForEnhancedECIMage() - method. - """ - input_stream = [] - expected_calls = [] - mock_check.side_effect = [False] - - # 'interrogate never' should disable the interrogation from happening at - # all. - cmd = b'interrogate never' - # Enter the OOBM prompt. - input_stream.extend(BytesToByteList(b'%')) - # Type the command - input_stream.extend(BytesToByteList(cmd)) - # Press enter. - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - input_stream = [] - - # The OOBM queue should have been called with the command being put. - expected_calls.append(mock.call.put(cmd)) - self.console.oobm_queue.assert_has_calls(expected_calls) - - # Process the OOBM queue. - self.console.oobm_queue.get.side_effect = [cmd] - self.console.ProcessOOBMQueue() - - # Type out a few commands. - input_stream.extend(BytesToByteList(b'version')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'flashinfo')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'sysinfo')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The Check function should NOT have been called at all. - mock_check.assert_not_called() - - # The EC image should be assumed to be not enhanced. - self.assertFalse(self.console.enhanced_ec, 'The image should be assumed to' - ' be NOT enhanced.') - - # Reset the mocks. - mock_check.reset_mock() - self.console.oobm_queue.reset_mock() - - # 'interrogate auto' should not interrogate at all. It should only be - # scanning the output stream for the 'console is enabled' strings. - cmd = b'interrogate auto' - # Enter the OOBM prompt. - input_stream.extend(BytesToByteList(b'%')) - # Type the command - input_stream.extend(BytesToByteList(cmd)) - # Press enter. - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - input_stream = [] - expected_calls = [] - - # The OOBM queue should have been called with the command being put. - expected_calls.append(mock.call.put(cmd)) - self.console.oobm_queue.assert_has_calls(expected_calls) - - # Process the OOBM queue. - self.console.oobm_queue.get.side_effect = [cmd] - self.console.ProcessOOBMQueue() - - # Type out a few commands. - input_stream.extend(BytesToByteList(b'version')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'flashinfo')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'sysinfo')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The Check function should NOT have been called at all. - mock_check.assert_not_called() - - # The EC image should be assumed to be not enhanced. - self.assertFalse(self.console.enhanced_ec, 'The image should be assumed to' - ' be NOT enhanced.') - - # Reset the mocks. - mock_check.reset_mock() - self.console.oobm_queue.reset_mock() - - # 'interrogate always' should, like its name implies, interrogate always - # after each press of the enter key. This was the former way of doing - # interrogation. - cmd = b'interrogate always' - # Enter the OOBM prompt. - input_stream.extend(BytesToByteList(b'%')) - # Type the command - input_stream.extend(BytesToByteList(cmd)) - # Press enter. - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - input_stream = [] - expected_calls = [] - - # The OOBM queue should have been called with the command being put. - expected_calls.append(mock.call.put(cmd)) - self.console.oobm_queue.assert_has_calls(expected_calls) - - # Process the OOBM queue. - self.console.oobm_queue.get.side_effect = [cmd] - self.console.ProcessOOBMQueue() - - # The Check method should be called 3 times here. - mock_check.side_effect = [False, False, False] - - # Type out a few commands. - input_stream.extend(BytesToByteList(b'help list')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'taskinfo')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'hibdelay')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The Check method should have been called 3 times here. - expected_calls = [mock.call(), mock.call(), mock.call()] - mock_check.assert_has_calls(expected_calls) - - # The EC image should be assumed to be not enhanced. - self.assertFalse(self.console.enhanced_ec, 'The image should be assumed to' - ' be NOT enhanced.') - - # Now, let's try to assume that the image is enhanced while still disabling - # interrogation. - mock_check.reset_mock() - self.console.oobm_queue.reset_mock() - input_stream = [] - cmd = b'interrogate never enhanced' - # Enter the OOBM prompt. - input_stream.extend(BytesToByteList(b'%')) - # Type the command - input_stream.extend(BytesToByteList(cmd)) - # Press enter. - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - input_stream = [] - expected_calls = [] - - # The OOBM queue should have been called with the command being put. - expected_calls.append(mock.call.put(cmd)) - self.console.oobm_queue.assert_has_calls(expected_calls) - - # Process the OOBM queue. - self.console.oobm_queue.get.side_effect = [cmd] - self.console.ProcessOOBMQueue() - - # Type out a few commands. - input_stream.extend(BytesToByteList(b'chgstate')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'hash')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - input_stream.extend(BytesToByteList(b'sysjump rw')) - input_stream.append(console.ControlKey.CARRIAGE_RETURN) - - # Send the sequence out. - for byte in input_stream: - self.console.HandleChar(byte) - - # The check method should have never been called. - mock_check.assert_not_called() - - # The EC image should be assumed to be enhanced. - self.assertTrue(self.console.enhanced_ec, 'The image should be' - ' assumed to be enhanced.') - - -if __name__ == '__main__': - unittest.main() + """Verify that OOBM console commands work correctly.""" + + def setUp(self): + """Setup the test harness.""" + # Setup logging with a timestamp, the module, and the log level. + logging.basicConfig( + level=logging.DEBUG, + format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"), + ) + # Create a temp file and set both the controller and peripheral PTYs to the + # file to create a loopback. + self.tempfile = tempfile.TemporaryFile() + + # Mock out the pipes. + mock_pipe_end_0, mock_pipe_end_1 = mock.MagicMock(), mock.MagicMock() + self.console = console.Console( + self.tempfile.fileno(), + self.tempfile, + tempfile.TemporaryFile(), + mock_pipe_end_0, + mock_pipe_end_1, + "EC", + ) + self.console.oobm_queue = mock.MagicMock() + + @mock.patch("ec3po.console.Console.CheckForEnhancedECImage") + def test_InterrogateCommand(self, mock_check): + """Verify that 'interrogate' command works as expected. + + Args: + mock_check: A MagicMock object replacing the CheckForEnhancedECIMage() + method. + """ + input_stream = [] + expected_calls = [] + mock_check.side_effect = [False] + + # 'interrogate never' should disable the interrogation from happening at + # all. + cmd = b"interrogate never" + # Enter the OOBM prompt. + input_stream.extend(BytesToByteList(b"%")) + # Type the command + input_stream.extend(BytesToByteList(cmd)) + # Press enter. + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + input_stream = [] + + # The OOBM queue should have been called with the command being put. + expected_calls.append(mock.call.put(cmd)) + self.console.oobm_queue.assert_has_calls(expected_calls) + + # Process the OOBM queue. + self.console.oobm_queue.get.side_effect = [cmd] + self.console.ProcessOOBMQueue() + + # Type out a few commands. + input_stream.extend(BytesToByteList(b"version")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"flashinfo")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"sysinfo")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The Check function should NOT have been called at all. + mock_check.assert_not_called() + + # The EC image should be assumed to be not enhanced. + self.assertFalse( + self.console.enhanced_ec, + "The image should be assumed to" " be NOT enhanced.", + ) + + # Reset the mocks. + mock_check.reset_mock() + self.console.oobm_queue.reset_mock() + + # 'interrogate auto' should not interrogate at all. It should only be + # scanning the output stream for the 'console is enabled' strings. + cmd = b"interrogate auto" + # Enter the OOBM prompt. + input_stream.extend(BytesToByteList(b"%")) + # Type the command + input_stream.extend(BytesToByteList(cmd)) + # Press enter. + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + input_stream = [] + expected_calls = [] + + # The OOBM queue should have been called with the command being put. + expected_calls.append(mock.call.put(cmd)) + self.console.oobm_queue.assert_has_calls(expected_calls) + + # Process the OOBM queue. + self.console.oobm_queue.get.side_effect = [cmd] + self.console.ProcessOOBMQueue() + + # Type out a few commands. + input_stream.extend(BytesToByteList(b"version")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"flashinfo")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"sysinfo")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The Check function should NOT have been called at all. + mock_check.assert_not_called() + + # The EC image should be assumed to be not enhanced. + self.assertFalse( + self.console.enhanced_ec, + "The image should be assumed to" " be NOT enhanced.", + ) + + # Reset the mocks. + mock_check.reset_mock() + self.console.oobm_queue.reset_mock() + + # 'interrogate always' should, like its name implies, interrogate always + # after each press of the enter key. This was the former way of doing + # interrogation. + cmd = b"interrogate always" + # Enter the OOBM prompt. + input_stream.extend(BytesToByteList(b"%")) + # Type the command + input_stream.extend(BytesToByteList(cmd)) + # Press enter. + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + input_stream = [] + expected_calls = [] + + # The OOBM queue should have been called with the command being put. + expected_calls.append(mock.call.put(cmd)) + self.console.oobm_queue.assert_has_calls(expected_calls) + + # Process the OOBM queue. + self.console.oobm_queue.get.side_effect = [cmd] + self.console.ProcessOOBMQueue() + + # The Check method should be called 3 times here. + mock_check.side_effect = [False, False, False] + + # Type out a few commands. + input_stream.extend(BytesToByteList(b"help list")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"taskinfo")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"hibdelay")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The Check method should have been called 3 times here. + expected_calls = [mock.call(), mock.call(), mock.call()] + mock_check.assert_has_calls(expected_calls) + + # The EC image should be assumed to be not enhanced. + self.assertFalse( + self.console.enhanced_ec, + "The image should be assumed to" " be NOT enhanced.", + ) + + # Now, let's try to assume that the image is enhanced while still disabling + # interrogation. + mock_check.reset_mock() + self.console.oobm_queue.reset_mock() + input_stream = [] + cmd = b"interrogate never enhanced" + # Enter the OOBM prompt. + input_stream.extend(BytesToByteList(b"%")) + # Type the command + input_stream.extend(BytesToByteList(cmd)) + # Press enter. + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + input_stream = [] + expected_calls = [] + + # The OOBM queue should have been called with the command being put. + expected_calls.append(mock.call.put(cmd)) + self.console.oobm_queue.assert_has_calls(expected_calls) + + # Process the OOBM queue. + self.console.oobm_queue.get.side_effect = [cmd] + self.console.ProcessOOBMQueue() + + # Type out a few commands. + input_stream.extend(BytesToByteList(b"chgstate")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"hash")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + input_stream.extend(BytesToByteList(b"sysjump rw")) + input_stream.append(console.ControlKey.CARRIAGE_RETURN) + + # Send the sequence out. + for byte in input_stream: + self.console.HandleChar(byte) + + # The check method should have never been called. + mock_check.assert_not_called() + + # The EC image should be assumed to be enhanced. + self.assertTrue( + self.console.enhanced_ec, "The image should be" " assumed to be enhanced." + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/util/ec3po/interpreter.py b/util/ec3po/interpreter.py index 4e151083bd..591603e038 100644 --- a/util/ec3po/interpreter.py +++ b/util/ec3po/interpreter.py @@ -25,443 +25,447 @@ import traceback import six - COMMAND_RETRIES = 3 # Number of attempts to retry a command. EC_MAX_READ = 1024 # Max bytes to read at a time from the EC. -EC_SYN = b'\xec' # Byte indicating EC interrogation. -EC_ACK = b'\xc0' # Byte representing correct EC response to interrogation. +EC_SYN = b"\xec" # Byte indicating EC interrogation. +EC_ACK = b"\xc0" # Byte representing correct EC response to interrogation. class LoggerAdapter(logging.LoggerAdapter): - """Class which provides a small adapter for the logger.""" + """Class which provides a small adapter for the logger.""" - def process(self, msg, kwargs): - """Prepends the served PTY to the beginning of the log message.""" - return '%s - %s' % (self.extra['pty'], msg), kwargs + def process(self, msg, kwargs): + """Prepends the served PTY to the beginning of the log message.""" + return "%s - %s" % (self.extra["pty"], msg), kwargs class Interpreter(object): - """Class which provides the interpretation layer between the EC and user. - - This class essentially performs all of the intepretation for the EC and the - user. It handles all of the automatic command retrying as well as the - formation of commands for EC images which support that. - - Attributes: - logger: A logger for this module. - ec_uart_pty: An opened file object to the raw EC UART PTY. - ec_uart_pty_name: A string containing the name of the raw EC UART PTY. - cmd_pipe: A socket.socket or multiprocessing.Connection object which - represents the Interpreter side of the command pipe. This must be a - bidirectional pipe. Commands and responses will utilize this pipe. - dbg_pipe: A socket.socket or multiprocessing.Connection object which - represents the Interpreter side of the debug pipe. This must be a - unidirectional pipe with write capabilities. EC debug output will utilize - this pipe. - cmd_retries: An integer representing the number of attempts the console - should retry commands if it receives an error. - log_level: An integer representing the numeric value of the log level. - inputs: A list of objects that the intpreter selects for reading. - Initially, these are the EC UART and the command pipe. - outputs: A list of objects that the interpreter selects for writing. - ec_cmd_queue: A FIFO queue used for sending commands down to the EC UART. - last_cmd: A string that represents the last command sent to the EC. If an - error is encountered, the interpreter will attempt to retry this command - up to COMMAND_RETRIES. - enhanced_ec: A boolean indicating if the EC image that we are currently - communicating with is enhanced or not. Enhanced EC images will support - packed commands and host commands over the UART. This defaults to False - and is changed depending on the result of an interrogation. - interrogating: A boolean indicating if we are in the middle of interrogating - the EC. - connected: A boolean indicating if the interpreter is actually connected to - the UART and listening. - """ - def __init__(self, ec_uart_pty, cmd_pipe, dbg_pipe, log_level=logging.INFO, - name=None): - """Intializes an Interpreter object with the provided args. + """Class which provides the interpretation layer between the EC and user. - Args: - ec_uart_pty: A string representing the EC UART to connect to. + This class essentially performs all of the intepretation for the EC and the + user. It handles all of the automatic command retrying as well as the + formation of commands for EC images which support that. + + Attributes: + logger: A logger for this module. + ec_uart_pty: An opened file object to the raw EC UART PTY. + ec_uart_pty_name: A string containing the name of the raw EC UART PTY. cmd_pipe: A socket.socket or multiprocessing.Connection object which represents the Interpreter side of the command pipe. This must be a bidirectional pipe. Commands and responses will utilize this pipe. dbg_pipe: A socket.socket or multiprocessing.Connection object which represents the Interpreter side of the debug pipe. This must be a - unidirectional pipe with write capabilities. EC debug output will - utilize this pipe. + unidirectional pipe with write capabilities. EC debug output will utilize + this pipe. cmd_retries: An integer representing the number of attempts the console should retry commands if it receives an error. - log_level: An optional integer representing the numeric value of the log - level. By default, the log level will be logging.INFO (20). - name: the console source name - """ - # Create a unique logger based on the interpreter name - interpreter_prefix = ('%s - ' % name) if name else '' - logger = logging.getLogger('%sEC3PO.Interpreter' % interpreter_prefix) - self.logger = LoggerAdapter(logger, {'pty': ec_uart_pty}) - # TODO(https://crbug.com/1162189): revist the 2 TODOs below - # TODO(https://bugs.python.org/issue27805, python3.7+): revert to ab+ - # TODO(https://bugs.python.org/issue20074): removing buffering=0 if/when - # that gets fixed, or keep two pty: one for reading and one for writing - self.ec_uart_pty = open(ec_uart_pty, 'r+b', buffering=0) - self.ec_uart_pty_name = ec_uart_pty - self.cmd_pipe = cmd_pipe - self.dbg_pipe = dbg_pipe - self.cmd_retries = COMMAND_RETRIES - self.log_level = log_level - self.inputs = [self.ec_uart_pty, self.cmd_pipe] - self.outputs = [] - self.ec_cmd_queue = six.moves.queue.Queue() - self.last_cmd = b'' - self.enhanced_ec = False - self.interrogating = False - self.connected = True - - def __str__(self): - """Show internal state of the Interpreter object. - - Returns: - A string that shows the values of the attributes. - """ - string = [] - string.append('%r' % self) - string.append('ec_uart_pty: %s' % self.ec_uart_pty) - string.append('cmd_pipe: %r' % self.cmd_pipe) - string.append('dbg_pipe: %r' % self.dbg_pipe) - string.append('cmd_retries: %d' % self.cmd_retries) - string.append('log_level: %d' % self.log_level) - string.append('inputs: %r' % self.inputs) - string.append('outputs: %r' % self.outputs) - string.append('ec_cmd_queue: %r' % self.ec_cmd_queue) - string.append('last_cmd: \'%s\'' % self.last_cmd) - string.append('enhanced_ec: %r' % self.enhanced_ec) - string.append('interrogating: %r' % self.interrogating) - return '\n'.join(string) - - def EnqueueCmd(self, command): - """Enqueue a command to be sent to the EC UART. - - Args: - command: A string which contains the command to be sent. + log_level: An integer representing the numeric value of the log level. + inputs: A list of objects that the intpreter selects for reading. + Initially, these are the EC UART and the command pipe. + outputs: A list of objects that the interpreter selects for writing. + ec_cmd_queue: A FIFO queue used for sending commands down to the EC UART. + last_cmd: A string that represents the last command sent to the EC. If an + error is encountered, the interpreter will attempt to retry this command + up to COMMAND_RETRIES. + enhanced_ec: A boolean indicating if the EC image that we are currently + communicating with is enhanced or not. Enhanced EC images will support + packed commands and host commands over the UART. This defaults to False + and is changed depending on the result of an interrogation. + interrogating: A boolean indicating if we are in the middle of interrogating + the EC. + connected: A boolean indicating if the interpreter is actually connected to + the UART and listening. """ - self.ec_cmd_queue.put(command) - self.logger.log(1, 'Commands now in queue: %d', self.ec_cmd_queue.qsize()) - # Add the EC UART as an output to be serviced. - if self.connected and self.ec_uart_pty not in self.outputs: - self.outputs.append(self.ec_uart_pty) - - def PackCommand(self, raw_cmd): - r"""Packs a command for use with error checking. - - For error checking, we pack console commands in a particular format. The - format is as follows: - - &&[x][x][x][x]&{cmd}\n\n - ^ ^ ^^ ^^ ^ ^-- 2 newlines. - | | || || |-- the raw console command. - | | || ||-- 1 ampersand. - | | ||____|--- 2 hex digits representing the CRC8 of cmd. - | |____|-- 2 hex digits reprsenting the length of cmd. - |-- 2 ampersands - - Args: - raw_cmd: A pre-packed string which contains the raw command. - - Returns: - A string which contains the packed command. - """ - # Don't pack a single carriage return. - if raw_cmd != b'\r': - # The command format is as follows. - # &&[x][x][x][x]&{cmd}\n\n - packed_cmd = [] - packed_cmd.append(b'&&') - # The first pair of hex digits are the length of the command. - packed_cmd.append(b'%02x' % len(raw_cmd)) - # Then the CRC8 of cmd. - packed_cmd.append(b'%02x' % Crc8(raw_cmd)) - packed_cmd.append(b'&') - # Now, the raw command followed by 2 newlines. - packed_cmd.append(raw_cmd) - packed_cmd.append(b'\n\n') - return b''.join(packed_cmd) - else: - return raw_cmd - - def ProcessCommand(self, command): - """Captures the input determines what actions to take. - - Args: - command: A string representing the command sent by the user. - """ - if command == b'disconnect': - if self.connected: - self.logger.debug('UART disconnect request.') - # Drop all pending commands if any. - while not self.ec_cmd_queue.empty(): - c = self.ec_cmd_queue.get() - self.logger.debug('dropped: \'%s\'', c) - if self.enhanced_ec: - # Reset retry state. - self.cmd_retries = COMMAND_RETRIES - self.last_cmd = b'' - # Get the UART that the interpreter is attached to. - fileobj = self.ec_uart_pty - self.logger.debug('fileobj: %r', fileobj) - # Remove the descriptor from the inputs and outputs. - self.inputs.remove(fileobj) - if fileobj in self.outputs: - self.outputs.remove(fileobj) - self.logger.debug('Removed fileobj. Remaining inputs: %r', self.inputs) - # Close the file. - fileobj.close() - # Mark the interpreter as disconnected now. - self.connected = False - self.logger.debug('Disconnected from %s.', self.ec_uart_pty_name) - return - - elif command == b'reconnect': - if not self.connected: - self.logger.debug('UART reconnect request.') - # Reopen the PTY. + def __init__( + self, ec_uart_pty, cmd_pipe, dbg_pipe, log_level=logging.INFO, name=None + ): + """Intializes an Interpreter object with the provided args. + + Args: + ec_uart_pty: A string representing the EC UART to connect to. + cmd_pipe: A socket.socket or multiprocessing.Connection object which + represents the Interpreter side of the command pipe. This must be a + bidirectional pipe. Commands and responses will utilize this pipe. + dbg_pipe: A socket.socket or multiprocessing.Connection object which + represents the Interpreter side of the debug pipe. This must be a + unidirectional pipe with write capabilities. EC debug output will + utilize this pipe. + cmd_retries: An integer representing the number of attempts the console + should retry commands if it receives an error. + log_level: An optional integer representing the numeric value of the log + level. By default, the log level will be logging.INFO (20). + name: the console source name + """ + # Create a unique logger based on the interpreter name + interpreter_prefix = ("%s - " % name) if name else "" + logger = logging.getLogger("%sEC3PO.Interpreter" % interpreter_prefix) + self.logger = LoggerAdapter(logger, {"pty": ec_uart_pty}) + # TODO(https://crbug.com/1162189): revist the 2 TODOs below # TODO(https://bugs.python.org/issue27805, python3.7+): revert to ab+ # TODO(https://bugs.python.org/issue20074): removing buffering=0 if/when # that gets fixed, or keep two pty: one for reading and one for writing - fileobj = open(self.ec_uart_pty_name, 'r+b', buffering=0) - self.logger.debug('fileobj: %r', fileobj) - self.ec_uart_pty = fileobj - # Add the descriptor to the inputs. - self.inputs.append(fileobj) - self.logger.debug('fileobj added. curr inputs: %r', self.inputs) - # Mark the interpreter as connected now. - self.connected = True - self.logger.debug('Connected to %s.', self.ec_uart_pty_name) - return - - elif command.startswith(b'enhanced'): - self.enhanced_ec = command.split(b' ')[1] == b'True' - return - - # Ignore any other commands while in the disconnected state. - self.logger.log(1, 'command: \'%s\'', command) - if not self.connected: - self.logger.debug('Ignoring command because currently disconnected.') - return - - # Remove leading and trailing spaces only if this is an enhanced EC image. - # For non-enhanced EC images, commands will be single characters at a time - # and can be spaces. - if self.enhanced_ec: - command = command.strip(b' ') - - # There's nothing to do if the command is empty. - if len(command) == 0: - return - - # Handle log level change requests. - if command.startswith(b'loglevel'): - self.logger.debug('Log level change request.') - new_log_level = int(command.split(b' ')[1]) - self.logger.logger.setLevel(new_log_level) - self.logger.info('Log level changed to %d.', new_log_level) - return - - # Check for interrogation command. - if command == EC_SYN: - # User is requesting interrogation. Send SYN as is. - self.logger.debug('User requesting interrogation.') - self.interrogating = True - # Assume the EC isn't enhanced until we get a response. - self.enhanced_ec = False - elif self.enhanced_ec: - # Enhanced EC images require the plaintext commands to be packed. - command = self.PackCommand(command) - # TODO(aaboagye): Make a dict of commands and keys and eventually, - # handle partial matching based on unique prefixes. - - self.EnqueueCmd(command) - - def HandleCmdRetries(self): - """Attempts to retry commands if possible.""" - if self.cmd_retries > 0: - # The EC encountered an error. We'll have to retry again. - self.logger.warning('Retrying command...') - self.cmd_retries -= 1 - self.logger.warning('Retries remaining: %d', self.cmd_retries) - # Retry the command and add the EC UART to the writers again. - self.EnqueueCmd(self.last_cmd) - self.outputs.append(self.ec_uart_pty) - else: - # We're out of retries, so just give up. - self.logger.error('Command failed. No retries left.') - # Clear the command in progress. - self.last_cmd = b'' - # Reset the retry count. - self.cmd_retries = COMMAND_RETRIES - - def SendCmdToEC(self): - """Sends a command to the EC.""" - # If we're retrying a command, just try to send it again. - if self.cmd_retries < COMMAND_RETRIES: - cmd = self.last_cmd - else: - # If we're not retrying, we should not be writing to the EC if we have no - # items in our command queue. - assert not self.ec_cmd_queue.empty() - # Get the command to send. - cmd = self.ec_cmd_queue.get() - - # Send the command. - self.ec_uart_pty.write(cmd) - self.ec_uart_pty.flush() - self.logger.log(1, 'Sent command to EC.') - - if self.enhanced_ec and cmd != EC_SYN: - # Now, that we've sent the command, store the current command as the last - # command sent. If we encounter an error string, we will attempt to retry - # this command. - if cmd != self.last_cmd: - self.last_cmd = cmd - # Reset the retry count. + self.ec_uart_pty = open(ec_uart_pty, "r+b", buffering=0) + self.ec_uart_pty_name = ec_uart_pty + self.cmd_pipe = cmd_pipe + self.dbg_pipe = dbg_pipe self.cmd_retries = COMMAND_RETRIES + self.log_level = log_level + self.inputs = [self.ec_uart_pty, self.cmd_pipe] + self.outputs = [] + self.ec_cmd_queue = six.moves.queue.Queue() + self.last_cmd = b"" + self.enhanced_ec = False + self.interrogating = False + self.connected = True - # If no command is pending to be sent, then we can remove the EC UART from - # writers. Might need better checking for command retry logic in here. - if self.ec_cmd_queue.empty(): - # Remove the EC UART from the writers while we wait for a response. - self.logger.debug('Removing EC UART from writers.') - self.outputs.remove(self.ec_uart_pty) - - def HandleECData(self): - """Handle any debug prints from the EC.""" - self.logger.log(1, 'EC has data') - # Read what the EC sent us. - data = os.read(self.ec_uart_pty.fileno(), EC_MAX_READ) - self.logger.log(1, 'got: \'%s\'', binascii.hexlify(data)) - if b'&E' in data and self.enhanced_ec: - # We received an error, so we should retry it if possible. - self.logger.warning('Error string found in data.') - self.HandleCmdRetries() - return - - # If we were interrogating, check the response and update our knowledge - # of the current EC image. - if self.interrogating: - self.enhanced_ec = data == EC_ACK - if self.enhanced_ec: - self.logger.debug('The current EC image seems enhanced.') - else: - self.logger.debug('The current EC image does NOT seem enhanced.') - # Done interrogating. - self.interrogating = False - # For now, just forward everything the EC sends us. - self.logger.log(1, 'Forwarding to user...') - self.dbg_pipe.send(data) - - def HandleUserData(self): - """Handle any incoming commands from the user. - - Raises: - EOFError: Allowed to propagate through from self.cmd_pipe.recv(). - """ - self.logger.log(1, 'Command data available. Begin processing.') - data = self.cmd_pipe.recv() - # Process the command. - self.ProcessCommand(data) + def __str__(self): + """Show internal state of the Interpreter object. + + Returns: + A string that shows the values of the attributes. + """ + string = [] + string.append("%r" % self) + string.append("ec_uart_pty: %s" % self.ec_uart_pty) + string.append("cmd_pipe: %r" % self.cmd_pipe) + string.append("dbg_pipe: %r" % self.dbg_pipe) + string.append("cmd_retries: %d" % self.cmd_retries) + string.append("log_level: %d" % self.log_level) + string.append("inputs: %r" % self.inputs) + string.append("outputs: %r" % self.outputs) + string.append("ec_cmd_queue: %r" % self.ec_cmd_queue) + string.append("last_cmd: '%s'" % self.last_cmd) + string.append("enhanced_ec: %r" % self.enhanced_ec) + string.append("interrogating: %r" % self.interrogating) + return "\n".join(string) + + def EnqueueCmd(self, command): + """Enqueue a command to be sent to the EC UART. + + Args: + command: A string which contains the command to be sent. + """ + self.ec_cmd_queue.put(command) + self.logger.log(1, "Commands now in queue: %d", self.ec_cmd_queue.qsize()) + + # Add the EC UART as an output to be serviced. + if self.connected and self.ec_uart_pty not in self.outputs: + self.outputs.append(self.ec_uart_pty) + + def PackCommand(self, raw_cmd): + r"""Packs a command for use with error checking. + + For error checking, we pack console commands in a particular format. The + format is as follows: + + &&[x][x][x][x]&{cmd}\n\n + ^ ^ ^^ ^^ ^ ^-- 2 newlines. + | | || || |-- the raw console command. + | | || ||-- 1 ampersand. + | | ||____|--- 2 hex digits representing the CRC8 of cmd. + | |____|-- 2 hex digits reprsenting the length of cmd. + |-- 2 ampersands + + Args: + raw_cmd: A pre-packed string which contains the raw command. + + Returns: + A string which contains the packed command. + """ + # Don't pack a single carriage return. + if raw_cmd != b"\r": + # The command format is as follows. + # &&[x][x][x][x]&{cmd}\n\n + packed_cmd = [] + packed_cmd.append(b"&&") + # The first pair of hex digits are the length of the command. + packed_cmd.append(b"%02x" % len(raw_cmd)) + # Then the CRC8 of cmd. + packed_cmd.append(b"%02x" % Crc8(raw_cmd)) + packed_cmd.append(b"&") + # Now, the raw command followed by 2 newlines. + packed_cmd.append(raw_cmd) + packed_cmd.append(b"\n\n") + return b"".join(packed_cmd) + else: + return raw_cmd + + def ProcessCommand(self, command): + """Captures the input determines what actions to take. + + Args: + command: A string representing the command sent by the user. + """ + if command == b"disconnect": + if self.connected: + self.logger.debug("UART disconnect request.") + # Drop all pending commands if any. + while not self.ec_cmd_queue.empty(): + c = self.ec_cmd_queue.get() + self.logger.debug("dropped: '%s'", c) + if self.enhanced_ec: + # Reset retry state. + self.cmd_retries = COMMAND_RETRIES + self.last_cmd = b"" + # Get the UART that the interpreter is attached to. + fileobj = self.ec_uart_pty + self.logger.debug("fileobj: %r", fileobj) + # Remove the descriptor from the inputs and outputs. + self.inputs.remove(fileobj) + if fileobj in self.outputs: + self.outputs.remove(fileobj) + self.logger.debug("Removed fileobj. Remaining inputs: %r", self.inputs) + # Close the file. + fileobj.close() + # Mark the interpreter as disconnected now. + self.connected = False + self.logger.debug("Disconnected from %s.", self.ec_uart_pty_name) + return + + elif command == b"reconnect": + if not self.connected: + self.logger.debug("UART reconnect request.") + # Reopen the PTY. + # TODO(https://bugs.python.org/issue27805, python3.7+): revert to ab+ + # TODO(https://bugs.python.org/issue20074): removing buffering=0 if/when + # that gets fixed, or keep two pty: one for reading and one for writing + fileobj = open(self.ec_uart_pty_name, "r+b", buffering=0) + self.logger.debug("fileobj: %r", fileobj) + self.ec_uart_pty = fileobj + # Add the descriptor to the inputs. + self.inputs.append(fileobj) + self.logger.debug("fileobj added. curr inputs: %r", self.inputs) + # Mark the interpreter as connected now. + self.connected = True + self.logger.debug("Connected to %s.", self.ec_uart_pty_name) + return + + elif command.startswith(b"enhanced"): + self.enhanced_ec = command.split(b" ")[1] == b"True" + return + + # Ignore any other commands while in the disconnected state. + self.logger.log(1, "command: '%s'", command) + if not self.connected: + self.logger.debug("Ignoring command because currently disconnected.") + return + + # Remove leading and trailing spaces only if this is an enhanced EC image. + # For non-enhanced EC images, commands will be single characters at a time + # and can be spaces. + if self.enhanced_ec: + command = command.strip(b" ") + + # There's nothing to do if the command is empty. + if len(command) == 0: + return + + # Handle log level change requests. + if command.startswith(b"loglevel"): + self.logger.debug("Log level change request.") + new_log_level = int(command.split(b" ")[1]) + self.logger.logger.setLevel(new_log_level) + self.logger.info("Log level changed to %d.", new_log_level) + return + + # Check for interrogation command. + if command == EC_SYN: + # User is requesting interrogation. Send SYN as is. + self.logger.debug("User requesting interrogation.") + self.interrogating = True + # Assume the EC isn't enhanced until we get a response. + self.enhanced_ec = False + elif self.enhanced_ec: + # Enhanced EC images require the plaintext commands to be packed. + command = self.PackCommand(command) + # TODO(aaboagye): Make a dict of commands and keys and eventually, + # handle partial matching based on unique prefixes. + + self.EnqueueCmd(command) + + def HandleCmdRetries(self): + """Attempts to retry commands if possible.""" + if self.cmd_retries > 0: + # The EC encountered an error. We'll have to retry again. + self.logger.warning("Retrying command...") + self.cmd_retries -= 1 + self.logger.warning("Retries remaining: %d", self.cmd_retries) + # Retry the command and add the EC UART to the writers again. + self.EnqueueCmd(self.last_cmd) + self.outputs.append(self.ec_uart_pty) + else: + # We're out of retries, so just give up. + self.logger.error("Command failed. No retries left.") + # Clear the command in progress. + self.last_cmd = b"" + # Reset the retry count. + self.cmd_retries = COMMAND_RETRIES + + def SendCmdToEC(self): + """Sends a command to the EC.""" + # If we're retrying a command, just try to send it again. + if self.cmd_retries < COMMAND_RETRIES: + cmd = self.last_cmd + else: + # If we're not retrying, we should not be writing to the EC if we have no + # items in our command queue. + assert not self.ec_cmd_queue.empty() + # Get the command to send. + cmd = self.ec_cmd_queue.get() + + # Send the command. + self.ec_uart_pty.write(cmd) + self.ec_uart_pty.flush() + self.logger.log(1, "Sent command to EC.") + + if self.enhanced_ec and cmd != EC_SYN: + # Now, that we've sent the command, store the current command as the last + # command sent. If we encounter an error string, we will attempt to retry + # this command. + if cmd != self.last_cmd: + self.last_cmd = cmd + # Reset the retry count. + self.cmd_retries = COMMAND_RETRIES + + # If no command is pending to be sent, then we can remove the EC UART from + # writers. Might need better checking for command retry logic in here. + if self.ec_cmd_queue.empty(): + # Remove the EC UART from the writers while we wait for a response. + self.logger.debug("Removing EC UART from writers.") + self.outputs.remove(self.ec_uart_pty) + + def HandleECData(self): + """Handle any debug prints from the EC.""" + self.logger.log(1, "EC has data") + # Read what the EC sent us. + data = os.read(self.ec_uart_pty.fileno(), EC_MAX_READ) + self.logger.log(1, "got: '%s'", binascii.hexlify(data)) + if b"&E" in data and self.enhanced_ec: + # We received an error, so we should retry it if possible. + self.logger.warning("Error string found in data.") + self.HandleCmdRetries() + return + + # If we were interrogating, check the response and update our knowledge + # of the current EC image. + if self.interrogating: + self.enhanced_ec = data == EC_ACK + if self.enhanced_ec: + self.logger.debug("The current EC image seems enhanced.") + else: + self.logger.debug("The current EC image does NOT seem enhanced.") + # Done interrogating. + self.interrogating = False + # For now, just forward everything the EC sends us. + self.logger.log(1, "Forwarding to user...") + self.dbg_pipe.send(data) + + def HandleUserData(self): + """Handle any incoming commands from the user. + + Raises: + EOFError: Allowed to propagate through from self.cmd_pipe.recv(). + """ + self.logger.log(1, "Command data available. Begin processing.") + data = self.cmd_pipe.recv() + # Process the command. + self.ProcessCommand(data) def Crc8(data): - """Calculates the CRC8 of data. + """Calculates the CRC8 of data. - The generator polynomial used is: x^8 + x^2 + x + 1. - This is the same implementation that is used in the EC. + The generator polynomial used is: x^8 + x^2 + x + 1. + This is the same implementation that is used in the EC. - Args: - data: A string of data that we wish to calculate the CRC8 on. + Args: + data: A string of data that we wish to calculate the CRC8 on. - Returns: - crc >> 8: An integer representing the CRC8 value. - """ - crc = 0 - for byte in six.iterbytes(data): - crc ^= (byte << 8) - for _ in range(8): - if crc & 0x8000: - crc ^= (0x1070 << 3) - crc <<= 1 - return crc >> 8 + Returns: + crc >> 8: An integer representing the CRC8 value. + """ + crc = 0 + for byte in six.iterbytes(data): + crc ^= byte << 8 + for _ in range(8): + if crc & 0x8000: + crc ^= 0x1070 << 3 + crc <<= 1 + return crc >> 8 def StartLoop(interp, shutdown_pipe=None): - """Starts an infinite loop of servicing the user and the EC. - - StartLoop checks to see if there are any commands to process, processing them - if any, and forwards EC output to the user. - - When sending a command to the EC, we send the command once and check the - response to see if the EC encountered an error when receiving the command. An - error condition is reported to the interpreter by a string with at least one - '&' and 'E'. The full string is actually '&&EE', however it's possible that - the leading ampersand or trailing 'E' could be dropped. If an error is - encountered, the interpreter will retry up to the amount configured. - - Args: - interp: An Interpreter object that has been properly initialised. - shutdown_pipe: A file object for a pipe or equivalent that becomes readable - (not blocked) to indicate that the loop should exit. Can be None to never - exit the loop. - """ - try: - # This is used instead of "break" to avoid exiting the loop in the middle of - # an iteration. - continue_looping = True - - while continue_looping: - # The inputs list is created anew in each loop iteration because the - # Interpreter class sometimes modifies the interp.inputs list. - if shutdown_pipe is None: - inputs = interp.inputs - else: - inputs = list(interp.inputs) - inputs.append(shutdown_pipe) - - readable, writeable, _ = select.select(inputs, interp.outputs, []) - - for obj in readable: - # Handle any debug prints from the EC. - if obj is interp.ec_uart_pty: - interp.HandleECData() - - # Handle any commands from the user. - elif obj is interp.cmd_pipe: - try: - interp.HandleUserData() - except EOFError: - interp.logger.debug( - 'ec3po interpreter received EOF from cmd_pipe in ' - 'HandleUserData()') - continue_looping = False - - elif obj is shutdown_pipe: - interp.logger.debug( - 'ec3po interpreter received shutdown pipe unblocked notification') - continue_looping = False - - for obj in writeable: - # Send a command to the EC. - if obj is interp.ec_uart_pty: - interp.SendCmdToEC() - - except KeyboardInterrupt: - pass - - finally: - interp.cmd_pipe.close() - interp.dbg_pipe.close() - interp.ec_uart_pty.close() - if shutdown_pipe is not None: - shutdown_pipe.close() - interp.logger.debug('Exit ec3po interpreter loop for %s', - interp.ec_uart_pty_name) + """Starts an infinite loop of servicing the user and the EC. + + StartLoop checks to see if there are any commands to process, processing them + if any, and forwards EC output to the user. + + When sending a command to the EC, we send the command once and check the + response to see if the EC encountered an error when receiving the command. An + error condition is reported to the interpreter by a string with at least one + '&' and 'E'. The full string is actually '&&EE', however it's possible that + the leading ampersand or trailing 'E' could be dropped. If an error is + encountered, the interpreter will retry up to the amount configured. + + Args: + interp: An Interpreter object that has been properly initialised. + shutdown_pipe: A file object for a pipe or equivalent that becomes readable + (not blocked) to indicate that the loop should exit. Can be None to never + exit the loop. + """ + try: + # This is used instead of "break" to avoid exiting the loop in the middle of + # an iteration. + continue_looping = True + + while continue_looping: + # The inputs list is created anew in each loop iteration because the + # Interpreter class sometimes modifies the interp.inputs list. + if shutdown_pipe is None: + inputs = interp.inputs + else: + inputs = list(interp.inputs) + inputs.append(shutdown_pipe) + + readable, writeable, _ = select.select(inputs, interp.outputs, []) + + for obj in readable: + # Handle any debug prints from the EC. + if obj is interp.ec_uart_pty: + interp.HandleECData() + + # Handle any commands from the user. + elif obj is interp.cmd_pipe: + try: + interp.HandleUserData() + except EOFError: + interp.logger.debug( + "ec3po interpreter received EOF from cmd_pipe in " + "HandleUserData()" + ) + continue_looping = False + + elif obj is shutdown_pipe: + interp.logger.debug( + "ec3po interpreter received shutdown pipe unblocked notification" + ) + continue_looping = False + + for obj in writeable: + # Send a command to the EC. + if obj is interp.ec_uart_pty: + interp.SendCmdToEC() + + except KeyboardInterrupt: + pass + + finally: + interp.cmd_pipe.close() + interp.dbg_pipe.close() + interp.ec_uart_pty.close() + if shutdown_pipe is not None: + shutdown_pipe.close() + interp.logger.debug( + "Exit ec3po interpreter loop for %s", interp.ec_uart_pty_name + ) diff --git a/util/ec3po/interpreter_unittest.py b/util/ec3po/interpreter_unittest.py index fe4d43c351..509b90f667 100755 --- a/util/ec3po/interpreter_unittest.py +++ b/util/ec3po/interpreter_unittest.py @@ -10,371 +10,389 @@ from __future__ import print_function import logging -import mock import tempfile import unittest +import mock import six - -from ec3po import interpreter -from ec3po import threadproc_shim +from ec3po import interpreter, threadproc_shim def GetBuiltins(func): - if six.PY2: - return '__builtin__.' + func - return 'builtins.' + func + if six.PY2: + return "__builtin__." + func + return "builtins." + func class TestEnhancedECBehaviour(unittest.TestCase): - """Test case to verify all enhanced EC interpretation tasks.""" - def setUp(self): - """Setup the test harness.""" - # Setup logging with a timestamp, the module, and the log level. - logging.basicConfig(level=logging.DEBUG, - format=('%(asctime)s - %(module)s -' - ' %(levelname)s - %(message)s')) - - # Create a tempfile that would represent the EC UART PTY. - self.tempfile = tempfile.NamedTemporaryFile() - - # Create the pipes that the interpreter will use. - self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe() - self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False) - - # Mock the open() function so we can inspect reads/writes to the EC. - self.ec_uart_pty = mock.mock_open() - - with mock.patch(GetBuiltins('open'), self.ec_uart_pty): - # Create an interpreter. - self.itpr = interpreter.Interpreter(self.tempfile.name, - self.cmd_pipe_itpr, - self.dbg_pipe_itpr, - log_level=logging.DEBUG, - name="EC") - - @mock.patch('ec3po.interpreter.os') - def test_HandlingCommandsThatProduceNoOutput(self, mock_os): - """Verify that the Interpreter correctly handles non-output commands. - - Args: - mock_os: MagicMock object replacing the 'os' module for this test - case. - """ - # The interpreter init should open the EC UART PTY. - expected_ec_calls = [mock.call(self.tempfile.name, 'r+b', buffering=0)] - # Have a command come in the command pipe. The first command will be an - # interrogation to determine if the EC is enhanced or not. - self.cmd_pipe_user.send(interpreter.EC_SYN) - self.itpr.HandleUserData() - # At this point, the command should be queued up waiting to be sent, so - # let's actually send it to the EC. - self.itpr.SendCmdToEC() - expected_ec_calls.extend([mock.call().write(interpreter.EC_SYN), - mock.call().flush()]) - # Now, assume that the EC sends only 1 response back of EC_ACK. - mock_os.read.side_effect = [interpreter.EC_ACK] - # When reading the EC, the interpreter will call file.fileno() to pass to - # os.read(). - expected_ec_calls.append(mock.call().fileno()) - # Simulate the response. - self.itpr.HandleECData() - - # Now that the interrogation was complete, it's time to send down the real - # command. - test_cmd = b'chan save' - # Send the test command down the pipe. - self.cmd_pipe_user.send(test_cmd) - self.itpr.HandleUserData() - self.itpr.SendCmdToEC() - # Since the EC image is enhanced, we should have sent a packed command. - expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd))) - expected_ec_calls.append(mock.call().flush()) - - # Now that the first command was sent, we should send another command which - # produces no output. The console would send another interrogation. - self.cmd_pipe_user.send(interpreter.EC_SYN) - self.itpr.HandleUserData() - self.itpr.SendCmdToEC() - expected_ec_calls.extend([mock.call().write(interpreter.EC_SYN), - mock.call().flush()]) - # Again, assume that the EC sends only 1 response back of EC_ACK. - mock_os.read.side_effect = [interpreter.EC_ACK] - # When reading the EC, the interpreter will call file.fileno() to pass to - # os.read(). - expected_ec_calls.append(mock.call().fileno()) - # Simulate the response. - self.itpr.HandleECData() - - # Now send the second test command. - test_cmd = b'chan 0' - self.cmd_pipe_user.send(test_cmd) - self.itpr.HandleUserData() - self.itpr.SendCmdToEC() - # Since the EC image is enhanced, we should have sent a packed command. - expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd))) - expected_ec_calls.append(mock.call().flush()) - - # Finally, verify that the appropriate writes were actually sent to the EC. - self.ec_uart_pty.assert_has_calls(expected_ec_calls) - - @mock.patch('ec3po.interpreter.os') - def test_CommandRetryingOnError(self, mock_os): - """Verify that commands are retried if an error is encountered. - - Args: - mock_os: MagicMock object replacing the 'os' module for this test - case. - """ - # The interpreter init should open the EC UART PTY. - expected_ec_calls = [mock.call(self.tempfile.name, 'r+b', buffering=0)] - # Have a command come in the command pipe. The first command will be an - # interrogation to determine if the EC is enhanced or not. - self.cmd_pipe_user.send(interpreter.EC_SYN) - self.itpr.HandleUserData() - # At this point, the command should be queued up waiting to be sent, so - # let's actually send it to the EC. - self.itpr.SendCmdToEC() - expected_ec_calls.extend([mock.call().write(interpreter.EC_SYN), - mock.call().flush()]) - # Now, assume that the EC sends only 1 response back of EC_ACK. - mock_os.read.side_effect = [interpreter.EC_ACK] - # When reading the EC, the interpreter will call file.fileno() to pass to - # os.read(). - expected_ec_calls.append(mock.call().fileno()) - # Simulate the response. - self.itpr.HandleECData() - - # Let's send a command that is received on the EC-side with an error. - test_cmd = b'accelinfo' - self.cmd_pipe_user.send(test_cmd) - self.itpr.HandleUserData() - self.itpr.SendCmdToEC() - packed_cmd = self.itpr.PackCommand(test_cmd) - expected_ec_calls.extend([mock.call().write(packed_cmd), - mock.call().flush()]) - # Have the EC return the error string twice. - mock_os.read.side_effect = [b'&&EE', b'&&EE'] - for i in range(2): - # When reading the EC, the interpreter will call file.fileno() to pass to - # os.read(). - expected_ec_calls.append(mock.call().fileno()) - # Simulate the response. - self.itpr.HandleECData() - - # Since an error was received, the EC should attempt to retry the command. - expected_ec_calls.extend([mock.call().write(packed_cmd), - mock.call().flush()]) - # Verify that the retry count was decremented. - self.assertEqual(interpreter.COMMAND_RETRIES-i-1, self.itpr.cmd_retries, - 'Unexpected cmd_remaining count.') - # Actually retry the command. - self.itpr.SendCmdToEC() - - # Now assume that the last one goes through with no trouble. - expected_ec_calls.extend([mock.call().write(packed_cmd), - mock.call().flush()]) - self.itpr.SendCmdToEC() - - # Verify all the calls. - self.ec_uart_pty.assert_has_calls(expected_ec_calls) - - def test_PackCommandsForEnhancedEC(self): - """Verify that the interpreter packs commands for enhanced EC images.""" - # Assume current EC image is enhanced. - self.itpr.enhanced_ec = True - # Receive a command from the user. - test_cmd = b'gettime' - self.cmd_pipe_user.send(test_cmd) - # Mock out PackCommand to see if it was called. - self.itpr.PackCommand = mock.MagicMock() - # Have the interpreter handle the command. - self.itpr.HandleUserData() - # Verify that PackCommand() was called. - self.itpr.PackCommand.assert_called_once_with(test_cmd) - - def test_DontPackCommandsForNonEnhancedEC(self): - """Verify the interpreter doesn't pack commands for non-enhanced images.""" - # Assume current EC image is not enhanced. - self.itpr.enhanced_ec = False - # Receive a command from the user. - test_cmd = b'gettime' - self.cmd_pipe_user.send(test_cmd) - # Mock out PackCommand to see if it was called. - self.itpr.PackCommand = mock.MagicMock() - # Have the interpreter handle the command. - self.itpr.HandleUserData() - # Verify that PackCommand() was called. - self.itpr.PackCommand.assert_not_called() - - @mock.patch('ec3po.interpreter.os') - def test_KeepingTrackOfInterrogation(self, mock_os): - """Verify that the interpreter can track the state of the interrogation. - - Args: - mock_os: MagicMock object replacing the 'os' module. for this test - case. - """ - # Upon init, the interpreter should assume that the current EC image is not - # enhanced. - self.assertFalse(self.itpr.enhanced_ec, msg=('State of enhanced_ec upon' - ' init is not False.')) - - # Assume an interrogation request comes in from the user. - self.cmd_pipe_user.send(interpreter.EC_SYN) - self.itpr.HandleUserData() - - # Verify the state is now within an interrogation. - self.assertTrue(self.itpr.interrogating, 'interrogating should be True') - # The state of enhanced_ec should not be changed yet because we haven't - # received a valid response yet. - self.assertFalse(self.itpr.enhanced_ec, msg=('State of enhanced_ec is ' - 'not False.')) - - # Assume that the EC responds with an EC_ACK. - mock_os.read.side_effect = [interpreter.EC_ACK] - self.itpr.HandleECData() - - # Now, the interrogation should be complete and we should know that the - # current EC image is enhanced. - self.assertFalse(self.itpr.interrogating, msg=('interrogating should be ' - 'False')) - self.assertTrue(self.itpr.enhanced_ec, msg='enhanced_ec sholud be True') - - # Now let's perform another interrogation, but pretend that the EC ignores - # it. - self.cmd_pipe_user.send(interpreter.EC_SYN) - self.itpr.HandleUserData() - - # Verify interrogating state. - self.assertTrue(self.itpr.interrogating, 'interrogating sholud be True') - # We should assume that the image is not enhanced until we get the valid - # response. - self.assertFalse(self.itpr.enhanced_ec, 'enhanced_ec should be False now.') - - # Let's pretend that we get a random debug print. This should clear the - # interrogating flag. - mock_os.read.side_effect = [b'[1660.593076 HC 0x103]'] - self.itpr.HandleECData() - - # Verify that interrogating flag is cleared and enhanced_ec is still False. - self.assertFalse(self.itpr.interrogating, 'interrogating should be False.') - self.assertFalse(self.itpr.enhanced_ec, - 'enhanced_ec should still be False.') + """Test case to verify all enhanced EC interpretation tasks.""" + + def setUp(self): + """Setup the test harness.""" + # Setup logging with a timestamp, the module, and the log level. + logging.basicConfig( + level=logging.DEBUG, + format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"), + ) + + # Create a tempfile that would represent the EC UART PTY. + self.tempfile = tempfile.NamedTemporaryFile() + + # Create the pipes that the interpreter will use. + self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe() + self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False) + + # Mock the open() function so we can inspect reads/writes to the EC. + self.ec_uart_pty = mock.mock_open() + + with mock.patch(GetBuiltins("open"), self.ec_uart_pty): + # Create an interpreter. + self.itpr = interpreter.Interpreter( + self.tempfile.name, + self.cmd_pipe_itpr, + self.dbg_pipe_itpr, + log_level=logging.DEBUG, + name="EC", + ) + + @mock.patch("ec3po.interpreter.os") + def test_HandlingCommandsThatProduceNoOutput(self, mock_os): + """Verify that the Interpreter correctly handles non-output commands. + + Args: + mock_os: MagicMock object replacing the 'os' module for this test + case. + """ + # The interpreter init should open the EC UART PTY. + expected_ec_calls = [mock.call(self.tempfile.name, "r+b", buffering=0)] + # Have a command come in the command pipe. The first command will be an + # interrogation to determine if the EC is enhanced or not. + self.cmd_pipe_user.send(interpreter.EC_SYN) + self.itpr.HandleUserData() + # At this point, the command should be queued up waiting to be sent, so + # let's actually send it to the EC. + self.itpr.SendCmdToEC() + expected_ec_calls.extend( + [mock.call().write(interpreter.EC_SYN), mock.call().flush()] + ) + # Now, assume that the EC sends only 1 response back of EC_ACK. + mock_os.read.side_effect = [interpreter.EC_ACK] + # When reading the EC, the interpreter will call file.fileno() to pass to + # os.read(). + expected_ec_calls.append(mock.call().fileno()) + # Simulate the response. + self.itpr.HandleECData() + + # Now that the interrogation was complete, it's time to send down the real + # command. + test_cmd = b"chan save" + # Send the test command down the pipe. + self.cmd_pipe_user.send(test_cmd) + self.itpr.HandleUserData() + self.itpr.SendCmdToEC() + # Since the EC image is enhanced, we should have sent a packed command. + expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd))) + expected_ec_calls.append(mock.call().flush()) + + # Now that the first command was sent, we should send another command which + # produces no output. The console would send another interrogation. + self.cmd_pipe_user.send(interpreter.EC_SYN) + self.itpr.HandleUserData() + self.itpr.SendCmdToEC() + expected_ec_calls.extend( + [mock.call().write(interpreter.EC_SYN), mock.call().flush()] + ) + # Again, assume that the EC sends only 1 response back of EC_ACK. + mock_os.read.side_effect = [interpreter.EC_ACK] + # When reading the EC, the interpreter will call file.fileno() to pass to + # os.read(). + expected_ec_calls.append(mock.call().fileno()) + # Simulate the response. + self.itpr.HandleECData() + + # Now send the second test command. + test_cmd = b"chan 0" + self.cmd_pipe_user.send(test_cmd) + self.itpr.HandleUserData() + self.itpr.SendCmdToEC() + # Since the EC image is enhanced, we should have sent a packed command. + expected_ec_calls.append(mock.call().write(self.itpr.PackCommand(test_cmd))) + expected_ec_calls.append(mock.call().flush()) + + # Finally, verify that the appropriate writes were actually sent to the EC. + self.ec_uart_pty.assert_has_calls(expected_ec_calls) + + @mock.patch("ec3po.interpreter.os") + def test_CommandRetryingOnError(self, mock_os): + """Verify that commands are retried if an error is encountered. + + Args: + mock_os: MagicMock object replacing the 'os' module for this test + case. + """ + # The interpreter init should open the EC UART PTY. + expected_ec_calls = [mock.call(self.tempfile.name, "r+b", buffering=0)] + # Have a command come in the command pipe. The first command will be an + # interrogation to determine if the EC is enhanced or not. + self.cmd_pipe_user.send(interpreter.EC_SYN) + self.itpr.HandleUserData() + # At this point, the command should be queued up waiting to be sent, so + # let's actually send it to the EC. + self.itpr.SendCmdToEC() + expected_ec_calls.extend( + [mock.call().write(interpreter.EC_SYN), mock.call().flush()] + ) + # Now, assume that the EC sends only 1 response back of EC_ACK. + mock_os.read.side_effect = [interpreter.EC_ACK] + # When reading the EC, the interpreter will call file.fileno() to pass to + # os.read(). + expected_ec_calls.append(mock.call().fileno()) + # Simulate the response. + self.itpr.HandleECData() + + # Let's send a command that is received on the EC-side with an error. + test_cmd = b"accelinfo" + self.cmd_pipe_user.send(test_cmd) + self.itpr.HandleUserData() + self.itpr.SendCmdToEC() + packed_cmd = self.itpr.PackCommand(test_cmd) + expected_ec_calls.extend([mock.call().write(packed_cmd), mock.call().flush()]) + # Have the EC return the error string twice. + mock_os.read.side_effect = [b"&&EE", b"&&EE"] + for i in range(2): + # When reading the EC, the interpreter will call file.fileno() to pass to + # os.read(). + expected_ec_calls.append(mock.call().fileno()) + # Simulate the response. + self.itpr.HandleECData() + + # Since an error was received, the EC should attempt to retry the command. + expected_ec_calls.extend( + [mock.call().write(packed_cmd), mock.call().flush()] + ) + # Verify that the retry count was decremented. + self.assertEqual( + interpreter.COMMAND_RETRIES - i - 1, + self.itpr.cmd_retries, + "Unexpected cmd_remaining count.", + ) + # Actually retry the command. + self.itpr.SendCmdToEC() + + # Now assume that the last one goes through with no trouble. + expected_ec_calls.extend([mock.call().write(packed_cmd), mock.call().flush()]) + self.itpr.SendCmdToEC() + + # Verify all the calls. + self.ec_uart_pty.assert_has_calls(expected_ec_calls) + + def test_PackCommandsForEnhancedEC(self): + """Verify that the interpreter packs commands for enhanced EC images.""" + # Assume current EC image is enhanced. + self.itpr.enhanced_ec = True + # Receive a command from the user. + test_cmd = b"gettime" + self.cmd_pipe_user.send(test_cmd) + # Mock out PackCommand to see if it was called. + self.itpr.PackCommand = mock.MagicMock() + # Have the interpreter handle the command. + self.itpr.HandleUserData() + # Verify that PackCommand() was called. + self.itpr.PackCommand.assert_called_once_with(test_cmd) + + def test_DontPackCommandsForNonEnhancedEC(self): + """Verify the interpreter doesn't pack commands for non-enhanced images.""" + # Assume current EC image is not enhanced. + self.itpr.enhanced_ec = False + # Receive a command from the user. + test_cmd = b"gettime" + self.cmd_pipe_user.send(test_cmd) + # Mock out PackCommand to see if it was called. + self.itpr.PackCommand = mock.MagicMock() + # Have the interpreter handle the command. + self.itpr.HandleUserData() + # Verify that PackCommand() was called. + self.itpr.PackCommand.assert_not_called() + + @mock.patch("ec3po.interpreter.os") + def test_KeepingTrackOfInterrogation(self, mock_os): + """Verify that the interpreter can track the state of the interrogation. + + Args: + mock_os: MagicMock object replacing the 'os' module. for this test + case. + """ + # Upon init, the interpreter should assume that the current EC image is not + # enhanced. + self.assertFalse( + self.itpr.enhanced_ec, + msg=("State of enhanced_ec upon" " init is not False."), + ) + + # Assume an interrogation request comes in from the user. + self.cmd_pipe_user.send(interpreter.EC_SYN) + self.itpr.HandleUserData() + + # Verify the state is now within an interrogation. + self.assertTrue(self.itpr.interrogating, "interrogating should be True") + # The state of enhanced_ec should not be changed yet because we haven't + # received a valid response yet. + self.assertFalse( + self.itpr.enhanced_ec, msg=("State of enhanced_ec is " "not False.") + ) + + # Assume that the EC responds with an EC_ACK. + mock_os.read.side_effect = [interpreter.EC_ACK] + self.itpr.HandleECData() + + # Now, the interrogation should be complete and we should know that the + # current EC image is enhanced. + self.assertFalse( + self.itpr.interrogating, msg=("interrogating should be " "False") + ) + self.assertTrue(self.itpr.enhanced_ec, msg="enhanced_ec sholud be True") + + # Now let's perform another interrogation, but pretend that the EC ignores + # it. + self.cmd_pipe_user.send(interpreter.EC_SYN) + self.itpr.HandleUserData() + + # Verify interrogating state. + self.assertTrue(self.itpr.interrogating, "interrogating sholud be True") + # We should assume that the image is not enhanced until we get the valid + # response. + self.assertFalse(self.itpr.enhanced_ec, "enhanced_ec should be False now.") + + # Let's pretend that we get a random debug print. This should clear the + # interrogating flag. + mock_os.read.side_effect = [b"[1660.593076 HC 0x103]"] + self.itpr.HandleECData() + + # Verify that interrogating flag is cleared and enhanced_ec is still False. + self.assertFalse(self.itpr.interrogating, "interrogating should be False.") + self.assertFalse(self.itpr.enhanced_ec, "enhanced_ec should still be False.") class TestUARTDisconnection(unittest.TestCase): - """Test case to verify interpreter disconnection/reconnection.""" - def setUp(self): - """Setup the test harness.""" - # Setup logging with a timestamp, the module, and the log level. - logging.basicConfig(level=logging.DEBUG, - format=('%(asctime)s - %(module)s -' - ' %(levelname)s - %(message)s')) - - # Create a tempfile that would represent the EC UART PTY. - self.tempfile = tempfile.NamedTemporaryFile() - - # Create the pipes that the interpreter will use. - self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe() - self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False) - - # Mock the open() function so we can inspect reads/writes to the EC. - self.ec_uart_pty = mock.mock_open() - - with mock.patch(GetBuiltins('open'), self.ec_uart_pty): - # Create an interpreter. - self.itpr = interpreter.Interpreter(self.tempfile.name, - self.cmd_pipe_itpr, - self.dbg_pipe_itpr, - log_level=logging.DEBUG, - name="EC") - - # First, check that interpreter is initialized to connected. - self.assertTrue(self.itpr.connected, ('The interpreter should be' - ' initialized in a connected state')) - - def test_DisconnectStopsECTraffic(self): - """Verify that when in disconnected state, no debug prints are sent.""" - # Let's send a disconnect command through the command pipe. - self.cmd_pipe_user.send(b'disconnect') - self.itpr.HandleUserData() - - # Verify interpreter is disconnected from EC. - self.assertFalse(self.itpr.connected, ('The interpreter should be' - 'disconnected.')) - # Verify that the EC UART is no longer a member of the inputs. The - # interpreter will never pull data from the EC if it's not a member of the - # inputs list. - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs) - - def test_CommandsDroppedWhenDisconnected(self): - """Verify that when in disconnected state, commands are dropped.""" - # Send a command, followed by 'disconnect'. - self.cmd_pipe_user.send(b'taskinfo') - self.itpr.HandleUserData() - self.cmd_pipe_user.send(b'disconnect') - self.itpr.HandleUserData() - - # Verify interpreter is disconnected from EC. - self.assertFalse(self.itpr.connected, ('The interpreter should be' - 'disconnected.')) - # Verify that the EC UART is no longer a member of the inputs nor outputs. - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs) - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) - - # Have the user send a few more commands in the disconnected state. - command = 'help\n' - for char in command: - self.cmd_pipe_user.send(char.encode('utf-8')) - self.itpr.HandleUserData() - - # The command queue should be empty. - self.assertEqual(0, self.itpr.ec_cmd_queue.qsize()) - - # Now send the reconnect command. - self.cmd_pipe_user.send(b'reconnect') - - with mock.patch(GetBuiltins('open'), mock.mock_open()): - self.itpr.HandleUserData() - - # Verify interpreter is connected. - self.assertTrue(self.itpr.connected) - # Verify that EC UART is a member of the inputs. - self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs) - # Since no command was sent after reconnection, verify that the EC UART is - # not a member of the outputs. - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) - - def test_ReconnectAllowsECTraffic(self): - """Verify that when connected, EC UART traffic is allowed.""" - # Let's send a disconnect command through the command pipe. - self.cmd_pipe_user.send(b'disconnect') - self.itpr.HandleUserData() - - # Verify interpreter is disconnected. - self.assertFalse(self.itpr.connected, ('The interpreter should be' - 'disconnected.')) - # Verify that the EC UART is no longer a member of the inputs nor outputs. - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs) - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) - - # Issue reconnect command through the command pipe. - self.cmd_pipe_user.send(b'reconnect') - - with mock.patch(GetBuiltins('open'), mock.mock_open()): - self.itpr.HandleUserData() - - # Verify interpreter is connected. - self.assertTrue(self.itpr.connected, ('The interpreter should be' - 'connected.')) - # Verify that the EC UART is now a member of the inputs. - self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs) - # Since we have issued no commands during the disconnected state, no - # commands are pending and therefore the PTY should not be added to the - # outputs. - self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) - - -if __name__ == '__main__': - unittest.main() + """Test case to verify interpreter disconnection/reconnection.""" + + def setUp(self): + """Setup the test harness.""" + # Setup logging with a timestamp, the module, and the log level. + logging.basicConfig( + level=logging.DEBUG, + format=("%(asctime)s - %(module)s -" " %(levelname)s - %(message)s"), + ) + + # Create a tempfile that would represent the EC UART PTY. + self.tempfile = tempfile.NamedTemporaryFile() + + # Create the pipes that the interpreter will use. + self.cmd_pipe_user, self.cmd_pipe_itpr = threadproc_shim.Pipe() + self.dbg_pipe_user, self.dbg_pipe_itpr = threadproc_shim.Pipe(duplex=False) + + # Mock the open() function so we can inspect reads/writes to the EC. + self.ec_uart_pty = mock.mock_open() + + with mock.patch(GetBuiltins("open"), self.ec_uart_pty): + # Create an interpreter. + self.itpr = interpreter.Interpreter( + self.tempfile.name, + self.cmd_pipe_itpr, + self.dbg_pipe_itpr, + log_level=logging.DEBUG, + name="EC", + ) + + # First, check that interpreter is initialized to connected. + self.assertTrue( + self.itpr.connected, + ("The interpreter should be" " initialized in a connected state"), + ) + + def test_DisconnectStopsECTraffic(self): + """Verify that when in disconnected state, no debug prints are sent.""" + # Let's send a disconnect command through the command pipe. + self.cmd_pipe_user.send(b"disconnect") + self.itpr.HandleUserData() + + # Verify interpreter is disconnected from EC. + self.assertFalse( + self.itpr.connected, ("The interpreter should be" "disconnected.") + ) + # Verify that the EC UART is no longer a member of the inputs. The + # interpreter will never pull data from the EC if it's not a member of the + # inputs list. + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs) + + def test_CommandsDroppedWhenDisconnected(self): + """Verify that when in disconnected state, commands are dropped.""" + # Send a command, followed by 'disconnect'. + self.cmd_pipe_user.send(b"taskinfo") + self.itpr.HandleUserData() + self.cmd_pipe_user.send(b"disconnect") + self.itpr.HandleUserData() + + # Verify interpreter is disconnected from EC. + self.assertFalse( + self.itpr.connected, ("The interpreter should be" "disconnected.") + ) + # Verify that the EC UART is no longer a member of the inputs nor outputs. + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs) + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) + + # Have the user send a few more commands in the disconnected state. + command = "help\n" + for char in command: + self.cmd_pipe_user.send(char.encode("utf-8")) + self.itpr.HandleUserData() + + # The command queue should be empty. + self.assertEqual(0, self.itpr.ec_cmd_queue.qsize()) + + # Now send the reconnect command. + self.cmd_pipe_user.send(b"reconnect") + + with mock.patch(GetBuiltins("open"), mock.mock_open()): + self.itpr.HandleUserData() + + # Verify interpreter is connected. + self.assertTrue(self.itpr.connected) + # Verify that EC UART is a member of the inputs. + self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs) + # Since no command was sent after reconnection, verify that the EC UART is + # not a member of the outputs. + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) + + def test_ReconnectAllowsECTraffic(self): + """Verify that when connected, EC UART traffic is allowed.""" + # Let's send a disconnect command through the command pipe. + self.cmd_pipe_user.send(b"disconnect") + self.itpr.HandleUserData() + + # Verify interpreter is disconnected. + self.assertFalse( + self.itpr.connected, ("The interpreter should be" "disconnected.") + ) + # Verify that the EC UART is no longer a member of the inputs nor outputs. + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.inputs) + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) + + # Issue reconnect command through the command pipe. + self.cmd_pipe_user.send(b"reconnect") + + with mock.patch(GetBuiltins("open"), mock.mock_open()): + self.itpr.HandleUserData() + + # Verify interpreter is connected. + self.assertTrue(self.itpr.connected, ("The interpreter should be" "connected.")) + # Verify that the EC UART is now a member of the inputs. + self.assertTrue(self.itpr.ec_uart_pty in self.itpr.inputs) + # Since we have issued no commands during the disconnected state, no + # commands are pending and therefore the PTY should not be added to the + # outputs. + self.assertFalse(self.itpr.ec_uart_pty in self.itpr.outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/util/ec3po/threadproc_shim.py b/util/ec3po/threadproc_shim.py index da5440b1f3..c0b3ce0bf4 100644 --- a/util/ec3po/threadproc_shim.py +++ b/util/ec3po/threadproc_shim.py @@ -34,33 +34,34 @@ wait until after completing the TODO above to stop using multiprocessing.Pipe! # Imports to bring objects into this namespace for users of this module. from multiprocessing import Pipe -from six.moves.queue import Queue from threading import Thread as ThreadOrProcess +from six.moves.queue import Queue + # True if this module has ec3po using subprocesses, False if using threads. USING_SUBPROCS = False def _DoNothing(): - """Do-nothing function for use as a callback with DoIf().""" + """Do-nothing function for use as a callback with DoIf().""" def DoIf(subprocs=_DoNothing, threads=_DoNothing): - """Return a callback or not based on ec3po use of subprocesses or threads. + """Return a callback or not based on ec3po use of subprocesses or threads. - Args: - subprocs: callback that does not require any args - This will be returned - (not called!) if and only if ec3po is using subprocesses. This is - OPTIONAL, the default value is a do-nothing callback that returns None. - threads: callback that does not require any args - This will be returned - (not called!) if and only if ec3po is using threads. This is OPTIONAL, - the default value is a do-nothing callback that returns None. + Args: + subprocs: callback that does not require any args - This will be returned + (not called!) if and only if ec3po is using subprocesses. This is + OPTIONAL, the default value is a do-nothing callback that returns None. + threads: callback that does not require any args - This will be returned + (not called!) if and only if ec3po is using threads. This is OPTIONAL, + the default value is a do-nothing callback that returns None. - Returns: - Either the subprocs or threads argument will be returned. - """ - return subprocs if USING_SUBPROCS else threads + Returns: + Either the subprocs or threads argument will be returned. + """ + return subprocs if USING_SUBPROCS else threads def Value(ctype, *args): - return ctype(*args) + return ctype(*args) diff --git a/util/ec_openocd.py b/util/ec_openocd.py index a84c00643c..11956ffa1c 100755 --- a/util/ec_openocd.py +++ b/util/ec_openocd.py @@ -16,6 +16,7 @@ import time Flashes and debugs the EC through openocd """ + @dataclasses.dataclass class BoardInfo: gdb_variant: str @@ -24,9 +25,7 @@ class BoardInfo: # Debuggers for each board, OpenOCD currently only supports GDB -boards = { - "skyrim": BoardInfo("arm-none-eabi-gdb", 6, 4) -} +boards = {"skyrim": BoardInfo("arm-none-eabi-gdb", 6, 4)} def create_openocd_args(interface, board): @@ -36,9 +35,12 @@ def create_openocd_args(interface, board): board_info = boards[board] args = [ "openocd", - "-f", f"interface/{interface}.cfg", - "-c", "add_script_search_dir openocd", - "-f", f"board/{board}.cfg", + "-f", + f"interface/{interface}.cfg", + "-c", + "add_script_search_dir openocd", + "-f", + f"board/{board}.cfg", ] return args @@ -53,11 +55,13 @@ def create_gdb_args(board, port, executable): board_info.gdb_variant, executable, # GDB can't autodetect these according to OpenOCD - "-ex", f"set remote hardware-breakpoint-limit {board_info.num_breakpoints}", - "-ex", f"set remote hardware-watchpoint-limit {board_info.num_watchpoints}", - + "-ex", + f"set remote hardware-breakpoint-limit {board_info.num_breakpoints}", + "-ex", + f"set remote hardware-watchpoint-limit {board_info.num_watchpoints}", # Connect to OpenOCD - "-ex", f"target extended-remote localhost:{port}", + "-ex", + f"target extended-remote localhost:{port}", ] return args diff --git a/util/flash_jlink.py b/util/flash_jlink.py index 26c3c2e709..50a0bfca20 100755 --- a/util/flash_jlink.py +++ b/util/flash_jlink.py @@ -25,7 +25,6 @@ import sys import tempfile import time - DEFAULT_SEGGER_REMOTE_PORT = 19020 # Commands are documented here: https://wiki.segger.com/J-Link_Commander @@ -41,27 +40,34 @@ exit class BoardConfig: """Board configuration.""" + def __init__(self, interface, device, flash_address): self.interface = interface self.device = device self.flash_address = flash_address -SWD_INTERFACE = 'SWD' -STM32_DEFAULT_FLASH_ADDRESS = '0x8000000' -DRAGONCLAW_CONFIG = BoardConfig(interface=SWD_INTERFACE, device='STM32F412CG', - flash_address=STM32_DEFAULT_FLASH_ADDRESS) -ICETOWER_CONFIG = BoardConfig(interface=SWD_INTERFACE, device='STM32H743ZI', - flash_address=STM32_DEFAULT_FLASH_ADDRESS) +SWD_INTERFACE = "SWD" +STM32_DEFAULT_FLASH_ADDRESS = "0x8000000" +DRAGONCLAW_CONFIG = BoardConfig( + interface=SWD_INTERFACE, + device="STM32F412CG", + flash_address=STM32_DEFAULT_FLASH_ADDRESS, +) +ICETOWER_CONFIG = BoardConfig( + interface=SWD_INTERFACE, + device="STM32H743ZI", + flash_address=STM32_DEFAULT_FLASH_ADDRESS, +) BOARD_CONFIGS = { - 'dragonclaw': DRAGONCLAW_CONFIG, - 'bloonchipper': DRAGONCLAW_CONFIG, - 'nucleo-f412zg': DRAGONCLAW_CONFIG, - 'dartmonkey': ICETOWER_CONFIG, - 'icetower': ICETOWER_CONFIG, - 'nucleo-dartmonkey': ICETOWER_CONFIG, - 'nucleo-h743zi': ICETOWER_CONFIG, + "dragonclaw": DRAGONCLAW_CONFIG, + "bloonchipper": DRAGONCLAW_CONFIG, + "nucleo-f412zg": DRAGONCLAW_CONFIG, + "dartmonkey": ICETOWER_CONFIG, + "icetower": ICETOWER_CONFIG, + "nucleo-dartmonkey": ICETOWER_CONFIG, + "nucleo-h743zi": ICETOWER_CONFIG, } @@ -93,9 +99,11 @@ def is_tcp_port_open(host: str, tcp_port: int) -> bool: def create_jlink_command_file(firmware_file, config): tmp = tempfile.NamedTemporaryFile() - tmp.write(JLINK_COMMANDS.format(FIRMWARE=firmware_file, - FLASH_ADDRESS=config.flash_address).encode( - 'utf-8')) + tmp.write( + JLINK_COMMANDS.format( + FIRMWARE=firmware_file, FLASH_ADDRESS=config.flash_address + ).encode("utf-8") + ) tmp.flush() return tmp @@ -106,8 +114,8 @@ def flash(jlink_exe, remote, device, interface, cmd_file): ] if remote: - logging.debug(f'Connecting to J-Link over TCP/IP {remote}.') - remote_components = remote.split(':') + logging.debug(f"Connecting to J-Link over TCP/IP {remote}.") + remote_components = remote.split(":") if len(remote_components) not in [1, 2]: logging.debug(f'Given remote "{remote}" is malformed.') return 1 @@ -118,7 +126,7 @@ def flash(jlink_exe, remote, device, interface, cmd_file): except socket.gaierror as e: logging.error(f'Failed to resolve host "{host}": {e}.') return 1 - logging.debug(f'Resolved {host} as {ip}.') + logging.debug(f"Resolved {host} as {ip}.") port = DEFAULT_SEGGER_REMOTE_PORT if len(remote_components) == 2: @@ -126,29 +134,36 @@ def flash(jlink_exe, remote, device, interface, cmd_file): port = int(remote_components[1]) except ValueError: logging.error( - f'Given remote port "{remote_components[1]}" is malformed.') + f'Given remote port "{remote_components[1]}" is malformed.' + ) return 1 - remote = f'{ip}:{port}' + remote = f"{ip}:{port}" - logging.debug(f'Checking connection to {remote}.') + logging.debug(f"Checking connection to {remote}.") if not is_tcp_port_open(ip, port): - logging.error( - f"JLink server doesn't seem to be listening on {remote}.") - logging.error('Ensure that JLinkRemoteServerCLExe is running.') + logging.error(f"JLink server doesn't seem to be listening on {remote}.") + logging.error("Ensure that JLinkRemoteServerCLExe is running.") return 1 - cmd.extend(['-ip', remote]) - - cmd.extend([ - '-device', device, - '-if', interface, - '-speed', 'auto', - '-autoconnect', '1', - '-CommandFile', cmd_file, - ]) - logging.debug('Running command: "%s"', ' '.join(cmd)) + cmd.extend(["-ip", remote]) + + cmd.extend( + [ + "-device", + device, + "-if", + interface, + "-speed", + "auto", + "-autoconnect", + "1", + "-CommandFile", + cmd_file, + ] + ) + logging.debug('Running command: "%s"', " ".join(cmd)) completed_process = subprocess.run(cmd) # pylint: disable=subprocess-run-check - logging.debug('JLink return code: %d', completed_process.returncode) + logging.debug("JLink return code: %d", completed_process.returncode) return completed_process.returncode @@ -156,38 +171,42 @@ def main(argv: list): parser = argparse.ArgumentParser() - default_jlink = './JLink_Linux_V684a_x86_64/JLinkExe' + default_jlink = "./JLink_Linux_V684a_x86_64/JLinkExe" if shutil.which(default_jlink) is None: - default_jlink = 'JLinkExe' - parser.add_argument( - '--jlink', '-j', - help='JLinkExe path (default: ' + default_jlink + ')', - default=default_jlink) - + default_jlink = "JLinkExe" parser.add_argument( - '--remote', '-n', - help='Use TCP/IP host[:port] to connect to a J-Link or ' - 'JLinkRemoteServerCLExe. If unspecified, connect over USB.') + "--jlink", + "-j", + help="JLinkExe path (default: " + default_jlink + ")", + default=default_jlink, + ) - default_board = 'bloonchipper' parser.add_argument( - '--board', '-b', - help='Board (default: ' + default_board + ')', - default=default_board) + "--remote", + "-n", + help="Use TCP/IP host[:port] to connect to a J-Link or " + "JLinkRemoteServerCLExe. If unspecified, connect over USB.", + ) - default_firmware = os.path.join('./build', default_board, 'ec.bin') + default_board = "bloonchipper" parser.add_argument( - '--image', '-i', - help='Firmware binary (default: ' + default_firmware + ')', - default=default_firmware) + "--board", + "-b", + help="Board (default: " + default_board + ")", + default=default_board, + ) - log_level_choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + default_firmware = os.path.join("./build", default_board, "ec.bin") parser.add_argument( - '--log_level', '-l', - choices=log_level_choices, - default='DEBUG' + "--image", + "-i", + help="Firmware binary (default: " + default_firmware + ")", + default=default_firmware, ) + log_level_choices = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + parser.add_argument("--log_level", "-l", choices=log_level_choices, default="DEBUG") + args = parser.parse_args(argv) logging.basicConfig(level=args.log_level) @@ -201,11 +220,12 @@ def main(argv: list): args.jlink = args.jlink cmd_file = create_jlink_command_file(args.image, config) - ret_code = flash(args.jlink, args.remote, config.device, config.interface, - cmd_file.name) + ret_code = flash( + args.jlink, args.remote, config.device, config.interface, cmd_file.name + ) cmd_file.close() return ret_code -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/util/fptool.py b/util/fptool.py index 5d73302bbc..b7f2150289 100755 --- a/util/fptool.py +++ b/util/fptool.py @@ -19,14 +19,14 @@ def cmd_flash(args: argparse.Namespace) -> int: disabled. """ - if not shutil.which('flash_fp_mcu'): - print('Error - The flash_fp_mcu utility does not exist.') + if not shutil.which("flash_fp_mcu"): + print("Error - The flash_fp_mcu utility does not exist.") return 1 - cmd = ['flash_fp_mcu'] + cmd = ["flash_fp_mcu"] if args.image: if not os.path.isfile(args.image): - print(f'Error - image {args.image} is not a file.') + print(f"Error - image {args.image} is not a file.") return 1 cmd.append(args.image) @@ -38,18 +38,17 @@ def cmd_flash(args: argparse.Namespace) -> int: def main(argv: list) -> int: parser = argparse.ArgumentParser(description=__doc__) - subparsers = parser.add_subparsers(dest='subcommand', title='subcommands') + subparsers = parser.add_subparsers(dest="subcommand", title="subcommands") # This method of setting required is more compatible with older python. subparsers.required = True # Parser for "flash" subcommand. - parser_decrypt = subparsers.add_parser('flash', help=cmd_flash.__doc__) - parser_decrypt.add_argument( - 'image', nargs='?', help='Path to the firmware image') + parser_decrypt = subparsers.add_parser("flash", help=cmd_flash.__doc__) + parser_decrypt.add_argument("image", nargs="?", help="Path to the firmware image") parser_decrypt.set_defaults(func=cmd_flash) opts = parser.parse_args(argv) return opts.func(opts) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/util/inject-keys.py b/util/inject-keys.py index bd10b693ad..d05d4fbed7 100755 --- a/util/inject-keys.py +++ b/util/inject-keys.py @@ -8,50 +8,124 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import string import subprocess import sys - -KEYMATRIX = {'`': (3, 1), '1': (6, 1), '2': (6, 4), '3': (6, 2), '4': (6, 3), - '5': (3, 3), '6': (3, 6), '7': (6, 6), '8': (6, 5), '9': (6, 9), - '0': (6, 8), '-': (3, 8), '=': (0, 8), 'q': (7, 1), 'w': (7, 4), - 'e': (7, 2), 'r': (7, 3), 't': (2, 3), 'y': (2, 6), 'u': (7, 6), - 'i': (7, 5), 'o': (7, 9), 'p': (7, 8), '[': (2, 8), ']': (2, 5), - '\\': (3, 11), 'a': (4, 1), 's': (4, 4), 'd': (4, 2), 'f': (4, 3), - 'g': (1, 3), 'h': (1, 6), 'j': (4, 6), 'k': (4, 5), 'l': (4, 9), - ';': (4, 8), '\'': (1, 8), 'z': (5, 1), 'x': (5, 4), 'c': (5, 2), - 'v': (5, 3), 'b': (0, 3), 'n': (0, 6), 'm': (5, 6), ',': (5, 5), - '.': (5, 9), '/': (5, 8), ' ': (5, 11), '<right>': (6, 12), - '<alt_r>': (0, 10), '<down>': (6, 11), '<tab>': (2, 1), - '<f10>': (0, 4), '<shift_r>': (7, 7), '<ctrl_r>': (4, 0), - '<esc>': (1, 1), '<backspace>': (1, 11), '<f2>': (3, 2), - '<alt_l>': (6, 10), '<ctrl_l>': (2, 0), '<f1>': (0, 2), - '<search>': (0, 1), '<f3>': (2, 2), '<f4>': (1, 2), '<f5>': (3, 4), - '<f6>': (2, 4), '<f7>': (1, 4), '<f8>': (2, 9), '<f9>': (1, 9), - '<up>': (7, 11), '<shift_l>': (5, 7), '<enter>': (4, 11), - '<left>': (7, 12)} - - -UNSHIFT_TABLE = { '~': '`', '!': '1', '@': '2', '#': '3', '$': '4', - '%': '5', '^': '6', '&': '7', '*': '8', '(': '9', - ')': '0', '_': '-', '+': '=', '{': '[', '}': ']', - '|': '\\', - ':': ';', '"': "'", '<': ',', '>': '.', '?': '/'} +KEYMATRIX = { + "`": (3, 1), + "1": (6, 1), + "2": (6, 4), + "3": (6, 2), + "4": (6, 3), + "5": (3, 3), + "6": (3, 6), + "7": (6, 6), + "8": (6, 5), + "9": (6, 9), + "0": (6, 8), + "-": (3, 8), + "=": (0, 8), + "q": (7, 1), + "w": (7, 4), + "e": (7, 2), + "r": (7, 3), + "t": (2, 3), + "y": (2, 6), + "u": (7, 6), + "i": (7, 5), + "o": (7, 9), + "p": (7, 8), + "[": (2, 8), + "]": (2, 5), + "\\": (3, 11), + "a": (4, 1), + "s": (4, 4), + "d": (4, 2), + "f": (4, 3), + "g": (1, 3), + "h": (1, 6), + "j": (4, 6), + "k": (4, 5), + "l": (4, 9), + ";": (4, 8), + "'": (1, 8), + "z": (5, 1), + "x": (5, 4), + "c": (5, 2), + "v": (5, 3), + "b": (0, 3), + "n": (0, 6), + "m": (5, 6), + ",": (5, 5), + ".": (5, 9), + "/": (5, 8), + " ": (5, 11), + "<right>": (6, 12), + "<alt_r>": (0, 10), + "<down>": (6, 11), + "<tab>": (2, 1), + "<f10>": (0, 4), + "<shift_r>": (7, 7), + "<ctrl_r>": (4, 0), + "<esc>": (1, 1), + "<backspace>": (1, 11), + "<f2>": (3, 2), + "<alt_l>": (6, 10), + "<ctrl_l>": (2, 0), + "<f1>": (0, 2), + "<search>": (0, 1), + "<f3>": (2, 2), + "<f4>": (1, 2), + "<f5>": (3, 4), + "<f6>": (2, 4), + "<f7>": (1, 4), + "<f8>": (2, 9), + "<f9>": (1, 9), + "<up>": (7, 11), + "<shift_l>": (5, 7), + "<enter>": (4, 11), + "<left>": (7, 12), +} + + +UNSHIFT_TABLE = { + "~": "`", + "!": "1", + "@": "2", + "#": "3", + "$": "4", + "%": "5", + "^": "6", + "&": "7", + "*": "8", + "(": "9", + ")": "0", + "_": "-", + "+": "=", + "{": "[", + "}": "]", + "|": "\\", + ":": ";", + '"': "'", + "<": ",", + ">": ".", + "?": "/", +} for c in string.ascii_lowercase: UNSHIFT_TABLE[c.upper()] = c def inject_event(key, press): - if len(key) >= 2 and key[0] != '<': - key = '<' + key + '>' + if len(key) >= 2 and key[0] != "<": + key = "<" + key + ">" if key not in KEYMATRIX: print("%s: invalid key: %s" % (this_script, key)) sys.exit(1) (row, col) = KEYMATRIX[key] - subprocess.call(["ectool", "kbpress", str(row), str(col), - "1" if press else "0"]) + subprocess.call(["ectool", "kbpress", str(row), str(col), "1" if press else "0"]) def inject_key(key): @@ -73,8 +147,10 @@ def inject_string(string): def usage(): - print("Usage: %s [-s <string>] [-k <key>]" % this_script, - "[-p <pressed-key>] [-r <released-key>] ...") + print( + "Usage: %s [-s <string>] [-k <key>]" % this_script, + "[-p <pressed-key>] [-r <released-key>] ...", + ) print("Examples:") print("%s -s MyPassw0rd -k enter" % this_script) print("%s -p ctrl_l -p alt_l -k f3 -r alt_l -r ctrl_l" % this_script) @@ -85,7 +161,7 @@ def help(): print("Valid keys are:") i = 0 for key in KEYMATRIX: - print("%12s" % key, end='') + print("%12s" % key, end="") i += 1 if i % 4 == 0: print() @@ -114,12 +190,13 @@ usage_check(arg_len > 1, "not enough arguments") usage_check(arg_len % 2 == 1, "mismatched arguments") for i in range(1, arg_len, 2): - usage_check(sys.argv[i] in ("-s", "-k", "-p", "-r"), - "unknown flag: %s" % sys.argv[i]) + usage_check( + sys.argv[i] in ("-s", "-k", "-p", "-r"), "unknown flag: %s" % sys.argv[i] + ) for i in range(1, arg_len, 2): flag = sys.argv[i] - arg = sys.argv[i+1] + arg = sys.argv[i + 1] if flag == "-s": inject_string(arg) elif flag == "-k": diff --git a/util/kconfig_check.py b/util/kconfig_check.py index d1eba8e62b..04cdf9a990 100755 --- a/util/kconfig_check.py +++ b/util/kconfig_check.py @@ -32,12 +32,13 @@ import sys USE_KCONFIGLIB = False try: import kconfiglib + USE_KCONFIGLIB = True except ImportError: pass # Where we put the new config_allowed file -NEW_ALLOWED_FNAME = pathlib.Path('/tmp/new_config_allowed.txt') +NEW_ALLOWED_FNAME = pathlib.Path("/tmp/new_config_allowed.txt") def parse_args(argv): @@ -49,38 +50,72 @@ def parse_args(argv): Returns: argparse.Namespace object containing the results """ - epilog = '''Checks that new ad-hoc CONFIG options are not introduced without -a corresponding Kconfig option for Zephyr''' + epilog = """Checks that new ad-hoc CONFIG options are not introduced without +a corresponding Kconfig option for Zephyr""" parser = argparse.ArgumentParser(epilog=epilog) - parser.add_argument('-a', '--allowed', type=str, - default='util/config_allowed.txt', - help='File containing list of allowed ad-hoc CONFIGs') - parser.add_argument('-c', '--configs', type=str, default='.config', - help='File containing CONFIG options to check') - parser.add_argument('-d', '--use-defines', action='store_true', - help='Lines in the configs file use #define') parser.add_argument( - '-D', '--debug', action='store_true', - help='Enabling debugging (provides a full traceback on error)') + "-a", + "--allowed", + type=str, + default="util/config_allowed.txt", + help="File containing list of allowed ad-hoc CONFIGs", + ) + parser.add_argument( + "-c", + "--configs", + type=str, + default=".config", + help="File containing CONFIG options to check", + ) + parser.add_argument( + "-d", + "--use-defines", + action="store_true", + help="Lines in the configs file use #define", + ) + parser.add_argument( + "-D", + "--debug", + action="store_true", + help="Enabling debugging (provides a full traceback on error)", + ) + parser.add_argument( + "-i", + "--ignore", + action="append", + help="Kconfig options to ignore (without CONFIG_ prefix)", + ) parser.add_argument( - '-i', '--ignore', action='append', - help='Kconfig options to ignore (without CONFIG_ prefix)') - parser.add_argument('-I', '--search-path', type=str, action='append', - help='Search paths to look for Kconfigs') - parser.add_argument('-p', '--prefix', type=str, default='PLATFORM_EC_', - help='Prefix to string from Kconfig options') - parser.add_argument('-s', '--srctree', type=str, default='zephyr/', - help='Path to source tree to look for Kconfigs') + "-I", + "--search-path", + type=str, + action="append", + help="Search paths to look for Kconfigs", + ) + parser.add_argument( + "-p", + "--prefix", + type=str, + default="PLATFORM_EC_", + help="Prefix to string from Kconfig options", + ) + parser.add_argument( + "-s", + "--srctree", + type=str, + default="zephyr/", + help="Path to source tree to look for Kconfigs", + ) # TODO(sjg@chromium.org): The chroot uses a very old Python. Once it moves # to 3.7 or later we can use this instead: # subparsers = parser.add_subparsers(dest='cmd', required=True) - subparsers = parser.add_subparsers(dest='cmd') + subparsers = parser.add_subparsers(dest="cmd") subparsers.required = True - subparsers.add_parser('build', help='Build new list of ad-hoc CONFIGs') - subparsers.add_parser('check', help='Check for new ad-hoc CONFIGs') + subparsers.add_parser("build", help="Build new list of ad-hoc CONFIGs") + subparsers.add_parser("check", help="Check for new ad-hoc CONFIGs") return parser.parse_args(argv) @@ -107,6 +142,7 @@ class KconfigCheck: the user is exhorted to add a new Kconfig. This helps avoid adding new ad-hoc CONFIG options, eventually returning the number to zero. """ + @classmethod def find_new_adhoc(cls, configs, kconfigs, allowed): """Get a list of new ad-hoc CONFIG options @@ -172,11 +208,12 @@ class KconfigCheck: List of CONFIG_xxx options found in the file, with the 'CONFIG_' prefix removed """ - with open(configs_file, 'r') as inf: - configs = re.findall('%sCONFIG_([A-Za-z0-9_]*)%s' % - ((use_defines and '#define ' or ''), - (use_defines and ' ' or '')), - inf.read()) + with open(configs_file, "r") as inf: + configs = re.findall( + "%sCONFIG_([A-Za-z0-9_]*)%s" + % ((use_defines and "#define " or ""), (use_defines and " " or "")), + inf.read(), + ) return configs @classmethod @@ -190,8 +227,8 @@ class KconfigCheck: List of CONFIG_xxx options found in the file, with the 'CONFIG_' prefix removed """ - with open(allowed_file, 'r') as inf: - configs = re.findall('CONFIG_([A-Za-z0-9_]*)', inf.read()) + with open(allowed_file, "r") as inf: + configs = re.findall("CONFIG_([A-Za-z0-9_]*)", inf.read()) return configs @classmethod @@ -209,15 +246,17 @@ class KconfigCheck: """ kconfig_files = [] for root, dirs, files in os.walk(srcdir): - kconfig_files += [os.path.join(root, fname) - for fname in files if fname.startswith('Kconfig')] - if 'Kconfig' in dirs: - dirs.remove('Kconfig') + kconfig_files += [ + os.path.join(root, fname) + for fname in files + if fname.startswith("Kconfig") + ] + if "Kconfig" in dirs: + dirs.remove("Kconfig") return kconfig_files @classmethod - def scan_kconfigs(cls, srcdir, prefix='', search_paths=None, - try_kconfiglib=True): + def scan_kconfigs(cls, srcdir, prefix="", search_paths=None, try_kconfiglib=True): """Scan a source tree for Kconfig options Args: @@ -231,31 +270,40 @@ class KconfigCheck: List of config and menuconfig options found """ if USE_KCONFIGLIB and try_kconfiglib: - os.environ['srctree'] = srcdir - kconf = kconfiglib.Kconfig('Kconfig', warn=False, - search_paths=search_paths, - allow_empty_macros=True) + os.environ["srctree"] = srcdir + kconf = kconfiglib.Kconfig( + "Kconfig", + warn=False, + search_paths=search_paths, + allow_empty_macros=True, + ) # There is always a MODULES config, since kconfiglib is designed for # linux, but we don't want it - kconfigs = [name for name in kconf.syms if name != 'MODULES'] + kconfigs = [name for name in kconf.syms if name != "MODULES"] if prefix: - re_drop_prefix = re.compile(r'^%s' % prefix) - kconfigs = [re_drop_prefix.sub('', name) for name in kconfigs] + re_drop_prefix = re.compile(r"^%s" % prefix) + kconfigs = [re_drop_prefix.sub("", name) for name in kconfigs] else: kconfigs = [] # Remove the prefix if present - expr = re.compile(r'\n(config|menuconfig) (%s)?([A-Za-z0-9_]*)\n' % - prefix) + expr = re.compile(r"\n(config|menuconfig) (%s)?([A-Za-z0-9_]*)\n" % prefix) for fname in cls.find_kconfigs(srcdir): with open(fname) as inf: found = re.findall(expr, inf.read()) kconfigs += [name for kctype, _, name in found] return sorted(kconfigs) - def check_adhoc_configs(self, configs_file, srcdir, allowed_file, - prefix='', use_defines=False, search_paths=None): + def check_adhoc_configs( + self, + configs_file, + srcdir, + allowed_file, + prefix="", + use_defines=False, + search_paths=None, + ): """Find new and unneeded ad-hoc configs in the configs_file Args: @@ -283,8 +331,9 @@ class KconfigCheck: except kconfiglib.KconfigError: # If we don't actually have access to the full Kconfig then we may # get an error. Fall back to using manual methods. - kconfigs = self.scan_kconfigs(srcdir, prefix, search_paths, - try_kconfiglib=False) + kconfigs = self.scan_kconfigs( + srcdir, prefix, search_paths, try_kconfiglib=False + ) allowed = self.read_allowed(allowed_file) new_adhoc = self.find_new_adhoc(configs, kconfigs, allowed) @@ -292,8 +341,16 @@ class KconfigCheck: updated_adhoc = self.get_updated_adhoc(unneeded_adhoc, allowed) return new_adhoc, unneeded_adhoc, updated_adhoc - def do_check(self, configs_file, srcdir, allowed_file, prefix, use_defines, - search_paths, ignore=None): + def do_check( + self, + configs_file, + srcdir, + allowed_file, + prefix, + use_defines, + search_paths, + ignore=None, + ): """Find new ad-hoc configs in the configs_file Args: @@ -313,11 +370,12 @@ class KconfigCheck: Exit code: 0 if OK, 1 if a problem was found """ new_adhoc, unneeded_adhoc, updated_adhoc = self.check_adhoc_configs( - configs_file, srcdir, allowed_file, prefix, use_defines, - search_paths) + configs_file, srcdir, allowed_file, prefix, use_defines, search_paths + ) if new_adhoc: - file_list = '\n'.join(['CONFIG_%s' % name for name in new_adhoc]) - print(f'''Error:\tThe EC is in the process of migrating to Zephyr. + file_list = "\n".join(["CONFIG_%s" % name for name in new_adhoc]) + print( + f"""Error:\tThe EC is in the process of migrating to Zephyr. \tZephyr uses Kconfig for configuration rather than ad-hoc #defines. \tAny new EC CONFIG options must ALSO be added to Zephyr so that new \tfunctionality is available in Zephyr also. The following new ad-hoc @@ -330,19 +388,21 @@ file in zephyr/ and add a 'config' or 'menuconfig' option. Also see details in http://issuetracker.google.com/181253613 To temporarily disable this, use: ALLOW_CONFIG=1 make ... -''', file=sys.stderr) +""", + file=sys.stderr, + ) return 1 if not ignore: ignore = [] unneeded_adhoc = [name for name in unneeded_adhoc if name not in ignore] if unneeded_adhoc: - with open(NEW_ALLOWED_FNAME, 'w') as out: + with open(NEW_ALLOWED_FNAME, "w") as out: for config in updated_adhoc: - print('CONFIG_%s' % config, file=out) - now_in_kconfig = '\n'.join( - ['CONFIG_%s' % name for name in unneeded_adhoc]) - print(f'''The following options are now in Kconfig: + print("CONFIG_%s" % config, file=out) + now_in_kconfig = "\n".join(["CONFIG_%s" % name for name in unneeded_adhoc]) + print( + f"""The following options are now in Kconfig: {now_in_kconfig} @@ -350,12 +410,14 @@ Please run this to update the list of allowed ad-hoc CONFIGs and include this update in your CL: cp {NEW_ALLOWED_FNAME} util/config_allowed.txt -''') +""" + ) return 1 return 0 - def do_build(self, configs_file, srcdir, allowed_file, prefix, use_defines, - search_paths): + def do_build( + self, configs_file, srcdir, allowed_file, prefix, use_defines, search_paths + ): """Find new ad-hoc configs in the configs_file Args: @@ -372,13 +434,14 @@ update in your CL: Exit code: 0 if OK, 1 if a problem was found """ new_adhoc, _, updated_adhoc = self.check_adhoc_configs( - configs_file, srcdir, allowed_file, prefix, use_defines, - search_paths) - with open(NEW_ALLOWED_FNAME, 'w') as out: + configs_file, srcdir, allowed_file, prefix, use_defines, search_paths + ) + with open(NEW_ALLOWED_FNAME, "w") as out: combined = sorted(new_adhoc + updated_adhoc) for config in combined: - print(f'CONFIG_{config}', file=out) - print(f'New list is in {NEW_ALLOWED_FNAME}') + print(f"CONFIG_{config}", file=out) + print(f"New list is in {NEW_ALLOWED_FNAME}") + def main(argv): """Main function""" @@ -386,18 +449,27 @@ def main(argv): if not args.debug: sys.tracebacklimit = 0 checker = KconfigCheck() - if args.cmd == 'check': + if args.cmd == "check": return checker.do_check( - configs_file=args.configs, srcdir=args.srctree, - allowed_file=args.allowed, prefix=args.prefix, - use_defines=args.use_defines, search_paths=args.search_path, - ignore=args.ignore) - elif args.cmd == 'build': - return checker.do_build(configs_file=args.configs, srcdir=args.srctree, - allowed_file=args.allowed, prefix=args.prefix, - use_defines=args.use_defines, search_paths=args.search_path) + configs_file=args.configs, + srcdir=args.srctree, + allowed_file=args.allowed, + prefix=args.prefix, + use_defines=args.use_defines, + search_paths=args.search_path, + ignore=args.ignore, + ) + elif args.cmd == "build": + return checker.do_build( + configs_file=args.configs, + srcdir=args.srctree, + allowed_file=args.allowed, + prefix=args.prefix, + use_defines=args.use_defines, + search_paths=args.search_path, + ) return 2 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/util/kconfiglib.py b/util/kconfiglib.py index 0e05aaaeac..a0033bba2d 100644 --- a/util/kconfiglib.py +++ b/util/kconfiglib.py @@ -553,7 +553,6 @@ import sys from glob import iglob from os.path import dirname, exists, expandvars, islink, join, realpath - VERSION = (14, 1, 0) @@ -810,6 +809,7 @@ class Kconfig(object): The current parsing location, for use in Python preprocessor functions. See the module docstring. """ + __slots__ = ( "_encoding", "_functions", @@ -848,7 +848,6 @@ class Kconfig(object): "warn_to_stderr", "warnings", "y", - # Parsing-related "_parsing_kconfigs", "_readline", @@ -866,9 +865,16 @@ class Kconfig(object): # Public interface # - def __init__(self, filename="Kconfig", warn=True, warn_to_stderr=True, - encoding="utf-8", suppress_traceback=False, search_paths=None, - allow_empty_macros=False): + def __init__( + self, + filename="Kconfig", + warn=True, + warn_to_stderr=True, + encoding="utf-8", + suppress_traceback=False, + search_paths=None, + allow_empty_macros=False, + ): """ Creates a new Kconfig object by parsing Kconfig files. Note that Kconfig files are not the same as .config files (which store @@ -972,8 +978,14 @@ class Kconfig(object): Pass True here to allow empty / undefined macros. """ try: - self._init(filename, warn, warn_to_stderr, encoding, search_paths, - allow_empty_macros) + self._init( + filename, + warn, + warn_to_stderr, + encoding, + search_paths, + allow_empty_macros, + ) except (EnvironmentError, KconfigError) as e: if suppress_traceback: cmd = sys.argv[0] # Empty string if missing @@ -985,8 +997,9 @@ class Kconfig(object): sys.exit(cmd + str(e).strip()) raise - def _init(self, filename, warn, warn_to_stderr, encoding, search_paths, - allow_empty_macros): + def _init( + self, filename, warn, warn_to_stderr, encoding, search_paths, allow_empty_macros + ): # See __init__() self._encoding = encoding @@ -1011,8 +1024,9 @@ class Kconfig(object): self.config_prefix = os.getenv("CONFIG_", "CONFIG_") # Regular expressions for parsing .config files self._set_match = _re_match(self.config_prefix + r"([^=]+)=(.*)") - self._unset_match = _re_match(r"# {}([^ ]+) is not set".format( - self.config_prefix)) + self._unset_match = _re_match( + r"# {}([^ ]+) is not set".format(self.config_prefix) + ) self.config_header = os.getenv("KCONFIG_CONFIG_HEADER", "") self.header_header = os.getenv("KCONFIG_AUTOHEADER_HEADER", "") @@ -1050,11 +1064,11 @@ class Kconfig(object): # Predefined preprocessor functions, with min/max number of arguments self._functions = { - "info": (_info_fn, 1, 1), - "error-if": (_error_if_fn, 2, 2), - "filename": (_filename_fn, 0, 0), - "lineno": (_lineno_fn, 0, 0), - "shell": (_shell_fn, 1, 1), + "info": (_info_fn, 1, 1), + "error-if": (_error_if_fn, 2, 2), + "filename": (_filename_fn, 0, 0), + "lineno": (_lineno_fn, 0, 0), + "shell": (_shell_fn, 1, 1), "warning-if": (_warning_if_fn, 2, 2), } @@ -1063,7 +1077,8 @@ class Kconfig(object): self._functions.update( importlib.import_module( os.getenv("KCONFIG_FUNCTIONS", "kconfigfunctions") - ).functions) + ).functions + ) except ImportError: pass @@ -1138,8 +1153,7 @@ class Kconfig(object): # KCONFIG_STRICT is an older alias for KCONFIG_WARN_UNDEF, supported # for backwards compatibility - if os.getenv("KCONFIG_WARN_UNDEF") == "y" or \ - os.getenv("KCONFIG_STRICT") == "y": + if os.getenv("KCONFIG_WARN_UNDEF") == "y" or os.getenv("KCONFIG_STRICT") == "y": self._check_undef_syms() @@ -1247,15 +1261,14 @@ class Kconfig(object): msg = None if filename is None: filename = standard_config_filename() - if not exists(filename) and \ - not exists(join(self.srctree, filename)): + if not exists(filename) and not exists(join(self.srctree, filename)): defconfig = self.defconfig_filename if defconfig is None: - return "Using default symbol values (no '{}')" \ - .format(filename) + return "Using default symbol values (no '{}')".format(filename) - msg = " default configuration '{}' (no '{}')" \ - .format(defconfig, filename) + msg = " default configuration '{}' (no '{}')".format( + defconfig, filename + ) filename = defconfig if not msg: @@ -1313,15 +1326,20 @@ class Kconfig(object): if sym.orig_type in _BOOL_TRISTATE: # The C implementation only checks the first character # to the right of '=', for whatever reason - if not (sym.orig_type is BOOL - and val.startswith(("y", "n")) or - sym.orig_type is TRISTATE - and val.startswith(("y", "m", "n"))): - self._warn("'{}' is not a valid value for the {} " - "symbol {}. Assignment ignored." - .format(val, TYPE_TO_STR[sym.orig_type], - sym.name_and_loc), - filename, linenr) + if not ( + sym.orig_type is BOOL + and val.startswith(("y", "n")) + or sym.orig_type is TRISTATE + and val.startswith(("y", "m", "n")) + ): + self._warn( + "'{}' is not a valid value for the {} " + "symbol {}. Assignment ignored.".format( + val, TYPE_TO_STR[sym.orig_type], sym.name_and_loc + ), + filename, + linenr, + ) continue val = val[0] @@ -1332,12 +1350,14 @@ class Kconfig(object): # to the choice symbols prev_mode = sym.choice.user_value - if prev_mode is not None and \ - TRI_TO_STR[prev_mode] != val: + if prev_mode is not None and TRI_TO_STR[prev_mode] != val: - self._warn("both m and y assigned to symbols " - "within the same choice", - filename, linenr) + self._warn( + "both m and y assigned to symbols " + "within the same choice", + filename, + linenr, + ) # Set the choice's mode sym.choice.set_value(val) @@ -1345,10 +1365,14 @@ class Kconfig(object): elif sym.orig_type is STRING: match = _conf_string_match(val) if not match: - self._warn("malformed string literal in " - "assignment to {}. Assignment ignored." - .format(sym.name_and_loc), - filename, linenr) + self._warn( + "malformed string literal in " + "assignment to {}. Assignment ignored.".format( + sym.name_and_loc + ), + filename, + linenr, + ) continue val = unescape(match.group(1)) @@ -1361,9 +1385,11 @@ class Kconfig(object): # lines or comments. 'line' has already been # rstrip()'d, so blank lines show up as "" here. if line and not line.lstrip().startswith("#"): - self._warn("ignoring malformed line '{}'" - .format(line), - filename, linenr) + self._warn( + "ignoring malformed line '{}'".format(line), + filename, + linenr, + ) continue @@ -1403,8 +1429,12 @@ class Kconfig(object): self.missing_syms.append((name, val)) if self.warn_assign_undef: self._warn( - "attempt to assign the value '{}' to the undefined symbol {}" - .format(val, name), filename, linenr) + "attempt to assign the value '{}' to the undefined symbol {}".format( + val, name + ), + filename, + linenr, + ) def _assigned_twice(self, sym, new_val, filename, linenr): # Called when a symbol is assigned more than once in a .config file @@ -1416,7 +1446,8 @@ class Kconfig(object): user_val = sym.user_value msg = '{} set more than once. Old value "{}", new value "{}".'.format( - sym.name_and_loc, user_val, new_val) + sym.name_and_loc, user_val, new_val + ) if user_val == new_val: if self.warn_assign_redun: @@ -1482,8 +1513,7 @@ class Kconfig(object): in tools, which can do e.g. print(kconf.write_autoconf()). """ if filename is None: - filename = os.getenv("KCONFIG_AUTOHEADER", - "include/generated/autoconf.h") + filename = os.getenv("KCONFIG_AUTOHEADER", "include/generated/autoconf.h") if self._write_if_changed(filename, self._autoconf_contents(header)): return "Kconfig header saved to '{}'".format(filename) @@ -1512,28 +1542,26 @@ class Kconfig(object): if sym.orig_type in _BOOL_TRISTATE: if val == "y": - add("#define {}{} 1\n" - .format(self.config_prefix, sym.name)) + add("#define {}{} 1\n".format(self.config_prefix, sym.name)) elif val == "m": - add("#define {}{}_MODULE 1\n" - .format(self.config_prefix, sym.name)) + add("#define {}{}_MODULE 1\n".format(self.config_prefix, sym.name)) elif sym.orig_type is STRING: - add('#define {}{} "{}"\n' - .format(self.config_prefix, sym.name, escape(val))) + add( + '#define {}{} "{}"\n'.format( + self.config_prefix, sym.name, escape(val) + ) + ) else: # sym.orig_type in _INT_HEX: - if sym.orig_type is HEX and \ - not val.startswith(("0x", "0X")): + if sym.orig_type is HEX and not val.startswith(("0x", "0X")): val = "0x" + val - add("#define {}{} {}\n" - .format(self.config_prefix, sym.name, val)) + add("#define {}{} {}\n".format(self.config_prefix, sym.name, val)) return "".join(chunks) - def write_config(self, filename=None, header=None, save_old=True, - verbose=None): + def write_config(self, filename=None, header=None, save_old=True, verbose=None): r""" Writes out symbol values in the .config format. The format matches the C implementation, including ordering. @@ -1647,9 +1675,12 @@ class Kconfig(object): node = node.parent # Add a comment when leaving visible menus - if node.item is MENU and expr_value(node.dep) and \ - expr_value(node.visibility) and \ - node is not self.top_node: + if ( + node.item is MENU + and expr_value(node.dep) + and expr_value(node.visibility) + and node is not self.top_node + ): add("# end of {}\n".format(node.prompt[0])) after_end_comment = True @@ -1680,9 +1711,9 @@ class Kconfig(object): add("\n") add(conf_string) - elif expr_value(node.dep) and \ - ((item is MENU and expr_value(node.visibility)) or - item is COMMENT): + elif expr_value(node.dep) and ( + (item is MENU and expr_value(node.visibility)) or item is COMMENT + ): add("\n#\n# {}\n#\n".format(node.prompt[0])) after_end_comment = False @@ -1738,8 +1769,7 @@ class Kconfig(object): # Skip symbols that cannot be changed. Only check # non-choice symbols, as selects don't affect choice # symbols. - if not sym.choice and \ - sym.visibility <= expr_value(sym.rev_dep): + if not sym.choice and sym.visibility <= expr_value(sym.rev_dep): continue # Skip symbols whose value matches their default @@ -1750,11 +1780,13 @@ class Kconfig(object): # choice, unless the choice is optional or the symbol type # isn't bool (it might be possible to set the choice mode # to n or the symbol to m in those cases). - if sym.choice and \ - not sym.choice.is_optional and \ - sym.choice._selection_from_defaults() is sym and \ - sym.orig_type is BOOL and \ - sym.tri_value == 2: + if ( + sym.choice + and not sym.choice.is_optional + and sym.choice._selection_from_defaults() is sym + and sym.orig_type is BOOL + and sym.tri_value == 2 + ): continue add(sym.config_string) @@ -1842,9 +1874,11 @@ class Kconfig(object): # making a missing symbol logically equivalent to n if sym._write_to_conf: - if sym._old_val is None and \ - sym.orig_type in _BOOL_TRISTATE and \ - val == "n": + if ( + sym._old_val is None + and sym.orig_type in _BOOL_TRISTATE + and val == "n" + ): # No old value (the symbol was missing or n), new value n. # No change. continue @@ -1924,17 +1958,20 @@ class Kconfig(object): # by passing a flag to it, plus we only need to look at symbols here. self._write_if_changed( - os.path.join(path, "auto.conf"), - self._old_vals_contents()) + os.path.join(path, "auto.conf"), self._old_vals_contents() + ) def _old_vals_contents(self): # _write_old_vals() helper. Returns the contents to write as a string. # Temporary list instead of generator makes this a bit faster - return "".join([ - sym.config_string for sym in self.unique_defined_syms + return "".join( + [ + sym.config_string + for sym in self.unique_defined_syms if not (sym.orig_type in _BOOL_TRISTATE and not sym.tri_value) - ]) + ] + ) def node_iter(self, unique_syms=False): """ @@ -2112,30 +2149,35 @@ class Kconfig(object): Returns a string with information about the Kconfig object when it is evaluated on e.g. the interactive Python prompt. """ + def status(flag): return "enabled" if flag else "disabled" - return "<{}>".format(", ".join(( - "configuration with {} symbols".format(len(self.syms)), - 'main menu prompt "{}"'.format(self.mainmenu_text), - "srctree is current directory" if not self.srctree else - 'srctree "{}"'.format(self.srctree), - 'config symbol prefix "{}"'.format(self.config_prefix), - "warnings " + status(self.warn), - "printing of warnings to stderr " + status(self.warn_to_stderr), - "undef. symbol assignment warnings " + - status(self.warn_assign_undef), - "overriding symbol assignment warnings " + - status(self.warn_assign_override), - "redundant symbol assignment warnings " + - status(self.warn_assign_redun) - ))) + return "<{}>".format( + ", ".join( + ( + "configuration with {} symbols".format(len(self.syms)), + 'main menu prompt "{}"'.format(self.mainmenu_text), + "srctree is current directory" + if not self.srctree + else 'srctree "{}"'.format(self.srctree), + 'config symbol prefix "{}"'.format(self.config_prefix), + "warnings " + status(self.warn), + "printing of warnings to stderr " + status(self.warn_to_stderr), + "undef. symbol assignment warnings " + + status(self.warn_assign_undef), + "overriding symbol assignment warnings " + + status(self.warn_assign_override), + "redundant symbol assignment warnings " + + status(self.warn_assign_redun), + ) + ) + ) # # Private methods # - # # File reading # @@ -2160,11 +2202,17 @@ class Kconfig(object): e = e2 raise _KconfigIOError( - e, "Could not open '{}' ({}: {}). Check that the $srctree " - "environment variable ({}) is set correctly." - .format(filename, errno.errorcode[e.errno], e.strerror, - "set to '{}'".format(self.srctree) if self.srctree - else "unset or blank")) + e, + "Could not open '{}' ({}: {}). Check that the $srctree " + "environment variable ({}) is set correctly.".format( + filename, + errno.errorcode[e.errno], + e.strerror, + "set to '{}'".format(self.srctree) + if self.srctree + else "unset or blank", + ), + ) def _enter_file(self, filename): # Jumps to the beginning of a sourced Kconfig file, saving the previous @@ -2179,7 +2227,7 @@ class Kconfig(object): if filename.startswith(self._srctree_prefix): # Relative path (or a redundant absolute path to within $srctree, # but it's probably fine to reduce those too) - rel_filename = filename[len(self._srctree_prefix):] + rel_filename = filename[len(self._srctree_prefix) :] else: # Absolute path rel_filename = filename @@ -2212,20 +2260,32 @@ class Kconfig(object): raise KconfigError( "\n{}:{}: recursive 'source' of '{}' detected. Check that " "environment variables are set correctly.\n" - "Include path:\n{}" - .format(self.filename, self.linenr, rel_filename, - "\n".join("{}:{}".format(name, linenr) - for name, linenr in self._include_path))) + "Include path:\n{}".format( + self.filename, + self.linenr, + rel_filename, + "\n".join( + "{}:{}".format(name, linenr) + for name, linenr in self._include_path + ), + ) + ) try: self._readline = self._open(filename, "r").readline except EnvironmentError as e: # We already know that the file exists raise _KconfigIOError( - e, "{}:{}: Could not open '{}' (in '{}') ({}: {})" - .format(self.filename, self.linenr, filename, - self._line.strip(), - errno.errorcode[e.errno], e.strerror)) + e, + "{}:{}: Could not open '{}' (in '{}') ({}: {})".format( + self.filename, + self.linenr, + filename, + self._line.strip(), + errno.errorcode[e.errno], + e.strerror, + ), + ) self.filename = rel_filename self.linenr = 0 @@ -2438,8 +2498,11 @@ class Kconfig(object): else: i = match.end() - token = self.const_syms[name] if name in STR_TO_TRI else \ - self._lookup_sym(name) + token = ( + self.const_syms[name] + if name in STR_TO_TRI + else self._lookup_sym(name) + ) else: # It's a case of missing quotes. For example, the @@ -2455,9 +2518,13 @@ class Kconfig(object): # Named choices ('choice FOO') also end up here. if token is not _T_CHOICE: - self._warn("style: quotes recommended around '{}' in '{}'" - .format(name, self._line.strip()), - self.filename, self.linenr) + self._warn( + "style: quotes recommended around '{}' in '{}'".format( + name, self._line.strip() + ), + self.filename, + self.linenr, + ) token = name i = match.end() @@ -2476,7 +2543,7 @@ class Kconfig(object): end_i = s.find(c, i + 1) + 1 if not end_i: self._parse_error("unterminated string") - val = s[i + 1:end_i - 1] + val = s[i + 1 : end_i - 1] i = end_i else: # Slow path @@ -2489,18 +2556,22 @@ class Kconfig(object): # # The preprocessor functionality changed how # environment variables are referenced, to $(FOO). - val = expandvars(s[i + 1:end_i - 1] - .replace("$UNAME_RELEASE", - _UNAME_RELEASE)) + val = expandvars( + s[i + 1 : end_i - 1].replace( + "$UNAME_RELEASE", _UNAME_RELEASE + ) + ) i = end_i # This is the only place where we don't survive with a # single token of lookback: 'option env="FOO"' does not # refer to a constant symbol named "FOO". - token = \ - val if token in _STRING_LEX or tokens[0] is _T_OPTION \ + token = ( + val + if token in _STRING_LEX or tokens[0] is _T_OPTION else self._lookup_const_sym(val) + ) elif s.startswith("&&", i): token = _T_AND @@ -2533,7 +2604,6 @@ class Kconfig(object): elif c == "#": break - # Very rare elif s.startswith("<=", i): @@ -2552,16 +2622,13 @@ class Kconfig(object): token = _T_GREATER i += 1 - else: self._parse_error("unknown tokens in line") - # Skip trailing whitespace while i < len(s) and s[i].isspace(): i += 1 - # Add the token tokens.append(token) @@ -2652,7 +2719,6 @@ class Kconfig(object): # Assigned variable name = s[:i] - # Extract assignment operator (=, :=, or +=) and value rhs_match = _assignment_rhs_match(s, i) if not rhs_match: @@ -2660,7 +2726,6 @@ class Kconfig(object): op, val = rhs_match.groups() - if name in self.variables: # Already seen variable var = self.variables[name] @@ -2686,8 +2751,9 @@ class Kconfig(object): else: # op == "+=" # += does immediate expansion if the variable was last set # with := - var.value += " " + (val if var.is_recursive else - self._expand_whole(val, ())) + var.value += " " + ( + val if var.is_recursive else self._expand_whole(val, ()) + ) def _expand_whole(self, s, args): # Expands preprocessor macros in all of 's'. Used whenever we don't @@ -2753,7 +2819,6 @@ class Kconfig(object): if not match: self._parse_error("unterminated string") - if match.group() == quote: # Found the end of the string return (s, match.end()) @@ -2762,7 +2827,7 @@ class Kconfig(object): # Replace '\x' with 'x'. 'i' ends up pointing to the character # after 'x', which allows macros to be canceled with '\$(foo)'. i = match.end() - s = s[:match.start()] + s[i:] + s = s[: match.start()] + s[i:] elif match.group() == "$(": # A macro call within the string @@ -2792,7 +2857,6 @@ class Kconfig(object): if not match: self._parse_error("missing end parenthesis in macro expansion") - if match.group() == "(": nesting += 1 i = match.end() @@ -2805,7 +2869,7 @@ class Kconfig(object): # Found the end of the macro - new_args.append(s[arg_start:match.start()]) + new_args.append(s[arg_start : match.start()]) # $(1) is replaced by the first argument to the function, etc., # provided at least that many arguments were passed @@ -2819,7 +2883,7 @@ class Kconfig(object): # and also go through the function value path res += self._fn_val(new_args) - return (res + s[match.end():], len(res)) + return (res + s[match.end() :], len(res)) elif match.group() == ",": i = match.end() @@ -2827,7 +2891,7 @@ class Kconfig(object): continue # Found the end of a macro argument - new_args.append(s[arg_start:match.start()]) + new_args.append(s[arg_start : match.start()]) arg_start = i else: # match.group() == "$(" @@ -2847,13 +2911,17 @@ class Kconfig(object): if len(args) == 1: # Plain variable if var._n_expansions: - self._parse_error("Preprocessor variable {} recursively " - "references itself".format(var.name)) + self._parse_error( + "Preprocessor variable {} recursively " + "references itself".format(var.name) + ) elif var._n_expansions > 100: # Allow functions to call themselves, but guess that functions # that are overly recursive are stuck - self._parse_error("Preprocessor function {} seems stuck " - "in infinite recursion".format(var.name)) + self._parse_error( + "Preprocessor function {} seems stuck " + "in infinite recursion".format(var.name) + ) var._n_expansions += 1 res = self._expand_whole(self.variables[fn].value, args) @@ -2865,8 +2933,9 @@ class Kconfig(object): py_fn, min_arg, max_arg = self._functions[fn] - if len(args) - 1 < min_arg or \ - (max_arg is not None and len(args) - 1 > max_arg): + if len(args) - 1 < min_arg or ( + max_arg is not None and len(args) - 1 > max_arg + ): if min_arg == max_arg: expected_args = min_arg @@ -2875,10 +2944,12 @@ class Kconfig(object): else: expected_args = "{}-{}".format(min_arg, max_arg) - raise KconfigError("{}:{}: bad number of arguments in call " - "to {}, expected {}, got {}" - .format(self.filename, self.linenr, fn, - expected_args, len(args) - 1)) + raise KconfigError( + "{}:{}: bad number of arguments in call " + "to {}, expected {}, got {}".format( + self.filename, self.linenr, fn, expected_args, len(args) - 1 + ) + ) return py_fn(self, *args) @@ -2962,7 +3033,7 @@ class Kconfig(object): node = MenuNode() node.kconfig = self node.item = sym - node.is_menuconfig = (t0 is _T_MENUCONFIG) + node.is_menuconfig = t0 is _T_MENUCONFIG node.prompt = node.help = node.list = None node.parent = parent node.filename = self.filename @@ -2974,8 +3045,11 @@ class Kconfig(object): self._parse_props(node) if node.is_menuconfig and not node.prompt: - self._warn("the menuconfig symbol {} has no prompt" - .format(sym.name_and_loc)) + self._warn( + "the menuconfig symbol {} has no prompt".format( + sym.name_and_loc + ) + ) # Equivalent to # @@ -3014,11 +3088,16 @@ class Kconfig(object): "{}:{}: '{}' not found (in '{}'). Check that " "environment variables are set correctly (e.g. " "$srctree, which is {}). Also note that unset " - "environment variables expand to the empty string." - .format(self.filename, self.linenr, pattern, - self._line.strip(), - "set to '{}'".format(self.srctree) - if self.srctree else "unset or blank")) + "environment variables expand to the empty string.".format( + self.filename, + self.linenr, + pattern, + self._line.strip(), + "set to '{}'".format(self.srctree) + if self.srctree + else "unset or blank", + ) + ) for filename in filenames: self._enter_file(filename) @@ -3125,20 +3204,28 @@ class Kconfig(object): # A valid endchoice/endif/endmenu is caught by the 'end_token' # check above self._parse_error( - "no corresponding 'choice'" if t0 is _T_ENDCHOICE else - "no corresponding 'if'" if t0 is _T_ENDIF else - "no corresponding 'menu'" if t0 is _T_ENDMENU else - "unrecognized construct") + "no corresponding 'choice'" + if t0 is _T_ENDCHOICE + else "no corresponding 'if'" + if t0 is _T_ENDIF + else "no corresponding 'menu'" + if t0 is _T_ENDMENU + else "unrecognized construct" + ) # End of file reached. Return the last node. if end_token: raise KconfigError( - "error: expected '{}' at end of '{}'" - .format("endchoice" if end_token is _T_ENDCHOICE else - "endif" if end_token is _T_ENDIF else - "endmenu", - self.filename)) + "error: expected '{}' at end of '{}'".format( + "endchoice" + if end_token is _T_ENDCHOICE + else "endif" + if end_token is _T_ENDIF + else "endmenu", + self.filename, + ) + ) return prev @@ -3187,8 +3274,7 @@ class Kconfig(object): if not self._check_token(_T_ON): self._parse_error("expected 'on' after 'depends'") - node.dep = self._make_and(node.dep, - self._expect_expr_and_eol()) + node.dep = self._make_and(node.dep, self._expect_expr_and_eol()) elif t0 is _T_HELP: self._parse_help(node) @@ -3197,42 +3283,40 @@ class Kconfig(object): if node.item.__class__ is not Symbol: self._parse_error("only symbols can select") - node.selects.append((self._expect_nonconst_sym(), - self._parse_cond())) + node.selects.append((self._expect_nonconst_sym(), self._parse_cond())) elif t0 is None: # Blank line continue elif t0 is _T_DEFAULT: - node.defaults.append((self._parse_expr(False), - self._parse_cond())) + node.defaults.append((self._parse_expr(False), self._parse_cond())) elif t0 in _DEF_TOKEN_TO_TYPE: self._set_type(node.item, _DEF_TOKEN_TO_TYPE[t0]) - node.defaults.append((self._parse_expr(False), - self._parse_cond())) + node.defaults.append((self._parse_expr(False), self._parse_cond())) elif t0 is _T_PROMPT: self._parse_prompt(node) elif t0 is _T_RANGE: - node.ranges.append((self._expect_sym(), self._expect_sym(), - self._parse_cond())) + node.ranges.append( + (self._expect_sym(), self._expect_sym(), self._parse_cond()) + ) elif t0 is _T_IMPLY: if node.item.__class__ is not Symbol: self._parse_error("only symbols can imply") - node.implies.append((self._expect_nonconst_sym(), - self._parse_cond())) + node.implies.append((self._expect_nonconst_sym(), self._parse_cond())) elif t0 is _T_VISIBLE: if not self._check_token(_T_IF): self._parse_error("expected 'if' after 'visible'") - node.visibility = self._make_and(node.visibility, - self._expect_expr_and_eol()) + node.visibility = self._make_and( + node.visibility, self._expect_expr_and_eol() + ) elif t0 is _T_OPTION: if self._check_token(_T_ENV): @@ -3244,33 +3328,42 @@ class Kconfig(object): if env_var in os.environ: node.defaults.append( - (self._lookup_const_sym(os.environ[env_var]), - self.y)) + (self._lookup_const_sym(os.environ[env_var]), self.y) + ) else: - self._warn("{1} has 'option env=\"{0}\"', " - "but the environment variable {0} is not " - "set".format(node.item.name, env_var), - self.filename, self.linenr) + self._warn( + "{1} has 'option env=\"{0}\"', " + "but the environment variable {0} is not " + "set".format(node.item.name, env_var), + self.filename, + self.linenr, + ) if env_var != node.item.name: - self._warn("Kconfiglib expands environment variables " - "in strings directly, meaning you do not " - "need 'option env=...' \"bounce\" symbols. " - "For compatibility with the C tools, " - "rename {} to {} (so that the symbol name " - "matches the environment variable name)." - .format(node.item.name, env_var), - self.filename, self.linenr) + self._warn( + "Kconfiglib expands environment variables " + "in strings directly, meaning you do not " + "need 'option env=...' \"bounce\" symbols. " + "For compatibility with the C tools, " + "rename {} to {} (so that the symbol name " + "matches the environment variable name).".format( + node.item.name, env_var + ), + self.filename, + self.linenr, + ) elif self._check_token(_T_DEFCONFIG_LIST): if not self.defconfig_list: self.defconfig_list = node.item else: - self._warn("'option defconfig_list' set on multiple " - "symbols ({0} and {1}). Only {0} will be " - "used.".format(self.defconfig_list.name, - node.item.name), - self.filename, self.linenr) + self._warn( + "'option defconfig_list' set on multiple " + "symbols ({0} and {1}). Only {0} will be " + "used.".format(self.defconfig_list.name, node.item.name), + self.filename, + self.linenr, + ) elif self._check_token(_T_MODULES): # To reduce warning spam, only warn if 'option modules' is @@ -3279,20 +3372,24 @@ class Kconfig(object): # modules besides the kernel yet, and there it's likely to # keep being called "MODULES". if node.item is not self.modules: - self._warn("the 'modules' option is not supported. " - "Let me know if this is a problem for you, " - "as it wouldn't be that hard to implement. " - "Note that modules are supported -- " - "Kconfiglib just assumes the symbol name " - "MODULES, like older versions of the C " - "implementation did when 'option modules' " - "wasn't used.", - self.filename, self.linenr) + self._warn( + "the 'modules' option is not supported. " + "Let me know if this is a problem for you, " + "as it wouldn't be that hard to implement. " + "Note that modules are supported -- " + "Kconfiglib just assumes the symbol name " + "MODULES, like older versions of the C " + "implementation did when 'option modules' " + "wasn't used.", + self.filename, + self.linenr, + ) elif self._check_token(_T_ALLNOCONFIG_Y): if node.item.__class__ is not Symbol: - self._parse_error("the 'allnoconfig_y' option is only " - "valid for symbols") + self._parse_error( + "the 'allnoconfig_y' option is only " "valid for symbols" + ) node.item.is_allnoconfig_y = True @@ -3315,8 +3412,11 @@ class Kconfig(object): # UNKNOWN is falsy if sc.orig_type and sc.orig_type is not new_type: - self._warn("{} defined with multiple types, {} will be used" - .format(sc.name_and_loc, TYPE_TO_STR[new_type])) + self._warn( + "{} defined with multiple types, {} will be used".format( + sc.name_and_loc, TYPE_TO_STR[new_type] + ) + ) sc.orig_type = new_type @@ -3326,8 +3426,10 @@ class Kconfig(object): # multiple times if node.prompt: - self._warn(node.item.name_and_loc + - " defined with multiple prompts in single location") + self._warn( + node.item.name_and_loc + + " defined with multiple prompts in single location" + ) prompt = self._tokens[1] self._tokens_i = 2 @@ -3336,8 +3438,10 @@ class Kconfig(object): self._parse_error("expected prompt string") if prompt != prompt.strip(): - self._warn(node.item.name_and_loc + - " has leading or trailing whitespace in its prompt") + self._warn( + node.item.name_and_loc + + " has leading or trailing whitespace in its prompt" + ) # This avoid issues for e.g. reStructuredText documentation, where # '*prompt *' is invalid @@ -3347,8 +3451,10 @@ class Kconfig(object): def _parse_help(self, node): if node.help is not None: - self._warn(node.item.name_and_loc + " defined with more than " - "one help text -- only the last one will be used") + self._warn( + node.item.name_and_loc + " defined with more than " + "one help text -- only the last one will be used" + ) # Micro-optimization. This code is pretty hot. readline = self._readline @@ -3403,8 +3509,7 @@ class Kconfig(object): self._line_after_help(line) def _empty_help(self, node, line): - self._warn(node.item.name_and_loc + - " has 'help' but empty help text") + self._warn(node.item.name_and_loc + " has 'help' but empty help text") node.help = "" if line: self._line_after_help(line) @@ -3447,8 +3552,11 @@ class Kconfig(object): # Return 'and_expr' directly if we have a "single-operand" OR. # Otherwise, parse the expression on the right and make an OR node. # This turns A || B || C || D into (OR, A, (OR, B, (OR, C, D))). - return and_expr if not self._check_token(_T_OR) else \ - (OR, and_expr, self._parse_expr(transform_m)) + return ( + and_expr + if not self._check_token(_T_OR) + else (OR, and_expr, self._parse_expr(transform_m)) + ) def _parse_and_expr(self, transform_m): factor = self._parse_factor(transform_m) @@ -3456,8 +3564,11 @@ class Kconfig(object): # Return 'factor' directly if we have a "single-operand" AND. # Otherwise, parse the right operand and make an AND node. This turns # A && B && C && D into (AND, A, (AND, B, (AND, C, D))). - return factor if not self._check_token(_T_AND) else \ - (AND, factor, self._parse_and_expr(transform_m)) + return ( + factor + if not self._check_token(_T_AND) + else (AND, factor, self._parse_and_expr(transform_m)) + ) def _parse_factor(self, transform_m): token = self._tokens[self._tokens_i] @@ -3481,8 +3592,7 @@ class Kconfig(object): # _T_EQUAL, _T_UNEQUAL, etc., deliberately have the same values as # EQUAL, UNEQUAL, etc., so we can just use the token directly self._tokens_i += 1 - return (self._tokens[self._tokens_i - 1], token, - self._expect_sym()) + return (self._tokens[self._tokens_i - 1], token, self._expect_sym()) if token is _T_NOT: # token == _T_NOT == NOT @@ -3689,36 +3799,43 @@ class Kconfig(object): if cur.item.__class__ in _SYMBOL_CHOICE: # Propagate 'visible if' and dependencies to the prompt if cur.prompt: - cur.prompt = (cur.prompt[0], - self._make_and( - cur.prompt[1], - self._make_and(visible_if, dep))) + cur.prompt = ( + cur.prompt[0], + self._make_and(cur.prompt[1], self._make_and(visible_if, dep)), + ) # Propagate dependencies to defaults if cur.defaults: - cur.defaults = [(default, self._make_and(cond, dep)) - for default, cond in cur.defaults] + cur.defaults = [ + (default, self._make_and(cond, dep)) + for default, cond in cur.defaults + ] # Propagate dependencies to ranges if cur.ranges: - cur.ranges = [(low, high, self._make_and(cond, dep)) - for low, high, cond in cur.ranges] + cur.ranges = [ + (low, high, self._make_and(cond, dep)) + for low, high, cond in cur.ranges + ] # Propagate dependencies to selects if cur.selects: - cur.selects = [(target, self._make_and(cond, dep)) - for target, cond in cur.selects] + cur.selects = [ + (target, self._make_and(cond, dep)) + for target, cond in cur.selects + ] # Propagate dependencies to implies if cur.implies: - cur.implies = [(target, self._make_and(cond, dep)) - for target, cond in cur.implies] + cur.implies = [ + (target, self._make_and(cond, dep)) + for target, cond in cur.implies + ] elif cur.prompt: # Not a symbol/choice # Propagate dependencies to the prompt. 'visible if' is only # propagated to symbols/choices. - cur.prompt = (cur.prompt[0], - self._make_and(cur.prompt[1], dep)) + cur.prompt = (cur.prompt[0], self._make_and(cur.prompt[1], dep)) cur = cur.next @@ -3744,16 +3861,14 @@ class Kconfig(object): # Modify the reverse dependencies of the selected symbol for target, cond in node.selects: - target.rev_dep = self._make_or( - target.rev_dep, - self._make_and(sym, cond)) + target.rev_dep = self._make_or(target.rev_dep, self._make_and(sym, cond)) # Modify the weak reverse dependencies of the implied # symbol for target, cond in node.implies: target.weak_rev_dep = self._make_or( - target.weak_rev_dep, - self._make_and(sym, cond)) + target.weak_rev_dep, self._make_and(sym, cond) + ) # # Misc. @@ -3781,82 +3896,106 @@ class Kconfig(object): for target_sym, _ in sym.selects: if target_sym.orig_type not in _BOOL_TRISTATE_UNKNOWN: - self._warn("{} selects the {} symbol {}, which is not " - "bool or tristate" - .format(sym.name_and_loc, - TYPE_TO_STR[target_sym.orig_type], - target_sym.name_and_loc)) + self._warn( + "{} selects the {} symbol {}, which is not " + "bool or tristate".format( + sym.name_and_loc, + TYPE_TO_STR[target_sym.orig_type], + target_sym.name_and_loc, + ) + ) for target_sym, _ in sym.implies: if target_sym.orig_type not in _BOOL_TRISTATE_UNKNOWN: - self._warn("{} implies the {} symbol {}, which is not " - "bool or tristate" - .format(sym.name_and_loc, - TYPE_TO_STR[target_sym.orig_type], - target_sym.name_and_loc)) + self._warn( + "{} implies the {} symbol {}, which is not " + "bool or tristate".format( + sym.name_and_loc, + TYPE_TO_STR[target_sym.orig_type], + target_sym.name_and_loc, + ) + ) elif sym.orig_type: # STRING/INT/HEX for default, _ in sym.defaults: if default.__class__ is not Symbol: raise KconfigError( "the {} symbol {} has a malformed default {} -- " - "expected a single symbol" - .format(TYPE_TO_STR[sym.orig_type], - sym.name_and_loc, expr_str(default))) + "expected a single symbol".format( + TYPE_TO_STR[sym.orig_type], + sym.name_and_loc, + expr_str(default), + ) + ) if sym.orig_type is STRING: - if not default.is_constant and not default.nodes and \ - not default.name.isupper(): + if ( + not default.is_constant + and not default.nodes + and not default.name.isupper() + ): # 'default foo' on a string symbol could be either a symbol # reference or someone leaving out the quotes. Guess that # the quotes were left out if 'foo' isn't all-uppercase # (and no symbol named 'foo' exists). - self._warn("style: quotes recommended around " - "default value for string symbol " - + sym.name_and_loc) + self._warn( + "style: quotes recommended around " + "default value for string symbol " + sym.name_and_loc + ) elif not num_ok(default, sym.orig_type): # INT/HEX - self._warn("the {0} symbol {1} has a non-{0} default {2}" - .format(TYPE_TO_STR[sym.orig_type], - sym.name_and_loc, - default.name_and_loc)) + self._warn( + "the {0} symbol {1} has a non-{0} default {2}".format( + TYPE_TO_STR[sym.orig_type], + sym.name_and_loc, + default.name_and_loc, + ) + ) if sym.selects or sym.implies: - self._warn("the {} symbol {} has selects or implies" - .format(TYPE_TO_STR[sym.orig_type], - sym.name_and_loc)) + self._warn( + "the {} symbol {} has selects or implies".format( + TYPE_TO_STR[sym.orig_type], sym.name_and_loc + ) + ) else: # UNKNOWN - self._warn("{} defined without a type" - .format(sym.name_and_loc)) - + self._warn("{} defined without a type".format(sym.name_and_loc)) if sym.ranges: if sym.orig_type not in _INT_HEX: self._warn( - "the {} symbol {} has ranges, but is not int or hex" - .format(TYPE_TO_STR[sym.orig_type], - sym.name_and_loc)) + "the {} symbol {} has ranges, but is not int or hex".format( + TYPE_TO_STR[sym.orig_type], sym.name_and_loc + ) + ) else: for low, high, _ in sym.ranges: - if not num_ok(low, sym.orig_type) or \ - not num_ok(high, sym.orig_type): - - self._warn("the {0} symbol {1} has a non-{0} " - "range [{2}, {3}]" - .format(TYPE_TO_STR[sym.orig_type], - sym.name_and_loc, - low.name_and_loc, - high.name_and_loc)) + if not num_ok(low, sym.orig_type) or not num_ok( + high, sym.orig_type + ): + + self._warn( + "the {0} symbol {1} has a non-{0} " + "range [{2}, {3}]".format( + TYPE_TO_STR[sym.orig_type], + sym.name_and_loc, + low.name_and_loc, + high.name_and_loc, + ) + ) def _check_choice_sanity(self): # Checks various choice properties that are handiest to check after # parsing. Only generates errors and warnings. def warn_select_imply(sym, expr, expr_type): - msg = "the choice symbol {} is {} by the following symbols, but " \ - "select/imply has no effect on choice symbols" \ - .format(sym.name_and_loc, expr_type) + msg = ( + "the choice symbol {} is {} by the following symbols, but " + "select/imply has no effect on choice symbols".format( + sym.name_and_loc, expr_type + ) + ) # si = select/imply for si in split_expr(expr, OR): @@ -3866,9 +4005,11 @@ class Kconfig(object): for choice in self.unique_choices: if choice.orig_type not in _BOOL_TRISTATE: - self._warn("{} defined with type {}" - .format(choice.name_and_loc, - TYPE_TO_STR[choice.orig_type])) + self._warn( + "{} defined with type {}".format( + choice.name_and_loc, TYPE_TO_STR[choice.orig_type] + ) + ) for node in choice.nodes: if node.prompt: @@ -3879,20 +4020,26 @@ class Kconfig(object): for default, _ in choice.defaults: if default.__class__ is not Symbol: raise KconfigError( - "{} has a malformed default {}" - .format(choice.name_and_loc, expr_str(default))) + "{} has a malformed default {}".format( + choice.name_and_loc, expr_str(default) + ) + ) if default.choice is not choice: - self._warn("the default selection {} of {} is not " - "contained in the choice" - .format(default.name_and_loc, - choice.name_and_loc)) + self._warn( + "the default selection {} of {} is not " + "contained in the choice".format( + default.name_and_loc, choice.name_and_loc + ) + ) for sym in choice.syms: if sym.defaults: - self._warn("default on the choice symbol {} will have " - "no effect, as defaults do not affect choice " - "symbols".format(sym.name_and_loc)) + self._warn( + "default on the choice symbol {} will have " + "no effect, as defaults do not affect choice " + "symbols".format(sym.name_and_loc) + ) if sym.rev_dep is not sym.kconfig.n: warn_select_imply(sym, sym.rev_dep, "selected") @@ -3903,19 +4050,28 @@ class Kconfig(object): for node in sym.nodes: if node.parent.item is choice: if not node.prompt: - self._warn("the choice symbol {} has no prompt" - .format(sym.name_and_loc)) + self._warn( + "the choice symbol {} has no prompt".format( + sym.name_and_loc + ) + ) elif node.prompt: - self._warn("the choice symbol {} is defined with a " - "prompt outside the choice" - .format(sym.name_and_loc)) + self._warn( + "the choice symbol {} is defined with a " + "prompt outside the choice".format(sym.name_and_loc) + ) def _parse_error(self, msg): - raise KconfigError("{}error: couldn't parse '{}': {}".format( - "" if self.filename is None else - "{}:{}: ".format(self.filename, self.linenr), - self._line.strip(), msg)) + raise KconfigError( + "{}error: couldn't parse '{}': {}".format( + "" + if self.filename is None + else "{}:{}: ".format(self.filename, self.linenr), + self._line.strip(), + msg, + ) + ) def _trailing_tokens_error(self): self._parse_error("extra tokens at end of line") @@ -3954,8 +4110,11 @@ class Kconfig(object): # - For Python 3, force the encoding. Forcing the encoding on Python 2 # turns strings into Unicode strings, which gets messy. Python 2 # doesn't decode regular strings anyway. - return open(filename, "rU" if mode == "r" else mode) if _IS_PY2 else \ - open(filename, mode, encoding=self._encoding) + return ( + open(filename, "rU" if mode == "r" else mode) + if _IS_PY2 + else open(filename, mode, encoding=self._encoding) + ) def _check_undef_syms(self): # Prints warnings for all references to undefined symbols within the @@ -3992,14 +4151,14 @@ class Kconfig(object): # symbols, but shouldn't be flagged # # - The MODULES symbol always exists - if not sym.nodes and not is_num(sym.name) and \ - sym.name != "MODULES": + if not sym.nodes and not is_num(sym.name) and sym.name != "MODULES": msg = "undefined symbol {}:".format(sym.name) for node in self.node_iter(): if sym in node.referenced: - msg += "\n\n- Referenced at {}:{}:\n\n{}" \ - .format(node.filename, node.linenr, node) + msg += "\n\n- Referenced at {}:{}:\n\n{}".format( + node.filename, node.linenr, node + ) self._warn(msg) def _warn(self, msg, filename=None, linenr=None): @@ -4274,6 +4433,7 @@ class Symbol(object): kconfig: The Kconfig instance this symbol is from. """ + __slots__ = ( "_cached_assignable", "_cached_str_val", @@ -4311,9 +4471,11 @@ class Symbol(object): """ See the class documentation. """ - if self.orig_type is TRISTATE and \ - (self.choice and self.choice.tri_value == 2 or - not self.kconfig.modules.tri_value): + if self.orig_type is TRISTATE and ( + self.choice + and self.choice.tri_value == 2 + or not self.kconfig.modules.tri_value + ): return BOOL @@ -4344,7 +4506,7 @@ class Symbol(object): # function call (property magic) vis = self.visibility - self._write_to_conf = (vis != 0) + self._write_to_conf = vis != 0 if self.orig_type in _INT_HEX: # The C implementation checks the user value against the range in a @@ -4361,10 +4523,16 @@ class Symbol(object): # The zeros are from the C implementation running strtoll() # on empty strings - low = int(low_expr.str_value, base) if \ - _is_base_n(low_expr.str_value, base) else 0 - high = int(high_expr.str_value, base) if \ - _is_base_n(high_expr.str_value, base) else 0 + low = ( + int(low_expr.str_value, base) + if _is_base_n(low_expr.str_value, base) + else 0 + ) + high = ( + int(high_expr.str_value, base) + if _is_base_n(high_expr.str_value, base) + else 0 + ) break else: @@ -4381,10 +4549,14 @@ class Symbol(object): self.kconfig._warn( "user value {} on the {} symbol {} ignored due to " "being outside the active range ([{}, {}]) -- falling " - "back on defaults" - .format(num2str(user_val), TYPE_TO_STR[self.orig_type], - self.name_and_loc, - num2str(low), num2str(high))) + "back on defaults".format( + num2str(user_val), + TYPE_TO_STR[self.orig_type], + self.name_and_loc, + num2str(low), + num2str(high), + ) + ) else: # If the user value is well-formed and satisfies range # contraints, it is stored in exactly the same form as @@ -4424,18 +4596,20 @@ class Symbol(object): if clamp is not None: # The value is rewritten to a standard form if it is # clamped - val = str(clamp) \ - if self.orig_type is INT else \ - hex(clamp) + val = str(clamp) if self.orig_type is INT else hex(clamp) if has_default: num2str = str if base == 10 else hex self.kconfig._warn( "default value {} on {} clamped to {} due to " - "being outside the active range ([{}, {}])" - .format(val_num, self.name_and_loc, - num2str(clamp), num2str(low), - num2str(high))) + "being outside the active range ([{}, {}])".format( + val_num, + self.name_and_loc, + num2str(clamp), + num2str(low), + num2str(high), + ) + ) elif self.orig_type is STRING: if vis and self.user_value is not None: @@ -4473,8 +4647,10 @@ class Symbol(object): # Would take some work to give the location here self.kconfig._warn( "The {} symbol {} is being evaluated in a logical context " - "somewhere. It will always evaluate to n." - .format(TYPE_TO_STR[self.orig_type], self.name_and_loc)) + "somewhere. It will always evaluate to n.".format( + TYPE_TO_STR[self.orig_type], self.name_and_loc + ) + ) self._cached_tri_val = 0 return 0 @@ -4482,7 +4658,7 @@ class Symbol(object): # Warning: See Symbol._rec_invalidate(), and note that this is a hidden # function call (property magic) vis = self.visibility - self._write_to_conf = (vis != 0) + self._write_to_conf = vis != 0 val = 0 @@ -4523,8 +4699,7 @@ class Symbol(object): # m is promoted to y for (1) bool symbols and (2) symbols with a # weak_rev_dep (from imply) of y - if val == 1 and \ - (self.type is BOOL or expr_value(self.weak_rev_dep) == 2): + if val == 1 and (self.type is BOOL or expr_value(self.weak_rev_dep) == 2): val = 2 elif vis == 2: @@ -4570,19 +4745,17 @@ class Symbol(object): return "" if self.orig_type in _BOOL_TRISTATE: - return "{}{}={}\n" \ - .format(self.kconfig.config_prefix, self.name, val) \ - if val != "n" else \ - "# {}{} is not set\n" \ - .format(self.kconfig.config_prefix, self.name) + return ( + "{}{}={}\n".format(self.kconfig.config_prefix, self.name, val) + if val != "n" + else "# {}{} is not set\n".format(self.kconfig.config_prefix, self.name) + ) if self.orig_type in _INT_HEX: - return "{}{}={}\n" \ - .format(self.kconfig.config_prefix, self.name, val) + return "{}{}={}\n".format(self.kconfig.config_prefix, self.name, val) # sym.orig_type is STRING - return '{}{}="{}"\n' \ - .format(self.kconfig.config_prefix, self.name, escape(val)) + return '{}{}="{}"\n'.format(self.kconfig.config_prefix, self.name, escape(val)) @property def name_and_loc(self): @@ -4646,21 +4819,31 @@ class Symbol(object): return True # Check if the value is valid for our type - if not (self.orig_type is BOOL and value in (2, 0) or - self.orig_type is TRISTATE and value in TRI_TO_STR or - value.__class__ is str and - (self.orig_type is STRING or - self.orig_type is INT and _is_base_n(value, 10) or - self.orig_type is HEX and _is_base_n(value, 16) - and int(value, 16) >= 0)): + if not ( + self.orig_type is BOOL + and value in (2, 0) + or self.orig_type is TRISTATE + and value in TRI_TO_STR + or value.__class__ is str + and ( + self.orig_type is STRING + or self.orig_type is INT + and _is_base_n(value, 10) + or self.orig_type is HEX + and _is_base_n(value, 16) + and int(value, 16) >= 0 + ) + ): # Display tristate values as n, m, y in the warning self.kconfig._warn( "the value {} is invalid for {}, which has type {} -- " - "assignment ignored" - .format(TRI_TO_STR[value] if value in TRI_TO_STR else - "'{}'".format(value), - self.name_and_loc, TYPE_TO_STR[self.orig_type])) + "assignment ignored".format( + TRI_TO_STR[value] if value in TRI_TO_STR else "'{}'".format(value), + self.name_and_loc, + TYPE_TO_STR[self.orig_type], + ) + ) return False @@ -4738,17 +4921,28 @@ class Symbol(object): add('"{}"'.format(node.prompt[0])) # Only add quotes for non-bool/tristate symbols - add("value " + (self.str_value if self.orig_type in _BOOL_TRISTATE - else '"{}"'.format(self.str_value))) + add( + "value " + + ( + self.str_value + if self.orig_type in _BOOL_TRISTATE + else '"{}"'.format(self.str_value) + ) + ) if not self.is_constant: # These aren't helpful to show for constant symbols if self.user_value is not None: # Only add quotes for non-bool/tristate symbols - add("user value " + (TRI_TO_STR[self.user_value] - if self.orig_type in _BOOL_TRISTATE - else '"{}"'.format(self.user_value))) + add( + "user value " + + ( + TRI_TO_STR[self.user_value] + if self.orig_type in _BOOL_TRISTATE + else '"{}"'.format(self.user_value) + ) + ) add("visibility " + TRI_TO_STR[self.visibility]) @@ -4798,8 +4992,7 @@ class Symbol(object): Works like Symbol.__str__(), but allows a custom format to be used for all symbol/choice references. See expr_str(). """ - return "\n\n".join(node.custom_str(sc_expr_str_fn) - for node in self.nodes) + return "\n\n".join(node.custom_str(sc_expr_str_fn) for node in self.nodes) # # Private methods @@ -4830,18 +5023,18 @@ class Symbol(object): self.implies = [] self.ranges = [] - self.user_value = \ - self.choice = \ - self.env_var = \ - self._cached_str_val = self._cached_tri_val = self._cached_vis = \ - self._cached_assignable = None + self.user_value = ( + self.choice + ) = ( + self.env_var + ) = ( + self._cached_str_val + ) = self._cached_tri_val = self._cached_vis = self._cached_assignable = None # _write_to_conf is calculated along with the value. If True, the # Symbol gets a .config entry. - self.is_allnoconfig_y = \ - self._was_set = \ - self._write_to_conf = False + self.is_allnoconfig_y = self._was_set = self._write_to_conf = False # See Kconfig._build_dep() self._dependents = set() @@ -4895,8 +5088,9 @@ class Symbol(object): def _invalidate(self): # Marks the symbol as needing to be recalculated - self._cached_str_val = self._cached_tri_val = self._cached_vis = \ - self._cached_assignable = None + self._cached_str_val = ( + self._cached_tri_val + ) = self._cached_vis = self._cached_assignable = None def _rec_invalidate(self): # Invalidates the symbol and all items that (possibly) depend on it @@ -4948,8 +5142,10 @@ class Symbol(object): return if self.kconfig._warn_assign_no_prompt: - self.kconfig._warn(self.name_and_loc + " has no prompt, meaning " - "user values have no effect on it") + self.kconfig._warn( + self.name_and_loc + " has no prompt, meaning " + "user values have no effect on it" + ) def _str_default(self): # write_min_config() helper function. Returns the value the symbol @@ -4968,9 +5164,7 @@ class Symbol(object): val = min(expr_value(default), cond_val) break - val = max(expr_value(self.rev_dep), - expr_value(self.weak_rev_dep), - val) + val = max(expr_value(self.rev_dep), expr_value(self.weak_rev_dep), val) # Transpose mod to yes if type is bool (possibly due to modules # being disabled) @@ -4992,11 +5186,15 @@ class Symbol(object): # and menus) is selected by some other symbol. Also warn if a symbol # whose direct dependencies evaluate to m is selected to y. - msg = "{} has direct dependencies {} with value {}, but is " \ - "currently being {}-selected by the following symbols:" \ - .format(self.name_and_loc, expr_str(self.direct_dep), - TRI_TO_STR[expr_value(self.direct_dep)], - TRI_TO_STR[expr_value(self.rev_dep)]) + msg = ( + "{} has direct dependencies {} with value {}, but is " + "currently being {}-selected by the following symbols:".format( + self.name_and_loc, + expr_str(self.direct_dep), + TRI_TO_STR[expr_value(self.direct_dep)], + TRI_TO_STR[expr_value(self.rev_dep)], + ) + ) # The reverse dependencies from each select are ORed together for select in split_expr(self.rev_dep, OR): @@ -5010,17 +5208,20 @@ class Symbol(object): # In both cases, we can split on AND and pick the first operand selecting_sym = split_expr(select, AND)[0] - msg += "\n - {}, with value {}, direct dependencies {} " \ - "(value: {})" \ - .format(selecting_sym.name_and_loc, - selecting_sym.str_value, - expr_str(selecting_sym.direct_dep), - TRI_TO_STR[expr_value(selecting_sym.direct_dep)]) + msg += ( + "\n - {}, with value {}, direct dependencies {} " + "(value: {})".format( + selecting_sym.name_and_loc, + selecting_sym.str_value, + expr_str(selecting_sym.direct_dep), + TRI_TO_STR[expr_value(selecting_sym.direct_dep)], + ) + ) if select.__class__ is tuple: - msg += ", and select condition {} (value: {})" \ - .format(expr_str(select[2]), - TRI_TO_STR[expr_value(select[2])]) + msg += ", and select condition {} (value: {})".format( + expr_str(select[2]), TRI_TO_STR[expr_value(select[2])] + ) self.kconfig._warn(msg) @@ -5182,6 +5383,7 @@ class Choice(object): kconfig: The Kconfig instance this choice is from. """ + __slots__ = ( "_cached_assignable", "_cached_selection", @@ -5299,16 +5501,22 @@ class Choice(object): self._was_set = True return True - if not (self.orig_type is BOOL and value in (2, 0) or - self.orig_type is TRISTATE and value in TRI_TO_STR): + if not ( + self.orig_type is BOOL + and value in (2, 0) + or self.orig_type is TRISTATE + and value in TRI_TO_STR + ): # Display tristate values as n, m, y in the warning self.kconfig._warn( "the value {} is invalid for {}, which has type {} -- " - "assignment ignored" - .format(TRI_TO_STR[value] if value in TRI_TO_STR else - "'{}'".format(value), - self.name_and_loc, TYPE_TO_STR[self.orig_type])) + "assignment ignored".format( + TRI_TO_STR[value] if value in TRI_TO_STR else "'{}'".format(value), + self.name_and_loc, + TYPE_TO_STR[self.orig_type], + ) + ) return False @@ -5346,8 +5554,10 @@ class Choice(object): Returns a string with information about the choice when it is evaluated on e.g. the interactive Python prompt. """ - fields = ["choice " + self.name if self.name else "choice", - TYPE_TO_STR[self.type]] + fields = [ + "choice " + self.name if self.name else "choice", + TYPE_TO_STR[self.type], + ] add = fields.append for node in self.nodes: @@ -5357,14 +5567,13 @@ class Choice(object): add("mode " + self.str_value) if self.user_value is not None: - add('user mode {}'.format(TRI_TO_STR[self.user_value])) + add("user mode {}".format(TRI_TO_STR[self.user_value])) if self.selection: add("{} selected".format(self.selection.name)) if self.user_selection: - user_sel_str = "{} selected by user" \ - .format(self.user_selection.name) + user_sel_str = "{} selected by user".format(self.user_selection.name) if self.selection is not self.user_selection: user_sel_str += " (overridden)" @@ -5399,8 +5608,7 @@ class Choice(object): Works like Choice.__str__(), but allows a custom format to be used for all symbol/choice references. See expr_str(). """ - return "\n\n".join(node.custom_str(sc_expr_str_fn) - for node in self.nodes) + return "\n\n".join(node.custom_str(sc_expr_str_fn) for node in self.nodes) # # Private methods @@ -5425,9 +5633,9 @@ class Choice(object): self.syms = [] self.defaults = [] - self.name = \ - self.user_value = self.user_selection = \ - self._cached_vis = self._cached_assignable = None + self.name = ( + self.user_value + ) = self.user_selection = self._cached_vis = self._cached_assignable = None self._cached_selection = _NO_CACHED_SELECTION @@ -5644,6 +5852,7 @@ class MenuNode(object): kconfig: The Kconfig instance the menu node is from. """ + __slots__ = ( "dep", "filename", @@ -5658,7 +5867,6 @@ class MenuNode(object): "parent", "prompt", "visibility", - # Properties "defaults", "selects", @@ -5689,32 +5897,28 @@ class MenuNode(object): """ See the class documentation. """ - return [(default, self._strip_dep(cond)) - for default, cond in self.defaults] + return [(default, self._strip_dep(cond)) for default, cond in self.defaults] @property def orig_selects(self): """ See the class documentation. """ - return [(select, self._strip_dep(cond)) - for select, cond in self.selects] + return [(select, self._strip_dep(cond)) for select, cond in self.selects] @property def orig_implies(self): """ See the class documentation. """ - return [(imply, self._strip_dep(cond)) - for imply, cond in self.implies] + return [(imply, self._strip_dep(cond)) for imply, cond in self.implies] @property def orig_ranges(self): """ See the class documentation. """ - return [(low, high, self._strip_dep(cond)) - for low, high, cond in self.ranges] + return [(low, high, self._strip_dep(cond)) for low, high, cond in self.ranges] @property def referenced(self): @@ -5774,8 +5978,11 @@ class MenuNode(object): add("menu node for comment") if self.prompt: - add('prompt "{}" (visibility {})'.format( - self.prompt[0], TRI_TO_STR[expr_value(self.prompt[1])])) + add( + 'prompt "{}" (visibility {})'.format( + self.prompt[0], TRI_TO_STR[expr_value(self.prompt[1])] + ) + ) if self.item.__class__ is Symbol and self.is_menuconfig: add("is menuconfig") @@ -5822,20 +6029,20 @@ class MenuNode(object): Works like MenuNode.__str__(), but allows a custom format to be used for all symbol/choice references. See expr_str(). """ - return self._menu_comment_node_str(sc_expr_str_fn) \ - if self.item in _MENU_COMMENT else \ - self._sym_choice_node_str(sc_expr_str_fn) + return ( + self._menu_comment_node_str(sc_expr_str_fn) + if self.item in _MENU_COMMENT + else self._sym_choice_node_str(sc_expr_str_fn) + ) def _menu_comment_node_str(self, sc_expr_str_fn): - s = '{} "{}"'.format("menu" if self.item is MENU else "comment", - self.prompt[0]) + s = '{} "{}"'.format("menu" if self.item is MENU else "comment", self.prompt[0]) if self.dep is not self.kconfig.y: s += "\n\tdepends on {}".format(expr_str(self.dep, sc_expr_str_fn)) if self.item is MENU and self.visibility is not self.kconfig.y: - s += "\n\tvisible if {}".format(expr_str(self.visibility, - sc_expr_str_fn)) + s += "\n\tvisible if {}".format(expr_str(self.visibility, sc_expr_str_fn)) return s @@ -5851,8 +6058,7 @@ class MenuNode(object): sc = self.item if sc.__class__ is Symbol: - lines = [("menuconfig " if self.is_menuconfig else "config ") - + sc.name] + lines = [("menuconfig " if self.is_menuconfig else "config ") + sc.name] else: lines = ["choice " + sc.name if sc.name else "choice"] @@ -5868,8 +6074,9 @@ class MenuNode(object): # Symbol defined without a type (which generates a warning) prefix = "prompt" - indent_add_cond(prefix + ' "{}"'.format(escape(self.prompt[0])), - self.orig_prompt[1]) + indent_add_cond( + prefix + ' "{}"'.format(escape(self.prompt[0])), self.orig_prompt[1] + ) if sc.__class__ is Symbol: if sc.is_allnoconfig_y: @@ -5886,13 +6093,12 @@ class MenuNode(object): for low, high, cond in self.orig_ranges: indent_add_cond( - "range {} {}".format(sc_expr_str_fn(low), - sc_expr_str_fn(high)), - cond) + "range {} {}".format(sc_expr_str_fn(low), sc_expr_str_fn(high)), + cond, + ) for default, cond in self.orig_defaults: - indent_add_cond("default " + expr_str(default, sc_expr_str_fn), - cond) + indent_add_cond("default " + expr_str(default, sc_expr_str_fn), cond) if sc.__class__ is Choice and sc.is_optional: indent_add("optional") @@ -5954,6 +6160,7 @@ class Variable(object): is_recursive: True if the variable is recursive (defined with =). """ + __slots__ = ( "_n_expansions", "is_recursive", @@ -5979,10 +6186,9 @@ class Variable(object): return self.kconfig._fn_val((self.name,) + args) def __repr__(self): - return "<variable {}, {}, value '{}'>" \ - .format(self.name, - "recursive" if self.is_recursive else "immediate", - self.value) + return "<variable {}, {}, value '{}'>".format( + self.name, "recursive" if self.is_recursive else "immediate", self.value + ) class KconfigError(Exception): @@ -5993,6 +6199,7 @@ class KconfigError(Exception): KconfigSyntaxError alias is only maintained for backwards compatibility. """ + KconfigSyntaxError = KconfigError # Backwards compatibility @@ -6010,7 +6217,8 @@ class _KconfigIOError(IOError): def __init__(self, ioerror, msg): self.msg = msg super(_KconfigIOError, self).__init__( - ioerror.errno, ioerror.strerror, ioerror.filename) + ioerror.errno, ioerror.strerror, ioerror.filename + ) def __str__(self): return self.msg @@ -6070,12 +6278,19 @@ def expr_value(expr): # parse as numbers comp = _strcmp(v1.str_value, v2.str_value) - return 2*(comp == 0 if rel is EQUAL else - comp != 0 if rel is UNEQUAL else - comp < 0 if rel is LESS else - comp <= 0 if rel is LESS_EQUAL else - comp > 0 if rel is GREATER else - comp >= 0) + return 2 * ( + comp == 0 + if rel is EQUAL + else comp != 0 + if rel is UNEQUAL + else comp < 0 + if rel is LESS + else comp <= 0 + if rel is LESS_EQUAL + else comp > 0 + if rel is GREATER + else comp >= 0 + ) def standard_sc_expr_str(sc): @@ -6115,14 +6330,18 @@ def expr_str(expr, sc_expr_str_fn=standard_sc_expr_str): return sc_expr_str_fn(expr) if expr[0] is AND: - return "{} && {}".format(_parenthesize(expr[1], OR, sc_expr_str_fn), - _parenthesize(expr[2], OR, sc_expr_str_fn)) + return "{} && {}".format( + _parenthesize(expr[1], OR, sc_expr_str_fn), + _parenthesize(expr[2], OR, sc_expr_str_fn), + ) if expr[0] is OR: # This turns A && B || C && D into "(A && B) || (C && D)", which is # redundant, but more readable - return "{} || {}".format(_parenthesize(expr[1], AND, sc_expr_str_fn), - _parenthesize(expr[2], AND, sc_expr_str_fn)) + return "{} || {}".format( + _parenthesize(expr[1], AND, sc_expr_str_fn), + _parenthesize(expr[2], AND, sc_expr_str_fn), + ) if expr[0] is NOT: if expr[1].__class__ is tuple: @@ -6133,8 +6352,9 @@ def expr_str(expr, sc_expr_str_fn=standard_sc_expr_str): # # Relation operands are always symbols (quoted strings are constant # symbols) - return "{} {} {}".format(sc_expr_str_fn(expr[1]), REL_TO_STR[expr[0]], - sc_expr_str_fn(expr[2])) + return "{} {} {}".format( + sc_expr_str_fn(expr[1]), REL_TO_STR[expr[0]], sc_expr_str_fn(expr[2]) + ) def expr_items(expr): @@ -6216,7 +6436,7 @@ def escape(s): replaced by \" and \\, respectively. """ # \ must be escaped before " to avoid double escaping - return s.replace("\\", r"\\").replace('"', r'\"') + return s.replace("\\", r"\\").replace('"', r"\"") def unescape(s): @@ -6226,6 +6446,7 @@ def unescape(s): """ return _unescape_sub(r"\1", s) + # unescape() helper _unescape_sub = re.compile(r"\\(.)").sub @@ -6245,15 +6466,16 @@ def standard_kconfig(description=None): import argparse parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter, - description=description) + formatter_class=argparse.RawDescriptionHelpFormatter, description=description + ) parser.add_argument( "kconfig", metavar="KCONFIG", default="Kconfig", nargs="?", - help="Top-level Kconfig file (default: Kconfig)") + help="Top-level Kconfig file (default: Kconfig)", + ) return Kconfig(parser.parse_args().kconfig, suppress_traceback=True) @@ -6299,16 +6521,20 @@ def load_allconfig(kconf, filename): try: print(kconf.load_config("all.config", False)) except EnvironmentError as e2: - sys.exit("error: KCONFIG_ALLCONFIG is set, but neither {} " - "nor all.config could be opened: {}, {}" - .format(filename, std_msg(e1), std_msg(e2))) + sys.exit( + "error: KCONFIG_ALLCONFIG is set, but neither {} " + "nor all.config could be opened: {}, {}".format( + filename, std_msg(e1), std_msg(e2) + ) + ) else: try: print(kconf.load_config(allconfig, False)) except EnvironmentError as e: - sys.exit("error: KCONFIG_ALLCONFIG is set to '{}', which " - "could not be opened: {}" - .format(allconfig, std_msg(e))) + sys.exit( + "error: KCONFIG_ALLCONFIG is set to '{}', which " + "could not be opened: {}".format(allconfig, std_msg(e)) + ) kconf.warn_assign_override = old_warn_assign_override kconf.warn_assign_redun = old_warn_assign_redun @@ -6332,8 +6558,11 @@ def _visibility(sc): vis = max(vis, expr_value(node.prompt[1])) if sc.__class__ is Symbol and sc.choice: - if sc.choice.orig_type is TRISTATE and \ - sc.orig_type is not TRISTATE and sc.choice.tri_value != 2: + if ( + sc.choice.orig_type is TRISTATE + and sc.orig_type is not TRISTATE + and sc.choice.tri_value != 2 + ): # Non-tristate choice symbols are only visible in y mode return 0 @@ -6407,8 +6636,11 @@ def _sym_to_num(sym): # For BOOL and TRISTATE, n/m/y count as 0/1/2. This mirrors 9059a3493ef # ("kconfig: fix relational operators for bool and tristate symbols") in # the C implementation. - return sym.tri_value if sym.orig_type in _BOOL_TRISTATE else \ - int(sym.str_value, _TYPE_TO_BASE[sym.orig_type]) + return ( + sym.tri_value + if sym.orig_type in _BOOL_TRISTATE + else int(sym.str_value, _TYPE_TO_BASE[sym.orig_type]) + ) def _touch_dep_file(path, sym_name): @@ -6421,8 +6653,7 @@ def _touch_dep_file(path, sym_name): os.makedirs(sym_path_dir, 0o755) # A kind of truncating touch, mirroring the C tools - os.close(os.open( - sym_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)) + os.close(os.open(sym_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)) def _save_old(path): @@ -6431,6 +6662,7 @@ def _save_old(path): def copy(src, dst): # Import as needed, to save some startup time import shutil + shutil.copyfile(src, dst) if islink(path): @@ -6463,8 +6695,8 @@ def _locs(sc): if sc.nodes: return "(defined at {})".format( - ", ".join("{0.filename}:{0.linenr}".format(node) - for node in sc.nodes)) + ", ".join("{0.filename}:{0.linenr}".format(node) for node in sc.nodes) + ) return "(undefined)" @@ -6491,13 +6723,13 @@ def _expr_depends_on(expr, sym): elif left is not sym: return False - return (expr[0] is EQUAL and right is sym.kconfig.m or - right is sym.kconfig.y) or \ - (expr[0] is UNEQUAL and right is sym.kconfig.n) + return ( + expr[0] is EQUAL and right is sym.kconfig.m or right is sym.kconfig.y + ) or (expr[0] is UNEQUAL and right is sym.kconfig.n) - return expr[0] is AND and \ - (_expr_depends_on(expr[1], sym) or - _expr_depends_on(expr[2], sym)) + return expr[0] is AND and ( + _expr_depends_on(expr[1], sym) or _expr_depends_on(expr[2], sym) + ) def _auto_menu_dep(node1, node2): @@ -6505,8 +6737,7 @@ def _auto_menu_dep(node1, node2): # node2 has a prompt, we check its condition. Otherwise, we look directly # at node2.dep. - return _expr_depends_on(node2.prompt[1] if node2.prompt else node2.dep, - node1.item) + return _expr_depends_on(node2.prompt[1] if node2.prompt else node2.dep, node1.item) def _flatten(node): @@ -6521,8 +6752,7 @@ def _flatten(node): # you enter the choice at some location with a prompt. while node: - if node.list and not node.prompt and \ - node.item.__class__ is not Choice: + if node.list and not node.prompt and node.item.__class__ is not Choice: last_node = node.list while 1: @@ -6637,9 +6867,11 @@ def _check_dep_loop_sym(sym, ignore_choice): # # Since we aren't entering the choice via a choice symbol, all # choice symbols need to be checked, hence the None. - loop = _check_dep_loop_choice(dep, None) \ - if dep.__class__ is Choice \ - else _check_dep_loop_sym(dep, False) + loop = ( + _check_dep_loop_choice(dep, None) + if dep.__class__ is Choice + else _check_dep_loop_sym(dep, False) + ) if loop: # Dependency loop found @@ -6711,8 +6943,7 @@ def _found_dep_loop(loop, cur): # Yep, we have the entire loop. Throw an exception that shows it. - msg = "\nDependency loop\n" \ - "===============\n\n" + msg = "\nDependency loop\n" "===============\n\n" for item in loop: if item is not loop[0]: @@ -6720,8 +6951,7 @@ def _found_dep_loop(loop, cur): if item.__class__ is Symbol and item.choice: msg += "the choice symbol " - msg += "{}, with definition...\n\n{}\n\n" \ - .format(item.name_and_loc, item) + msg += "{}, with definition...\n\n{}\n\n".format(item.name_and_loc, item) # Small wart: Since we reuse the already calculated # Symbol/Choice._dependents sets for recursive dependency detection, we @@ -6738,12 +6968,14 @@ def _found_dep_loop(loop, cur): if item.__class__ is Symbol: if item.rev_dep is not item.kconfig.n: - msg += "(select-related dependencies: {})\n\n" \ - .format(expr_str(item.rev_dep)) + msg += "(select-related dependencies: {})\n\n".format( + expr_str(item.rev_dep) + ) if item.weak_rev_dep is not item.kconfig.n: - msg += "(imply-related dependencies: {})\n\n" \ - .format(expr_str(item.rev_dep)) + msg += "(imply-related dependencies: {})\n\n".format( + expr_str(item.rev_dep) + ) msg += "...depends again on " + loop[0].name_and_loc @@ -6765,11 +6997,14 @@ def _decoding_error(e, filename, macro_linenr=None): "Problematic data: {}\n" "Reason: {}".format( e.encoding, - "'{}'".format(filename) if macro_linenr is None else - "output from macro at {}:{}".format(filename, macro_linenr), - e.object[max(e.start - 40, 0):e.end + 40], - e.object[e.start:e.end], - e.reason)) + "'{}'".format(filename) + if macro_linenr is None + else "output from macro at {}:{}".format(filename, macro_linenr), + e.object[max(e.start - 40, 0) : e.end + 40], + e.object[e.start : e.end], + e.reason, + ) + ) def _warn_verbose_deprecated(fn_name): @@ -6779,7 +7014,8 @@ def _warn_verbose_deprecated(fn_name): "and is always generated. Do e.g. print(kconf.{0}()) if you want to " "want to show a message like \"Loaded configuration '.config'\" on " "stdout. The old API required ugly hacks to reuse messages in " - "configuration interfaces.\n".format(fn_name)) + "configuration interfaces.\n".format(fn_name) + ) # Predefined preprocessor functions @@ -6808,8 +7044,7 @@ def _warning_if_fn(kconf, _, cond, msg): def _error_if_fn(kconf, _, cond, msg): if cond == "y": - raise KconfigError("{}:{}: {}".format( - kconf.filename, kconf.linenr, msg)) + raise KconfigError("{}:{}: {}".format(kconf.filename, kconf.linenr, msg)) return "" @@ -6829,9 +7064,11 @@ def _shell_fn(kconf, _, command): _decoding_error(e, kconf.filename, kconf.linenr) if stderr: - kconf._warn("'{}' wrote to stderr: {}".format( - command, "\n".join(stderr.splitlines())), - kconf.filename, kconf.linenr) + kconf._warn( + "'{}' wrote to stderr: {}".format(command, "\n".join(stderr.splitlines())), + kconf.filename, + kconf.linenr, + ) # Universal newlines with splitlines() (to prevent e.g. stray \r's in # command output on Windows), trailing newline removal, and @@ -6842,6 +7079,7 @@ def _shell_fn(kconf, _, command): # parameter was added in 3.6), so we do this manual version instead. return "\n".join(stdout.splitlines()).rstrip("\n").replace("\n", " ") + # # Global constants # @@ -6871,6 +7109,7 @@ try: except AttributeError: # Only import as needed, to save some startup time import platform + _UNAME_RELEASE = platform.uname()[2] # The token and type constants below are safe to test with 'is', which is a bit @@ -6940,112 +7179,112 @@ except AttributeError: # Keyword to token map, with the get() method assigned directly as a small # optimization _get_keyword = { - "---help---": _T_HELP, - "allnoconfig_y": _T_ALLNOCONFIG_Y, - "bool": _T_BOOL, - "boolean": _T_BOOL, - "choice": _T_CHOICE, - "comment": _T_COMMENT, - "config": _T_CONFIG, - "def_bool": _T_DEF_BOOL, - "def_hex": _T_DEF_HEX, - "def_int": _T_DEF_INT, - "def_string": _T_DEF_STRING, - "def_tristate": _T_DEF_TRISTATE, - "default": _T_DEFAULT, + "---help---": _T_HELP, + "allnoconfig_y": _T_ALLNOCONFIG_Y, + "bool": _T_BOOL, + "boolean": _T_BOOL, + "choice": _T_CHOICE, + "comment": _T_COMMENT, + "config": _T_CONFIG, + "def_bool": _T_DEF_BOOL, + "def_hex": _T_DEF_HEX, + "def_int": _T_DEF_INT, + "def_string": _T_DEF_STRING, + "def_tristate": _T_DEF_TRISTATE, + "default": _T_DEFAULT, "defconfig_list": _T_DEFCONFIG_LIST, - "depends": _T_DEPENDS, - "endchoice": _T_ENDCHOICE, - "endif": _T_ENDIF, - "endmenu": _T_ENDMENU, - "env": _T_ENV, - "grsource": _T_ORSOURCE, # Backwards compatibility - "gsource": _T_OSOURCE, # Backwards compatibility - "help": _T_HELP, - "hex": _T_HEX, - "if": _T_IF, - "imply": _T_IMPLY, - "int": _T_INT, - "mainmenu": _T_MAINMENU, - "menu": _T_MENU, - "menuconfig": _T_MENUCONFIG, - "modules": _T_MODULES, - "on": _T_ON, - "option": _T_OPTION, - "optional": _T_OPTIONAL, - "orsource": _T_ORSOURCE, - "osource": _T_OSOURCE, - "prompt": _T_PROMPT, - "range": _T_RANGE, - "rsource": _T_RSOURCE, - "select": _T_SELECT, - "source": _T_SOURCE, - "string": _T_STRING, - "tristate": _T_TRISTATE, - "visible": _T_VISIBLE, + "depends": _T_DEPENDS, + "endchoice": _T_ENDCHOICE, + "endif": _T_ENDIF, + "endmenu": _T_ENDMENU, + "env": _T_ENV, + "grsource": _T_ORSOURCE, # Backwards compatibility + "gsource": _T_OSOURCE, # Backwards compatibility + "help": _T_HELP, + "hex": _T_HEX, + "if": _T_IF, + "imply": _T_IMPLY, + "int": _T_INT, + "mainmenu": _T_MAINMENU, + "menu": _T_MENU, + "menuconfig": _T_MENUCONFIG, + "modules": _T_MODULES, + "on": _T_ON, + "option": _T_OPTION, + "optional": _T_OPTIONAL, + "orsource": _T_ORSOURCE, + "osource": _T_OSOURCE, + "prompt": _T_PROMPT, + "range": _T_RANGE, + "rsource": _T_RSOURCE, + "select": _T_SELECT, + "source": _T_SOURCE, + "string": _T_STRING, + "tristate": _T_TRISTATE, + "visible": _T_VISIBLE, }.get # The constants below match the value of the corresponding tokens to remove the # need for conversion # Node types -MENU = _T_MENU +MENU = _T_MENU COMMENT = _T_COMMENT # Expression types -AND = _T_AND -OR = _T_OR -NOT = _T_NOT -EQUAL = _T_EQUAL -UNEQUAL = _T_UNEQUAL -LESS = _T_LESS -LESS_EQUAL = _T_LESS_EQUAL -GREATER = _T_GREATER +AND = _T_AND +OR = _T_OR +NOT = _T_NOT +EQUAL = _T_EQUAL +UNEQUAL = _T_UNEQUAL +LESS = _T_LESS +LESS_EQUAL = _T_LESS_EQUAL +GREATER = _T_GREATER GREATER_EQUAL = _T_GREATER_EQUAL REL_TO_STR = { - EQUAL: "=", - UNEQUAL: "!=", - LESS: "<", - LESS_EQUAL: "<=", - GREATER: ">", + EQUAL: "=", + UNEQUAL: "!=", + LESS: "<", + LESS_EQUAL: "<=", + GREATER: ">", GREATER_EQUAL: ">=", } # Symbol/choice types. UNKNOWN is 0 (falsy) to simplify some checks. # Client code shouldn't rely on it though, as it was non-zero in # older versions. -UNKNOWN = 0 -BOOL = _T_BOOL +UNKNOWN = 0 +BOOL = _T_BOOL TRISTATE = _T_TRISTATE -STRING = _T_STRING -INT = _T_INT -HEX = _T_HEX +STRING = _T_STRING +INT = _T_INT +HEX = _T_HEX TYPE_TO_STR = { - UNKNOWN: "unknown", - BOOL: "bool", + UNKNOWN: "unknown", + BOOL: "bool", TRISTATE: "tristate", - STRING: "string", - INT: "int", - HEX: "hex", + STRING: "string", + INT: "int", + HEX: "hex", } # Used in comparisons. 0 means the base is inferred from the format of the # string. _TYPE_TO_BASE = { - HEX: 16, - INT: 10, - STRING: 0, - UNKNOWN: 0, + HEX: 16, + INT: 10, + STRING: 0, + UNKNOWN: 0, } # def_bool -> BOOL, etc. _DEF_TOKEN_TO_TYPE = { - _T_DEF_BOOL: BOOL, - _T_DEF_HEX: HEX, - _T_DEF_INT: INT, - _T_DEF_STRING: STRING, + _T_DEF_BOOL: BOOL, + _T_DEF_HEX: HEX, + _T_DEF_INT: INT, + _T_DEF_STRING: STRING, _T_DEF_TRISTATE: TRISTATE, } @@ -7056,91 +7295,115 @@ _DEF_TOKEN_TO_TYPE = { # Identifier-like lexemes ("missing quotes") are also treated as strings after # these tokens. _T_CHOICE is included to avoid symbols being registered for # named choices. -_STRING_LEX = frozenset({ - _T_BOOL, - _T_CHOICE, - _T_COMMENT, - _T_HEX, - _T_INT, - _T_MAINMENU, - _T_MENU, - _T_ORSOURCE, - _T_OSOURCE, - _T_PROMPT, - _T_RSOURCE, - _T_SOURCE, - _T_STRING, - _T_TRISTATE, -}) +_STRING_LEX = frozenset( + { + _T_BOOL, + _T_CHOICE, + _T_COMMENT, + _T_HEX, + _T_INT, + _T_MAINMENU, + _T_MENU, + _T_ORSOURCE, + _T_OSOURCE, + _T_PROMPT, + _T_RSOURCE, + _T_SOURCE, + _T_STRING, + _T_TRISTATE, + } +) # Various sets for quick membership tests. Gives a single global lookup and # avoids creating temporary dicts/tuples. -_TYPE_TOKENS = frozenset({ - _T_BOOL, - _T_TRISTATE, - _T_INT, - _T_HEX, - _T_STRING, -}) - -_SOURCE_TOKENS = frozenset({ - _T_SOURCE, - _T_RSOURCE, - _T_OSOURCE, - _T_ORSOURCE, -}) - -_REL_SOURCE_TOKENS = frozenset({ - _T_RSOURCE, - _T_ORSOURCE, -}) +_TYPE_TOKENS = frozenset( + { + _T_BOOL, + _T_TRISTATE, + _T_INT, + _T_HEX, + _T_STRING, + } +) + +_SOURCE_TOKENS = frozenset( + { + _T_SOURCE, + _T_RSOURCE, + _T_OSOURCE, + _T_ORSOURCE, + } +) + +_REL_SOURCE_TOKENS = frozenset( + { + _T_RSOURCE, + _T_ORSOURCE, + } +) # Obligatory (non-optional) sources -_OBL_SOURCE_TOKENS = frozenset({ - _T_SOURCE, - _T_RSOURCE, -}) - -_BOOL_TRISTATE = frozenset({ - BOOL, - TRISTATE, -}) - -_BOOL_TRISTATE_UNKNOWN = frozenset({ - BOOL, - TRISTATE, - UNKNOWN, -}) - -_INT_HEX = frozenset({ - INT, - HEX, -}) - -_SYMBOL_CHOICE = frozenset({ - Symbol, - Choice, -}) - -_MENU_COMMENT = frozenset({ - MENU, - COMMENT, -}) - -_EQUAL_UNEQUAL = frozenset({ - EQUAL, - UNEQUAL, -}) - -_RELATIONS = frozenset({ - EQUAL, - UNEQUAL, - LESS, - LESS_EQUAL, - GREATER, - GREATER_EQUAL, -}) +_OBL_SOURCE_TOKENS = frozenset( + { + _T_SOURCE, + _T_RSOURCE, + } +) + +_BOOL_TRISTATE = frozenset( + { + BOOL, + TRISTATE, + } +) + +_BOOL_TRISTATE_UNKNOWN = frozenset( + { + BOOL, + TRISTATE, + UNKNOWN, + } +) + +_INT_HEX = frozenset( + { + INT, + HEX, + } +) + +_SYMBOL_CHOICE = frozenset( + { + Symbol, + Choice, + } +) + +_MENU_COMMENT = frozenset( + { + MENU, + COMMENT, + } +) + +_EQUAL_UNEQUAL = frozenset( + { + EQUAL, + UNEQUAL, + } +) + +_RELATIONS = frozenset( + { + EQUAL, + UNEQUAL, + LESS, + LESS_EQUAL, + GREATER, + GREATER_EQUAL, + } +) # Helper functions for getting compiled regular expressions, with the needed # matching function returned directly as a small optimization. @@ -7189,7 +7452,7 @@ _string_special_search = _re_search(r'"|\'|\\|\$\(') # Special characters/strings while expanding a symbol name. Also includes # end-of-line, in case the macro is the last thing on the line. -_name_special_search = _re_search(r'[^A-Za-z0-9_$/.-]|\$\(|$') +_name_special_search = _re_search(r"[^A-Za-z0-9_$/.-]|\$\(|$") # A valid right-hand side for an assignment to a string symbol in a .config # file, including escaped characters. Extracts the contents. diff --git a/util/run_ects.py b/util/run_ects.py index 9178328e5f..9293f60779 100644 --- a/util/run_ects.py +++ b/util/run_ects.py @@ -16,81 +16,81 @@ import subprocess import sys # List of tests to run. -TESTS = ['meta', 'gpio', 'hook', 'i2c', 'interrupt', 'mutex', 'task', 'timer'] +TESTS = ["meta", "gpio", "hook", "i2c", "interrupt", "mutex", "task", "timer"] class CtsRunner(object): - """Class running eCTS tests.""" - - def __init__(self, ec_dir, dryrun): - self.ec_dir = ec_dir - self.cts_py = [] - if dryrun: - self.cts_py += ['echo'] - self.cts_py += [os.path.join(ec_dir, 'cts/cts.py')] - - def run_cmd(self, cmd): - try: - rc = subprocess.call(cmd) - if rc != 0: - return False - except OSError: - return False - return True - - def run_test(self, test): - cmd = self.cts_py + ['-m', test] - self.run_cmd(cmd) - - def run(self, tests): - for test in tests: - logging.info('Running', test, 'test.') - self.run_test(test) - - def sync(self): - logging.info('Syncing tree...') - os.chdir(self.ec_dir) - cmd = ['repo', 'sync', '.'] - return self.run_cmd(cmd) - - def upload(self): - logging.info('Uploading results...') + """Class running eCTS tests.""" + + def __init__(self, ec_dir, dryrun): + self.ec_dir = ec_dir + self.cts_py = [] + if dryrun: + self.cts_py += ["echo"] + self.cts_py += [os.path.join(ec_dir, "cts/cts.py")] + + def run_cmd(self, cmd): + try: + rc = subprocess.call(cmd) + if rc != 0: + return False + except OSError: + return False + return True + + def run_test(self, test): + cmd = self.cts_py + ["-m", test] + self.run_cmd(cmd) + + def run(self, tests): + for test in tests: + logging.info("Running", test, "test.") + self.run_test(test) + + def sync(self): + logging.info("Syncing tree...") + os.chdir(self.ec_dir) + cmd = ["repo", "sync", "."] + return self.run_cmd(cmd) + + def upload(self): + logging.info("Uploading results...") def main(): - if not os.path.exists('/etc/cros_chroot_version'): - logging.error('This script has to run inside chroot.') - sys.exit(-1) - - ec_dir = os.path.realpath(os.path.dirname(__file__) + '/..') - - parser = argparse.ArgumentParser(description='Run eCTS and report results.') - parser.add_argument('-d', - '--dryrun', - action='store_true', - help='Echo commands to be executed without running them.') - parser.add_argument('-s', - '--sync', - action='store_true', - help='Sync tree before running tests.') - parser.add_argument('-u', - '--upload', - action='store_true', - help='Upload test results.') - args = parser.parse_args() - - runner = CtsRunner(ec_dir, args.dryrun) - - if args.sync: - if not runner.sync(): - logging.error('Failed to sync.') - sys.exit(-1) - - runner.run(TESTS) - - if args.upload: - runner.upload() - - -if __name__ == '__main__': - main() + if not os.path.exists("/etc/cros_chroot_version"): + logging.error("This script has to run inside chroot.") + sys.exit(-1) + + ec_dir = os.path.realpath(os.path.dirname(__file__) + "/..") + + parser = argparse.ArgumentParser(description="Run eCTS and report results.") + parser.add_argument( + "-d", + "--dryrun", + action="store_true", + help="Echo commands to be executed without running them.", + ) + parser.add_argument( + "-s", "--sync", action="store_true", help="Sync tree before running tests." + ) + parser.add_argument( + "-u", "--upload", action="store_true", help="Upload test results." + ) + args = parser.parse_args() + + runner = CtsRunner(ec_dir, args.dryrun) + + if args.sync: + if not runner.sync(): + logging.error("Failed to sync.") + sys.exit(-1) + + runner.run(TESTS) + + if args.upload: + runner.upload() + + +if __name__ == "__main__": + main() diff --git a/util/test_kconfig_check.py b/util/test_kconfig_check.py index cd1b9bf098..db73e7ee71 100644 --- a/util/test_kconfig_check.py +++ b/util/test_kconfig_check.py @@ -16,7 +16,8 @@ import kconfig_check # Prefix that we strip from each Kconfig option, when considering whether it is # equivalent to a CONFIG option with the same name -PREFIX = 'PLATFORM_EC_' +PREFIX = "PLATFORM_EC_" + @contextlib.contextmanager def capture_sys_output(): @@ -39,38 +40,49 @@ def capture_sys_output(): # directly from Python. You can still run this test with 'pytest' if you like. class KconfigCheck(unittest.TestCase): """Tests for the KconfigCheck class""" + def test_simple_check(self): """Check it detected a new ad-hoc CONFIG""" checker = kconfig_check.KconfigCheck() - self.assertEqual(['NEW_ONE'], checker.find_new_adhoc( - configs=['NEW_ONE', 'OLD_ONE', 'IN_KCONFIG'], - kconfigs=['IN_KCONFIG'], - allowed=['OLD_ONE'])) + self.assertEqual( + ["NEW_ONE"], + checker.find_new_adhoc( + configs=["NEW_ONE", "OLD_ONE", "IN_KCONFIG"], + kconfigs=["IN_KCONFIG"], + allowed=["OLD_ONE"], + ), + ) def test_sorted_check(self): """Check it sorts the results in order""" checker = kconfig_check.KconfigCheck() self.assertSequenceEqual( - ['ANOTHER_NEW_ONE', 'NEW_ONE'], + ["ANOTHER_NEW_ONE", "NEW_ONE"], checker.find_new_adhoc( - configs=['NEW_ONE', 'ANOTHER_NEW_ONE', 'OLD_ONE', 'IN_KCONFIG'], - kconfigs=['IN_KCONFIG'], - allowed=['OLD_ONE'])) + configs=["NEW_ONE", "ANOTHER_NEW_ONE", "OLD_ONE", "IN_KCONFIG"], + kconfigs=["IN_KCONFIG"], + allowed=["OLD_ONE"], + ), + ) def check_read_configs(self, use_defines): checker = kconfig_check.KconfigCheck() with tempfile.NamedTemporaryFile() as configs: - with open(configs.name, 'w') as out: - prefix = '#define ' if use_defines else '' - suffix = ' ' if use_defines else '=' - out.write(f'''{prefix}CONFIG_OLD_ONE{suffix}y + with open(configs.name, "w") as out: + prefix = "#define " if use_defines else "" + suffix = " " if use_defines else "=" + out.write( + f"""{prefix}CONFIG_OLD_ONE{suffix}y {prefix}NOT_A_CONFIG{suffix} {prefix}CONFIG_STRING{suffix}"something" {prefix}CONFIG_INT{suffix}123 {prefix}CONFIG_HEX{suffix}45ab -''') - self.assertEqual(['OLD_ONE', 'STRING', 'INT', 'HEX'], - checker.read_configs(configs.name, use_defines)) +""" + ) + self.assertEqual( + ["OLD_ONE", "STRING", "INT", "HEX"], + checker.read_configs(configs.name, use_defines), + ) def test_read_configs(self): """Test KconfigCheck.read_configs()""" @@ -87,22 +99,24 @@ class KconfigCheck(unittest.TestCase): Args: srctree: Directory to write to """ - with open(os.path.join(srctree, 'Kconfig'), 'w') as out: - out.write(f'''config {PREFIX}MY_KCONFIG + with open(os.path.join(srctree, "Kconfig"), "w") as out: + out.write( + f"""config {PREFIX}MY_KCONFIG \tbool "my kconfig" rsource "subdir/Kconfig.wibble" -''') - subdir = os.path.join(srctree, 'subdir') +""" + ) + subdir = os.path.join(srctree, "subdir") os.mkdir(subdir) - with open(os.path.join(subdir, 'Kconfig.wibble'), 'w') as out: - out.write('menuconfig %sMENU_KCONFIG\n' % PREFIX) + with open(os.path.join(subdir, "Kconfig.wibble"), "w") as out: + out.write("menuconfig %sMENU_KCONFIG\n" % PREFIX) # Add a directory which should be ignored - bad_subdir = os.path.join(subdir, 'Kconfig') + bad_subdir = os.path.join(subdir, "Kconfig") os.mkdir(bad_subdir) - with open(os.path.join(bad_subdir, 'Kconfig.bad'), 'w') as out: - out.write('menuconfig %sBAD_KCONFIG' % PREFIX) + with open(os.path.join(bad_subdir, "Kconfig.bad"), "w") as out: + out.write("menuconfig %sBAD_KCONFIG" % PREFIX) def test_find_kconfigs(self): """Test KconfigCheck.find_kconfigs()""" @@ -110,20 +124,20 @@ rsource "subdir/Kconfig.wibble" with tempfile.TemporaryDirectory() as srctree: self.setup_srctree(srctree) files = checker.find_kconfigs(srctree) - fnames = [fname[len(srctree):] for fname in files] - self.assertEqual(['/Kconfig', '/subdir/Kconfig.wibble'], fnames) + fnames = [fname[len(srctree) :] for fname in files] + self.assertEqual(["/Kconfig", "/subdir/Kconfig.wibble"], fnames) def test_scan_kconfigs(self): """Test KconfigCheck.scan_configs()""" checker = kconfig_check.KconfigCheck() with tempfile.TemporaryDirectory() as srctree: self.setup_srctree(srctree) - self.assertEqual(['MENU_KCONFIG', 'MY_KCONFIG'], - checker.scan_kconfigs(srctree, PREFIX)) + self.assertEqual( + ["MENU_KCONFIG", "MY_KCONFIG"], checker.scan_kconfigs(srctree, PREFIX) + ) @classmethod - def setup_allowed_and_configs(cls, allowed_fname, configs_fname, - add_new_one=True): + def setup_allowed_and_configs(cls, allowed_fname, configs_fname, add_new_one=True): """Set up the 'allowed' and 'configs' files for tests Args: @@ -131,14 +145,14 @@ rsource "subdir/Kconfig.wibble" configs_fname: Filename to which CONFIGs to check should be written add_new_one: True to add CONFIG_NEW_ONE to the configs_fname file """ - with open(allowed_fname, 'w') as out: - out.write('CONFIG_OLD_ONE\n') - out.write('CONFIG_MENU_KCONFIG\n') - with open(configs_fname, 'w') as out: - to_add = ['CONFIG_OLD_ONE', 'CONFIG_MY_KCONFIG'] + with open(allowed_fname, "w") as out: + out.write("CONFIG_OLD_ONE\n") + out.write("CONFIG_MENU_KCONFIG\n") + with open(configs_fname, "w") as out: + to_add = ["CONFIG_OLD_ONE", "CONFIG_MY_KCONFIG"] if add_new_one: - to_add.append('CONFIG_NEW_ONE') - out.write('\n'.join(to_add)) + to_add.append("CONFIG_NEW_ONE") + out.write("\n".join(to_add)) def test_check_adhoc_configs(self): """Test KconfigCheck.check_adhoc_configs()""" @@ -148,12 +162,16 @@ rsource "subdir/Kconfig.wibble" with tempfile.NamedTemporaryFile() as allowed: with tempfile.NamedTemporaryFile() as configs: self.setup_allowed_and_configs(allowed.name, configs.name) - new_adhoc, unneeded_adhoc, updated_adhoc = ( - checker.check_adhoc_configs( - configs.name, srctree, allowed.name, PREFIX)) - self.assertEqual(['NEW_ONE'], new_adhoc) - self.assertEqual(['MENU_KCONFIG'], unneeded_adhoc) - self.assertEqual(['OLD_ONE'], updated_adhoc) + ( + new_adhoc, + unneeded_adhoc, + updated_adhoc, + ) = checker.check_adhoc_configs( + configs.name, srctree, allowed.name, PREFIX + ) + self.assertEqual(["NEW_ONE"], new_adhoc) + self.assertEqual(["MENU_KCONFIG"], unneeded_adhoc) + self.assertEqual(["OLD_ONE"], updated_adhoc) def test_check(self): """Test running the 'check' subcommand""" @@ -162,29 +180,39 @@ rsource "subdir/Kconfig.wibble" self.setup_srctree(srctree) with tempfile.NamedTemporaryFile() as allowed: with tempfile.NamedTemporaryFile() as configs: - self.setup_allowed_and_configs(allowed.name, - configs.name) + self.setup_allowed_and_configs(allowed.name, configs.name) ret_code = kconfig_check.main( - ['-c', configs.name, '-s', srctree, - '-a', allowed.name, '-p', PREFIX, 'check']) + [ + "-c", + configs.name, + "-s", + srctree, + "-a", + allowed.name, + "-p", + PREFIX, + "check", + ] + ) self.assertEqual(1, ret_code) - self.assertEqual('', stdout.getvalue()) - found = re.findall('(CONFIG_.*)', stderr.getvalue()) - self.assertEqual(['CONFIG_NEW_ONE'], found) + self.assertEqual("", stdout.getvalue()) + found = re.findall("(CONFIG_.*)", stderr.getvalue()) + self.assertEqual(["CONFIG_NEW_ONE"], found) def test_real_kconfig(self): """Same Kconfig should be returned for kconfiglib / adhoc""" if not kconfig_check.USE_KCONFIGLIB: - self.skipTest('No kconfiglib available') - zephyr_path = pathlib.Path('../../third_party/zephyr/main').resolve() + self.skipTest("No kconfiglib available") + zephyr_path = pathlib.Path("../../third_party/zephyr/main").resolve() if not zephyr_path.exists(): - self.skipTest('No zephyr tree available') + self.skipTest("No zephyr tree available") checker = kconfig_check.KconfigCheck() - srcdir = 'zephyr' + srcdir = "zephyr" search_paths = [zephyr_path] kc_version = checker.scan_kconfigs( - srcdir, search_paths=search_paths, try_kconfiglib=True) + srcdir, search_paths=search_paths, try_kconfiglib=True + ) adhoc_version = checker.scan_kconfigs(srcdir, try_kconfiglib=False) # List of things missing from the Kconfig @@ -192,15 +220,17 @@ rsource "subdir/Kconfig.wibble" # The Kconfig is disjoint in some places, e.g. the boards have their # own Kconfig files which are not included from the main Kconfig - missing = [item for item in missing - if not item.startswith('BOARD') and - not item.startswith('VARIANT')] + missing = [ + item + for item in missing + if not item.startswith("BOARD") and not item.startswith("VARIANT") + ] # Similarly, some other items are defined in files that are not included # in all cases, only for particular values of $(ARCH) self.assertEqual( - ['FLASH_LOAD_OFFSET', 'NPCX_HEADER', 'SYS_CLOCK_HW_CYCLES_PER_SEC'], - missing) + ["FLASH_LOAD_OFFSET", "NPCX_HEADER", "SYS_CLOCK_HW_CYCLES_PER_SEC"], missing + ) def test_check_unneeded(self): """Test running the 'check' subcommand with unneeded ad-hoc configs""" @@ -209,18 +239,29 @@ rsource "subdir/Kconfig.wibble" self.setup_srctree(srctree) with tempfile.NamedTemporaryFile() as allowed: with tempfile.NamedTemporaryFile() as configs: - self.setup_allowed_and_configs(allowed.name, - configs.name, False) + self.setup_allowed_and_configs( + allowed.name, configs.name, False + ) ret_code = kconfig_check.main( - ['-c', configs.name, '-s', srctree, - '-a', allowed.name, '-p', PREFIX, 'check']) + [ + "-c", + configs.name, + "-s", + srctree, + "-a", + allowed.name, + "-p", + PREFIX, + "check", + ] + ) self.assertEqual(1, ret_code) - self.assertEqual('', stderr.getvalue()) - found = re.findall('(CONFIG_.*)', stdout.getvalue()) - self.assertEqual(['CONFIG_MENU_KCONFIG'], found) + self.assertEqual("", stderr.getvalue()) + found = re.findall("(CONFIG_.*)", stdout.getvalue()) + self.assertEqual(["CONFIG_MENU_KCONFIG"], found) allowed = kconfig_check.NEW_ALLOWED_FNAME.read_text().splitlines() - self.assertEqual(['CONFIG_OLD_ONE'], allowed) + self.assertEqual(["CONFIG_OLD_ONE"], allowed) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/util/uart_stress_tester.py b/util/uart_stress_tester.py index b3db60060e..a89fe730c9 100755 --- a/util/uart_stress_tester.py +++ b/util/uart_stress_tester.py @@ -21,9 +21,7 @@ Prerequisite: e.g. dut-control cr50_uart_timestamp:off """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import argparse import atexit @@ -36,472 +34,501 @@ import time import serial -BAUDRATE = 115200 # Default baudrate setting for UART port -CROS_USERNAME = 'root' # Account name to login to ChromeOS -CROS_PASSWORD = 'test0000' # Password to login to ChromeOS -CHARGEN_TXT = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - # The result of 'chargen 62 62' +BAUDRATE = 115200 # Default baudrate setting for UART port +CROS_USERNAME = "root" # Account name to login to ChromeOS +CROS_PASSWORD = "test0000" # Password to login to ChromeOS +CHARGEN_TXT = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +# The result of 'chargen 62 62' CHARGEN_TXT_LEN = len(CHARGEN_TXT) -CR = '\r' # Carriage Return -LF = '\n' # Line Feed +CR = "\r" # Carriage Return +LF = "\n" # Line Feed CRLF = CR + LF -FLAG_FILENAME = '/tmp/chargen_testing' -TPM_CMD = ('trunks_client --key_create --rsa=2048 --usage=sign' - ' --key_blob=/tmp/blob &> /dev/null') - # A ChromeOS TPM command for the cr50 stress - # purpose. -CR50_LOAD_GEN_CMD = ('while [[ -f %s ]]; do %s; done &' - % (FLAG_FILENAME, TPM_CMD)) - # A command line to run TPM_CMD in background - # infinitely. +FLAG_FILENAME = "/tmp/chargen_testing" +TPM_CMD = ( + "trunks_client --key_create --rsa=2048 --usage=sign" + " --key_blob=/tmp/blob &> /dev/null" +) +# A ChromeOS TPM command for the cr50 stress +# purpose. +CR50_LOAD_GEN_CMD = "while [[ -f %s ]]; do %s; done &" % (FLAG_FILENAME, TPM_CMD) +# A command line to run TPM_CMD in background +# infinitely. class ChargenTestError(Exception): - """Exception for Uart Stress Test Error""" - pass + """Exception for Uart Stress Test Error""" + pass -class UartSerial(object): - """Test Object for a single UART serial device - - Attributes: - UART_DEV_PROFILES - char_loss_occurrences: Number that character loss happens - cleanup_cli: Command list to perform before the test exits - cr50_workload: True if cr50 should be stressed, or False otherwise - usb_output: True if output should be generated to USB channel - dev_prof: Dictionary of device profile - duration: Time to keep chargen running - eol: Characters to add at the end of input - logger: object that store the log - num_ch_exp: Expected number of characters in output - num_ch_cap: Number of captured characters in output - test_cli: Command list to run for chargen test - test_thread: Thread object that captures the UART output - serial: serial.Serial object - """ - UART_DEV_PROFILES = ( - # Kernel - { - 'prompt':'localhost login:', - 'device_type':'AP', - 'prepare_cmd':[ - CROS_USERNAME, # Login - CROS_PASSWORD, # Password - 'dmesg -D', # Disable console message - 'touch ' + FLAG_FILENAME, # Create a temp file - ], - 'cleanup_cmd':[ - 'rm -f ' + FLAG_FILENAME, # Remove the temp file - 'dmesg -E', # Enable console message - 'logout', # Logout - ], - 'end_of_input':LF, - }, - # EC - { - 'prompt':'> ', - 'device_type':'EC', - 'prepare_cmd':[ - 'chan save', - 'chan 0' # Disable console message - ], - 'cleanup_cmd':['', 'chan restore'], - 'end_of_input':CRLF, - }, - ) - - def __init__(self, port, duration, timeout=1, - baudrate=BAUDRATE, cr50_workload=False, - usb_output=False): - """Initialize UartSerial - - Args: - port: UART device path. e.g. /dev/ttyUSB0 - duration: Time to test, in seconds - timeout: Read timeout value. - baudrate: Baud rate such as 9600 or 115200. - cr50_workload: True if a workload should be generated on cr50 - usb_output: True if a workload should be generated to USB channel - """ - - # Initialize serial object - self.serial = serial.Serial() - self.serial.port = port - self.serial.timeout = timeout - self.serial.baudrate = baudrate - - self.duration = duration - self.cr50_workload = cr50_workload - self.usb_output = usb_output - - self.logger = logging.getLogger(type(self).__name__ + '| ' + port) - self.test_thread = threading.Thread(target=self.stress_test_thread) - self.dev_prof = {} - self.cleanup_cli = [] - self.test_cli = [] - self.eol = CRLF - self.num_ch_exp = 0 - self.num_ch_cap = 0 - self.char_loss_occurrences = 0 - atexit.register(self.cleanup) - - def run_command(self, command_lines, delay=0): - """Run command(s) at UART prompt - - Args: - command_lines: list of commands to run. - delay: delay after a command in second +class UartSerial(object): + """Test Object for a single UART serial device + + Attributes: + UART_DEV_PROFILES + char_loss_occurrences: Number that character loss happens + cleanup_cli: Command list to perform before the test exits + cr50_workload: True if cr50 should be stressed, or False otherwise + usb_output: True if output should be generated to USB channel + dev_prof: Dictionary of device profile + duration: Time to keep chargen running + eol: Characters to add at the end of input + logger: object that store the log + num_ch_exp: Expected number of characters in output + num_ch_cap: Number of captured characters in output + test_cli: Command list to run for chargen test + test_thread: Thread object that captures the UART output + serial: serial.Serial object """ - for cli in command_lines: - self.logger.debug('run %r', cli) - - self.serial.write((cli + self.eol).encode()) - self.serial.flush() - if delay: - time.sleep(delay) - - def cleanup(self): - """Before termination, clean up the UART device.""" - self.logger.debug('Closing...') - - self.serial.open() - self.run_command(self.cleanup_cli) # Run cleanup commands - self.serial.close() - self.logger.debug('Cleanup done') + UART_DEV_PROFILES = ( + # Kernel + { + "prompt": "localhost login:", + "device_type": "AP", + "prepare_cmd": [ + CROS_USERNAME, # Login + CROS_PASSWORD, # Password + "dmesg -D", # Disable console message + "touch " + FLAG_FILENAME, # Create a temp file + ], + "cleanup_cmd": [ + "rm -f " + FLAG_FILENAME, # Remove the temp file + "dmesg -E", # Enable console message + "logout", # Logout + ], + "end_of_input": LF, + }, + # EC + { + "prompt": "> ", + "device_type": "EC", + "prepare_cmd": ["chan save", "chan 0"], # Disable console message + "cleanup_cmd": ["", "chan restore"], + "end_of_input": CRLF, + }, + ) + + def __init__( + self, + port, + duration, + timeout=1, + baudrate=BAUDRATE, + cr50_workload=False, + usb_output=False, + ): + """Initialize UartSerial + + Args: + port: UART device path. e.g. /dev/ttyUSB0 + duration: Time to test, in seconds + timeout: Read timeout value. + baudrate: Baud rate such as 9600 or 115200. + cr50_workload: True if a workload should be generated on cr50 + usb_output: True if a workload should be generated to USB channel + """ + + # Initialize serial object + self.serial = serial.Serial() + self.serial.port = port + self.serial.timeout = timeout + self.serial.baudrate = baudrate + + self.duration = duration + self.cr50_workload = cr50_workload + self.usb_output = usb_output + + self.logger = logging.getLogger(type(self).__name__ + "| " + port) + self.test_thread = threading.Thread(target=self.stress_test_thread) + + self.dev_prof = {} + self.cleanup_cli = [] + self.test_cli = [] + self.eol = CRLF + self.num_ch_exp = 0 + self.num_ch_cap = 0 + self.char_loss_occurrences = 0 + atexit.register(self.cleanup) + + def run_command(self, command_lines, delay=0): + """Run command(s) at UART prompt + + Args: + command_lines: list of commands to run. + delay: delay after a command in second + """ + for cli in command_lines: + self.logger.debug("run %r", cli) + + self.serial.write((cli + self.eol).encode()) + self.serial.flush() + if delay: + time.sleep(delay) + + def cleanup(self): + """Before termination, clean up the UART device.""" + self.logger.debug("Closing...") + + self.serial.open() + self.run_command(self.cleanup_cli) # Run cleanup commands + self.serial.close() + + self.logger.debug("Cleanup done") + + def get_output(self): + """Capture the UART output + + Args: + stop_char: Read output buffer until it reads stop_char. + + Returns: + text from UART output. + """ + if self.serial.inWaiting() == 0: + time.sleep(1) + + return self.serial.read(self.serial.inWaiting()).decode() + + def prepare(self): + """Prepare the test: + + Identify the type of UART device (EC or Kernel?), then + decide what kind of commands to use to generate stress loads. + + Raises: + ChargenTestError if UART source can't be identified. + """ + try: + self.logger.info("Preparing...") + + self.serial.open() + + # Prepare the device for test + self.serial.flushInput() + self.serial.flushOutput() + + self.get_output() # drain data + + # Give a couple of line feeds, and capture the prompt text + self.run_command(["", ""]) + prompt_txt = self.get_output() + + # Detect the device source: EC or AP? + # Detect if the device is AP or EC console based on the captured. + for dev_prof in self.UART_DEV_PROFILES: + if dev_prof["prompt"] in prompt_txt: + self.dev_prof = dev_prof + break + else: + # No prompt patterns were found. UART seems not responding or in + # an undesirable status. + if prompt_txt: + raise ChargenTestError( + "%s: Got an unknown prompt text: %s\n" + "Check manually whether %s is available." + % (self.serial.port, prompt_txt, self.serial.port) + ) + else: + raise ChargenTestError( + "%s: Got no input. Close any other connections" + " to this port, and try it again." % self.serial.port + ) + + self.logger.info("Detected as %s UART", self.dev_prof["device_type"]) + # Log displays the UART type (AP|EC) instead of device filename. + self.logger = logging.getLogger( + type(self).__name__ + "| " + self.dev_prof["device_type"] + ) + + # Either login to AP or run some commands to prepare the device + # for test + self.eol = self.dev_prof["end_of_input"] + self.run_command(self.dev_prof["prepare_cmd"], delay=2) + self.cleanup_cli += self.dev_prof["cleanup_cmd"] + + # 'chargen' of AP does not have option for USB output. + # Force it work on UART. + if self.dev_prof["device_type"] == "AP": + self.usb_output = False + + # Check whether the command 'chargen' is available in the device. + # 'chargen 1 4' is supposed to print '0000' + self.get_output() # drain data + + chargen_cmd = "chargen 1 4" + if self.usb_output: + chargen_cmd += " usb" + self.run_command([chargen_cmd]) + tmp_txt = self.get_output() + + # Check whether chargen command is available. + if "0000" not in tmp_txt: + raise ChargenTestError( + "%s: Chargen got an unexpected result: %s" + % (self.dev_prof["device_type"], tmp_txt) + ) + + self.num_ch_exp = int(self.serial.baudrate * self.duration / 10) + chargen_cmd = "chargen " + str(CHARGEN_TXT_LEN) + " " + str(self.num_ch_exp) + if self.usb_output: + chargen_cmd += " usb" + self.test_cli = [chargen_cmd] + + self.logger.info("Ready to test") + finally: + self.serial.close() + + def stress_test_thread(self): + """Test thread + + Raises: + ChargenTestError: if broken character is found. + """ + try: + self.serial.open() + self.serial.flushInput() + self.serial.flushOutput() + + # Run TPM command in background to burden cr50. + if self.dev_prof["device_type"] == "AP" and self.cr50_workload: + self.run_command([CR50_LOAD_GEN_CMD]) + self.logger.debug("run TPM job while %s exists", FLAG_FILENAME) + + # Run the command 'chargen', one time + self.run_command([""]) # Give a line feed + self.get_output() # Drain the output + self.run_command(self.test_cli) + self.serial.readline() # Drain the echoed command line. + + err_msg = "%s: Expected %r but got %s after %d char received" + + # Keep capturing the output until the test timer is expired. + self.num_ch_cap = 0 + self.char_loss_occurrences = 0 + data_starve_count = 0 + + total_num_ch = self.num_ch_exp # Expected number of characters in total + ch_exp = CHARGEN_TXT[0] + ch_cap = "z" # any character value is ok for loop initial condition. + while self.num_ch_cap < total_num_ch: + captured = self.get_output() + + if captured: + # There is some output data. Reset the data starvation count. + data_starve_count = 0 + else: + data_starve_count += 1 + if data_starve_count > 1: + # If nothing was captured more than once, then terminate the test. + self.logger.debug("No more output") + break + + for ch_cap in captured: + if ch_cap not in CHARGEN_TXT: + # If it is not alpha-numeric, terminate the test. + if ch_cap not in CRLF: + # If it is neither a CR nor LF, then it is an error case. + self.logger.error("Whole captured characters: %r", captured) + raise ChargenTestError( + err_msg + % ( + "Broken char captured", + ch_exp, + hex(ord(ch_cap)), + self.num_ch_cap, + ) + ) + + # Set the loop termination condition true. + total_num_ch = self.num_ch_cap + + if self.num_ch_cap >= total_num_ch: + break + + if ch_exp != ch_cap: + # If it is alpha-numeric but not continuous, then some characters + # are lost. + self.logger.error( + err_msg, + "Char loss detected", + ch_exp, + repr(ch_cap), + self.num_ch_cap, + ) + self.char_loss_occurrences += 1 + + # Recalculate the expected number of characters to adjust + # termination condition. The loss might be bigger than this + # adjustment, but it is okay since it will terminates by either + # CR/LF detection or by data starvation. + idx_ch_exp = CHARGEN_TXT.find(ch_exp) + idx_ch_cap = CHARGEN_TXT.find(ch_cap) + if idx_ch_cap < idx_ch_exp: + idx_ch_cap += len(CHARGEN_TXT) + total_num_ch -= idx_ch_cap - idx_ch_exp + + self.num_ch_cap += 1 + + # Determine What character is expected next? + ch_exp = CHARGEN_TXT[ + (CHARGEN_TXT.find(ch_cap) + 1) % CHARGEN_TXT_LEN + ] + + finally: + self.serial.close() + + def start_test(self): + """Start the test thread""" + self.logger.info("Test thread starts") + self.test_thread.start() + + def wait_test_done(self): + """Wait until the test thread get done and join""" + self.test_thread.join() + self.logger.info("Test thread is done") + + def get_result(self): + """Display the result + + Returns: + Integer = the number of lost character + + Raises: + ChargenTestError: if the capture is corrupted. + """ + # If more characters than expected are captured, it means some messages + # from other than chargen are mixed. Stop processing further. + if self.num_ch_exp < self.num_ch_cap: + raise ChargenTestError( + "%s: UART output is corrupted." % self.dev_prof["device_type"] + ) + + # Get the count difference between the expected to the captured + # as the number of lost character. + char_lost = self.num_ch_exp - self.num_ch_cap + self.logger.info( + "%8d char lost / %10d (%.1f %%)", + char_lost, + self.num_ch_exp, + char_lost * 100.0 / self.num_ch_exp, + ) + + return char_lost, self.num_ch_exp, self.char_loss_occurrences - def get_output(self): - """Capture the UART output - Args: - stop_char: Read output buffer until it reads stop_char. +class ChargenTest(object): + """UART stress tester - Returns: - text from UART output. + Attributes: + logger: logging object + serials: Dictionary where key is filename of UART device, and the value is + UartSerial object """ - if self.serial.inWaiting() == 0: - time.sleep(1) - - return self.serial.read(self.serial.inWaiting()).decode() - def prepare(self): - """Prepare the test: - - Identify the type of UART device (EC or Kernel?), then - decide what kind of commands to use to generate stress loads. - - Raises: - ChargenTestError if UART source can't be identified. - """ - try: - self.logger.info('Preparing...') - - self.serial.open() - - # Prepare the device for test - self.serial.flushInput() - self.serial.flushOutput() - - self.get_output() # drain data - - # Give a couple of line feeds, and capture the prompt text - self.run_command(['', '']) - prompt_txt = self.get_output() - - # Detect the device source: EC or AP? - # Detect if the device is AP or EC console based on the captured. - for dev_prof in self.UART_DEV_PROFILES: - if dev_prof['prompt'] in prompt_txt: - self.dev_prof = dev_prof - break - else: - # No prompt patterns were found. UART seems not responding or in - # an undesirable status. - if prompt_txt: - raise ChargenTestError('%s: Got an unknown prompt text: %s\n' - 'Check manually whether %s is available.' % - (self.serial.port, prompt_txt, - self.serial.port)) - else: - raise ChargenTestError('%s: Got no input. Close any other connections' - ' to this port, and try it again.' % - self.serial.port) - - self.logger.info('Detected as %s UART', self.dev_prof['device_type']) - # Log displays the UART type (AP|EC) instead of device filename. - self.logger = logging.getLogger(type(self).__name__ + '| ' + - self.dev_prof['device_type']) - - # Either login to AP or run some commands to prepare the device - # for test - self.eol = self.dev_prof['end_of_input'] - self.run_command(self.dev_prof['prepare_cmd'], delay=2) - self.cleanup_cli += self.dev_prof['cleanup_cmd'] - - # 'chargen' of AP does not have option for USB output. - # Force it work on UART. - if self.dev_prof['device_type'] == 'AP': - self.usb_output = False - - # Check whether the command 'chargen' is available in the device. - # 'chargen 1 4' is supposed to print '0000' - self.get_output() # drain data - - chargen_cmd = 'chargen 1 4' - if self.usb_output: - chargen_cmd += ' usb' - self.run_command([chargen_cmd]) - tmp_txt = self.get_output() - - # Check whether chargen command is available. - if '0000' not in tmp_txt: - raise ChargenTestError('%s: Chargen got an unexpected result: %s' % - (self.dev_prof['device_type'], tmp_txt)) - - self.num_ch_exp = int(self.serial.baudrate * self.duration / 10) - chargen_cmd = 'chargen ' + str(CHARGEN_TXT_LEN) + ' ' + \ - str(self.num_ch_exp) - if self.usb_output: - chargen_cmd += ' usb' - self.test_cli = [chargen_cmd] - - self.logger.info('Ready to test') - finally: - self.serial.close() - - def stress_test_thread(self): - """Test thread - - Raises: - ChargenTestError: if broken character is found. - """ - try: - self.serial.open() - self.serial.flushInput() - self.serial.flushOutput() - - # Run TPM command in background to burden cr50. - if self.dev_prof['device_type'] == 'AP' and self.cr50_workload: - self.run_command([CR50_LOAD_GEN_CMD]) - self.logger.debug('run TPM job while %s exists', FLAG_FILENAME) - - # Run the command 'chargen', one time - self.run_command(['']) # Give a line feed - self.get_output() # Drain the output - self.run_command(self.test_cli) - self.serial.readline() # Drain the echoed command line. - - err_msg = '%s: Expected %r but got %s after %d char received' - - # Keep capturing the output until the test timer is expired. - self.num_ch_cap = 0 - self.char_loss_occurrences = 0 - data_starve_count = 0 - - total_num_ch = self.num_ch_exp # Expected number of characters in total - ch_exp = CHARGEN_TXT[0] - ch_cap = 'z' # any character value is ok for loop initial condition. - while self.num_ch_cap < total_num_ch: - captured = self.get_output() - - if captured: - # There is some output data. Reset the data starvation count. - data_starve_count = 0 + def __init__(self, ports, duration, cr50_workload=False, usb_output=False): + """Initialize UART stress tester + + Args: + ports: List of UART ports to test. + duration: Time to keep testing in seconds. + cr50_workload: True if a workload should be generated on cr50 + usb_output: True if a workload should be generated to USB channel + + Raises: + ChargenTestError: if any of ports is not a valid character device. + """ + + # Save the arguments + for port in ports: + try: + mode = os.stat(port).st_mode + except OSError as e: + raise ChargenTestError(e) + if not stat.S_ISCHR(mode): + raise ChargenTestError("%s is not a character device." % port) + + if duration <= 0: + raise ChargenTestError("Input error: duration is not positive.") + + # Initialize logging object + self.logger = logging.getLogger(type(self).__name__) + + # Create an UartSerial object per UART port + self.serials = {} # UartSerial objects + for port in ports: + self.serials[port] = UartSerial( + port=port, + duration=duration, + cr50_workload=cr50_workload, + usb_output=usb_output, + ) + + def prepare(self): + """Prepare the test for each UART port""" + self.logger.info("Prepare ports for test") + for _, ser in self.serials.items(): + ser.prepare() + self.logger.info("Ports are ready to test") + + def print_result(self): + """Display the test result for each UART port + + Returns: + char_lost: Total number of characters lost + """ + char_lost = 0 + for _, ser in self.serials.items(): + (tmp_lost, _, _) = ser.get_result() + char_lost += tmp_lost + + # If any characters are lost, then test fails. + msg = "lost %d character(s) from the test" % char_lost + if char_lost > 0: + self.logger.error("FAIL: %s", msg) else: - data_starve_count += 1 - if data_starve_count > 1: - # If nothing was captured more than once, then terminate the test. - self.logger.debug('No more output') - break - - for ch_cap in captured: - if ch_cap not in CHARGEN_TXT: - # If it is not alpha-numeric, terminate the test. - if ch_cap not in CRLF: - # If it is neither a CR nor LF, then it is an error case. - self.logger.error('Whole captured characters: %r', captured) - raise ChargenTestError(err_msg % ('Broken char captured', ch_exp, - hex(ord(ch_cap)), - self.num_ch_cap)) - - # Set the loop termination condition true. - total_num_ch = self.num_ch_cap - - if self.num_ch_cap >= total_num_ch: - break - - if ch_exp != ch_cap: - # If it is alpha-numeric but not continuous, then some characters - # are lost. - self.logger.error(err_msg, 'Char loss detected', - ch_exp, repr(ch_cap), self.num_ch_cap) - self.char_loss_occurrences += 1 - - # Recalculate the expected number of characters to adjust - # termination condition. The loss might be bigger than this - # adjustment, but it is okay since it will terminates by either - # CR/LF detection or by data starvation. - idx_ch_exp = CHARGEN_TXT.find(ch_exp) - idx_ch_cap = CHARGEN_TXT.find(ch_cap) - if idx_ch_cap < idx_ch_exp: - idx_ch_cap += len(CHARGEN_TXT) - total_num_ch -= (idx_ch_cap - idx_ch_exp) - - self.num_ch_cap += 1 - - # Determine What character is expected next? - ch_exp = CHARGEN_TXT[(CHARGEN_TXT.find(ch_cap) + 1) % CHARGEN_TXT_LEN] - - finally: - self.serial.close() - - def start_test(self): - """Start the test thread""" - self.logger.info('Test thread starts') - self.test_thread.start() - - def wait_test_done(self): - """Wait until the test thread get done and join""" - self.test_thread.join() - self.logger.info('Test thread is done') - - def get_result(self): - """Display the result + self.logger.info("PASS: %s", msg) - Returns: - Integer = the number of lost character + return char_lost - Raises: - ChargenTestError: if the capture is corrupted. - """ - # If more characters than expected are captured, it means some messages - # from other than chargen are mixed. Stop processing further. - if self.num_ch_exp < self.num_ch_cap: - raise ChargenTestError('%s: UART output is corrupted.' % - self.dev_prof['device_type']) + def run(self): + """Run the stress test on UART port(s) - # Get the count difference between the expected to the captured - # as the number of lost character. - char_lost = self.num_ch_exp - self.num_ch_cap - self.logger.info('%8d char lost / %10d (%.1f %%)', - char_lost, self.num_ch_exp, - char_lost * 100.0 / self.num_ch_exp) + Raises: + ChargenTestError: If any characters are lost. + """ - return char_lost, self.num_ch_exp, self.char_loss_occurrences + # Detect UART source type, and decide which command to test. + self.prepare() + # Run the test on each UART port in thread. + self.logger.info("Test starts") + for _, ser in self.serials.items(): + ser.start_test() -class ChargenTest(object): - """UART stress tester + # Wait all tests to finish. + for _, ser in self.serials.items(): + ser.wait_test_done() - Attributes: - logger: logging object - serials: Dictionary where key is filename of UART device, and the value is - UartSerial object - """ + # Print the result. + char_lost = self.print_result() + if char_lost: + raise ChargenTestError("Test failed: lost %d character(s)" % char_lost) - def __init__(self, ports, duration, cr50_workload=False, - usb_output=False): - """Initialize UART stress tester + self.logger.info("Test is done") - Args: - ports: List of UART ports to test. - duration: Time to keep testing in seconds. - cr50_workload: True if a workload should be generated on cr50 - usb_output: True if a workload should be generated to USB channel - Raises: - ChargenTestError: if any of ports is not a valid character device. - """ +def parse_args(cmdline): + """Parse command line arguments. - # Save the arguments - for port in ports: - try: - mode = os.stat(port).st_mode - except OSError as e: - raise ChargenTestError(e) - if not stat.S_ISCHR(mode): - raise ChargenTestError('%s is not a character device.' % port) - - if duration <= 0: - raise ChargenTestError('Input error: duration is not positive.') - - # Initialize logging object - self.logger = logging.getLogger(type(self).__name__) - - # Create an UartSerial object per UART port - self.serials = {} # UartSerial objects - for port in ports: - self.serials[port] = UartSerial(port=port, duration=duration, - cr50_workload=cr50_workload, - usb_output=usb_output) - - def prepare(self): - """Prepare the test for each UART port""" - self.logger.info('Prepare ports for test') - for _, ser in self.serials.items(): - ser.prepare() - self.logger.info('Ports are ready to test') - - def print_result(self): - """Display the test result for each UART port + Args: + cmdline: list to be parsed Returns: - char_lost: Total number of characters lost - """ - char_lost = 0 - for _, ser in self.serials.items(): - (tmp_lost, _, _) = ser.get_result() - char_lost += tmp_lost - - # If any characters are lost, then test fails. - msg = 'lost %d character(s) from the test' % char_lost - if char_lost > 0: - self.logger.error('FAIL: %s', msg) - else: - self.logger.info('PASS: %s', msg) - - return char_lost - - def run(self): - """Run the stress test on UART port(s) - - Raises: - ChargenTestError: If any characters are lost. + tuple (options, args) where args is a list of cmdline arguments that the + parser was unable to match i.e. they're servod controls, not options. """ - - # Detect UART source type, and decide which command to test. - self.prepare() - - # Run the test on each UART port in thread. - self.logger.info('Test starts') - for _, ser in self.serials.items(): - ser.start_test() - - # Wait all tests to finish. - for _, ser in self.serials.items(): - ser.wait_test_done() - - # Print the result. - char_lost = self.print_result() - if char_lost: - raise ChargenTestError('Test failed: lost %d character(s)' % - char_lost) - - self.logger.info('Test is done') - -def parse_args(cmdline): - """Parse command line arguments. - - Args: - cmdline: list to be parsed - - Returns: - tuple (options, args) where args is a list of cmdline arguments that the - parser was unable to match i.e. they're servod controls, not options. - """ - description = """%(prog)s repeats sending a uart console command + description = """%(prog)s repeats sending a uart console command to each UART device for a given time, and check if output has any missing characters. @@ -511,52 +538,70 @@ Examples: %(prog)s /dev/ttyUSB1 /dev/ttyUSB2 --cr50 """ - parser = argparse.ArgumentParser(description=description, - formatter_class=argparse.RawTextHelpFormatter - ) - parser.add_argument('port', type=str, nargs='*', - help='UART device path to test') - parser.add_argument('-c', '--cr50', action='store_true', default=False, - help='generate TPM workload on cr50') - parser.add_argument('-d', '--debug', action='store_true', default=False, - help='enable debug messages') - parser.add_argument('-t', '--time', type=int, - help='Test duration in second', default=300) - parser.add_argument('-u', '--usb', action='store_true', default=False, - help='Generate output to USB channel instead') - return parser.parse_known_args(cmdline) + parser = argparse.ArgumentParser( + description=description, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument("port", type=str, nargs="*", help="UART device path to test") + parser.add_argument( + "-c", + "--cr50", + action="store_true", + default=False, + help="generate TPM workload on cr50", + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="enable debug messages", + ) + parser.add_argument( + "-t", "--time", type=int, help="Test duration in second", default=300 + ) + parser.add_argument( + "-u", + "--usb", + action="store_true", + default=False, + help="Generate output to USB channel instead", + ) + return parser.parse_known_args(cmdline) def main(): - """Main function wrapper""" - try: - (options, _) = parse_args(sys.argv[1:]) - - # Set Log format - log_format = '%(asctime)s %(levelname)-6s | %(name)-25s' - date_format = '%Y-%m-%d %H:%M:%S' - if options.debug: - log_format += ' | %(filename)s:%(lineno)4d:%(funcName)-18s' - loglevel = logging.DEBUG - else: - loglevel = logging.INFO - log_format += ' | %(message)s' - - logging.basicConfig(level=loglevel, format=log_format, - datefmt=date_format) - - # Create a ChargenTest object - utest = ChargenTest(options.port, options.time, - cr50_workload=options.cr50, - usb_output=options.usb) - utest.run() # Run - - except KeyboardInterrupt: - sys.exit(0) - - except ChargenTestError as e: - logging.error(str(e)) - sys.exit(1) - -if __name__ == '__main__': - main() + """Main function wrapper""" + try: + (options, _) = parse_args(sys.argv[1:]) + + # Set Log format + log_format = "%(asctime)s %(levelname)-6s | %(name)-25s" + date_format = "%Y-%m-%d %H:%M:%S" + if options.debug: + log_format += " | %(filename)s:%(lineno)4d:%(funcName)-18s" + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + log_format += " | %(message)s" + + logging.basicConfig(level=loglevel, format=log_format, datefmt=date_format) + + # Create a ChargenTest object + utest = ChargenTest( + options.port, + options.time, + cr50_workload=options.cr50, + usb_output=options.usb, + ) + utest.run() # Run + + except KeyboardInterrupt: + sys.exit(0) + + except ChargenTestError as e: + logging.error(str(e)) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/util/unpack_ftb.py b/util/unpack_ftb.py index 03127a7089..a68662d82b 100755 --- a/util/unpack_ftb.py +++ b/util/unpack_ftb.py @@ -10,26 +10,28 @@ # Note: This is a py2/3 compatible file. from __future__ import print_function + import argparse import ctypes import os class Header(ctypes.Structure): - _pack_ = 1 - _fields_ = [ - ('signature', ctypes.c_uint32), - ('ftb_ver', ctypes.c_uint32), - ('chip_id', ctypes.c_uint32), - ('svn_ver', ctypes.c_uint32), - ('fw_ver', ctypes.c_uint32), - ('config_id', ctypes.c_uint32), - ('config_ver', ctypes.c_uint32), - ('reserved', ctypes.c_uint8 * 8), - ('release_info', ctypes.c_ulonglong), - ('sec_size', ctypes.c_uint32 * 4), - ('crc', ctypes.c_uint32), - ] + _pack_ = 1 + _fields_ = [ + ("signature", ctypes.c_uint32), + ("ftb_ver", ctypes.c_uint32), + ("chip_id", ctypes.c_uint32), + ("svn_ver", ctypes.c_uint32), + ("fw_ver", ctypes.c_uint32), + ("config_id", ctypes.c_uint32), + ("config_ver", ctypes.c_uint32), + ("reserved", ctypes.c_uint8 * 8), + ("release_info", ctypes.c_ulonglong), + ("sec_size", ctypes.c_uint32 * 4), + ("crc", ctypes.c_uint32), + ] + FW_HEADER_SIZE = 64 FW_HEADER_SIGNATURE = 0xAA55AA55 @@ -44,7 +46,7 @@ FLASH_SEC_ADDR = [ 0x0000 * 4, # CODE 0x7C00 * 4, # CONFIG 0x7000 * 4, # CX - None # This section shouldn't exist + None, # This section shouldn't exist ] UPDATE_PDU_SIZE = 4096 @@ -59,64 +61,66 @@ OUTPUT_FILE_SIZE = UPDATE_PDU_SIZE + 128 * 1024 def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', '-i', required=True) - parser.add_argument('--output', '-o', required=True) - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument("--input", "-i", required=True) + parser.add_argument("--output", "-o", required=True) + args = parser.parse_args() - with open(args.input, 'rb') as f: - bs = f.read() + with open(args.input, "rb") as f: + bs = f.read() - size = len(bs) - if size < FW_HEADER_SIZE + FW_BYTES_ALIGN: - raise Exception('FW size too small') + size = len(bs) + if size < FW_HEADER_SIZE + FW_BYTES_ALIGN: + raise Exception("FW size too small") - print('FTB file size:', size) + print("FTB file size:", size) - header = Header() - assert ctypes.sizeof(header) == FW_HEADER_SIZE + header = Header() + assert ctypes.sizeof(header) == FW_HEADER_SIZE - ctypes.memmove(ctypes.addressof(header), bs, ctypes.sizeof(header)) - if (header.signature != FW_HEADER_SIGNATURE or - header.ftb_ver != FW_FTB_VER or - header.chip_id != FW_CHIP_ID): - raise Exception('Invalid header') + ctypes.memmove(ctypes.addressof(header), bs, ctypes.sizeof(header)) + if ( + header.signature != FW_HEADER_SIGNATURE + or header.ftb_ver != FW_FTB_VER + or header.chip_id != FW_CHIP_ID + ): + raise Exception("Invalid header") - for key, _ in header._fields_: - v = getattr(header, key) - if isinstance(v, ctypes.Array): - print(key, list(map(hex, v))) - else: - print(key, hex(v)) + for key, _ in header._fields_: + v = getattr(header, key) + if isinstance(v, ctypes.Array): + print(key, list(map(hex, v))) + else: + print(key, hex(v)) - dimension = sum(header.sec_size) + dimension = sum(header.sec_size) - assert dimension + FW_HEADER_SIZE + FW_BYTES_ALIGN == size - data = bs[FW_HEADER_SIZE:FW_HEADER_SIZE + dimension] + assert dimension + FW_HEADER_SIZE + FW_BYTES_ALIGN == size + data = bs[FW_HEADER_SIZE : FW_HEADER_SIZE + dimension] - with open(args.output, 'wb') as f: - # ensure the file size - f.seek(OUTPUT_FILE_SIZE - 1, os.SEEK_SET) - f.write(b'\x00') + with open(args.output, "wb") as f: + # ensure the file size + f.seek(OUTPUT_FILE_SIZE - 1, os.SEEK_SET) + f.write(b"\x00") - f.seek(0, os.SEEK_SET) - f.write(bs[0 : ctypes.sizeof(header)]) + f.seek(0, os.SEEK_SET) + f.write(bs[0 : ctypes.sizeof(header)]) - offset = 0 - # write each sections - for i, addr in enumerate(FLASH_SEC_ADDR): - size = header.sec_size[i] - assert addr is not None or size == 0 + offset = 0 + # write each sections + for i, addr in enumerate(FLASH_SEC_ADDR): + size = header.sec_size[i] + assert addr is not None or size == 0 - if size == 0: - continue + if size == 0: + continue - f.seek(UPDATE_PDU_SIZE + addr, os.SEEK_SET) - f.write(data[offset : offset + size]) - offset += size + f.seek(UPDATE_PDU_SIZE + addr, os.SEEK_SET) + f.write(data[offset : offset + size]) + offset += size - f.flush() + f.flush() -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/util/update_release_branch.py b/util/update_release_branch.py index b9063d4970..4d9c89df4a 100755 --- a/util/update_release_branch.py +++ b/util/update_release_branch.py @@ -19,8 +19,7 @@ import subprocess import sys import textwrap - -BUG_NONE_PATTERN = re.compile('none', flags=re.IGNORECASE) +BUG_NONE_PATTERN = re.compile("none", flags=re.IGNORECASE) def git_commit_msg(branch, head, merge_head, rel_paths, cmd): @@ -42,18 +41,17 @@ def git_commit_msg(branch, head, merge_head, rel_paths, cmd): A String containing the git commit message with the exception of the Signed-Off-By field and Change-ID field. """ - relevant_commits_cmd, relevant_commits = get_relevant_commits(head, - merge_head, - '--oneline', - rel_paths) + relevant_commits_cmd, relevant_commits = get_relevant_commits( + head, merge_head, "--oneline", rel_paths + ) - _, relevant_bugs = get_relevant_commits(head, merge_head, '', rel_paths) - relevant_bugs = set(re.findall('BUG=(.*)', relevant_bugs)) + _, relevant_bugs = get_relevant_commits(head, merge_head, "", rel_paths) + relevant_bugs = set(re.findall("BUG=(.*)", relevant_bugs)) # Filter out "none" from set of bugs filtered = [] for bug_line in relevant_bugs: - bug_line = bug_line.replace(',', ' ') - bugs = bug_line.split(' ') + bug_line = bug_line.replace(",", " ") + bugs = bug_line.split(" ") for bug in bugs: if bug and not BUG_NONE_PATTERN.match(bug): filtered.append(bug) @@ -82,18 +80,20 @@ Cq-Include-Trybots: chromeos/cq:cq-orchestrator # 72 cols. relevant_commits_cmd = textwrap.fill(relevant_commits_cmd, width=72) # Wrap at 68 cols to save room for 'BUG=' - bugs = textwrap.wrap(' '.join(relevant_bugs), width=68) - bug_field = '' + bugs = textwrap.wrap(" ".join(relevant_bugs), width=68) + bug_field = "" for line in bugs: - bug_field += 'BUG=' + line + '\n' + bug_field += "BUG=" + line + "\n" # Remove the final newline since the template adds it for us. bug_field = bug_field[:-1] - return COMMIT_MSG_TEMPLATE.format(BRANCH=branch, - RELEVANT_COMMITS_CMD=relevant_commits_cmd, - RELEVANT_COMMITS=relevant_commits, - BUG_FIELD=bug_field, - COMMAND_LINE=cmd) + return COMMIT_MSG_TEMPLATE.format( + BRANCH=branch, + RELEVANT_COMMITS_CMD=relevant_commits_cmd, + RELEVANT_COMMITS=relevant_commits, + BUG_FIELD=bug_field, + COMMAND_LINE=cmd, + ) def get_relevant_boards(baseboard): @@ -105,15 +105,16 @@ def get_relevant_boards(baseboard): Returns: A list of strings containing the boards based off of the baseboard. """ - proc = subprocess.run(['git', 'grep', 'BASEBOARD:=' + baseboard, '--', - 'board/'], - stdout=subprocess.PIPE, - encoding='utf-8', - check=True) + proc = subprocess.run( + ["git", "grep", "BASEBOARD:=" + baseboard, "--", "board/"], + stdout=subprocess.PIPE, + encoding="utf-8", + check=True, + ) boards = [] res = proc.stdout.splitlines() for line in res: - boards.append(line.split('/')[1]) + boards.append(line.split("/")[1]) return boards @@ -135,21 +136,18 @@ def get_relevant_commits(head, merge_head, fmt, relevant_paths): stdout. """ if fmt: - cmd = ['git', 'log', fmt, head + '..' + merge_head, '--', - relevant_paths] + cmd = ["git", "log", fmt, head + ".." + merge_head, "--", relevant_paths] else: - cmd = ['git', 'log', head + '..' + merge_head, '--', relevant_paths] + cmd = ["git", "log", head + ".." + merge_head, "--", relevant_paths] # Pass cmd as a string to subprocess.run() since we need to run with shell # equal to True. The reason we are using shell equal to True is to take # advantage of the glob expansion for the relevant paths. - cmd = ' '.join(cmd) - proc = subprocess.run(cmd, - stdout=subprocess.PIPE, - encoding='utf-8', - check=True, - shell=True) - return ''.join(proc.args), proc.stdout + cmd = " ".join(cmd) + proc = subprocess.run( + cmd, stdout=subprocess.PIPE, encoding="utf-8", check=True, shell=True + ) + return "".join(proc.args), proc.stdout def main(argv): @@ -165,46 +163,61 @@ def main(argv): argv: A list of the command line arguments passed to this script. """ # Set up argument parser. - parser = argparse.ArgumentParser(description=('A script that generates a ' - 'merge commit from cros/main' - ' to a desired release ' - 'branch. By default, the ' - '"recursive" merge strategy ' - 'with the "theirs" strategy ' - 'option is used.')) - parser.add_argument('--baseboard') - parser.add_argument('--board') - parser.add_argument('release_branch', help=('The name of the target release' - ' branch')) - parser.add_argument('--relevant_paths_file', - help=('A path to a text file which includes other ' - 'relevant paths of interest for this board ' - 'or baseboard')) - parser.add_argument('--merge_strategy', '-s', default='recursive', - help='The merge strategy to pass to `git merge -s`') - parser.add_argument('--strategy_option', '-X', - help=('The strategy option for the chosen merge ' - 'strategy')) + parser = argparse.ArgumentParser( + description=( + "A script that generates a " + "merge commit from cros/main" + " to a desired release " + "branch. By default, the " + '"recursive" merge strategy ' + 'with the "theirs" strategy ' + "option is used." + ) + ) + parser.add_argument("--baseboard") + parser.add_argument("--board") + parser.add_argument( + "release_branch", help=("The name of the target release" " branch") + ) + parser.add_argument( + "--relevant_paths_file", + help=( + "A path to a text file which includes other " + "relevant paths of interest for this board " + "or baseboard" + ), + ) + parser.add_argument( + "--merge_strategy", + "-s", + default="recursive", + help="The merge strategy to pass to `git merge -s`", + ) + parser.add_argument( + "--strategy_option", + "-X", + help=("The strategy option for the chosen merge " "strategy"), + ) opts = parser.parse_args(argv[1:]) - baseboard_dir = '' - board_dir = '' + baseboard_dir = "" + board_dir = "" if opts.baseboard: # Dereference symlinks so "git log" works as expected. - baseboard_dir = os.path.relpath('baseboard/' + opts.baseboard) + baseboard_dir = os.path.relpath("baseboard/" + opts.baseboard) baseboard_dir = os.path.relpath(os.path.realpath(baseboard_dir)) boards = get_relevant_boards(opts.baseboard) elif opts.board: - board_dir = os.path.relpath('board/' + opts.board) + board_dir = os.path.relpath("board/" + opts.board) board_dir = os.path.relpath(os.path.realpath(board_dir)) boards = [opts.board] else: - parser.error('You must specify a board OR a baseboard') + parser.error("You must specify a board OR a baseboard") - print('Gathering relevant paths...') + print("Gathering relevant paths...") relevant_paths = [] if opts.baseboard: relevant_paths.append(baseboard_dir) @@ -212,65 +225,91 @@ def main(argv): relevant_paths.append(board_dir) for board in boards: - relevant_paths.append('board/' + board) + relevant_paths.append("board/" + board) # Check for the existence of a file that has other paths of interest. if opts.relevant_paths_file and os.path.exists(opts.relevant_paths_file): - with open(opts.relevant_paths_file, 'r') as relevant_paths_file: + with open(opts.relevant_paths_file, "r") as relevant_paths_file: for line in relevant_paths_file: - if not line.startswith('#'): + if not line.startswith("#"): relevant_paths.append(line.rstrip()) - relevant_paths.append('util/getversion.sh') - relevant_paths = ' '.join(relevant_paths) + relevant_paths.append("util/getversion.sh") + relevant_paths = " ".join(relevant_paths) # Check if we are already in merge process - result = subprocess.run(['git', 'rev-parse', '--quiet', '--verify', - 'MERGE_HEAD'], stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, check=False) + result = subprocess.run( + ["git", "rev-parse", "--quiet", "--verify", "MERGE_HEAD"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) if result.returncode: # Let's perform the merge - print('Updating remote...') - subprocess.run(['git', 'remote', 'update'], check=True) - subprocess.run(['git', 'checkout', '-B', opts.release_branch, 'cros/' + - opts.release_branch], check=True) - print('Attempting git merge...') - if opts.merge_strategy == 'recursive' and not opts.strategy_option: - opts.strategy_option = 'theirs' - print('Using "%s" merge strategy' % opts.merge_strategy, - ("with strategy option '%s'" % opts.strategy_option - if opts.strategy_option else '')) - arglist = ['git', 'merge', '--no-ff', '--no-commit', 'cros/main', '-s', - opts.merge_strategy] + print("Updating remote...") + subprocess.run(["git", "remote", "update"], check=True) + subprocess.run( + [ + "git", + "checkout", + "-B", + opts.release_branch, + "cros/" + opts.release_branch, + ], + check=True, + ) + print("Attempting git merge...") + if opts.merge_strategy == "recursive" and not opts.strategy_option: + opts.strategy_option = "theirs" + print( + 'Using "%s" merge strategy' % opts.merge_strategy, + ( + "with strategy option '%s'" % opts.strategy_option + if opts.strategy_option + else "" + ), + ) + arglist = [ + "git", + "merge", + "--no-ff", + "--no-commit", + "cros/main", + "-s", + opts.merge_strategy, + ] if opts.strategy_option: - arglist.append('-X' + opts.strategy_option) + arglist.append("-X" + opts.strategy_option) subprocess.run(arglist, check=True) else: - print('We have already started merge process.', - 'Attempt to generate commit.') - - print('Generating commit message...') - branch = subprocess.run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stdout=subprocess.PIPE, - encoding='utf-8', - check=True).stdout.rstrip() - head = subprocess.run(['git', 'rev-parse', '--short', 'HEAD'], - stdout=subprocess.PIPE, - encoding='utf-8', - check=True).stdout.rstrip() - merge_head = subprocess.run(['git', 'rev-parse', '--short', - 'MERGE_HEAD'], - stdout=subprocess.PIPE, - encoding='utf-8', - check=True).stdout.rstrip() - - cmd = ' '.join(argv) - print('Typing as fast as I can...') + print("We have already started merge process.", "Attempt to generate commit.") + + print("Generating commit message...") + branch = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + stdout=subprocess.PIPE, + encoding="utf-8", + check=True, + ).stdout.rstrip() + head = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + stdout=subprocess.PIPE, + encoding="utf-8", + check=True, + ).stdout.rstrip() + merge_head = subprocess.run( + ["git", "rev-parse", "--short", "MERGE_HEAD"], + stdout=subprocess.PIPE, + encoding="utf-8", + check=True, + ).stdout.rstrip() + + cmd = " ".join(argv) + print("Typing as fast as I can...") commit_msg = git_commit_msg(branch, head, merge_head, relevant_paths, cmd) - subprocess.run(['git', 'commit', '--signoff', '-m', commit_msg], check=True) - subprocess.run(['git', 'commit', '--amend'], check=True) - print(("Finished! **Please review the commit to see if it's to your " - 'liking.**')) + subprocess.run(["git", "commit", "--signoff", "-m", commit_msg], check=True) + subprocess.run(["git", "commit", "--amend"], check=True) + print(("Finished! **Please review the commit to see if it's to your " "liking.**")) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv) diff --git a/zephyr/firmware_builder.py b/zephyr/firmware_builder.py index f77e51d6c4..c0963a84db 100755 --- a/zephyr/firmware_builder.py +++ b/zephyr/firmware_builder.py @@ -15,14 +15,12 @@ import re import subprocess import sys -from google.protobuf import json_format # pylint: disable=import-error import zmake.project - from chromite.api.gen_sdk.chromite.api import firmware_pb2 +from google.protobuf import json_format # pylint: disable=import-error - -DEFAULT_BUNDLE_DIRECTORY = '/tmp/artifact_bundles' -DEFAULT_BUNDLE_METADATA_FILE = '/tmp/artifact_bundle_metadata' +DEFAULT_BUNDLE_DIRECTORY = "/tmp/artifact_bundles" +DEFAULT_BUNDLE_METADATA_FILE = "/tmp/artifact_bundle_metadata" def build(opts): @@ -33,54 +31,52 @@ def build(opts): platform_ec = zephyr_dir.resolve().parent subprocess.run([platform_ec / "util" / "check_clang_format.py"], check=True) - cmd = ['zmake', '-D', 'build', '-a'] + cmd = ["zmake", "-D", "build", "-a"] if opts.code_coverage: - cmd.append('--coverage') + cmd.append("--coverage") subprocess.run(cmd, cwd=pathlib.Path(__file__).parent, check=True) if not opts.code_coverage: for project in zmake.project.find_projects(zephyr_dir).values(): if project.config.is_test: continue - build_dir = ( - platform_ec / 'build' / 'zephyr' / project.config.project_name - ) + build_dir = platform_ec / "build" / "zephyr" / project.config.project_name metric = metric_list.value.add() metric.target_name = project.config.project_name metric.platform_name = project.config.zephyr_board for (variant, _) in project.iter_builds(): - build_log = build_dir / f'build-{variant}' / 'build.log' + build_log = build_dir / f"build-{variant}" / "build.log" parse_buildlog(build_log, metric, variant.upper()) - with open(opts.metrics, 'w') as file: + with open(opts.metrics, "w") as file: file.write(json_format.MessageToJson(metric_list)) return 0 UNITS = { - 'B': 1, - 'KB': 1024, - 'MB': 1024 * 1024, - 'GB': 1024 * 1024 * 1024, + "B": 1, + "KB": 1024, + "MB": 1024 * 1024, + "GB": 1024 * 1024 * 1024, } def parse_buildlog(filename, metric, variant): """Parse the build.log generated by zmake to find the size of the image.""" - with open(filename, 'r') as infile: + with open(filename, "r") as infile: # Skip over all lines until the memory report is found while True: line = infile.readline() if not line: return - if line.startswith('Memory region'): + if line.startswith("Memory region"): break for line in infile.readlines(): # Skip any lines that are not part of the report - if not line.startswith(' '): + if not line.startswith(" "): continue parts = line.split() fw_section = metric.fw_section.add() - fw_section.region = variant + '_' + parts[0][:-1] + fw_section.region = variant + "_" + parts[0][:-1] fw_section.used = int(parts[1]) * UNITS[parts[2]] fw_section.total = int(parts[3]) * UNITS[parts[4]] fw_section.track_on_gerrit = False @@ -114,7 +110,7 @@ def write_metadata(opts, info): bundle_metadata_file = ( opts.metadata if opts.metadata else DEFAULT_BUNDLE_METADATA_FILE ) - with open(bundle_metadata_file, 'w') as file: + with open(bundle_metadata_file, "w") as file: file.write(json_format.MessageToJson(info)) @@ -125,10 +121,10 @@ def bundle_coverage(opts): bundle_dir = get_bundle_dir(opts) zephyr_dir = pathlib.Path(__file__).parent platform_ec = zephyr_dir.resolve().parent - build_dir = platform_ec / 'build' / 'zephyr' - tarball_name = 'coverage.tbz2' + build_dir = platform_ec / "build" / "zephyr" + tarball_name = "coverage.tbz2" tarball_path = bundle_dir / tarball_name - cmd = ['tar', 'cvfj', tarball_path, 'lcov.info'] + cmd = ["tar", "cvfj", tarball_path, "lcov.info"] subprocess.run(cmd, cwd=build_dir, check=True) meta = info.objects.add() meta.file_name = tarball_name @@ -149,13 +145,11 @@ def bundle_firmware(opts): for project in zmake.project.find_projects(zephyr_dir).values(): if project.config.is_test: continue - build_dir = ( - platform_ec / 'build' / 'zephyr' / project.config.project_name - ) - artifacts_dir = build_dir / 'output' - tarball_name = f'{project.config.project_name}.firmware.tbz2' + build_dir = platform_ec / "build" / "zephyr" / project.config.project_name + artifacts_dir = build_dir / "output" + tarball_name = f"{project.config.project_name}.firmware.tbz2" tarball_path = bundle_dir.joinpath(tarball_name) - cmd = ['tar', 'cvfj', tarball_path, '.'] + cmd = ["tar", "cvfj", tarball_path, "."] subprocess.run(cmd, cwd=artifacts_dir, check=True) meta = info.objects.add() meta.file_name = tarball_name @@ -176,76 +170,92 @@ def test(opts): # Run zmake tests to ensure we have a fully working zmake before # proceeding. - subprocess.run([zephyr_dir / 'zmake' / 'run_tests.sh'], check=True) + subprocess.run([zephyr_dir / "zmake" / "run_tests.sh"], check=True) # Run formatting checks on all BUILD.py files. - config_files = zephyr_dir.rglob('**/BUILD.py') - subprocess.run(['black', '--diff', '--check', *config_files], check=True) + config_files = zephyr_dir.rglob("**/BUILD.py") + subprocess.run(["black", "--diff", "--check", *config_files], check=True) - cmd = ['zmake', '-D', 'test', '-a', '--no-rebuild'] + cmd = ["zmake", "-D", "test", "-a", "--no-rebuild"] if opts.code_coverage: - cmd.append('--coverage') + cmd.append("--coverage") ret = subprocess.run(cmd, check=True).returncode if ret: return ret if opts.code_coverage: platform_ec = zephyr_dir.parent - build_dir = platform_ec / 'build' / 'zephyr' + build_dir = platform_ec / "build" / "zephyr" # Merge lcov files here because bundle failures are "infra" failures. cmd = [ - '/usr/bin/lcov', - '-o', - build_dir / 'zephyr_merged.info', - '--rc', - 'lcov_branch_coverage=1', - '-a', - build_dir / 'all_tests.info', - '-a', - build_dir / 'all_builds.info', + "/usr/bin/lcov", + "-o", + build_dir / "zephyr_merged.info", + "--rc", + "lcov_branch_coverage=1", + "-a", + build_dir / "all_tests.info", + "-a", + build_dir / "all_builds.info", ] output = subprocess.run( - cmd, cwd=pathlib.Path(__file__).parent, check=True, - stdout=subprocess.PIPE, universal_newlines=True).stdout - _extract_lcov_summary('EC_ZEPHYR_MERGED', metrics, output) + cmd, + cwd=pathlib.Path(__file__).parent, + check=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ).stdout + _extract_lcov_summary("EC_ZEPHYR_MERGED", metrics, output) output = subprocess.run( - ['/usr/bin/lcov', '--summary', build_dir / 'all_tests.info'], - cwd=pathlib.Path(__file__).parent, check=True, - stdout=subprocess.PIPE, universal_newlines=True).stdout - _extract_lcov_summary('EC_ZEPHYR_TESTS', metrics, output) - - cmd = ['make', 'coverage', f'-j{opts.cpus}'] + ["/usr/bin/lcov", "--summary", build_dir / "all_tests.info"], + cwd=pathlib.Path(__file__).parent, + check=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ).stdout + _extract_lcov_summary("EC_ZEPHYR_TESTS", metrics, output) + + cmd = ["make", "coverage", f"-j{opts.cpus}"] print(f"# Running {' '.join(cmd)}.") subprocess.run(cmd, cwd=platform_ec, check=True) output = subprocess.run( - ['/usr/bin/lcov', '--summary', platform_ec / 'build/coverage/lcov.info'], - cwd=pathlib.Path(__file__).parent, check=True, - stdout=subprocess.PIPE, universal_newlines=True).stdout - _extract_lcov_summary('EC_LEGACY_MERGED', metrics, output) + ["/usr/bin/lcov", "--summary", platform_ec / "build/coverage/lcov.info"], + cwd=pathlib.Path(__file__).parent, + check=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ).stdout + _extract_lcov_summary("EC_LEGACY_MERGED", metrics, output) cmd = [ - '/usr/bin/lcov', - '-o', - build_dir / 'lcov.info', - '--rc', - 'lcov_branch_coverage=1', - '-a', - build_dir / 'zephyr_merged.info', - '-a', - platform_ec / 'build/coverage/lcov.info', + "/usr/bin/lcov", + "-o", + build_dir / "lcov.info", + "--rc", + "lcov_branch_coverage=1", + "-a", + build_dir / "zephyr_merged.info", + "-a", + platform_ec / "build/coverage/lcov.info", ] output = subprocess.run( - cmd, cwd=pathlib.Path(__file__).parent, check=True, - stdout=subprocess.PIPE, universal_newlines=True).stdout - _extract_lcov_summary('ALL_MERGED', metrics, output) - - with open(opts.metrics, 'w') as file: + cmd, + cwd=pathlib.Path(__file__).parent, + check=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ).stdout + _extract_lcov_summary("ALL_MERGED", metrics, output) + + with open(opts.metrics, "w") as file: file.write(json_format.MessageToJson(metrics)) return 0 -COVERAGE_RE = re.compile(r'lines\.*: *([0-9\.]+)% \(([0-9]+) of ([0-9]+) lines\)') +COVERAGE_RE = re.compile(r"lines\.*: *([0-9\.]+)% \(([0-9]+) of ([0-9]+) lines\)") + + def _extract_lcov_summary(name, metrics, output): re_match = COVERAGE_RE.search(output) if re_match: @@ -255,12 +265,13 @@ def _extract_lcov_summary(name, metrics, output): metric.covered_lines = int(re_match.group(2)) metric.total_lines = int(re_match.group(3)) + def main(args): """Builds and tests all of the Zephyr targets and reports build metrics""" opts = parse_args(args) - if not hasattr(opts, 'func'): - print('Must select a valid sub command!') + if not hasattr(opts, "func"): + print("Must select a valid sub command!") return -1 # Run selected sub command function @@ -272,70 +283,66 @@ def parse_args(args): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--cpus', + "--cpus", default=multiprocessing.cpu_count(), - help='The number of cores to use.', + help="The number of cores to use.", ) parser.add_argument( - '--metrics', - dest='metrics', + "--metrics", + dest="metrics", required=True, - help='File to write the json-encoded MetricsList proto message.', + help="File to write the json-encoded MetricsList proto message.", ) parser.add_argument( - '--metadata', + "--metadata", required=False, help=( - 'Full pathname for the file in which to write build artifact ' - 'metadata.' + "Full pathname for the file in which to write build artifact " "metadata." ), ) parser.add_argument( - '--output-dir', + "--output-dir", required=False, - help=( - 'Full pathname for the directory in which to bundle build ' - 'artifacts.' - ), + help=("Full pathname for the directory in which to bundle build " "artifacts."), ) parser.add_argument( - '--code-coverage', + "--code-coverage", required=False, - action='store_true', - help='Build host-based unit tests for code coverage.', + action="store_true", + help="Build host-based unit tests for code coverage.", ) parser.add_argument( - '--bcs-version', - dest='bcs_version', - default='', + "--bcs-version", + dest="bcs_version", + default="", required=False, # TODO(b/180008931): make this required=True. - help='BCS version to include in metadata.', + help="BCS version to include in metadata.", ) # Would make this required=True, but not available until 3.7 sub_cmds = parser.add_subparsers() - build_cmd = sub_cmds.add_parser('build', help='Builds all firmware targets') + build_cmd = sub_cmds.add_parser("build", help="Builds all firmware targets") build_cmd.set_defaults(func=build) build_cmd = sub_cmds.add_parser( - 'bundle', - help='Creates a tarball containing build ' - 'artifacts from all firmware targets', + "bundle", + help="Creates a tarball containing build " + "artifacts from all firmware targets", ) build_cmd.set_defaults(func=bundle) - test_cmd = sub_cmds.add_parser('test', help='Runs all firmware unit tests') + test_cmd = sub_cmds.add_parser("test", help="Runs all firmware unit tests") test_cmd.set_defaults(func=test) return parser.parse_args(args) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/zephyr/zmake/tests/conftest.py b/zephyr/zmake/tests/conftest.py index 38e34bef56..be1de01401 100644 --- a/zephyr/zmake/tests/conftest.py +++ b/zephyr/zmake/tests/conftest.py @@ -9,7 +9,6 @@ import pathlib import hypothesis import pytest - import zmake.zmake as zm hypothesis.settings.register_profile( diff --git a/zephyr/zmake/tests/test_build_config.py b/zephyr/zmake/tests/test_build_config.py index 76cc0a2028..f79ed1f8a0 100644 --- a/zephyr/zmake/tests/test_build_config.py +++ b/zephyr/zmake/tests/test_build_config.py @@ -13,7 +13,6 @@ import tempfile import hypothesis import hypothesis.strategies as st import pytest - import zmake.jobserver import zmake.util as util from zmake.build_config import BuildConfig diff --git a/zephyr/zmake/tests/test_generate_readme.py b/zephyr/zmake/tests/test_generate_readme.py index cb4bcf6cc1..2149b3fc6e 100644 --- a/zephyr/zmake/tests/test_generate_readme.py +++ b/zephyr/zmake/tests/test_generate_readme.py @@ -7,7 +7,6 @@ Tests for the generate_readme.py file. """ import pytest - import zmake.generate_readme as gen_readme diff --git a/zephyr/zmake/tests/test_modules.py b/zephyr/zmake/tests/test_modules.py index 600544d2e7..9446e54f1c 100644 --- a/zephyr/zmake/tests/test_modules.py +++ b/zephyr/zmake/tests/test_modules.py @@ -9,7 +9,6 @@ import tempfile import hypothesis import hypothesis.strategies as st - import zmake.modules module_lists = st.lists( diff --git a/zephyr/zmake/tests/test_packers.py b/zephyr/zmake/tests/test_packers.py index 43b63a908f..402cee690e 100644 --- a/zephyr/zmake/tests/test_packers.py +++ b/zephyr/zmake/tests/test_packers.py @@ -10,7 +10,6 @@ import tempfile import hypothesis import hypothesis.strategies as st import pytest - import zmake.output_packers as packers # Strategies for use with hypothesis diff --git a/zephyr/zmake/tests/test_project.py b/zephyr/zmake/tests/test_project.py index b5facbc331..5b5ca12583 100644 --- a/zephyr/zmake/tests/test_project.py +++ b/zephyr/zmake/tests/test_project.py @@ -11,7 +11,6 @@ import tempfile import hypothesis import hypothesis.strategies as st import pytest - import zmake.modules import zmake.output_packers import zmake.project diff --git a/zephyr/zmake/tests/test_reexec.py b/zephyr/zmake/tests/test_reexec.py index 5d7905cd8f..08943909b2 100644 --- a/zephyr/zmake/tests/test_reexec.py +++ b/zephyr/zmake/tests/test_reexec.py @@ -8,7 +8,6 @@ import sys import unittest.mock as mock import pytest - import zmake.__main__ as main diff --git a/zephyr/zmake/tests/test_toolchains.py b/zephyr/zmake/tests/test_toolchains.py index 910a5faa78..f210bb7511 100644 --- a/zephyr/zmake/tests/test_toolchains.py +++ b/zephyr/zmake/tests/test_toolchains.py @@ -8,7 +8,6 @@ import os import pathlib import pytest - import zmake.output_packers import zmake.project as project import zmake.toolchains as toolchains diff --git a/zephyr/zmake/tests/test_util.py b/zephyr/zmake/tests/test_util.py index 1ec0076162..4a6c39f904 100644 --- a/zephyr/zmake/tests/test_util.py +++ b/zephyr/zmake/tests/test_util.py @@ -10,7 +10,6 @@ import tempfile import hypothesis import hypothesis.strategies as st import pytest - import zmake.util as util # Strategies for use with hypothesis diff --git a/zephyr/zmake/tests/test_zmake.py b/zephyr/zmake/tests/test_zmake.py index 4ca1d7f077..1c892ca2e4 100644 --- a/zephyr/zmake/tests/test_zmake.py +++ b/zephyr/zmake/tests/test_zmake.py @@ -11,14 +11,13 @@ import re import unittest.mock import pytest -from testfixtures import LogCapture - import zmake.build_config import zmake.jobserver import zmake.multiproc as multiproc import zmake.output_packers import zmake.project import zmake.toolchains +from testfixtures import LogCapture OUR_PATH = os.path.dirname(os.path.realpath(__file__)) |