diff options
Diffstat (limited to 'chromium/third_party/openscreen')
186 files changed, 7280 insertions, 2338 deletions
diff --git a/chromium/third_party/openscreen/src/BUILD.gn b/chromium/third_party/openscreen/src/BUILD.gn index cbd318c934f..ea5b2658f25 100644 --- a/chromium/third_party/openscreen/src/BUILD.gn +++ b/chromium/third_party/openscreen/src/BUILD.gn @@ -107,9 +107,7 @@ source_set("openscreen_unittests_all") { if (!build_with_chromium) { executable("openscreen_unittests") { testonly = true - deps = [ - ":openscreen_unittests_all", - ] + deps = [ ":openscreen_unittests_all" ] } } @@ -118,6 +116,7 @@ if (!build_with_chromium && is_posix) { testonly = true public_deps = [ "cast/common:discovery_e2e_test", + "cast/standalone_receiver:e2e_tests", "cast/test:e2e_tests", "cast/test:make_crl_tests($host_toolchain)", "test:test_main", @@ -126,8 +125,6 @@ if (!build_with_chromium && is_posix) { executable("e2e_tests") { testonly = true - deps = [ - ":e2e_tests_all", - ] + deps = [ ":e2e_tests_all" ] } } diff --git a/chromium/third_party/openscreen/src/DEPS b/chromium/third_party/openscreen/src/DEPS index 82eef6ca381..22665a78b30 100644 --- a/chromium/third_party/openscreen/src/DEPS +++ b/chromium/third_party/openscreen/src/DEPS @@ -51,6 +51,12 @@ deps = { 'condition': 'not build_with_chromium', }, + 'third_party/libprotobuf-mutator/src': { + 'url': Var('chromium_git') + + '/external/github.com/google/libprotobuf-mutator.git' + '@' '439e81f8f4847ec6e2bf11b3aa634a5d8485633d', + 'condition': 'not build_with_chromium', + }, + 'third_party/zlib/src': { 'url': Var('github') + '/madler/zlib.git' + @@ -178,6 +184,7 @@ include_rules = [ '+absl/algorithm/container.h', '+absl/base/thread_annotations.h', '+absl/hash/hash.h', + '+absl/hash/hash_testing.h', '+absl/strings/ascii.h', '+absl/strings/match.h', '+absl/strings/numbers.h', diff --git a/chromium/third_party/openscreen/src/PRESUBMIT.py b/chromium/third_party/openscreen/src/PRESUBMIT.py index fd7cca881ea..1178061c666 100755 --- a/chromium/third_party/openscreen/src/PRESUBMIT.py +++ b/chromium/third_party/openscreen/src/PRESUBMIT.py @@ -67,9 +67,15 @@ def _CommonChecks(input_api, output_api): results.extend(input_api.canned_checks.CheckChangeTodoHasOwner( input_api, output_api)) - # Linter + # Linter. + # - We disable c++11 header checks since Open Screen allows them. + # - We disable whitespace/braces because of various false positives. + # - There are some false positives with 'explicit' checks, but it's useful + # enough to keep. results.extend(input_api.canned_checks.CheckChangeLintsClean( - input_api, output_api, lint_filters = None, verbose_level=4)) + input_api, output_api, + lint_filters = ['-build/c++11', '-whitespace/braces'], + verbose_level=4)) # clang-format results.extend(input_api.canned_checks.CheckPatchFormatted( diff --git a/chromium/third_party/openscreen/src/README.md b/chromium/third_party/openscreen/src/README.md index a03d6d419ab..6e3c999b56b 100644 --- a/chromium/third_party/openscreen/src/README.md +++ b/chromium/third_party/openscreen/src/README.md @@ -208,14 +208,15 @@ target. openscreen uses [LUCI builders](https://ci.chromium.org/p/openscreen/builders) to monitor the build and test health of the library. Current builders include: -| Name | Arch | OS | Toolchain | Build | Notes | -|------------------------|--------|--------------------|-----------|---------|--------------| -| linux64_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | ASAN enabled | -| linux64_gcc_debug | x86-64 | Ubuntu Linux 18.04 | gcc-7 | debug | | -| linux64_tsan | x86-64 | Ubuntu Linux 16.04 | clang | release | TSAN enabled | -| mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | | -| chromium_linux64_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | built within chromium | -| chromium_mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | built within chromium | +| Name | Arch | OS | Toolchain | Build | Notes | +|------------------------|--------|--------------------|-----------|---------|------------------------| +| linux64_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | ASAN enabled | +| linux64_gcc_debug | x86-64 | Ubuntu Linux 18.04 | gcc-7 | debug | | +| linux64_tsan | x86-64 | Ubuntu Linux 16.04 | clang | release | TSAN enabled | +| mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | | +| chromium_linux64_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | built within chromium | +| chromium_mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | built within chromium | +| linux64_coverage_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | used for code coverage | You can run a patch through the try job queue (which tests it on all non-chromium builders) using `git cl try`, or through Gerrit (details below). diff --git a/chromium/third_party/openscreen/src/build/code_coverage/merge_lib.py b/chromium/third_party/openscreen/src/build/code_coverage/merge_lib.py index 818044c39a3..4b956d06a33 100644 --- a/chromium/third_party/openscreen/src/build/code_coverage/merge_lib.py +++ b/chromium/third_party/openscreen/src/build/code_coverage/merge_lib.py @@ -29,7 +29,7 @@ logging.basicConfig( def _call_profdata_tool(profile_input_file_paths, profile_output_file_path, profdata_tool_path, - retries=3): + sparse=True): """Calls the llvm-profdata tool. Args: @@ -37,6 +37,8 @@ def _call_profdata_tool(profile_input_file_paths, are to be merged. profile_output_file_path: The path to the merged file to write. profdata_tool_path: The path to the llvm-profdata executable. + sparse (bool): flag to indicate whether to run llvm-profdata with --sparse. + Doc: https://llvm.org/docs/CommandGuide/llvm-profdata.html#profdata-merge Returns: A list of paths to profiles that had to be excluded to get the merge to @@ -45,12 +47,16 @@ def _call_profdata_tool(profile_input_file_paths, Raises: CalledProcessError: An error occurred merging profiles. """ + logging.debug('Profile input paths: %r' % profile_input_file_paths) + logging.debug('Profile output path: %r' % profile_output_file_path) try: subprocess_cmd = [ profdata_tool_path, 'merge', '-o', profile_output_file_path, - '-sparse=true' ] + if sparse: + subprocess_cmd += ['-sparse=true',] subprocess_cmd.extend(profile_input_file_paths) + logging.info('profdata command: %r', ' '.join(subprocess_cmd)) # Redirecting stderr is required because when error happens, llvm-profdata # writes the error output to stderr and our error handling logic relies on @@ -58,34 +64,6 @@ def _call_profdata_tool(profile_input_file_paths, output = subprocess.check_output(subprocess_cmd, stderr=subprocess.STDOUT) logging.info('Merge succeeded with output: %r', output) except subprocess.CalledProcessError as error: - if len(profile_input_file_paths) > 1 and retries >= 0: - logging.warning('Merge failed with error output: %r', error.output) - - # The output of the llvm-profdata command will include the path of - # malformed files, such as - # `error: /.../default.profraw: Malformed instrumentation profile data` - invalid_profiles = [ - f for f in profile_input_file_paths if f in error.output - ] - - if not invalid_profiles: - logging.info( - 'Merge failed, but wasn\'t able to figure out the culprit invalid ' - 'profiles from the output, so skip retry and bail out.') - raise error - - valid_profiles = list( - set(profile_input_file_paths) - set(invalid_profiles)) - if valid_profiles: - logging.warning( - 'Following invalid profiles are removed as they were mentioned in ' - 'the merge error output: %r', invalid_profiles) - logging.info('Retry merging with the remaining profiles: %r', - valid_profiles) - return invalid_profiles + _call_profdata_tool( - valid_profiles, profile_output_file_path, profdata_tool_path, - retries - 1) - logging.error('Failed to merge profiles, return code (%d), output: %r' % (error.returncode, error.output)) raise error @@ -108,7 +86,9 @@ def _get_profile_paths(input_dir, return paths -def _validate_and_convert_profraws(profraw_files, profdata_tool_path): +def _validate_and_convert_profraws(profraw_files, + profdata_tool_path, + sparse=True): """Validates and converts profraws to profdatas. For each given .profraw file in the input, this method first validates it by @@ -121,6 +101,8 @@ def _validate_and_convert_profraws(profraw_files, profdata_tool_path): Args: profraw_files: A list of .profraw paths. profdata_tool_path: The path to the llvm-profdata executable. + sparse (bool): flag to indicate whether to run llvm-profdata with --sparse. + Doc: https://llvm.org/docs/CommandGuide/llvm-profdata.html#profdata-merge Returns: A tulple: @@ -144,7 +126,7 @@ def _validate_and_convert_profraws(profraw_files, profdata_tool_path): pool.apply_async( _validate_and_convert_profraw, (profraw_file, output_profdata_files, invalid_profraw_files, - counter_overflows, profdata_tool_path)) + counter_overflows, profdata_tool_path, sparse)) pool.close() pool.join() @@ -159,16 +141,25 @@ def _validate_and_convert_profraws(profraw_files, profdata_tool_path): def _validate_and_convert_profraw(profraw_file, output_profdata_files, invalid_profraw_files, counter_overflows, - profdata_tool_path): + profdata_tool_path, sparse=True): output_profdata_file = profraw_file.replace('.profraw', '.profdata') subprocess_cmd = [ - profdata_tool_path, 'merge', '-o', output_profdata_file, '-sparse=true', - profraw_file + profdata_tool_path, + 'merge', + '-o', + output_profdata_file, ] + if sparse: + subprocess_cmd.append('--sparse') + + subprocess_cmd.append(profraw_file) + profile_valid = False counter_overflow = False validation_output = None + logging.info('profdata command: %r', ' '.join(subprocess_cmd)) + # 1. Determine if the profile is valid. try: # Redirecting stderr is required because when error happens, llvm-profdata @@ -237,7 +228,9 @@ def merge_profiles(input_dir, output_file, input_extension, profdata_tool_path, - input_filename_pattern='.*'): + input_filename_pattern='.*', + sparse=True, + skip_validation=False): """Merges the profiles produced by the shards using llvm-profdata. Args: @@ -248,6 +241,11 @@ def merge_profiles(input_dir, profdata_tool_path: The path to the llvm-profdata executable. input_filename_pattern (str): The regex pattern of input filename. Should be a valid regex pattern if present. + sparse (bool): flag to indicate whether to run llvm-profdata with --sparse. + Doc: https://llvm.org/docs/CommandGuide/llvm-profdata.html#profdata-merge + skip_validation (bool): flag to skip the _validate_and_convert_profraws + invocation. only applicable when input_extension is .profraw. + Returns: The list of profiles that had to be excluded to get the merge to succeed and a list of profiles that had a counter overflow. @@ -257,10 +255,16 @@ def merge_profiles(input_dir, input_filename_pattern) invalid_profraw_files = [] counter_overflows = [] - if input_extension == '.profraw': + + if skip_validation: + logging.warning('--skip-validation has been enabled. Skipping conversion ' + 'to ensure that profiles are valid.') + + if input_extension == '.profraw' and not skip_validation: profile_input_file_paths, invalid_profraw_files, counter_overflows = ( _validate_and_convert_profraws(profile_input_file_paths, - profdata_tool_path)) + profdata_tool_path, + sparse=sparse)) logging.info('List of converted .profdata files: %r', profile_input_file_paths) logging.info(( @@ -284,7 +288,8 @@ def merge_profiles(input_dir, invalid_profdata_files = _call_profdata_tool( profile_input_file_paths=profile_input_file_paths, profile_output_file_path=output_file, - profdata_tool_path=profdata_tool_path) + profdata_tool_path=profdata_tool_path, + sparse=sparse) # Remove inputs when merging profraws as they won't be needed and they can be # pretty large. If the inputs are profdata files, do not remove them as they @@ -320,3 +325,4 @@ def get_shards_to_retry(bad_profiles): assert is_task_id(task_id) bad_shard_ids.add(task_id) return bad_shard_ids + diff --git a/chromium/third_party/openscreen/src/build/code_coverage/merge_results.py b/chromium/third_party/openscreen/src/build/code_coverage/merge_results.py index a049a90639b..67e63365a04 100644 --- a/chromium/third_party/openscreen/src/build/code_coverage/merge_results.py +++ b/chromium/third_party/openscreen/src/build/code_coverage/merge_results.py @@ -35,6 +35,32 @@ def _MergeAPIArgumentParser(*args, **kwargs): '--per-cl-coverage', action='store_true', help='set to indicate that this is a per-CL coverage build') + # TODO(crbug.com/1077304) - migrate this to sparse=False as default, and have + # --sparse to set sparse + parser.add_argument( + '--no-sparse', + action='store_false', + dest='sparse', + help='run llvm-profdata without the sparse flag.') + # TODO(crbug.com/1077304) - The intended behaviour is to default sparse to + # false. --no-sparse above was added as a workaround, and will be removed. + # This is being introduced now in support of the migration to intended + # behavior. Ordering of args matters here, as the default is set by the former + # (sparse defaults to False because of ordering. See unit tests for details) + parser.add_argument( + '--sparse', + action='store_true', + dest='sparse', + help='run llvm-profdata with the sparse flag.') + # (crbug.com/1091310) - IR PGO is incompatible with the initial conversion + # of .profraw -> .profdata that's run to detect validation errors. + # Introducing a bypass flag that'll merge all .profraw directly to .profdata + parser.add_argument( + '--skip-validation', + action='store_true', + help='skip validation for good raw profile data. this will pass all ' + 'raw profiles found to llvm-profdata to be merged. only applicable ' + 'when input extension is .profraw.') return parser @@ -47,7 +73,9 @@ def main(): invalid_profiles, counter_overflows = coverage_merger.merge_profiles( params.task_output_dir, os.path.join(params.profdata_dir, output_prodata_filename), '.profraw', - params.llvm_profdata) + params.llvm_profdata, + sparse=params.sparse, + skip_validation=params.skip_validation) # At the moment counter overflows overlap with invalid profiles, but this is # not guaranteed to remain the case indefinitely. To avoid future conflicts diff --git a/chromium/third_party/openscreen/src/build/code_coverage/merge_steps.py b/chromium/third_party/openscreen/src/build/code_coverage/merge_steps.py index 40d0e94639d..f11409389f2 100644 --- a/chromium/third_party/openscreen/src/build/code_coverage/merge_steps.py +++ b/chromium/third_party/openscreen/src/build/code_coverage/merge_steps.py @@ -28,6 +28,24 @@ def _merge_steps_argument_parser(*args, **kwargs): default='.*', help='regex pattern of profdata filename to merge for current test type. ' 'If not present, all profdata files will be merged.') + # TODO(crbug.com/1077304) - migrate this to sparse=False as default, and have + # --sparse to set sparse + parser.add_argument( + '--no-sparse', + action='store_false', + dest='sparse', + help='run llvm-profdata without the sparse flag.') + # TODO(crbug.com/1077304) - The intended behaviour is to default sparse to + # false. --no-sparse above was added as a workaround, and will be removed. + # This is being introduced now in support of the migration to intended + # behavior. Ordering of args matters here, as the default is set by the former + # (sparse defaults to False because of ordering. See merge_results unit tests + # for details) + parser.add_argument( + '--sparse', + action='store_true', + dest='sparse', + help='run llvm-profdata with the sparse flag.') return parser @@ -36,7 +54,8 @@ def main(): parser = _merge_steps_argument_parser(description=desc) params = parser.parse_args() merger.merge_profiles(params.input_dir, params.output_file, '.profdata', - params.llvm_profdata, params.profdata_filename_pattern) + params.llvm_profdata, params.profdata_filename_pattern, + sparse=params.sparse) if __name__ == '__main__': diff --git a/chromium/third_party/openscreen/src/build/config/BUILD.gn b/chromium/third_party/openscreen/src/build/config/BUILD.gn index a153e5b1c0b..a68031e88c5 100644 --- a/chromium/third_party/openscreen/src/build/config/BUILD.gn +++ b/chromium/third_party/openscreen/src/build/config/BUILD.gn @@ -16,9 +16,6 @@ declare_args() { # Enable thread sanitizer. is_tsan = false - # Must be enabled for fuzzing targets. - use_libfuzzer = false - # Enables clang's source-based coverage (requires is_clang=true). # NOTE: This will slow down the build and increase binary size # significantly. For more details, see: diff --git a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn index 4e69143f537..0fa9693520d 100644 --- a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn +++ b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn @@ -75,6 +75,9 @@ declare_args() { # further explanation, see # https://gn.googlesource.com/gn/+/refs/heads/master/docs/reference.md#toolchain-overview host_toolchain = "" + + # Must be enabled for fuzzing targets. + use_libfuzzer = false } declare_args() { diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc index 63457298716..569d22b1379 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.cc @@ -10,6 +10,7 @@ #include <openssl/x509v3.h> #include <time.h> +#include <string> #include <vector> #include "cast/common/certificate/types.h" @@ -381,11 +382,15 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->intermediate_certs; target_cert.reset(ParseX509Der(der_certs[0])); if (!target_cert) { + OSP_DVLOG << "FindCertificatePath: Invalid target certificate"; return Error::Code::kErrCertsParse; } for (size_t i = 1; i < der_certs.size(); ++i) { intermediate_certs.emplace_back(ParseX509Der(der_certs[i])); if (!intermediate_certs.back()) { + OSP_DVLOG + << "FindCertificatePath: Failed to parse intermediate certificate " + << i << " of " << der_certs.size(); return Error::Code::kErrCertsParse; } } @@ -393,10 +398,12 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, // Basic checks on the target certificate. Error::Code error = VerifyCertTime(target_cert.get(), time); if (error != Error::Code::kNone) { + OSP_DVLOG << "FindCertificatePath: Failed to verify certificate time"; return error; } bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(target_cert.get())}; if (!VerifyPublicKeyLength(public_key.get())) { + OSP_DVLOG << "FindCertificatePath: Failed with invalid public key length"; return Error::Code::kErrCertsVerifyGeneric; } if (X509_ALGOR_cmp(target_cert.get()->sig_alg, @@ -405,11 +412,13 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, } bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(target_cert.get()); if (!key_usage) { + OSP_DVLOG << "FindCertificatePath: Failed with no key usage"; return Error::Code::kErrCertsRestrictions; } int bit = ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kDigitalSignature); if (bit == 0) { + OSP_DVLOG << "FindCertificatePath: Failed to get digital signature"; return Error::Code::kErrCertsRestrictions; } @@ -443,6 +452,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, Error::Code last_error = Error::Code::kNone; for (;;) { X509_NAME* target_issuer_name = X509_get_issuer_name(path_head); + OSP_DVLOG << "FindCertificatePath: Target certificate issuer name: " + << X509_NAME_oneline(target_issuer_name, 0, 0); // The next issuer certificate to add to the current path. X509* next_issuer = nullptr; @@ -451,6 +462,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, X509* trust_store_cert = trust_store->certs[i].get(); X509_NAME* trust_store_cert_name = X509_get_subject_name(trust_store_cert); + OSP_DVLOG << "FindCertificatePath: Trust store certificate issuer name: " + << X509_NAME_oneline(trust_store_cert_name, 0, 0); if (X509_NAME_cmp(trust_store_cert_name, target_issuer_name) == 0) { CertPathStep& next_step = path[--path_index]; next_step.cert = trust_store_cert; @@ -485,6 +498,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, if (path_index == first_index) { // There are no more paths to try. Ensure an error is returned. if (last_error == Error::Code::kNone) { + OSP_DVLOG << "FindCertificatePath: Failed after trying all " + "certificate paths, no matches"; return Error::Code::kErrCertsVerifyGeneric; } return last_error; @@ -515,6 +530,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->path.push_back(path[i].cert); } + OSP_DVLOG + << "FindCertificatePath: Succeeded at validating receiver certificates"; return Error::Code::kNone; } diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h index 8aac9d3905b..801d9274a74 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_trust_store.h @@ -21,7 +21,7 @@ class CastTrustStore { const std::vector<uint8_t>& trust_anchor_der); CastTrustStore(); - CastTrustStore(const std::vector<uint8_t>& trust_anchor_der); + explicit CastTrustStore(const std::vector<uint8_t>& trust_anchor_der); CastTrustStore(const CastTrustStore&) = delete; ~CastTrustStore(); CastTrustStore& operator=(const CastTrustStore&) = delete; diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/proto/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/certificate/proto/BUILD.gn index d171715dbca..629617a1381 100644 --- a/chromium/third_party/openscreen/src/cast/common/certificate/proto/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/common/certificate/proto/BUILD.gn @@ -3,17 +3,14 @@ # found in the LICENSE file. import("//build_overrides/build.gni") +import("//third_party/libprotobuf-mutator/fuzzable_proto_library.gni") import("//third_party/protobuf/proto_library.gni") -proto_library("certificate_proto") { - sources = [ - "revocation.proto", - ] +fuzzable_proto_library("certificate_proto") { + sources = [ "revocation.proto" ] } proto_library("certificate_unittest_proto") { testonly = true - sources = [ - "test_suite.proto", - ] + sources = [ "test_suite.proto" ] } diff --git a/chromium/third_party/openscreen/src/cast/common/channel/proto/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/channel/proto/BUILD.gn index 13ee38af7bf..152cb450ea4 100644 --- a/chromium/third_party/openscreen/src/cast/common/channel/proto/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/common/channel/proto/BUILD.gn @@ -3,9 +3,10 @@ # found in the LICENSE file. import("//build_overrides/build.gni") +import("//third_party/libprotobuf-mutator/fuzzable_proto_library.gni") import("//third_party/protobuf/proto_library.gni") -proto_library("channel_proto") { +fuzzable_proto_library("channel_proto") { sources = [ "authority_keys.proto", "cast_channel.proto", diff --git a/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc index 80216f26a8e..e05d7eb8d78 100644 --- a/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc +++ b/chromium/third_party/openscreen/src/cast/common/discovery/e2e_test/tests.cc @@ -20,28 +20,29 @@ #include "platform/impl/network_interface.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/task_runner.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { namespace { +// Maximum amount of time needed for a query to be received. +constexpr seconds kMaxQueryDuration{3}; + // Total wait time = 4 seconds. -constexpr std::chrono::milliseconds kWaitLoopSleepTime = - std::chrono::milliseconds(500); +constexpr milliseconds kWaitLoopSleepTime(500); constexpr int kMaxWaitLoopIterations = 8; // Total wait time = 2.5 seconds. // NOTE: This must be less than the above wait time. -constexpr std::chrono::milliseconds kCheckLoopSleepTime = - std::chrono::milliseconds(100); +constexpr milliseconds kCheckLoopSleepTime(100); constexpr int kMaxCheckLoopIterations = 25; -} // namespace // Publishes new service instances. class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { public: - Publisher(discovery::DnsSdService* service) + explicit Publisher(discovery::DnsSdService* service) // NOLINT : DnsSdServicePublisher<ServiceInfo>(service, kCastV2ServiceId, ServiceInfoToDnsSdInstance) { @@ -66,9 +67,9 @@ class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { }; // Receives incoming services and outputs their results to stdout. -class Receiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { +class ServiceReceiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { public: - Receiver(discovery::DnsSdService* service) + explicit ServiceReceiver(discovery::DnsSdService* service) // NOLINT : discovery::DnsSdServiceWatcher<ServiceInfo>( service, kCastV2ServiceId, @@ -77,7 +78,7 @@ class Receiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { std::vector<std::reference_wrapper<const ServiceInfo>> infos) { ProcessResults(std::move(infos)); }) { - OSP_LOG << "Initializing Receiver..."; + OSP_LOG << "Initializing ServiceReceiver..."; } bool IsServiceFound(const ServiceInfo& check_service) { @@ -133,9 +134,9 @@ class DiscoveryE2ETest : public testing::Test { public: DiscoveryE2ETest() { // Sleep to let any packets clear off the network before further tests. - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + std::this_thread::sleep_for(milliseconds(500)); - PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}); + PlatformClientPosix::Create(milliseconds(50), milliseconds(50)); task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner(); } @@ -161,7 +162,7 @@ class DiscoveryE2ETest : public testing::Test { task_runner_->PostTask([this, &config, &done]() { dnssd_service_ = discovery::CreateDnsSdService( task_runner_, &reporting_client_, config); - receiver_ = std::make_unique<Receiver>(dnssd_service_.get()); + receiver_ = std::make_unique<ServiceReceiver>(dnssd_service_.get()); publisher_ = std::make_unique<Publisher>(dnssd_service_.get()); done = true; }); @@ -253,6 +254,9 @@ class DiscoveryE2ETest : public testing::Test { }); } + // TODO(issuetracker.google.com/159256503): Change this to use a polling + // method to wait until the service disappears rather than immediately failing + // if it exists, so waits throughout this file can be removed. void CheckNotPublishedService(ServiceInfo service_info, std::atomic_bool* has_been_seen) { OSP_DCHECK(dnssd_service_.get()); @@ -264,7 +268,7 @@ class DiscoveryE2ETest : public testing::Test { TaskRunner* task_runner_; FailOnErrorReporting reporting_client_; SerialDeletePtr<discovery::DnsSdService> dnssd_service_; - std::unique_ptr<Receiver> receiver_; + std::unique_ptr<ServiceReceiver> receiver_; std::unique_ptr<Publisher> publisher_; private: @@ -423,7 +427,7 @@ TEST_F(DiscoveryE2ETest, ValidateAnnouncementFlow) { ASSERT_FALSE(result.is_error()); ASSERT_EQ(result.value(), 3); }); - std::this_thread::sleep_for(std::chrono::seconds(3)); + std::this_thread::sleep_for(seconds(3)); found1 = false; found2 = false; found3 = false; @@ -499,8 +503,9 @@ TEST_F(DiscoveryE2ETest, ValidateRecordsOnlyReceivedWhenQueryRunning) { CheckNotPublishedService(instance, &found); WaitUntilSeen(false, &found); - // Restart discovery and ensure that only the updated record is returned. StartDiscovery(); + std::this_thread::sleep_for(kMaxQueryDuration); + OSP_LOG << "Service discovery in progress..."; found = false; CheckNotPublishedService(updated_instance, &found); @@ -525,7 +530,6 @@ TEST_F(DiscoveryE2ETest, ValidateRefreshFlow) { auto discovery_config = GetConfigSettings(); discovery_config.new_record_announcement_count = 0; discovery_config.new_query_announcement_count = 2; - constexpr std::chrono::seconds kMaxQueryDuration{3}; SetUpService(discovery_config); auto instance = GetInfo(1); @@ -578,5 +582,6 @@ TEST_F(DiscoveryE2ETest, ValidateRefreshFlow) { WaitUntilSeen(true, &found); } +} // namespace } // namespace cast } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/common/public/cast_socket.h b/chromium/third_party/openscreen/src/cast/common/public/cast_socket.h index 0fd065d60a9..d7ac683ffb1 100644 --- a/chromium/third_party/openscreen/src/cast/common/public/cast_socket.h +++ b/chromium/third_party/openscreen/src/cast/common/public/cast_socket.h @@ -10,6 +10,7 @@ #include <vector> #include "platform/api/tls_connection.h" +#include "util/weak_ptr.h" namespace cast { namespace channel { @@ -58,6 +59,8 @@ class CastSocket : public TlsConnection::Client { void OnError(TlsConnection* connection, Error error) override; void OnRead(TlsConnection* connection, std::vector<uint8_t> block) override; + WeakPtr<CastSocket> GetWeakPtr() const { return weak_factory_.GetWeakPtr(); } + private: enum class State : bool { kOpen = true, @@ -72,6 +75,8 @@ class CastSocket : public TlsConnection::Client { bool audio_only_ = false; std::vector<uint8_t> read_buffer_; State state_ = State::kOpen; + + WeakPtrFactory<CastSocket> weak_factory_{this}; }; } // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info.cc b/chromium/third_party/openscreen/src/cast/common/public/service_info.cc index 852ecca9be2..b6d6f623819 100644 --- a/chromium/third_party/openscreen/src/cast/common/public/service_info.cc +++ b/chromium/third_party/openscreen/src/cast/common/public/service_info.cc @@ -196,9 +196,15 @@ ErrorOr<ServiceInfo> DnsSdInstanceEndpointToServiceInfo( } ServiceInfo record; - record.v4_address = endpoint.address_v4(); - record.v6_address = endpoint.address_v6(); record.port = endpoint.port(); + for (const IPAddress& address : endpoint.addresses()) { + if (!record.v4_address && address.IsV4()) { + record.v4_address = address; + } else if (!record.v6_address && address.IsV6()) { + record.v6_address = address; + } + } + OSP_DCHECK(record.v4_address || record.v6_address); const auto& txt = endpoint.txt(); std::string capabilities_base64; diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info.h b/chromium/third_party/openscreen/src/cast/common/public/service_info.h index a3f13a208fb..944a68ed73b 100644 --- a/chromium/third_party/openscreen/src/cast/common/public/service_info.h +++ b/chromium/third_party/openscreen/src/cast/common/public/service_info.h @@ -7,6 +7,7 @@ #include <memory> #include <string> +#include <utility> #include "discovery/dnssd/public/dns_sd_instance.h" #include "discovery/dnssd/public/dns_sd_instance_endpoint.h" diff --git a/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc b/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc index a459cafd2e4..2f53ef97ca1 100644 --- a/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/public/service_info_unittest.cc @@ -11,13 +11,18 @@ namespace openscreen { namespace cast { +namespace { + +constexpr NetworkInterfaceIndex kNetworkInterface = 0; + +} TEST(ServiceInfoTests, ConvertValidFromDnsSd) { std::string instance = "InstanceId"; discovery::DnsSdTxtRecord txt = CreateValidTxt(); - discovery::DnsSdInstanceEndpoint record(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + discovery::DnsSdInstanceEndpoint record( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); ErrorOr<ServiceInfo> info = DnsSdInstanceEndpointToServiceInfo(record); ASSERT_TRUE(info.is_value()); EXPECT_EQ(info.value().unique_id, kTestUniqueId); @@ -33,9 +38,9 @@ TEST(ServiceInfoTests, ConvertValidFromDnsSd) { EXPECT_EQ(info.value().model_name, kModelName); EXPECT_EQ(info.value().friendly_name, kFriendlyName); - record = discovery::DnsSdInstanceEndpoint( - instance, kCastV2ServiceId, kCastV2DomainId, txt, kEndpointV4, 0); - ASSERT_FALSE(record.address_v6()); + record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, + kCastV2DomainId, txt, + kNetworkInterface, kEndpointV4); info = DnsSdInstanceEndpointToServiceInfo(record); ASSERT_TRUE(info.is_value()); EXPECT_EQ(info.value().unique_id, kTestUniqueId); @@ -49,9 +54,9 @@ TEST(ServiceInfoTests, ConvertValidFromDnsSd) { EXPECT_EQ(info.value().model_name, kModelName); EXPECT_EQ(info.value().friendly_name, kFriendlyName); - record = discovery::DnsSdInstanceEndpoint( - instance, kCastV2ServiceId, kCastV2DomainId, txt, kEndpointV6, 0); - ASSERT_FALSE(record.address_v4()); + record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, + kCastV2DomainId, txt, + kNetworkInterface, kEndpointV6); info = DnsSdInstanceEndpointToServiceInfo(record); ASSERT_TRUE(info.is_value()); EXPECT_EQ(info.value().unique_id, kTestUniqueId); @@ -70,44 +75,44 @@ TEST(ServiceInfoTests, ConvertInvalidFromDnsSd) { std::string instance = "InstanceId"; discovery::DnsSdTxtRecord txt = CreateValidTxt(); txt.ClearValue(kUniqueIdKey); - discovery::DnsSdInstanceEndpoint record(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + discovery::DnsSdInstanceEndpoint record( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kVersionId); - record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + record = discovery::DnsSdInstanceEndpoint( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kCapabilitiesId); - record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + record = discovery::DnsSdInstanceEndpoint( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kStatusId); - record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + record = discovery::DnsSdInstanceEndpoint( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kFriendlyNameId); - record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + record = discovery::DnsSdInstanceEndpoint( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kModelNameId); - record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, - kCastV2DomainId, txt, kEndpointV4, - kEndpointV6, 0); + record = discovery::DnsSdInstanceEndpoint( + instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, + kEndpointV4, kEndpointV6); EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); } diff --git a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h index 4ca6df5e17c..d23ebd7e00d 100644 --- a/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h +++ b/chromium/third_party/openscreen/src/cast/sender/channel/cast_auth_util.h @@ -7,7 +7,7 @@ #include <openssl/x509.h> -#include <chrono> // NOLINT +#include <chrono> #include <string> #include <vector> diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn b/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn index ba54d33f414..46ab0cad3c7 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/BUILD.gn @@ -9,16 +9,18 @@ import("//build_overrides/build.gni") # standalone platform implementation; since this is itself a standalone # application. if (!build_with_chromium) { - executable("cast_receiver") { + source_set("standalone_receiver") { sources = [ "cast_agent.cc", "cast_agent.h", "cast_socket_message_port.cc", "cast_socket_message_port.h", - "main.cc", + "static_credentials.cc", + "static_credentials.h", "streaming_playback_controller.cc", "streaming_playback_controller.h", ] + deps = [ "../../platform", "../../third_party/jsoncpp", @@ -57,4 +59,25 @@ if (!build_with_chromium) { ] } } + + source_set("e2e_tests") { + testonly = true + sources = [ "cast_agent_integration_tests.cc" ] + + deps = [ + ":standalone_receiver", + "../../third_party/boringssl", + "../../third_party/googletest:gtest", + "../receiver:channel", + ] + } + + executable("cast_receiver") { + sources = [ "main.cc" ] + + deps = [ + ":standalone_receiver", + "../receiver:channel", + ] + } } diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc index dfccd351886..ad22f2d6538 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.cc @@ -11,13 +11,12 @@ #include <vector> #include "absl/strings/str_cat.h" +#include "cast/common/channel/message_util.h" #include "cast/standalone_receiver/cast_socket_message_port.h" -#include "cast/standalone_receiver/private_key_der.h" #include "cast/streaming/constants.h" #include "cast/streaming/offer_messages.h" #include "platform/base/tls_credentials.h" #include "platform/base/tls_listen_options.h" -#include "util/crypto/certificate_utils.h" #include "util/json/json_serialization.h" #include "util/osp_logging.h" @@ -28,100 +27,77 @@ namespace { constexpr int kDefaultMaxBacklogSize = 64; const TlsListenOptions kDefaultListenOptions{kDefaultMaxBacklogSize}; -constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60; -constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds); - -// Generates a valid set of credentials for use with the TLS Server socket, -// including a generated X509 certificate generated from the static private key -// stored in private_key_der.h. The certificate is valid for -// kCertificateDuration from when this function is called. -ErrorOr<TlsCredentials> CreateCredentials(const IPEndpoint& endpoint) { - ErrorOr<bssl::UniquePtr<EVP_PKEY>> private_key = - ImportRSAPrivateKey(kPrivateKeyDer.data(), kPrivateKeyDer.size()); - OSP_CHECK(private_key); - - ErrorOr<bssl::UniquePtr<X509>> cert = CreateSelfSignedX509Certificate( - endpoint.ToString(), kCertificateDuration, *private_key.value()); - if (!cert) { - return cert.error(); - } - - auto cert_bytes = ExportX509CertificateToDer(*cert.value()); - if (!cert_bytes) { - return cert_bytes.error(); - } - - // TODO(jophba): either refactor the TLS server socket to use the public key - // and add a valid key here, or remove from the TlsCredentials struct. - return TlsCredentials( - std::vector<uint8_t>(kPrivateKeyDer.begin(), kPrivateKeyDer.end()), - std::vector<uint8_t>{}, std::move(cert_bytes.value())); -} - } // namespace -CastAgent::CastAgent(TaskRunner* task_runner, InterfaceInfo interface) - : task_runner_(task_runner) { - // Create the Environment that holds the required injected dependencies - // (clock, task runner) used throughout the system, and owns the UDP socket - // over which all communication occurs with the Sender. - IPEndpoint receive_endpoint{IPAddress::kV4LoopbackAddress, kDefaultCastPort}; - receive_endpoint.address = interface.GetIpAddressV4() - ? interface.GetIpAddressV4() - : interface.GetIpAddressV6(); - OSP_DCHECK(receive_endpoint.address); - environment_ = std::make_unique<Environment>(&Clock::now, task_runner_, - receive_endpoint); - receive_endpoint_ = std::move(receive_endpoint); +CastAgent::CastAgent( + TaskRunner* task_runner, + InterfaceInfo interface, + DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider, + TlsCredentials tls_credentials) + : task_runner_(task_runner), + credentials_provider_(credentials_provider), + tls_credentials_(tls_credentials) { + const IPAddress address = interface.GetIpAddressV4() + ? interface.GetIpAddressV4() + : interface.GetIpAddressV6(); + OSP_CHECK(address); + environment_ = std::make_unique<Environment>( + &Clock::now, task_runner_, + IPEndpoint{address, kDefaultCastStreamingPort}); + receive_endpoint_ = IPEndpoint{address, kDefaultCastPort}; } CastAgent::~CastAgent() = default; Error CastAgent::Start() { - OSP_DCHECK(!current_session_); - - task_runner_->PostTask( - [this] { this->wake_lock_ = ScopedWakeLock::Create(); }); - - // TODO(jophba): add command line argument for setting the private key. - ErrorOr<TlsCredentials> credentials = CreateCredentials(receive_endpoint_); - if (!credentials) { - return credentials.error(); - } - - // TODO(jophba, rwkeane): begin discovery process before creating TLS - // connection factory instance. - socket_factory_ = - std::make_unique<ReceiverSocketFactory>(this, &message_port_); - task_runner_->PostTask([this, creds = std::move(credentials.value())] { - connection_factory_ = TlsConnectionFactory::CreateFactory( - socket_factory_.get(), task_runner_); - connection_factory_->SetListenCredentials(creds); + OSP_CHECK(!current_session_); + + auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>( + task_runner_, credentials_provider_); + router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_, + &connection_manager_); + router_->AddHandlerForLocalId(kPlatformReceiverId, auth_handler_.get()); + socket_factory_ = MakeSerialDelete<ReceiverSocketFactory>(task_runner_, this, + router_.get()); + + task_runner_->PostTask([this] { + wake_lock_ = ScopedWakeLock::Create(task_runner_); + + connection_factory_ = SerialDeletePtr<TlsConnectionFactory>( + task_runner_, + TlsConnectionFactory::CreateFactory(socket_factory_.get(), task_runner_) + .release()); + connection_factory_->SetListenCredentials(tls_credentials_); connection_factory_->Listen(receive_endpoint_, kDefaultListenOptions); + OSP_LOG_INFO << "Listening for connections at: " << receive_endpoint_; }); - OSP_LOG_INFO << "Listening for connections at: " << receive_endpoint_; return Error::None(); } Error CastAgent::Stop() { - controller_.reset(); - current_session_.reset(); + task_runner_->PostTask([this] { + router_.reset(); + connection_factory_.reset(); + controller_.reset(); + current_session_.reset(); + socket_factory_.reset(); + wake_lock_.reset(); + }); return Error::None(); } void CastAgent::OnConnected(ReceiverSocketFactory* factory, const IPEndpoint& endpoint, std::unique_ptr<CastSocket> socket) { - OSP_DCHECK(factory); - if (current_session_) { OSP_LOG_WARN << "Already connected, dropping peer at: " << endpoint; return; } OSP_LOG_INFO << "Received connection from peer at: " << endpoint; - message_port_.SetSocket(std::move(socket)); + message_port_.SetSocket(socket->GetWeakPtr()); + router_->TakeSocket(this, std::move(socket)); controller_ = std::make_unique<StreamingPlaybackController>(task_runner_, this); current_session_ = std::make_unique<ReceiverSession>( @@ -131,6 +107,17 @@ void CastAgent::OnConnected(ReceiverSocketFactory* factory, void CastAgent::OnError(ReceiverSocketFactory* factory, Error error) { OSP_LOG_ERROR << "Cast agent received socket factory error: " << error; + StopCurrentSession(); +} + +void CastAgent::OnClose(CastSocket* cast_socket) { + OSP_VLOG << "Cast agent socket closed."; + StopCurrentSession(); +} + +void CastAgent::OnError(CastSocket* socket, Error error) { + OSP_LOG_ERROR << "Cast agent received socket error: " << error; + StopCurrentSession(); } // Currently we don't do anything with the receiver output--the session @@ -139,23 +126,30 @@ void CastAgent::OnError(ReceiverSocketFactory* factory, Error error) { // about the receiver configurations we will have to handle OnNegotiated here. void CastAgent::OnNegotiated(const ReceiverSession* session, ReceiverSession::ConfiguredReceivers receivers) { - OSP_LOG_INFO << "Successfully negotiated with sender."; + OSP_VLOG << "Successfully negotiated with sender."; } void CastAgent::OnConfiguredReceiversDestroyed(const ReceiverSession* session) { - OSP_LOG_INFO << "Receiver instances destroyed."; + OSP_VLOG << "Receiver instances destroyed."; } // Currently, we just kill the session if an error is encountered. void CastAgent::OnError(const ReceiverSession* session, Error error) { OSP_LOG_ERROR << "Cast agent received receiver session error: " << error; - current_session_.reset(); + StopCurrentSession(); } void CastAgent::OnPlaybackError(StreamingPlaybackController* controller, Error error) { OSP_LOG_ERROR << "Cast agent received playback error: " << error; + StopCurrentSession(); +} + +void CastAgent::StopCurrentSession() { + controller_.reset(); current_session_.reset(); + router_->CloseSocket(message_port_.GetSocketId()); + message_port_.SetSocket(nullptr); } } // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h index d9932b9d077..b4fca60bed7 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent.h @@ -8,10 +8,15 @@ #include <openssl/x509.h> #include <memory> +#include <vector> +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" +#include "cast/receiver/channel/device_auth_namespace_handler.h" #include "cast/receiver/public/receiver_socket_factory.h" #include "cast/standalone_receiver/cast_socket_message_port.h" +#include "cast/standalone_receiver/static_credentials.h" #include "cast/standalone_receiver/streaming_playback_controller.h" #include "cast/streaming/environment.h" #include "cast/streaming/receiver_session.h" @@ -19,6 +24,7 @@ #include "platform/api/serial_delete_ptr.h" #include "platform/base/error.h" #include "platform/base/interface_info.h" +#include "platform/base/tls_credentials.h" #include "platform/impl/task_runner.h" namespace openscreen { @@ -29,13 +35,19 @@ namespace cast { // received, and linking Receivers to the output decoder and SDL visualizer. // // Consumers of this class are expected to provide a single threaded task runner -// implementation, and a network interface information struct that will be used -// both for TLS listening and UDP messaging. -class CastAgent : public ReceiverSocketFactory::Client, - public ReceiverSession::Client, - public StreamingPlaybackController::Client { +// implementation, a network interface information struct that will be used +// both for TLS listening and UDP messaging, and a credentials provider used +// for TLS listening. +class CastAgent final : public ReceiverSocketFactory::Client, + public VirtualConnectionRouter::SocketErrorHandler, + public ReceiverSession::Client, + public StreamingPlaybackController::Client { public: - CastAgent(TaskRunner* task_runner, InterfaceInfo interface); + CastAgent( + TaskRunner* task_runner, + InterfaceInfo interface, + DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider, + TlsCredentials tls_credentials); ~CastAgent(); // Initialization occurs as part of construction, however to actually bind @@ -49,6 +61,10 @@ class CastAgent : public ReceiverSocketFactory::Client, std::unique_ptr<CastSocket> socket) override; void OnError(ReceiverSocketFactory* factory, Error error) override; + // VirtualConnectionRouter::SocketErrorHandler overrides. + void OnClose(CastSocket* cast_socket) override; + void OnError(CastSocket* socket, Error error) override; + // ReceiverSession::Client overrides. void OnNegotiated(const ReceiverSession* session, ReceiverSession::ConfiguredReceivers receivers) override; @@ -60,16 +76,26 @@ class CastAgent : public ReceiverSocketFactory::Client, Error error) override; private: + // Helper for stopping the current session. This is useful for when we don't + // want to completely stop (e.g. an issue with a specific Sender) but need + // to terminate the current connection. + void StopCurrentSession(); + // Member variables set as part of construction. std::unique_ptr<Environment> environment_; TaskRunner* const task_runner_; IPEndpoint receive_endpoint_; + DeviceAuthNamespaceHandler::CredentialsProvider* credentials_provider_; CastSocketMessagePort message_port_; + TlsCredentials tls_credentials_; // Member variables set as part of starting up. - std::unique_ptr<TlsConnectionFactory> connection_factory_; - std::unique_ptr<ReceiverSocketFactory> socket_factory_; - std::unique_ptr<ScopedWakeLock> wake_lock_; + SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_; + SerialDeletePtr<TlsConnectionFactory> connection_factory_; + VirtualConnectionManager connection_manager_; + SerialDeletePtr<VirtualConnectionRouter> router_; + SerialDeletePtr<ReceiverSocketFactory> socket_factory_; + SerialDeletePtr<ScopedWakeLock> wake_lock_; // Member variables set as part of a sender connection. // NOTE: currently we only support a single sender connection and a diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent_integration_tests.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent_integration_tests.cc new file mode 100644 index 00000000000..919711abeb4 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_agent_integration_tests.cc @@ -0,0 +1,142 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/certificate/cast_trust_store.h" +#include "cast/common/certificate/testing/test_helpers.h" +#include "cast/common/channel/virtual_connection_manager.h" +#include "cast/common/channel/virtual_connection_router.h" +#include "cast/sender/public/sender_socket_factory.h" +#include "cast/standalone_receiver/cast_agent.h" +#include "cast/standalone_receiver/static_credentials.h" +#include "gtest/gtest.h" +#include "platform/api/serial_delete_ptr.h" +#include "platform/api/time.h" +#include "platform/impl/network_interface.h" +#include "platform/impl/platform_client_posix.h" +#include "platform/impl/task_runner.h" + +namespace openscreen { +namespace cast { +namespace { + +// Based heavily on SenderSocketsClient from cast_socket_e2e_test.cc. +class MockSender final : public SenderSocketFactory::Client, + public VirtualConnectionRouter::SocketErrorHandler { + public: + explicit MockSender(VirtualConnectionRouter* router) : router_(router) {} + ~MockSender() = default; + + CastSocket* socket() const { return socket_; } + + // SenderSocketFactory::Client overrides. + void OnConnected(SenderSocketFactory* factory, + const IPEndpoint& endpoint, + std::unique_ptr<CastSocket> socket) override { + ASSERT_FALSE(socket_); + OSP_LOG_INFO << "Sender connected to endpoint: " << endpoint; + socket_ = socket.get(); + router_->TakeSocket(this, std::move(socket)); + } + + void OnError(SenderSocketFactory* factory, + const IPEndpoint& endpoint, + Error error) override { + FAIL() << error; + } + + // VirtualConnectionRouter::SocketErrorHandler overrides. + void OnClose(CastSocket* socket) override {} + void OnError(CastSocket* socket, Error error) override { FAIL() << error; } + + private: + VirtualConnectionRouter* const router_; + std::atomic<CastSocket*> socket_{nullptr}; +}; + +class CastAgentIntegrationTest : public ::testing::Test { + public: + void SetUp() override { + PlatformClientPosix::Create(std::chrono::milliseconds(50), + std::chrono::milliseconds(50)); + task_runner_ = reinterpret_cast<TaskRunnerImpl*>( + PlatformClientPosix::GetInstance()->GetTaskRunner()); + + sender_router_ = MakeSerialDelete<VirtualConnectionRouter>( + task_runner_, &sender_vc_manager_); + sender_client_ = std::make_unique<MockSender>(sender_router_.get()); + sender_factory_ = MakeSerialDelete<SenderSocketFactory>( + task_runner_, sender_client_.get(), task_runner_); + sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>( + task_runner_, + TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_) + .release()); + sender_factory_->set_factory(sender_tls_factory_.get()); + } + + void TearDown() override { + sender_router_.reset(); + sender_tls_factory_.reset(); + sender_factory_.reset(); + PlatformClientPosix::ShutDown(); + // Must be shut down after platform client, so joined tasks + // depending on certs are called correctly. + CastTrustStore::ResetInstance(); + } + + void WaitAndAssertSenderSocketConnected() { + constexpr int kMaxAttempts = 10; + constexpr std::chrono::milliseconds kSocketWaitDelay(250); + for (int i = 0; i < kMaxAttempts; ++i) { + OSP_LOG_INFO << "\tChecking for CastSocket, attempt " << i + 1 << "/" + << kMaxAttempts; + if (sender_client_->socket()) { + break; + } + std::this_thread::sleep_for(kSocketWaitDelay); + } + ASSERT_TRUE(sender_client_->socket()); + } + + void AssertConnect(const IPAddress& address) { + OSP_LOG_INFO << "Sending connect task"; + task_runner_->PostTask( + [this, &address, port = (static_cast<uint16_t>(kDefaultCastPort))]() { + OSP_LOG_INFO << "Calling SenderSocketFactory::Connect"; + sender_factory_->Connect( + IPEndpoint{address, port}, + SenderSocketFactory::DeviceMediaPolicy::kNone, + sender_router_.get()); + }); + WaitAndAssertSenderSocketConnected(); + } + + TaskRunnerImpl* task_runner_; + // Cast socket sender components, used in conjuction to mock a Libcast sender. + VirtualConnectionManager sender_vc_manager_; + SerialDeletePtr<VirtualConnectionRouter> sender_router_; + std::unique_ptr<MockSender> sender_client_; + SerialDeletePtr<SenderSocketFactory> sender_factory_; + SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_; +}; + +TEST_F(CastAgentIntegrationTest, StartsListeningProperly) { + absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting(); + ASSERT_TRUE(loopback.has_value()); + + ErrorOr<GeneratedCredentials> creds = + GenerateCredentials("Test Device Certificate"); + ASSERT_TRUE(creds.is_value()); + CastTrustStore::CreateInstanceForTest(creds.value().root_cert_der); + + auto agent = MakeSerialDelete<CastAgent>( + task_runner_, task_runner_, loopback.value(), + creds.value().provider.get(), creds.value().tls_credentials); + EXPECT_TRUE(agent->Start().ok()); + AssertConnect(loopback.value().GetIpAddressV4()); + EXPECT_TRUE(agent->Stop().ok()); +} + +} // namespace +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc index d5540f28b3e..6f3c55c8ae8 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.cc @@ -19,27 +19,17 @@ CastSocketMessagePort::~CastSocketMessagePort() = default; // since sockets should map one to one with receiver sessions, we reset our // client. The consumer of this message port should call SetClient with the new // message port client after setting the socket. -void CastSocketMessagePort::SetSocket(std::unique_ptr<CastSocket> socket) { +void CastSocketMessagePort::SetSocket(WeakPtr<CastSocket> socket) { client_ = nullptr; - socket_ = std::move(socket); + socket_ = socket; } -void CastSocketMessagePort::SetClient(MessagePort::Client* client) { - client_ = client; -} - -void CastSocketMessagePort::OnError(CastSocket* socket, Error error) { - if (client_) { - client_->OnError(error); - } +int CastSocketMessagePort::GetSocketId() { + return socket_ ? socket_->socket_id() : -1; } -void CastSocketMessagePort::OnMessage(CastSocket* socket, - ::cast::channel::CastMessage message) { - if (client_) { - client_->OnMessage(message.source_id(), message.namespace_(), - message.payload_utf8()); - } +void CastSocketMessagePort::SetClient(MessagePort::Client* client) { + client_ = client; } void CastSocketMessagePort::PostMessage(absl::string_view sender_id, @@ -51,6 +41,10 @@ void CastSocketMessagePort::PostMessage(absl::string_view sender_id, message_namespace.size()); cast_message.set_payload_utf8(message.data(), message.size()); + if (!socket_) { + client_->OnError(Error::Code::kAlreadyClosed); + return; + } Error error = socket_->Send(cast_message); if (!error.ok()) { client_->OnError(error); diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h index 98fc47f686c..67d037e96b6 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/cast_socket_message_port.h @@ -11,16 +11,20 @@ #include "cast/common/public/cast_socket.h" #include "cast/streaming/receiver_session.h" +#include "util/weak_ptr.h" namespace openscreen { namespace cast { -class CastSocketMessagePort : public MessagePort, public CastSocket::Client { +class CastSocketMessagePort : public MessagePort { public: CastSocketMessagePort(); ~CastSocketMessagePort() override; - void SetSocket(std::unique_ptr<CastSocket> socket); + void SetSocket(WeakPtr<CastSocket> socket); + + // Returns current socket identifier, or -1 if not connected. + int GetSocketId(); // MessagePort overrides. void SetClient(MessagePort::Client* client) override; @@ -28,14 +32,9 @@ class CastSocketMessagePort : public MessagePort, public CastSocket::Client { absl::string_view message_namespace, absl::string_view message) override; - // CastSocket::Client overrides. - void OnError(CastSocket* socket, Error error) override; - void OnMessage(CastSocket* socket, - ::cast::channel::CastMessage message) override; - private: MessagePort::Client* client_ = nullptr; - std::unique_ptr<CastSocket> socket_; + WeakPtr<CastSocket> socket_; }; } // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc index 221c85700f2..9a2324e3128 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/decoder.cc @@ -6,7 +6,7 @@ #include <algorithm> #include <sstream> -#include <thread> // NOLINT +#include <thread> #include "util/osp_logging.h" #include "util/trace_logging.h" @@ -32,7 +32,8 @@ void Decoder::Buffer::Resize(int new_size) { } absl::Span<const uint8_t> Decoder::Buffer::GetSpan() const { - return const_cast<Buffer*>(this)->GetSpan(); + return absl::Span<const uint8_t>( + buffer_.data(), buffer_.size() - AV_INPUT_BUFFER_PADDING_SIZE); } absl::Span<uint8_t> Decoder::Buffer::GetSpan() { diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc index bd667c3a2d1..512ea7da2ad 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/dummy_player.cc @@ -8,10 +8,9 @@ #include "absl/types/span.h" #include "cast/streaming/encoded_frame.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" -using std::chrono::microseconds; - namespace openscreen { namespace cast { @@ -33,7 +32,7 @@ void DummyPlayer::OnFramesReady(int buffer_size) { // Convert the RTP timestamp to a human-readable timestamp (in µs) and log // some short information about the frame. const auto media_timestamp = - frame.rtp_timestamp.ToTimeSinceOrigin<microseconds>( + frame.rtp_timestamp.ToTimeSinceOrigin<std::chrono::microseconds>( receiver_->rtp_timebase()); OSP_LOG_INFO << "[SSRC " << receiver_->ssrc() << "] " << (frame.dependency == EncodedFrame::KEY_FRAME ? "KEY " : "") diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc index 7f16b2b4cd0..44a1a0a09bf 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/main.cc @@ -5,11 +5,13 @@ #include <getopt.h> #include <array> -#include <chrono> // NOLINT +#include <chrono> #include <iostream> +#include "absl/strings/str_cat.h" #include "cast/common/public/service_info.h" #include "cast/standalone_receiver/cast_agent.h" +#include "cast/standalone_receiver/static_credentials.h" #include "cast/streaming/ssrc.h" #include "discovery/common/config.h" #include "discovery/common/reporting_client.h" @@ -24,6 +26,7 @@ #include "platform/impl/platform_client_posix.h" #include "platform/impl/task_runner.h" #include "platform/impl/text_trace_logging_platform.h" +#include "util/chrono_helpers.h" #include "util/stringprintf.h" #include "util/trace_logging.h" @@ -93,8 +96,11 @@ ErrorOr<std::unique_ptr<DiscoveryState>> StartDiscovery( return state; } -void StartCastAgent(TaskRunnerImpl* task_runner, InterfaceInfo interface) { - CastAgent agent(task_runner, interface); +void StartCastAgent(TaskRunnerImpl* task_runner, + InterfaceInfo interface, + GeneratedCredentials* creds) { + CastAgent agent(task_runner, interface, creds->provider.get(), + creds->tls_credentials); const auto error = agent.Start(); if (!error.ok()) { OSP_LOG_ERROR << "Error occurred while starting agent: " << error; @@ -173,15 +179,19 @@ int RunStandaloneReceiver(int argc, char* argv[]) { : openscreen::LogLevel::kInfo); auto* const task_runner = new TaskRunnerImpl(&Clock::now); - PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}, + PlatformClientPosix::Create(milliseconds(50), milliseconds(50), std::unique_ptr<TaskRunnerImpl>(task_runner)); auto discovery_state = StartDiscovery(task_runner, interface_info); OSP_CHECK(discovery_state.is_value()) << "Failed to start discovery."; + auto creds = GenerateCredentials( + absl::StrCat("Standalone Receiver on ", argv[optind])); + OSP_CHECK(creds.is_value()); + // Runs until the process is interrupted. Safe to pass |task_runner| as it // will not be destroyed by ShutDown() until this exits. - StartCastAgent(task_runner, interface_info); + StartCastAgent(task_runner, interface_info, &(creds.value())); // The task runner must be deleted after all serial delete pointers, such // as the one stored in the discovery state. diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc index 5e517d80005..3392b706615 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_audio_player.cc @@ -6,17 +6,15 @@ #include <chrono> #include <sstream> +#include <utility> #include "absl/types/span.h" #include "cast/standalone_receiver/avcodec_glue.h" #include "util/big_endian.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/trace_logging.h" -using std::chrono::duration_cast; -using std::chrono::milliseconds; -using std::chrono::seconds; - namespace openscreen { namespace cast { @@ -117,9 +115,9 @@ ErrorOr<Clock::time_point> SDLAudioPlayer::RenderNextFrame( pending_audio_spec_.samples *= 2; } - approximate_lead_time_ = (pending_audio_spec_.samples * - duration_cast<Clock::duration>(kOneSecond)) / - pending_audio_spec_.freq; + approximate_lead_time_ = + (pending_audio_spec_.samples * Clock::to_duration(kOneSecond)) / + pending_audio_spec_.freq; } // If the decoded audio is in planar format, interleave it for SDL. diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc index 6f793f8a4de..ab2f1327cf1 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/sdl_player_base.cc @@ -6,17 +6,16 @@ #include <chrono> #include <sstream> +#include <utility> #include "absl/types/span.h" #include "cast/standalone_receiver/avcodec_glue.h" #include "cast/streaming/encoded_frame.h" #include "util/big_endian.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/trace_logging.h" -using std::chrono::duration_cast; -using std::chrono::milliseconds; - namespace openscreen { namespace cast { @@ -74,8 +73,7 @@ Clock::time_point SDLPlayerBase::ResyncAndDeterminePresentationTime( .ToDuration<Clock::duration>(receiver_->rtp_timebase()); Clock::time_point presentation_time = last_sync_reference_time_ + media_time_since_last_sync; - const auto drift = - duration_cast<milliseconds>(frame.reference_time - presentation_time); + const auto drift = to_microseconds(frame.reference_time - presentation_time); if (drift > kMaxPlayoutDrift || drift < -kMaxPlayoutDrift) { // Only log if not the very first frame. OSP_LOG_IF(INFO, frame.frame_id != FrameId::first()) diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/static_credentials.cc b/chromium/third_party/openscreen/src/cast/standalone_receiver/static_credentials.cc new file mode 100644 index 00000000000..9980f20caf7 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/static_credentials.cc @@ -0,0 +1,139 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_receiver/static_credentials.h" + +#include <openssl/mem.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "cast/standalone_receiver/private_key_der.h" +#include "platform/base/tls_credentials.h" +#include "util/crypto/certificate_utils.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace cast { +namespace { + +constexpr int kThreeDaysInSeconds = 3 * 24 * 60 * 60; +constexpr auto kCertificateDuration = std::chrono::seconds(kThreeDaysInSeconds); + +} // namespace + +StaticCredentialsProvider::StaticCredentialsProvider() = default; +StaticCredentialsProvider::StaticCredentialsProvider( + DeviceCredentials device_creds, + std::vector<uint8_t> tls_cert_der) + : device_creds(std::move(device_creds)), + tls_cert_der(std::move(tls_cert_der)) {} + +StaticCredentialsProvider::StaticCredentialsProvider( + StaticCredentialsProvider&&) = default; +StaticCredentialsProvider& StaticCredentialsProvider::operator=( + StaticCredentialsProvider&&) = default; +StaticCredentialsProvider::~StaticCredentialsProvider() = default; + +ErrorOr<GeneratedCredentials> GenerateCredentials( + absl::string_view device_certificate_id) { + GeneratedCredentials credentials; + + bssl::UniquePtr<EVP_PKEY> root_key = GenerateRsaKeyPair(); + bssl::UniquePtr<EVP_PKEY> intermediate_key = GenerateRsaKeyPair(); + bssl::UniquePtr<EVP_PKEY> device_key = GenerateRsaKeyPair(); + OSP_CHECK(root_key); + OSP_CHECK(intermediate_key); + OSP_CHECK(device_key); + + ErrorOr<bssl::UniquePtr<X509>> root_cert_or_error = + CreateSelfSignedX509Certificate("Cast Root CA", kCertificateDuration, + *root_key, GetWallTimeSinceUnixEpoch(), + true); + OSP_CHECK(root_cert_or_error); + bssl::UniquePtr<X509> root_cert = std::move(root_cert_or_error.value()); + + ErrorOr<bssl::UniquePtr<X509>> intermediate_cert_or_error = + CreateSelfSignedX509Certificate( + "Cast Intermediate", kCertificateDuration, *intermediate_key, + GetWallTimeSinceUnixEpoch(), true, root_cert.get(), root_key.get()); + OSP_CHECK(intermediate_cert_or_error); + bssl::UniquePtr<X509> intermediate_cert = + std::move(intermediate_cert_or_error.value()); + + ErrorOr<bssl::UniquePtr<X509>> device_cert_or_error = + CreateSelfSignedX509Certificate( + device_certificate_id, kCertificateDuration, *device_key, + GetWallTimeSinceUnixEpoch(), false, intermediate_cert.get(), + intermediate_key.get()); + OSP_CHECK(device_cert_or_error); + bssl::UniquePtr<X509> device_cert = std::move(device_cert_or_error.value()); + + // NOTE: Device cert chain plumbing + serialization. + DeviceCredentials device_creds; + device_creds.private_key = std::move(device_key); + + int cert_length = i2d_X509(device_cert.get(), nullptr); + std::string cert_serial(cert_length, 0); + uint8_t* out = reinterpret_cast<uint8_t*>(&cert_serial[0]); + i2d_X509(device_cert.get(), &out); + device_creds.certs.emplace_back(std::move(cert_serial)); + + cert_length = i2d_X509(intermediate_cert.get(), nullptr); + cert_serial.resize(cert_length); + out = reinterpret_cast<uint8_t*>(&cert_serial[0]); + i2d_X509(intermediate_cert.get(), &out); + device_creds.certs.emplace_back(std::move(cert_serial)); + + cert_length = i2d_X509(root_cert.get(), nullptr); + std::vector<uint8_t> trust_anchor_der(cert_length); + out = &trust_anchor_der[0]; + i2d_X509(root_cert.get(), &out); + + // NOTE: TLS key pair + certificate generation. + bssl::UniquePtr<EVP_PKEY> tls_key = GenerateRsaKeyPair(); + OSP_CHECK_EQ(EVP_PKEY_id(tls_key.get()), EVP_PKEY_RSA); + ErrorOr<bssl::UniquePtr<X509>> tls_cert_or_error = + CreateSelfSignedX509Certificate("Test Device TLS", kCertificateDuration, + *tls_key, GetWallTimeSinceUnixEpoch()); + OSP_CHECK(tls_cert_or_error); + bssl::UniquePtr<X509> tls_cert = std::move(tls_cert_or_error.value()); + + // NOTE: TLS private key serialization. + RSA* rsa_key = EVP_PKEY_get0_RSA(tls_key.get()); + size_t pkey_len = 0; + uint8_t* pkey_bytes = nullptr; + OSP_CHECK(RSA_private_key_to_bytes(&pkey_bytes, &pkey_len, rsa_key)); + OSP_CHECK_GT(pkey_len, 0u); + std::vector<uint8_t> tls_key_serial(pkey_bytes, pkey_bytes + pkey_len); + OPENSSL_free(pkey_bytes); + + // NOTE: TLS public key serialization. + pkey_len = 0; + pkey_bytes = nullptr; + OSP_CHECK(RSA_public_key_to_bytes(&pkey_bytes, &pkey_len, rsa_key)); + OSP_CHECK_GT(pkey_len, 0u); + std::vector<uint8_t> tls_pub_serial(pkey_bytes, pkey_bytes + pkey_len); + OPENSSL_free(pkey_bytes); + + // NOTE: TLS cert serialization. + cert_length = 0; + cert_length = i2d_X509(tls_cert.get(), nullptr); + OSP_CHECK_GT(cert_length, 0); + std::vector<uint8_t> tls_cert_serial(cert_length); + out = &tls_cert_serial[0]; + i2d_X509(tls_cert.get(), &out); + + return GeneratedCredentials{ + std::make_unique<StaticCredentialsProvider>(std::move(device_creds), + tls_cert_serial), + TlsCredentials{std::move(tls_key_serial), std::move(tls_pub_serial), + std::move(tls_cert_serial)}, + std::move(trust_anchor_der)}; +} + +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/standalone_receiver/static_credentials.h b/chromium/third_party/openscreen/src/cast/standalone_receiver/static_credentials.h new file mode 100644 index 00000000000..4707f5f40af --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/standalone_receiver/static_credentials.h @@ -0,0 +1,60 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_STATIC_CREDENTIALS_H_ +#define CAST_STANDALONE_RECEIVER_STATIC_CREDENTIALS_H_ + +#include <memory> +#include <vector> + +#include "absl/strings/string_view.h" +#include "cast/receiver/channel/device_auth_namespace_handler.h" +#include "platform/base/error.h" +#include "platform/base/tls_credentials.h" + +namespace openscreen { +namespace cast { + +class StaticCredentialsProvider final + : public DeviceAuthNamespaceHandler::CredentialsProvider { + public: + StaticCredentialsProvider(); + StaticCredentialsProvider(DeviceCredentials device_creds, + std::vector<uint8_t> tls_cert_der); + + StaticCredentialsProvider(const StaticCredentialsProvider&) = delete; + StaticCredentialsProvider(StaticCredentialsProvider&&); + StaticCredentialsProvider& operator=(const StaticCredentialsProvider&) = + delete; + StaticCredentialsProvider& operator=(StaticCredentialsProvider&&); + ~StaticCredentialsProvider(); + + absl::Span<const uint8_t> GetCurrentTlsCertAsDer() override { + return absl::Span<uint8_t>(tls_cert_der); + } + const DeviceCredentials& GetCurrentDeviceCredentials() override { + return device_creds; + } + + DeviceCredentials device_creds; + std::vector<uint8_t> tls_cert_der; +}; + +struct GeneratedCredentials { + std::unique_ptr<StaticCredentialsProvider> provider; + TlsCredentials tls_credentials; + std::vector<uint8_t> root_cert_der; +}; + +// Generates a valid set of credentials for use with the TLS Server socket, +// including a generated X509 certificate generated from the static private key +// stored in private_key_der.h. The certificate is valid for +// kCertificateDuration from when this function is called. +ErrorOr<GeneratedCredentials> GenerateCredentials( + absl::string_view device_certificate_id); + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_STATIC_CREDENTIALS_H_ diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc index c621394bd4b..ef417132eb5 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/main.cc @@ -4,7 +4,7 @@ #include <getopt.h> -#include <chrono> // NOLINT +#include <chrono> #include <cinttypes> #include <csignal> #include <cstdio> @@ -26,6 +26,7 @@ #include "platform/impl/task_runner.h" #include "platform/impl/text_trace_logging_platform.h" #include "util/alarm.h" +#include "util/chrono_helpers.h" #if defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) #include "cast/standalone_sender/simulated_capturer.h" @@ -37,10 +38,6 @@ namespace openscreen { namespace cast { namespace { -using std::chrono::duration_cast; -using std::chrono::milliseconds; -using std::chrono::seconds; - //////////////////////////////////////////////////////////////////////////////// // Sender Configuration // @@ -115,7 +112,7 @@ class LoopingFileSender final : public SimulatedAudioCapturer::Client, const IPEndpoint& remote_endpoint, int max_bitrate, bool use_android_rtp_hack) - : env_(&Clock::now, task_runner, IPEndpoint{IPAddress(), 0}), + : env_(&Clock::now, task_runner), path_(path), packet_router_(&env_), max_bitrate_(max_bitrate), @@ -247,9 +244,8 @@ class LoopingFileSender final : public SimulatedAudioCapturer::Client, void UpdateStatusOnConsole() { const Clock::duration elapsed = latest_frame_time_ - capture_start_time_; - const auto seconds_part = duration_cast<seconds>(elapsed); - const auto millis_part = - duration_cast<milliseconds>(elapsed - seconds_part); + const auto seconds_part = to_seconds(elapsed); + const auto millis_part = to_microseconds(elapsed - seconds_part); // The control codes here attempt to erase the current line the cursor is // on, and then print out the updated status text. If the terminal does not // support simple ANSI escape codes, the following will still work, but @@ -333,7 +329,7 @@ class LoopingFileSender final : public SimulatedAudioCapturer::Client, #endif // defined(CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS) IPEndpoint GetDefaultEndpoint() { - return IPEndpoint{IPAddress::kV4LoopbackAddress, kDefaultCastStreamingPort}; + return IPEndpoint{IPAddress::kV4LoopbackAddress(), kDefaultCastStreamingPort}; } void LogUsage(const char* argv0) { diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc index 405922be5fa..87313010db5 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/simulated_capturer.cc @@ -5,10 +5,10 @@ #include "cast/standalone_sender/simulated_capturer.h" #include <algorithm> -#include <chrono> // NOLINT -#include <ratio> // NOLINT +#include <chrono> +#include <ratio> #include <sstream> -#include <thread> // NOLINT +#include <thread> #include "cast/streaming/environment.h" #include "util/osp_logging.h" diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc index 07aeaaefd13..ef9cc577564 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_opus_encoder.cc @@ -7,15 +7,11 @@ #include <opus/opus.h> #include <algorithm> -#include <chrono> // NOLINT +#include <chrono> namespace openscreen { namespace cast { -using std::chrono::duration_cast; -using std::chrono::microseconds; -using std::chrono::seconds; - using openscreen::operator<<; // To pretty-print chrono values. namespace { @@ -38,7 +34,7 @@ StreamingOpusEncoder::StreamingOpusEncoder(int num_channels, sender_(sender), samples_per_cast_frame_(sample_rate() / cast_frames_per_second), approximate_cast_frame_duration_( - duration_cast<Clock::duration>(seconds(1)) / cast_frames_per_second), + Clock::to_duration(std::chrono::seconds(1)) / cast_frames_per_second), encoder_storage_(new uint8_t[opus_encoder_get_size(num_channels_)]), input_(new float[num_channels_ * samples_per_cast_frame_]), output_(new uint8_t[kOpusMaxPayloadSize]) { diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc index d5aafc94ed1..066e37f5ae1 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.cc @@ -8,22 +8,20 @@ #include <string.h> #include <vpx/vp8cx.h> +#include <chrono> #include <cmath> #include <utility> #include "cast/streaming/encoded_frame.h" #include "cast/streaming/environment.h" #include "cast/streaming/sender.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/saturate_cast.h" namespace openscreen { namespace cast { -using std::chrono::duration_cast; -using std::chrono::milliseconds; -using std::chrono::seconds; - // TODO(https://crbug.com/openscreen/123): Fix the declarations and then remove // this: using openscreen::operator<<; // For std::chrono::duration pretty-printing. @@ -371,7 +369,7 @@ void StreamingVp8Encoder::ComputeFrameEncodeStats( constexpr double kBytesPerBit = 1.0 / CHAR_BIT; constexpr double kSecondsPerClockTick = - 1.0 / duration_cast<Clock::duration>(seconds(1)).count(); + 1.0 / Clock::to_duration(seconds(1)).count(); const double target_bytes_per_clock_tick = target_bitrate * (kBytesPerBit * kSecondsPerClockTick); stats.target_size = target_bytes_per_clock_tick * work_unit->duration.count(); diff --git a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h index 1c64cafc5b0..c5d52248a61 100644 --- a/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h +++ b/chromium/third_party/openscreen/src/cast/standalone_sender/streaming_vp8_encoder.h @@ -12,9 +12,9 @@ #include <condition_variable> // NOLINT #include <functional> #include <memory> -#include <mutex> // NOLINT +#include <mutex> #include <queue> -#include <thread> // NOLINT +#include <thread> #include <vector> #include "absl/base/thread_annotations.h" diff --git a/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn b/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn index 688a453f1ce..e384705be9e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/streaming/BUILD.gn @@ -7,6 +7,10 @@ import("../../testing/libfuzzer/fuzzer_test.gni") source_set("common") { sources = [ + "answer_messages.cc", + "answer_messages.h", + "capture_recommendations.cc", + "capture_recommendations.h", "clock_drift_smoother.cc", "clock_drift_smoother.h", "constants.h", @@ -21,6 +25,8 @@ source_set("common") { "frame_id.h", "ntp_time.cc", "ntp_time.h", + "offer_messages.cc", + "offer_messages.h", "packet_util.cc", "packet_util.h", "rtcp_common.cc", @@ -52,14 +58,10 @@ source_set("common") { source_set("receiver") { sources = [ - "answer_messages.cc", - "answer_messages.h", "compound_rtcp_builder.cc", "compound_rtcp_builder.h", "frame_collector.cc", "frame_collector.h", - "offer_messages.cc", - "offer_messages.h", "packet_receive_stats_tracker.cc", "packet_receive_stats_tracker.h", "receiver.cc", @@ -74,13 +76,9 @@ source_set("receiver") { "sender_report_parser.h", ] - public_deps = [ - ":common", - ] + public_deps = [ ":common" ] - deps = [ - "../../util", - ] + deps = [ "../../util" ] } source_set("sender") { @@ -99,9 +97,7 @@ source_set("sender") { "sender_report_builder.h", ] - public_deps = [ - ":common", - ] + public_deps = [ ":common" ] } source_set("unittests") { @@ -110,6 +106,7 @@ source_set("unittests") { sources = [ "answer_messages_unittest.cc", "bandwidth_estimator_unittest.cc", + "capture_recommendations_unittest.cc", "compound_rtcp_builder_unittest.cc", "compound_rtcp_parser_unittest.cc", "expanded_value_base_unittest.cc", @@ -144,9 +141,7 @@ source_set("unittests") { } openscreen_fuzzer_test("compound_rtcp_parser_fuzzer") { - sources = [ - "compound_rtcp_parser_fuzzer.cc", - ] + sources = [ "compound_rtcp_parser_fuzzer.cc" ] deps = [ ":sender", @@ -160,9 +155,7 @@ openscreen_fuzzer_test("compound_rtcp_parser_fuzzer") { } openscreen_fuzzer_test("rtp_packet_parser_fuzzer") { - sources = [ - "rtp_packet_parser_fuzzer.cc", - ] + sources = [ "rtp_packet_parser_fuzzer.cc" ] deps = [ ":receiver", @@ -176,9 +169,7 @@ openscreen_fuzzer_test("rtp_packet_parser_fuzzer") { } openscreen_fuzzer_test("sender_report_parser_fuzzer") { - sources = [ - "sender_report_parser_fuzzer.cc", - ] + sources = [ "sender_report_parser_fuzzer.cc" ] deps = [ ":receiver", diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc index e62f5a1ff70..16d99fc1824 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.cc @@ -7,8 +7,9 @@ #include <utility> #include "absl/strings/str_cat.h" -#include "cast/streaming/message_util.h" +#include "absl/strings/str_split.h" #include "platform/base/error.h" +#include "util/json/json_helpers.h" #include "util/osp_logging.h" namespace openscreen { @@ -16,28 +17,145 @@ namespace cast { namespace { -static constexpr char kMessageKeyType[] = "type"; -static constexpr char kMessageTypeAnswer[] = "ANSWER"; - -// List of ANSWER message fields. -static constexpr char kAnswerMessageBody[] = "answer"; -static constexpr char kResult[] = "result"; -static constexpr char kResultOk[] = "ok"; -static constexpr char kResultError[] = "error"; -static constexpr char kErrorMessageBody[] = "error"; -static constexpr char kErrorCode[] = "code"; -static constexpr char kErrorDescription[] = "description"; +/// NOTE: Constants here are all taken from the Cast V2: Mirroring Control +/// Protocol specification: http://goto.google.com/mirroring-control-protocol +// TODO(jophba): document the protocol in a public repository. + +/// Constraint properties. +// Audio constraints. See properties below. +static constexpr char kAudio[] = "audio"; +// Video constraints. See properties below. +static constexpr char kVideo[] = "video"; + +// An optional field representing the minimum bits per second. If not specified +// by the receiver, the sender will use kDefaultAudioMinBitRate and +// kDefaultVideoMinBitRate, which represent the true operational minimum. +static constexpr char kMinBitRate[] = "minBitRate"; +// 32kbps is sender default for audio minimum bit rate. +static constexpr int kDefaultAudioMinBitRate = 32 * 1000; +// 300kbps is sender default for video minimum bit rate. +static constexpr int kDefaultVideoMinBitRate = 300 * 1000; + +// Maximum encoded bits per second. This is the lower of (1) the max capability +// of the decoder, or (2) the max data transfer rate. +static constexpr char kMaxBitRate[] = "maxBitRate"; +// Maximum supported end-to-end latency, in milliseconds. Proportional to the +// size of the data buffers in the receiver. +static constexpr char kMaxDelay[] = "maxDelay"; + +/// Video constraint properties. +// Maximum pixel rate (width * height * framerate). Is often less than +// multiplying the fields in maxDimensions. This field is used to set the +// maximum processing rate. +static constexpr char kMaxPixelsPerSecond[] = "maxPixelsPerSecond"; +// Minimum dimensions. If omitted, the sender will assume a reasonable minimum +// with the same aspect ratio as maxDimensions, as close to 320*180 as possible. +// Should reflect the true operational minimum. +static constexpr char kMinDimensions[] = "minDimensions"; +// Maximum dimensions, not necessarily ideal dimensions. +static constexpr char kMaxDimensions[] = "maxDimensions"; + +/// Audio constraint properties. +// Maximum supported sampling frequency (not necessarily ideal). +static constexpr char kMaxSampleRate[] = "maxSampleRate"; +// Maximum number of audio channels (1 is mono, 2 is stereo, etc.). +static constexpr char kMaxChannels[] = "maxChannels"; + +/// Dimension properties. +// Width in pixels. +static constexpr char kWidth[] = "width"; +// Height in pixels. +static constexpr char kHeight[] = "height"; +// Frame rate as a rational decimal number or fraction. +// E.g. 30 and "3000/1001" are both valid representations. +static constexpr char kFrameRate[] = "frameRate"; + +/// Display description properties +// If this optional field is included in the ANSWER message, the receiver is +// attached to a fixed display that has the given dimensions and frame rate +// configuration. These may exceed, be the same, or be less than the values in +// constraints. If undefined, we assume the display is not fixed (e.g. a Google +// Hangouts UI panel). +static constexpr char kDimensions[] = "dimensions"; +// An optional field. When missing and dimensions are specified, the sender +// will assume square pixels and the dimensions imply the aspect ratio of the +// fixed display. WHen present and dimensions are also specified, implies the +// pixels are not square. +static constexpr char kAspectRatio[] = "aspectRatio"; +// The delimeter used for the aspect ratio format ("A:B"). +static constexpr char kAspectRatioDelimiter[] = ":"; +// Sets the aspect ratio constraints. Value must be either "sender" or +// "receiver", see kScalingSender and kScalingReceiver below. +static constexpr char kScaling[] = "scaling"; +// sclaing = "sender" means that the sender must provide video frames of a fixed +// aspect ratio. In this case, the dimensions object must be passed or an error +// case will occur. +static constexpr char kScalingSender[] = "sender"; +// scaling = "receiver" means that the sender may send arbitrarily sized frames, +// and the receiver will handle scaling and letterboxing as necessary. +static constexpr char kScalingReceiver[] = "receiver"; + +/// Answer properties. +// A number specifying the UDP port used for all streams in this session. +// Must have a value between kUdpPortMin and kUdpPortMax. +static constexpr char kUdpPort[] = "udpPort"; +static constexpr int kUdpPortMin = 1; +static constexpr int kUdpPortMax = 65535; +// Numbers specifying the indexes chosen from the offer message. +static constexpr char kSendIndexes[] = "sendIndexes"; +// uint32_t values specifying the RTP SSRC values used to send the RTCP feedback +// of the stream indicated in kSendIndexes. +static constexpr char kSsrcs[] = "ssrcs"; +// Provides detailed maximum and minimum capabilities of the receiver for +// processing the selected streams. The sender may alter video resolution and +// frame rate throughout the session, and the constraints here determine how +// much data volume is allowed. +static constexpr char kConstraints[] = "constraints"; +// Provides details about the display on the receiver. +static constexpr char kDisplay[] = "display"; +// absl::optional array of numbers specifying the indexes of streams that will +// send event logs through RTCP. +static constexpr char kReceiverRtcpEventLog[] = "receiverRtcpEventLog"; +// OPtional array of numbers specifying the indexes of streams that will use +// DSCP values specified in the OFFER message for RTCP packets. +static constexpr char kReceiverRtcpDscp[] = "receiverRtcpDscp"; +// True if receiver can report wifi status. +static constexpr char kReceiverGetStatus[] = "receiverGetStatus"; +// If this optional field is present the receiver supports the specific +// RTP extensions (such as adaptive playout delay). +static constexpr char kRtpExtensions[] = "rtpExtensions"; Json::Value AspectRatioConstraintToJson(AspectRatioConstraint aspect_ratio) { switch (aspect_ratio) { case AspectRatioConstraint::kVariable: - return Json::Value("receiver"); + return Json::Value(kScalingReceiver); case AspectRatioConstraint::kFixed: default: - return Json::Value("sender"); + return Json::Value(kScalingSender); } } +bool AspectRatioConstraintParseAndValidate(const Json::Value& value, + AspectRatioConstraint* out) { + // the aspect ratio constraint is an optional field. + if (!value) { + return true; + } + + std::string aspect_ratio; + if (!json::ParseAndValidateString(value, &aspect_ratio)) { + return false; + } + if (aspect_ratio == kScalingReceiver) { + *out = AspectRatioConstraint::kVariable; + return true; + } else if (aspect_ratio == kScalingSender) { + *out = AspectRatioConstraint::kFixed; + return true; + } + return false; +} + template <typename T> Json::Value PrimitiveVectorToJson(const std::vector<T>& vec) { Json::Value array(Json::ValueType::arrayValue); @@ -50,159 +168,293 @@ Json::Value PrimitiveVectorToJson(const std::vector<T>& vec) { return array; } +template <typename T> +bool ParseOptional(const Json::Value& value, absl::optional<T>* out) { + // It's fine if the value is empty. + if (!value) { + return true; + } + T tentative_out; + if (!T::ParseAndValidate(value, &tentative_out)) { + return false; + } + *out = tentative_out; + return true; +} + } // namespace -ErrorOr<Json::Value> AudioConstraints::ToJson() const { - if (max_sample_rate <= 0 || max_channels <= 0 || min_bit_rate <= 0 || - max_bit_rate < min_bit_rate) { - return CreateParameterError("AudioConstraints"); +// static +bool AspectRatio::ParseAndValidate(const Json::Value& value, AspectRatio* out) { + std::string parsed_value; + if (!json::ParseAndValidateString(value, &parsed_value)) { + return false; + } + + std::vector<absl::string_view> fields = + absl::StrSplit(parsed_value, kAspectRatioDelimiter); + if (fields.size() != 2) { + return false; + } + + if (!absl::SimpleAtoi(fields[0], &out->width) || + !absl::SimpleAtoi(fields[1], &out->height)) { + return false; } + return out->IsValid(); +} +bool AspectRatio::IsValid() const { + return width > 0 && height > 0; +} + +// static +bool AudioConstraints::ParseAndValidate(const Json::Value& root, + AudioConstraints* out) { + if (!json::ParseAndValidateInt(root[kMaxSampleRate], + &(out->max_sample_rate)) || + !json::ParseAndValidateInt(root[kMaxChannels], &(out->max_channels)) || + !json::ParseAndValidateInt(root[kMaxBitRate], &(out->max_bit_rate)) || + !json::ParseAndValidateMilliseconds(root[kMaxDelay], &(out->max_delay))) { + return false; + } + if (!json::ParseAndValidateInt(root[kMinBitRate], &(out->min_bit_rate))) { + out->min_bit_rate = kDefaultAudioMinBitRate; + } + return out->IsValid(); +} + +Json::Value AudioConstraints::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; - root["maxSampleRate"] = max_sample_rate; - root["maxChannels"] = max_channels; - root["minBitRate"] = min_bit_rate; - root["maxBitRate"] = max_bit_rate; - root["maxDelay"] = Json::Value::Int64(max_delay.count()); + root[kMaxSampleRate] = max_sample_rate; + root[kMaxChannels] = max_channels; + root[kMinBitRate] = min_bit_rate; + root[kMaxBitRate] = max_bit_rate; + root[kMaxDelay] = Json::Value::Int64(max_delay.count()); return root; } -ErrorOr<Json::Value> Dimensions::ToJson() const { - if (width <= 0 || height <= 0 || !frame_rate.is_defined() || - !frame_rate.is_positive()) { - return CreateParameterError("Dimensions"); +bool AudioConstraints::IsValid() const { + return max_sample_rate > 0 && max_channels > 0 && min_bit_rate > 0 && + max_bit_rate >= min_bit_rate; +} + +bool Dimensions::ParseAndValidate(const Json::Value& root, Dimensions* out) { + if (!json::ParseAndValidateInt(root[kWidth], &(out->width)) || + !json::ParseAndValidateInt(root[kHeight], &(out->height)) || + !json::ParseAndValidateSimpleFraction(root[kFrameRate], + &(out->frame_rate))) { + return false; } + return out->IsValid(); +} + +bool Dimensions::IsValid() const { + return width > 0 && height > 0 && frame_rate.is_positive(); +} +Json::Value Dimensions::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; - root["width"] = width; - root["height"] = height; - root["frameRate"] = frame_rate.ToString(); + root[kWidth] = width; + root[kHeight] = height; + root[kFrameRate] = frame_rate.ToString(); return root; } -ErrorOr<Json::Value> VideoConstraints::ToJson() const { - if (max_pixels_per_second <= 0 || min_bit_rate <= 0 || - max_bit_rate < min_bit_rate || max_delay.count() <= 0) { - return CreateParameterError("VideoConstraints"); +// static +bool VideoConstraints::ParseAndValidate(const Json::Value& root, + VideoConstraints* out) { + if (!json::ParseAndValidateDouble(root[kMaxPixelsPerSecond], + &(out->max_pixels_per_second)) || + !Dimensions::ParseAndValidate(root[kMaxDimensions], + &(out->max_dimensions)) || + !json::ParseAndValidateInt(root[kMaxBitRate], &(out->max_bit_rate)) || + !json::ParseAndValidateMilliseconds(root[kMaxDelay], &(out->max_delay)) || + !ParseOptional<Dimensions>(root[kMinDimensions], + &(out->min_dimensions))) { + return false; } - - auto error_or_min_dim = min_dimensions.ToJson(); - if (error_or_min_dim.is_error()) { - return error_or_min_dim.error(); + if (!json::ParseAndValidateInt(root[kMinBitRate], &(out->min_bit_rate))) { + out->min_bit_rate = kDefaultVideoMinBitRate; } + return out->IsValid(); +} - auto error_or_max_dim = max_dimensions.ToJson(); - if (error_or_max_dim.is_error()) { - return error_or_max_dim.error(); - } +bool VideoConstraints::IsValid() const { + return max_pixels_per_second > 0 && min_bit_rate > 0 && + max_bit_rate > min_bit_rate && max_delay.count() > 0 && + max_dimensions.IsValid() && + (!min_dimensions.has_value() || min_dimensions->IsValid()) && + max_dimensions.frame_rate.numerator > 0; +} +Json::Value VideoConstraints::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; - root["maxPixelsPerSecond"] = max_pixels_per_second; - root["minDimensions"] = error_or_min_dim.value(); - root["maxDimensions"] = error_or_max_dim.value(); - root["minBitRate"] = min_bit_rate; - root["maxBitRate"] = max_bit_rate; - root["maxDelay"] = Json::Value::Int64(max_delay.count()); + root[kMaxPixelsPerSecond] = max_pixels_per_second; + if (min_dimensions.has_value()) { + root[kMinDimensions] = min_dimensions->ToJson(); + } + root[kMaxDimensions] = max_dimensions.ToJson(); + root[kMinBitRate] = min_bit_rate; + root[kMaxBitRate] = max_bit_rate; + root[kMaxDelay] = Json::Value::Int64(max_delay.count()); return root; } -ErrorOr<Json::Value> Constraints::ToJson() const { - auto audio_or_error = audio.ToJson(); - if (audio_or_error.is_error()) { - return audio_or_error.error(); +// static +bool Constraints::ParseAndValidate(const Json::Value& root, Constraints* out) { + if (!AudioConstraints::ParseAndValidate(root[kAudio], &(out->audio)) || + !VideoConstraints::ParseAndValidate(root[kVideo], &(out->video))) { + return false; } + return out->IsValid(); +} - auto video_or_error = video.ToJson(); - if (video_or_error.is_error()) { - return video_or_error.error(); - } +bool Constraints::IsValid() const { + return audio.IsValid() && video.IsValid(); +} +Json::Value Constraints::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; - root["audio"] = audio_or_error.value(); - root["video"] = video_or_error.value(); + root[kAudio] = audio.ToJson(); + root[kVideo] = video.ToJson(); return root; } -ErrorOr<Json::Value> DisplayDescription::ToJson() const { - if (aspect_ratio.width < 1 || aspect_ratio.height < 1) { - return CreateParameterError("DisplayDescription"); +// static +bool DisplayDescription::ParseAndValidate(const Json::Value& root, + DisplayDescription* out) { + if (!ParseOptional<Dimensions>(root[kDimensions], &(out->dimensions)) || + !ParseOptional<AspectRatio>(root[kAspectRatio], &(out->aspect_ratio))) { + return false; + } + + AspectRatioConstraint constraint; + if (AspectRatioConstraintParseAndValidate(root[kScaling], &constraint)) { + out->aspect_ratio_constraint = + absl::optional<AspectRatioConstraint>(std::move(constraint)); + } else { + out->aspect_ratio_constraint = absl::nullopt; } - auto dimensions_or_error = dimensions.ToJson(); - if (dimensions_or_error.is_error()) { - return dimensions_or_error.error(); + return out->IsValid(); +} + +bool DisplayDescription::IsValid() const { + // At least one of the properties must be set, and if a property is set + // it must be valid. + if (aspect_ratio.has_value() && !aspect_ratio->IsValid()) { + return false; + } + if (dimensions.has_value() && !dimensions->IsValid()) { + return false; } + // Sender behavior is undefined if the aspect ratio is fixed but no + // dimensions or aspect ratio are provided. + if (aspect_ratio_constraint.has_value() && + (aspect_ratio_constraint.value() == AspectRatioConstraint::kFixed) && + !dimensions.has_value() && !aspect_ratio.has_value()) { + return false; + } + return aspect_ratio.has_value() || dimensions.has_value() || + aspect_ratio_constraint.has_value(); +} +Json::Value DisplayDescription::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; - root["dimensions"] = dimensions_or_error.value(); - root["aspectRatio"] = - absl::StrCat(aspect_ratio.width, ":", aspect_ratio.height); - root["scaling"] = AspectRatioConstraintToJson(aspect_ratio_constraint); + if (aspect_ratio.has_value()) { + root[kAspectRatio] = absl::StrCat( + aspect_ratio->width, kAspectRatioDelimiter, aspect_ratio->height); + } + if (dimensions.has_value()) { + root[kDimensions] = dimensions->ToJson(); + } + if (aspect_ratio_constraint.has_value()) { + root[kScaling] = + AspectRatioConstraintToJson(aspect_ratio_constraint.value()); + } return root; } -ErrorOr<Json::Value> Answer::ToJson() const { - if (udp_port <= 0 || udp_port > 65535) { - return CreateParameterError("Answer - UDP Port number"); +bool Answer::ParseAndValidate(const Json::Value& root, Answer* out) { + if (!json::ParseAndValidateInt(root[kUdpPort], &(out->udp_port)) || + !json::ParseAndValidateIntArray(root[kSendIndexes], + &(out->send_indexes)) || + !json::ParseAndValidateUintArray(root[kSsrcs], &(out->ssrcs)) || + !ParseOptional<Constraints>(root[kConstraints], &(out->constraints)) || + !ParseOptional<DisplayDescription>(root[kDisplay], &(out->display))) { + return false; + } + if (!json::ParseBool(root[kReceiverGetStatus], + &(out->supports_wifi_status_reporting))) { + out->supports_wifi_status_reporting = false; } - Json::Value root; - if (constraints) { - auto constraints_or_error = constraints.value().ToJson(); - if (constraints_or_error.is_error()) { - return constraints_or_error.error(); - } - root["constraints"] = constraints_or_error.value(); + // These function set to empty array if not present, so we can ignore + // the return value for optional values. + json::ParseAndValidateIntArray(root[kReceiverRtcpEventLog], + &(out->receiver_rtcp_event_log)); + json::ParseAndValidateIntArray(root[kReceiverRtcpDscp], + &(out->receiver_rtcp_dscp)); + json::ParseAndValidateStringArray(root[kRtpExtensions], + &(out->rtp_extensions)); + + return out->IsValid(); +} + +bool Answer::IsValid() const { + if (ssrcs.empty() || send_indexes.empty()) { + return false; } - if (display) { - auto display_or_error = display.value().ToJson(); - if (display_or_error.is_error()) { - return display_or_error.error(); + // We don't know what the indexes used in the offer were here, so we just + // sanity check. + for (const int index : send_indexes) { + if (index < 0) { + return false; } - root["display"] = display_or_error.value(); } + if (constraints.has_value() && !constraints->IsValid()) { + return false; + } + if (display.has_value() && !display->IsValid()) { + return false; + } + return kUdpPortMin <= udp_port && udp_port <= kUdpPortMax; +} - root["castMode"] = cast_mode.ToString(); - root["udpPort"] = udp_port; - root["receiverGetStatus"] = supports_wifi_status_reporting; - root["sendIndexes"] = PrimitiveVectorToJson(send_indexes); - root["ssrcs"] = PrimitiveVectorToJson(ssrcs); +Json::Value Answer::ToJson() const { + OSP_DCHECK(IsValid()); + Json::Value root; + if (constraints.has_value()) { + root[kConstraints] = constraints->ToJson(); + } + if (display.has_value()) { + root[kDisplay] = display->ToJson(); + } + root[kUdpPort] = udp_port; + root[kReceiverGetStatus] = supports_wifi_status_reporting; + root[kSendIndexes] = PrimitiveVectorToJson(send_indexes); + root[kSsrcs] = PrimitiveVectorToJson(ssrcs); + // Some sender do not handle empty array properly, so we omit these fields + // if they are empty. if (!receiver_rtcp_event_log.empty()) { - root["receiverRtcpEventLog"] = + root[kReceiverRtcpEventLog] = PrimitiveVectorToJson(receiver_rtcp_event_log); } if (!receiver_rtcp_dscp.empty()) { - root["receiverRtcpDscp"] = PrimitiveVectorToJson(receiver_rtcp_dscp); + root[kReceiverRtcpDscp] = PrimitiveVectorToJson(receiver_rtcp_dscp); } if (!rtp_extensions.empty()) { - root["rtpExtensions"] = PrimitiveVectorToJson(rtp_extensions); + root[kRtpExtensions] = PrimitiveVectorToJson(rtp_extensions); } return root; } -Json::Value Answer::ToAnswerMessage() const { - auto json_or_error = ToJson(); - if (json_or_error.is_error()) { - return CreateInvalidAnswer(json_or_error.error()); - } - - Json::Value message_root; - message_root[kMessageKeyType] = kMessageTypeAnswer; - message_root[kAnswerMessageBody] = std::move(json_or_error.value()); - message_root[kResult] = kResultOk; - return message_root; -} - -Json::Value CreateInvalidAnswer(Error error) { - Json::Value message_root; - message_root[kMessageKeyType] = kMessageTypeAnswer; - message_root[kResult] = kResultError; - message_root[kErrorMessageBody][kErrorCode] = static_cast<int>(error.code()); - message_root[kErrorMessageBody][kErrorDescription] = error.message(); - - return message_root; -} - } // namespace cast } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h index 60b9a49479a..4298913b6dc 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h +++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages.h @@ -6,7 +6,7 @@ #define CAST_STREAMING_ANSWER_MESSAGES_H_ #include <array> -#include <chrono> // NOLINT +#include <chrono> #include <cstdint> #include <initializer_list> #include <memory> @@ -14,7 +14,7 @@ #include <utility> #include <vector> -#include "cast/streaming/offer_messages.h" +#include "absl/types/optional.h" #include "cast/streaming/ssrc.h" #include "json/value.h" #include "platform/base/error.h" @@ -23,42 +23,61 @@ namespace openscreen { namespace cast { +// For each of the below classes, though a number of methods are shared, the use +// of a shared base class has intentionally been avoided. This is to improve +// readability of the structs provided in this file by cutting down on the +// amount of obscuring boilerplate code. For each of the following struct +// definitions, the following method definitions are shared: +// (1) ParseAndValidate. Shall return a boolean indicating whether the out +// parameter is in a valid state after checking bounds and restrictions. +// (2) ToJson. Should return a proper JSON object. Assumes that IsValid() +// has been called already, OSP_DCHECKs if not IsValid(). +// (3) IsValid. Used by both ParseAndValidate and ToJson to ensure that the +// object is in a good state. struct AudioConstraints { + static bool ParseAndValidate(const Json::Value& value, AudioConstraints* out); + Json::Value ToJson() const; + bool IsValid() const; + int max_sample_rate = 0; int max_channels = 0; // Technically optional, sender will assume 32kbps if omitted. int min_bit_rate = 0; int max_bit_rate = 0; std::chrono::milliseconds max_delay = {}; - - ErrorOr<Json::Value> ToJson() const; }; struct Dimensions { + static bool ParseAndValidate(const Json::Value& value, Dimensions* out); + Json::Value ToJson() const; + bool IsValid() const; + int width = 0; int height = 0; SimpleFraction frame_rate; - - ErrorOr<Json::Value> ToJson() const; }; struct VideoConstraints { + static bool ParseAndValidate(const Json::Value& value, VideoConstraints* out); + Json::Value ToJson() const; + bool IsValid() const; + double max_pixels_per_second = {}; - Dimensions min_dimensions = {}; + absl::optional<Dimensions> min_dimensions = {}; Dimensions max_dimensions = {}; // Technically optional, sender will assume 300kbps if omitted. int min_bit_rate = 0; int max_bit_rate = 0; std::chrono::milliseconds max_delay = {}; - - ErrorOr<Json::Value> ToJson() const; }; struct Constraints { + static bool ParseAndValidate(const Json::Value& value, Constraints* out); + Json::Value ToJson() const; + bool IsValid() const; + AudioConstraints audio; VideoConstraints video; - - ErrorOr<Json::Value> ToJson() const; }; // Decides whether the Sender scales and letterboxes content to 16:9, or if @@ -67,22 +86,35 @@ struct Constraints { enum class AspectRatioConstraint : uint8_t { kVariable = 0, kFixed }; struct AspectRatio { + static bool ParseAndValidate(const Json::Value& value, AspectRatio* out); + bool IsValid() const; + + bool operator==(const AspectRatio& other) const { + return width == other.width && height == other.height; + } + int width = 0; int height = 0; }; struct DisplayDescription { + static bool ParseAndValidate(const Json::Value& value, + DisplayDescription* out); + Json::Value ToJson() const; + bool IsValid() const; + // May exceed, be the same, or less than those mentioned in the // video constraints. - Dimensions dimensions; - AspectRatio aspect_ratio = {}; - AspectRatioConstraint aspect_ratio_constraint = {}; - - ErrorOr<Json::Value> ToJson() const; + absl::optional<Dimensions> dimensions; + absl::optional<AspectRatio> aspect_ratio = {}; + absl::optional<AspectRatioConstraint> aspect_ratio_constraint = {}; }; struct Answer { - CastMode cast_mode = {}; + static bool ParseAndValidate(const Json::Value& value, Answer* out); + Json::Value ToJson() const; + bool IsValid() const; + int udp_port = 0; std::vector<int> send_indexes; std::vector<Ssrc> ssrcs; @@ -97,22 +129,8 @@ struct Answer { // RTP extensions should be empty, but not null. std::vector<std::string> rtp_extensions = {}; - - // ToJson performs a standard serialization, returning an error if this - // instance failed to serialize properly. - ErrorOr<Json::Value> ToJson() const; - - // In constrast to ToJson, ToAnswerMessage performs a successful serialization - // even if the answer object is malformed, by complying to the spec's - // error answer message format in this case. - Json::Value ToAnswerMessage() const; }; -// Helper method that creates an invalid Answer response. Exposed publicly -// here as it is called in ToAnswerMessage(), but can also be called by -// the receiver session. -Json::Value CreateInvalidAnswer(Error error); - } // namespace cast } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc index d1c708281d9..e4ec82f4a4e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/answer_messages_unittest.cc @@ -4,10 +4,12 @@ #include "cast/streaming/answer_messages.h" +#include <chrono> #include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "util/chrono_helpers.h" #include "util/json/json_serialization.h" namespace openscreen { @@ -15,59 +17,183 @@ namespace cast { namespace { +using ::testing::ElementsAre; + +// NOTE: the castMode property has been removed from the specification. We leave +// it here in the valid offer to ensure that its inclusion does not break +// parsing. +constexpr char kValidAnswerJson[] = R"({ + "castMode": "mirroring", + "udpPort": 1234, + "sendIndexes": [1, 3], + "ssrcs": [1233324, 2234222], + "constraints": { + "audio": { + "maxSampleRate": 96000, + "maxChannels": 5, + "minBitRate": 32000, + "maxBitRate": 320000, + "maxDelay": 5000 + }, + "video": { + "maxPixelsPerSecond": 62208000, + "minDimensions": { + "width": 320, + "height": 180, + "frameRate": 0 + }, + "maxDimensions": { + "width": 1920, + "height": 1080, + "frameRate": "60" + }, + "minBitRate": 300000, + "maxBitRate": 10000000, + "maxDelay": 5000 + } + }, + "display": { + "dimensions": { + "width": 1920, + "height": 1080, + "frameRate": "60000/1001" + }, + "aspectRatio": "64:27", + "scaling": "sender" + }, + "receiverRtcpEventLog": [0, 1], + "receiverRtcpDscp": [234, 567], + "receiverGetStatus": true, + "rtpExtensions": ["adaptive_playout_delay"] +})"; + const Answer kValidAnswer{ - CastMode{CastMode::Type::kMirroring}, 1234, // udp_port std::vector<int>{1, 2, 3}, // send_indexes std::vector<Ssrc>{123, 456}, // ssrcs - Constraints{ + absl::optional<Constraints>(Constraints{ AudioConstraints{ - 96000, // max_sample_rate - 7, // max_channels - 32000, // min_bit_rate - 96000, // max_bit_rate - std::chrono::milliseconds(2000) // max_delay - }, // audio + 96000, // max_sample_rate + 7, // max_channels + 32000, // min_bit_rate + 96000, // max_bit_rate + milliseconds(2000) // max_delay + }, // audio VideoConstraints{ 40000.0, // max_pixels_per_second - Dimensions{ + absl::optional<Dimensions>(Dimensions{ 320, // width 480, // height SimpleFraction{15000, 101} // frame_rate - }, // min_dimensions + }), // min_dimensions Dimensions{ 1920, // width 1080, // height SimpleFraction{288, 2} // frame_rate }, - 300000, // min_bit_rate - 144000000, // max_bit_rate - std::chrono::milliseconds(3000) // max_delay - } // video - }, // constraints - DisplayDescription{ - Dimensions{ + 300000, // min_bit_rate + 144000000, // max_bit_rate + milliseconds(3000) // max_delay + } // video + }), // constraints + absl::optional<DisplayDescription>(DisplayDescription{ + absl::optional<Dimensions>(Dimensions{ 640, // width 480, // height SimpleFraction{30, 1} // frame_rate - }, - AspectRatio{16, 9}, // aspect_ratio - AspectRatioConstraint::kFixed, // scaling - }, + }), + absl::optional<AspectRatio>(AspectRatio{16, 9}), // aspect_ratio + absl::optional<AspectRatioConstraint>( + AspectRatioConstraint::kFixed), // scaling + }), std::vector<int>{7, 8, 9}, // receiver_rtcp_event_log std::vector<int>{11, 12, 13}, // receiver_rtcp_dscp true, // receiver_get_status std::vector<std::string>{"foo", "bar"} // rtp_extensions }; +constexpr int kValidMaxPixelsPerSecond = 1920 * 1080 * 30; +constexpr Dimensions kValidDimensions{1920, 1080, SimpleFraction{60, 1}}; +static const VideoConstraints kValidVideoConstraints{ + kValidMaxPixelsPerSecond, absl::optional<Dimensions>(kValidDimensions), + kValidDimensions, 300 * 1000, + 300 * 1000 * 1000, milliseconds(3000)}; + +constexpr AudioConstraints kValidAudioConstraints{123, 456, 300, 9920, + milliseconds(123)}; + +void ExpectEqualsValidAnswerJson(const Answer& answer) { + EXPECT_EQ(1234, answer.udp_port); + + EXPECT_THAT(answer.send_indexes, ElementsAre(1, 3)); + EXPECT_THAT(answer.ssrcs, ElementsAre(1233324u, 2234222u)); + ASSERT_TRUE(answer.constraints.has_value()); + const AudioConstraints& audio = answer.constraints->audio; + EXPECT_EQ(96000, audio.max_sample_rate); + EXPECT_EQ(5, audio.max_channels); + EXPECT_EQ(32000, audio.min_bit_rate); + EXPECT_EQ(320000, audio.max_bit_rate); + EXPECT_EQ(milliseconds{5000}, audio.max_delay); + + const VideoConstraints& video = answer.constraints->video; + EXPECT_EQ(62208000, video.max_pixels_per_second); + ASSERT_TRUE(video.min_dimensions.has_value()); + EXPECT_EQ(320, video.min_dimensions->width); + EXPECT_EQ(180, video.min_dimensions->height); + EXPECT_EQ((SimpleFraction{0, 1}), video.min_dimensions->frame_rate); + EXPECT_EQ(1920, video.max_dimensions.width); + EXPECT_EQ(1080, video.max_dimensions.height); + EXPECT_EQ((SimpleFraction{60, 1}), video.max_dimensions.frame_rate); + EXPECT_EQ(300000, video.min_bit_rate); + EXPECT_EQ(10000000, video.max_bit_rate); + EXPECT_EQ(milliseconds{5000}, video.max_delay); + + ASSERT_TRUE(answer.display.has_value()); + const DisplayDescription& display = answer.display.value(); + ASSERT_TRUE(display.dimensions.has_value()); + EXPECT_EQ(1920, display.dimensions->width); + EXPECT_EQ(1080, display.dimensions->height); + EXPECT_EQ((SimpleFraction{60000, 1001}), display.dimensions->frame_rate); + EXPECT_EQ((AspectRatio{64, 27}), display.aspect_ratio.value()); + EXPECT_EQ(AspectRatioConstraint::kFixed, + display.aspect_ratio_constraint.value()); + + EXPECT_THAT(answer.receiver_rtcp_event_log, ElementsAre(0, 1)); + EXPECT_THAT(answer.receiver_rtcp_dscp, ElementsAre(234, 567)); + EXPECT_TRUE(answer.supports_wifi_status_reporting); + EXPECT_THAT(answer.rtp_extensions, ElementsAre("adaptive_playout_delay")); +} + +void ExpectFailureOnParse(absl::string_view raw_json) { + ErrorOr<Json::Value> root = json::Parse(raw_json); + // Must be a valid JSON object, but not a valid answer. + ASSERT_TRUE(root.is_value()); + + Answer answer; + EXPECT_FALSE(Answer::ParseAndValidate(std::move(root.value()), &answer)); + EXPECT_FALSE(answer.IsValid()); +} + +// Functions that use ASSERT_* must return void, so we use an out parameter +// here instead of returning. +void ExpectSuccessOnParse(absl::string_view raw_json, Answer* out = nullptr) { + ErrorOr<Json::Value> root = json::Parse(raw_json); + // Must be a valid JSON object, but not a valid answer. + ASSERT_TRUE(root.is_value()); + + Answer answer; + ASSERT_TRUE(Answer::ParseAndValidate(std::move(root.value()), &answer)); + EXPECT_TRUE(answer.IsValid()); + if (out) { + *out = std::move(answer); + } +} + } // anonymous namespace TEST(AnswerMessagesTest, ProperlyPopulatedAnswerSerializesProperly) { - auto value_or_error = kValidAnswer.ToJson(); - EXPECT_TRUE(value_or_error.is_value()); - - Json::Value root = std::move(value_or_error.value()); - EXPECT_EQ(root["castMode"], "mirroring"); + ASSERT_TRUE(kValidAnswer.IsValid()); + Json::Value root = kValidAnswer.ToJson(); EXPECT_EQ(root["udpPort"], 1234); Json::Value sendIndexes = std::move(root["sendIndexes"]); @@ -140,41 +266,393 @@ TEST(AnswerMessagesTest, ProperlyPopulatedAnswerSerializesProperly) { EXPECT_EQ(rtp_extensions[1], "bar"); } -TEST(AnswerMessagesTest, InvalidDimensionsCauseError) { +TEST(AnswerMessagesTest, EmptyArraysOmitted) { + Answer missing_event_log = kValidAnswer; + missing_event_log.receiver_rtcp_event_log.clear(); + ASSERT_TRUE(missing_event_log.IsValid()); + Json::Value root = missing_event_log.ToJson(); + EXPECT_FALSE(root["receiverRtcpEventLog"]); + + Answer missing_rtcp_dscp = kValidAnswer; + missing_rtcp_dscp.receiver_rtcp_dscp.clear(); + ASSERT_TRUE(missing_rtcp_dscp.IsValid()); + root = missing_rtcp_dscp.ToJson(); + EXPECT_FALSE(root["receiverRtcpDscp"]); + + Answer missing_extensions = kValidAnswer; + missing_extensions.rtp_extensions.clear(); + ASSERT_TRUE(missing_extensions.IsValid()); + root = missing_extensions.ToJson(); + EXPECT_FALSE(root["rtpExtensions"]); +} + +TEST(AnswerMessagesTest, InvalidDimensionsCauseInvalid) { Answer invalid_dimensions = kValidAnswer; - invalid_dimensions.display.value().dimensions.width = -1; - auto value_or_error = invalid_dimensions.ToJson(); - EXPECT_TRUE(value_or_error.is_error()); + invalid_dimensions.display->dimensions->width = -1; + EXPECT_FALSE(invalid_dimensions.IsValid()); } TEST(AnswerMessagesTest, InvalidAudioConstraintsCauseError) { Answer invalid_audio = kValidAnswer; - invalid_audio.constraints.value().audio.max_bit_rate = - invalid_audio.constraints.value().audio.min_bit_rate - 1; - auto value_or_error = invalid_audio.ToJson(); - EXPECT_TRUE(value_or_error.is_error()); + invalid_audio.constraints->audio.max_bit_rate = + invalid_audio.constraints->audio.min_bit_rate - 1; + EXPECT_FALSE(invalid_audio.IsValid()); } TEST(AnswerMessagesTest, InvalidVideoConstraintsCauseError) { Answer invalid_video = kValidAnswer; - invalid_video.constraints.value().video.max_pixels_per_second = -1.0; - auto value_or_error = invalid_video.ToJson(); - EXPECT_TRUE(value_or_error.is_error()); + invalid_video.constraints->video.max_pixels_per_second = -1.0; + EXPECT_FALSE(invalid_video.IsValid()); } TEST(AnswerMessagesTest, InvalidDisplayDescriptionsCauseError) { Answer invalid_display = kValidAnswer; - invalid_display.display.value().aspect_ratio = {0, 0}; - auto value_or_error = invalid_display.ToJson(); - EXPECT_TRUE(value_or_error.is_error()); + invalid_display.display->aspect_ratio = {0, 0}; + EXPECT_FALSE(invalid_display.IsValid()); } TEST(AnswerMessagesTest, InvalidUdpPortsCauseError) { Answer invalid_port = kValidAnswer; invalid_port.udp_port = 65536; - auto value_or_error = invalid_port.ToJson(); - EXPECT_TRUE(value_or_error.is_error()); + EXPECT_FALSE(invalid_port.IsValid()); +} + +TEST(AnswerMessagesTest, CanParseValidAnswerJson) { + Answer answer; + ExpectSuccessOnParse(kValidAnswerJson, &answer); + ExpectEqualsValidAnswerJson(answer); +} + +// In practice, the rtpExtensions, receiverRtcpDscp, and receiverRtcpEventLog +// fields may be missing from some receivers. We handle this case by treating +// them as empty. +TEST(AnswerMessagesTest, SucceedsWithMissingRtpFields) { + ExpectSuccessOnParse(R"({ + "udpPort": 1234, + "sendIndexes": [1, 3], + "ssrcs": [1233324, 2234222], + "receiverGetStatus": true + })"); +} + +TEST(AnswerMessagesTest, ErrorOnEmptyAnswer) { + ExpectFailureOnParse("{}"); +} + +TEST(AnswerMessagesTest, ErrorOnMissingUdpPort) { + ExpectFailureOnParse(R"({ + "sendIndexes": [1, 3], + "ssrcs": [1233324, 2234222], + "receiverGetStatus": true + })"); +} + +TEST(AnswerMessagesTest, ErrorOnMissingSsrcs) { + ExpectFailureOnParse(R"({ + "udpPort": 1234, + "sendIndexes": [1, 3], + "receiverGetStatus": true + })"); } +TEST(AnswerMessagesTest, ErrorOnMissingSendIndexes) { + ExpectFailureOnParse(R"({ + "udpPort": 1234, + "ssrcs": [1233324, 2234222], + "receiverGetStatus": true + })"); +} + +TEST(AnswerMessagesTest, AssumesNoReportingIfGetStatusFalse) { + Answer answer; + ExpectSuccessOnParse(R"({ + "udpPort": 1234, + "sendIndexes": [1, 3], + "ssrcs": [1233324, 2234222] + })", + &answer); + + EXPECT_FALSE(answer.supports_wifi_status_reporting); +} + +TEST(AnswerMessagesTest, AllowsReceiverSideScaling) { + Answer answer; + ExpectSuccessOnParse(R"({ + "udpPort": 1234, + "sendIndexes": [1, 3], + "ssrcs": [1233324, 2234222], + "display": { + "dimensions": { + "width": 1920, + "height": 1080, + "frameRate": "60000/1001" + }, + "aspectRatio": "64:27", + "scaling": "receiver" + } + })", + &answer); + ASSERT_TRUE(answer.display.has_value()); + EXPECT_EQ(answer.display->aspect_ratio_constraint.value(), + AspectRatioConstraint::kVariable); +} + +TEST(AnswerMessagesTest, AssumesMinBitRateIfOmitted) { + Answer answer; + ExpectSuccessOnParse(R"({ + "udpPort": 1234, + "sendIndexes": [1, 3], + "ssrcs": [1233324, 2234222], + "constraints": { + "audio": { + "maxSampleRate": 96000, + "maxChannels": 5, + "maxBitRate": 320000, + "maxDelay": 5000 + }, + "video": { + "maxPixelsPerSecond": 62208000, + "maxDimensions": { + "width": 1920, + "height": 1080, + "frameRate": "60" + }, + "maxBitRate": 10000000, + "maxDelay": 5000 + } + }, + "receiverGetStatus": true + })", + &answer); + + EXPECT_EQ(32000, answer.constraints->audio.min_bit_rate); + EXPECT_EQ(300000, answer.constraints->video.min_bit_rate); +} + +// Instead of testing all possible json parsing options for validity, we +// can instead directly test the IsValid() methods. +TEST(AnswerMessagesTest, AudioConstraintsIsValid) { + constexpr AudioConstraints kInvalidSampleRate{0, 456, 300, 9920, + milliseconds(123)}; + constexpr AudioConstraints kInvalidMaxChannels{123, 0, 300, 9920, + milliseconds(123)}; + constexpr AudioConstraints kInvalidMinBitRate{123, 456, 0, 9920, + milliseconds(123)}; + constexpr AudioConstraints kInvalidMaxBitRate{123, 456, 300, 0, + milliseconds(123)}; + constexpr AudioConstraints kInvalidMaxDelay{123, 456, 300, 0, + milliseconds(0)}; + + EXPECT_TRUE(kValidAudioConstraints.IsValid()); + EXPECT_FALSE(kInvalidSampleRate.IsValid()); + EXPECT_FALSE(kInvalidMaxChannels.IsValid()); + EXPECT_FALSE(kInvalidMinBitRate.IsValid()); + EXPECT_FALSE(kInvalidMaxBitRate.IsValid()); + EXPECT_FALSE(kInvalidMaxDelay.IsValid()); +} + +TEST(AnswerMessagesTest, DimensionsIsValid) { + // NOTE: in some cases (such as min dimensions) a frame rate of zero is valid. + constexpr Dimensions kValidZeroFrameRate{1920, 1080, SimpleFraction{0, 60}}; + constexpr Dimensions kInvalidWidth{0, 1080, SimpleFraction{60, 1}}; + constexpr Dimensions kInvalidHeight{1920, 0, SimpleFraction{60, 1}}; + constexpr Dimensions kInvalidFrameRateZeroDenominator{1920, 1080, + SimpleFraction{60, 0}}; + constexpr Dimensions kInvalidFrameRateNegativeNumerator{ + 1920, 1080, SimpleFraction{-1, 30}}; + constexpr Dimensions kInvalidFrameRateNegativeDenominator{ + 1920, 1080, SimpleFraction{30, -1}}; + + EXPECT_TRUE(kValidDimensions.IsValid()); + EXPECT_TRUE(kValidZeroFrameRate.IsValid()); + EXPECT_FALSE(kInvalidWidth.IsValid()); + EXPECT_FALSE(kInvalidHeight.IsValid()); + EXPECT_FALSE(kInvalidFrameRateZeroDenominator.IsValid()); + EXPECT_FALSE(kInvalidFrameRateNegativeNumerator.IsValid()); + EXPECT_FALSE(kInvalidFrameRateNegativeDenominator.IsValid()); +} + +TEST(AnswerMessagesTest, VideoConstraintsIsValid) { + VideoConstraints invalid_max_pixels_per_second = kValidVideoConstraints; + invalid_max_pixels_per_second.max_pixels_per_second = 0; + + VideoConstraints invalid_min_dimensions = kValidVideoConstraints; + invalid_min_dimensions.min_dimensions->width = 0; + + VideoConstraints invalid_max_dimensions = kValidVideoConstraints; + invalid_max_dimensions.max_dimensions.height = 0; + + VideoConstraints invalid_min_bit_rate = kValidVideoConstraints; + invalid_min_bit_rate.min_bit_rate = 0; + + VideoConstraints invalid_max_bit_rate = kValidVideoConstraints; + invalid_max_bit_rate.max_bit_rate = invalid_max_bit_rate.min_bit_rate - 1; + + VideoConstraints invalid_max_delay = kValidVideoConstraints; + invalid_max_delay.max_delay = milliseconds(0); + + EXPECT_TRUE(kValidVideoConstraints.IsValid()); + EXPECT_FALSE(invalid_max_pixels_per_second.IsValid()); + EXPECT_FALSE(invalid_min_dimensions.IsValid()); + EXPECT_FALSE(invalid_max_dimensions.IsValid()); + EXPECT_FALSE(invalid_min_bit_rate.IsValid()); + EXPECT_FALSE(invalid_max_bit_rate.IsValid()); + EXPECT_FALSE(invalid_max_delay.IsValid()); +} + +TEST(AnswerMessagesTest, ConstraintsIsValid) { + VideoConstraints invalid_video_constraints = kValidVideoConstraints; + invalid_video_constraints.max_pixels_per_second = 0; + + AudioConstraints invalid_audio_constraints = kValidAudioConstraints; + invalid_audio_constraints.max_bit_rate = 0; + + const Constraints valid{kValidAudioConstraints, kValidVideoConstraints}; + const Constraints invalid_audio{kValidAudioConstraints, + invalid_video_constraints}; + const Constraints invalid_video{invalid_audio_constraints, + kValidVideoConstraints}; + + EXPECT_TRUE(valid.IsValid()); + EXPECT_FALSE(invalid_audio.IsValid()); + EXPECT_FALSE(invalid_video.IsValid()); +} + +TEST(AnswerMessagesTest, AspectRatioIsValid) { + constexpr AspectRatio kValid{16, 9}; + constexpr AspectRatio kInvalidWidth{0, 9}; + constexpr AspectRatio kInvalidHeight{16, 0}; + + EXPECT_TRUE(kValid.IsValid()); + EXPECT_FALSE(kInvalidWidth.IsValid()); + EXPECT_FALSE(kInvalidHeight.IsValid()); +} + +TEST(AnswerMessagesTest, AspectRatioParseAndValidate) { + const Json::Value kValid = "16:9"; + const Json::Value kWrongDelimiter = "16-9"; + const Json::Value kTooManyFields = "16:9:3"; + const Json::Value kTooFewFields = "1:"; + const Json::Value kNoDelimiter = "12345"; + const Json::Value kNegativeWidth = "-123:2345"; + const Json::Value kNegativeHeight = "22:-7"; + const Json::Value kNegativeBoth = "22:-7"; + const Json::Value kNonNumberWidth = "twenty2#:9"; + const Json::Value kNonNumberHeight = "2:thirty"; + const Json::Value kZeroWidth = "0:9"; + const Json::Value kZeroHeight = "16:0"; + + AspectRatio out; + EXPECT_TRUE(AspectRatio::ParseAndValidate(kValid, &out)); + EXPECT_EQ(out.width, 16); + EXPECT_EQ(out.height, 9); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kWrongDelimiter, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kTooManyFields, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kTooFewFields, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kWrongDelimiter, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kNoDelimiter, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kNegativeWidth, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kNegativeHeight, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kNegativeBoth, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kNonNumberWidth, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kNonNumberHeight, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kZeroWidth, &out)); + EXPECT_FALSE(AspectRatio::ParseAndValidate(kZeroHeight, &out)); +} + +TEST(AnswerMessagesTest, DisplayDescriptionParseAndValidate) { + Json::Value valid_scaling; + valid_scaling["scaling"] = "receiver"; + Json::Value invalid_scaling; + invalid_scaling["scaling"] = "embedder"; + Json::Value invalid_scaling_valid_ratio; + invalid_scaling_valid_ratio["scaling"] = "embedder"; + invalid_scaling_valid_ratio["aspectRatio"] = "16:9"; + + Json::Value dimensions; + dimensions["width"] = 1920; + dimensions["height"] = 1080; + dimensions["frameRate"] = "30"; + Json::Value valid_dimensions; + valid_dimensions["dimensions"] = dimensions; + + Json::Value dimensions_invalid = dimensions; + dimensions_invalid["frameRate"] = "infinity"; + Json::Value invalid_dimensions; + invalid_dimensions["dimensions"] = dimensions_invalid; + + Json::Value aspect_ratio_and_constraint; + aspect_ratio_and_constraint["scaling"] = "sender"; + aspect_ratio_and_constraint["aspectRatio"] = "4:3"; + + DisplayDescription out; + ASSERT_TRUE(DisplayDescription::ParseAndValidate(valid_scaling, &out)); + ASSERT_TRUE(out.aspect_ratio_constraint.has_value()); + EXPECT_EQ(out.aspect_ratio_constraint.value(), + AspectRatioConstraint::kVariable); + + EXPECT_FALSE(DisplayDescription::ParseAndValidate(invalid_scaling, &out)); + EXPECT_TRUE( + DisplayDescription::ParseAndValidate(invalid_scaling_valid_ratio, &out)); + + ASSERT_TRUE(DisplayDescription::ParseAndValidate(valid_dimensions, &out)); + ASSERT_TRUE(out.dimensions.has_value()); + EXPECT_EQ(1920, out.dimensions->width); + EXPECT_EQ(1080, out.dimensions->height); + EXPECT_EQ((SimpleFraction{30, 1}), out.dimensions->frame_rate); + + EXPECT_FALSE(DisplayDescription::ParseAndValidate(invalid_dimensions, &out)); + + ASSERT_TRUE( + DisplayDescription::ParseAndValidate(aspect_ratio_and_constraint, &out)); + EXPECT_EQ(AspectRatioConstraint::kFixed, out.aspect_ratio_constraint.value()); +} + +TEST(AnswerMessagesTest, DisplayDescriptionIsValid) { + const DisplayDescription kInvalidEmptyDescription{ + absl::optional<Dimensions>{}, absl::optional<AspectRatio>{}, + absl::optional<AspectRatioConstraint>{}}; + + DisplayDescription has_valid_dimensions = kInvalidEmptyDescription; + has_valid_dimensions.dimensions = + absl::optional<Dimensions>(kValidDimensions); + + DisplayDescription has_invalid_dimensions = kInvalidEmptyDescription; + has_invalid_dimensions.dimensions = + absl::optional<Dimensions>(kValidDimensions); + has_invalid_dimensions.dimensions->width = 0; + + DisplayDescription has_aspect_ratio = kInvalidEmptyDescription; + has_aspect_ratio.aspect_ratio = + absl::optional<AspectRatio>{AspectRatio{16, 9}}; + + DisplayDescription has_invalid_aspect_ratio = kInvalidEmptyDescription; + has_invalid_aspect_ratio.aspect_ratio = + absl::optional<AspectRatio>{AspectRatio{0, 20}}; + + DisplayDescription has_aspect_ratio_constraint = kInvalidEmptyDescription; + has_aspect_ratio_constraint.aspect_ratio_constraint = + absl::optional<AspectRatioConstraint>(AspectRatioConstraint::kFixed); + + DisplayDescription has_constraint_and_dimensions = + has_aspect_ratio_constraint; + has_constraint_and_dimensions.dimensions = + absl::optional<Dimensions>(kValidDimensions); + + DisplayDescription has_constraint_and_ratio = has_aspect_ratio_constraint; + has_constraint_and_ratio.aspect_ratio = AspectRatio{4, 3}; + + EXPECT_FALSE(kInvalidEmptyDescription.IsValid()); + EXPECT_TRUE(has_valid_dimensions.IsValid()); + EXPECT_FALSE(has_invalid_dimensions.IsValid()); + EXPECT_TRUE(has_aspect_ratio.IsValid()); + EXPECT_FALSE(has_invalid_aspect_ratio.IsValid()); + EXPECT_FALSE(has_aspect_ratio_constraint.IsValid()); + EXPECT_TRUE(has_constraint_and_dimensions.IsValid()); +} + +// Instead of being tested here, Answer's IsValid is checked in all other +// relevant tests. + } // namespace cast } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc index 5c42aa4a1b7..e6bc594f6a2 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator.cc @@ -24,8 +24,7 @@ int ToClampedBitsPerSecond(int32_t bytes, Clock::duration time_window) { // Divide |bytes| by |time_window| and scale the units to bits per second. constexpr int64_t kBitsPerByte = 8; constexpr int64_t kClockTicksPerSecond = - std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)) - .count(); + Clock::to_duration(std::chrono::seconds(1)).count(); const int64_t bits = bytes * kBitsPerByte; const int64_t bits_per_second = (bits * kClockTicksPerSecond) / time_window.count(); diff --git a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc index 7850573f0f3..6a2e1fc91de 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/bandwidth_estimator_unittest.cc @@ -4,21 +4,19 @@ #include "cast/streaming/bandwidth_estimator.h" +#include <chrono> #include <limits> #include <random> #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/api/time.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { namespace { -using std::chrono::duration_cast; -using std::chrono::milliseconds; -using std::chrono::seconds; - using openscreen::operator<<; // For std::chrono::duration gtest pretty-print. // BandwidthEstimator configuration common to all tests. diff --git a/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations.cc b/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations.cc new file mode 100644 index 00000000000..1a240e0493a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations.cc @@ -0,0 +1,177 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/capture_recommendations.h" + +#include <algorithm> +#include <utility> + +#include "cast/streaming/answer_messages.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace cast { +namespace capture_recommendations { +namespace { + +bool DoubleEquals(double a, double b) { + // Choice of epsilon for double comparison allows for proper comparison + // for both aspect ratios and frame rates. For frame rates, it is based on the + // broadcast rate of 29.97fps, which is actually 29.976. For aspect ratios, it + // allows for a one-pixel difference at a 4K resolution, we want it to be + // relatively high to avoid false negative comparison results. + const double kEpsilon = .0001; + return std::abs(a - b) < kEpsilon; +} + +void ApplyDisplay(const DisplayDescription& description, + Recommendations* recommendations) { + if (description.aspect_ratio_constraint && + description.aspect_ratio_constraint.value() == + AspectRatioConstraint::kFixed) { + recommendations->video.supports_scaling = false; + } + + // We should never exceed the display's resolution, since it will always + // force scaling. + if (description.dimensions) { + const double frame_rate = + static_cast<double>(description.dimensions->frame_rate); + recommendations->video.maximum = + Resolution{description.dimensions->width, + description.dimensions->height, frame_rate}; + recommendations->video.bit_rate_limits.maximum = + recommendations->video.maximum.effective_bit_rate(); + recommendations->video.minimum.set_minimum(recommendations->video.maximum); + } + + // If the receiver gives us an aspect ratio that doesn't match the display + // resolution they give us, the behavior is undefined from the spec. + // Here we prioritize the aspect ratio, and the receiver can scale the frame + // as they wish. + double aspect_ratio = 0.0; + if (description.aspect_ratio) { + aspect_ratio = static_cast<double>(description.aspect_ratio->width) / + description.aspect_ratio->height; +#if OSP_DCHECK_IS_ON() + if (description.dimensions) { + const double from_dims = + static_cast<double>(description.dimensions->width) / + description.dimensions->height; + if (!DoubleEquals(from_dims, aspect_ratio)) { + OSP_DLOG_WARN << "Received mismatched aspect ratio from the receiver."; + } + } +#endif + recommendations->video.maximum.width = + recommendations->video.maximum.height * aspect_ratio; + } else if (description.dimensions) { + aspect_ratio = static_cast<double>(description.dimensions->width) / + description.dimensions->height; + } else { + return; + } + recommendations->video.minimum.width = + recommendations->video.minimum.height * aspect_ratio; +} + +Resolution ToResolution(const Dimensions& dims) { + return {dims.width, dims.height, static_cast<double>(dims.frame_rate)}; +} + +void ApplyConstraints(const Constraints& constraints, + Recommendations* recommendations) { + // Audio has no fields in the display description, so we can safely + // ignore the current recommendations when setting values here. + recommendations->audio.max_delay = constraints.audio.max_delay; + recommendations->audio.max_channels = constraints.audio.max_channels; + recommendations->audio.max_sample_rate = constraints.audio.max_sample_rate; + + recommendations->audio.bit_rate_limits = BitRateLimits{ + std::max(constraints.audio.min_bit_rate, kDefaultAudioMinBitRate), + std::max(constraints.audio.max_bit_rate, kDefaultAudioMinBitRate)}; + + // With video, we take the intersection of values of the constraints and + // the display description. + recommendations->video.max_delay = constraints.video.max_delay; + recommendations->video.max_pixels_per_second = + constraints.video.max_pixels_per_second; + recommendations->video.bit_rate_limits = + BitRateLimits{std::max(constraints.video.min_bit_rate, + recommendations->video.bit_rate_limits.minimum), + std::min(constraints.video.max_bit_rate, + recommendations->video.bit_rate_limits.maximum)}; + Resolution max = ToResolution(constraints.video.max_dimensions); + if (max <= kDefaultMinResolution) { + recommendations->video.maximum = kDefaultMinResolution; + } else if (max < recommendations->video.maximum) { + recommendations->video.maximum = std::move(max); + } + // Implicit else: maximum = kDefaultMaxResolution. + + if (constraints.video.min_dimensions) { + Resolution min = ToResolution(constraints.video.min_dimensions.value()); + if (kDefaultMinResolution < min) { + recommendations->video.minimum = std::move(min); + } + } +} + +} // namespace + +bool BitRateLimits::operator==(const BitRateLimits& other) const { + return std::tie(minimum, maximum) == std::tie(other.minimum, other.maximum); +} + +bool Audio::operator==(const Audio& other) const { + return std::tie(bit_rate_limits, max_delay, max_channels, max_sample_rate) == + std::tie(other.bit_rate_limits, other.max_delay, other.max_channels, + other.max_sample_rate); +} + +bool Resolution::operator==(const Resolution& other) const { + return (std::tie(width, height) == std::tie(other.width, other.height)) && + DoubleEquals(frame_rate, other.frame_rate); +} + +bool Resolution::operator<(const Resolution& other) const { + return effective_bit_rate() < other.effective_bit_rate(); +} + +bool Resolution::operator<=(const Resolution& other) const { + return (*this == other) || (*this < other); +} + +void Resolution::set_minimum(const Resolution& other) { + if (other < *this) { + *this = other; + } +} + +bool Video::operator==(const Video& other) const { + return std::tie(bit_rate_limits, minimum, maximum, supports_scaling, + max_delay, max_pixels_per_second) == + std::tie(other.bit_rate_limits, other.minimum, other.maximum, + other.supports_scaling, other.max_delay, + other.max_pixels_per_second); +} + +bool Recommendations::operator==(const Recommendations& other) const { + return std::tie(audio, video) == std::tie(other.audio, other.video); +} + +Recommendations GetRecommendations(const Answer& answer) { + Recommendations recommendations; + if (answer.display.has_value() && answer.display->IsValid()) { + ApplyDisplay(answer.display.value(), &recommendations); + } + if (answer.constraints.has_value() && answer.constraints->IsValid()) { + ApplyConstraints(answer.constraints.value(), &recommendations); + } + return recommendations; +} + +} // namespace capture_recommendations +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations.h b/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations.h new file mode 100644 index 00000000000..67263c85ee5 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations.h @@ -0,0 +1,166 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_CAPTURE_RECOMMENDATIONS_H_ +#define CAST_STREAMING_CAPTURE_RECOMMENDATIONS_H_ + +#include <chrono> +#include <cmath> +#include <memory> +#include <tuple> + +namespace openscreen { +namespace cast { + +struct Answer; + +// This namespace contains classes and functions to be used by senders for +// determining what constraints are recommended for the capture device, based on +// the limits reported by the receiver. +// +// A general note about recommendations: they are NOT maximum operational +// limits, instead they are targeted to provide a delightful cast experience. +// For example, if a receiver is connected to a 1080P display but cannot provide +// 1080P at a stable FPS with a good experience, 1080P will not be recommended. +namespace capture_recommendations { + +// Default maximum delay for both audio and video. Used if the sender fails +// to provide any constraints. +constexpr std::chrono::milliseconds kDefaultMaxDelayMs(4000); + +// Bit rate limits, used for both audio and video streams. +struct BitRateLimits { + bool operator==(const BitRateLimits& other) const; + + // Minimum bit rate, in bits per second. + int minimum; + + // Maximum bit rate, in bits per second. + int maximum; +}; + +// The mirroring control protocol specifies 32kbps as the absolute minimum +// for audio. Depending on the type of audio content (narrowband, fullband, +// etc.) Opus specifically can perform very well at this bitrate. +// See: https://research.google/pubs/pub41650/ +constexpr int kDefaultAudioMinBitRate = 32 * 1000; + +// Opus generally sees little improvement above 192kbps, but some older codecs +// that we may consider supporting improve at up to 256kbps. +constexpr int kDefaultAudioMaxBitRate = 256 * 1000; +constexpr BitRateLimits kDefaultAudioBitRateLimits{kDefaultAudioMinBitRate, + kDefaultAudioMaxBitRate}; + +// Generally speaking, due to the range of human hearing (20Hz-20kHz) and the +// Nyquist sampling theorem, 44.1kHz captures should capture all the fidelity +// of the audio source. +constexpr int kDefaultAudioMaxSampleRate = 44100; + +// Default to stereo if channel count is not provided. +constexpr int kDefaultAudioMaxChannels = 2; + +// Audio capture recommendations. Maximum delay is determined by buffer +// constraints, and capture bit rate may vary between limits as appropriate. +struct Audio { + bool operator==(const Audio& other) const; + + // Represents the recommended bit rate range. + BitRateLimits bit_rate_limits = kDefaultAudioBitRateLimits; + + // Represents the maximum audio delay, in milliseconds. + std::chrono::milliseconds max_delay = kDefaultMaxDelayMs; + + // Represents the maximum number of audio channels. + int max_channels = kDefaultAudioMaxChannels; + + // Represents the maximum samples per second. + int max_sample_rate = kDefaultAudioMaxSampleRate; +}; + +struct Resolution { + bool operator==(const Resolution& other) const; + bool operator<(const Resolution& other) const; + bool operator<=(const Resolution& other) const; + void set_minimum(const Resolution& other); + + // The effective bit rate is the predicted average bit rate based on the + // properties of the Resolution instance, and is currently just the product. + constexpr int effective_bit_rate() const { + return static_cast<int>(static_cast<double>(width * height) * frame_rate); + } + + int width; + int height; + double frame_rate; +}; + +// The minimum dimensions are as close as possible to low-definition +// television, factoring in the receiver's aspect ratio if provided. +constexpr Resolution kDefaultMinResolution{320, 240, 30}; + +// Currently mirroring only supports 1080P. +constexpr Resolution kDefaultMaxResolution{1920, 1080, 30}; + +// The mirroring spec suggests 300kbps as the absolute minimum bitrate. +constexpr int kDefaultVideoMinBitRate = 300 * 1000; + +// The theoretical maximum pixels per second is the maximum bit rate +// divided by 8 (the max byte rate). In practice it should generally be +// less. +constexpr int kDefaultVideoMaxPixelsPerSecond = + kDefaultMaxResolution.effective_bit_rate() / 8; + +// Our default limits are merely the product of the minimum and maximum +// dimensions, and are only used if the receiver fails to give better +// constraint information. +constexpr BitRateLimits kDefaultVideoBitRateLimits{ + kDefaultVideoMinBitRate, kDefaultMaxResolution.effective_bit_rate()}; + +// Video capture recommendations. +struct Video { + bool operator==(const Video& other) const; + + // Represents the recommended bit rate range. + BitRateLimits bit_rate_limits = kDefaultVideoBitRateLimits; + + // Represents the recommended minimum resolution. + Resolution minimum = kDefaultMinResolution; + + // Represents the recommended maximum resolution. + Resolution maximum = kDefaultMaxResolution; + + // Indicates whether the receiver can scale frames from a different aspect + // ratio, or if it needs to be done by the sender. Default is true, as we + // may not know the aspect ratio that the receiver supports. + bool supports_scaling = true; + + // Represents the maximum video delay, in milliseconds. + std::chrono::milliseconds max_delay = kDefaultMaxDelayMs; + + // Represents the maximum pixels per second, not necessarily correlated + // to bit rate. + int max_pixels_per_second = kDefaultVideoMaxPixelsPerSecond; +}; + +// Outputted recommendations for usage by capture devices. Note that we always +// return both audio and video (it is up to the sender to determine what +// streams actually get created). If the receiver doesn't give us any +// information for making recommendations, the defaults are used. +struct Recommendations { + bool operator==(const Recommendations& other) const; + + // Audio specific recommendations. + Audio audio; + + // Video specific recommendations. + Video video; +}; + +Recommendations GetRecommendations(const Answer& answer); + +} // namespace capture_recommendations +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_CAPTURE_RECOMMENDATIONS_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations_unittest.cc new file mode 100644 index 00000000000..56e42e48895 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/streaming/capture_recommendations_unittest.cc @@ -0,0 +1,286 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/capture_recommendations.h" + +#include "absl/types/optional.h" +#include "cast/streaming/answer_messages.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/chrono_helpers.h" + +namespace openscreen { +namespace cast { +namespace capture_recommendations { +namespace { + +constexpr Recommendations kDefaultRecommendations{ + Audio{BitRateLimits{32000, 256000}, milliseconds(4000), 2, 44100}, + Video{BitRateLimits{300000, 1920 * 1080 * 30}, Resolution{320, 240, 30}, + Resolution{1920, 1080, 30}, true, milliseconds(4000), + 1920 * 1080 * 30 / 8}}; + +constexpr DisplayDescription kEmptyDescription{}; + +constexpr DisplayDescription kValidOnlyResolution{ + Dimensions{1024, 768, SimpleFraction{60, 1}}, absl::nullopt, absl::nullopt}; + +constexpr DisplayDescription kValidOnlyAspectRatio{ + absl::nullopt, AspectRatio{4, 3}, absl::nullopt}; + +constexpr DisplayDescription kValidOnlyAspectRatioSixteenNine{ + absl::nullopt, AspectRatio{16, 9}, absl::nullopt}; + +constexpr DisplayDescription kValidOnlyVariable{ + absl::nullopt, absl::nullopt, AspectRatioConstraint::kVariable}; + +constexpr DisplayDescription kInvalidOnlyFixed{absl::nullopt, absl::nullopt, + AspectRatioConstraint::kFixed}; + +constexpr DisplayDescription kValidFixedAspectRatio{ + absl::nullopt, AspectRatio{4, 3}, AspectRatioConstraint::kFixed}; + +constexpr DisplayDescription kValidVariableAspectRatio{ + absl::nullopt, AspectRatio{4, 3}, AspectRatioConstraint::kVariable}; + +constexpr DisplayDescription kValidFixedMissingAspectRatio{ + Dimensions{1024, 768, SimpleFraction{60, 1}}, absl::nullopt, + AspectRatioConstraint::kFixed}; + +constexpr DisplayDescription kValidDisplayFhd{ + Dimensions{1920, 1080, SimpleFraction{30, 1}}, AspectRatio{16, 9}, + AspectRatioConstraint::kVariable}; + +constexpr DisplayDescription kValidDisplayXga{ + Dimensions{1024, 768, SimpleFraction{60, 1}}, AspectRatio{4, 3}, + AspectRatioConstraint::kFixed}; + +constexpr DisplayDescription kValidDisplayTiny{ + Dimensions{300, 200, SimpleFraction{30, 1}}, AspectRatio{3, 2}, + AspectRatioConstraint::kFixed}; + +constexpr DisplayDescription kValidDisplayMismatched{ + Dimensions{300, 200, SimpleFraction{30, 1}}, AspectRatio{3, 4}, + AspectRatioConstraint::kFixed}; + +constexpr Constraints kEmptyConstraints{}; + +constexpr Constraints kValidConstraintsHighEnd{ + {96100, 5, 96000, 500000, std::chrono::seconds(6)}, + {6000000, Dimensions{640, 480, SimpleFraction{30, 1}}, + Dimensions{3840, 2160, SimpleFraction{144, 1}}, 600000, 6000000, + std::chrono::seconds(6)}}; + +constexpr Constraints kValidConstraintsLowEnd{ + {22000, 2, 24000, 50000, std::chrono::seconds(1)}, + {60000, Dimensions{120, 80, SimpleFraction{10, 1}}, + Dimensions{1200, 800, SimpleFraction{30, 1}}, 100000, 1000000, + std::chrono::seconds(1)}}; + +} // namespace + +TEST(CaptureRecommendationsTest, UsesDefaultsIfNoReceiverInformationAvailable) { + EXPECT_EQ(kDefaultRecommendations, GetRecommendations(Answer{})); +} + +TEST(CaptureRecommendationsTest, EmptyDisplayDescription) { + Answer answer; + answer.display = kEmptyDescription; + EXPECT_EQ(kDefaultRecommendations, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, OnlyResolution) { + Recommendations expected = kDefaultRecommendations; + expected.video.maximum = Resolution{1024, 768, 60.0}; + expected.video.bit_rate_limits.maximum = 47185920; + Answer answer; + answer.display = kValidOnlyResolution; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, OnlyAspectRatioFourThirds) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{320, 240, 30.0}; + expected.video.maximum = Resolution{1440, 1080, 30.0}; + expected.video.supports_scaling = true; + Answer answer; + answer.display = kValidOnlyAspectRatio; + + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, OnlyAspectRatioSixteenNine) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{426, 240, 30.0}; + expected.video.maximum = Resolution{1920, 1080, 30.0}; + expected.video.supports_scaling = true; + Answer answer; + answer.display = kValidOnlyAspectRatioSixteenNine; + + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, OnlyAspectRatioConstraint) { + Answer answer; + answer.display = kValidOnlyVariable; + EXPECT_EQ(kDefaultRecommendations, GetRecommendations(answer)); +} + +// It doesn't make sense to just provide a "fixed" aspect ratio with no +// other dimension information, so we just return default recommendations +// in this case and assume the sender will handle it elsewhere, e.g. on +// ANSWER message parsing. +TEST(CaptureRecommendationsTest, OnlyInvalidAspectRatioConstraint) { + Answer answer; + answer.display = kInvalidOnlyFixed; + EXPECT_EQ(kDefaultRecommendations, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, FixedAspectRatioConstraint) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{320, 240, 30.0}; + expected.video.maximum = Resolution{1440, 1080, 30.0}; + expected.video.supports_scaling = false; + Answer answer; + answer.display = kValidFixedAspectRatio; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +// Our behavior is actually the same whether the constraint is passed, we +// just percolate the constraint up to the capture devices so that intermediate +// frame sizes between minimum and maximum can be properly scaled. +TEST(CaptureRecommendationsTest, VariableAspectRatioConstraint) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{320, 240, 30.0}; + expected.video.maximum = Resolution{1440, 1080, 30.0}; + Answer answer; + answer.display = kValidVariableAspectRatio; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, ResolutionWithFixedConstraint) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{320, 240, 30.0}; + expected.video.maximum = Resolution{1024, 768, 60.0}; + expected.video.supports_scaling = false; + expected.video.bit_rate_limits.maximum = 47185920; + Answer answer; + answer.display = kValidFixedMissingAspectRatio; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, ExplicitFhdChangesMinimum) { + Answer answer; + answer.display = kValidDisplayFhd; + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{426, 240, 30.0}; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, XgaResolution) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{320, 240, 30.0}; + expected.video.maximum = Resolution{1024, 768, 60.0}; + expected.video.supports_scaling = false; + expected.video.bit_rate_limits.maximum = 47185920; + Answer answer; + answer.display = kValidDisplayXga; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, MismatchedDisplayAndAspectRatio) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{150, 200, 30.0}; + expected.video.maximum = Resolution{150, 200, 30.0}; + expected.video.supports_scaling = false; + expected.video.bit_rate_limits.maximum = 300 * 200 * 30; + Answer answer; + answer.display = kValidDisplayMismatched; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, TinyDisplay) { + Recommendations expected = kDefaultRecommendations; + expected.video.minimum = Resolution{300, 200, 30.0}; + expected.video.maximum = Resolution{300, 200, 30.0}; + expected.video.supports_scaling = false; + expected.video.bit_rate_limits.maximum = 300 * 200 * 30; + Answer answer; + answer.display = kValidDisplayTiny; + EXPECT_EQ(expected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, EmptyConstraints) { + Answer answer; + answer.constraints = kEmptyConstraints; + EXPECT_EQ(kDefaultRecommendations, GetRecommendations(answer)); +} + +// Generally speaking, if the receiver gives us constraints higher than our +// defaults we will accept them, with the exception of maximum resolutions +// exceeding 1080P. +TEST(CaptureRecommendationsTest, HandlesHighEnd) { + const Recommendations kExpected{ + Audio{BitRateLimits{96000, 500000}, milliseconds(6000), 5, 96100}, + Video{BitRateLimits{600000, 6000000}, Resolution{640, 480, 30}, + Resolution{1920, 1080, 30}, true, milliseconds(6000), 6000000}}; + Answer answer; + answer.constraints = kValidConstraintsHighEnd; + EXPECT_EQ(kExpected, GetRecommendations(answer)); +} + +// However, if the receiver gives us constraints lower than our minimum +// defaults, we will ignore them--they would result in an unacceptable cast +// experience. +TEST(CaptureRecommendationsTest, HandlesLowEnd) { + const Recommendations kExpected{ + Audio{BitRateLimits{32000, 50000}, milliseconds(1000), 2, 22000}, + Video{BitRateLimits{300000, 1000000}, Resolution{320, 240, 30}, + Resolution{1200, 800, 30}, true, milliseconds(1000), 60000}}; + Answer answer; + answer.constraints = kValidConstraintsLowEnd; + EXPECT_EQ(kExpected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, HandlesTooSmallScreen) { + const Recommendations kExpected{ + Audio{BitRateLimits{32000, 50000}, milliseconds(1000), 2, 22000}, + Video{BitRateLimits{300000, 1000000}, Resolution{320, 240, 30}, + Resolution{320, 240, 30}, true, milliseconds(1000), 60000}}; + Answer answer; + answer.constraints = kValidConstraintsLowEnd; + answer.constraints->video.max_dimensions = + answer.constraints->video.min_dimensions.value(); + EXPECT_EQ(kExpected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, HandlesMinimumSizeScreen) { + const Recommendations kExpected{ + Audio{BitRateLimits{32000, 50000}, milliseconds(1000), 2, 22000}, + Video{BitRateLimits{300000, 1000000}, Resolution{320, 240, 30}, + Resolution{320, 240, 30}, true, milliseconds(1000), 60000}}; + Answer answer; + answer.constraints = kValidConstraintsLowEnd; + answer.constraints->video.max_dimensions = + Dimensions{320, 240, SimpleFraction{30, 1}}; + EXPECT_EQ(kExpected, GetRecommendations(answer)); +} + +TEST(CaptureRecommendationsTest, UsesIntersectionOfDisplayAndConstraints) { + const Recommendations kExpected{ + Audio{BitRateLimits{96000, 500000}, milliseconds(6000), 5, 96100}, + Video{BitRateLimits{600000, 6000000}, Resolution{640, 480, 30}, + // Max resolution should be 1080P, since that's the display + // resolution. No reason to capture at 4K, even though the + // receiver supports it. + Resolution{1920, 1080, 30}, true, milliseconds(6000), 6000000}}; + Answer answer; + answer.display = kValidDisplayFhd; + answer.constraints = kValidConstraintsHighEnd; + EXPECT_EQ(kExpected, GetRecommendations(answer)); +} + +} // namespace capture_recommendations +} // namespace cast +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc index 969056cfa40..4cc7de005c8 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_builder_unittest.cc @@ -14,6 +14,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/api/time.h" +#include "util/chrono_helpers.h" using testing::_; using testing::Invoke; @@ -57,7 +58,7 @@ class CompoundRtcpBuilderTest : public testing::Test { TEST_F(CompoundRtcpBuilderTest, TheBasics) { const FrameId checkpoint = FrameId::first() + 42; builder()->SetCheckpointFrame(checkpoint); - const std::chrono::milliseconds playout_delay{321}; + const milliseconds playout_delay{321}; builder()->SetPlayoutDelay(playout_delay); const auto send_time = Clock::now(); @@ -115,7 +116,7 @@ TEST_F(CompoundRtcpBuilderTest, WithReceiverReportBlock) { // Build again, but this time the builder should not include the receiver // report block. - const auto second_send_time = send_time + std::chrono::milliseconds(500); + const auto second_send_time = send_time + milliseconds(500); const auto second_packet = builder()->BuildPacket(second_send_time, buffer); ASSERT_TRUE(second_packet.data()); EXPECT_CALL(*(client()), OnReceiverReferenceTimeAdvanced( @@ -160,7 +161,7 @@ TEST_F(CompoundRtcpBuilderTest, WithPictureLossIndicator) { Mock::VerifyAndClearExpectations(client()); ++checkpoint; - send_time += std::chrono::milliseconds(500); + send_time += milliseconds(500); } } } @@ -200,7 +201,7 @@ TEST_F(CompoundRtcpBuilderTest, WithNacks) { Mock::VerifyAndClearExpectations(client()); // Build again, but this time the builder should not include the feedback. - const auto second_send_time = send_time + std::chrono::milliseconds(500); + const auto second_send_time = send_time + milliseconds(500); const auto second_packet = builder()->BuildPacket(second_send_time, buffer); ASSERT_TRUE(second_packet.data()); EXPECT_CALL(*(client()), OnReceiverReferenceTimeAdvanced( @@ -248,7 +249,7 @@ TEST_F(CompoundRtcpBuilderTest, WithAcks) { // Build again, but this time the builder should not include the feedback // because it was already provided in the prior packet. - send_time += std::chrono::milliseconds(500); + send_time += milliseconds(500); const auto second_packet = builder()->BuildPacket(send_time, buffer); ASSERT_TRUE(second_packet.data()); EXPECT_CALL(*(client()), OnReceiverReferenceTimeAdvanced( @@ -258,7 +259,7 @@ TEST_F(CompoundRtcpBuilderTest, WithAcks) { ASSERT_TRUE(parser()->Parse(second_packet, kMaxFeedbackFrameId)); Mock::VerifyAndClearExpectations(client()); - send_time += std::chrono::milliseconds(500); + send_time += milliseconds(500); } } diff --git a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc index 9f8e3c50988..7b2499a23ce 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/compound_rtcp_parser_unittest.cc @@ -12,6 +12,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/api/time.h" +#include "util/chrono_helpers.h" using testing::_; using testing::Mock; @@ -264,7 +265,7 @@ TEST_F(CompoundRtcpParserTest, ParsesSimpleFeedback) { // First scenario: Valid range of FrameIds is [0,42]. const auto kMaxFeedbackFrameId0 = FrameId::first() + 42; const auto expected_frame_id0 = FrameId::first() + 10; - const auto expected_playout_delay = std::chrono::milliseconds(550); + const auto expected_playout_delay = milliseconds(550); EXPECT_CALL(*(client()), OnReceiverCheckpoint(expected_frame_id0, expected_playout_delay)); EXPECT_TRUE(parser()->Parse(kFeedbackPacket, kMaxFeedbackFrameId0)); @@ -321,7 +322,7 @@ TEST_F(CompoundRtcpParserTest, ParsesFeedbackWithNacks) { const auto kMaxFeedbackFrameId = FrameId::first() + 42; const auto expected_frame_id = FrameId::first() + 10; - const auto expected_playout_delay = std::chrono::milliseconds(552); + const auto expected_playout_delay = milliseconds(552); EXPECT_CALL(*(client()), OnReceiverCheckpoint(expected_frame_id, expected_playout_delay)); EXPECT_CALL(*(client()), OnReceiverIsMissingPackets(kMissingPackets)); @@ -386,7 +387,7 @@ TEST_F(CompoundRtcpParserTest, ParsesFeedbackWithAcks) { // Test the smaller packet. const auto kMaxFeedbackFrameId = FrameId::first() + 100; const auto expected_frame_id = FrameId::first() + 10; - const auto expected_playout_delay = std::chrono::milliseconds(294); + const auto expected_playout_delay = milliseconds(294); EXPECT_CALL(*(client()), OnReceiverCheckpoint(expected_frame_id, expected_playout_delay)); EXPECT_CALL(*(client()), OnReceiverHasFrames(kFrame13Only)); diff --git a/chromium/third_party/openscreen/src/cast/streaming/environment.cc b/chromium/third_party/openscreen/src/cast/streaming/environment.cc index dec44c8bb38..16cbb64fdf3 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/environment.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/environment.cc @@ -4,6 +4,8 @@ #include "cast/streaming/environment.h" +#include <utility> + #include "cast/streaming/rtp_defines.h" #include "platform/api/task_runner.h" #include "util/osp_logging.h" @@ -12,16 +14,11 @@ namespace openscreen { namespace cast { Environment::Environment(ClockNowFunctionPtr now_function, - TaskRunner* task_runner) + TaskRunner* task_runner, + const IPEndpoint& local_endpoint) : now_function_(now_function), task_runner_(task_runner) { OSP_DCHECK(now_function_); OSP_DCHECK(task_runner_); -} - -Environment::Environment(ClockNowFunctionPtr now_function, - TaskRunner* task_runner, - const IPEndpoint& local_endpoint) - : Environment(now_function, task_runner) { ErrorOr<std::unique_ptr<UdpSocket>> result = UdpSocket::Create(task_runner_, this, local_endpoint); const_cast<std::unique_ptr<UdpSocket>&>(socket_) = std::move(result.value()); diff --git a/chromium/third_party/openscreen/src/cast/streaming/environment.h b/chromium/third_party/openscreen/src/cast/streaming/environment.h index 278604b0d0f..0ab9a3997e4 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/environment.h +++ b/chromium/third_party/openscreen/src/cast/streaming/environment.h @@ -9,6 +9,7 @@ #include <functional> #include <memory> +#include <vector> #include "absl/types/span.h" #include "platform/api/time.h" @@ -34,10 +35,11 @@ class Environment : public UdpSocket::Client { // Construct with the given clock source and TaskRunner. Creates and // internally-owns a UdpSocket, and immediately binds it to the given - // |local_endpoint|. + // |local_endpoint|. If embedders do not care what interface/address the UDP + // socket is bound on, they may omit that argument. Environment(ClockNowFunctionPtr now_function, TaskRunner* task_runner, - const IPEndpoint& local_endpoint); + const IPEndpoint& local_endpoint = IPEndpoint::kAnyV6()); ~Environment() override; @@ -87,10 +89,11 @@ class Environment : public UdpSocket::Client { virtual void SendPacket(absl::Span<const uint8_t> packet); protected: - // Common constructor that just stores the injected dependencies and does not - // create a socket. Subclasses use this to provide an alternative packet - // receive/send mechanism (e.g., for testing). - Environment(ClockNowFunctionPtr now_function, TaskRunner* task_runner); + Environment() : now_function_(nullptr), task_runner_(nullptr) {} + + // Protected so that they can be set by the MockEnvironment for testing. + ClockNowFunctionPtr now_function_; + TaskRunner* task_runner_; private: // UdpSocket::Client implementation. @@ -98,8 +101,6 @@ class Environment : public UdpSocket::Client { void OnSendError(UdpSocket* socket, Error error) final; void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet_or_error) final; - const ClockNowFunctionPtr now_function_; - TaskRunner* const task_runner_; // The UDP socket bound to the local endpoint that was passed into the // constructor, or null if socket creation failed. diff --git a/chromium/third_party/openscreen/src/cast/streaming/message_util.h b/chromium/third_party/openscreen/src/cast/streaming/message_util.h deleted file mode 100644 index c986f0ca9e8..00000000000 --- a/chromium/third_party/openscreen/src/cast/streaming/message_util.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CAST_STREAMING_MESSAGE_UTIL_H_ -#define CAST_STREAMING_MESSAGE_UTIL_H_ - -#include <vector> - -#include "absl/strings/string_view.h" -#include "json/value.h" -#include "platform/base/error.h" - -// This file contains helper methods that are used by both answer and offer -// messages, but should not be publicly exposed/consumed. -namespace openscreen { -namespace cast { - -inline Error CreateParseError(const std::string& type) { - return Error(Error::Code::kJsonParseError, "Failed to parse " + type); -} - -inline Error CreateParameterError(const std::string& type) { - return Error(Error::Code::kParameterInvalid, "Invalid parameter: " + type); -} - -inline ErrorOr<bool> ParseBool(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isBool()) { - return CreateParseError("bool field " + field); - } - return value.asBool(); -} - -inline ErrorOr<int> ParseInt(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isInt()) { - return CreateParseError("integer field: " + field); - } - return value.asInt(); -} - -inline ErrorOr<uint32_t> ParseUint(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isUInt()) { - return CreateParseError("unsigned integer field: " + field); - } - return value.asUInt(); -} - -inline ErrorOr<std::string> ParseString(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isString()) { - return CreateParseError("string field: " + field); - } - return value.asString(); -} - -} // namespace cast -} // namespace openscreen - -#endif // CAST_STREAMING_MESSAGE_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc index 9d712537ccd..512b031e243 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/mock_environment.cc @@ -8,8 +8,10 @@ namespace openscreen { namespace cast { MockEnvironment::MockEnvironment(ClockNowFunctionPtr now_function, - TaskRunner* task_runner) - : Environment(now_function, task_runner) {} + TaskRunner* task_runner) { + task_runner_ = task_runner; + now_function_ = now_function; +} MockEnvironment::~MockEnvironment() = default; diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc index 51a232dd964..9fd04e038c7 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time.cc @@ -6,8 +6,6 @@ #include "util/osp_logging.h" -using std::chrono::duration_cast; - namespace openscreen { namespace cast { @@ -21,17 +19,19 @@ constexpr NtpSeconds kTimeBetweenNtpEpochAndUnixEpoch{INT64_C(2208988800)}; NtpTimeConverter::NtpTimeConverter(Clock::time_point now, std::chrono::seconds since_unix_epoch) : start_time_(now), - since_ntp_epoch_(duration_cast<NtpSeconds>(since_unix_epoch) + - kTimeBetweenNtpEpochAndUnixEpoch) {} + since_ntp_epoch_( + std::chrono::duration_cast<NtpSeconds>(since_unix_epoch) + + kTimeBetweenNtpEpochAndUnixEpoch) {} NtpTimeConverter::~NtpTimeConverter() = default; NtpTimestamp NtpTimeConverter::ToNtpTimestamp( Clock::time_point time_point) const { const Clock::duration time_since_start = time_point - start_time_; - const auto whole_seconds = duration_cast<NtpSeconds>(time_since_start); + const auto whole_seconds = + std::chrono::duration_cast<NtpSeconds>(time_since_start); const auto remainder = - duration_cast<NtpFraction>(time_since_start - whole_seconds); + std::chrono::duration_cast<NtpFraction>(time_since_start - whole_seconds); return AssembleNtpTimestamp(since_ntp_epoch_ + whole_seconds, remainder); } @@ -47,9 +47,8 @@ Clock::time_point NtpTimeConverter::ToLocalTime(NtpTimestamp timestamp) const { const auto whole_seconds = ntp_seconds - since_ntp_epoch_; const auto seconds_since_start = - duration_cast<Clock::duration>(whole_seconds) + start_time_; - const auto remainder = - duration_cast<Clock::duration>(NtpFractionPart(timestamp)); + Clock::to_duration(whole_seconds) + start_time_; + const auto remainder = Clock::to_duration(NtpFractionPart(timestamp)); return seconds_since_start + remainder; } diff --git a/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc index 1caa81ad475..325e75877f2 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/ntp_time_unittest.cc @@ -4,11 +4,10 @@ #include "cast/streaming/ntp_time.h" -#include "gtest/gtest.h" +#include <chrono> -using std::chrono::duration_cast; -using std::chrono::microseconds; -using std::chrono::milliseconds; +#include "gtest/gtest.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { @@ -22,8 +21,7 @@ TEST(NtpTimestampTest, SplitsIntoParts) { // 1 Jan 1900 plus 10 ms. timestamp = UINT64_C(0x00000000028f5c29); EXPECT_EQ(NtpSeconds::zero(), NtpSecondsPart(timestamp)); - EXPECT_EQ(milliseconds(10), - duration_cast<milliseconds>(NtpFractionPart(timestamp))); + EXPECT_EQ(milliseconds(10), to_microseconds(NtpFractionPart(timestamp))); // 1 Jan 1970 minus 2^-32 seconds. timestamp = UINT64_C(0x83aa7e80ffffffff); @@ -33,8 +31,7 @@ TEST(NtpTimestampTest, SplitsIntoParts) { // 2019-03-23 17:25:50.500. timestamp = UINT64_C(0xe0414d0e80000000); EXPECT_EQ(NtpSeconds(INT64_C(3762375950)), NtpSecondsPart(timestamp)); - EXPECT_EQ(milliseconds(500), - duration_cast<milliseconds>(NtpFractionPart(timestamp))); + EXPECT_EQ(milliseconds(500), to_microseconds(NtpFractionPart(timestamp))); } TEST(NtpTimestampTest, AssemblesFromParts) { @@ -43,13 +40,14 @@ TEST(NtpTimestampTest, AssemblesFromParts) { AssembleNtpTimestamp(NtpSeconds::zero(), NtpFraction::zero()); EXPECT_EQ(UINT64_C(0x0000000000000000), timestamp); - // 1 Jan 1900 plus 10 ms. Note that the duration_cast<NtpFraction>(10ms) - // truncates rather than rounds the 10ms value, so the resulting timestamp is - // one fractional tick less than the one found in the SplitsIntoParts test. - // The ~0.4 nanosecond error in the conversion is totally insignificant to a - // live system. + // 1 Jan 1900 plus 10 ms. Note that the + // std::chrono::duration_cast<NtpFraction>(10ms) truncates rather than rounds + // the 10ms value, so the resulting timestamp is one fractional tick less than + // the one found in the SplitsIntoParts test. The ~0.4 nanosecond error in the + // conversion is totally insignificant to a live system. timestamp = AssembleNtpTimestamp( - NtpSeconds::zero(), duration_cast<NtpFraction>(milliseconds(10))); + NtpSeconds::zero(), + std::chrono::duration_cast<NtpFraction>(milliseconds(10))); EXPECT_EQ(UINT64_C(0x00000000028f5c28), timestamp); // 1 Jan 1970 minus 2^-32 seconds. @@ -58,9 +56,9 @@ TEST(NtpTimestampTest, AssemblesFromParts) { EXPECT_EQ(UINT64_C(0x83aa7e7fffffffff), timestamp); // 2019-03-23 17:25:50.500. - timestamp = - AssembleNtpTimestamp(NtpSeconds(INT64_C(3762375950)), - duration_cast<NtpFraction>(milliseconds(500))); + timestamp = AssembleNtpTimestamp( + NtpSeconds(INT64_C(3762375950)), + std::chrono::duration_cast<NtpFraction>(milliseconds(500))); EXPECT_EQ(UINT64_C(0xe0414d0e80000000), timestamp); } diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc index caa6babfa00..ff6845a31f7 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.cc @@ -6,6 +6,7 @@ #include <inttypes.h> +#include <limits> #include <string> #include <utility> @@ -13,10 +14,10 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "cast/streaming/constants.h" -#include "cast/streaming/message_util.h" #include "cast/streaming/receiver_session.h" #include "platform/base/error.h" #include "util/big_endian.h" +#include "util/json/json_helpers.h" #include "util/json/json_serialization.h" #include "util/osp_logging.h" #include "util/stringprintf.h" @@ -33,7 +34,7 @@ constexpr char kStreamType[] = "type"; ErrorOr<RtpPayloadType> ParseRtpPayloadType(const Json::Value& parent, const std::string& field) { - auto t = ParseInt(parent, field); + auto t = json::ParseInt(parent, field); if (!t) { return t.error(); } @@ -49,14 +50,14 @@ ErrorOr<RtpPayloadType> ParseRtpPayloadType(const Json::Value& parent, ErrorOr<int> ParseRtpTimebase(const Json::Value& parent, const std::string& field) { - auto error_or_raw = ParseString(parent, field); + auto error_or_raw = json::ParseString(parent, field); if (!error_or_raw) { return error_or_raw.error(); } const auto fraction = SimpleFraction::FromString(error_or_raw.value()); if (fraction.is_error() || !fraction.value().is_positive()) { - return CreateParseError("RTP timebase"); + return json::CreateParseError("RTP timebase"); } // The spec demands a leading 1, so this isn't really a fraction. OSP_DCHECK(fraction.value().numerator == 1); @@ -71,7 +72,7 @@ constexpr int kAesStringLength = kAesBytesSize * kHexDigitsPerByte; ErrorOr<std::array<uint8_t, kAesBytesSize>> ParseAesHexBytes( const Json::Value& parent, const std::string& field) { - auto hex_string = ParseString(parent, field); + auto hex_string = json::ParseString(parent, field); if (!hex_string) { return hex_string.error(); } @@ -91,24 +92,24 @@ ErrorOr<std::array<uint8_t, kAesBytesSize>> ParseAesHexBytes( WriteBigEndian(quads[1], bytes.data() + 8); return bytes; } - return CreateParseError("AES hex string bytes"); + return json::CreateParseError("AES hex string bytes"); } ErrorOr<Stream> ParseStream(const Json::Value& value, Stream::Type type) { - auto index = ParseInt(value, "index"); + auto index = json::ParseInt(value, "index"); if (!index) { return index.error(); } // If channel is omitted, the default value is used later. - auto channels = ParseInt(value, "channels"); + auto channels = json::ParseInt(value, "channels"); if (channels.is_value() && channels.value() <= 0) { - return CreateParameterError("channel"); + return json::CreateParameterError("channel"); } - auto codec_name = ParseString(value, "codecName"); + auto codec_name = json::ParseString(value, "codecName"); if (!codec_name) { return codec_name.error(); } - auto rtp_profile = ParseString(value, "rtpProfile"); + auto rtp_profile = json::ParseString(value, "rtpProfile"); if (!rtp_profile) { return rtp_profile.error(); } @@ -116,7 +117,7 @@ ErrorOr<Stream> ParseStream(const Json::Value& value, Stream::Type type) { if (!rtp_payload_type) { return rtp_payload_type.error(); } - auto ssrc = ParseUint(value, "ssrc"); + auto ssrc = json::ParseUint(value, "ssrc"); if (!ssrc) { return ssrc.error(); } @@ -133,19 +134,19 @@ ErrorOr<Stream> ParseStream(const Json::Value& value, Stream::Type type) { return rtp_timebase.error(); } - auto target_delay = ParseInt(value, "targetDelay"); + auto target_delay = json::ParseInt(value, "targetDelay"); std::chrono::milliseconds target_delay_ms = kDefaultTargetPlayoutDelay; if (target_delay) { auto d = std::chrono::milliseconds(target_delay.value()); if (d >= kMinTargetPlayoutDelay && d <= kMaxTargetPlayoutDelay) { target_delay_ms = d; } else { - return CreateParameterError("target delay"); + return json::CreateParameterError("target delay"); } } - auto receiver_rtcp_event_log = ParseBool(value, "receiverRtcpEventLog"); - auto receiver_rtcp_dscp = ParseString(value, "receiverRtcpDscp"); + auto receiver_rtcp_event_log = json::ParseBool(value, "receiverRtcpEventLog"); + auto receiver_rtcp_dscp = json::ParseString(value, "receiverRtcpDscp"); return Stream{index.value(), type, channels.value(type == Stream::Type::kAudioSource @@ -167,28 +168,28 @@ ErrorOr<AudioStream> ParseAudioStream(const Json::Value& value) { if (!stream) { return stream.error(); } - auto bit_rate = ParseInt(value, "bitRate"); + auto bit_rate = json::ParseInt(value, "bitRate"); if (!bit_rate) { return bit_rate.error(); } // A bit rate of 0 is valid for some codec types, so we don't enforce here. if (bit_rate.value() < 0) { - return CreateParameterError("bit rate"); + return json::CreateParameterError("bit rate"); } return AudioStream{stream.value(), bit_rate.value()}; } ErrorOr<Resolution> ParseResolution(const Json::Value& value) { - auto width = ParseInt(value, "width"); + auto width = json::ParseInt(value, "width"); if (!width) { return width.error(); } - auto height = ParseInt(value, "height"); + auto height = json::ParseInt(value, "height"); if (!height) { return height.error(); } if (width.value() <= 0 || height.value() <= 0) { - return CreateParameterError("resolution"); + return json::CreateParameterError("resolution"); } return Resolution{width.value(), height.value()}; } @@ -223,7 +224,7 @@ ErrorOr<VideoStream> ParseVideoStream(const Json::Value& value) { return resolutions.error(); } - auto raw_max_frame_rate = ParseString(value, "maxFrameRate"); + auto raw_max_frame_rate = json::ParseString(value, "maxFrameRate"); SimpleFraction max_frame_rate{kDefaultMaxFrameRate, 1}; if (raw_max_frame_rate.is_value()) { auto parsed = SimpleFraction::FromString(raw_max_frame_rate.value()); @@ -232,11 +233,11 @@ ErrorOr<VideoStream> ParseVideoStream(const Json::Value& value) { } } - auto profile = ParseString(value, "profile"); - auto protection = ParseString(value, "protection"); - auto max_bit_rate = ParseInt(value, "maxBitRate"); - auto level = ParseString(value, "level"); - auto error_recovery_mode = ParseString(value, "errorRecoveryMode"); + auto profile = json::ParseString(value, "profile"); + auto protection = json::ParseString(value, "protection"); + auto max_bit_rate = json::ParseInt(value, "maxBitRate"); + auto level = json::ParseString(value, "level"); + auto error_recovery_mode = json::ParseString(value, "errorRecoveryMode"); return VideoStream{stream.value(), max_frame_rate, max_bit_rate.value(4 << 20), @@ -276,7 +277,7 @@ ErrorOr<Json::Value> Stream::ToJson() const { target_delay.count() <= 0 || target_delay.count() > std::numeric_limits<int>::max() || rtp_timebase < 1) { - return CreateParameterError("Stream"); + return json::CreateParameterError("Stream"); } Json::Value root; @@ -315,7 +316,7 @@ std::string CastMode::ToString() const { ErrorOr<Json::Value> AudioStream::ToJson() const { // A bit rate of 0 is valid for some codec types, so we don't enforce here. if (bit_rate < 0) { - return CreateParameterError("AudioStream"); + return json::CreateParameterError("AudioStream"); } auto error_or_stream = stream.ToJson(); @@ -329,7 +330,7 @@ ErrorOr<Json::Value> AudioStream::ToJson() const { ErrorOr<Json::Value> Resolution::ToJson() const { if (width <= 0 || height <= 0) { - return CreateParameterError("Resolution"); + return json::CreateParameterError("Resolution"); } Json::Value root; @@ -340,7 +341,7 @@ ErrorOr<Json::Value> Resolution::ToJson() const { ErrorOr<Json::Value> VideoStream::ToJson() const { if (max_bit_rate <= 0 || !max_frame_rate.is_positive()) { - return CreateParameterError("VideoStream"); + return json::CreateParameterError("VideoStream"); } auto error_or_stream = stream.ToJson(); @@ -372,18 +373,18 @@ ErrorOr<Json::Value> VideoStream::ToJson() const { ErrorOr<Offer> Offer::Parse(const Json::Value& root) { CastMode cast_mode = CastMode::Parse(root["castMode"].asString()); - const ErrorOr<bool> get_status = ParseBool(root, "receiverGetStatus"); + const ErrorOr<bool> get_status = json::ParseBool(root, "receiverGetStatus"); Json::Value supported_streams = root[kSupportedStreams]; if (!supported_streams.isArray()) { - return CreateParseError("supported streams in offer"); + return json::CreateParseError("supported streams in offer"); } std::vector<AudioStream> audio_streams; std::vector<VideoStream> video_streams; for (Json::ArrayIndex i = 0; i < supported_streams.size(); ++i) { const Json::Value& fields = supported_streams[i]; - auto type = ParseString(fields, kStreamType); + auto type = json::ParseString(fields, kStreamType); if (!type) { return type.error(); } diff --git a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h index 319145bc3cf..f5642495906 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h +++ b/chromium/third_party/openscreen/src/cast/streaming/offer_messages.h @@ -5,7 +5,7 @@ #ifndef CAST_STREAMING_OFFER_MESSAGES_H_ #define CAST_STREAMING_OFFER_MESSAGES_H_ -#include <chrono> // NOLINT +#include <chrono> #include <string> #include <vector> diff --git a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc index 5146cc8d302..eb1b45c7118 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/packet_receive_stats_tracker_unittest.cc @@ -4,10 +4,12 @@ #include "cast/streaming/packet_receive_stats_tracker.h" +#include <chrono> #include <limits> #include "cast/streaming/constants.h" #include "gtest/gtest.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { @@ -74,8 +76,7 @@ TEST(PacketReceiveStatsTrackerTest, PopulatesReportWithOnePacketTracked) { constexpr uint16_t kSequenceNumber = 1234; constexpr RtpTimeTicks kRtpTimestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(42); - constexpr auto kArrivalTime = - Clock::time_point() + std::chrono::seconds(3600); + constexpr auto kArrivalTime = Clock::time_point() + seconds(3600); PacketReceiveStatsTracker tracker(kSomeRtpTimebase); tracker.OnReceivedValidRtpPacket(kSequenceNumber, kRtpTimestamp, @@ -96,8 +97,7 @@ TEST(PacketReceiveStatsTrackerTest, WhenReceivingAllPackets) { std::numeric_limits<uint16_t>::max() - 2; constexpr RtpTimeTicks kFirstRtpTimestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(42); - constexpr auto kFirstArrivalTime = - Clock::time_point() + std::chrono::seconds(3600); + constexpr auto kFirstArrivalTime = Clock::time_point() + seconds(3600); PacketReceiveStatsTracker tracker(kSomeRtpTimebase); @@ -107,7 +107,7 @@ TEST(PacketReceiveStatsTrackerTest, WhenReceivingAllPackets) { tracker.OnReceivedValidRtpPacket( kFirstSequenceNumber + i, kFirstRtpTimestamp + RtpTimeDelta::FromTicks(kSomeRtpTimebase) * i, - kFirstArrivalTime + std::chrono::seconds(i)); + kFirstArrivalTime + seconds(i)); } RtcpReportBlock report = GetSentinel(); @@ -131,8 +131,7 @@ TEST(PacketReceiveStatsTrackerTest, WhenReceivingAboutHalfThePackets) { constexpr uint16_t kFirstSequenceNumber = 3; constexpr RtpTimeTicks kFirstRtpTimestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(99); - constexpr auto kFirstArrivalTime = - Clock::time_point() + std::chrono::seconds(8888); + constexpr auto kFirstArrivalTime = Clock::time_point() + seconds(8888); PacketReceiveStatsTracker tracker(kSomeRtpTimebase); @@ -145,7 +144,7 @@ TEST(PacketReceiveStatsTrackerTest, WhenReceivingAboutHalfThePackets) { tracker.OnReceivedValidRtpPacket( kFirstSequenceNumber + (i * 2 + 1), kFirstRtpTimestamp + RtpTimeDelta::FromTicks(kSomeRtpTimebase) * i, - kFirstArrivalTime + std::chrono::seconds(i)); + kFirstArrivalTime + seconds(i)); } RtcpReportBlock report = GetSentinel(); @@ -163,14 +162,12 @@ TEST(PacketReceiveStatsTrackerTest, ComputesJitterCorrectly) { constexpr uint16_t kFirstSequenceNumber = 3; constexpr RtpTimeTicks kFirstRtpTimestamp = RtpTimeTicks() + RtpTimeDelta::FromTicks(99); - constexpr auto kFirstArrivalTime = - Clock::time_point() + std::chrono::seconds(8888); + constexpr auto kFirstArrivalTime = Clock::time_point() + seconds(8888); // Record 100 packet arrivals, one second apart, where each packet's RTP // timestamps are progressing 2 seconds forward. Thus, the jitter calculation // should gradually converge towards a difference of one second. - constexpr auto kTrueJitter = - std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)); + constexpr auto kTrueJitter = Clock::to_duration(seconds(1)); PacketReceiveStatsTracker tracker(kSomeRtpTimebase); Clock::duration last_diff = Clock::duration::max(); for (int i = 0; i < 100; ++i) { @@ -178,7 +175,7 @@ TEST(PacketReceiveStatsTrackerTest, ComputesJitterCorrectly) { kFirstSequenceNumber + i, kFirstRtpTimestamp + RtpTimeDelta::FromTicks(kSomeRtpTimebase) * (i * 2), - kFirstArrivalTime + std::chrono::seconds(i)); + kFirstArrivalTime + seconds(i)); // Expect that the jitter is becoming closer to the actual value in each // iteration. @@ -198,8 +195,7 @@ TEST(PacketReceiveStatsTrackerTest, ComputesJitterCorrectly) { tracker.PopulateNextReport(&report); const auto diff = kTrueJitter - report.jitter.ToDuration<Clock::duration>(kSomeRtpTimebase); - constexpr auto kMaxDiffAtEnd = - std::chrono::duration_cast<Clock::duration>(std::chrono::milliseconds(2)); + constexpr auto kMaxDiffAtEnd = Clock::to_duration(milliseconds(2)); EXPECT_NEAR(0, diff.count(), kMaxDiffAtEnd.count()); } diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver.cc index 280af7940a5..d4c86da437c 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver.cc @@ -5,18 +5,16 @@ #include "cast/streaming/receiver.h" #include <algorithm> +#include <utility> #include "absl/types/span.h" #include "cast/streaming/constants.h" #include "cast/streaming/receiver_packet_router.h" #include "cast/streaming/session_config.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/std_util.h" -using std::chrono::duration_cast; -using std::chrono::microseconds; -using std::chrono::milliseconds; - namespace openscreen { namespace cast { @@ -148,14 +146,14 @@ EncodedFrame Receiver::ConsumeNextFrame(absl::Span<uint8_t> buffer) { frame.reference_time = *entry.estimated_capture_time + ResolveTargetPlayoutDelay(frame_id); - RECEIVER_VLOG - << "ConsumeNextFrame → " << frame.frame_id << ": " << frame.data.size() - << " payload bytes, RTP Timestamp " - << frame.rtp_timestamp.ToTimeSinceOrigin<microseconds>(rtp_timebase_) - .count() - << " µs, to play-out " - << duration_cast<microseconds>(frame.reference_time - now_()).count() - << " µs from now."; + RECEIVER_VLOG << "ConsumeNextFrame → " << frame.frame_id << ": " + << frame.data.size() << " payload bytes, RTP Timestamp " + << frame.rtp_timestamp + .ToTimeSinceOrigin<microseconds>(rtp_timebase_) + .count() + << " µs, to play-out " + << to_microseconds(frame.reference_time - now_()).count() + << " µs from now."; entry.Reset(); last_frame_consumed_ = frame_id; @@ -310,7 +308,7 @@ void Receiver::OnReceivedRtcpPacket(Clock::time_point arrival_time, smoothed_clock_offset_.Update(arrival_time, measured_offset); RECEIVER_VLOG << "Received Sender Report: Local clock is ahead of Sender's by " - << duration_cast<microseconds>(smoothed_clock_offset_.Current()).count() + << to_microseconds(smoothed_clock_offset_.Current()).count() << " µs (minus one-way network transit time)."; RtcpReportBlock report; diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver.h b/chromium/third_party/openscreen/src/cast/streaming/receiver.h index e63a9e4ee1f..b4c53868b03 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver.h +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver.h @@ -8,8 +8,9 @@ #include <stdint.h> #include <array> -#include <chrono> // NOLINT +#include <chrono> #include <memory> +#include <utility> #include <vector> #include "absl/types/optional.h" diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc index 83ac0cf69b0..a9aad413f2c 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.cc @@ -4,7 +4,7 @@ #include "cast/streaming/receiver_session.h" -#include <chrono> // NOLINT +#include <chrono> #include <string> #include <utility> @@ -12,14 +12,15 @@ #include "absl/strings/numbers.h" #include "cast/streaming/environment.h" #include "cast/streaming/message_port.h" -#include "cast/streaming/message_util.h" #include "cast/streaming/offer_messages.h" #include "cast/streaming/receiver.h" +#include "util/json/json_helpers.h" #include "util/osp_logging.h" namespace openscreen { namespace cast { +/// NOTE: Constants here are all taken from the Cast V2: Mirroring Control // JSON message field values specific to the Receiver Session. static constexpr char kMessageTypeOffer[] = "OFFER"; @@ -28,25 +29,39 @@ static constexpr char kOfferMessageBody[] = "offer"; static constexpr char kKeyType[] = "type"; static constexpr char kSequenceNumber[] = "seqNum"; +/// Protocol specification: http://goto.google.com/mirroring-control-protocol +// TODO(jophba): document the protocol in a public repository. +static constexpr char kMessageKeyType[] = "type"; +static constexpr char kMessageTypeAnswer[] = "ANSWER"; + +/// ANSWER message fields. +static constexpr char kAnswerMessageBody[] = "answer"; +static constexpr char kResult[] = "result"; +static constexpr char kResultOk[] = "ok"; +static constexpr char kResultError[] = "error"; +static constexpr char kErrorMessageBody[] = "error"; +static constexpr char kErrorCode[] = "code"; +static constexpr char kErrorDescription[] = "description"; + // Using statements for constructor readability. using Preferences = ReceiverSession::Preferences; using ConfiguredReceivers = ReceiverSession::ConfiguredReceivers; namespace { -std::string GetCodecName(ReceiverSession::AudioCodec codec) { +std::string CodecToString(ReceiverSession::AudioCodec codec) { switch (codec) { case ReceiverSession::AudioCodec::kAac: return "aac"; case ReceiverSession::AudioCodec::kOpus: return "opus"; + default: + OSP_NOTREACHED() << "Codec not accounted for in switch statement."; + return {}; } - - OSP_NOTREACHED() << "Codec not accounted for in switch statement."; - return {}; } -std::string GetCodecName(ReceiverSession::VideoCodec codec) { +std::string CodecToString(ReceiverSession::VideoCodec codec) { switch (codec) { case ReceiverSession::VideoCodec::kH264: return "h264"; @@ -56,20 +71,20 @@ std::string GetCodecName(ReceiverSession::VideoCodec codec) { return "hevc"; case ReceiverSession::VideoCodec::kVp9: return "vp9"; + default: + OSP_NOTREACHED() << "Codec not accounted for in switch statement."; + return {}; } - - OSP_NOTREACHED() << "Codec not accounted for in switch statement."; - return {}; } template <typename Stream, typename Codec> const Stream* SelectStream(const std::vector<Codec>& preferred_codecs, const std::vector<Stream>& offered_streams) { - for (Codec codec : preferred_codecs) { - const std::string codec_name = GetCodecName(codec); + for (auto codec : preferred_codecs) { + const std::string codec_name = CodecToString(codec); for (const Stream& offered_stream : offered_streams) { if (offered_stream.stream.codec_name == codec_name) { - OSP_VLOG << "Selected " << codec_name << " as codec for streaming."; + OSP_DVLOG << "Selected " << codec_name << " as codec for streaming"; return &offered_stream; } } @@ -77,6 +92,27 @@ const Stream* SelectStream(const std::vector<Codec>& preferred_codecs, return nullptr; } +// Helper method that creates an invalid Answer response. +Json::Value CreateInvalidAnswerMessage(Error error) { + Json::Value message_root; + message_root[kMessageKeyType] = kMessageTypeAnswer; + message_root[kResult] = kResultError; + message_root[kErrorMessageBody][kErrorCode] = static_cast<int>(error.code()); + message_root[kErrorMessageBody][kErrorDescription] = error.message(); + + return message_root; +} + +// Helper method that creates a valid Answer response. +Json::Value CreateAnswerMessage(const Answer& answer) { + OSP_DCHECK(answer.IsValid()); + Json::Value message_root; + message_root[kMessageKeyType] = kMessageTypeAnswer; + message_root[kAnswerMessageBody] = answer.ToJson(); + message_root[kResult] = kResultOk; + return message_root; +} + } // namespace Preferences::Preferences() = default; @@ -124,29 +160,33 @@ void ReceiverSession::OnMessage(absl::string_view sender_id, if (!message_json) { client_->OnError(this, Error::Code::kJsonParseError); - OSP_LOG_WARN << "Received an invalid message: " << message; + OSP_DLOG_WARN << "Received an invalid message: " << message; return; } + OSP_DVLOG << "Received a message: " << message; // TODO(jophba): add sender connected/disconnected messaging. - auto sequence_number = ParseInt(message_json.value(), kSequenceNumber); - if (!sequence_number) { - OSP_LOG_WARN << "Invalid message sequence number"; + int sequence_number; + if (!json::ParseAndValidateInt(message_json.value()[kSequenceNumber], + &sequence_number)) { + OSP_DLOG_WARN << "Invalid message sequence number"; return; } - auto key_or_error = ParseString(message_json.value(), kKeyType); - if (!key_or_error) { - OSP_LOG_WARN << "Invalid message key"; + std::string key; + if (!json::ParseAndValidateString(message_json.value()[kKeyType], &key)) { + OSP_DLOG_WARN << "Invalid message key"; return; } Message parsed_message{sender_id.data(), message_namespace.data(), - sequence_number.value()}; - if (key_or_error.value() == kMessageTypeOffer) { + sequence_number}; + if (key == kMessageTypeOffer) { parsed_message.body = std::move(message_json.value()[kOfferMessageBody]); if (parsed_message.body.isNull()) { - OSP_LOG_WARN << "Invalid message offer body"; + client_->OnError(this, Error(Error::Code::kJsonParseError, + "Received offer missing offer body")); + OSP_DLOG_WARN << "Invalid message offer body"; return; } OnOffer(&parsed_message); @@ -154,15 +194,14 @@ void ReceiverSession::OnMessage(absl::string_view sender_id, } void ReceiverSession::OnError(Error error) { - OSP_LOG_WARN << "ReceiverSession's MessagePump encountered an error:" - << error; + OSP_DLOG_WARN << "ReceiverSession message port error: " << error; } void ReceiverSession::OnOffer(Message* message) { ErrorOr<Offer> offer = Offer::Parse(std::move(message->body)); if (!offer) { client_->OnError(this, offer.error()); - OSP_LOG_WARN << "Could not parse offer" << offer.error(); + OSP_DLOG_WARN << "Could not parse offer" << offer.error(); return; } @@ -180,19 +219,31 @@ void ReceiverSession::OnOffer(Message* message) { SelectStream(preferences_.video_codecs, offer.value().video_streams); } - cast_mode_ = offer.value().cast_mode; - auto receivers = - TrySpawningReceivers(selected_audio_stream, selected_video_stream); - if (receivers) { - const Answer answer = - ConstructAnswer(message, selected_audio_stream, selected_video_stream); - client_->OnNegotiated(this, std::move(receivers.value())); + if (!selected_audio_stream && !selected_video_stream) { + message->body = CreateInvalidAnswerMessage( + Error(Error::Code::kParseError, "No selected streams")); + OSP_DLOG_WARN << "Failed to select any streams from OFFER"; + SendMessage(message); + return; + } - message->body = answer.ToAnswerMessage(); - } else { - message->body = CreateInvalidAnswer(receivers.error()); + const Answer answer = + ConstructAnswer(message, selected_audio_stream, selected_video_stream); + if (!answer.IsValid()) { + message->body = CreateInvalidAnswerMessage( + Error(Error::Code::kParseError, "Invalid answer message")); + OSP_DLOG_WARN << "Failed to construct an ANSWER message"; + SendMessage(message); + return; } + // Only spawn receivers if we know we have a valid answer message. + ConfiguredReceivers receivers = + SpawnReceivers(selected_audio_stream, selected_video_stream); + // If the answer message is invalid, there is no point in setting up a + // negotiation because the sender won't be able to connect to it. + client_->OnNegotiated(this, std::move(receivers)); + message->body = CreateAnswerMessage(answer); SendMessage(message); } @@ -208,13 +259,9 @@ ReceiverSession::ConstructReceiver(const Stream& stream) { return std::make_pair(std::move(config), std::move(receiver)); } -ErrorOr<ConfiguredReceivers> ReceiverSession::TrySpawningReceivers( - const AudioStream* audio, - const VideoStream* video) { - if (!audio && !video) { - return Error::Code::kParameterInvalid; - } - +ConfiguredReceivers ReceiverSession::SpawnReceivers(const AudioStream* audio, + const VideoStream* video) { + OSP_DCHECK(audio || video); ResetReceivers(); absl::optional<ConfiguredReceiver<AudioStream>> audio_receiver; @@ -266,20 +313,20 @@ Answer ReceiverSession::ConstructAnswer( absl::optional<Constraints> constraints; if (preferences_.constraints) { - constraints = *preferences_.constraints; + constraints = absl::optional<Constraints>(*preferences_.constraints); } absl::optional<DisplayDescription> display; if (preferences_.display_description) { - display = *preferences_.display_description; + display = + absl::optional<DisplayDescription>(*preferences_.display_description); } - return Answer{cast_mode_, - environment_->GetBoundLocalEndpoint().port, + return Answer{environment_->GetBoundLocalEndpoint().port, std::move(stream_indexes), std::move(stream_ssrcs), - constraints, - display, + std::move(constraints), + std::move(display), std::vector<int>{}, // receiver_rtcp_event_log std::vector<int>{}, // receiver_rtcp_dscp supports_wifi_status_reporting_}; @@ -291,9 +338,14 @@ void ReceiverSession::SendMessage(Message* message) { auto body_or_error = json::Stringify(message->body); if (body_or_error.is_value()) { + OSP_DVLOG << "Sending message: SENDER[" << message->sender_id + << "], NAMESPACE[" << message->message_namespace << "], BODY:\n" + << body_or_error.value(); message_port_->PostMessage(message->sender_id, message->message_namespace, body_or_error.value()); } else { + OSP_DLOG_WARN << "Sending message failed with error:\n" + << body_or_error.error(); client_->OnError(this, body_or_error.error()); } } diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h index be59f4ec2c7..53c41051f8b 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session.h @@ -7,8 +7,14 @@ #include <memory> #include <string> +#include <utility> #include <vector> +// TODO(jophba): remove public abseil dependencies. Will require modifying +// either Optional or ConfiguredReceivers, as the compiler currently has an +// error. +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "cast/streaming/answer_messages.h" #include "cast/streaming/message_port.h" #include "cast/streaming/offer_messages.h" @@ -23,7 +29,7 @@ class CastSocket; class Environment; class Receiver; class VirtualConnectionRouter; -class VirtualConnection; +struct VirtualConnection; class ReceiverSession final : public MessagePort::Client { public: @@ -51,6 +57,8 @@ class ReceiverSession final : public MessagePort::Client { // If the receiver is audio- or video-only, either of the receivers // may be nullptr. However, in the majority of cases they will be populated. + // TODO(jophba): remove AudioStream, VideoStream from public API. + // TODO(jophba): remove absl::optional from public API. absl::optional<ConfiguredReceiver<AudioStream>> audio; absl::optional<ConfiguredReceiver<VideoStream>> video; }; @@ -73,8 +81,8 @@ class ReceiverSession final : public MessagePort::Client { // The embedder has the option of providing a list of prioritized // preferences for selecting from the offer. - enum class AudioCodec : int { kAac, kOpus }; - enum class VideoCodec : int { kH264, kVp8, kHevc, kVp9 }; + enum class AudioCodec { kAac, kOpus }; + enum class VideoCodec { kH264, kVp8, kHevc, kVp9 }; // Note: embedders are required to implement the following // codecs to be Cast V2 compliant: H264, VP8, AAC, Opus. @@ -126,22 +134,24 @@ class ReceiverSession final : public MessagePort::Client { Json::Value body; }; - // Message handlers + // Specific message type handler methods. void OnOffer(Message* message); + // Used by SpawnReceivers to generate a receiver for a specific stream. std::pair<SessionConfig, std::unique_ptr<Receiver>> ConstructReceiver( const Stream& stream); - // Either stream input to this method may be null, however if both - // are null this method returns error. - ErrorOr<ConfiguredReceivers> TrySpawningReceivers(const AudioStream* audio, - const VideoStream* video); + // Creates a set of configured receivers from a given pair of audio and + // video streams. NOTE: either audio or video may be null, but not both. + ConfiguredReceivers SpawnReceivers(const AudioStream* audio, + const VideoStream* video); // Callers of this method should ensure at least one stream is non-null. Answer ConstructAnswer(Message* message, const AudioStream* audio, const VideoStream* video); + // Sends a message over the message port. void SendMessage(Message* message); // Handles resetting receivers and notifying the client. @@ -152,7 +162,6 @@ class ReceiverSession final : public MessagePort::Client { MessagePort* const message_port_; const Preferences preferences_; - CastMode cast_mode_; bool supports_wifi_status_reporting_ = false; ReceiverPacketRouter packet_router_; diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc index ce9a0079cd7..c5894165cef 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_session_unittest.cc @@ -12,6 +12,7 @@ #include "platform/base/ip_address.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" +#include "util/chrono_helpers.h" using ::testing::_; using ::testing::Invoke; @@ -122,6 +123,38 @@ constexpr char kNoAudioOfferMessage[] = R"({ } })"; +constexpr char kInvalidCodecOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337, + "offer": { + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": [ + { + "index": 31338, + "type": "video_source", + "codecName": "vp12", + "rtpProfile": "cast", + "rtpPayloadType": 127, + "ssrc": 19088745, + "maxFrameRate": "60000/1000", + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "4", + "aesKey": "040d756791711fd3adb939066e6d8690", + "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed", + "resolutions": [ + { + "width": 1280, + "height": 720 + } + ] + } + ] + } +})"; + constexpr char kNoVideoOfferMessage[] = R"({ "type": "OFFER", "seqNum": 1337, @@ -158,14 +191,44 @@ constexpr char kNoAudioOrVideoOfferMessage[] = R"({ constexpr char kInvalidJsonOfferMessage[] = R"({ "type": "OFFER", - "seqNum": 1337,,, - "offer": + "seqNum": 1337, + "offer": { "castMode": "mirroring", "receiverGetStatus": true, "supportedStreams": [ } })"; +constexpr char kValidJsonInvalidFormatOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337, + "offer": { + "castMode": "mirroring", + "receiverGetStatus": true, + "supportedStreams": "anything" + } +})"; + +constexpr char kNullJsonOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 1337 +})"; + +constexpr char kInvalidSequenceNumberMessage[] = R"({ + "type": "OFFER", + "seqNum": "not actually a number" +})"; + +constexpr char kUnknownTypeMessage[] = R"({ + "type": "OFFER_VERSION_2", + "seqNum": 1337 +})"; + +constexpr char kInvalidTypeMessage[] = R"({ + "type": 39, + "seqNum": 1337 +})"; + class SimpleMessagePort : public MessagePort { public: ~SimpleMessagePort() override {} @@ -221,7 +284,6 @@ void ExpectIsErrorAnswerMessage(const ErrorOr<Json::Value>& message_or_error) { const Json::Value& error = message["error"]; EXPECT_TRUE(error.isObject()); EXPECT_GT(error["code"].asInt(), 0); - EXPECT_EQ("", error["description"].asString()); } } // namespace @@ -231,41 +293,38 @@ class ReceiverSessionTest : public ::testing::Test { ReceiverSessionTest() : clock_(Clock::time_point{}), task_runner_(&clock_) {} std::unique_ptr<MockEnvironment> MakeEnvironment() { - auto environment = std::make_unique<NiceMock<MockEnvironment>>( + auto environment_ = std::make_unique<NiceMock<MockEnvironment>>( &FakeClock::now, &task_runner_); - ON_CALL(*environment, GetBoundLocalEndpoint()) + ON_CALL(*environment_, GetBoundLocalEndpoint()) .WillByDefault( Return(IPEndpoint{IPAddress::Parse("127.0.0.1").value(), 12345})); - return environment; + return environment_; } - private: + void SetUp() { + message_port_ = std::make_unique<SimpleMessagePort>(); + environment_ = MakeEnvironment(); + session_ = std::make_unique<ReceiverSession>( + &client_, environment_.get(), message_port_.get(), + ReceiverSession::Preferences{}); + } + + protected: + StrictMock<FakeClient> client_; FakeClock clock_; + std::unique_ptr<MockEnvironment> environment_; + std::unique_ptr<SimpleMessagePort> message_port_; + std::unique_ptr<ReceiverSession> session_; FakeTaskRunner task_runner_; }; -TEST_F(ReceiverSessionTest, RegistersSelfOnMessagePump) { - auto message_port = std::make_unique<SimpleMessagePort>(); - // This should be safe, since the message_port location should not move - // just because of being moved into the ReceiverSession. - StrictMock<FakeClient> client; - - auto environment = MakeEnvironment(); - auto session = std::make_unique<ReceiverSession>( - &client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); - EXPECT_EQ(message_port->client(), session.get()); +TEST_F(ReceiverSessionTest, RegistersSelfOnMessagePort) { + EXPECT_EQ(message_port_->client(), session_.get()); } TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - auto environment = MakeEnvironment(); - ReceiverSession session(&client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); - - EXPECT_CALL(client, OnNegotiated(&session, _)) - .WillOnce([](const ReceiverSession* session, + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)) + .WillOnce([](const ReceiverSession* session_, ReceiverSession::ConfiguredReceivers cr) { EXPECT_TRUE(cr.audio); EXPECT_EQ(cr.audio.value().receiver_config.sender_ssrc, 19088747u); @@ -293,11 +352,11 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { EXPECT_EQ(cr.video.value().selected_stream.stream.codec_name, "vp8"); EXPECT_EQ(cr.video.value().selected_stream.stream.channels, 1); }); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + EXPECT_CALL(client_, OnConfiguredReceiversDestroyed(session_.get())).Times(1); - message_port->ReceiveMessage(kValidOfferMessage); + message_port_->ReceiveMessage(kValidOfferMessage); - const auto& messages = message_port->posted_messages(); + const auto& messages = message_port_->posted_messages(); ASSERT_EQ(1u, messages.size()); auto message_body = json::Parse(messages[0]); @@ -314,7 +373,6 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { // Spot check the answer body fields. We have more in depth testing // of answer behavior in answer_messages_unittest, but here we can // ensure that the ReceiverSession properly configured the answer. - EXPECT_EQ("mirroring", answer_body["castMode"].asString()); EXPECT_EQ(1337, answer_body["sendIndexes"][0].asInt()); EXPECT_EQ(31338, answer_body["sendIndexes"][1].asInt()); EXPECT_LT(0, answer_body["udpPort"].asInt()); @@ -329,16 +387,13 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { } TEST_F(ReceiverSessionTest, CanNegotiateWithCustomCodecPreferences) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - auto environment = MakeEnvironment(); ReceiverSession session( - &client, environment.get(), message_port.get(), + &client_, environment_.get(), message_port_.get(), ReceiverSession::Preferences{{ReceiverSession::VideoCodec::kVp9}, {ReceiverSession::AudioCodec::kOpus}}); - EXPECT_CALL(client, OnNegotiated(&session, _)) - .WillOnce([](const ReceiverSession* session, + EXPECT_CALL(client_, OnNegotiated(&session, _)) + .WillOnce([](const ReceiverSession* session_, ReceiverSession::ConfiguredReceivers cr) { EXPECT_TRUE(cr.audio); EXPECT_EQ(cr.audio.value().receiver_config.sender_ssrc, 19088747u); @@ -353,37 +408,37 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithCustomCodecPreferences) { EXPECT_EQ(cr.video.value().receiver_config.channels, 1); EXPECT_EQ(cr.video.value().receiver_config.rtp_timebase, 90000); }); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); - message_port->ReceiveMessage(kValidOfferMessage); + EXPECT_CALL(client_, OnConfiguredReceiversDestroyed(&session)).Times(1); + message_port_->ReceiveMessage(kValidOfferMessage); } TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - - auto constraints = std::unique_ptr<Constraints>{new Constraints{ + auto constraints = std::make_unique<Constraints>(Constraints{ AudioConstraints{1, 2, 3, 4}, - VideoConstraints{3.14159, Dimensions{320, 240, SimpleFraction{24, 1}}, + + VideoConstraints{3.14159, + absl::optional<Dimensions>( + Dimensions{320, 240, SimpleFraction{24, 1}}), Dimensions{1920, 1080, SimpleFraction{144, 1}}, 3000, - 90000000, std::chrono::milliseconds{1000}}}}; + 90000000, milliseconds(1000)}}); - auto display = std::unique_ptr<DisplayDescription>{new DisplayDescription{ - Dimensions{640, 480, SimpleFraction{60, 1}}, AspectRatio{16, 9}, - AspectRatioConstraint::kFixed}}; + auto display = std::make_unique<DisplayDescription>(DisplayDescription{ + absl::optional<Dimensions>(Dimensions{640, 480, SimpleFraction{60, 1}}), + absl::optional<AspectRatio>(AspectRatio{16, 9}), + absl::optional<AspectRatioConstraint>(AspectRatioConstraint::kFixed)}); - auto environment = MakeEnvironment(); ReceiverSession session( - &client, environment.get(), message_port.get(), + &client_, environment_.get(), message_port_.get(), ReceiverSession::Preferences{{ReceiverSession::VideoCodec::kVp9}, {ReceiverSession::AudioCodec::kOpus}, std::move(constraints), std::move(display)}); - EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); - message_port->ReceiveMessage(kValidOfferMessage); + EXPECT_CALL(client_, OnNegotiated(&session, _)).Times(1); + EXPECT_CALL(client_, OnConfiguredReceiversDestroyed(&session)).Times(1); + message_port_->ReceiveMessage(kValidOfferMessage); - const auto& messages = message_port->posted_messages(); + const auto& messages = message_port_->posted_messages(); EXPECT_EQ(1u, messages.size()); auto message_body = json::Parse(messages[0]); @@ -430,17 +485,11 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) { } TEST_F(ReceiverSessionTest, HandlesNoValidAudioStream) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - auto environment = MakeEnvironment(); - ReceiverSession session(&client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); - - EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)).Times(1); + EXPECT_CALL(client_, OnConfiguredReceiversDestroyed(session_.get())).Times(1); - message_port->ReceiveMessage(kNoAudioOfferMessage); - const auto& messages = message_port->posted_messages(); + message_port_->ReceiveMessage(kNoAudioOfferMessage); + const auto& messages = message_port_->posted_messages(); EXPECT_EQ(1u, messages.size()); auto message_body = json::Parse(messages[0]); @@ -455,18 +504,26 @@ TEST_F(ReceiverSessionTest, HandlesNoValidAudioStream) { EXPECT_EQ(19088746, answer_body["ssrcs"][0].asInt()); } -TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - auto environment = MakeEnvironment(); - ReceiverSession session(&client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); +TEST_F(ReceiverSessionTest, HandlesInvalidCodec) { + // We didn't select any streams, but didn't have any errors either. + message_port_->ReceiveMessage(kInvalidCodecOfferMessage); + const auto& messages = message_port_->posted_messages(); + EXPECT_EQ(1u, messages.size()); + + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + + // We should have failed to produce a valid answer message due to not + // selecting any stream. + EXPECT_EQ("error", message_body.value()["result"].asString()); +} - EXPECT_CALL(client, OnNegotiated(&session, _)).Times(1); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(1); +TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) { + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)).Times(1); + EXPECT_CALL(client_, OnConfiguredReceiversDestroyed(session_.get())).Times(1); - message_port->ReceiveMessage(kNoVideoOfferMessage); - const auto& messages = message_port->posted_messages(); + message_port_->ReceiveMessage(kNoVideoOfferMessage); + const auto& messages = message_port_->posted_messages(); EXPECT_EQ(1u, messages.size()); auto message_body = json::Parse(messages[0]); @@ -482,19 +539,9 @@ TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) { } TEST_F(ReceiverSessionTest, HandlesNoValidStreams) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - - auto environment = MakeEnvironment(); - ReceiverSession session(&client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); - // We shouldn't call OnNegotiated if we failed to negotiate any streams. - EXPECT_CALL(client, OnNegotiated(&session, _)).Times(0); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(0); - - message_port->ReceiveMessage(kNoAudioOrVideoOfferMessage); - const auto& messages = message_port->posted_messages(); + message_port_->ReceiveMessage(kNoAudioOrVideoOfferMessage); + const auto& messages = message_port_->posted_messages(); EXPECT_EQ(1u, messages.size()); auto message_body = json::Parse(messages[0]); @@ -502,35 +549,77 @@ TEST_F(ReceiverSessionTest, HandlesNoValidStreams) { } TEST_F(ReceiverSessionTest, HandlesMalformedOffer) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - auto environment = MakeEnvironment(); - ReceiverSession session(&client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); - - // We shouldn't call OnNegotiated if we failed to negotiate any streams. // Note that unlike when we simply don't select any streams, when the offer // is actually completely invalid we call OnError. - EXPECT_CALL(client, OnNegotiated(&session, _)).Times(0); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(0); - EXPECT_CALL(client, OnError(&session, Error(Error::Code::kJsonParseError))) + EXPECT_CALL(client_, + OnError(session_.get(), Error(Error::Code::kJsonParseError))) .Times(1); - message_port->ReceiveMessage(kInvalidJsonOfferMessage); + message_port_->ReceiveMessage(kInvalidJsonOfferMessage); +} + +TEST_F(ReceiverSessionTest, HandlesImproperlyFormattedOffer) { + EXPECT_CALL(client_, + OnError(session_.get(), + Error(Error::Code::kJsonParseError, + "Failed to parse supported streams in offer"))) + .Times(1); + + message_port_->ReceiveMessage(kValidJsonInvalidFormatOfferMessage); +} + +TEST_F(ReceiverSessionTest, HandlesNullOffer) { + EXPECT_CALL(client_, OnError(session_.get(), + Error(Error::Code::kJsonParseError, + "Received offer missing offer body"))) + .Times(1); + + message_port_->ReceiveMessage(kNullJsonOfferMessage); +} + +TEST_F(ReceiverSessionTest, HandlesInvalidSequenceNumber) { + // We should just discard messages with an invalid sequence number. + message_port_->ReceiveMessage(kInvalidSequenceNumberMessage); +} + +TEST_F(ReceiverSessionTest, HandlesUnknownTypeMessage) { + // We should just discard messages with an unknown message type. + message_port_->ReceiveMessage(kUnknownTypeMessage); +} + +TEST_F(ReceiverSessionTest, HandlesInvalidTypeMessage) { + // We should just discard messages with an invalid message type. + message_port_->ReceiveMessage(kInvalidTypeMessage); +} + +TEST_F(ReceiverSessionTest, DoesntCrashOnMessagePortError) { + message_port_->ReceiveError(Error(Error::Code::kUnknownError)); } TEST_F(ReceiverSessionTest, NotifiesReceiverDestruction) { - auto message_port = std::make_unique<SimpleMessagePort>(); - StrictMock<FakeClient> client; - auto environment = MakeEnvironment(); - ReceiverSession session(&client, environment.get(), message_port.get(), - ReceiverSession::Preferences{}); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)).Times(2); + EXPECT_CALL(client_, OnConfiguredReceiversDestroyed(session_.get())).Times(2); + + message_port_->ReceiveMessage(kNoAudioOfferMessage); + message_port_->ReceiveMessage(kValidOfferMessage); +} + +TEST_F(ReceiverSessionTest, HandlesInvalidAnswer) { + // Simulate an unbound local endpoint. + EXPECT_CALL(*environment_, GetBoundLocalEndpoint).WillOnce([]() { + return IPEndpoint{}; + }); - EXPECT_CALL(client, OnNegotiated(&session, _)).Times(2); - EXPECT_CALL(client, OnConfiguredReceiversDestroyed(&session)).Times(2); + message_port_->ReceiveMessage(kValidOfferMessage); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); - message_port->ReceiveMessage(kNoAudioOfferMessage); - message_port->ReceiveMessage(kValidOfferMessage); + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + const Json::Value answer = std::move(message_body.value()); + + EXPECT_EQ("ANSWER", answer["type"].asString()); + EXPECT_EQ("error", answer["result"].asString()); } } // namespace cast } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc index b7a6187a956..d3c77178c39 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/receiver_unittest.cc @@ -35,13 +35,9 @@ #include "platform/base/udp_packet.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" -using std::chrono::duration_cast; -using std::chrono::microseconds; -using std::chrono::milliseconds; -using std::chrono::seconds; - using testing::_; using testing::AtLeast; using testing::Gt; @@ -404,11 +400,10 @@ TEST_F(ReceiverTest, ReceivesAndSendsRtcpPackets) { // from the wire-format NtpTimestamps. See the unit tests in // ntp_time_unittest.cc for further discussion. constexpr auto kAllowedNtpRoundingError = microseconds(2); - EXPECT_NEAR(duration_cast<microseconds>(kOneWayNetworkDelay).count(), - duration_cast<microseconds>(receiver_reference_time - - sender_reference_time) - .count(), - kAllowedNtpRoundingError.count()); + EXPECT_NEAR( + to_microseconds(kOneWayNetworkDelay).count(), + to_microseconds(receiver_reference_time - sender_reference_time).count(), + kAllowedNtpRoundingError.count()); // Without the Sender doing anything, the Receiver should continue providing // RTCP reports at regular intervals. Simulate three intervals of time, diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc index 696849bb501..ce1e42d96ec 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common.cc @@ -4,6 +4,7 @@ #include "cast/streaming/rtcp_common.h" +#include <algorithm> #include <limits> #include "cast/streaming/packet_util.h" @@ -187,7 +188,7 @@ void RtcpReportBlock::SetDelaySinceLastReport( // math (well, only for unusually large inputs). constexpr Delay kMaxValidReportedDelay{std::numeric_limits<uint32_t>::max()}; constexpr auto kMaxValidLocalClockDelay = - std::chrono::duration_cast<Clock::duration>(kMaxValidReportedDelay); + Clock::to_duration(kMaxValidReportedDelay); if (local_clock_delay > kMaxValidLocalClockDelay) { delay_since_last_report = kMaxValidReportedDelay; return; diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc index d593e4f02ee..14aaa7eea50 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtcp_common_unittest.cc @@ -10,6 +10,7 @@ #include "absl/types/span.h" #include "gtest/gtest.h" #include "platform/api/time.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { @@ -280,28 +281,23 @@ TEST(RtcpCommonTest, ComputesDelayForReportBlocks) { // A duration less than or equal to zero should clamp to zero. EXPECT_EQ(Delay::zero(), ComputeDelay(Clock::duration::min())); - EXPECT_EQ(Delay::zero(), ComputeDelay(std::chrono::milliseconds(-1234))); + EXPECT_EQ(Delay::zero(), ComputeDelay(milliseconds{-1234})); EXPECT_EQ(Delay::zero(), ComputeDelay(Clock::duration::zero())); // Test conversion of various durations that should not clamp. EXPECT_EQ(Delay(32768 /* 1/2 second worth of ticks */), - ComputeDelay(std::chrono::milliseconds(500))); + ComputeDelay(milliseconds(500))); EXPECT_EQ(Delay(65536 /* 1 second worth of ticks */), - ComputeDelay(std::chrono::seconds(1))); + ComputeDelay(seconds(1))); EXPECT_EQ(Delay(655360 /* 10 seconds worth of ticks */), - ComputeDelay(std::chrono::seconds(10))); - EXPECT_EQ(Delay(4294967294), - ComputeDelay(std::chrono::microseconds(65535999983))); - EXPECT_EQ(Delay(4294967294), - ComputeDelay(std::chrono::microseconds(65535999984))); + ComputeDelay(seconds(10))); + EXPECT_EQ(Delay(4294967294), ComputeDelay(microseconds(65535999983))); + EXPECT_EQ(Delay(4294967294), ComputeDelay(microseconds(65535999984))); // A too-large duration should clamp to the maximum-possible Delay value. - EXPECT_EQ(Delay(4294967295), - ComputeDelay(std::chrono::microseconds(65535999985))); - EXPECT_EQ(Delay(4294967295), - ComputeDelay(std::chrono::microseconds(65535999986))); - EXPECT_EQ(Delay(4294967295), - ComputeDelay(std::chrono::microseconds(999999000000))); + EXPECT_EQ(Delay(4294967295), ComputeDelay(microseconds(65535999985))); + EXPECT_EQ(Delay(4294967295), ComputeDelay(microseconds(65535999986))); + EXPECT_EQ(Delay(4294967295), ComputeDelay(microseconds(999999000000))); EXPECT_EQ(Delay(4294967295), ComputeDelay(Clock::duration::max())); } diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc index d99eefa011e..8b4710659d6 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_packetizer_unittest.cc @@ -4,12 +4,16 @@ #include "cast/streaming/rtp_packetizer.h" +#include <chrono> +#include <memory> + #include "absl/types/optional.h" #include "cast/streaming/frame_crypto.h" #include "cast/streaming/rtp_defines.h" #include "cast/streaming/rtp_packet_parser.h" #include "cast/streaming/ssrc.h" #include "gtest/gtest.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { @@ -34,7 +38,7 @@ class RtpPacketizerTest : public testing::Test { EncryptedFrame CreateFrame(FrameId frame_id, bool is_key_frame, - std::chrono::milliseconds new_playout_delay, + milliseconds new_playout_delay, int payload_size) const { EncodedFrame frame; frame.dependency = is_key_frame ? EncodedFrame::KEY_FRAME @@ -102,7 +106,7 @@ class RtpPacketizerTest : public testing::Test { if (packet_id == FramePacketId{0}) { EXPECT_EQ(frame.new_playout_delay, result->new_playout_delay); } else { - EXPECT_EQ(std::chrono::milliseconds(0), result->new_playout_delay); + EXPECT_EQ(milliseconds(0), result->new_playout_delay); } // Check that the RTP payload is correct for this packet. @@ -141,8 +145,8 @@ TEST_F(RtpPacketizerTest, GeneratesPacketsForSequenceOfFrames) { const bool is_key_frame = (i == 0); const int frame_payload_size = is_key_frame ? 48269 : 10000; const EncryptedFrame frame = - CreateFrame(FrameId::first() + i, is_key_frame, - std::chrono::milliseconds(0), frame_payload_size); + CreateFrame(FrameId::first() + i, is_key_frame, milliseconds(0), + frame_payload_size); SCOPED_TRACE(testing::Message() << "frame_id=" << frame.frame_id); const int num_packets = packetizer()->ComputeNumberOfPackets(frame); ASSERT_EQ(is_key_frame ? 34 : 7, num_packets); @@ -160,9 +164,8 @@ TEST_F(RtpPacketizerTest, GeneratesPacketsForSequenceOfFrames) { // delay change. Only the first packet should mention the playout delay change. TEST_F(RtpPacketizerTest, GeneratesPacketsForFrameWithLatencyChange) { const int frame_payload_size = 38383; - const EncryptedFrame frame = - CreateFrame(FrameId::first() + 42, true, std::chrono::milliseconds(543), - frame_payload_size); + const EncryptedFrame frame = CreateFrame( + FrameId::first() + 42, true, milliseconds(543), frame_payload_size); const int num_packets = packetizer()->ComputeNumberOfPackets(frame); ASSERT_EQ(27, num_packets); @@ -179,9 +182,8 @@ TEST_F(RtpPacketizerTest, GeneratesPacketsForFrameWithLatencyChange) { // silence can be represented by an empty payload). TEST_F(RtpPacketizerTest, GeneratesOnePacketForFrameWithNoPayload) { const int frame_payload_size = 0; - const EncryptedFrame frame = - CreateFrame(FrameId::first() + 99, false, std::chrono::milliseconds(0), - frame_payload_size); + const EncryptedFrame frame = CreateFrame(FrameId::first() + 99, false, + milliseconds(0), frame_payload_size); ASSERT_EQ(1, packetizer()->ComputeNumberOfPackets(frame)); TestGeneratePacket(frame, FramePacketId{0}); } @@ -190,8 +192,8 @@ TEST_F(RtpPacketizerTest, GeneratesOnePacketForFrameWithNoPayload) { // a different sequence counter value in the packet each time. TEST_F(RtpPacketizerTest, GeneratesPacketForRetransmission) { const int frame_payload_size = 16384; - const EncryptedFrame frame = CreateFrame( - FrameId::first(), true, std::chrono::milliseconds(0), frame_payload_size); + const EncryptedFrame frame = + CreateFrame(FrameId::first(), true, milliseconds(0), frame_payload_size); const int num_packets = packetizer()->ComputeNumberOfPackets(frame); ASSERT_EQ(12, num_packets); diff --git a/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc index 8427a5ea201..54173a8b0e2 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/rtp_time_unittest.cc @@ -4,7 +4,10 @@ #include "cast/streaming/rtp_time.h" +#include <chrono> + #include "gtest/gtest.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace cast { @@ -13,10 +16,6 @@ namespace cast { // accurate. Note that this implicitly tests the conversions to/from // RtpTimeTicks as well due to shared implementation. TEST(RtpTimeDeltaTest, ConversionToAndFromDurations) { - using std::chrono::microseconds; - using std::chrono::milliseconds; - using std::chrono::seconds; - constexpr int kTimebase = 48000; // Origin in both timelines is equivalent. diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender.cc b/chromium/third_party/openscreen/src/cast/streaming/sender.cc index a09b73678dd..3713a199683 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender.cc @@ -5,19 +5,17 @@ #include "cast/streaming/sender.h" #include <algorithm> -#include <ratio> // NOLINT +#include <chrono> +#include <ratio> #include "cast/streaming/session_config.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/std_util.h" namespace openscreen { namespace cast { -using std::chrono::duration_cast; -using std::chrono::microseconds; -using std::chrono::milliseconds; - using openscreen::operator<<; // For std::chrono::duration logging. Sender::Sender(Environment* environment, @@ -264,7 +262,7 @@ void Sender::OnReceiverReport(const RtcpReportBlock& receiver_report) { sender_report_builder_.GetRecentReportTime( receiver_report.last_status_report_id, rtcp_packet_arrival_time_); const auto non_network_delay = - duration_cast<Clock::duration>(receiver_report.delay_since_last_report); + Clock::to_duration(receiver_report.delay_since_last_report); // Round trip time measurement: This is the time elapsed since the Sender // Report was sent, minus the time the Receiver did other stuff before sending @@ -275,8 +273,7 @@ void Sender::OnReceiverReport(const RtcpReportBlock& receiver_report) { // true value is likely very close to zero (i.e., this is ideal network // behavior); and so just represent this as 75 µs, an optimistic // wired-Ethernet LAN ping time. - constexpr auto kNearZeroRoundTripTime = - duration_cast<Clock::duration>(microseconds(75)); + constexpr auto kNearZeroRoundTripTime = Clock::to_duration(microseconds(75)); static_assert(kNearZeroRoundTripTime > Clock::duration::zero(), "More precision in Clock::duration needed!"); const Clock::duration measurement = @@ -412,11 +409,12 @@ void Sender::OnReceiverIsMissingPackets(std::vector<PacketNack> nacks) { if (!slot) { // TODO(miu): Add tracing event here to record this. for (++nack_it; nack_it != nacks.end() && nack_it->frame_id == frame_id; - ++nack_it) - ; + ++nack_it) { + } continue; } + // NOLINTNEXTLINE latest_expected_frame_id_ = std::max(latest_expected_frame_id_, frame_id); const auto HandleIndividualNack = [&](FramePacketId packet_id) { @@ -500,8 +498,8 @@ Sender::ChosenPacketAndWhen Sender::ChooseKickstartPacket() { // arrivals. using kWaitFraction = std::ratio<1, 20>; const Clock::duration desired_kickstart_interval = - duration_cast<Clock::duration>(target_playout_delay_) * - kWaitFraction::num / kWaitFraction::den; + Clock::to_duration(target_playout_delay_) * kWaitFraction::num / + kWaitFraction::den; // The actual interval used is increased, if current network performance // warrants waiting longer. Don't send a Kickstart packet until no NACKs // have been received for two network round-trip periods. diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender.h b/chromium/third_party/openscreen/src/cast/streaming/sender.h index 48f651c04de..33ea5227318 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender.h +++ b/chromium/third_party/openscreen/src/cast/streaming/sender.h @@ -8,7 +8,7 @@ #include <stdint.h> #include <array> -#include <chrono> // NOLINT +#include <chrono> #include <vector> #include "absl/types/span.h" diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc index b505b996d7b..c3fccd1829d 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.cc @@ -9,6 +9,7 @@ #include "cast/streaming/constants.h" #include "cast/streaming/packet_util.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/saturate_cast.h" #include "util/stringprintf.h" @@ -16,10 +17,6 @@ namespace openscreen { namespace cast { -using std::chrono::duration_cast; -using std::chrono::milliseconds; -using std::chrono::seconds; - SenderPacketRouter::SenderPacketRouter(Environment* environment, int max_burst_bitrate) : SenderPacketRouter( @@ -227,8 +224,7 @@ int SenderPacketRouter::SendJustTheRtpPackets(Clock::time_point send_time, namespace { constexpr int kBitsPerByte = 8; -constexpr auto kOneSecondInMilliseconds = - duration_cast<milliseconds>(seconds(1)); +constexpr auto kOneSecondInMilliseconds = to_microseconds(seconds(1)); } // namespace // static diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h index e73e73eca3b..a6c161a91f8 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router.h @@ -7,7 +7,7 @@ #include <stdint.h> -#include <chrono> // NOLINT +#include <chrono> #include <memory> #include <vector> diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc index 16cd6a937cd..3c1d1f6aa07 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_packet_router_unittest.cc @@ -4,18 +4,19 @@ #include "cast/streaming/sender_packet_router.h" +#include <chrono> + #include "cast/streaming/constants.h" +#include "cast/streaming/mock_environment.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/base/ip_address.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" #include "util/big_endian.h" +#include "util/chrono_helpers.h" #include "util/osp_logging.h" -using std::chrono::milliseconds; -using std::chrono::seconds; - using testing::_; using testing::Invoke; using testing::Mock; @@ -126,16 +127,6 @@ absl::Span<uint8_t> ToEmptyPacketBuffer(Clock::time_point send_time, return buffer.subspan(0, 0); } -class MockEnvironment : public Environment { - public: - MockEnvironment(ClockNowFunctionPtr now_function, TaskRunner* task_runner) - : Environment(now_function, task_runner) {} - - ~MockEnvironment() override = default; - - MOCK_METHOD1(SendPacket, void(absl::Span<const uint8_t> packet)); -}; - class MockSender : public SenderPacketRouter::Sender { public: MockSender() = default; diff --git a/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc b/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc index 1c93fceed8c..7b1621616f6 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/sender_unittest.cc @@ -8,10 +8,11 @@ #include <algorithm> #include <array> -#include <chrono> // NOLINT +#include <chrono> #include <limits> #include <map> #include <set> +#include <utility> #include <vector> #include "absl/types/optional.h" @@ -36,14 +37,9 @@ #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" #include "util/alarm.h" +#include "util/chrono_helpers.h" #include "util/yet_another_bit_vector.h" -using std::chrono::duration_cast; -using std::chrono::microseconds; -using std::chrono::milliseconds; -using std::chrono::nanoseconds; -using std::chrono::seconds; - using testing::_; using testing::AtLeast; using testing::Invoke; @@ -528,8 +524,7 @@ TEST_F(SenderTest, ComputesInFlightMediaDuration) { TEST_F(SenderTest, RespondsToNetworkLatencyChanges) { // The expected maximum error in time calculations is one tick of the RTCP // report block's delay type. - constexpr auto kEpsilon = - duration_cast<nanoseconds>(RtcpReportBlock::Delay(1)); + constexpr auto kEpsilon = to_nanoseconds(RtcpReportBlock::Delay(1)); // Before the Sender has the necessary information to compute the network // round-trip time, GetMaxInFlightMediaDuration() will return half the target @@ -572,8 +567,8 @@ TEST_F(SenderTest, RespondsToNetworkLatencyChanges) { // Create the Receiver Report "reply," and simulate it being sent across the // network, back to the Sender. receiver()->SetReceiverReport( - sender_report_id, - duration_cast<RtcpReportBlock::Delay>(kReceiverProcessingDelay)); + sender_report_id, std::chrono::duration_cast<RtcpReportBlock::Delay>( + kReceiverProcessingDelay)); receiver()->TransmitRtcpFeedbackPacket(); SimulateExecution(kInboundDelay); diff --git a/chromium/third_party/openscreen/src/cast/streaming/session_config.cc b/chromium/third_party/openscreen/src/cast/streaming/session_config.cc index 651170294c1..f6f4aade45e 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/session_config.cc +++ b/chromium/third_party/openscreen/src/cast/streaming/session_config.cc @@ -4,6 +4,8 @@ #include "cast/streaming/session_config.h" +#include <utility> + namespace openscreen { namespace cast { @@ -19,8 +21,15 @@ SessionConfig::SessionConfig(Ssrc sender_ssrc, rtp_timebase(rtp_timebase), channels(channels), target_playout_delay(target_playout_delay), - aes_secret_key(aes_secret_key), - aes_iv_mask(aes_iv_mask) {} + aes_secret_key(std::move(aes_secret_key)), + aes_iv_mask(std::move(aes_iv_mask)) {} + +SessionConfig::SessionConfig(const SessionConfig& other) = default; +SessionConfig::SessionConfig(SessionConfig&& other) noexcept = default; +SessionConfig& SessionConfig::operator=(const SessionConfig& other) = default; +SessionConfig& SessionConfig::operator=(SessionConfig&& other) noexcept = + default; +SessionConfig::~SessionConfig() = default; } // namespace cast } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/cast/streaming/session_config.h b/chromium/third_party/openscreen/src/cast/streaming/session_config.h index 4d611b65e2e..cf87667e8fc 100644 --- a/chromium/third_party/openscreen/src/cast/streaming/session_config.h +++ b/chromium/third_party/openscreen/src/cast/streaming/session_config.h @@ -6,7 +6,7 @@ #define CAST_STREAMING_SESSION_CONFIG_H_ #include <array> -#include <chrono> // NOLINT +#include <chrono> #include <cstdint> #include "cast/streaming/ssrc.h" @@ -25,11 +25,11 @@ struct SessionConfig final { std::chrono::milliseconds target_playout_delay, std::array<uint8_t, 16> aes_secret_key, std::array<uint8_t, 16> aes_iv_mask); - SessionConfig(const SessionConfig&) = default; - SessionConfig(SessionConfig&&) noexcept = default; - SessionConfig& operator=(const SessionConfig&) = default; - SessionConfig& operator=(SessionConfig&&) noexcept = default; - ~SessionConfig() = default; + SessionConfig(const SessionConfig& other); + SessionConfig(SessionConfig&& other) noexcept; + SessionConfig& operator=(const SessionConfig& other); + SessionConfig& operator=(SessionConfig&& other) noexcept; + ~SessionConfig(); // The sender and receiver's SSRC identifiers. Note: SSRC identifiers // are defined as unsigned 32 bit integers here: diff --git a/chromium/third_party/openscreen/src/discovery/BUILD.gn b/chromium/third_party/openscreen/src/discovery/BUILD.gn index 51f695b24b0..923b8fccfa3 100644 --- a/chromium/third_party/openscreen/src/discovery/BUILD.gn +++ b/chromium/third_party/openscreen/src/discovery/BUILD.gn @@ -11,9 +11,7 @@ source_set("common") { "common/reporting_client.h", ] - deps = [ - "../util", - ] + deps = [ "../util" ] public_deps = [ "../platform", @@ -54,9 +52,7 @@ source_set("mdns") { "mdns/public/mdns_service.h", ] - deps = [ - "../util", - ] + deps = [ "../util" ] public_deps = [ ":common", @@ -69,8 +65,8 @@ source_set("dnssd") { sources = [ "dnssd/impl/conversion_layer.cc", "dnssd/impl/conversion_layer.h", - "dnssd/impl/dns_data.cc", - "dnssd/impl/dns_data.h", + "dnssd/impl/dns_data_graph.cc", + "dnssd/impl/dns_data_graph.h", "dnssd/impl/instance_key.cc", "dnssd/impl/instance_key.h", "dnssd/impl/network_interface_config.cc", @@ -137,11 +133,12 @@ source_set("unittests") { sources = [ "dnssd/impl/conversion_layer_unittest.cc", - "dnssd/impl/dns_data_unittest.cc", + "dnssd/impl/dns_data_graph_unittest.cc", "dnssd/impl/instance_key_unittest.cc", "dnssd/impl/publisher_impl_unittest.cc", "dnssd/impl/querier_impl_unittest.cc", "dnssd/impl/service_key_unittest.cc", + "dnssd/public/dns_sd_instance_endpoint_unittest.cc", "dnssd/public/dns_sd_instance_unittest.cc", "dnssd/public/dns_sd_txt_record_unittest.cc", "mdns/mdns_probe_manager_unittest.cc", @@ -170,13 +167,9 @@ source_set("unittests") { } openscreen_fuzzer_test("mdns_fuzzer") { - sources = [ - "mdns/mdns_reader_fuzztest.cc", - ] + sources = [ "mdns/mdns_reader_fuzztest.cc" ] - deps = [ - ":mdns", - ] + deps = [ ":mdns" ] seed_corpus = "mdns/fuzzer_seeds" diff --git a/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h b/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h index 4e3063c5d3a..9ba1de28d9c 100644 --- a/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h +++ b/chromium/third_party/openscreen/src/discovery/common/testing/mock_reporting_client.h @@ -12,6 +12,7 @@ namespace openscreen { namespace discovery { class MockReportingClient : public ReportingClient { + public: MOCK_METHOD1(OnFatalError, void(Error error)); MOCK_METHOD1(OnRecoverableError, void(Error error)); }; diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS b/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS index 0c34a54b3d9..243d363df56 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/DEPS @@ -1,4 +1,4 @@ -# -*- Mode: Python; -*- +#- * - Mode : Python; - * - include_rules = [ '+discovery/dnssd/public', diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc index 53d745946e8..75fab14d5c5 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer.cc @@ -4,6 +4,8 @@ #include "discovery/dnssd/impl/conversion_layer.h" +#include <utility> + #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/types/optional.h" @@ -67,27 +69,33 @@ MdnsRecord CreateSrvRecord(const DnsSdInstance& instance, kSrvRecordTtl, std::move(data)); } -absl::optional<MdnsRecord> CreateARecord(const DnsSdInstanceEndpoint& endpoint, - const DomainName& domain) { - if (!endpoint.address_v4()) { - return absl::nullopt; +std::vector<MdnsRecord> CreateARecords(const DnsSdInstanceEndpoint& endpoint, + const DomainName& domain) { + std::vector<MdnsRecord> records; + for (const IPAddress& address : endpoint.addresses()) { + if (address.IsV4()) { + ARecordRdata data(address); + records.emplace_back(domain, DnsType::kA, DnsClass::kIN, + RecordType::kUnique, kARecordTtl, std::move(data)); + } } - ARecordRdata data(endpoint.address_v4()); - return MdnsRecord(domain, DnsType::kA, DnsClass::kIN, RecordType::kUnique, - kARecordTtl, std::move(data)); + return records; } -absl::optional<MdnsRecord> CreateAAAARecord( - const DnsSdInstanceEndpoint& endpoint, - const DomainName& domain) { - if (!endpoint.address_v6()) { - return absl::nullopt; +std::vector<MdnsRecord> CreateAAAARecords(const DnsSdInstanceEndpoint& endpoint, + const DomainName& domain) { + std::vector<MdnsRecord> records; + for (const IPAddress& address : endpoint.addresses()) { + if (address.IsV6()) { + AAAARecordRdata data(address); + records.emplace_back(domain, DnsType::kAAAA, DnsClass::kIN, + RecordType::kUnique, kAAAARecordTtl, + std::move(data)); + } } - AAAARecordRdata data(endpoint.address_v6()); - return MdnsRecord(domain, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique, - kAAAARecordTtl, std::move(data)); + return records; } MdnsRecord CreateTxtRecord(const DnsSdInstance& endpoint, @@ -181,15 +189,11 @@ std::vector<MdnsRecord> GetDnsRecords(const DnsSdInstanceEndpoint& endpoint) { std::vector<MdnsRecord> records = GetDnsRecords(static_cast<DnsSdInstance>(endpoint)); - auto v4 = CreateARecord(endpoint, domain); - if (v4.has_value()) { - records.push_back(std::move(v4.value())); - } + std::vector<MdnsRecord> v4 = CreateARecords(endpoint, domain); + std::vector<MdnsRecord> v6 = CreateAAAARecords(endpoint, domain); - auto v6 = CreateAAAARecord(endpoint, domain); - if (v6.has_value()) { - records.push_back(std::move(v6.value())); - } + records.insert(records.end(), v4.begin(), v4.end()); + records.insert(records.end(), v6.begin(), v6.end()); return records; } diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc index 4559ec2bbd0..b43a281aaec 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/conversion_layer_unittest.cc @@ -98,12 +98,11 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsPtr) { DnsSdTxtRecord txt; DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), FakeDnsRecordFactory::kPortNum}, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -135,12 +134,11 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsSrv) { DnsSdTxtRecord txt; DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), FakeDnsRecordFactory::kPortNum}, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -168,12 +166,11 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAPresent) { DnsSdTxtRecord txt; DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), FakeDnsRecordFactory::kPortNum}, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -200,10 +197,9 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsANotPresent) { DnsSdTxtRecord txt; DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -216,12 +212,11 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAAPresent) { DnsSdTxtRecord txt; DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), FakeDnsRecordFactory::kPortNum}, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -248,10 +243,9 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsAAAANotPresent) { DnsSdTxtRecord txt; DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { @@ -267,12 +261,11 @@ TEST(DnsSdConversionLayerTest, GetDnsRecordsTxt) { txt.SetFlag("boolean", true); DnsSdInstanceEndpoint instance_endpoint( FakeDnsRecordFactory::kInstanceName, FakeDnsRecordFactory::kServiceName, - FakeDnsRecordFactory::kDomainName, txt, + FakeDnsRecordFactory::kDomainName, txt, 0, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV4AddressOctets), FakeDnsRecordFactory::kPortNum}, IPEndpoint{IPAddress(FakeDnsRecordFactory::kV6AddressHextets), - FakeDnsRecordFactory::kPortNum}, - 0); + FakeDnsRecordFactory::kPortNum}); std::vector<MdnsRecord> records = GetDnsRecords(instance_endpoint); auto it = std::find_if(records.begin(), records.end(), [](const MdnsRecord& record) { diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data.cc deleted file mode 100644 index e06036e4239..00000000000 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data.cc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "discovery/dnssd/impl/dns_data.h" - -#include "absl/types/optional.h" -#include "discovery/dnssd/impl/conversion_layer.h" -#include "discovery/mdns/mdns_records.h" - -namespace openscreen { -namespace discovery { -namespace { - -template <typename T> -inline Error CreateRecord(absl::optional<T>* stored, const MdnsRecord& record) { - if (stored->has_value()) { - return Error::Code::kItemAlreadyExists; - } - *stored = absl::get<T>(record.rdata()); - return Error::None(); -} - -template <typename T> -inline Error UpdateRecord(absl::optional<T>* stored, const MdnsRecord& record) { - if (!stored->has_value()) { - return Error::Code::kItemNotFound; - } - *stored = absl::get<T>(record.rdata()); - return Error::None(); -} - -template <typename T> -inline Error DeleteRecord(absl::optional<T>* stored) { - if (!stored->has_value()) { - return Error::Code::kItemNotFound; - } - *stored = absl::nullopt; - return Error::None(); -} - -template <typename T> -inline Error ProcessRecord(absl::optional<T>* stored, - const MdnsRecord& record, - RecordChangedEvent event) { - switch (event) { - case RecordChangedEvent::kCreated: - return CreateRecord(stored, record); - case RecordChangedEvent::kUpdated: - return UpdateRecord(stored, record); - case RecordChangedEvent::kExpired: - return DeleteRecord(stored); - } - return Error::Code::kUnknownError; -} - -} // namespace - -DnsData::DnsData(const InstanceKey& instance_id, - NetworkInterfaceIndex network_interface) - : instance_id_(instance_id), network_interface_(network_interface) {} - -ErrorOr<DnsSdInstanceEndpoint> DnsData::CreateEndpoint() { - if (!srv_.has_value() || !txt_.has_value() || - (!a_.has_value() && !aaaa_.has_value())) { - return Error::Code::kOperationInvalid; - } - - ErrorOr<DnsSdTxtRecord> txt_or_error = CreateFromDnsTxt(txt_.value()); - if (txt_or_error.is_error()) { - return txt_or_error.error(); - } - - if (a_.has_value() && aaaa_.has_value()) { - return DnsSdInstanceEndpoint( - instance_id_.instance_id(), instance_id_.service_id(), - instance_id_.domain_id(), std::move(txt_or_error.value()), - {a_.value().ipv4_address(), srv_.value().port()}, - {aaaa_.value().ipv6_address(), srv_.value().port()}, - network_interface_); - } else { - IPEndpoint ep = - a_.has_value() - ? IPEndpoint{a_.value().ipv4_address(), srv_.value().port()} - : IPEndpoint{aaaa_.value().ipv6_address(), srv_.value().port()}; - return DnsSdInstanceEndpoint( - instance_id_.instance_id(), instance_id_.service_id(), - instance_id_.domain_id(), std::move(txt_or_error.value()), - std::move(ep), network_interface_); - } -} - -Error DnsData::ApplyDataRecordChange(const MdnsRecord& record, - RecordChangedEvent event) { - switch (record.dns_type()) { - case DnsType::kSRV: - return ProcessRecord(&srv_, record, event); - case DnsType::kTXT: - return ProcessRecord(&txt_, record, event); - case DnsType::kA: - return ProcessRecord(&a_, record, event); - case DnsType::kAAAA: - return ProcessRecord(&aaaa_, record, event); - default: - return Error::Code::kOperationInvalid; - } -} - -} // namespace discovery -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data.h deleted file mode 100644 index 7be02e094fe..00000000000 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef DISCOVERY_DNSSD_IMPL_DNS_DATA_H_ -#define DISCOVERY_DNSSD_IMPL_DNS_DATA_H_ - -#include "absl/types/optional.h" -#include "discovery/dnssd/impl/constants.h" -#include "discovery/dnssd/impl/instance_key.h" -#include "discovery/dnssd/public/dns_sd_instance_endpoint.h" -#include "discovery/mdns/mdns_record_changed_callback.h" -#include "discovery/mdns/mdns_records.h" - -namespace openscreen { -namespace discovery { - -// This is the set of DNS data that can be associated with a single PTR record. -class DnsData { - public: - explicit DnsData(const InstanceKey& instance_id, - NetworkInterfaceIndex network_interface); - - // Converts this DnsData to an InstanceEndpoint if enough data has been - // populated to create a valid InstanceEndpoint. Specifically, this means that - // the SRV, TXT, and either A or AAAA fields have been populated. In all other - // cases, returns an error. - ErrorOr<DnsSdInstanceEndpoint> CreateEndpoint(); - - // Modifies this entity with the provided DnsRecord. If called with a valid - // record type, the provided change will always be applied. The returned - // result will be an error if the change does not make sense from our current - // data state, and Error::None() otherwise. Valid record types with which this - // method can be called are SRV, TXT, A, and AAAA record types. - Error ApplyDataRecordChange(const MdnsRecord& record, - RecordChangedEvent event); - - private: - absl::optional<SrvRecordRdata> srv_; - absl::optional<TxtRecordRdata> txt_; - absl::optional<ARecordRdata> a_; - absl::optional<AAAARecordRdata> aaaa_; - - InstanceKey instance_id_; - - NetworkInterfaceIndex network_interface_; - - // Used in dns_data_unittest.cc. - friend class DnsDataTesting; - - // Used in querier_impl_unittest.cc. - friend class DnsDataAccessor; -}; - -} // namespace discovery -} // namespace openscreen - -#endif // DISCOVERY_DNSSD_IMPL_DNS_DATA_H_ diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph.cc new file mode 100644 index 00000000000..2b46fa7aa3a --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph.cc @@ -0,0 +1,590 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/dnssd/impl/dns_data_graph.h" + +#include <utility> + +#include "discovery/dnssd/impl/conversion_layer.h" +#include "discovery/dnssd/impl/instance_key.h" + +namespace openscreen { +namespace discovery { +namespace { + +ErrorOr<DnsSdInstanceEndpoint> CreateEndpoint( + const DomainName& domain, + const absl::optional<ARecordRdata>& a, + const absl::optional<AAAARecordRdata>& aaaa, + const SrvRecordRdata& srv, + const TxtRecordRdata& txt, + NetworkInterfaceIndex network_interface) { + // Create the user-visible TXT record representation. + ErrorOr<DnsSdTxtRecord> txt_or_error = CreateFromDnsTxt(txt); + if (txt_or_error.is_error()) { + return txt_or_error.error(); + } + + InstanceKey instance_id(domain); + std::vector<IPEndpoint> endpoints; + if (a.has_value()) { + endpoints.push_back({a.value().ipv4_address(), srv.port()}); + } + if (aaaa.has_value()) { + endpoints.push_back({aaaa.value().ipv6_address(), srv.port()}); + } + + return DnsSdInstanceEndpoint( + instance_id.instance_id(), instance_id.service_id(), + instance_id.domain_id(), std::move(txt_or_error.value()), + network_interface, std::move(endpoints)); +} + +class DnsDataGraphImpl : public DnsDataGraph { + public: + using DnsDataGraph::DomainChangeCallback; + + explicit DnsDataGraphImpl(NetworkInterfaceIndex network_interface) + : network_interface_(network_interface) {} + DnsDataGraphImpl(const DnsDataGraphImpl& other) = delete; + DnsDataGraphImpl(DnsDataGraphImpl&& other) = delete; + + ~DnsDataGraphImpl() override { is_dtor_running_ = true; } + + DnsDataGraphImpl& operator=(const DnsDataGraphImpl& rhs) = delete; + DnsDataGraphImpl& operator=(DnsDataGraphImpl&& rhs) = delete; + + // DnsDataGraph overrides. + void StartTracking(const DomainName& domain, + DomainChangeCallback on_start_tracking) override; + + void StopTracking(const DomainName& domain, + DomainChangeCallback on_stop_tracking) override; + + std::vector<ErrorOr<DnsSdInstanceEndpoint>> CreateEndpoints( + DomainGroup domain_group, + const DomainName& name) const override; + + Error ApplyDataRecordChange(MdnsRecord record, + RecordChangedEvent event, + DomainChangeCallback on_start_tracking, + DomainChangeCallback on_stop_tracking) override; + + size_t GetTrackedDomainCount() const override { return nodes_.size(); } + + bool IsTracked(const DomainName& name) const override { + return nodes_.find(name) != nodes_.end(); + } + + private: + class NodeLifetimeHandler; + + using ScopedCallbackHandler = std::unique_ptr<NodeLifetimeHandler>; + + // A single node of the graph represented by this type. + class Node { + public: + // NOE: This class is non-copyable, non-movable because either operation + // would invalidate the pointer references or bidirectional edge states + // maintained by instances of this class. + Node(DomainName name, DnsDataGraphImpl* graph); + Node(const Node& other) = delete; + Node(Node&& other) = delete; + + ~Node(); + + Node& operator=(const Node& rhs) = delete; + Node& operator=(Node&& rhs) = delete; + + // Applies a record change for this node. + Error ApplyDataRecordChange(MdnsRecord record, RecordChangedEvent event); + + // Returns the first rdata of a record with type matching |type| in this + // node's |records_|, or absl::nullopt if no such record exists. + template <typename T> + absl::optional<T> GetRdata(DnsType type) { + auto it = FindRecord(type); + if (it == records_.end()) { + return absl::nullopt; + } else { + return std::cref(absl::get<T>(it->rdata())); + } + } + + const DomainName& name() const { return name_; } + const std::vector<Node*>& parents() const { return parents_; } + const std::vector<Node*>& children() const { return children_; } + const std::vector<MdnsRecord>& records() const { return records_; } + + private: + // Adds or removes an edge in |graph_|. + // NOTE: The same edge may be added multiple times, and one call to remove + // is needed for every such call. + void AddChild(Node* child); + void RemoveChild(Node* child); + + // Applies the specified change to domain |child| for this node. + void ApplyChildChange(DomainName child_name, RecordChangedEvent event); + + // Finds an iterator to the record of the provided type, or to + // records_.end() if no such record exists. + std::vector<MdnsRecord>::iterator FindRecord(DnsType type); + + // The domain with which the data records stored in this node are + // associated. + const DomainName name_; + + // Currently extant mDNS Records at |name_|. + std::vector<MdnsRecord> records_; + + // Nodes which contain records pointing to this node's |name|. + std::vector<Node*> parents_; + + // Nodes containing records pointed to by the records in this node. + std::vector<Node*> children_; + + // Graph containing this node. + DnsDataGraphImpl* graph_; + }; + + // Wrapper to handle the creation and deletion callbacks. When the object is + // created, it sets the callback to use, and erases the callback when it goes + // out of scope. This class allows all node creations to complete before + // calling the user-provided callback to ensure there are no race-conditions. + class NodeLifetimeHandler { + public: + NodeLifetimeHandler(DomainChangeCallback* callback_ptr, + DomainChangeCallback callback); + + // NOTE: The copy and delete ctors and operators must be deleted because + // they would invalidate the pointer logic used here. + NodeLifetimeHandler(const NodeLifetimeHandler& other) = delete; + NodeLifetimeHandler(NodeLifetimeHandler&& other) = delete; + + ~NodeLifetimeHandler(); + + NodeLifetimeHandler operator=(const NodeLifetimeHandler& other) = delete; + NodeLifetimeHandler operator=(NodeLifetimeHandler&& other) = delete; + + private: + std::vector<DomainName> domains_changed; + + DomainChangeCallback* callback_ptr_; + DomainChangeCallback callback_; + }; + + // Helpers to create the ScopedCallbackHandlers for creation and deletion + // callbacks. + ScopedCallbackHandler GetScopedCreationHandler( + DomainChangeCallback creation_callback); + ScopedCallbackHandler GetScopedDeletionHandler( + DomainChangeCallback deletion_callback); + + // Determines whether the provided node has the necessary records to be a + // valid node at the specified domain level. + static bool IsValidAddressNode(Node* node); + static bool IsValidSrvAndTxtNode(Node* node); + + // Calculates the set of DnsSdInstanceEndpoints associated with the PTR + // records present at the given |node|. + std::vector<ErrorOr<DnsSdInstanceEndpoint>> CalculatePtrRecordEndpoints( + Node* node) const; + + // Denotes whether the dtor for this instance has been called. This is + // required for validation of Node instance functionality. See the + // implementation of DnsDataGraph::Node::~Node() for more details. + bool is_dtor_running_ = false; + + // Map from domain name to the node containing all records associated with the + // name. + std::map<DomainName, std::unique_ptr<Node>> nodes_; + + const NetworkInterfaceIndex network_interface_; + + // The methods to be called when a domain name either starts or stops being + // referenced. These will only be set when a record change is ongoing, and act + // as a single source of truth for the creation and deletion callbacks that + // should be used during that operation. + DomainChangeCallback on_node_creation_; + DomainChangeCallback on_node_deletion_; +}; + +DnsDataGraphImpl::Node::Node(DomainName name, DnsDataGraphImpl* graph) + : name_(std::move(name)), graph_(graph) { + OSP_DCHECK(graph_); + + graph_->on_node_creation_(name_); +} + +DnsDataGraphImpl::Node::~Node() { + // A node should only be deleted when it has no parents. The only case where + // a deletion can occur when parents are still extant is during destruction of + // the holding graph. In that case, the state of the graph no longer matters + // and all nodes will be deleted, so no need to consider the child pointers. + if (!graph_->is_dtor_running_) { + auto it = std::find_if(parents_.begin(), parents_.end(), + [this](Node* parent) { return parent != this; }); + OSP_DCHECK(it == parents_.end()); + + // Erase all childrens' parent pointers to this node. + for (Node* child : children_) { + RemoveChild(child); + } + + OSP_DCHECK(graph_->on_node_deletion_); + graph_->on_node_deletion_(name_); + } +} + +Error DnsDataGraphImpl::Node::ApplyDataRecordChange(MdnsRecord record, + RecordChangedEvent event) { + OSP_DCHECK(record.name() == name_); + + // The child domain to which the changed record points, or none. This is only + // applicable for PTR and SRV records, and is empty in all other cases. + DomainName child_name; + + // The location of the current record. In the case of PTR records, multiple + // records are allowed for the same domain. In all other cases, this is not + // valid. + std::vector<MdnsRecord>::iterator it; + + if (record.dns_type() == DnsType::kPTR) { + child_name = absl::get<PtrRecordRdata>(record.rdata()).ptr_domain(); + it = std::find(records_.begin(), records_.end(), record); + } else { + if (record.dns_type() == DnsType::kSRV) { + child_name = absl::get<SrvRecordRdata>(record.rdata()).target(); + } + it = FindRecord(record.dns_type()); + } + + // Validate that the requested change is allowed and apply it. + switch (event) { + case RecordChangedEvent::kCreated: + if (it != records_.end()) { + return Error::Code::kItemAlreadyExists; + } + records_.push_back(std::move(record)); + break; + + case RecordChangedEvent::kUpdated: + if (it == records_.end()) { + return Error::Code::kItemNotFound; + } + *it = std::move(record); + break; + + case RecordChangedEvent::kExpired: + if (it == records_.end()) { + return Error::Code::kItemNotFound; + } + records_.erase(it); + break; + } + + // Apply any required edge changes to the graph. This is only applicable if + // a |child| was found earlier. Note that the same child can be added multiple + // times to the |children_| vector, which simplifies the code dramatically. + if (!child_name.empty()) { + ApplyChildChange(std::move(child_name), event); + } + + return Error::None(); +} + +void DnsDataGraphImpl::Node::ApplyChildChange(DomainName child_name, + RecordChangedEvent event) { + if (event == RecordChangedEvent::kCreated) { + const auto pair = + graph_->nodes_.emplace(child_name, std::unique_ptr<Node>()); + if (pair.second) { + auto new_node = std::make_unique<Node>(std::move(child_name), graph_); + pair.first->second.swap(new_node); + } + + AddChild(pair.first->second.get()); + } else if (event == RecordChangedEvent::kExpired) { + const auto it = graph_->nodes_.find(child_name); + OSP_DCHECK(it != graph_->nodes_.end()); + RemoveChild(it->second.get()); + } +} + +void DnsDataGraphImpl::Node::AddChild(Node* child) { + OSP_DCHECK(child); + children_.push_back(child); + child->parents_.push_back(this); +} + +void DnsDataGraphImpl::Node::RemoveChild(Node* child) { + OSP_DCHECK(child); + + auto it = std::find(children_.begin(), children_.end(), child); + OSP_DCHECK(it != children_.end()); + children_.erase(it); + + it = std::find(child->parents_.begin(), child->parents_.end(), this); + OSP_DCHECK(it != child->parents_.end()); + child->parents_.erase(it); + + // If the node has been orphaned, remove it. + it = std::find_if(child->parents_.begin(), child->parents_.end(), + [child](Node* parent) { return parent != child; }); + if (it == child->parents_.end()) { + DomainName child_name = child->name(); + const size_t count = graph_->nodes_.erase(child_name); + OSP_DCHECK(child == this || count); + } +} + +std::vector<MdnsRecord>::iterator DnsDataGraphImpl::Node::FindRecord( + DnsType type) { + return std::find_if( + records_.begin(), records_.end(), + [type](const MdnsRecord& record) { return record.dns_type() == type; }); +} + +DnsDataGraphImpl::NodeLifetimeHandler::NodeLifetimeHandler( + DomainChangeCallback* callback_ptr, + DomainChangeCallback callback) + : callback_ptr_(callback_ptr), callback_(callback) { + OSP_DCHECK(callback_ptr_); + OSP_DCHECK(callback); + OSP_DCHECK(*callback_ptr_ == nullptr); + *callback_ptr = [this](DomainName domain) { + domains_changed.push_back(std::move(domain)); + }; +} + +DnsDataGraphImpl::NodeLifetimeHandler::~NodeLifetimeHandler() { + *callback_ptr_ = nullptr; + for (DomainName& domain : domains_changed) { + callback_(domain); + } +} + +DnsDataGraphImpl::ScopedCallbackHandler +DnsDataGraphImpl::GetScopedCreationHandler( + DomainChangeCallback creation_callback) { + return std::make_unique<NodeLifetimeHandler>(&on_node_creation_, + std::move(creation_callback)); +} + +DnsDataGraphImpl::ScopedCallbackHandler +DnsDataGraphImpl::GetScopedDeletionHandler( + DomainChangeCallback deletion_callback) { + return std::make_unique<NodeLifetimeHandler>(&on_node_deletion_, + std::move(deletion_callback)); +} + +void DnsDataGraphImpl::StartTracking(const DomainName& domain, + DomainChangeCallback on_start_tracking) { + ScopedCallbackHandler creation_handler = + GetScopedCreationHandler(std::move(on_start_tracking)); + + auto pair = + nodes_.emplace(domain, std::make_unique<Node>(std::move(domain), this)); + + OSP_DCHECK(pair.second); + OSP_DCHECK(nodes_.find(domain) != nodes_.end()); +} + +void DnsDataGraphImpl::StopTracking(const DomainName& domain, + DomainChangeCallback on_stop_tracking) { + ScopedCallbackHandler deletion_handler = + GetScopedDeletionHandler(std::move(on_stop_tracking)); + + auto it = nodes_.find(domain); + OSP_CHECK(it != nodes_.end()); + OSP_DCHECK(it->second->parents().empty()); + it->second.reset(); + const size_t erased_count = nodes_.erase(domain); + OSP_DCHECK(erased_count); +} + +Error DnsDataGraphImpl::ApplyDataRecordChange( + MdnsRecord record, + RecordChangedEvent event, + DomainChangeCallback on_start_tracking, + DomainChangeCallback on_stop_tracking) { + ScopedCallbackHandler creation_handler = + GetScopedCreationHandler(std::move(on_start_tracking)); + ScopedCallbackHandler deletion_handler = + GetScopedDeletionHandler(std::move(on_stop_tracking)); + + auto it = nodes_.find(record.name()); + if (it == nodes_.end()) { + return Error::Code::kOperationCancelled; + } + + const auto result = + it->second->ApplyDataRecordChange(std::move(record), event); + + return result; +} + +std::vector<ErrorOr<DnsSdInstanceEndpoint>> DnsDataGraphImpl::CreateEndpoints( + DomainGroup domain_group, + const DomainName& name) const { + const auto it = nodes_.find(name); + if (it == nodes_.end()) { + return {}; + } + Node* target_node = it->second.get(); + + // NOTE: One of these will contain no more than one element, so iterating over + // them both will be fast. + std::vector<Node*> srv_and_txt_record_nodes; + std::vector<Node*> address_record_nodes; + + switch (domain_group) { + case DomainGroup::kAddress: + if (!IsValidAddressNode(target_node)) { + return {}; + } + + address_record_nodes.push_back(target_node); + srv_and_txt_record_nodes = target_node->parents(); + break; + + case DomainGroup::kSrvAndTxt: + if (!IsValidSrvAndTxtNode(target_node)) { + return {}; + } + + srv_and_txt_record_nodes.push_back(target_node); + address_record_nodes = target_node->children(); + break; + + case DomainGroup::kPtr: + return CalculatePtrRecordEndpoints(target_node); + + default: + return {}; + } + + // Iterate across all node pairs and create all possible DnsSdInstanceEndpoint + // objects. + std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints; + for (Node* srv_and_txt : srv_and_txt_record_nodes) { + for (Node* address : address_record_nodes) { + // First, there has to be a SRV record present (to provide the port + // number), and the target of that SRV record has to be the node where the + // address records are sourced from. + const absl::optional<SrvRecordRdata> srv = + srv_and_txt->GetRdata<SrvRecordRdata>(DnsType::kSRV); + if (!srv.has_value() || srv.value().target() != address->name()) { + continue; + } + + // Next, a TXT record must be present to provide additional connection + // information about the service per RFC 6763. + const absl::optional<TxtRecordRdata> txt = + srv_and_txt->GetRdata<TxtRecordRdata>(DnsType::kTXT); + if (!txt.has_value()) { + continue; + } + + // Last, at least one address record must be present to provide an + // endpoint for this instance. + const absl::optional<ARecordRdata> a = + address->GetRdata<ARecordRdata>(DnsType::kA); + const absl::optional<AAAARecordRdata> aaaa = + address->GetRdata<AAAARecordRdata>(DnsType::kAAAA); + if (!a.has_value() && !aaaa.has_value()) { + continue; + } + + // Then use the above info to create an endpoint object. If an error + // occurs, this is only related to the one endpoint and its possible that + // other endpoints may still be valid, so only the one endpoint is treated + // as failing. For instance, a bad TXT record for service A will not + // affect the endpoints for service B. + ErrorOr<DnsSdInstanceEndpoint> endpoint = + CreateEndpoint(srv_and_txt->name(), a, aaaa, srv.value(), txt.value(), + network_interface_); + endpoints.push_back(std::move(endpoint)); + } + } + + return endpoints; +} + +// static +bool DnsDataGraphImpl::IsValidAddressNode(Node* node) { + const absl::optional<ARecordRdata> a = + node->GetRdata<ARecordRdata>(DnsType::kA); + const absl::optional<AAAARecordRdata> aaaa = + node->GetRdata<AAAARecordRdata>(DnsType::kAAAA); + return a.has_value() || aaaa.has_value(); +} + +// static +bool DnsDataGraphImpl::IsValidSrvAndTxtNode(Node* node) { + const absl::optional<SrvRecordRdata> srv = + node->GetRdata<SrvRecordRdata>(DnsType::kSRV); + const absl::optional<TxtRecordRdata> txt = + node->GetRdata<TxtRecordRdata>(DnsType::kTXT); + + return srv.has_value() && txt.has_value(); +} + +std::vector<ErrorOr<DnsSdInstanceEndpoint>> +DnsDataGraphImpl::CalculatePtrRecordEndpoints(Node* node) const { + // PTR records aren't actually part of the generated endpoint objects, so + // call this method recursively on all children and + std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints; + for (const MdnsRecord& record : node->records()) { + if (record.dns_type() != DnsType::kPTR) { + continue; + } + + const DomainName domain = + absl::get<PtrRecordRdata>(record.rdata()).ptr_domain(); + const Node* child = nodes_.find(domain)->second.get(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> child_endpoints = + CreateEndpoints(DomainGroup::kSrvAndTxt, child->name()); + for (auto& endpoint_or_error : child_endpoints) { + endpoints.push_back(std::move(endpoint_or_error)); + } + } + return endpoints; +} + +} // namespace + +DnsDataGraph::~DnsDataGraph() = default; + +// static +std::unique_ptr<DnsDataGraph> DnsDataGraph::Create( + NetworkInterfaceIndex network_interface) { + return std::make_unique<DnsDataGraphImpl>(network_interface); +} + +// static +DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup(DnsType type) { + switch (type) { + case DnsType::kA: + case DnsType::kAAAA: + return DnsDataGraphImpl::DomainGroup::kAddress; + case DnsType::kSRV: + case DnsType::kTXT: + return DnsDataGraphImpl::DomainGroup::kSrvAndTxt; + case DnsType::kPTR: + return DnsDataGraphImpl::DomainGroup::kPtr; + default: + OSP_NOTREACHED(); + return DnsDataGraphImpl::DomainGroup::kNone; + } +} + +// static +DnsDataGraphImpl::DomainGroup DnsDataGraph::GetDomainGroup( + const MdnsRecord record) { + return GetDomainGroup(record.dns_type()); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph.h new file mode 100644 index 00000000000..9d686722bab --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph.h @@ -0,0 +1,134 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef DISCOVERY_DNSSD_IMPL_DNS_DATA_GRAPH_H_ +#define DISCOVERY_DNSSD_IMPL_DNS_DATA_GRAPH_H_ + +#include <functional> +#include <map> +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "discovery/dnssd/impl/constants.h" +#include "discovery/dnssd/public/dns_sd_instance_endpoint.h" +#include "discovery/mdns/mdns_record_changed_callback.h" +#include "discovery/mdns/mdns_records.h" + +namespace openscreen { +namespace discovery { + +/* + Per RFC 6763, the following mappings exist between the domains of the called + out mDNS records: + + -------------- + | PTR Record | + -------------- + /\ + / \ + / \ + / \ + / \ + -------------- -------------- + | SRV Record | | TXT Record | + -------------- -------------- + /\ + / \ + / \ + / \ + / \ + / \ + -------------- --------------- + | A Record | | AAAA Record | + -------------- --------------- + + Such that PTR records point to the domain of SRV and TXT records, and SRV + records point to the domain of A and AAAA records. Below, these 3 separate + sets are referred to as "Domain Groups". + + Though it is frequently the case that each A or AAAA record will only be + pointed to by one SRV record domain, this is not a requirement for DNS-SD and + in the wild this case does come up. On the other hand, it is expected that + each PTR record domain will point to multiple SRV records. + + To represent this data, a multigraph structure has been used. + - Each node of the graph represents a specific domain name + - Each edge represents a parent-child relationship, such that node A is a + parent of node B iff there exists some record x in A such that x points to + the domain represented by B. + In practice, it is expected that no more than one edge will ever exist + between two nodes. A multigraph is used despite this to simplify the code and + avoid a number of tricky edge cases (both literally and figuratively). + + Note the following: + - This definition allows for cycles in the multigraph (which are unexpected + but allowed by the RFC). + - This definition allows for self loops (which are expected when a SRV record + points to address records with the same domain). + - The memory requirement for this graph is bounded due to a bound on the + number of tracked records in the mDNS layer as part of + discovery/mdns/mdns_querier.h. +*/ +class DnsDataGraph { + public: + // The set of valid groups of domains, as called out in the hierarchy + // described above. + enum class DomainGroup { kNone = 0, kPtr, kSrvAndTxt, kAddress }; + + // Get the domain group associated with the provided object. + static DomainGroup GetDomainGroup(DnsType type); + static DomainGroup GetDomainGroup(const MdnsRecord record); + + // Creates a new DnsDataGraph. + static std::unique_ptr<DnsDataGraph> Create( + NetworkInterfaceIndex network_index); + + // Callback to use when a domain change occurs. + using DomainChangeCallback = std::function<void(DomainName)>; + + virtual ~DnsDataGraph(); + + // Manually starts or stops tracking the provided domain. These methods should + // only be called for top-level PTR domains. + virtual void StartTracking(const DomainName& domain, + DomainChangeCallback on_start_tracking) = 0; + virtual void StopTracking(const DomainName& domain, + DomainChangeCallback on_stop_tracking) = 0; + + // Attempts to create all DnsSdInstanceEndpoint objects with |name| associated + // with the provided |domain_group|. If all required data for one such + // endpoint has been received, and an error occurs while parsing this data, + // then an error is returned in place of that endpoint. + virtual std::vector<ErrorOr<DnsSdInstanceEndpoint>> CreateEndpoints( + DomainGroup domain_group, + const DomainName& name) const = 0; + + // Modifies this entity with the provided DnsRecord. If called with a valid + // record type, the provided change will only be applied if the provided event + // is valid at the time of calling. The returned result will be an error if + // the change does not make sense from our current data state, and + // Error::None() otherwise. Valid record types with which this method can be + // called are PTR, SRV, TXT, A, and AAAA record types. + // + // TODO(issuetracker.google.com/157822423): Allow for duplicate records of + // non-PTR types. + virtual Error ApplyDataRecordChange( + MdnsRecord record, + RecordChangedEvent event, + DomainChangeCallback on_start_tracking, + DomainChangeCallback on_stop_tracking) = 0; + + virtual size_t GetTrackedDomainCount() const = 0; + + // Returns whether the provided domain is tracked or not. This may either be + // due to a direct call to StartTracking() or due to the result of a received + // record. + virtual bool IsTracked(const DomainName& name) const = 0; +}; + +} // namespace discovery +} // namespace openscreen + +#endif // DISCOVERY_DNSSD_IMPL_DNS_DATA_GRAPH_H_ diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph_unittest.cc new file mode 100644 index 00000000000..0af3aa19dbb --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_graph_unittest.cc @@ -0,0 +1,469 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/dnssd/impl/dns_data_graph.h" + +#include <utility> + +#include "discovery/mdns/testing/mdns_test_util.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/base/ip_address.h" + +namespace openscreen { +namespace discovery { +namespace { + +IPAddress GetAddressV4(const DnsSdInstanceEndpoint endpoint) { + for (const IPAddress& address : endpoint.addresses()) { + if (address.IsV4()) { + return address; + } + } + return IPAddress{}; +} + +IPAddress GetAddressV6(const DnsSdInstanceEndpoint endpoint) { + for (const IPAddress& address : endpoint.addresses()) { + if (address.IsV6()) { + return address; + } + } + return IPAddress{}; +} + +} // namespace + +using testing::_; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +class DomainChangeImpl { + public: + MOCK_METHOD1(OnStartTracking, void(const DomainName&)); + MOCK_METHOD1(OnStopTracking, void(const DomainName&)); +}; + +class DnsDataGraphTests : public testing::Test { + public: + DnsDataGraphTests() : graph_(DnsDataGraph::Create(network_interface_)) { + EXPECT_CALL(callbacks_, OnStartTracking(ptr_domain_)); + StartTracking(ptr_domain_); + testing::Mock::VerifyAndClearExpectations(&callbacks_); + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{1}); + } + + protected: + void TriggerRecordCreation(MdnsRecord record, + Error::Code result_code = Error::Code::kNone) { + size_t size = graph_->GetTrackedDomainCount(); + Error result = + ApplyDataRecordChange(std::move(record), RecordChangedEvent::kCreated); + EXPECT_EQ(result.code(), result_code) + << "Failed with error code " << result.code(); + size_t new_size = graph_->GetTrackedDomainCount(); + EXPECT_EQ(size, new_size); + } + + void TriggerRecordCreationWithCallback(MdnsRecord record, + const DomainName& target_domain) { + EXPECT_CALL(callbacks_, OnStartTracking(target_domain)); + size_t size = graph_->GetTrackedDomainCount(); + Error result = + ApplyDataRecordChange(std::move(record), RecordChangedEvent::kCreated); + EXPECT_TRUE(result.ok()) << "Failed with error code " << result.code(); + size_t new_size = graph_->GetTrackedDomainCount(); + EXPECT_EQ(size + 1, new_size); + } + + void ExpectDomainEqual(const DnsSdInstance& instance, + const DomainName& name) { + EXPECT_EQ(name.labels().size(), size_t{4}); + EXPECT_EQ(instance.instance_id(), name.labels()[0]); + EXPECT_EQ(instance.service_id(), name.labels()[1] + "." + name.labels()[2]); + EXPECT_EQ(instance.domain_id(), name.labels()[3]); + } + + Error ApplyDataRecordChange(MdnsRecord record, RecordChangedEvent event) { + return graph_->ApplyDataRecordChange( + std::move(record), event, + [this](const DomainName& domain) { + callbacks_.OnStartTracking(domain); + }, + [this](const DomainName& domain) { + callbacks_.OnStopTracking(domain); + }); + } + + void StartTracking(const DomainName& domain) { + graph_->StartTracking(domain, [this](const DomainName& domain) { + callbacks_.OnStartTracking(domain); + }); + } + + void StopTracking(const DomainName& domain) { + graph_->StopTracking(domain, [this](const DomainName& domain) { + callbacks_.OnStopTracking(domain); + }); + } + + StrictMock<DomainChangeImpl> callbacks_; + NetworkInterfaceIndex network_interface_ = 1234; + std::unique_ptr<DnsDataGraph> graph_; + DomainName ptr_domain_{"_cast", "_tcp", "local"}; + DomainName primary_domain_{"test", "_cast", "_tcp", "local"}; + DomainName secondary_domain_{"test2", "_cast", "_tcp", "local"}; + DomainName tertiary_domain_{"test3", "_cast", "_tcp", "local"}; +}; + +TEST_F(DnsDataGraphTests, CallbacksCalledForStartStopTracking) { + EXPECT_CALL(callbacks_, OnStopTracking(ptr_domain_)); + StopTracking(ptr_domain_); + + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{0}); +} + +TEST_F(DnsDataGraphTests, ApplyChangeForUntrackedDomainError) { + Error result = ApplyDataRecordChange(GetFakeSrvRecord(primary_domain_), + RecordChangedEvent::kCreated); + EXPECT_EQ(result.code(), Error::Code::kOperationCancelled); + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{1}); +} + +TEST_F(DnsDataGraphTests, ChildrenStopTrackingWhenRootQueryStopped) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, secondary_domain_); + auto a = GetFakeARecord(secondary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreationWithCallback(srv, secondary_domain_); + TriggerRecordCreation(a); + + EXPECT_CALL(callbacks_, OnStopTracking(ptr_domain_)); + EXPECT_CALL(callbacks_, OnStopTracking(primary_domain_)); + EXPECT_CALL(callbacks_, OnStopTracking(secondary_domain_)); + StopTracking(ptr_domain_); + testing::Mock::VerifyAndClearExpectations(&callbacks_); + + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{0}); +} + +TEST_F(DnsDataGraphTests, CyclicSrvStopsTrackingWhenRootQueryStopped) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_); + auto a = GetFakeARecord(primary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreation(srv); + TriggerRecordCreation(a); + + EXPECT_CALL(callbacks_, OnStopTracking(ptr_domain_)); + EXPECT_CALL(callbacks_, OnStopTracking(primary_domain_)); + StopTracking(ptr_domain_); + testing::Mock::VerifyAndClearExpectations(&callbacks_); + + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{0}); +} + +TEST_F(DnsDataGraphTests, ChildrenStopTrackingWhenParentDeleted) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, secondary_domain_); + auto a = GetFakeARecord(secondary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreationWithCallback(srv, secondary_domain_); + TriggerRecordCreation(a); + + EXPECT_CALL(callbacks_, OnStopTracking(primary_domain_)); + EXPECT_CALL(callbacks_, OnStopTracking(secondary_domain_)); + auto result = ApplyDataRecordChange(ptr, RecordChangedEvent::kExpired); + EXPECT_TRUE(result.ok()) << "Failed with error code " << result.code(); + testing::Mock::VerifyAndClearExpectations(&callbacks_); + + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{1}); +} + +TEST_F(DnsDataGraphTests, OnlyAffectedNodesChangedWhenParentDeleted) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, secondary_domain_); + auto a = GetFakeARecord(secondary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreationWithCallback(srv, secondary_domain_); + TriggerRecordCreation(a); + + EXPECT_CALL(callbacks_, OnStopTracking(secondary_domain_)); + auto result = ApplyDataRecordChange(srv, RecordChangedEvent::kExpired); + EXPECT_TRUE(result.ok()) << "Failed with error code " << result.code(); + testing::Mock::VerifyAndClearExpectations(&callbacks_); + + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{2}); +} + +TEST_F(DnsDataGraphTests, CreateFailsForExistingRecord) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreation(srv); + + auto result = ApplyDataRecordChange(srv, RecordChangedEvent::kCreated); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{2}); +} + +TEST_F(DnsDataGraphTests, UpdateFailsForNonExistingRecord) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + + auto result = ApplyDataRecordChange(srv, RecordChangedEvent::kUpdated); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{2}); +} + +TEST_F(DnsDataGraphTests, DeleteFailsForNonExistingRecord) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + + auto result = ApplyDataRecordChange(srv, RecordChangedEvent::kExpired); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(graph_->GetTrackedDomainCount(), size_t{2}); +} + +TEST_F(DnsDataGraphTests, UpdateEndpointsWorksAsExpected) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, secondary_domain_); + auto txt = GetFakeTxtRecord(primary_domain_); + auto a = GetFakeARecord(secondary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreation(txt); + TriggerRecordCreationWithCallback(srv, secondary_domain_); + TriggerRecordCreation(a); + + std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints = + graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{1}); + ErrorOr<DnsSdInstanceEndpoint> endpoint_or_error = std::move(endpoints[0]); + ASSERT_TRUE(endpoint_or_error.is_value()); + DnsSdInstanceEndpoint endpoint = std::move(endpoint_or_error.value()); + + ARecordRdata rdata(IPAddress(192, 168, 1, 2)); + MdnsRecord new_a(secondary_domain_, DnsType::kA, DnsClass::kIN, + RecordType::kUnique, std::chrono::seconds(0), + std::move(rdata)); + auto result = ApplyDataRecordChange(new_a, RecordChangedEvent::kUpdated); + + endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{1}); + endpoint_or_error = std::move(endpoints[0]); + ASSERT_TRUE(endpoint_or_error.is_value()); + DnsSdInstanceEndpoint endpoint2 = std::move(endpoint_or_error.value()); + ASSERT_EQ(endpoint.addresses().size(), size_t{1}); + ASSERT_EQ(endpoint.addresses().size(), endpoint2.addresses().size()); + EXPECT_NE(endpoint.addresses()[0], endpoint2.addresses()[0]); + EXPECT_EQ(endpoint.instance_id(), endpoint2.instance_id()); + EXPECT_EQ(endpoint.service_id(), endpoint2.service_id()); + EXPECT_EQ(endpoint.domain_id(), endpoint2.domain_id()); + EXPECT_EQ(endpoint.txt(), endpoint2.txt()); + EXPECT_EQ(endpoint.port(), endpoint2.port()); +} + +TEST_F(DnsDataGraphTests, CreateEndpointsGeneratesCorrectRecords) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, secondary_domain_); + auto txt = GetFakeTxtRecord(primary_domain_); + auto a = GetFakeARecord(secondary_domain_); + auto aaaa = GetFakeAAAARecord(secondary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreation(txt); + TriggerRecordCreationWithCallback(srv, secondary_domain_); + + std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints = + graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + EXPECT_EQ(endpoints.size(), size_t{0}); + + TriggerRecordCreation(a); + endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{1}); + ErrorOr<DnsSdInstanceEndpoint> endpoint_or_error = std::move(endpoints[0]); + ASSERT_TRUE(endpoint_or_error.is_value()); + DnsSdInstanceEndpoint endpoint_a = std::move(endpoint_or_error.value()); + EXPECT_TRUE(GetAddressV4(endpoint_a)); + EXPECT_FALSE(GetAddressV6(endpoint_a)); + EXPECT_EQ(GetAddressV4(endpoint_a), kFakeARecordAddress); + ExpectDomainEqual(endpoint_a, primary_domain_); + EXPECT_EQ(endpoint_a.port(), kFakeSrvRecordPort); + + TriggerRecordCreation(aaaa); + endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{1}); + endpoint_or_error = std::move(endpoints[0]); + ASSERT_TRUE(endpoint_or_error.is_value()); + DnsSdInstanceEndpoint endpoint_a_aaaa = std::move(endpoint_or_error.value()); + ASSERT_TRUE(GetAddressV4(endpoint_a_aaaa)); + ASSERT_TRUE(GetAddressV6(endpoint_a_aaaa)); + EXPECT_EQ(GetAddressV4(endpoint_a_aaaa), kFakeARecordAddress); + EXPECT_EQ(GetAddressV6(endpoint_a_aaaa), kFakeAAAARecordAddress); + EXPECT_EQ(static_cast<DnsSdInstance>(endpoint_a), + static_cast<DnsSdInstance>(endpoint_a_aaaa)); + + auto result = ApplyDataRecordChange(a, RecordChangedEvent::kExpired); + EXPECT_TRUE(result.ok()) << "Failed with error code " << result.code(); + endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{1}); + endpoint_or_error = std::move(endpoints[0]); + ASSERT_TRUE(endpoint_or_error.is_value()); + DnsSdInstanceEndpoint endpoint_aaaa = std::move(endpoint_or_error.value()); + EXPECT_FALSE(GetAddressV4(endpoint_aaaa)); + ASSERT_TRUE(GetAddressV6(endpoint_aaaa)); + EXPECT_EQ(GetAddressV6(endpoint_aaaa), kFakeAAAARecordAddress); + EXPECT_EQ(static_cast<DnsSdInstance>(endpoint_a), + static_cast<DnsSdInstance>(endpoint_aaaa)); + + result = ApplyDataRecordChange(aaaa, RecordChangedEvent::kExpired); + EXPECT_TRUE(result.ok()) << "Failed with error code " << result.code(); + endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{0}); +} + +TEST_F(DnsDataGraphTests, CreateEndpointsHandlesSelfLoops) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, primary_domain_); + auto txt = GetFakeTxtRecord(primary_domain_); + auto a = GetFakeARecord(primary_domain_); + auto aaaa = GetFakeAAAARecord(primary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreation(srv); + TriggerRecordCreation(txt); + TriggerRecordCreation(a); + TriggerRecordCreation(aaaa); + + auto endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(srv), + primary_domain_); + ASSERT_EQ(endpoints.size(), size_t{1}); + ASSERT_TRUE(endpoints[0].is_value()); + DnsSdInstanceEndpoint endpoint = std::move(endpoints[0].value()); + + EXPECT_EQ(GetAddressV4(endpoint), kFakeARecordAddress); + EXPECT_EQ(GetAddressV6(endpoint), kFakeAAAARecordAddress); + ExpectDomainEqual(endpoint, primary_domain_); + EXPECT_EQ(endpoint.port(), kFakeSrvRecordPort); + + auto endpoints2 = + graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(ptr), ptr_domain_); + ASSERT_EQ(endpoints2.size(), size_t{1}); + ASSERT_TRUE(endpoints2[0].is_value()); + DnsSdInstanceEndpoint endpoint2 = std::move(endpoints2[0].value()); + + EXPECT_EQ(GetAddressV4(endpoint2), kFakeARecordAddress); + EXPECT_EQ(GetAddressV6(endpoint2), kFakeAAAARecordAddress); + ExpectDomainEqual(endpoint2, primary_domain_); + EXPECT_EQ(endpoint2.port(), kFakeSrvRecordPort); + + EXPECT_EQ(static_cast<DnsSdInstance>(endpoint), + static_cast<DnsSdInstance>(endpoint2)); + EXPECT_EQ(endpoint, endpoint2); +} + +TEST_F(DnsDataGraphTests, CreateEndpointsWithMultipleParents) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, tertiary_domain_); + auto txt = GetFakeTxtRecord(primary_domain_); + auto ptr2 = GetFakePtrRecord(secondary_domain_); + auto srv2 = GetFakeSrvRecord(secondary_domain_, tertiary_domain_); + auto txt2 = GetFakeTxtRecord(secondary_domain_); + auto a = GetFakeARecord(tertiary_domain_); + auto aaaa = GetFakeAAAARecord(tertiary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreationWithCallback(srv, tertiary_domain_); + TriggerRecordCreation(txt); + TriggerRecordCreationWithCallback(ptr2, secondary_domain_); + TriggerRecordCreation(srv2); + TriggerRecordCreation(txt2); + TriggerRecordCreation(a); + TriggerRecordCreation(aaaa); + + auto endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(a), + tertiary_domain_); + ASSERT_EQ(endpoints.size(), size_t{2}); + ASSERT_TRUE(endpoints[0].is_value()); + ASSERT_TRUE(endpoints[1].is_value()); + + DnsSdInstanceEndpoint endpoint_a = std::move(endpoints[0].value()); + DnsSdInstanceEndpoint endpoint_b = std::move(endpoints[1].value()); + DnsSdInstanceEndpoint* endpoint_1; + DnsSdInstanceEndpoint* endpoint_2; + if (endpoint_a.instance_id() == "test") { + endpoint_1 = &endpoint_a; + endpoint_2 = &endpoint_b; + } else { + endpoint_2 = &endpoint_a; + endpoint_1 = &endpoint_b; + } + + EXPECT_EQ(GetAddressV4(*endpoint_1), kFakeARecordAddress); + EXPECT_EQ(GetAddressV6(*endpoint_1), kFakeAAAARecordAddress); + EXPECT_EQ(endpoint_1->port(), kFakeSrvRecordPort); + ExpectDomainEqual(*endpoint_1, primary_domain_); + + EXPECT_EQ(GetAddressV4(*endpoint_2), kFakeARecordAddress); + EXPECT_EQ(GetAddressV6(*endpoint_2), kFakeAAAARecordAddress); + EXPECT_EQ(endpoint_2->port(), kFakeSrvRecordPort); + ExpectDomainEqual(*endpoint_2, secondary_domain_); +} + +TEST_F(DnsDataGraphTests, FailedConversionOnlyFailsSingleEndpointCreation) { + auto ptr = GetFakePtrRecord(primary_domain_); + auto srv = GetFakeSrvRecord(primary_domain_, tertiary_domain_); + auto txt = GetFakeTxtRecord(primary_domain_); + auto ptr2 = GetFakePtrRecord(secondary_domain_); + auto srv2 = GetFakeSrvRecord(secondary_domain_, tertiary_domain_); + auto txt2 = MdnsRecord(secondary_domain_, DnsType::kTXT, DnsClass::kIN, + RecordType::kUnique, std::chrono::seconds(0), + MakeTxtRecord({"=bad_txt_record"})); + auto a = GetFakeARecord(tertiary_domain_); + auto aaaa = GetFakeAAAARecord(tertiary_domain_); + + TriggerRecordCreationWithCallback(ptr, primary_domain_); + TriggerRecordCreationWithCallback(ptr2, secondary_domain_); + TriggerRecordCreationWithCallback(srv, tertiary_domain_); + TriggerRecordCreation(srv2); + TriggerRecordCreation(txt); + TriggerRecordCreation(txt2); + TriggerRecordCreation(a); + TriggerRecordCreation(aaaa); + + auto endpoints = graph_->CreateEndpoints(DnsDataGraph::GetDomainGroup(a), + tertiary_domain_); + ASSERT_EQ(endpoints.size(), size_t{2}); + ASSERT_TRUE(endpoints[0].is_error() || endpoints[1].is_error()); + ASSERT_TRUE(endpoints[0].is_value() || endpoints[1].is_value()); + + DnsSdInstanceEndpoint endpoint = endpoints[0].is_value() + ? std::move(endpoints[0].value()) + : std::move(endpoints[1].value()); + EXPECT_EQ(GetAddressV4(endpoint), kFakeARecordAddress); + EXPECT_EQ(GetAddressV6(endpoint), kFakeAAAARecordAddress); + EXPECT_EQ(endpoint.port(), kFakeSrvRecordPort); + ExpectDomainEqual(endpoint, primary_domain_); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc deleted file mode 100644 index 45118ee8e30..00000000000 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/dns_data_unittest.cc +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "discovery/dnssd/impl/dns_data.h" - -#include <chrono> // NOLINT - -#include "discovery/mdns/testing/mdns_test_util.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace openscreen { -namespace discovery { - -class DnsDataTesting : public DnsData { - public: - explicit DnsDataTesting(const InstanceKey& instance_key) - : DnsData(instance_key, 0) {} - - void set_srv(absl::optional<SrvRecordRdata> new_srv) { - SetVariable(new_srv, srv(), DnsType::kSRV); - } - - void set_txt(absl::optional<TxtRecordRdata> new_txt) { - SetVariable(new_txt, txt(), DnsType::kTXT); - } - - void set_a(absl::optional<ARecordRdata> new_a) { - SetVariable(new_a, a(), DnsType::kA); - } - - void set_aaaa(absl::optional<AAAARecordRdata> new_aaaa) { - SetVariable(new_aaaa, aaaa(), DnsType::kAAAA); - } - - const absl::optional<SrvRecordRdata>& srv() { return srv_; } - const absl::optional<TxtRecordRdata>& txt() { return txt_; } - const absl::optional<ARecordRdata>& a() { return a_; } - const absl::optional<AAAARecordRdata>& aaaa() { return aaaa_; } - - private: - template <typename T> - void SetVariable(absl::optional<T> new_val, - const absl::optional<T>& old_val, - DnsType type) { - if (!new_val.has_value() && !old_val.has_value()) { - return; - } - - if (!new_val.has_value()) { - MdnsRecord record(DomainName{"0"}, type, DnsClass::kIN, - RecordType::kUnique, std::chrono::seconds(1), - old_val.value()); - ApplyDataRecordChange(record, RecordChangedEvent::kExpired); - return; - } - - MdnsRecord record(DomainName{"0"}, type, DnsClass::kIN, RecordType::kUnique, - std::chrono::seconds(1), new_val.value()); - if (!old_val.has_value()) { - ApplyDataRecordChange(record, RecordChangedEvent::kCreated); - } else { - ApplyDataRecordChange(record, RecordChangedEvent::kUpdated); - } - } -}; - -namespace { - -const uint8_t kV4AddressOctets[4] = {192, 168, 0, 0}; -const uint16_t kV6AddressHextets[8] = {0x0102, 0x0304, 0x0506, 0x0708, - 0x090a, 0x0b0c, 0x0d0e, 0x0f10}; -const char kInstanceName[] = "instance"; -const char kServiceName[] = "_srv-name._udp"; -const char kDomainName[] = "local"; -constexpr uint16_t kServicePort = uint16_t{80}; - -} // namespace - -DnsDataTesting CreateFullyPopulatedData() { - InstanceKey instance{kInstanceName, kServiceName, kDomainName}; - DnsDataTesting data(instance); - DomainName target{kInstanceName, "_srv-name", "_udp", kDomainName}; - SrvRecordRdata srv(0, 0, kServicePort, target); - TxtRecordRdata txt = MakeTxtRecord({"name=value", "boolValue"}); - ARecordRdata a{IPAddress(kV4AddressOctets)}; - AAAARecordRdata aaaa{IPAddress(kV6AddressHextets)}; - - data.set_srv(srv); - data.set_txt(txt); - data.set_a(a); - data.set_aaaa(aaaa); - - return data; -} - -MdnsRecord CreateFullyPopulatedRecord(uint16_t port = kServicePort) { - DomainName target{kInstanceName, "_srv-name", "_udp", kDomainName}; - auto type = DnsType::kSRV; - auto clazz = DnsClass::kIN; - auto record_type = RecordType::kShared; - auto ttl = std::chrono::seconds(0); - SrvRecordRdata srv(0, 0, port, target); - return MdnsRecord(target, type, clazz, record_type, ttl, srv); -} - -// DnsData Conversions. -TEST(DnsSdDnsDataTests, TestConvertDnsDataCorrectly) { - DnsDataTesting data = CreateFullyPopulatedData(); - ErrorOr<DnsSdInstanceEndpoint> result = data.CreateEndpoint(); - ASSERT_TRUE(result.is_value()); - - DnsSdInstanceEndpoint record = result.value(); - ASSERT_TRUE(record.endpoint_v4()); - ASSERT_TRUE(record.endpoint_v6()); - EXPECT_EQ(record.instance_id(), kInstanceName); - EXPECT_EQ(record.service_id(), kServiceName); - EXPECT_EQ(record.domain_id(), kDomainName); - EXPECT_EQ(record.endpoint_v4().port, kServicePort); - EXPECT_EQ(record.endpoint_v4().address, IPAddress(kV4AddressOctets)); - EXPECT_EQ(record.endpoint_v6().port, kServicePort); - EXPECT_EQ(record.endpoint_v6().address, IPAddress(kV6AddressHextets)); - EXPECT_FALSE(record.txt().IsEmpty()); -} - -TEST(DnsSdDnsDataTests, TestConvertDnsDataMissingData) { - DnsDataTesting data = CreateFullyPopulatedData(); - EXPECT_TRUE(data.CreateEndpoint().is_value()); - - data = CreateFullyPopulatedData(); - data.set_srv(absl::nullopt); - EXPECT_FALSE(data.CreateEndpoint().is_value()); - - data = CreateFullyPopulatedData(); - data.set_txt(absl::nullopt); - EXPECT_FALSE(data.CreateEndpoint().is_value()); - - data = CreateFullyPopulatedData(); - data.set_a(absl::nullopt); - EXPECT_TRUE(data.CreateEndpoint().is_value()); - - data = CreateFullyPopulatedData(); - data.set_aaaa(absl::nullopt); - EXPECT_TRUE(data.CreateEndpoint().is_value()); - - data = CreateFullyPopulatedData(); - data.set_a(absl::nullopt); - data.set_aaaa(absl::nullopt); - EXPECT_FALSE(data.CreateEndpoint().is_value()); -} - -TEST(DnsSdDnsDataTests, TestConvertDnsDataOneAddress) { - // Address v4. - DnsDataTesting data = CreateFullyPopulatedData(); - data.set_aaaa(absl::nullopt); - ErrorOr<DnsSdInstanceEndpoint> result = data.CreateEndpoint(); - ASSERT_TRUE(result.is_value()); - - DnsSdInstanceEndpoint record = result.value(); - EXPECT_FALSE(record.endpoint_v6().address); - EXPECT_FALSE(record.endpoint_v6()); - ASSERT_TRUE(record.endpoint_v4()); - EXPECT_EQ(record.endpoint_v4().port, kServicePort); - EXPECT_EQ(record.endpoint_v4().address, IPAddress(kV4AddressOctets)); - - // Address v6. - data = CreateFullyPopulatedData(); - data.set_a(absl::nullopt); - result = data.CreateEndpoint(); - ASSERT_TRUE(result.is_value()); - - record = result.value(); - EXPECT_FALSE(record.endpoint_v4().address); - EXPECT_FALSE(record.endpoint_v4()); - ASSERT_TRUE(record.endpoint_v6()); - EXPECT_EQ(record.endpoint_v6().port, kServicePort); - EXPECT_EQ(record.endpoint_v6().address, IPAddress(kV6AddressHextets)); -} - -TEST(DnsSdDnsDataTests, TestConvertDnsDataBadTxt) { - DnsDataTesting data = CreateFullyPopulatedData(); - data.set_txt(MakeTxtRecord({"=bad_text"})); - ErrorOr<DnsSdInstanceEndpoint> result = data.CreateEndpoint(); - EXPECT_TRUE(result.is_error()); -} - -// ApplyDataRecordChange tests. -TEST(DnsSdDnsDataTests, TestApplyRecordChanges) { - MdnsRecord record = CreateFullyPopulatedRecord(kServicePort); - InstanceKey instance{kInstanceName, kServiceName, kDomainName}; - DnsDataTesting data(instance); - EXPECT_TRUE( - data.ApplyDataRecordChange(record, RecordChangedEvent::kCreated).ok()); - ASSERT_TRUE(data.srv().has_value()); - EXPECT_EQ(data.srv().value().port(), kServicePort); - - record = CreateFullyPopulatedRecord(234); - EXPECT_FALSE( - data.ApplyDataRecordChange(record, RecordChangedEvent::kCreated).ok()); - ASSERT_TRUE(data.srv().has_value()); - EXPECT_EQ(data.srv().value().port(), kServicePort); - - record = CreateFullyPopulatedRecord(345); - EXPECT_TRUE( - data.ApplyDataRecordChange(record, RecordChangedEvent::kUpdated).ok()); - ASSERT_TRUE(data.srv().has_value()); - EXPECT_EQ(data.srv().value().port(), 345); - - record = CreateFullyPopulatedRecord(); - EXPECT_TRUE( - data.ApplyDataRecordChange(record, RecordChangedEvent::kExpired).ok()); - ASSERT_FALSE(data.srv().has_value()); - - record = CreateFullyPopulatedRecord(1234); - EXPECT_FALSE( - data.ApplyDataRecordChange(record, RecordChangedEvent::kUpdated).ok()); - ASSERT_FALSE(data.srv().has_value()); -} - -} // namespace discovery -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc index a3844f61778..eb64c5daaa1 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.cc @@ -4,6 +4,8 @@ #include "discovery/dnssd/impl/instance_key.h" +#include <vector> + #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "discovery/dnssd/impl/conversion_layer.h" @@ -32,7 +34,8 @@ InstanceKey::InstanceKey(absl::string_view instance, absl::string_view service, absl::string_view domain) : ServiceKey(service, domain), instance_id_(instance) { - OSP_DCHECK(IsInstanceValid(instance_id_)); + OSP_DCHECK(IsInstanceValid(instance_id_)) + << "invalid instance id" << instance; } InstanceKey::InstanceKey(const InstanceKey& other) = default; @@ -41,5 +44,11 @@ InstanceKey::InstanceKey(InstanceKey&& other) = default; InstanceKey& InstanceKey::operator=(const InstanceKey& rhs) = default; InstanceKey& InstanceKey::operator=(InstanceKey&& rhs) = default; +DomainName InstanceKey::GetName() const { + std::vector<std::string> labels = ServiceKey::GetName().labels(); + labels.insert(labels.begin(), instance_id()); + return DomainName(std::move(labels)); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h index 5b807107e3f..23d953b1b7f 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key.h @@ -41,6 +41,8 @@ class InstanceKey : public ServiceKey { InstanceKey& operator=(const InstanceKey& rhs); InstanceKey& operator=(InstanceKey&& rhs); + DomainName GetName() const override; + const std::string& instance_id() const { return instance_id_; } private: diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc index 29e199817b1..d0886a9a20f 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/instance_key_unittest.cc @@ -96,5 +96,15 @@ TEST(DnsSdInstanceKeyTest, CreateFromRecordTest) { EXPECT_EQ(key.domain_id(), FakeDnsRecordFactory::kDomainName); } +TEST(DnsSdInstanceKeyTest, GetNameTest) { + InstanceKey key("instance", "_service._udp", "domain"); + DomainName expected{"instance", "_service", "_udp", "domain"}; + EXPECT_EQ(expected, key.GetName()); + + key = InstanceKey("foo", "_bar._tcp", "local"); + expected = DomainName{"foo", "_bar", "_tcp", "local"}; + EXPECT_EQ(expected, key.GetName()); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc index 4dac0aa6d08..ab45a86b29e 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/publisher_impl.cc @@ -5,6 +5,7 @@ #include "discovery/dnssd/impl/publisher_impl.h" #include <map> +#include <string> #include <utility> #include <vector> @@ -26,20 +27,16 @@ DnsSdInstanceEndpoint CreateEndpoint( DnsSdInstance instance, InstanceKey key, const NetworkInterfaceConfig& network_config) { - if (!network_config.HasAddressV4() || !network_config.HasAddressV6()) { - const IPAddress& address = network_config.GetAddress(); - OSP_DCHECK(address); - IPEndpoint endpoint{address, instance.port()}; - return DnsSdInstanceEndpoint(key.instance_id(), key.service_id(), - key.domain_id(), instance.txt(), endpoint, - network_config.network_interface()); - } else { - IPEndpoint endpoint_v4{network_config.address_v4(), instance.port()}; - IPEndpoint endpoint_v6{network_config.address_v6(), instance.port()}; - return DnsSdInstanceEndpoint( - key.instance_id(), key.service_id(), key.domain_id(), instance.txt(), - endpoint_v4, endpoint_v6, network_config.network_interface()); + std::vector<IPEndpoint> endpoints; + if (network_config.HasAddressV4()) { + endpoints.push_back({network_config.address_v4(), instance.port()}); + } + if (network_config.HasAddressV6()) { + endpoints.push_back({network_config.address_v6(), instance.port()}); } + return DnsSdInstanceEndpoint( + key.instance_id(), key.service_id(), key.domain_id(), instance.txt(), + network_config.network_interface(), std::move(endpoints)); } DnsSdInstanceEndpoint UpdateDomain( diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc index 07a62755726..94f54fc91c4 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.cc @@ -4,9 +4,13 @@ #include "discovery/dnssd/impl/querier_impl.h" +#include <algorithm> #include <string> +#include <utility> #include <vector> +#include "discovery/common/reporting_client.h" +#include "discovery/dnssd/impl/conversion_layer.h" #include "discovery/dnssd/impl/network_interface_config.h" #include "platform/api/task_runner.h" #include "util/osp_logging.h" @@ -17,29 +21,200 @@ namespace { static constexpr char kLocalDomain[] = "local"; -std::vector<PendingQueryChange> GetDnsQueriesDelayed( - std::vector<DnsQueryInfo> query_infos, - QuerierImpl* callback, - PendingQueryChange::ChangeType change_type) { - std::vector<PendingQueryChange> pending_changes; - for (auto& info : query_infos) { - pending_changes.push_back({std::move(info.name), info.dns_type, - info.dns_class, callback, change_type}); +// Removes all error instances from the below records, and calls the log +// function on all errors present in |new_endpoints|. Input vectors are expected +// to be sorted in ascending order. +void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints, + std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints, + std::function<void(Error)> log) { + OSP_DCHECK(old_endpoints); + OSP_DCHECK(new_endpoints); + + auto old_it = old_endpoints->begin(); + auto new_it = new_endpoints->begin(); + + // Iterate across both vectors and log new errors in the process. + // NOTE: In sorted order, all errors will appear before all non-errors. + while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) { + ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it; + ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it; + + if (new_ep.is_value()) { + break; + } + + // If they are equal, the element is in both |old_endpoints| and + // |new_endpoints|, so skip it in both vectors. + if (old_ep == new_ep) { + old_it++; + new_it++; + continue; + } + + // There's an error in |old_endpoints| not in |new_endpoints|, so skip it. + if (old_ep < new_ep) { + old_it++; + continue; + } + + // There's an error in |new_endpoints| not in |old_endpoints|, so it's a new + // error from the applied changes. Log it. + log(std::move(new_ep.error())); + new_it++; + } + + // Skip all remaining errors in the old vector. + for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) { + } + + // Log all errors remaining in the new vector. + for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) { + log(std::move(new_it->error())); + } + + // Erase errors. + old_endpoints->erase(old_endpoints->begin(), old_it); + new_endpoints->erase(new_endpoints->begin(), new_it); +} + +// Returns a vector containing the value of each ErrorOr<> instance provided. +// All ErrorOr<> values are expected to be non-errors. +std::vector<DnsSdInstanceEndpoint> GetValues( + std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) { + std::vector<DnsSdInstanceEndpoint> results; + results.reserve(endpoints.size()); + for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) { + OSP_CHECK(endpoint.is_value()); + results.push_back(std::move(endpoint.value())); + } + return results; +} + +bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first, + const absl::optional<DnsSdInstanceEndpoint>& second) { + if (!first.has_value() || !second.has_value()) { + return !first.has_value() && !second.has_value(); + } + + // In the remaining case, both |first| and |second| must be values. + const DnsSdInstanceEndpoint& a = first.value(); + const DnsSdInstanceEndpoint& b = second.value(); + + // All endpoints from this querier should have the same network interface + // because the querier is only associated with a single network interface. + OSP_DCHECK_EQ(a.network_interface(), b.network_interface()); + + // Function returns true if first < second. + return a.instance_id() == b.instance_id() && + a.service_id() == b.service_id() && a.domain_id() == b.domain_id(); +} + +bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first, + const absl::optional<DnsSdInstanceEndpoint>& second) { + return !IsEqualOrUpdate(first, second); +} + +// Calculates the created, updated, and deleted elements using the provided +// sets, appending these values to the provided vectors. Each of the input +// vectors is expected to contain only elements such that +// |element|.is_error() == false. Additionally, input vectors are expected to +// be sorted in ascending order. +// +// NOTE: A lot of operations are used to do this, but each is only O(n) so the +// resulting algorithm is still fast. +void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints, + std::vector<DnsSdInstanceEndpoint> new_endpoints, + std::vector<DnsSdInstanceEndpoint>* created_out, + std::vector<DnsSdInstanceEndpoint>* updated_out, + std::vector<DnsSdInstanceEndpoint>* deleted_out) { + OSP_DCHECK(created_out); + OSP_DCHECK(updated_out); + OSP_DCHECK(deleted_out); + + // Use set difference with default operators to find the elements present in + // one list but not the others. + // + // NOTE: Because absl::optional<...> types are used here and below, calls to + // the ctor and dtor for empty elements are no-ops. + const int total_count = old_endpoints.size() + new_endpoints.size(); + + // This is the set of elements that aren't in the old endpoints, meaning the + // old endpoint either didn't exist or had different TXT / Address / etc.. + std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated( + total_count); + auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(), + old_endpoints.begin(), old_endpoints.end(), + created_or_updated.begin()); + created_or_updated.erase(new_end, created_or_updated.end()); + + // This is the set of elements that are only in the old endpoints, similar to + // the above. + std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated( + total_count); + new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(), + new_endpoints.begin(), new_endpoints.end(), + deleted_or_updated.begin()); + deleted_or_updated.erase(new_end, deleted_or_updated.end()); + + // Next, find the elements which were updated. + const size_t max_count = + std::max(created_or_updated.size(), deleted_or_updated.size()); + std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count); + new_end = std::set_intersection( + created_or_updated.begin(), created_or_updated.end(), + deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(), + IsNotEqualOrUpdate); + updated.erase(new_end, updated.end()); + + // Use the updated elements to find all created and deleted elements. + std::vector<absl::optional<DnsSdInstanceEndpoint>> created( + created_or_updated.size()); + new_end = std::set_difference( + created_or_updated.begin(), created_or_updated.end(), updated.begin(), + updated.end(), created.begin(), IsNotEqualOrUpdate); + created.erase(new_end, created.end()); + + std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted( + deleted_or_updated.size()); + new_end = std::set_difference( + deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(), + updated.end(), deleted.begin(), IsNotEqualOrUpdate); + deleted.erase(new_end, deleted.end()); + + // Return the calculated elements back to the caller in the output variables. + created_out->reserve(created.size()); + for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) { + OSP_DCHECK(endpoint.has_value()); + created_out->push_back(std::move(endpoint.value())); + } + + updated_out->reserve(updated.size()); + for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) { + OSP_DCHECK(endpoint.has_value()); + updated_out->push_back(std::move(endpoint.value())); + } + + deleted_out->reserve(deleted.size()); + for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) { + OSP_DCHECK(endpoint.has_value()); + deleted_out->push_back(std::move(endpoint.value())); } - return pending_changes; } } // namespace QuerierImpl::QuerierImpl(MdnsService* mdns_querier, TaskRunner* task_runner, + ReportingClient* reporting_client, const NetworkInterfaceConfig* network_config) : mdns_querier_(mdns_querier), task_runner_(task_runner), - network_config_(network_config) { + reporting_client_(reporting_client) { OSP_DCHECK(mdns_querier_); OSP_DCHECK(task_runner_); - OSP_DCHECK(network_config_); + + OSP_DCHECK(network_config); + graph_ = DnsDataGraph::Create(network_config->network_interface()); } QuerierImpl::~QuerierImpl() = default; @@ -48,77 +223,104 @@ void QuerierImpl::StartQuery(const std::string& service, Callback* callback) { OSP_DCHECK(callback); OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - OSP_DVLOG << "Starting query for service '" << service << "'"; + OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'"; - ServiceKey key(service, kLocalDomain); - if (!IsQueryRunning(key)) { - callback_map_[key] = {callback}; - auto queries = GetDataToStartDnsQuery(std::move(key)); - StartDnsQueriesImmediately(queries); - } else { - callback_map_[key].push_back(callback); - - for (auto& kvp : received_records_) { - if (kvp.first == key) { - ErrorOr<DnsSdInstanceEndpoint> endpoint = kvp.second.CreateEndpoint(); - if (endpoint.is_value()) { - callback->OnEndpointCreated(endpoint.value()); - } - } - } + // Start tracking the new callback + const ServiceKey key(service, kLocalDomain); + auto it = + callback_map_.emplace(std::move(key), std::vector<Callback*>{}).first; + it->second.push_back(callback); + + const DomainName domain = key.GetName(); + + // If the associated service isn't tracked yet, start tracking it and start + // queries for the relevant PTR records. + if (!graph_->IsTracked(domain)) { + std::function<void(const DomainName&)> mdns_query( + [this, &domain](const DomainName& changed_domain) { + OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'"; + mdns_querier_->StartQuery(changed_domain, DnsType::kANY, + DnsClass::kANY, this); + }); + graph_->StartTracking(domain, std::move(mdns_query)); + return; } -} -bool QuerierImpl::IsQueryRunning(const std::string& service) const { - OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - return IsQueryRunning(ServiceKey(service, kLocalDomain)); + // Else, it's already being tracked so fire creation callbacks for any already + // found service instances. + const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints = + graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain); + for (const auto& endpoint : endpoints) { + if (endpoint.is_value()) { + callback->OnEndpointCreated(endpoint.value()); + } + } } void QuerierImpl::StopQuery(const std::string& service, Callback* callback) { OSP_DCHECK(callback); OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - OSP_DVLOG << "Stopping query for service '" << service << "'"; + OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'"; ServiceKey key(service, kLocalDomain); - auto callback_it = callback_map_.find(key); - if (callback_it == callback_map_.end()) { + const auto callbacks_it = callback_map_.find(key); + if (callbacks_it == callback_map_.end()) { return; } - std::vector<Callback*>* callbacks = &callback_it->second; - - const auto it = std::find(callbacks->begin(), callbacks->end(), callback); - if (it != callbacks->end()) { - callbacks->erase(it); - if (callbacks->empty()) { - callback_map_.erase(callback_it); - auto queries = GetDataToStopDnsQuery(std::move(key)); - StopDnsQueriesImmediately(queries); - } + + std::vector<Callback*>& callbacks = callbacks_it->second; + const auto it = std::find(callbacks.begin(), callbacks.end(), callback); + if (it == callbacks.end()) { + return; + } + + callbacks.erase(it); + if (callbacks.empty()) { + callback_map_.erase(callbacks_it); + + ServiceKey key(service, kLocalDomain); + DomainName domain = key.GetName(); + + std::function<void(const DomainName&)> stop_mdns_query( + [this](const DomainName& changed_domain) { + OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString() + << "'"; + mdns_querier_->StopQuery(changed_domain, DnsType::kANY, + DnsClass::kANY, this); + }); + graph_->StopTracking(domain, std::move(stop_mdns_query)); } } +bool QuerierImpl::IsQueryRunning(const std::string& service) const { + OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + const ServiceKey key(service, kLocalDomain); + return graph_->IsTracked(key.GetName()); +} + void QuerierImpl::ReinitializeQueries(const std::string& service) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DVLOG << "Re-initializing query for service '" << service << "'"; const ServiceKey key(service, kLocalDomain); - - // Stop instance-specific queries and erase all instance data received so far. - std::vector<InstanceKey> keys_to_remove; - for (const auto& pair : received_records_) { - if (key == pair.first) { - keys_to_remove.push_back(pair.first); - } - } - for (InstanceKey& ik : keys_to_remove) { - auto queries = GetDataToStopDnsQuery(std::move(ik), false); - StopDnsQueriesImmediately(queries); - } + const DomainName domain = key.GetName(); + + std::function<void(const DomainName&)> start_callback( + [this](const DomainName& domain) { + mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this); + }); + std::function<void(const DomainName&)> stop_callback( + [this](const DomainName& domain) { + mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this); + }); + graph_->StopTracking(domain, std::move(stop_callback)); // Restart top-level queries. mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name); + + graph_->StartTracking(domain, std::move(start_callback)); } std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged( @@ -126,206 +328,147 @@ std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged( RecordChangedEvent event) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - OSP_DVLOG << "Record with name '" << record.name().ToString() - << "' and type '" << record.dns_type() - << "' has received change of type '" << event << "'"; - - if (IsPtrRecord(record)) { - ErrorOr<std::vector<PendingQueryChange>> pending_changes = - HandlePtrRecordChange(record, event); - if (pending_changes.is_error()) { - OSP_LOG << "Failed to handle PTR record change of type " << event - << " with error " << pending_changes.error(); - return {}; - } else { - return pending_changes.value(); - } - } else { - Error error = HandleNonPtrRecordChange(record, event); - if (!error.ok()) { - OSP_LOG << "Failed to handle " << record.dns_type() - << " record change of type " << event << " with error " << error; - } + OSP_DVLOG << "Record " << record.ToString() + << " has received change of type '" << event << "'"; + + std::function<void(Error)> log = [this](Error error) mutable { + reporting_client_->OnRecoverableError( + Error(Error::Code::kProcessReceivedRecordFailure)); + }; + + // Get the details to use for calling CreateEndpoints(). Special case PTR + // records to optimize performance. + const DomainName& create_endpoints_domain = + record.dns_type() != DnsType::kPTR + ? record.name() + : absl::get<PtrRecordRdata>(record.rdata()).ptr_domain(); + const DnsDataGraph::DomainGroup create_endpoints_group = + record.dns_type() != DnsType::kPTR + ? DnsDataGraph::GetDomainGroup(record) + : DnsDataGraph::DomainGroup::kSrvAndTxt; + + // Get the current set of DnsSdInstanceEndpoints prior to this change. Special + // case PTR records to avoid iterating over unrelated child domains. + std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors = + graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain); + + // Apply the changes, creating a list of all pending changes that should be + // applied afterwards. + ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error = + ApplyRecordChanges(record, event); + if (pending_changes_or_error.is_error()) { + OSP_DVLOG << "Failed to apply changes for " << record.dns_type() + << " record change of type " << event << " with error " + << pending_changes_or_error.error(); + log(std::move(pending_changes_or_error.error())); return {}; } -} - -ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::HandlePtrRecordChange( - const MdnsRecord& record, - RecordChangedEvent event) { - if (!HasValidDnsRecordAddress(record)) { - // This means that the received record is malformed. - return Error::Code::kParameterInvalid; + std::vector<PendingQueryChange>& pending_changes = + pending_changes_or_error.value(); + + // Get the new set of DnsSdInstanceEndpoints following this change. + std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors = + graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain); + + // Return early if the resulting sets are equal. This will frequently be the + // case, especially when both sets are empty. + std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end()); + std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end()); + if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() && + std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(), + new_endpoints_or_errors.begin())) { + return pending_changes; } - std::vector<DnsQueryInfo> changes; - switch (event) { - case RecordChangedEvent::kCreated: - changes = GetDataToStartDnsQuery(InstanceKey(record)); - return StartDnsQueriesDelayed(std::move(changes)); - case RecordChangedEvent::kExpired: - changes = GetDataToStopDnsQuery(InstanceKey(record)); - return StopDnsQueriesDelayed(std::move(changes)); - case RecordChangedEvent::kUpdated: - return Error::Code::kOperationInvalid; - } - return Error::Code::kUnknownError; + // Log all errors and erase them. + ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors, + std::move(log)); + const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size(); + const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size(); + std::vector<DnsSdInstanceEndpoint> old_endpoints = + GetValues(std::move(old_endpoints_or_errors)); + std::vector<DnsSdInstanceEndpoint> new_endpoints = + GetValues(std::move(new_endpoints_or_errors)); + OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count); + OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count); + + // Calculate the changes and call callbacks. + // + // NOTE: As the input sets are expected to be small, the generated sets will + // also be small. + std::vector<DnsSdInstanceEndpoint> created; + std::vector<DnsSdInstanceEndpoint> updated; + std::vector<DnsSdInstanceEndpoint> deleted; + CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints), + &created, &updated, &deleted); + + InvokeChangeCallbacks(std::move(created), std::move(updated), + std::move(deleted)); + return pending_changes; } -Error QuerierImpl::HandleNonPtrRecordChange(const MdnsRecord& record, - RecordChangedEvent event) { - if (!HasValidDnsRecordAddress(record)) { - // This means that the call received had malformed data. - return Error::Code::kParameterInvalid; - } - - const ServiceKey key(record); - if (!IsQueryRunning(key)) { - // This means that the call was already queued up on the TaskRunner when the - // callback was removed. The caller no longer cares, so drop the record. - return Error::Code::kOperationCancelled; - } - const std::vector<Callback*>& callbacks = callback_map_[key]; - - // Get the current InstanceEndpoint data associated with the received record. - const InstanceKey id(record); - ErrorOr<DnsSdInstanceEndpoint> old_instance_endpoint = - Error::Code::kItemNotFound; - auto it = received_records_.find(id); - if (it == received_records_.end()) { - it = received_records_ - .emplace(id, DnsData(id, network_config_->network_interface())) - .first; +void QuerierImpl::InvokeChangeCallbacks( + std::vector<DnsSdInstanceEndpoint> created, + std::vector<DnsSdInstanceEndpoint> updated, + std::vector<DnsSdInstanceEndpoint> deleted) { + // Find an endpoint and use it to create the key, or return if there is none. + DnsSdInstanceEndpoint* some_endpoint; + if (!created.empty()) { + some_endpoint = &created.front(); + } else if (!updated.empty()) { + some_endpoint = &updated.front(); + } else if (!deleted.empty()) { + some_endpoint = &deleted.front(); } else { - old_instance_endpoint = it->second.CreateEndpoint(); - } - DnsData* data = &it->second; - - // Apply the changes specified by the received event to the stored - // InstanceEndpoint. - Error apply_result = data->ApplyDataRecordChange(record, event); - if (!apply_result.ok()) { - OSP_LOG_ERROR << "Received erroneous record change. Error: " - << apply_result; - return apply_result; - } - - // Send an update to the user, based on how the new and old records compare. - ErrorOr<DnsSdInstanceEndpoint> new_instance_endpoint = data->CreateEndpoint(); - NotifyCallbacks(callbacks, old_instance_endpoint, new_instance_endpoint); - - return Error::None(); -} - -void QuerierImpl::NotifyCallbacks( - const std::vector<Callback*>& callbacks, - const ErrorOr<DnsSdInstanceEndpoint>& old_endpoint, - const ErrorOr<DnsSdInstanceEndpoint>& new_endpoint) { - if (old_endpoint.is_value() && new_endpoint.is_value()) { - for (Callback* callback : callbacks) { - callback->OnEndpointUpdated(new_endpoint.value()); - } - } else if (old_endpoint.is_value() && !new_endpoint.is_value()) { - for (Callback* callback : callbacks) { - callback->OnEndpointDeleted(old_endpoint.value()); - } - } else if (!old_endpoint.is_value() && new_endpoint.is_value()) { - for (Callback* callback : callbacks) { - callback->OnEndpointCreated(new_endpoint.value()); - } - } -} - -std::vector<DnsQueryInfo> QuerierImpl::GetDataToStartDnsQuery(InstanceKey key) { - auto pair = received_records_.emplace( - key, DnsData(key, network_config_->network_interface())); - if (!pair.second) { - // This means that a query is already ongoing. - return {}; + return; } + ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id()); - return {GetInstanceQueryInfo(key)}; -} - -std::vector<DnsQueryInfo> QuerierImpl::GetDataToStopDnsQuery( - InstanceKey key, - bool should_inform_callbacks) { - // If the instance is not being queried for, return. - auto record_it = received_records_.find(key); - if (record_it == received_records_.end()) { - return {}; + // Find all callbacks. + auto it = callback_map_.find(key); + if (it == callback_map_.end()) { + return; } - // If the instance has enough associated data that an instance was provided to - // the higher layer, call the deleted callback for all associated callbacks. - ErrorOr<DnsSdInstanceEndpoint> instance_endpoint = - record_it->second.CreateEndpoint(); - if (should_inform_callbacks && instance_endpoint.is_value()) { - const auto it = callback_map_.find(key); - if (it != callback_map_.end()) { - for (Callback* callback : it->second) { - callback->OnEndpointDeleted(instance_endpoint.value()); - } + // Call relevant callbacks. + std::vector<Callback*>& callbacks = it->second; + for (Callback* callback : callbacks) { + for (const DnsSdInstanceEndpoint& endpoint : created) { + callback->OnEndpointCreated(endpoint); } - } - - // Erase the key to mark the instance as no longer being queried for. - received_records_.erase(record_it); - - // Call to the mDNS layer to stop the query. - return {GetInstanceQueryInfo(key)}; -} - -std::vector<DnsQueryInfo> QuerierImpl::GetDataToStartDnsQuery(ServiceKey key) { - return {GetPtrQueryInfo(key)}; -} - -std::vector<DnsQueryInfo> QuerierImpl::GetDataToStopDnsQuery(ServiceKey key) { - std::vector<DnsQueryInfo> query_infos = {GetPtrQueryInfo(key)}; - - // Stop any ongoing instance-specific queries. - std::vector<InstanceKey> keys_to_remove; - for (const auto& pair : received_records_) { - const bool key_is_service_from_query = (key == pair.first); - if (key_is_service_from_query) { - keys_to_remove.push_back(pair.first); + for (const DnsSdInstanceEndpoint& endpoint : updated) { + callback->OnEndpointUpdated(endpoint); + } + for (const DnsSdInstanceEndpoint& endpoint : deleted) { + callback->OnEndpointDeleted(endpoint); } - } - for (auto it = keys_to_remove.begin(); it != keys_to_remove.end(); it++) { - std::vector<DnsQueryInfo> instance_query_infos = - GetDataToStopDnsQuery(std::move(*it)); - query_infos.insert(query_infos.begin(), instance_query_infos.begin(), - instance_query_infos.end()); - } - - return query_infos; -} - -void QuerierImpl::StartDnsQueriesImmediately( - const std::vector<DnsQueryInfo>& query_infos) { - for (const auto& query : query_infos) { - mdns_querier_->StartQuery(query.name, query.dns_type, query.dns_class, - this); } } -void QuerierImpl::StopDnsQueriesImmediately( - const std::vector<DnsQueryInfo>& query_infos) { - for (const auto& query : query_infos) { - mdns_querier_->StopQuery(query.name, query.dns_type, query.dns_class, this); +ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges( + const MdnsRecord& record, + RecordChangedEvent event) { + std::vector<PendingQueryChange> pending_changes; + std::function<void(DomainName)> creation_callback( + [this, &pending_changes](DomainName domain) mutable { + pending_changes.push_back({std::move(domain), DnsType::kANY, + DnsClass::kANY, this, + PendingQueryChange::kStartQuery}); + }); + std::function<void(DomainName)> deletion_callback( + [this, &pending_changes](DomainName domain) mutable { + pending_changes.push_back({std::move(domain), DnsType::kANY, + DnsClass::kANY, this, + PendingQueryChange::kStopQuery}); + }); + Error result = + graph_->ApplyDataRecordChange(record, event, std::move(creation_callback), + std::move(deletion_callback)); + if (!result.ok()) { + return result; } -} -std::vector<PendingQueryChange> QuerierImpl::StartDnsQueriesDelayed( - std::vector<DnsQueryInfo> query_infos) { - return GetDnsQueriesDelayed(std::move(query_infos), this, - PendingQueryChange::kStartQuery); -} - -std::vector<PendingQueryChange> QuerierImpl::StopDnsQueriesDelayed( - std::vector<DnsQueryInfo> query_infos) { - return GetDnsQueriesDelayed(std::move(query_infos), this, - PendingQueryChange::kStopQuery); + return pending_changes; } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h index e720f16634f..f6db4c8068e 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl.h @@ -6,6 +6,8 @@ #define DISCOVERY_DNSSD_IMPL_QUERIER_IMPL_H_ #include <map> +#include <memory> +#include <string> #include <unordered_map> #include <vector> @@ -13,7 +15,7 @@ #include "absl/strings/string_view.h" #include "discovery/dnssd/impl/constants.h" #include "discovery/dnssd/impl/conversion_layer.h" -#include "discovery/dnssd/impl/dns_data.h" +#include "discovery/dnssd/impl/dns_data_graph.h" #include "discovery/dnssd/impl/instance_key.h" #include "discovery/dnssd/impl/service_key.h" #include "discovery/dnssd/public/dns_sd_instance_endpoint.h" @@ -26,6 +28,7 @@ namespace openscreen { namespace discovery { class NetworkInterfaceConfig; +class ReportingClient; class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { public: @@ -33,6 +36,7 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { // instance constructed. QuerierImpl(MdnsService* querier, TaskRunner* task_runner, + ReportingClient* reporting_client, const NetworkInterfaceConfig* network_config); ~QuerierImpl() override; @@ -49,49 +53,20 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { RecordChangedEvent event) override; private: - // Process an OnRecordChanged event for a PTR record. - ErrorOr<std::vector<PendingQueryChange>> HandlePtrRecordChange( + friend class QuerierImplTesting; + + // Applies the provided record change to the underlying |graph_| instance. + ErrorOr<std::vector<PendingQueryChange>> ApplyRecordChanges( const MdnsRecord& record, RecordChangedEvent event); - // Process an OnRecordChanged event for non-PTR records (SRV, TXT, A, and AAAA - // records). - Error HandleNonPtrRecordChange(const MdnsRecord& record, - RecordChangedEvent event); - - inline bool IsQueryRunning(const ServiceKey& key) const { - return callback_map_.find(key) != callback_map_.end(); - } - - std::vector<DnsQueryInfo> GetDataToStopDnsQuery(ServiceKey key); - std::vector<DnsQueryInfo> GetDataToStartDnsQuery(ServiceKey key); - std::vector<DnsQueryInfo> GetDataToStopDnsQuery( - InstanceKey key, - bool should_inform_callbacks = true); - std::vector<DnsQueryInfo> GetDataToStartDnsQuery(InstanceKey key); - - void StartDnsQueriesImmediately(const std::vector<DnsQueryInfo>& query_infos); - void StopDnsQueriesImmediately(const std::vector<DnsQueryInfo>& query_infos); - - std::vector<PendingQueryChange> StartDnsQueriesDelayed( - std::vector<DnsQueryInfo> query_infos); - std::vector<PendingQueryChange> StopDnsQueriesDelayed( - std::vector<DnsQueryInfo> query_infos); - - // Calls the appropriate callback method based on the provided Instance - // Endpoint values. - void NotifyCallbacks(const std::vector<Callback*>& callbacks, - const ErrorOr<DnsSdInstanceEndpoint>& old_endpoint, - const ErrorOr<DnsSdInstanceEndpoint>& new_endpoint); - - // Map from a specific service instance to the data received so far about - // that instance. The keys in this map are the instances for which an - // associated PTR record has been received, and the values are the set of - // non-PTR records received which describe that service (if any). Note that, - // with this definition, it is possible for a InstanceKey to be mapped to an - // empty DnsData if the instance has no associated records yet. - std::unordered_map<InstanceKey, DnsData, absl::Hash<InstanceKey>> - received_records_; + // Informs all relevant callbacks of the provided changes. + void InvokeChangeCallbacks(std::vector<DnsSdInstanceEndpoint> created, + std::vector<DnsSdInstanceEndpoint> updated, + std::vector<DnsSdInstanceEndpoint> deleted); + + // Graph of underlying mDNS Record and their associations with each-other. + std::unique_ptr<DnsDataGraph> graph_; // Map from the (service, domain) pairs currently being queried for to the // callbacks to call when new InstanceEndpoints are available. @@ -100,9 +75,7 @@ class QuerierImpl : public DnsSdQuerier, public MdnsRecordChangedCallback { MdnsService* const mdns_querier_; TaskRunner* const task_runner_; - const NetworkInterfaceConfig* const network_config_; - - friend class QuerierImplTesting; + ReportingClient* reporting_client_; }; } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc index 553ac7fa270..626ca92a034 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/querier_impl_unittest.cc @@ -11,6 +11,8 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/types/optional.h" +#include "discovery/common/testing/mock_reporting_client.h" +#include "discovery/dnssd/impl/conversion_layer.h" #include "discovery/dnssd/testing/fake_network_interface_config.h" #include "discovery/mdns/mdns_records.h" #include "discovery/mdns/testing/mdns_test_util.h" @@ -24,6 +26,8 @@ namespace openscreen { namespace discovery { namespace { +NetworkInterfaceIndex kNetworkInterface = 0; + class MockCallback : public DnsSdQuerier::Callback { public: MOCK_METHOD1(OnEndpointCreated, void(const DnsSdInstanceEndpoint&)); @@ -52,49 +56,30 @@ class MockMdnsService : public MdnsService { Error(const MdnsRecord&, const MdnsRecord&)); }; -SrvRecordRdata CreateSrvRecord() { - DomainName kDomain({"label"}); - constexpr uint16_t kPort{8080}; - return SrvRecordRdata(0, 0, kPort, std::move(kDomain)); -} - -ARecordRdata CreateARecord() { - return ARecordRdata(IPAddress{192, 168, 0, 0}); -} - -AAAARecordRdata CreateAAAARecord() { - return AAAARecordRdata(IPAddress(0x0102, 0x0304, 0x0506, 0x0708, 0x090a, - 0x0b0c, 0x0d0e, 0x0f10)); -} - -MdnsRecord CreatePtrRecord(const std::string& instance, - const std::string& service, - const std::string& domain) { - std::vector<std::string> ptr_labels; - std::vector<std::string> instance_labels{instance}; - - std::vector<std::string> service_labels = absl::StrSplit(service, '.'); - ptr_labels.insert(ptr_labels.end(), service_labels.begin(), - service_labels.end()); - instance_labels.insert(instance_labels.end(), service_labels.begin(), - service_labels.end()); - - std::vector<std::string> domain_labels = absl::StrSplit(domain, '.'); - ptr_labels.insert(ptr_labels.end(), domain_labels.begin(), - domain_labels.end()); - instance_labels.insert(instance_labels.end(), domain_labels.begin(), - domain_labels.end()); - - DomainName ptr_domain(ptr_labels); - DomainName inner_domain(instance_labels); - - PtrRecordRdata data(std::move(inner_domain)); - - // TTL specified by RFC 6762 section 10. - constexpr std::chrono::seconds kTtl(120); - return MdnsRecord(std::move(ptr_domain), DnsType::kPTR, DnsClass::kIN, - RecordType::kShared, kTtl, std::move(data)); -} +class MockDnsDataGraph : public DnsDataGraph { + public: + MOCK_METHOD2(StartTracking, + void(const DomainName& domain, + DomainChangeCallback on_start_tracking)); + MOCK_METHOD2(StopTracking, + void(const DomainName& domain, + DomainChangeCallback on_start_tracking)); + + MOCK_CONST_METHOD2( + CreateEndpoints, + std::vector<ErrorOr<DnsSdInstanceEndpoint>>(DomainGroup, + const DomainName&)); + + MOCK_METHOD4(ApplyDataRecordChange, + Error(MdnsRecord, + RecordChangedEvent, + DomainChangeCallback, + DomainChangeCallback)); + + MOCK_CONST_METHOD0(GetTrackedDomainCount, size_t()); + + MOCK_CONST_METHOD1(IsTracked, bool(const DomainName&)); +}; } // namespace @@ -103,289 +88,569 @@ using testing::ByMove; using testing::Return; using testing::StrictMock; -class DnsDataAccessor { - public: - explicit DnsDataAccessor(DnsData* data) : data_(data) {} - - void set_srv(absl::optional<SrvRecordRdata> record) { data_->srv_ = record; } - void set_txt(absl::optional<TxtRecordRdata> record) { data_->txt_ = record; } - void set_a(absl::optional<ARecordRdata> record) { data_->a_ = record; } - void set_aaaa(absl::optional<AAAARecordRdata> record) { - data_->aaaa_ = record; - } - - absl::optional<SrvRecordRdata>* srv() { return &data_->srv_; } - absl::optional<TxtRecordRdata>* txt() { return &data_->txt_; } - absl::optional<ARecordRdata>* a() { return &data_->a_; } - absl::optional<AAAARecordRdata>* aaaa() { return &data_->aaaa_; } - - bool CanCreateEndpoint() { return data_->CreateEndpoint().is_value(); } - - private: - DnsData* data_; -}; - class QuerierImplTesting : public QuerierImpl { public: QuerierImplTesting() - : QuerierImpl(&mock_service_, &task_runner_, &network_config_), + : QuerierImpl(&mock_service_, + &task_runner_, + &reporting_client_, + &network_config_), clock_(Clock::now()), task_runner_(&clock_) {} - MockMdnsService* service() { return &mock_service_; } - - DnsDataAccessor CreateDnsData(const std::string& instance, - const std::string& service, - const std::string& domain) { - InstanceKey key{instance, service, domain}; - auto it = - received_records_ - .emplace(key, DnsData(key, network_config_.network_interface())) - .first; - return DnsDataAccessor(&it->second); + StrictMock<MockMdnsService>& service() { return mock_service_; } + + StrictMock<MockReportingClient>& reporting_client() { + return reporting_client_; } - absl::optional<DnsDataAccessor> GetDnsData(const std::string& instance, - const std::string& service, - const std::string& domain) { - InstanceKey key{instance, service, domain}; - auto it = received_records_.find(key); - if (it == received_records_.end()) { - return absl::nullopt; + // NOTE: This should only be used for testing hard-to-achieve edge cases. + StrictMock<MockDnsDataGraph>& GetMockedGraph() { + if (!is_graph_mocked_) { + graph_ = std::make_unique<StrictMock<MockDnsDataGraph>>(); + is_graph_mocked_ = true; } - return DnsDataAccessor(&it->second); + + return static_cast<StrictMock<MockDnsDataGraph>&>(*graph_); } + size_t GetTrackedDomainCount() { return graph_->GetTrackedDomainCount(); } + + bool IsDomainTracked(const DomainName& domain) { + return graph_->IsTracked(domain); + } + + using QuerierImpl::OnRecordChanged; + private: FakeClock clock_; FakeTaskRunner task_runner_; FakeNetworkInterfaceConfig network_config_; StrictMock<MockMdnsService> mock_service_; + StrictMock<MockReportingClient> reporting_client_; + + bool is_graph_mocked_ = false; }; class DnsSdQuerierImplTest : public testing::Test { public: - DnsSdQuerierImplTest() { - EXPECT_FALSE(querier.IsQueryRunning(service)); - - EXPECT_CALL(*querier.service(), - StartQuery(_, DnsType::kPTR, DnsClass::kANY, _)) + DnsSdQuerierImplTest() + : querier(std::make_unique<QuerierImplTesting>()), + ptr_domain(DomainName{"_service", "_udp", domain}), + name(DomainName{instance, "_service", "_udp", domain}), + name2(DomainName{instance2, "_service", "_udp", domain}) { + EXPECT_FALSE(querier->IsQueryRunning(service)); + + EXPECT_CALL(querier->service(), + StartQuery(_, DnsType::kANY, DnsClass::kANY, _)) .Times(1); - querier.StartQuery(service, &callback); - EXPECT_TRUE(querier.IsQueryRunning(service)); - testing::Mock::VerifyAndClearExpectations(querier.service()); - - EXPECT_CALL(*querier.service(), - StartQuery(_, DnsType::kPTR, DnsClass::kANY, _)) - .Times(0); - EXPECT_TRUE(querier.IsQueryRunning(service)); + querier->StartQuery(service, &callback); + EXPECT_TRUE(querier->IsQueryRunning(service)); + testing::Mock::VerifyAndClearExpectations(&querier->service()); + + EXPECT_TRUE(querier->IsQueryRunning(service)); + testing::Mock::VerifyAndClearExpectations(&querier->service()); } protected: + void ValidateRecordChangeStartsQuery( + const std::vector<PendingQueryChange>& changes, + const DomainName& name, + size_t expected_size) { + ValidateRecordChangeResult(changes, name, expected_size, + PendingQueryChange::kStartQuery); + } + + void ValidateRecordChangeStopsQuery( + const std::vector<PendingQueryChange>& changes, + const DomainName& name, + size_t expected_size) { + ValidateRecordChangeResult(changes, name, expected_size, + PendingQueryChange::kStopQuery); + } + + void CreateServiceInstance(const DomainName& service_domain, + MockCallback* cb) { + MdnsRecord ptr = GetFakePtrRecord(service_domain); + MdnsRecord srv = GetFakeSrvRecord(service_domain); + MdnsRecord txt = GetFakeTxtRecord(service_domain); + MdnsRecord a = GetFakeARecord(service_domain); + MdnsRecord aaaa = GetFakeAAAARecord(service_domain); + + auto result = querier->OnRecordChanged(ptr, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(result, service_domain, 1); + + // NOTE: This verbose iterator handling is used to avoid gcc failures. + auto it = service_domain.labels().begin(); + it++; + std::string service_name = *it; + it++; + std::string service_protocol = *it; + std::string service_id = ""; + service_id.append(std::move(service_name)) + .append(".") + .append(std::move(service_protocol)); + ASSERT_TRUE(querier->IsQueryRunning(service_id)); + + result = querier->OnRecordChanged(srv, RecordChangedEvent::kCreated); + EXPECT_EQ(result.size(), size_t{0}); + + result = querier->OnRecordChanged(a, RecordChangedEvent::kCreated); + EXPECT_EQ(result.size(), size_t{0}); + + result = querier->OnRecordChanged(aaaa, RecordChangedEvent::kCreated); + EXPECT_EQ(result.size(), size_t{0}); + + EXPECT_CALL(*cb, OnEndpointCreated(_)).Times(1); + result = querier->OnRecordChanged(txt, RecordChangedEvent::kCreated); + EXPECT_EQ(result.size(), size_t{0}); + testing::Mock::VerifyAndClearExpectations(cb); + } + std::string instance = "instance"; + std::string instance2 = "instance2"; std::string service = "_service._udp"; + std::string service2 = "_service2._udp"; std::string domain = "local"; StrictMock<MockCallback> callback; - QuerierImplTesting querier; + std::unique_ptr<QuerierImplTesting> querier; + DomainName ptr_domain; + DomainName name; + DomainName name2; + + private: + void ValidateRecordChangeResult( + const std::vector<PendingQueryChange>& changes, + const DomainName& name, + size_t expected_size, + PendingQueryChange::ChangeType change_type) { + EXPECT_EQ(changes.size(), expected_size); + auto it = std::find_if( + changes.begin(), changes.end(), + [&name, change_type](const PendingQueryChange& change) { + return change.dns_type == DnsType::kANY && + change.dns_class == DnsClass::kANY && + change.change_type == change_type && change.name == name; + }); + EXPECT_TRUE(it != changes.end()); + } }; +// Common Use Cases +// +// The below tests validate the common use cases for QuerierImpl, which we +// expect will be hit for reasonable actors on the network. For these tests, the +// real DnsDataGraph object will be used. + TEST_F(DnsSdQuerierImplTest, TestStartStopQueryCallsMdnsQueries) { + DomainName other_service_id( + DomainName{instance2, "_service2", "_udp", domain}); + StrictMock<MockCallback> callback2; + EXPECT_FALSE(querier->IsQueryRunning(service2)); - querier.StartQuery(service, &callback2); - querier.StopQuery(service, &callback); - EXPECT_TRUE(querier.IsQueryRunning(service)); + EXPECT_CALL(querier->service(), + StartQuery(_, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + querier->StartQuery(service2, &callback2); + EXPECT_TRUE(querier->IsQueryRunning(service2)); - EXPECT_CALL(*querier.service(), - StopQuery(_, DnsType::kPTR, DnsClass::kANY, _)) + EXPECT_CALL(querier->service(), + StopQuery(_, DnsType::kANY, DnsClass::kANY, _)) .Times(1); - querier.StopQuery(service, &callback2); - EXPECT_FALSE(querier.IsQueryRunning(service)); + querier->StopQuery(service2, &callback2); + EXPECT_FALSE(querier->IsQueryRunning(service2)); } TEST_F(DnsSdQuerierImplTest, TestStartDuplicateQueryFiresCallbacksWhenAble) { StrictMock<MockCallback> callback2; - - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_srv(CreateSrvRecord()); - dns_data.set_txt(MakeTxtRecord({})); - dns_data.set_a(CreateARecord()); - dns_data.set_aaaa(CreateAAAARecord()); + CreateServiceInstance(name, &callback); EXPECT_CALL(callback2, OnEndpointCreated(_)).Times(1); - querier.StartQuery(service, &callback2); + querier->StartQuery(service, &callback2); + testing::Mock::VerifyAndClearExpectations(&callback2); } -TEST_F(DnsSdQuerierImplTest, TestStopQueryClearsRecords) { - querier.CreateDnsData(instance, service, domain); +TEST_F(DnsSdQuerierImplTest, TestStopQueryStopsTrackingRecords) { + CreateServiceInstance(name, &callback); - EXPECT_CALL(*querier.service(), - StopQuery(_, DnsType::kPTR, DnsClass::kANY, _)) + DomainName ptr_domain(++name.labels().begin(), name.labels().end()); + EXPECT_CALL(querier->service(), + StopQuery(ptr_domain, DnsType::kANY, DnsClass::kANY, _)) .Times(1); - EXPECT_CALL(*querier.service(), - StopQuery(_, DnsType::kANY, DnsClass::kANY, _)) + EXPECT_CALL(querier->service(), + StopQuery(name, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + querier->StopQuery(service, &callback); + EXPECT_FALSE(querier->IsDomainTracked(ptr_domain)); + EXPECT_FALSE(querier->IsDomainTracked(name)); + EXPECT_EQ(querier->GetTrackedDomainCount(), size_t{0}); + testing::Mock::VerifyAndClearExpectations(&callback); + + EXPECT_CALL(querier->service(), + StartQuery(_, DnsType::kANY, DnsClass::kANY, _)) .Times(1); - querier.StopQuery(service, &callback); - EXPECT_FALSE(querier.GetDnsData(instance, service, domain).has_value()); + querier->StartQuery(service, &callback); + EXPECT_TRUE(querier->IsQueryRunning(service)); } TEST_F(DnsSdQuerierImplTest, TestStopNonexistantQueryHasNoEffect) { StrictMock<MockCallback> callback2; - querier.CreateDnsData(instance, service, domain); - - querier.StopQuery(service, &callback2); - EXPECT_TRUE(querier.GetDnsData(instance, service, domain).has_value()); + querier->StopQuery(service, &callback2); } -TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecord) { - const auto ptr = CreatePtrRecord(instance, service, domain); - const auto ptr2 = CreatePtrRecord(instance, service, domain); - - auto result = querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated); - ASSERT_EQ(result.size(), size_t{1}); - auto query = result[0]; - EXPECT_EQ(query.dns_type, DnsType::kANY); - EXPECT_EQ(query.dns_class, DnsClass::kANY); - EXPECT_EQ(query.change_type, PendingQueryChange::kStartQuery); - - result = querier.OnRecordChanged(ptr2, RecordChangedEvent::kExpired); - ASSERT_EQ(result.size(), size_t{1}); - query = result[0]; - EXPECT_EQ(query.dns_type, DnsType::kANY); - EXPECT_EQ(query.dns_class, DnsClass::kANY); - EXPECT_EQ(query.change_type, PendingQueryChange::kStopQuery); -} +TEST_F(DnsSdQuerierImplTest, TestAFollowingAAAAFiresSecondCallback) { + MdnsRecord ptr = GetFakePtrRecord(name); + MdnsRecord srv = GetFakeSrvRecord(name); + MdnsRecord txt = GetFakeTxtRecord(name); + MdnsRecord a = GetFakeARecord(name); + MdnsRecord aaaa = GetFakeAAAARecord(name); + + std::vector<DnsSdInstanceEndpoint> endpoints; + auto changes = querier->OnRecordChanged(ptr, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(changes, name, 1); + + changes = querier->OnRecordChanged(srv, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); + changes = querier->OnRecordChanged(txt, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); + + EXPECT_CALL(callback, OnEndpointCreated(_)) + .WillOnce([&endpoints](const DnsSdInstanceEndpoint& ep) mutable { + endpoints.push_back(ep); + }); + changes = querier->OnRecordChanged(aaaa, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); + testing::Mock::VerifyAndClearExpectations(&callback); -TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) { - auto ptr = CreatePtrRecord(instance, service, domain); - auto result = querier.OnRecordChanged(ptr, RecordChangedEvent::kCreated); - ASSERT_EQ(result.size(), size_t{1}); - auto query = result[0]; - EXPECT_EQ(query.dns_type, DnsType::kANY); - EXPECT_EQ(query.dns_class, DnsClass::kANY); - EXPECT_EQ(query.change_type, PendingQueryChange::kStartQuery); - - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_srv(CreateSrvRecord()); - dns_data.set_txt(MakeTxtRecord({})); - dns_data.set_a(CreateARecord()); - dns_data.set_aaaa(CreateAAAARecord()); - ASSERT_TRUE(dns_data.CanCreateEndpoint()); - - EXPECT_CALL(callback, OnEndpointDeleted(_)).Times(1); - result = querier.OnRecordChanged(ptr, RecordChangedEvent::kExpired); - ASSERT_EQ(result.size(), size_t{1}); - query = result[0]; - EXPECT_EQ(query.dns_type, DnsType::kANY); - EXPECT_EQ(query.dns_class, DnsClass::kANY); - EXPECT_EQ(query.change_type, PendingQueryChange::kStopQuery); - - EXPECT_FALSE(querier.GetDnsData(instance, service, domain).has_value()); -} + EXPECT_CALL(callback, OnEndpointUpdated(_)) + .WillOnce([&endpoints](const DnsSdInstanceEndpoint& ep) mutable { + endpoints.push_back(ep); + }); + changes = querier->OnRecordChanged(a, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); + testing::Mock::VerifyAndClearExpectations(&callback); + + ASSERT_EQ(endpoints.size(), size_t{2}); + DnsSdInstanceEndpoint& created = endpoints[0]; + DnsSdInstanceEndpoint& updated = endpoints[1]; + EXPECT_EQ(static_cast<DnsSdInstance>(created), + static_cast<DnsSdInstance>(updated)); -TEST_F(DnsSdQuerierImplTest, NeitherNewNorOldValidRecords) { - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_a(CreateARecord()); - dns_data.set_aaaa(CreateAAAARecord()); - - auto srv_rdata = CreateSrvRecord(); - DomainName kDomainName{"instance", "_service", "_udp", "local"}; - MdnsRecord srv_record(std::move(kDomainName), DnsType::kSRV, DnsClass::kIN, - RecordType::kUnique, std::chrono::seconds(0), - srv_rdata); - querier.OnRecordChanged(srv_record, RecordChangedEvent::kCreated); + ASSERT_EQ(created.addresses().size(), size_t{1}); + EXPECT_TRUE(created.addresses()[0].IsV6()); + + ASSERT_EQ(updated.addresses().size(), size_t{2}); + EXPECT_TRUE(created.addresses()[0] == updated.addresses()[0] || + created.addresses()[0] == updated.addresses()[1]); + EXPECT_TRUE(updated.addresses()[0].IsV4() || updated.addresses()[1].IsV4()); } -TEST_F(DnsSdQuerierImplTest, BothNewAndOldValidRecords) { - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_srv(CreateSrvRecord()); - dns_data.set_txt(MakeTxtRecord({})); - dns_data.set_aaaa(CreateAAAARecord()); +TEST_F(DnsSdQuerierImplTest, TestGenerateTwoRecordsCallsCallbackTwice) { + DomainName third{"android", "local"}; + MdnsRecord ptr1 = GetFakePtrRecord(name); + MdnsRecord srv1 = GetFakeSrvRecord(name, third); + MdnsRecord txt1 = GetFakeTxtRecord(name); + MdnsRecord ptr2 = GetFakePtrRecord(name2); + MdnsRecord srv2 = GetFakeSrvRecord(name2, third); + MdnsRecord txt2 = GetFakeTxtRecord(name2); + MdnsRecord a = GetFakeARecord(third); - auto a_rdata = CreateARecord(); - const DomainName kDomainName{"instance", "_service", "_udp", "local"}; - MdnsRecord a_record(kDomainName, DnsType::kA, DnsClass::kIN, - RecordType::kUnique, std::chrono::seconds(0), a_rdata); + auto changes = querier->OnRecordChanged(ptr1, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(changes, name, 1); - EXPECT_CALL(callback, OnEndpointUpdated(_)).Times(1); - querier.OnRecordChanged(a_record, RecordChangedEvent::kCreated); - testing::Mock::VerifyAndClearExpectations(&callback); + changes = querier->OnRecordChanged(srv1, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(changes, third, 1); - EXPECT_CALL(callback, OnEndpointUpdated(_)).Times(1); - querier.OnRecordChanged(a_record, RecordChangedEvent::kUpdated); - testing::Mock::VerifyAndClearExpectations(&callback); + changes = querier->OnRecordChanged(txt1, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); - auto aaaa_rdata = CreateAAAARecord(); - MdnsRecord aaaa_record(kDomainName, DnsType::kAAAA, DnsClass::kIN, - RecordType::kUnique, std::chrono::seconds(0), - aaaa_rdata); + changes = querier->OnRecordChanged(ptr2, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(changes, name2, 1); - EXPECT_CALL(callback, OnEndpointUpdated(_)).Times(1); - querier.OnRecordChanged(aaaa_record, RecordChangedEvent::kUpdated); - testing::Mock::VerifyAndClearExpectations(&callback); + changes = querier->OnRecordChanged(srv2, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); + + changes = querier->OnRecordChanged(txt2, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); - EXPECT_CALL(callback, OnEndpointUpdated(_)).Times(1); - querier.OnRecordChanged(a_record, RecordChangedEvent::kExpired); + EXPECT_CALL(callback, OnEndpointCreated(_)).Times(2); + changes = querier->OnRecordChanged(a, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); testing::Mock::VerifyAndClearExpectations(&callback); + + EXPECT_CALL(callback, OnEndpointDeleted(_)).Times(2); + changes = querier->OnRecordChanged(a, RecordChangedEvent::kExpired); + EXPECT_EQ(changes.size(), size_t{0}); } -TEST_F(DnsSdQuerierImplTest, OnlyNewRecordValid) { - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_srv(CreateSrvRecord()); - dns_data.set_txt(MakeTxtRecord({})); +TEST_F(DnsSdQuerierImplTest, TestCreateDeletePtrRecordResults) { + const auto ptr = GetFakePtrRecord(name); - auto a_rdata = CreateARecord(); - DomainName kDomainName{"instance", "_service", "_udp", "local"}; - MdnsRecord a_record(std::move(kDomainName), DnsType::kA, DnsClass::kIN, - RecordType::kUnique, std::chrono::seconds(0), a_rdata); + auto result = querier->OnRecordChanged(ptr, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(result, name, 1); - EXPECT_CALL(callback, OnEndpointCreated(_)).Times(1); - querier.OnRecordChanged(a_record, RecordChangedEvent::kCreated); + result = querier->OnRecordChanged(ptr, RecordChangedEvent::kExpired); + ValidateRecordChangeStopsQuery(result, name, 1); } -TEST_F(DnsSdQuerierImplTest, OnlyOldRecordValid) { - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_srv(CreateSrvRecord()); - dns_data.set_txt(MakeTxtRecord({})); - dns_data.set_a(CreateARecord()); +TEST_F(DnsSdQuerierImplTest, CallbackCalledWhenPtrDeleted) { + MdnsRecord ptr = GetFakePtrRecord(name); + MdnsRecord srv = GetFakeSrvRecord(name, name2); + MdnsRecord txt = GetFakeTxtRecord(name); + MdnsRecord a = GetFakeARecord(name2); + + auto changes = querier->OnRecordChanged(ptr, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(changes, name, 1); + + changes = querier->OnRecordChanged(srv, RecordChangedEvent::kCreated); + ValidateRecordChangeStartsQuery(changes, name2, 1); + + changes = querier->OnRecordChanged(txt, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); - auto a_rdata = CreateARecord(); - DomainName kDomainName{"instance", "_service", "_udp", "local"}; - MdnsRecord a_record(std::move(kDomainName), DnsType::kA, DnsClass::kIN, - RecordType::kUnique, std::chrono::seconds(0), a_rdata); + EXPECT_CALL(callback, OnEndpointCreated(_)); + changes = querier->OnRecordChanged(a, RecordChangedEvent::kCreated); + EXPECT_EQ(changes.size(), size_t{0}); - EXPECT_CALL(callback, OnEndpointDeleted(_)).Times(1); - querier.OnRecordChanged(a_record, RecordChangedEvent::kExpired); + EXPECT_CALL(callback, OnEndpointDeleted(_)); + changes = querier->OnRecordChanged(ptr, RecordChangedEvent::kExpired); + ValidateRecordChangeStopsQuery(changes, name, 2); + ValidateRecordChangeStopsQuery(changes, name2, 2); } TEST_F(DnsSdQuerierImplTest, HardRefresh) { - const std::string service2 = "_service2._udp"; - - DnsDataAccessor dns_data = querier.CreateDnsData(instance, service, domain); - dns_data.set_srv(CreateSrvRecord()); - dns_data.set_txt(MakeTxtRecord({})); - dns_data.set_a(CreateARecord()); - dns_data.set_aaaa(CreateAAAARecord()); - DnsDataAccessor dns_data2 = querier.CreateDnsData(instance, service2, domain); - dns_data2.set_srv(CreateSrvRecord()); - - EXPECT_CALL(callback, OnEndpointCreated(_)).Times(1); - querier.StartQuery(service, &callback); - EXPECT_TRUE(querier.IsQueryRunning(service)); - - const DomainName ptr_domain{"_service", "_udp", "local"}; - const DomainName instance_domain{"instance", "_service", "_udp", "local"}; - EXPECT_CALL(*querier.service(), ReinitializeQueries(ptr_domain)); - EXPECT_CALL(*querier.service(), StopQuery(instance_domain, _, _, _)); - querier.ReinitializeQueries(service); - testing::Mock::VerifyAndClearExpectations(querier.service()); - - absl::optional<DnsDataAccessor> data = - querier.GetDnsData(instance, service, domain); - EXPECT_EQ(data, absl::nullopt); - data = querier.GetDnsData(instance, service2, domain); - EXPECT_NE(data, absl::nullopt); - EXPECT_TRUE(querier.IsQueryRunning(service)); + MdnsRecord ptr = GetFakePtrRecord(name); + MdnsRecord srv = GetFakeSrvRecord(name, name2); + MdnsRecord txt = GetFakeTxtRecord(name); + MdnsRecord a = GetFakeARecord(name2); + + querier->OnRecordChanged(ptr, RecordChangedEvent::kCreated); + querier->OnRecordChanged(srv, RecordChangedEvent::kCreated); + querier->OnRecordChanged(txt, RecordChangedEvent::kCreated); + + EXPECT_CALL(callback, OnEndpointCreated(_)); + querier->OnRecordChanged(a, RecordChangedEvent::kCreated); + testing::Mock::VerifyAndClearExpectations(&callback); + + EXPECT_CALL(querier->service(), + StopQuery(ptr_domain, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + EXPECT_CALL(querier->service(), + StopQuery(name, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + EXPECT_CALL(querier->service(), + StopQuery(name2, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + EXPECT_CALL(querier->service(), ReinitializeQueries(_)).Times(1); + EXPECT_CALL(querier->service(), + StartQuery(ptr_domain, DnsType::kANY, DnsClass::kANY, _)) + .Times(1); + querier->ReinitializeQueries(service); + testing::Mock::VerifyAndClearExpectations(querier.get()); +} + +// Edge Cases +// +// The below tests validate against edge cases that either either difficult to +// achieve, are not expected to be possible under normal circumstances but +// should be validated against for safety, or should only occur when either a +// bad actor or a misbehaving publisher is present on the network. To simplify +// these tests, the DnsDataGraph object will be mocked. +TEST_F(DnsSdQuerierImplTest, ErrorsOnlyAfterChangesAreLogged) { + MockDnsDataGraph& mock_graph = querier->GetMockedGraph(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> before_changes{}; + std::vector<ErrorOr<DnsSdInstanceEndpoint>> after_changes{}; + after_changes.emplace_back(Error::Code::kItemNotFound); + after_changes.emplace_back(Error::Code::kItemNotFound); + after_changes.emplace_back(Error::Code::kItemAlreadyExists); + + // Calls before and after applying record changes, then the error it logs. + EXPECT_CALL(mock_graph, CreateEndpoints(_, _)) + .WillOnce(Return(ByMove(std::move(before_changes)))) + .WillOnce(Return(ByMove(std::move(after_changes)))); + EXPECT_CALL(querier->reporting_client(), OnRecoverableError(_)).Times(3); + + // Call to apply record changes. The specifics are unimportant. + EXPECT_CALL(mock_graph, ApplyDataRecordChange(_, _, _, _)) + .WillOnce(Return(Error::None())); + + // Call with any record. The mocks make the specifics unimportant. + querier->OnRecordChanged(GetFakePtrRecord(name), + RecordChangedEvent::kCreated); +} + +TEST_F(DnsSdQuerierImplTest, ErrorsOnlyBeforeChangesNotLogged) { + MockDnsDataGraph& mock_graph = querier->GetMockedGraph(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> before_changes{}; + before_changes.emplace_back(Error::Code::kItemNotFound); + before_changes.emplace_back(Error::Code::kItemNotFound); + before_changes.emplace_back(Error::Code::kItemAlreadyExists); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> after_changes{}; + + // Calls before and after applying record changes. + EXPECT_CALL(mock_graph, CreateEndpoints(_, _)) + .WillOnce(Return(ByMove(std::move(before_changes)))) + .WillOnce(Return(ByMove(std::move(after_changes)))); + + // Call to apply record changes. The specifics are unimportant. + EXPECT_CALL(mock_graph, ApplyDataRecordChange(_, _, _, _)) + .WillOnce(Return(Error::None())); + + // Call with any record. The mocks make the specifics unimportant. + querier->OnRecordChanged(GetFakePtrRecord(name), + RecordChangedEvent::kCreated); +} + +TEST_F(DnsSdQuerierImplTest, ErrorsBeforeAndAfterChangesNotLogged) { + MockDnsDataGraph& mock_graph = querier->GetMockedGraph(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> before_changes{}; + before_changes.emplace_back(Error::Code::kItemNotFound); + before_changes.emplace_back(Error::Code::kItemNotFound); + before_changes.emplace_back(Error::Code::kItemAlreadyExists); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> after_changes{}; + after_changes.emplace_back(Error::Code::kItemNotFound); + after_changes.emplace_back(Error::Code::kItemAlreadyExists); + after_changes.emplace_back(Error::Code::kItemNotFound); + + // Calls before and after applying record changes. + EXPECT_CALL(mock_graph, CreateEndpoints(_, _)) + .WillOnce(Return(ByMove(std::move(before_changes)))) + .WillOnce(Return(ByMove(std::move(after_changes)))); + + // Call to apply record changes. The specifics are unimportant. + EXPECT_CALL(mock_graph, ApplyDataRecordChange(_, _, _, _)) + .WillOnce(Return(Error::None())); + + // Call with any record. The mocks make the specifics unimportant. + querier->OnRecordChanged(GetFakePtrRecord(name), + RecordChangedEvent::kCreated); +} + +TEST_F(DnsSdQuerierImplTest, OrderOfErrorsDoesNotAffectResults) { + MockDnsDataGraph& mock_graph = querier->GetMockedGraph(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> before_changes{}; + before_changes.emplace_back(Error::Code::kIndexOutOfBounds); + before_changes.emplace_back(Error::Code::kItemAlreadyExists); + before_changes.emplace_back(Error::Code::kOperationCancelled); + before_changes.emplace_back(Error::Code::kItemNotFound); + before_changes.emplace_back(Error::Code::kOperationInProgress); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> after_changes{}; + after_changes.emplace_back(Error::Code::kOperationInProgress); + after_changes.emplace_back(Error::Code::kUnknownError); + after_changes.emplace_back(Error::Code::kItemNotFound); + after_changes.emplace_back(Error::Code::kItemAlreadyExists); + after_changes.emplace_back(Error::Code::kOperationCancelled); + + // Calls before and after applying record changes, then the error it logs. + EXPECT_CALL(mock_graph, CreateEndpoints(_, _)) + .WillOnce(Return(ByMove(std::move(before_changes)))) + .WillOnce(Return(ByMove(std::move(after_changes)))); + EXPECT_CALL(querier->reporting_client(), OnRecoverableError(_)).Times(1); + + // Call to apply record changes. The specifics are unimportant. + EXPECT_CALL(mock_graph, ApplyDataRecordChange(_, _, _, _)) + .WillOnce(Return(Error::None())); + + // Call with any record. The mocks make the specifics unimportant. + querier->OnRecordChanged(GetFakePtrRecord(name), + RecordChangedEvent::kCreated); +} + +TEST_F(DnsSdQuerierImplTest, ResultsWithMultipleAddressRecordsHandled) { + IPEndpoint endpointa{{192, 168, 86, 23}, 80}; + IPEndpoint endpointb{{1, 2, 3, 4, 5, 6, 7, 8}, 80}; + IPEndpoint endpointc{{192, 168, 0, 1}, 80}; + IPEndpoint endpointd{{192, 168, 0, 2}, 80}; + IPEndpoint endpointe{{192, 168, 0, 3}, 80}; + + DnsSdInstanceEndpoint instance1("instance1", "_service._udp", "local", {}, + kNetworkInterface, {endpointa, endpointb}); + DnsSdInstanceEndpoint instance2("instance2", "_service2._udp", "local", {}, + kNetworkInterface, {endpointa, endpointb}); + DnsSdInstanceEndpoint instance3("instance3", "_service._udp", "local", {}, + kNetworkInterface, {endpointc}); + DnsSdInstanceEndpoint instance4("instance1", "_service3._udp", "local", {}, + kNetworkInterface, {endpointd, endpointe}); + DnsSdInstanceEndpoint instance5("instance1", "_service3._udp", "local", {}, + kNetworkInterface, {endpointe}); + + MockDnsDataGraph& mock_graph = querier->GetMockedGraph(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> before_changes{}; + before_changes.emplace_back(instance4); + before_changes.emplace_back(instance2); + before_changes.emplace_back(instance3); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> after_changes{}; + after_changes.emplace_back(instance5); + after_changes.emplace_back(instance3); + after_changes.emplace_back(instance1); + + // Calls before and after applying record changes, then the error it logs. + EXPECT_CALL(mock_graph, CreateEndpoints(_, _)) + .WillOnce(Return(ByMove(std::move(before_changes)))) + .WillOnce(Return(ByMove(std::move(after_changes)))); + EXPECT_CALL(callback, OnEndpointCreated(instance1)); + EXPECT_CALL(callback, OnEndpointUpdated(instance5)); + EXPECT_CALL(callback, OnEndpointDeleted(instance2)); + + // Call to apply record changes. The specifics are unimportant. + EXPECT_CALL(mock_graph, ApplyDataRecordChange(_, _, _, _)) + .WillOnce(Return(Error::None())); + + // Call with any record. The mocks make the specifics unimportant. + querier->OnRecordChanged(GetFakePtrRecord(name), + RecordChangedEvent::kCreated); +} + +TEST_F(DnsSdQuerierImplTest, MixOfErrorsAndSuccessesHandledCorrectly) { + DnsSdInstanceEndpoint instance1("instance1", "_service._udp", "local", {}, + kNetworkInterface, {{{192, 168, 2, 24}, 80}}); + DnsSdInstanceEndpoint instance2("instance2", "_service2._udp", "local", {}, + kNetworkInterface, {{{192, 168, 17, 2}, 80}}); + DnsSdInstanceEndpoint instance3("instance3", "_service._udp", "local", {}, + kNetworkInterface, {{{127, 0, 0, 1}, 80}}); + DnsSdInstanceEndpoint instance4("instance1", "_service3._udp", "local", {}, + kNetworkInterface, {{{127, 0, 0, 1}, 80}}); + DnsSdInstanceEndpoint instance5("instance1", "_service3._udp", "local", {}, + kNetworkInterface, + {{{127, 0, 0, 1}, 80}, {{127, 0, 0, 2}, 80}}); + + MockDnsDataGraph& mock_graph = querier->GetMockedGraph(); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> before_changes{}; + before_changes.emplace_back(Error::Code::kIndexOutOfBounds); + before_changes.emplace_back(instance2); + before_changes.emplace_back(Error::Code::kItemAlreadyExists); + before_changes.emplace_back(Error::Code::kOperationCancelled); + before_changes.emplace_back(instance1); + before_changes.emplace_back(Error::Code::kItemNotFound); + before_changes.emplace_back(Error::Code::kOperationInProgress); + before_changes.emplace_back(instance4); + std::vector<ErrorOr<DnsSdInstanceEndpoint>> after_changes{}; + after_changes.emplace_back(instance1); + after_changes.emplace_back(Error::Code::kOperationInProgress); + after_changes.emplace_back(Error::Code::kUnknownError); + after_changes.emplace_back(Error::Code::kItemNotFound); + after_changes.emplace_back(Error::Code::kItemAlreadyExists); + after_changes.emplace_back(instance3); + after_changes.emplace_back(instance5); + after_changes.emplace_back(Error::Code::kOperationCancelled); + + // Calls before and after applying record changes, then the error it logs. + EXPECT_CALL(mock_graph, CreateEndpoints(_, _)) + .WillOnce(Return(ByMove(std::move(before_changes)))) + .WillOnce(Return(ByMove(std::move(after_changes)))); + EXPECT_CALL(querier->reporting_client(), OnRecoverableError(_)).Times(1); + EXPECT_CALL(callback, OnEndpointCreated(instance3)); + EXPECT_CALL(callback, OnEndpointUpdated(instance5)); + EXPECT_CALL(callback, OnEndpointDeleted(instance2)); + + // Call to apply record changes. The specifics are unimportant. + EXPECT_CALL(mock_graph, ApplyDataRecordChange(_, _, _, _)) + .WillOnce(Return(Error::None())); + + // Call with any record. The mocks make the specifics unimportant. + querier->OnRecordChanged(GetFakePtrRecord(name), + RecordChangedEvent::kCreated); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_instance.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_instance.cc index 97ce80654b2..991f0ffe241 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_instance.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_instance.cc @@ -36,8 +36,8 @@ ServiceInstance::ServiceInstance(TaskRunner* task_runner, !network_config_.HasAddressV6()); if (config.enable_querying) { - querier_ = std::make_unique<QuerierImpl>(mdns_service_.get(), task_runner_, - &network_config_); + querier_ = std::make_unique<QuerierImpl>( + mdns_service_.get(), task_runner_, reporting_client, &network_config_); } if (config.enable_publication) { publisher_ = std::make_unique<PublisherImpl>( diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc index e5af9d3ba89..20b0986f0ad 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.cc @@ -33,8 +33,8 @@ ServiceKey::ServiceKey(const DomainName& domain) { ServiceKey::ServiceKey(absl::string_view service, absl::string_view domain) : service_id_(service.data(), service.size()), domain_id_(domain.data(), domain.size()) { - OSP_DCHECK(IsServiceValid(service_id_)); - OSP_DCHECK(IsDomainValid(domain_id_)); + OSP_DCHECK(IsServiceValid(service_id_)) << "invalid service id: " << service; + OSP_DCHECK(IsDomainValid(domain_id_)) << "invalid domain id: " << domain; } ServiceKey::ServiceKey(const ServiceKey& other) = default; @@ -43,6 +43,12 @@ ServiceKey::ServiceKey(ServiceKey&& other) = default; ServiceKey& ServiceKey::operator=(const ServiceKey& rhs) = default; ServiceKey& ServiceKey::operator=(ServiceKey&& rhs) = default; +DomainName ServiceKey::GetName() const { + std::string service_type = service_id().substr(0, service_id().size() - 5); + std::string protocol = service_id().substr(service_id().size() - 4); + return DomainName{std::move(service_type), std::move(protocol), domain_id_}; +} + // static ErrorOr<ServiceKey> ServiceKey::TryCreate(const MdnsRecord& record) { return TryCreate(GetDomainName(record)); diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h index 4e6ff7fae8d..33ff7aea6ec 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/impl/service_key.h @@ -24,6 +24,7 @@ class ServiceKey { // NOTE: The record provided must have valid service domain labels. explicit ServiceKey(const MdnsRecord& record); explicit ServiceKey(const DomainName& domain); + virtual ~ServiceKey() = default; // NOTE: The provided service and domain labels must be valid. ServiceKey(absl::string_view service, absl::string_view domain); @@ -33,6 +34,8 @@ class ServiceKey { ServiceKey& operator=(const ServiceKey& rhs); ServiceKey& operator=(ServiceKey&& rhs); + virtual DomainName GetName() const; + const std::string& service_id() const { return service_id_; } const std::string& domain_id() const { return domain_id_; } diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.cc index 5ac540c284f..2c1382b2fb2 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.cc @@ -4,7 +4,10 @@ #include "discovery/dnssd/public/dns_sd_instance.h" +#include <algorithm> #include <cctype> +#include <utility> +#include <vector> #include "util/osp_logging.h" @@ -12,6 +15,9 @@ namespace openscreen { namespace discovery { namespace { +// Maximum number of octets allowed in a single domain name label. +constexpr size_t kMaxLabelLength = 63; + bool IsValidUtf8(const std::string& string) { for (size_t i = 0; i < string.size(); i++) { if (string[i] >> 5 == 0x06) { // 110xxxxx 10xxxxxx @@ -51,19 +57,37 @@ DnsSdInstance::DnsSdInstance(std::string instance_id, std::string service_id, std::string domain_id, DnsSdTxtRecord txt, - uint16_t port) + uint16_t port, + std::vector<Subtype> subtypes) : instance_id_(std::move(instance_id)), service_id_(std::move(service_id)), domain_id_(std::move(domain_id)), txt_(std::move(txt)), - port_(port) { - OSP_DCHECK(IsInstanceValid(instance_id_)); - OSP_DCHECK(IsServiceValid(service_id_)); - OSP_DCHECK(IsDomainValid(domain_id_)); + port_(port), + subtypes_(std::move(subtypes)) { + OSP_DCHECK(IsInstanceValid(instance_id_)) + << instance_id_ << " is an invalid instance id"; + OSP_DCHECK(IsServiceValid(service_id_)) + << service_id_ << " is an invalid service id"; + OSP_DCHECK(IsDomainValid(domain_id_)) + << domain_id_ << " is an invalid domain"; + for (const Subtype& subtype : subtypes_) { + OSP_DCHECK(IsSubtypeValid(subtype)) << subtype << " is an invalid subtype"; + } + + std::sort(subtypes_.begin(), subtypes_.end()); } +DnsSdInstance::DnsSdInstance(const DnsSdInstance& other) = default; + +DnsSdInstance::DnsSdInstance(DnsSdInstance&& other) = default; + DnsSdInstance::~DnsSdInstance() = default; +DnsSdInstance& DnsSdInstance::operator=(const DnsSdInstance& rhs) = default; + +DnsSdInstance& DnsSdInstance::operator=(DnsSdInstance&& rhs) = default; + // static bool IsInstanceValid(const std::string& instance) { // According to RFC6763, Instance names must: @@ -71,8 +95,8 @@ bool IsInstanceValid(const std::string& instance) { // - NOT contain ASCII control characters // - Be no longer than 63 octets. - return instance.size() <= 63 && !HasControlCharacters(instance) && - IsValidUtf8(instance); + return instance.size() <= kMaxLabelLength && + !HasControlCharacters(instance) && IsValidUtf8(instance); } // static @@ -134,7 +158,7 @@ bool IsDomainValid(const std::string& domain) { size_t label_start = 0; for (size_t next_dot = domain.find('.'); next_dot != std::string::npos; next_dot = domain.find('.', label_start)) { - if (next_dot - label_start > 63) { + if (next_dot - label_start > kMaxLabelLength) { return false; } label_start = next_dot + 1; @@ -143,6 +167,20 @@ bool IsDomainValid(const std::string& domain) { return !HasControlCharacters(domain) && IsValidUtf8(domain); } +// static +bool IsSubtypeValid(const DnsSdInstance::Subtype& subtype) { + // As specified in RFC6763 section 9.1, all subtypes may be arbitrary bit + // data. Despite this, this implementation has chosen to limit valid subtypes + // to only UTF8 character strings. Therefore, the subtype must: + // - Be encoded in Net-Unicode (which required UTF-8 formatting). + // - NOT contain ASCII control characters + // - Be no longer than 63 octets. + // - Be of length one label. + return subtype.size() <= kMaxLabelLength && + subtype.find('.') == std::string::npos && + !HasControlCharacters(subtype) && IsValidUtf8(subtype); +} + bool operator<(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { if (lhs.port_ != rhs.port_) { return lhs.port_ < rhs.port_; @@ -163,6 +201,17 @@ bool operator<(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { return comp < 0; } + if (lhs.subtypes_.size() != rhs.subtypes_.size()) { + return lhs.subtypes_.size() < rhs.subtypes_.size(); + } + + for (size_t i = 0; i < lhs.subtypes_.size(); i++) { + comp = lhs.subtypes_[i].compare(rhs.subtypes_[i]); + if (comp != 0) { + return comp < 0; + } + } + return lhs.txt_ < rhs.txt_; } diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.h index 3acfcdd825e..5e89c059b4a 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance.h @@ -6,9 +6,12 @@ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_INSTANCE_H_ #include <string> +#include <utility> +#include <vector> #include "discovery/dnssd/public/dns_sd_txt_record.h" #include "platform/base/ip_address.h" +#include "util/std_util.h" namespace openscreen { namespace discovery { @@ -20,14 +23,38 @@ bool IsDomainValid(const std::string& domain); // Represents the data stored in DNS records of types SRV, TXT, A, and AAAA class DnsSdInstance { public: + using Subtype = std::string; + // These ctors expect valid input, and will cause a crash if they are not. DnsSdInstance(std::string instance_id, std::string service_id, std::string domain_id, DnsSdTxtRecord txt, - uint16_t port); + uint16_t port, + std::vector<Subtype> subtypes); + + template <typename... T> + DnsSdInstance(std::string instance_id, + std::string service_id, + std::string domain_id, + DnsSdTxtRecord txt, + uint16_t port, + T... subtypes) + : DnsSdInstance(std::move(instance_id), + std::move(service_id), + std::move(domain_id), + std::move(txt), + port, + std::vector<Subtype>{std::move(subtypes)...}) {} + + DnsSdInstance(const DnsSdInstance& other); + DnsSdInstance(DnsSdInstance&& other); + virtual ~DnsSdInstance(); + DnsSdInstance& operator=(const DnsSdInstance& rhs); + DnsSdInstance& operator=(DnsSdInstance&& rhs); + // Returns the instance name for this DNS-SD record. const std::string& instance_id() const { return instance_id_; } @@ -43,6 +70,9 @@ class DnsSdInstance { // Returns the port associated with this instance record. uint16_t port() const { return port_; } + // The set of subtypes for this instance. + const std::vector<Subtype>& subtypes() { return subtypes_; } + private: std::string instance_id_; std::string service_id_; @@ -50,9 +80,21 @@ class DnsSdInstance { DnsSdTxtRecord txt_; uint16_t port_; + // Subtypes of this instance which have been received so far. + // NOTE: Subtypes are stored in sorted order to simplify comparison. + // NOTE: This vector will always be empty for incoming queries and will not be + // respected for publications. It is only present for future use. + // + // TODO(issuetracker.google.com/158533407): Implement use of this field. + std::vector<Subtype> subtypes_; + friend bool operator<(const DnsSdInstance& lhs, const DnsSdInstance& rhs); }; +bool IsSubtypeValid(const DnsSdInstance::Subtype& subtype); + +bool IsValid(const DnsSdInstance::Subtype& subtype); + bool operator<(const DnsSdInstance& lhs, const DnsSdInstance& rhs); inline bool operator>(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { @@ -60,11 +102,11 @@ inline bool operator>(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { } inline bool operator<=(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { - return !(rhs > lhs); + return !(lhs > rhs); } inline bool operator>=(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { - return !(rhs < lhs); + return !(lhs < rhs); } inline bool operator==(const DnsSdInstance& lhs, const DnsSdInstance& rhs) { diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.cc index 168f80623d4..7cc9ed41787 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.cc @@ -4,7 +4,10 @@ #include "discovery/dnssd/public/dns_sd_instance_endpoint.h" +#include <algorithm> #include <cctype> +#include <utility> +#include <vector> #include "util/osp_logging.h" @@ -16,85 +19,66 @@ DnsSdInstanceEndpoint::DnsSdInstanceEndpoint( std::string service_id, std::string domain_id, DnsSdTxtRecord txt, - IPEndpoint endpoint, - NetworkInterfaceIndex network_interface) - : DnsSdInstance(std::move(instance_id), - std::move(service_id), - std::move(domain_id), - std::move(txt), - endpoint.port), - network_interface_(network_interface) { - OSP_DCHECK(endpoint); - if (endpoint.address.IsV4()) { - address_v4_ = std::move(endpoint.address); - } else if (endpoint.address.IsV6()) { - address_v6_ = std::move(endpoint.address); - } else { - OSP_NOTREACHED(); - } -} - -DnsSdInstanceEndpoint::DnsSdInstanceEndpoint( - DnsSdInstance record, - IPAddress address, - NetworkInterfaceIndex network_interface) - : DnsSdInstance(std::move(record)), network_interface_(network_interface) { - OSP_DCHECK(address); - if (address.IsV4()) { - address_v4_ = std::move(address); - } else if (address.IsV6()) { - address_v6_ = std::move(address); - } else { - OSP_NOTREACHED(); - } -} + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoints) + : DnsSdInstanceEndpoint(std::move(instance_id), + std::move(service_id), + std::move(domain_id), + std::move(txt), + network_interface, + std::move(endpoints), + std::vector<Subtype>{}) {} DnsSdInstanceEndpoint::DnsSdInstanceEndpoint( std::string instance_id, std::string service_id, std::string domain_id, DnsSdTxtRecord txt, - IPEndpoint ipv4_endpoint, - IPEndpoint ipv6_endpoint, - NetworkInterfaceIndex network_interface) + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoints, + std::vector<Subtype> subtypes) : DnsSdInstance(std::move(instance_id), std::move(service_id), std::move(domain_id), std::move(txt), - ipv4_endpoint.port), - address_v4_(std::move(ipv4_endpoint.address)), - address_v6_(std::move(ipv6_endpoint.address)), + endpoints.empty() ? 0 : endpoints[0].port, + std::move(subtypes)), + endpoints_(std::move(endpoints)), network_interface_(network_interface) { - OSP_CHECK(address_v4_); - OSP_CHECK(address_v6_); - OSP_CHECK(address_v4_.IsV4()); - OSP_CHECK(address_v6_.IsV6()); - OSP_CHECK_EQ(ipv4_endpoint.port, ipv6_endpoint.port); + InitializeEndpoints(); } DnsSdInstanceEndpoint::DnsSdInstanceEndpoint( DnsSdInstance instance, - IPAddress ipv4_address, - IPAddress ipv6_address, - NetworkInterfaceIndex network_interface) + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoints) : DnsSdInstance(std::move(instance)), - address_v4_(std::move(ipv4_address)), - address_v6_(std::move(ipv6_address)), + endpoints_(std::move(endpoints)), network_interface_(network_interface) { - OSP_CHECK(address_v4_); - OSP_CHECK(address_v6_); - OSP_CHECK(address_v4_.IsV4()); - OSP_CHECK(address_v6_.IsV6()); + InitializeEndpoints(); } +DnsSdInstanceEndpoint::DnsSdInstanceEndpoint( + const DnsSdInstanceEndpoint& other) = default; + +DnsSdInstanceEndpoint::DnsSdInstanceEndpoint(DnsSdInstanceEndpoint&& other) = + default; + DnsSdInstanceEndpoint::~DnsSdInstanceEndpoint() = default; -IPEndpoint DnsSdInstanceEndpoint::endpoint_v4() const { - return address_v4_ ? IPEndpoint{address_v4_, port()} : IPEndpoint{}; -} +DnsSdInstanceEndpoint& DnsSdInstanceEndpoint::operator=( + const DnsSdInstanceEndpoint& rhs) = default; -IPEndpoint DnsSdInstanceEndpoint::endpoint_v6() const { - return address_v6_ ? IPEndpoint{address_v6_, port()} : IPEndpoint{}; +DnsSdInstanceEndpoint& DnsSdInstanceEndpoint::operator=( + DnsSdInstanceEndpoint&& rhs) = default; + +void DnsSdInstanceEndpoint::InitializeEndpoints() { + OSP_CHECK(!endpoints_.empty()); + std::sort(endpoints_.begin(), endpoints_.end()); + for (const auto& endpoint : endpoints_) { + OSP_DCHECK_EQ(endpoint.port, port()); + addresses_.push_back(endpoint.address); + } } bool operator<(const DnsSdInstanceEndpoint& lhs, @@ -103,12 +87,14 @@ bool operator<(const DnsSdInstanceEndpoint& lhs, return lhs.network_interface_ < rhs.network_interface_; } - if (lhs.address_v4_ != rhs.address_v4_) { - return lhs.address_v4_ < rhs.address_v4_; + if (lhs.endpoints_.size() != rhs.endpoints_.size()) { + return lhs.endpoints_.size() < rhs.endpoints_.size(); } - if (lhs.address_v6_ != rhs.address_v6_) { - return lhs.address_v6_ < rhs.address_v6_; + for (int i = 0; i < static_cast<int>(lhs.endpoints_.size()); i++) { + if (lhs.endpoints_[i] != rhs.endpoints_[i]) { + return lhs.endpoints_[i] < rhs.endpoints_[i]; + } } return static_cast<const DnsSdInstance&>(lhs) < diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.h b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.h index c0d443a4c0e..f8d0803dda1 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint.h @@ -6,6 +6,8 @@ #define DISCOVERY_DNSSD_PUBLIC_DNS_SD_INSTANCE_ENDPOINT_H_ #include <string> +#include <utility> +#include <vector> #include "discovery/dnssd/public/dns_sd_instance.h" #include "discovery/dnssd/public/dns_sd_txt_record.h" @@ -18,46 +20,128 @@ namespace discovery { // Represents the data stored in DNS records of types SRV, TXT, A, and AAAA class DnsSdInstanceEndpoint : public DnsSdInstance { public: + using DnsSdInstance::Subtype; + // These ctors expect valid input, and will cause a crash if they are not. + // Additionally, these ctors expect at least one IPAddress will be provided. + DnsSdInstanceEndpoint(DnsSdInstance record, + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> address); + DnsSdInstanceEndpoint(std::string instance_id, std::string service_id, std::string domain_id, DnsSdTxtRecord txt, - IPEndpoint endpoint, - NetworkInterfaceIndex network_interface); - DnsSdInstanceEndpoint(DnsSdInstance record, - IPAddress address, - NetworkInterfaceIndex network_interface); + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoint); + DnsSdInstanceEndpoint(std::string instance_id, + std::string service_id, + std::string domain_id, + DnsSdTxtRecord txt, + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoint, + std::vector<Subtype> subtypes); - // NOTE: These constructors expects one endpoint to be an IPv4 address and the - // other to be an IPv6 address. + // Overloads of the above ctors to allow for simpler creation. The same + // expectations as above apply. + template <typename... TEndpoints> + DnsSdInstanceEndpoint(DnsSdInstance record, + NetworkInterfaceIndex network_interface, + TEndpoints... endpoints) + : DnsSdInstanceEndpoint( + std::move(record), + network_interface, + std::vector<IPEndpoint>{std::move(endpoints)...}) {} + + // NOTE: All subtypes must follow all IPEndpoints. + template <typename... Types> DnsSdInstanceEndpoint(std::string instance_id, std::string service_id, std::string domain_id, DnsSdTxtRecord txt, - IPEndpoint ipv4_endpoint, - IPEndpoint ipv6_endpoint, - NetworkInterfaceIndex network_interface); - DnsSdInstanceEndpoint(DnsSdInstance instance, - IPAddress address_v4, - IPAddress address_v6, - NetworkInterfaceIndex network_interface); + NetworkInterfaceIndex network_interface, + IPEndpoint endpoint, + Types... types) + : DnsSdInstanceEndpoint( + std::move(instance_id), + std::move(service_id), + std::move(domain_id), + std::move(txt), + network_interface, + GetVectorWithCapacity<IPEndpoint>(sizeof...(Types) + 1), + GetVectorWithCapacity<Subtype>(sizeof...(Types)), + std::move(endpoint), + std::move(types)...) {} + + DnsSdInstanceEndpoint(const DnsSdInstanceEndpoint& other); + DnsSdInstanceEndpoint(DnsSdInstanceEndpoint&& other); ~DnsSdInstanceEndpoint() override; + DnsSdInstanceEndpoint& operator=(const DnsSdInstanceEndpoint& rhs); + DnsSdInstanceEndpoint& operator=(DnsSdInstanceEndpoint&& rhs); + // Returns the address associated with this DNS-SD record. In any valid // record, at least one will be set. - const IPAddress& address_v4() const { return address_v4_; } - const IPAddress& address_v6() const { return address_v6_; } - IPEndpoint endpoint_v4() const; - IPEndpoint endpoint_v6() const; + const std::vector<IPAddress>& addresses() const { return addresses_; } + const std::vector<IPEndpoint>& endpoints() const { return endpoints_; } // Network Interface associated with this endpoint. NetworkInterfaceIndex network_interface() const { return network_interface_; } private: - IPAddress address_v4_; - IPAddress address_v6_; + // Pick off the first IPEndpoint then call again recursively. + template <typename... Types> + DnsSdInstanceEndpoint(std::string instance_id, + std::string service_id, + std::string domain_id, + DnsSdTxtRecord txt, + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoints, + std::vector<Subtype> subtypes, + IPEndpoint endpoint, + Types... types) + : DnsSdInstanceEndpoint(std::move(instance_id), + std::move(service_id), + std::move(domain_id), + std::move(txt), + network_interface, + Append(std::move(endpoints), std::move(endpoint)), + std::move(subtypes), + std::move(types)...) {} + + // All following arguments must be Subtypes, so pull them all off and recurse + // to a non-templated ctor. + template <typename... Types> + DnsSdInstanceEndpoint(std::string instance_id, + std::string service_id, + std::string domain_id, + DnsSdTxtRecord txt, + NetworkInterfaceIndex network_interface, + std::vector<IPEndpoint> endpoints, + std::vector<Subtype> subtypes, + Subtype subtype, + Types... types) + : DnsSdInstanceEndpoint(std::move(instance_id), + std::move(service_id), + std::move(domain_id), + std::move(txt), + network_interface, + std::move(endpoints), + Append(std::move(subtypes), + std::move(subtype), + std::move(types)...)) {} + + // Lazy Initializes the |addresses_| vector. + const std::vector<IPAddress>& CalculateAddresses() const; + + // Initialized the |endpoints_| vector after construction. + void InitializeEndpoints(); + + // NOTE: The below vector is stored in sorted order to make comparison + // simpler. + std::vector<IPEndpoint> endpoints_; + std::vector<IPAddress> addresses_; NetworkInterfaceIndex network_interface_; @@ -75,12 +159,12 @@ inline bool operator>(const DnsSdInstanceEndpoint& lhs, inline bool operator<=(const DnsSdInstanceEndpoint& lhs, const DnsSdInstanceEndpoint& rhs) { - return !(rhs > lhs); + return !(lhs > rhs); } inline bool operator>=(const DnsSdInstanceEndpoint& lhs, const DnsSdInstanceEndpoint& rhs) { - return !(rhs < lhs); + return !(lhs < rhs); } inline bool operator==(const DnsSdInstanceEndpoint& lhs, diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint_unittest.cc new file mode 100644 index 00000000000..b9ee3ac8b62 --- /dev/null +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_endpoint_unittest.cc @@ -0,0 +1,77 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "discovery/dnssd/public/dns_sd_instance_endpoint.h" + +#include "discovery/dnssd/public/dns_sd_instance.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace openscreen { +namespace discovery { + +TEST(DnsSdInstanceEndpointTests, ComparisonTests) { + constexpr NetworkInterfaceIndex kIndex0 = 0; + constexpr NetworkInterfaceIndex kIndex1 = 1; + DnsSdInstance instance("instance", "_test._tcp", "local", {}, 80); + DnsSdInstance instance2("instance", "_test._tcp", "local", {}, 79); + IPEndpoint ep1{{192, 168, 80, 32}, 80}; + IPEndpoint ep2{{192, 168, 80, 32}, 79}; + IPEndpoint ep3{{192, 168, 80, 33}, 79}; + DnsSdInstanceEndpoint endpoint1(instance, kIndex1, ep1); + DnsSdInstanceEndpoint endpoint2("instance", "_test._tcp", "local", {}, + kIndex1, ep1); + DnsSdInstanceEndpoint endpoint3(instance2, kIndex1, ep2); + DnsSdInstanceEndpoint endpoint4(instance2, kIndex0, ep2); + DnsSdInstanceEndpoint endpoint5(instance2, kIndex1, ep3); + DnsSdInstanceEndpoint endpoint6("instance", "_test._tcp", "local", {}, + kIndex1, ep1, "foo", "bar"); + DnsSdInstanceEndpoint endpoint7("instance", "_test._tcp", "local", {}, + kIndex1, ep1, "foo", "foobar"); + DnsSdInstanceEndpoint endpoint8("instance", "_test._tcp", "local", {}, + kIndex1, ep1, "foobar"); + + EXPECT_EQ(static_cast<DnsSdInstance>(endpoint1), + static_cast<DnsSdInstance>(endpoint2)); + EXPECT_EQ(endpoint1, endpoint2); + EXPECT_GE(endpoint1, endpoint3); + EXPECT_GE(endpoint1, endpoint4); + EXPECT_LE(endpoint1, endpoint5); + EXPECT_LE(endpoint1, endpoint6); + EXPECT_LE(endpoint1, endpoint7); + EXPECT_LE(endpoint1, endpoint8); + + EXPECT_GE(endpoint3, endpoint4); + EXPECT_LE(endpoint3, endpoint5); + + EXPECT_LE(endpoint4, endpoint5); + + EXPECT_LE(endpoint6, endpoint7); + EXPECT_GE(endpoint6, endpoint8); + EXPECT_GE(endpoint7, endpoint8); +} + +TEST(DnsSdInstanceEndpointTests, Constructors) { + constexpr NetworkInterfaceIndex kIndex = 0; + std::vector<std::string> subtypes{"foo", "bar", "foobar"}; + IPEndpoint endpoint1{{192, 168, 12, 21}, 80}; + IPEndpoint endpoint2{{227, 0, 0, 1}, 80}; + DnsSdInstance instance("instance", "_test._tcp", "local", {}, 80, subtypes); + + DnsSdInstanceEndpoint ep1(instance, kIndex, endpoint1, endpoint2); + DnsSdInstanceEndpoint ep2(instance, kIndex, + std::vector<IPEndpoint>{endpoint1, endpoint2}); + DnsSdInstanceEndpoint ep3("instance", "_test._tcp", "local", {}, kIndex, + endpoint1, endpoint2, "foo", "bar", "foobar"); + DnsSdInstanceEndpoint ep4("instance", "_test._tcp", "local", {}, kIndex, + std::vector<IPEndpoint>{endpoint1, endpoint2}, + subtypes); + + EXPECT_EQ(ep1, ep2); + EXPECT_EQ(ep1, ep3); + EXPECT_EQ(ep1, ep4); +} + +} // namespace discovery +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_unittest.cc b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_unittest.cc index f59052ba7d5..3479204123d 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/dnssd/public/dns_sd_instance_unittest.cc @@ -4,6 +4,8 @@ #include "discovery/dnssd/public/dns_sd_instance.h" +#include <vector> + #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -164,5 +166,86 @@ TEST(DnsSdInstanceTests, DomainUTF8) { } } +TEST(DnsSdInstanceTests, SubtypeCharacters) { + EXPECT_TRUE(IsSubtypeValid("IncludingSpecialCharacters+ =*&<<+`~\\/")); + EXPECT_TRUE(IsSubtypeValid("+ =*&<<+`~\\/ ")); + + EXPECT_FALSE(IsSubtypeValid("foo.bar")); + EXPECT_FALSE(IsSubtypeValid(std::string(1, uint8_t{0x7F}))); + EXPECT_FALSE(IsSubtypeValid(std::string("name with ") + + std::string(1, uint8_t{0x7F}) + + " in the middle")); + for (uint8_t bad_char = 0x0; bad_char <= 0x1F; bad_char++) { + EXPECT_FALSE(IsSubtypeValid(std::string(1, bad_char))); + EXPECT_FALSE(IsSubtypeValid(std::string("name with ") + + std::string(1, bad_char) + " in the middle")); + } +} + +TEST(DnsSdInstanceTests, SubtypeUTF8) { + // Sets of bytes which do not form valid UTF8 encoded chars. + std::vector<uint8_t> char_sets[] = { + {0x80}, + {0xC0}, + {0xC0, 0xFF}, + {0xE0}, + {0xE0, 0xFF}, + {0xE0, 0x80, 0x00}, + {0xF0}, + {0xF0, 0x00}, + {0xF0, 0x80, 0xFF}, + {0xF0, 0x80, 0x80, 0x0A}, + }; + + for (const auto& set : char_sets) { + std::string test_string = "start"; + for (uint8_t ch : set) { + test_string.append(std::string(1, ch)); + } + + EXPECT_FALSE(IsSubtypeValid(test_string)); + } +} + +TEST(DnsSdInstanceTests, SubtypeLength) { + std::string kCharsAlmostMaxLength = + "123456989012345678901234567890123456789012345678901234567890123"; + + ASSERT_EQ(kCharsAlmostMaxLength.size(), size_t{63}); + EXPECT_TRUE(IsSubtypeValid(kCharsAlmostMaxLength)); + EXPECT_FALSE(IsSubtypeValid(kCharsAlmostMaxLength + "4")); +} + +TEST(DnsSdInstanceTests, ComparisonTests) { + DnsSdTxtRecord set_record; + set_record.SetValue("foo", "bar"); + + DnsSdInstance kIn1("instance", "_service._tcp", "local", {}, 80); + DnsSdInstance kIn2("instance", "_service._tcp", "local", {}, 80); + DnsSdInstance kIn3("instance2", "_service._tcp", "local", {}, 80); + DnsSdInstance kIn4("instance", "_service2._tcp", "local", {}, 80); + DnsSdInstance kIn5("instance", "_service._tcp", "local2", {}, 80); + DnsSdInstance kIn6("instance", "_service._tcp", "local", set_record, 80); + DnsSdInstance kIn7("instance", "_service._tcp", "local", {}, 79); + DnsSdInstance kIn8("instance", "_service._tcp", "local", {}, 80, "foo"); + DnsSdInstance kIn9("instance", "_service._tcp", "local", {}, 80, "foobar"); + DnsSdInstance kIn10("instance", "_service._tcp", "local", {}, 80, "foo", + "bar"); + + EXPECT_EQ(kIn1, kIn2); + EXPECT_LT(kIn1, kIn3); + EXPECT_LT(kIn1, kIn4); + EXPECT_LT(kIn1, kIn5); + EXPECT_LT(kIn1, kIn6); + EXPECT_GT(kIn1, kIn7); + EXPECT_LT(kIn1, kIn8); + EXPECT_LT(kIn1, kIn9); + EXPECT_LT(kIn1, kIn10); + + EXPECT_LT(kIn8, kIn9); + EXPECT_LT(kIn8, kIn10); + EXPECT_LT(kIn9, kIn10); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h index 473ec68ebc1..8c9a80fb57d 100644 --- a/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h +++ b/chromium/third_party/openscreen/src/discovery/dnssd/testing/fake_dns_record_factory.h @@ -7,7 +7,7 @@ #include <stdint.h> -#include <chrono> // NOLINT +#include <chrono> #include "discovery/dnssd/impl/constants.h" #include "discovery/mdns/mdns_records.h" diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc index dc911043b0c..5e254aae36b 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe.cc @@ -93,8 +93,7 @@ void MdnsProbeImpl::Postpone(std::chrono::seconds delay) { successful_probe_queries_ = 0; alarm_.Cancel(); - alarm_.ScheduleFromNow([this]() { ProbeOnce(); }, - std::chrono::duration_cast<Clock::duration>(delay)); + alarm_.ScheduleFromNow([this]() { ProbeOnce(); }, Clock::to_duration(delay)); } void MdnsProbeImpl::OnMessageReceived(const MdnsMessage& message) { diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h index 6e8584ff7e9..a5cbbfa2798 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_probe_manager.h @@ -6,6 +6,7 @@ #define DISCOVERY_MDNS_MDNS_PROBE_MANAGER_H_ #include <memory> +#include <utility> #include <vector> #include "discovery/mdns/mdns_domain_confirmed_provider.h" @@ -62,8 +63,8 @@ class MdnsProbeManagerImpl : public MdnsProbe::Observer, MdnsRandom* random_delay, TaskRunner* task_runner, ClockNowFunctionPtr now_function); - MdnsProbeManagerImpl(const MdnsProbeManager& other) = delete; - MdnsProbeManagerImpl(MdnsProbeManager&& other) = delete; + MdnsProbeManagerImpl(const MdnsProbeManager& other) = delete; // NOLINT + MdnsProbeManagerImpl(MdnsProbeManager&& other) = delete; // NOLINT ~MdnsProbeManagerImpl() override; MdnsProbeManagerImpl& operator=(const MdnsProbeManagerImpl& other) = delete; diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc index 51aa30e7dfe..4c09d169747 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.cc @@ -42,11 +42,6 @@ inline MdnsRecord CreateGoodbyeRecord(const MdnsRecord& record) { record.record_type(), kGoodbyeTtl, record.rdata()); } -inline void ValidateRecord(const MdnsRecord& record) { - OSP_DCHECK(record.dns_type() != DnsType::kANY); - OSP_DCHECK(record.dns_class() != DnsClass::kANY); -} - } // namespace MdnsPublisher::MdnsPublisher(MdnsSender* sender, @@ -74,11 +69,11 @@ MdnsPublisher::~MdnsPublisher() { Error MdnsPublisher::RegisterRecord(const MdnsRecord& record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + OSP_DCHECK(record.dns_class() != DnsClass::kANY); - if (record.dns_type() == DnsType::kNSEC) { + if (!CanBePublished(record.dns_type())) { return Error::Code::kParameterInvalid; } - ValidateRecord(record); if (!IsRecordNameClaimed(record)) { return Error::Code::kParameterInvalid; @@ -101,11 +96,11 @@ Error MdnsPublisher::RegisterRecord(const MdnsRecord& record) { Error MdnsPublisher::UnregisterRecord(const MdnsRecord& record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + OSP_DCHECK(record.dns_class() != DnsClass::kANY); - if (record.dns_type() == DnsType::kNSEC) { + if (!CanBePublished(record.dns_type())) { return Error::Code::kParameterInvalid; } - ValidateRecord(record); OSP_DVLOG << "Unregistering record of type '" << record.dns_type() << "'"; @@ -116,7 +111,7 @@ Error MdnsPublisher::UpdateRegisteredRecord(const MdnsRecord& old_record, const MdnsRecord& new_record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - if (old_record.dns_type() == DnsType::kNSEC) { + if (!CanBePublished(new_record.dns_type())) { return Error::Code::kParameterInvalid; } @@ -371,8 +366,8 @@ void MdnsPublisher::ProcessRecordQueue() { } Clock::duration MdnsPublisher::RecordAnnouncer::GetNextAnnounceDelay() { - return std::chrono::duration_cast<Clock::duration>( - kMinAnnounceDelay * pow(kIntervalIncreaseFactor, attempts_)); + return Clock::to_duration(kMinAnnounceDelay * + pow(kIntervalIncreaseFactor, attempts_)); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h index 4b418312104..b9092697756 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher.h @@ -58,7 +58,13 @@ class MdnsPublisher : public MdnsResponder::RecordHandler { // ClaimExclusiveOwnership() method and for PTR records the name being pointed // to must have been claimed in the same fashion, but the domain name in the // top-level MdnsRecord entity does not. - // NOTE: NSEC records cannot be registered, and doing so will return an error. + // NOTE: This call is only valid for |dns_type| values: + // - DnsType::kA + // - DnsType::kPTR + // - DnsType::kTXT + // - DnsType::kAAAA + // - DnsType::kSRV + // - DnsType::kANY Error RegisterRecord(const MdnsRecord& record); // Updates the existing record with name matching the name of the new record. diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc index c05019f7468..a739a36dde1 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_publisher_unittest.cc @@ -4,7 +4,7 @@ #include "discovery/mdns/mdns_publisher.h" -#include <chrono> // NOLINT +#include <chrono> #include <vector> #include "discovery/common/config.h" @@ -86,8 +86,7 @@ class MdnsPublisherTest : public testing::Test { ~MdnsPublisherTest() { // Clear out any remaining calls in the task runner queue. - clock_.Advance( - std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1))); + clock_.Advance(Clock::to_duration(std::chrono::seconds(1))); } protected: @@ -360,7 +359,7 @@ TEST_F(MdnsPublisherTest, RegistrationAnnouncesEightTimes) { EXPECT_CALL(probe_manager_, IsDomainClaimed(domain_)) .WillRepeatedly(Return(true)); constexpr Clock::duration kOneSecond = - std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)); + Clock::to_duration(std::chrono::seconds(1)); // First announce, at registration. const MdnsRecord record = GetFakeARecord(domain_); diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc index ea5cef6e2be..2ae7260e1e1 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.cc @@ -4,6 +4,8 @@ #include "discovery/mdns/mdns_querier.h" +#include <memory> +#include <utility> #include <vector> #include "discovery/common/config.h" @@ -240,7 +242,7 @@ void MdnsQuerier::StartQuery(const DomainName& name, MdnsRecordChangedCallback* callback) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(callback); - OSP_DCHECK(dns_type != DnsType::kNSEC); + OSP_DCHECK(CanBeQueried(dns_type)); // Add a new callback if haven't seen it before auto callbacks_it = callbacks_.equal_range(name); @@ -298,7 +300,10 @@ void MdnsQuerier::StopQuery(const DomainName& name, MdnsRecordChangedCallback* callback) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(callback); - OSP_DCHECK(dns_type != DnsType::kNSEC); + + if (!CanBeQueried(dns_type)) { + return; + } // Find and remove the callback. int callbacks_for_key = 0; @@ -370,9 +375,7 @@ void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) { for (const MdnsRecord& record : message.answers()) { if (ShouldAnswerRecordBeProcessed(record)) { ProcessRecord(record); - OSP_DVLOG << "\tProcessing answer record for domain '" - << record.name().ToString() << "' of type '" - << record.dns_type() << "'..."; + OSP_DVLOG << "\tProcessing answer record (" << record.ToString() << ")"; found_relevant_records = true; processed_count++; } @@ -383,9 +386,8 @@ void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) { // individual records relevant to this querier to update the cache. for (const MdnsRecord& record : message.additional_records()) { if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) { - OSP_DVLOG << "\tProcessing additional record for domain '" - << record.name().ToString() << "' of type '" - << record.dns_type() << "'..."; + OSP_DVLOG << "\tProcessing additional record (" << record.ToString() + << ")"; ProcessRecord(record); processed_count++; } @@ -452,6 +454,11 @@ void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker, void MdnsQuerier::ProcessRecord(const MdnsRecord& record) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); + // Skip all records that can't be processed. + if (!CanBeProcessed(record.dns_type())) { + return; + } + // Get the types which the received record is associated with. In most cases // this will only be the type of the provided record, but in the case of // NSEC records this will be all records which the record dictates the @@ -528,16 +535,12 @@ void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record, if (will_exist) { ProcessCallbacks(record, RecordChangedEvent::kCreated); } - } - - // There is exactly one tracker associated with this key. This is the expected - // case when a record matching this one has already been seen. - else if (num_records_for_key == size_t{1}) { + } else if (num_records_for_key == size_t{1}) { + // There is exactly one tracker associated with this key. This is the + // expected case when a record matching this one has already been seen. ProcessSinglyTrackedUniqueRecord(record, trackers[0]); - } - - // Multiple records with the same key. - else { + } else { + // Multiple records with the same key. ProcessMultiTrackedUniqueRecord(record, dns_type); } } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h index 8f17790bd9d..07f1cbf7489 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_querier.h @@ -7,6 +7,8 @@ #include <list> #include <map> +#include <memory> +#include <vector> #include "discovery/common/config.h" #include "discovery/mdns/mdns_receiver.h" @@ -42,7 +44,13 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient { // Starts an mDNS query with the given name, DNS type, and DNS class. Updated // records are passed to |callback|. The caller must ensure |callback| // remains alive while it is registered with a query. - // NOTE: NSEC records cannot be queried for. + // NOTE: This call is only valid for |dns_type| values: + // - DnsType::kA + // - DnsType::kPTR + // - DnsType::kTXT + // - DnsType::kAAAA + // - DnsType::kSRV + // - DnsType::kANY void StartQuery(const DomainName& name, DnsType dns_type, DnsClass dns_class, @@ -164,6 +172,8 @@ class MdnsQuerier : public MdnsReceiver::ResponseClient { bool ShouldAnswerRecordBeProcessed(const MdnsRecord& answer); // Processes any record update, calling into the below methods as needed. + // NOTE: All records of type OPT are dropped, as they should not be cached per + // RFC6891. void ProcessRecord(const MdnsRecord& records); // Processes a shared record update as a record of type |type|. diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc index 409de085980..c6aa926b6be 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_reader.cc @@ -56,7 +56,7 @@ bool MdnsReader::Read(TxtRecordRdata::Entry* out) { } // RFC 1035: https://www.ietf.org/rfc/rfc1035.txt -// See section 4.1.4. Message compression +// See section 4.1.4. Message compression. bool MdnsReader::Read(DomainName* out) { OSP_DCHECK(out); const uint8_t* position = current(); diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc index b21d748f43f..bb8634d2198 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_receiver.cc @@ -4,6 +4,8 @@ #include "discovery/mdns/mdns_receiver.h" +#include <utility> + #include "discovery/mdns/mdns_reader.h" #include "util/trace_logging.h" @@ -66,6 +68,7 @@ void MdnsReceiver::OnRead(UdpSocket* socket, MdnsReader reader(config_, packet.data(), packet.size()); MdnsMessage message; if (!reader.Read(&message)) { + OSP_DVLOG << "mDNS message failed to parse..."; return; } @@ -74,13 +77,14 @@ void MdnsReceiver::OnRead(UdpSocket* socket, client->OnMessageReceived(message); } if (response_clients_.empty()) { - OSP_DVLOG << "Response message dropped. No response client registered..."; + OSP_DVLOG + << "mDNS response message dropped. No response client registered..."; } } else { if (query_callback_) { query_callback_(message, packet.source()); } else { - OSP_DVLOG << "Query message dropped. No query client registered..."; + OSP_DVLOG << "mDNS query message dropped. No query client registered..."; } } } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc index a04c2694165..eadbff9c31c 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.cc @@ -4,7 +4,11 @@ #include "discovery/mdns/mdns_records.h" +#include <algorithm> #include <cctype> +#include <limits> +#include <sstream> +#include <vector> #include "absl/strings/ascii.h" #include "absl/strings/match.h" @@ -455,6 +459,83 @@ size_t NsecRecordRdata::MaxWireSize() const { return next_domain_name_.MaxWireSize() + encoded_types_.size(); } +size_t OptRecordRdata::Option::MaxWireSize() const { + // One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC + // 6891 section 6.1.2. + constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t); + return data.size() + kOptionLengthAndCodeSize; +} + +bool OptRecordRdata::Option::operator>( + const OptRecordRdata::Option& rhs) const { + if (code != rhs.code) { + return code > rhs.code; + } else if (length != rhs.length) { + return length > rhs.length; + } else if (data.size() != rhs.data.size()) { + return data.size() > rhs.data.size(); + } + + for (int i = 0; i < static_cast<int>(data.size()); i++) { + if (data[i] != rhs.data[i]) { + return data[i] > rhs.data[i]; + } + } + + return false; +} + +bool OptRecordRdata::Option::operator<( + const OptRecordRdata::Option& rhs) const { + return rhs > *this; +} + +bool OptRecordRdata::Option::operator>=( + const OptRecordRdata::Option& rhs) const { + return !(*this < rhs); +} + +bool OptRecordRdata::Option::operator<=( + const OptRecordRdata::Option& rhs) const { + return !(*this > rhs); +} + +bool OptRecordRdata::Option::operator==( + const OptRecordRdata::Option& rhs) const { + return *this >= rhs && *this <= rhs; +} + +bool OptRecordRdata::Option::operator!=( + const OptRecordRdata::Option& rhs) const { + return !(*this == rhs); +} + +OptRecordRdata::OptRecordRdata() = default; + +OptRecordRdata::OptRecordRdata(std::vector<Option> options) + : options_(std::move(options)) { + for (const auto& option : options_) { + max_wire_size_ += option.MaxWireSize(); + } + std::sort(options_.begin(), options_.end()); +} + +OptRecordRdata::OptRecordRdata(const OptRecordRdata& other) = default; + +OptRecordRdata::OptRecordRdata(OptRecordRdata&& other) = default; + +OptRecordRdata& OptRecordRdata::operator=(const OptRecordRdata& rhs) = default; + +OptRecordRdata& OptRecordRdata::operator=(OptRecordRdata&& rhs) = default; + +bool OptRecordRdata::operator==(const OptRecordRdata& rhs) const { + return options_ == rhs.options_; +} + +bool OptRecordRdata::operator!=(const OptRecordRdata& rhs) const { + return !(*this == rhs); +} + // static ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name, DnsType dns_type, @@ -500,7 +581,12 @@ bool MdnsRecord::IsValidConfig(const DomainName& name, DnsType dns_type, std::chrono::seconds ttl, const Rdata& rdata) { - return !name.empty() && ttl.count() <= std::numeric_limits<uint32_t>::max() && + // NOTE: Although the name_ field was initially expected to be non-empty, this + // validation is no longer accurate for some record types (such as OPT + // records). To ensure that future record types correctly parse into + // RawRecordData types and do not invalidate the received message, this check + // has been removed. + return ttl.count() <= std::numeric_limits<uint32_t>::max() && ((dns_type == DnsType::kSRV && absl::holds_alternative<SrvRecordRdata>(rdata)) || (dns_type == DnsType::kA && @@ -513,6 +599,8 @@ bool MdnsRecord::IsValidConfig(const DomainName& name, absl::holds_alternative<TxtRecordRdata>(rdata)) || (dns_type == DnsType::kNSEC && absl::holds_alternative<NsecRecordRdata>(rdata)) || + (dns_type == DnsType::kOPT && + absl::holds_alternative<OptRecordRdata>(rdata)) || absl::holds_alternative<RawRecordRdata>(rdata)); } @@ -562,6 +650,34 @@ size_t MdnsRecord::MaxWireSize() const { return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8; } +std::string MdnsRecord::ToString() const { + std::stringstream ss; + ss << "name: '" << name_.ToString() << "'"; + ss << ", type: " << dns_type_; + + if (dns_type_ == DnsType::kPTR) { + const DomainName& target = absl::get<PtrRecordRdata>(rdata_).ptr_domain(); + ss << ", target: '" << target.ToString() << "'"; + } else if (dns_type_ == DnsType::kSRV) { + const DomainName& target = absl::get<SrvRecordRdata>(rdata_).target(); + ss << ", target: '" << target.ToString() << "'"; + } else if (dns_type_ == DnsType::kNSEC) { + const auto& nsec_rdata = absl::get<NsecRecordRdata>(rdata_); + std::vector<DnsType> types = nsec_rdata.types(); + ss << ", representing ["; + if (!types.empty()) { + auto it = types.begin(); + ss << *it++; + while (it != types.end()) { + ss << ", " << *it++; + } + ss << "]"; + } + } + + return ss.str(); +} + MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) { Rdata rdata; DnsType type; @@ -741,5 +857,69 @@ uint16_t CreateMessageId() { return id++; } +bool CanBePublished(DnsType type) { + // NOTE: A 'default' switch statement has intentionally been avoided below to + // enforce that new DnsTypes added must be added below through a compile-time + // check. + switch (type) { + case DnsType::kA: + case DnsType::kAAAA: + case DnsType::kPTR: + case DnsType::kTXT: + case DnsType::kSRV: + return true; + case DnsType::kOPT: + case DnsType::kNSEC: + case DnsType::kANY: + break; + } + + return false; +} + +bool CanBePublished(const MdnsRecord& record) { + return CanBePublished(record.dns_type()); +} + +bool CanBeQueried(DnsType type) { + // NOTE: A 'default' switch statement has intentionally been avoided below to + // enforce that new DnsTypes added must be added below through a compile-time + // check. + switch (type) { + case DnsType::kA: + case DnsType::kAAAA: + case DnsType::kPTR: + case DnsType::kTXT: + case DnsType::kSRV: + case DnsType::kANY: + return true; + case DnsType::kOPT: + case DnsType::kNSEC: + break; + } + + return false; +} + +bool CanBeProcessed(DnsType type) { + // NOTE: A 'default' switch statement has intentionally been avoided below to + // enforce that new DnsTypes added must be added below through a compile-time + // check. + switch (type) { + case DnsType::kA: + case DnsType::kAAAA: + case DnsType::kPTR: + case DnsType::kTXT: + case DnsType::kSRV: + case DnsType::kNSEC: + return true; + case DnsType::kOPT: + case DnsType::kANY: + break; + } + + return false; +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h index 0a1f76ab13c..9212b519b85 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records.h @@ -6,13 +6,14 @@ #define DISCOVERY_MDNS_MDNS_RECORDS_H_ #include <algorithm> -#include <chrono> // NOLINT +#include <chrono> #include <functional> #include <initializer_list> #include <string> #include <utility> #include <vector> +#include "absl/strings/ascii.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "discovery/mdns/public/mdns_constants.h" @@ -82,11 +83,16 @@ class DomainName { // compression the actual space taken in on-the-wire format is smaller. size_t MaxWireSize() const; bool empty() const { return labels_.empty(); } + bool IsRoot() const { return labels_.empty(); } const std::vector<std::string>& labels() const { return labels_; } template <typename H> friend H AbslHashValue(H h, const DomainName& domain_name) { - return H::combine(std::move(h), domain_name.labels_); + std::vector<std::string> labels_clone = domain_name.labels_; + for (auto& label : labels_clone) { + absl::AsciiStrToLower(&label); + } + return H::combine(std::move(h), std::move(labels_clone)); } private: @@ -129,9 +135,9 @@ class RawRecordRdata { }; // SRV record format (http://www.ietf.org/rfc/rfc2782.txt): -// 2 bytes network-order unsigned priority -// 2 bytes network-order unsigned weight -// 2 bytes network-order unsigned port +// 2 bytes network-order unsigned priority +// 2 bytes network-order unsigned weight +// 2 bytes network-order unsigned port // target: domain name (on-the-wire representation) class SrvRecordRdata { public: @@ -188,7 +194,8 @@ class ARecordRdata { template <typename H> friend H AbslHashValue(H h, const ARecordRdata& rdata) { - return H::combine(std::move(h), rdata.ipv4_address_.bytes()); + const auto& bytes = rdata.ipv4_address_.bytes(); + return H::combine_contiguous(std::move(h), bytes, 4); } private: @@ -217,7 +224,8 @@ class AAAARecordRdata { template <typename H> friend H AbslHashValue(H h, const AAAARecordRdata& rdata) { - return H::combine(std::move(h), rdata.ipv6_address_.bytes()); + const auto& bytes = rdata.ipv6_address_.bytes(); + return H::combine_contiguous(std::move(h), bytes, 16); } private: @@ -352,13 +360,84 @@ class NsecRecordRdata { DomainName next_domain_name_; }; +// The OPT pseudo-record / meta-record as defined by RFC6891. +class OptRecordRdata { + public: + // A single option as defined in RFC6891 section 6.1.2. + struct Option { + size_t MaxWireSize() const; + + bool operator>(const Option& rhs) const; + bool operator<(const Option& rhs) const; + bool operator>=(const Option& rhs) const; + bool operator<=(const Option& rhs) const; + bool operator==(const Option& rhs) const; + bool operator!=(const Option& rhs) const; + + template <typename H> + friend H AbslHashValue(H h, const Option& option) { + return H::combine(std::move(h), option.code, option.length, option.data); + } + + // Code assigned by the Expert Review process as defined by the DNSEXT + // working group and the IESG, as specified in RFC6891 section 9.1. For + // specific assignments, see: + // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml + uint16_t code; + + // Size (in octets) of |data|. + uint16_t length; + + // Bit Field with meaning varying based on |code|. + std::vector<uint8_t> data; + }; + + OptRecordRdata(); + + // Constructor that takes zero or more Option parameters. + template <typename... Types> + explicit OptRecordRdata(Types... types) + : OptRecordRdata(std::vector<Option>{std::move(types)...}) {} + explicit OptRecordRdata(std::vector<Option> options); + OptRecordRdata(const OptRecordRdata& other); + OptRecordRdata(OptRecordRdata&& other); + + OptRecordRdata& operator=(const OptRecordRdata& rhs); + OptRecordRdata& operator=(OptRecordRdata&& rhs); + + // NOTE: Only the options field is technically considered part of the rdata, + // so only this field is considered for equality comparison. The other fields + // are included here solely because their meaning differs for OPT pseudo- + // records and normal record types. + bool operator==(const OptRecordRdata& rhs) const; + bool operator!=(const OptRecordRdata& rhs) const; + + size_t MaxWireSize() const { return max_wire_size_; } + + // Set of options stored in this OPT record. + const std::vector<Option>& options() { return options_; } + + template <typename H> + friend H AbslHashValue(H h, const OptRecordRdata& rdata) { + return H::combine(std::move(h), rdata.options_); + } + + private: + // NOTE: The elements of |options_| are stored is sorted order to simplify the + // comparison operators of OptRecordRdata. + std::vector<Option> options_; + + size_t max_wire_size_ = 0; +}; + using Rdata = absl::variant<RawRecordRdata, SrvRecordRdata, ARecordRdata, AAAARecordRdata, PtrRecordRdata, TxtRecordRdata, - NsecRecordRdata>; + NsecRecordRdata, + OptRecordRdata>; // Resource record top level format (http://www.ietf.org/rfc/rfc1035.txt): // name: the name of the node to which this resource record pertains. @@ -408,10 +487,12 @@ class MdnsRecord { template <typename H> friend H AbslHashValue(H h, const MdnsRecord& record) { return H::combine(std::move(h), record.name_, record.dns_type_, - record.dns_class_, record.record_type_, record.ttl_, - record.rdata_); + record.dns_class_, record.record_type_, + record.ttl_.count(), record.rdata_); } + std::string ToString() const; + private: static bool IsValidConfig(const DomainName& name, DnsType dns_type, @@ -476,12 +557,12 @@ class MdnsQuestion { // id: 2 bytes network-order identifier assigned by the program that generates // any kind of query. This identifier is copied to the corresponding reply and // can be used by the requester to match up replies to outstanding queries. -// flags: 2 bytes network-order flags bitfield -// questions: questions in the message -// answers: resource records that answer the questions -// authority_records: resource records that point toward authoritative name +// flags: 2 bytes network-order flags bitfield. +// questions: questions in the message. +// answers: resource records that answer the questions. +// authority_records: resource records that point toward authoritative name. // servers additional_records: additional resource records that relate to the -// query +// query. class MdnsMessage { public: static ErrorOr<MdnsMessage> TryCreate( @@ -558,6 +639,16 @@ class MdnsMessage { uint16_t CreateMessageId(); +// Determines whether a record of the given type can be published. +bool CanBePublished(DnsType type); + +// Determines whether a record of the given type can be queried for. +bool CanBeQueried(DnsType type); + +// Determines whether a record of the given type received over the network +// should be processed. +bool CanBeProcessed(DnsType type); + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc index 0fab37ded24..395c9259c5d 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_records_unittest.cc @@ -4,6 +4,12 @@ #include "discovery/mdns/mdns_records.h" +#include <limits> +#include <string> +#include <utility> +#include <vector> + +#include "absl/hash/hash_testing.h" #include "discovery/mdns/mdns_reader.h" #include "discovery/mdns/mdns_writer.h" #include "discovery/mdns/testing/mdns_test_util.h" @@ -102,6 +108,9 @@ TEST(MdnsDomainNameTest, Compare) { EXPECT_FALSE(fourth < fifth); EXPECT_FALSE(fifth < fourth); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {first, second, third, fourth, fifth})); } TEST(MdnsDomainNameTest, CopyAndMove) { @@ -147,6 +156,9 @@ TEST(MdnsRawRecordRdataTest, Compare) { EXPECT_EQ(rdata1, rdata2); EXPECT_NE(rdata1, rdata3); + + EXPECT_TRUE( + absl::VerifyTypeImplementsAbslHashCorrectly({rdata1, rdata2, rdata3})); } TEST(MdnsRawRecordRdataTest, CopyAndMove) { @@ -185,6 +197,9 @@ TEST(MdnsSrvRecordRdataTest, Compare) { EXPECT_NE(rdata1, rdata4); EXPECT_NE(rdata1, rdata5); EXPECT_NE(rdata1, rdata6); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {rdata1, rdata2, rdata3, rdata4, rdata5, rdata6})); } TEST(MdnsSrvRecordRdataTest, CopyAndMove) { @@ -208,6 +223,9 @@ TEST(MdnsARecordRdataTest, Compare) { EXPECT_EQ(rdata1, rdata2); EXPECT_NE(rdata1, rdata3); + + EXPECT_TRUE( + absl::VerifyTypeImplementsAbslHashCorrectly({rdata1, rdata2, rdata3})); } TEST(MdnsARecordRdataTest, CopyAndMove) { @@ -249,6 +267,9 @@ TEST(MdnsAAAARecordRdataTest, Compare) { EXPECT_EQ(rdata1, rdata2); EXPECT_NE(rdata1, rdata3); + + EXPECT_TRUE( + absl::VerifyTypeImplementsAbslHashCorrectly({rdata1, rdata2, rdata3})); } TEST(MdnsAAAARecordRdataTest, CopyAndMove) { @@ -275,6 +296,9 @@ TEST(MdnsPtrRecordRdataTest, Compare) { EXPECT_EQ(rdata1, rdata2); EXPECT_NE(rdata1, rdata3); + + EXPECT_TRUE( + absl::VerifyTypeImplementsAbslHashCorrectly({rdata1, rdata2, rdata3})); } TEST(MdnsPtrRecordRdataTest, CopyAndMove) { @@ -300,6 +324,9 @@ TEST(MdnsTxtRecordRdataTest, Compare) { EXPECT_EQ(rdata1, rdata2); EXPECT_NE(rdata1, rdata3); EXPECT_NE(rdata1, rdata4); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {rdata1, rdata2, rdata3, rdata4})); } TEST(MdnsTxtRecordRdataTest, CopyAndMove) { @@ -428,6 +455,9 @@ TEST(MdnsNsecRecordRdataTest, Compare) { EXPECT_NE(rdata1, rdata3); EXPECT_NE(rdata1, rdata4); EXPECT_NE(rdata3, rdata4); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {rdata1, rdata2, rdata3, rdata4})); } TEST(MdnsNsecRecordRdataTest, CopyAndMove) { @@ -435,6 +465,57 @@ TEST(MdnsNsecRecordRdataTest, CopyAndMove) { DnsType::kSRV)); } +TEST(MdnsOptRecordRdataTest, Construct) { + OptRecordRdata rdata1; + EXPECT_EQ(rdata1.MaxWireSize(), size_t{0}); + EXPECT_EQ(rdata1.options().size(), size_t{0}); + + OptRecordRdata::Option opt1{12, 34, {0x12, 0x34}}; + OptRecordRdata::Option opt2{12, 34, {0x12, 0x34}}; + OptRecordRdata::Option opt3{12, 34, {0x12, 0x34, 0x56}}; + OptRecordRdata::Option opt4{34, 12, {0x00}}; + OptRecordRdata::Option opt5{12, 12, {0x12, 0x34}}; + rdata1 = OptRecordRdata(opt1, opt2, opt3, opt4, opt5); + EXPECT_EQ(rdata1.MaxWireSize(), size_t{30}); + + ASSERT_EQ(rdata1.options().size(), size_t{5}); + EXPECT_EQ(rdata1.options()[0], opt5); + EXPECT_EQ(rdata1.options()[1], opt1); + EXPECT_EQ(rdata1.options()[2], opt2); + EXPECT_EQ(rdata1.options()[3], opt3); + EXPECT_EQ(rdata1.options()[4], opt4); +} + +TEST(MdnsOptRecordRdataTest, Compare) { + OptRecordRdata::Option opt1{12, 34, {0x12, 0x34}}; + OptRecordRdata::Option opt2{12, 34, {0x12, 0x34}}; + OptRecordRdata::Option opt3{12, 34, {0x12, 0x56}}; + OptRecordRdata rdata1(opt1); + OptRecordRdata rdata2(opt2); + OptRecordRdata rdata3(opt3); + OptRecordRdata rdata4; + + EXPECT_EQ(rdata1, rdata1); + EXPECT_EQ(rdata2, rdata2); + EXPECT_EQ(rdata3, rdata3); + EXPECT_EQ(rdata4, rdata4); + + EXPECT_EQ(rdata1, rdata2); + EXPECT_NE(rdata1, rdata3); + EXPECT_NE(rdata1, rdata4); + EXPECT_NE(rdata2, rdata3); + EXPECT_NE(rdata2, rdata4); + EXPECT_NE(rdata3, rdata4); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {rdata1, rdata2, rdata3, rdata4})); +} + +TEST(MdnsOptRecordRdataTest, CopyAndMove) { + OptRecordRdata::Option opt1{12, 34, {0x12, 0x34}}; + TestCopyAndMove(OptRecordRdata(opt1)); +} + TEST(MdnsRecordTest, Construct) { MdnsRecord record1; EXPECT_EQ(record1.MaxWireSize(), UINT64_C(11)); @@ -489,6 +570,9 @@ TEST(MdnsRecordTest, Compare) { EXPECT_NE(record1, record5); EXPECT_NE(record1, record6); EXPECT_NE(record1, record7); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {record1, record2, record3, record4, record5, record6, record7})); } TEST(MdnsRecordTest, CopyAndMove) { @@ -531,6 +615,9 @@ TEST(MdnsQuestionTest, Compare) { EXPECT_NE(question1, question3); EXPECT_NE(question1, question4); EXPECT_NE(question1, question5); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {question1, question2, question3, question4, question5})); } TEST(MdnsQuestionTest, CopyAndMove) { @@ -656,6 +743,10 @@ TEST(MdnsMessageTest, Compare) { EXPECT_NE(message1, message6); EXPECT_NE(message1, message7); EXPECT_NE(message1, message8); + + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {message1, message2, message3, message4, message5, message6, message7, + message8})); } TEST(MdnsMessageTest, CopyAndMove) { @@ -677,5 +768,11 @@ TEST(MdnsMessageTest, CopyAndMove) { TestCopyAndMove(message); } +TEST(MdnsRecordOperations, CanBeProcessed) { + EXPECT_FALSE(CanBeProcessed(static_cast<DnsType>(1234))); + EXPECT_FALSE(CanBeProcessed(static_cast<DnsType>(222))); + EXPECT_FALSE(CanBeProcessed(static_cast<DnsType>(8973))); +} + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc index b97014c2650..953828e3590 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_responder.cc @@ -4,6 +4,7 @@ #include "discovery/mdns/mdns_responder.h" +#include <string> #include <utility> #include "discovery/common/config.h" @@ -24,7 +25,10 @@ const std::array<std::string, 3> kServiceEnumerationDomainLabels{ enum AddResult { kNonePresent = 0, kAdded, kAlreadyKnown }; -std::chrono::seconds GetTtlForRecordType(DnsType type) { +std::chrono::seconds GetTtlForNsecTargetingType(DnsType type) { + // NOTE: A 'default' switch statement has intentionally been avoided below to + // enforce that new DnsTypes added must be added below through a compile-time + // check. switch (type) { case DnsType::kA: return kARecordTtl; @@ -40,17 +44,23 @@ std::chrono::seconds GetTtlForRecordType(DnsType type) { // If no records are present, re-querying should happen at the minimum // of any record that might be retrieved at that time. return kSrvRecordTtl; - default: - OSP_NOTREACHED(); - return std::chrono::seconds{0}; + case DnsType::kNSEC: + case DnsType::kOPT: + // Neither of these types should ever be hit. We should never be creating + // an NSEC record for type NSEC, and OPT record querying is not supported, + // so creating NSEC records for type OPT is not valid. + break; } + + OSP_NOTREACHED() << "NSEC records do not support type " << type; + return std::chrono::seconds(0); } MdnsRecord CreateNsecRecord(DomainName target_name, DnsType target_type, DnsClass target_class) { auto rdata = NsecRecordRdata(target_name, target_type); - std::chrono::seconds ttl = GetTtlForRecordType(target_type); + std::chrono::seconds ttl = GetTtlForNsecTargetingType(target_type); return MdnsRecord(std::move(target_name), DnsType::kNSEC, target_class, RecordType::kUnique, ttl, std::move(rdata)); } @@ -185,14 +195,12 @@ void ApplyQueryResults(MdnsMessage* message, DnsType::kAAAA, clazz, target == domain); } } - } - - // Per RFC 6763 section 12.2, when querying for an SRV record, all address - // records of type A and AAAA should be added to the additional records - // section. Per RFC 6762 section 6.1, if these records are not present and - // their name and class match that which is being queried for, a negative - // response NSEC record may be added to show their non-existence. - else if (type == DnsType::kSRV) { + } else if (type == DnsType::kSRV) { + // Per RFC 6763 section 12.2, when querying for an SRV record, all address + // records of type A and AAAA should be added to the additional records + // section. Per RFC 6762 section 6.1, if these records are not present and + // their name and class match that which is being queried for, a negative + // response NSEC record may be added to show their non-existence. for (const auto& srv_record : message->answers()) { OSP_DCHECK(srv_record.dns_type() == DnsType::kSRV); @@ -203,13 +211,11 @@ void ApplyQueryResults(MdnsMessage* message, AddAdditionalRecords(message, record_handler, target, known_answers, DnsType::kAAAA, clazz, target == domain); } - } - - // Per RFC 6762 section 6.2, when querying for an address record of type A or - // AAAA, the record of the opposite type should be added to the additional - // records section if present. Else, a negative response NSEC record should be - // added to show its non-existence. - else if (type == DnsType::kA) { + } else if (type == DnsType::kA) { + // Per RFC 6762 section 6.2, when querying for an address record of type A + // or AAAA, the record of the opposite type should be added to the + // additional records section if present. Else, a negative response NSEC + // record should be added to show its non-existence. AddAdditionalRecords(message, record_handler, domain, known_answers, DnsType::kAAAA, clazz, true); } else if (type == DnsType::kAAAA) { @@ -397,7 +403,7 @@ void MdnsResponder::OnMessageReceived(const MdnsMessage& message, OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DCHECK(message.type() == MessageType::Query); - // Handle multi-packet known answer suppression + // Handle multi-packet known answer suppression. if (IsMultiPacketTruncatedQueryMessage(message)) { // If there have been an excessive number of known answers received already, // then skip them. This would most likely mean that: @@ -516,7 +522,7 @@ void MdnsResponder::ProcessQueries( for (const auto& question : questions) { OSP_DVLOG << "\tProcessing mDNS Query for domain: '" << question.name().ToString() << "', type: '" - << question.dns_type() << "'"; + << question.dns_type() << "' from '" << src << "'"; // NSEC records should not be queried for. if (question.dns_type() == DnsType::kNSEC) { @@ -589,18 +595,25 @@ void MdnsResponder::SendResponse( // method is called. Exclusive ownership cannot be gained for a record which // has previously been published, and if this host is the exclusive owner // then this method will have been called without any delay on the task - // runner + // runner. ApplyQueryResults(&message, record_handler_, question.name(), known_answers, question.dns_type(), question.dns_class(), is_exclusive_owner); } // Send the response only if it contains answers to the query. + OSP_DVLOG << "\tCompleted Processing mDNS Query for domain: '" + << question.name().ToString() << "', type: '" << question.dns_type() + << "', with " << message.answers().size() << " results:"; + for (const auto& record : message.answers()) { + OSP_DVLOG << "\t\tanswer (" << record.ToString() << ")"; + } + for (const auto& record : message.additional_records()) { + OSP_DVLOG << "\t\tadditional record ('" << record.ToString() << ")"; + } + if (!message.answers().empty()) { - OSP_DVLOG << "\tmDNS Query processed and response sent!"; send_response(message); - } else { - OSP_DVLOG << "\tmDNS Query processed and no response sent!"; } } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc index b780e8a22a1..e91df71801c 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.cc @@ -51,11 +51,6 @@ MdnsServiceImpl::MdnsServiceImpl( OSP_DCHECK(socket.value()->IsIPv4()); socket_v4_ = std::move(socket.value()); - socket_v4_->SetMulticastOutboundInterface(network_interface); - socket_v4_->JoinMulticastGroup(kDefaultMulticastGroupIPv4, - network_interface); - socket_v4_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv4, - network_interface); } if (supported_address_types & Config::NetworkInfo::kUseIpV6) { @@ -66,11 +61,6 @@ MdnsServiceImpl::MdnsServiceImpl( OSP_DCHECK(socket.value()->IsIPv6()); socket_v6_ = std::move(socket.value()); - socket_v6_->SetMulticastOutboundInterface(network_interface); - socket_v6_->JoinMulticastGroup(kDefaultMulticastGroupIPv6, - network_interface); - socket_v6_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv6, - network_interface); } // Initialize objects which depend on the above sockets. @@ -102,9 +92,25 @@ MdnsServiceImpl::MdnsServiceImpl( // used for reading on the mDNS v4 and v6 addresses and ports. if (socket_v4_.get()) { socket_v4_->Bind(); + + // This configuration must happen after the socket is bound for + // compatibility with chromium. + socket_v4_->SetMulticastOutboundInterface(network_interface); + socket_v4_->JoinMulticastGroup(kDefaultMulticastGroupIPv4, + network_interface); + socket_v4_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv4, + network_interface); } if (socket_v6_.get()) { socket_v6_->Bind(); + + // This configuration must happen after the socket is bound for + // compatibility with chromium. + socket_v6_->SetMulticastOutboundInterface(network_interface); + socket_v6_->JoinMulticastGroup(kDefaultMulticastGroupIPv6, + network_interface); + socket_v6_->JoinMulticastGroup(kDefaultSiteLocalGroupIPv6, + network_interface); } } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h index e87d84a2896..6a218139f0b 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_service_impl.h @@ -5,6 +5,8 @@ #ifndef DISCOVERY_MDNS_MDNS_SERVICE_IMPL_H_ #define DISCOVERY_MDNS_MDNS_SERVICE_IMPL_H_ +#include <memory> + #include "discovery/common/config.h" #include "discovery/mdns/mdns_domain_confirmed_provider.h" #include "discovery/mdns/mdns_probe_manager.h" @@ -74,7 +76,7 @@ class MdnsServiceImpl : public MdnsService, public UdpSocket::Client { MdnsRandom random_delay_; MdnsReceiver receiver_; - // Sockets to send and receive mDNS Data according to RFC 6762. + // Sockets to send and receive mDNS data. std::unique_ptr<UdpSocket> socket_v4_; std::unique_ptr<UdpSocket> socket_v6_; diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc index e250582197b..eb1e87fd2d7 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers.cc @@ -6,6 +6,7 @@ #include <array> #include <limits> +#include <utility> #include "discovery/common/config.h" #include "discovery/mdns/mdns_random.h" @@ -38,7 +39,7 @@ constexpr std::chrono::minutes kMaximumQueryInterval{60}; // A goodbye record is a record with TTL of 0. bool IsGoodbyeRecord(const MdnsRecord& record) { - return record.ttl() == std::chrono::seconds{0}; + return record.ttl() == std::chrono::seconds(0); } bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) { @@ -284,7 +285,7 @@ Clock::time_point MdnsRecordTracker::GetNextSendTime() { } const Clock::duration delay = - std::chrono::duration_cast<Clock::duration>(record_.ttl() * ttl_fraction); + Clock::to_duration(record_.ttl() * ttl_fraction); return start_time_ + delay; } diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc index bb80ceb1913..0edd8259b84 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_trackers_unittest.cc @@ -22,8 +22,7 @@ namespace discovery { namespace { constexpr Clock::duration kOneSecond = - std::chrono::duration_cast<Clock::duration>(std::chrono::seconds(1)); - + Clock::to_duration(std::chrono::seconds(1)); } using testing::_; @@ -130,8 +129,7 @@ class MdnsTrackerTest : public testing::Test { constexpr double kTtlFractions[] = {0.83, 0.88, 0.93, 0.98, 1.00}; Clock::duration time_passed{0}; for (double fraction : kTtlFractions) { - Clock::duration time_till_refresh = - std::chrono::duration_cast<Clock::duration>(ttl * fraction); + Clock::duration time_till_refresh = Clock::to_duration(ttl * fraction); Clock::duration delta = time_till_refresh - time_passed; time_passed = time_till_refresh; clock_.Advance(delta); @@ -251,8 +249,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerSendsMessage) { return Error::None(); }); - clock_.Advance( - std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.83)); + clock_.Advance(Clock::to_duration(a_record_.ttl() * 0.83)); } TEST_F(MdnsTrackerTest, RecordTrackerNoQueryAfterDestruction) { @@ -274,8 +271,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerUpdateResetsTtl) { expiration_called_ = false; std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); // Advance time by 60% of record's TTL - Clock::duration advance_time = - std::chrono::duration_cast<Clock::duration>(a_record_.ttl() * 0.6); + Clock::duration advance_time = Clock::to_duration(a_record_.ttl() * 0.6); clock_.Advance(advance_time); // Now update the record, this must reset expiration time EXPECT_EQ(tracker->Update(a_record_).value(), @@ -317,17 +313,17 @@ TEST_F(MdnsTrackerTest, RecordTrackerExpirationCallbackAfterGoodbye) { std::unique_ptr<MdnsRecordTracker> tracker = CreateRecordTracker(a_record_); MdnsRecord goodbye_record(a_record_.name(), a_record_.dns_type(), a_record_.dns_class(), a_record_.record_type(), - std::chrono::seconds{0}, a_record_.rdata()); + std::chrono::seconds(0), a_record_.rdata()); // After a goodbye record is received, expiration is schedule in a second. EXPECT_EQ(tracker->Update(goodbye_record).value(), MdnsRecordTracker::UpdateType::kGoodbye); // Advance clock to just before the expiration time of 1 second. - clock_.Advance(std::chrono::microseconds{999999}); + clock_.Advance(std::chrono::microseconds(999999)); EXPECT_FALSE(expiration_called_); // Advance clock to exactly the expiration time. - clock_.Advance(std::chrono::microseconds{1}); + clock_.Advance(std::chrono::microseconds(1)); EXPECT_TRUE(expiration_called_); } @@ -356,7 +352,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerInvalidPositiveRecordUpdate) { // RDATA must match the old RDATA for goodbye records MdnsRecord invalid_rdata(a_record_.name(), a_record_.dns_type(), a_record_.dns_class(), a_record_.record_type(), - std::chrono::seconds{0}, + std::chrono::seconds(0), ARecordRdata(IPAddress{172, 0, 0, 2})); EXPECT_EQ(tracker->Update(invalid_rdata).error(), Error::Code::kParameterInvalid); @@ -420,7 +416,7 @@ TEST_F(MdnsTrackerTest, RecordTrackerUpdateNegativeResponseWithPositive) { tracker = CreateRecordTracker(nsec_record_, DnsType::kA); MdnsRecord aaaa_record(a_record_.name(), DnsType::kAAAA, a_record_.dns_class(), a_record_.record_type(), - std::chrono::seconds{0}, + std::chrono::seconds(0), AAAARecordRdata(IPAddress{0, 0, 0, 0, 0, 0, 0, 1})); result = tracker->Update(aaaa_record); EXPECT_TRUE(result.is_error()); diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc index a95605b88ed..05880cb6412 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.cc @@ -4,6 +4,11 @@ #include "discovery/mdns/mdns_writer.h" +#include <limits> +#include <string> +#include <utility> +#include <vector> + #include "absl/hash/hash.h" #include "absl/strings/ascii.h" #include "util/hashing.h" @@ -202,6 +207,12 @@ bool MdnsWriter::Write(const NsecRecordRdata& rdata) { return false; } +bool MdnsWriter::Write(const OptRecordRdata& rdata) { + // OPT records are currently not supported for outgoing messages. + OSP_UNIMPLEMENTED(); + return false; +} + bool MdnsWriter::Write(const MdnsRecord& record) { Cursor cursor(this); if (Write(record.name()) && Write(static_cast<uint16_t>(record.dns_type())) && diff --git a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h index 3b8a3b08026..8dad9f06fe7 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/mdns_writer.h @@ -5,7 +5,9 @@ #ifndef DISCOVERY_MDNS_MDNS_WRITER_H_ #define DISCOVERY_MDNS_MDNS_WRITER_H_ +#include <string> #include <unordered_map> +#include <vector> #include "discovery/mdns/mdns_records.h" #include "util/big_endian.h" @@ -32,6 +34,7 @@ class MdnsWriter : public BigEndianWriter { bool Write(const PtrRecordRdata& rdata); bool Write(const TxtRecordRdata& rdata); bool Write(const NsecRecordRdata& rdata); + bool Write(const OptRecordRdata& rdata); // Writes a DNS resource record with its RDATA. // The correct type of RDATA to be written is contained in the type // specified in the record. diff --git a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h index 76a49d1b544..ecaa18bb95f 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/public/mdns_constants.h @@ -303,6 +303,7 @@ enum class DnsType : uint16_t { kTXT = 16, kAAAA = 28, kSRV = 33, + kOPT = 41, kNSEC = 47, kANY = 255, // Only allowed for QTYPE }; @@ -319,6 +320,8 @@ inline std::ostream& operator<<(std::ostream& output, DnsType type) { return output << "AAAA"; case DnsType::kSRV: return output << "SRV"; + case DnsType::kOPT: + return output << "OPT"; case DnsType::kNSEC: return output << "NSEC"; case DnsType::kANY: @@ -429,12 +432,32 @@ constexpr uint8_t kTXTEmptyRdata = 0; // RFC 6762 section 8.1 specifies that a probe should wait 250 ms between // subsequent probe queries. constexpr Clock::duration kDelayBetweenProbeQueries = - std::chrono::duration_cast<Clock::duration>(std::chrono::milliseconds{250}); + Clock::to_duration(std::chrono::milliseconds(250)); // RFC 6762 section 8.1 specifies that the probing phase should send out probe // requests 3 times before treating the probe as completed. constexpr int kProbeIterationCountBeforeSuccess = 3; +// ============================================================================ +// OPT Pseudo-Record Constants +// ============================================================================ + +// For OPT records, the TTL field has been re-purposed as follows: +// +// +0 (MSB) +1 (LSB) +// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ +// 0: | EXTENDED-RCODE | VERSION | +// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ +// 2: | DO| Z | +// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + +constexpr uint32_t kExtendedRcodeMask = 0xFF000000; +constexpr int kExtendedRcodeShift = 24; +constexpr uint32_t kVersionMask = 0x00FF0000; +constexpr int kVersionShift = 16; +constexpr uint32_t kDnssecOkBitMask = 0x00008000; +constexpr uint8_t kVersionBadvers = 0x10; + } // namespace discovery } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc index 6b154174a62..a6209cf08f9 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc +++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.cc @@ -4,6 +4,7 @@ #include "discovery/mdns/testing/mdns_test_util.h" +#include <string> #include <utility> #include <vector> @@ -23,31 +24,37 @@ MdnsRecord GetFakePtrRecord(const DomainName& target, DomainName name(++target.labels().begin(), target.labels().end()); PtrRecordRdata rdata(target); return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN, - RecordType::kShared, ttl, rdata); + RecordType::kShared, ttl, std::move(rdata)); } MdnsRecord GetFakeSrvRecord(const DomainName& name, std::chrono::seconds ttl) { - SrvRecordRdata rdata(0, 0, 80, name); + return GetFakeSrvRecord(name, name, ttl); +} + +MdnsRecord GetFakeSrvRecord(const DomainName& name, + const DomainName& target, + std::chrono::seconds ttl) { + SrvRecordRdata rdata(0, 0, kFakeSrvRecordPort, target); return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique, - ttl, rdata); + ttl, std::move(rdata)); } MdnsRecord GetFakeTxtRecord(const DomainName& name, std::chrono::seconds ttl) { TxtRecordRdata rdata; return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique, - ttl, rdata); + ttl, std::move(rdata)); } MdnsRecord GetFakeARecord(const DomainName& name, std::chrono::seconds ttl) { - ARecordRdata rdata(IPAddress(192, 168, 0, 0)); + ARecordRdata rdata(kFakeARecordAddress); return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique, ttl, - rdata); + std::move(rdata)); } MdnsRecord GetFakeAAAARecord(const DomainName& name, std::chrono::seconds ttl) { - AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8)); + AAAARecordRdata rdata(kFakeAAAARecordAddress); return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique, - ttl, rdata); + ttl, std::move(rdata)); } } // namespace discovery diff --git a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h index 419e1f06d9b..27bc5875e38 100644 --- a/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h +++ b/chromium/third_party/openscreen/src/discovery/mdns/testing/mdns_test_util.h @@ -13,6 +13,10 @@ namespace openscreen { namespace discovery { +const IPAddress kFakeARecordAddress = IPAddress(192, 168, 0, 0); +const IPAddress kFakeAAAARecordAddress = IPAddress(1, 2, 3, 4, 5, 6, 7, 8); +constexpr uint16_t kFakeSrvRecordPort = 80; + TxtRecordRdata MakeTxtRecord(std::initializer_list<absl::string_view> strings); // Methods to create fake MdnsRecord entities for use in UnitTests. @@ -20,6 +24,9 @@ MdnsRecord GetFakePtrRecord(const DomainName& target, std::chrono::seconds ttl = std::chrono::seconds(1)); MdnsRecord GetFakeSrvRecord(const DomainName& name, std::chrono::seconds ttl = std::chrono::seconds(1)); +MdnsRecord GetFakeSrvRecord(const DomainName& name, + const DomainName& target, + std::chrono::seconds ttl = std::chrono::seconds(1)); MdnsRecord GetFakeTxtRecord(const DomainName& name, std::chrono::seconds ttl = std::chrono::seconds(1)); MdnsRecord GetFakeARecord(const DomainName& name, diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h index 7a9bc3d67c2..9be8e194dc0 100644 --- a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h +++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher.h @@ -196,8 +196,9 @@ class DnsSdServiceWatcher : public DnsSdQuerier::Callback { // Set of all instance ids found so far, mapped to the T type that it // represents. unique_ptr<T> entities are used so that the const refs returned // from GetServices() and the ServicesUpdatedCallback can persist even once - // this map is resized. NOTE: Unordered map is used because this set is in - // many cases expected to be large. + // this map is resized. + // NOTE: Unordered map is used because this set is in many cases expected to + // be large. std::unordered_map<EndpointKey, std::unique_ptr<T>, PairHash> records_; // Represents whether discovery is currently running or not. diff --git a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc index 1e8adad6ebe..8d513253cf3 100644 --- a/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc +++ b/chromium/third_party/openscreen/src/discovery/public/dns_sd_service_watcher_unittest.cc @@ -32,8 +32,9 @@ std::vector<std::string> ConvertRefs( static const IPAddress kAddressV4(192, 168, 0, 0); static const IPEndpoint kEndpointV4{kAddressV4, 0}; -static const std::string kCastServiceId = "_googlecast._tcp"; -static const std::string kCastDomainId = "local"; +constexpr char kCastServiceId[] = "_googlecast._tcp"; +constexpr char kCastDomainId[] = "local"; +constexpr NetworkInterfaceIndex kNetworkInterface = 0; class MockDnsSdService : public DnsSdService { public: @@ -260,7 +261,8 @@ TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) { TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) { const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId, - DnsSdTxtRecord{}, kEndpointV4, 0); + DnsSdTxtRecord{}, kNetworkInterface, + kEndpointV4); CreateNewInstance(record); // Refresh services. @@ -277,10 +279,11 @@ TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) { TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) { const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId, - DnsSdTxtRecord{}, kEndpointV4, 0); + DnsSdTxtRecord{}, kNetworkInterface, + kEndpointV4); const DnsSdInstanceEndpoint record2("Instance2", kCastServiceId, kCastDomainId, DnsSdTxtRecord{}, - kEndpointV4, 0); + kNetworkInterface, kEndpointV4); EXPECT_FALSE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); diff --git a/chromium/third_party/openscreen/src/docs/style_guide.md b/chromium/third_party/openscreen/src/docs/style_guide.md index d397873a3e3..b3e8a120a00 100644 --- a/chromium/third_party/openscreen/src/docs/style_guide.md +++ b/chromium/third_party/openscreen/src/docs/style_guide.md @@ -1,24 +1,43 @@ # Open Screen Library Style Guide -The Open Screen Library follows the -[Chromium C++ coding style](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md). -We also follow the -[Chromium C++ Do's and Don'ts](https://sites.google.com/a/chromium.org/dev/developers/coding-style/cpp-dos-and-donts). +The Open Screen Library follows the [Chromium C++ coding style](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md) +which, in turn, defers to the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). +We also follow the [Chromium C++ Do's and Don'ts](https://sites.google.com/a/chromium.org/dev/developers/coding-style/cpp-dos-and-donts). C++14 language and library features are allowed in the Open Screen Library -according to the -[C++14 use in Chromium](https://chromium-cpp.appspot.com#core-whitelist) guidelines. +according to the [C++14 use in Chromium]( +https://chromium-cpp.appspot.com#core-whitelist) guidelines. ## Modifications to the Chromium C++ Guidelines - `<functional>` and `std::function` objects are allowed. - `<chrono>` is allowed and encouraged for representation of time. -- Abseil types are allowed based on the whitelist in [DEPS](https://chromium.googlesource.com/openscreen/+/refs/heads/master/DEPS). +- Abseil types are allowed based on the whitelist in [DEPS]( + https://chromium.googlesource.com/openscreen/+/refs/heads/master/DEPS). - However, Abseil types **must not be used in public APIs**. - `<thread>` and `<mutex>` are allowed, but discouraged from general use as the library only needs to handle threading in very specific places; see [threading.md](threading.md). +## Interacting with `std::chrono` + +One of the trickier parts of the Open Screen Library is using time and clock +functionality provided by `platform/api/time.h`. + +- When working extensively with `std::chrono` types in implementation code, + `util/chrono_helpers.h` header can be included for access to type aliases for + common `std::chrono` types, so they can just be referred to as `hours`, + `milliseconds`, etc. This header also includes helpful conversion functions, + such as `to_milliseconds` instead of + `std::chrono::duration_cast<std::chrono::milliseconds>`. + `util/chrono_helpers.h` cannot be used in headers exposed to embedders, and + this is enforced by DEPS. +- `Clock::duration` is defined currently as `std::chrono::microseconds`, and + thus is generally not suitable as a time type (developers generally think in + milliseconds). Prefer casting from explicit time types using + `Clock::to_duration`, e.g. `Clock::to_duration(seconds(2))` + instead of using `Clock::duration` types directly. + ## Open Screen Library Features - For public API functions that return values or errors, please return @@ -35,7 +54,7 @@ according to the ## Copy and Move Operators Use the following guidelines when deciding on copy and move semantics for -objects. +objects: - Objects with data members greater than 32 bytes should be move-able. - Known large objects (I/O buffers, etc.) should be be move-only. @@ -45,22 +64,22 @@ objects. We [prefer the use of `default` and `delete`](https://sites.google.com/a/chromium.org/dev/developers/coding-style/cpp-dos-and-donts#TOC-Prefer-to-use-default) to declare the copy and move semantics of objects. See -[Stoustrop's C++ FAQ](http://www.stroustrup.com/C++11FAQ.html#default) +[Stroustrup's C++ FAQ](http://www.stroustrup.com/C++11FAQ.html#default) for details on how to do that. ### User Defined Copy and Move Operators -Classes should follow the [rule of -three/five/zero](https://en.cppreference.com/w/cpp/language/rule_of_three), -meaning that if it has a custom destructor, copy contructor, or copy -assignment operator: +Classes should follow the [rule of three/five/zero](https://en.cppreference.com/w/cpp/language/rule_of_three). + +This means that if they implement a destructor or any of the copy or move +operators, then all five (destructor, copy & move constructors, copy & move +assignment operators) should be defined or marked as `delete`d as appropriate. +Finally, polymorphic base classes with virtual destructors should `default` all constructors, destructors, and assignment operators. -- All three operators must be defined (and not defaulted). -- It must also either: - - Have a custom move constructor *and* move assignment operator; - - Delete both of them if move semantics are not desired (in rare cases). -- Polymorphic base classes with virtual destructors should declare all - contructors, destructors and assignment operators as defaulted. +Note that operator definitions belong in the source (`.cc`) file, including +`default`, with the exception of `delete`, because it is not a definition, +rather a declaration that there is no definition, and thus belongs in the header +(`.h`) file. ## Noexcept @@ -95,9 +114,9 @@ from external inputs. Instead, one should code proper error-checking and handling for such things. OSP_CHECKs are "turned on" for all build types. However, OSP_DCHECKs are only -"turned on" in Debug builds, or in any build where the "dcheck_always_on=true" +"turned on" in Debug builds, or in any build where the `dcheck_always_on=true` GN argument is being used. In fact, at any time during development (including -Release builds), it is highly recommended to use "dcheck_always_on=true" to +Release builds), it is highly recommended to use `dcheck_always_on=true` to catch bugs. When OSP_DCHECKs are "turned off" they effectively become code comments: All @@ -107,3 +126,6 @@ strip-out unused functions and constants referenced in OSP_DCHECK expressions run-time/space overhead when the program runs. For this reason, a developer need not explicitly sprinkle "#if OSP_DCHECK_IS_ON()" guards all around any functions, variables, etc. that will be unused in "DCHECK off" builds. + +Use OSP_DCHECK and OSP_CHECK in accordance with the +[Chromium guidance for DCHECK/CHECK](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md#check_dcheck_and-notreached).
\ No newline at end of file diff --git a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg index d51fcc00bee..22a5b11392b 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg @@ -9,6 +9,7 @@ submit_options { } } config_groups { + name: "openscreen-build-config" gerrit { url: "https://chromium-review.googlesource.com" projects { @@ -22,6 +23,8 @@ config_groups { dry_run_access_list: "project-openscreen-tryjob-access" } tryjob { + # Bots declared "experiment_percentage: 100" are FYI Bots that always + # run but are not considered part of the commit queue pass/fail. builders { name: "openscreen/try/linux64_debug" } @@ -40,14 +43,16 @@ config_groups { builders { name: "openscreen/try/openscreen_presubmit" } - builders { name: "openscreen/try/chromium_linux64_debug" } builders { name: "openscreen/try/chromium_mac_debug" } - + builders { + name: "openscreen/try/linux64_coverage_debug" + experiment_percentage: 100 + } retry_config { single_quota: 1 global_quota: 2 diff --git a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg index 31ef94b509a..4bab09a0aac 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg @@ -34,6 +34,13 @@ acl_sets { } builder_mixins { + name: "ci" + recipe { + properties_j: "is_ci:true" + } +} + +builder_mixins { name: "debug" recipe { properties_j: "is_debug:true" @@ -137,6 +144,31 @@ builder_mixins { } } +builder_mixins { + name: "goma_rbe" + recipe: { + properties_j: <<EOF + $build/goma: { + "server_host": "goma.chromium.org", + "rpc_extra_params": "?prod" + } + EOF + } +} + +builder_mixins { + name: "goma_rbe_ats" + recipe: { + properties_j: <<EOF + $build/goma: { + "server_host": "goma.chromium.org", + "rpc_extra_params": "?prod", + "enable_ats": true + } + EOF + } +} + buckets { name: "luci.openscreen.ci" acl_sets: "ci" @@ -148,6 +180,22 @@ buckets { cipd_package: "infra/recipe_bundles/chromium.googlesource.com/chromium/tools/build" cipd_version: "refs/heads/master" name: "openscreen" + # Note: we use bash-style heredocs to avoid having to escape everything. + properties_j: <<EOF + $depot_tools/bot_update: { + "apply_patch_on_gclient":true + } + EOF + properties_j: <<EOF + $recipe_engine/isolated: { + "server": "https://isolateserver.appspot.com" + } + EOF + properties_j: <<EOF + $recipe_engine/swarming: { + "server": "https://chromium-swarm.appspot.com" + } + EOF } service_account: "openscreen-ci-builder@chops-service-accounts.iam.gserviceaccount.com" } @@ -158,6 +206,8 @@ buckets { mixins: "debug" mixins: "x64" mixins: "asan" + mixins: "ci" + mixins: "goma_rbe_ats" } builders { @@ -166,6 +216,7 @@ buckets { mixins: "debug" mixins: "x64" mixins: "gcc" + mixins: "ci" } builders { @@ -173,6 +224,8 @@ buckets { mixins: "linux" mixins: "x64" mixins: "tsan" + mixins: "ci" + mixins: "goma_rbe_ats" } builders { @@ -181,6 +234,8 @@ buckets { mixins: "arm64" mixins: "debug" mixins: "sysroot_platform_stretch" + mixins: "ci" + mixins: "goma_rbe_ats" } builders { @@ -188,6 +243,8 @@ buckets { mixins: "mac" mixins: "debug" mixins: "x64" + mixins: "ci" + mixins: "goma_rbe" } builders { @@ -196,6 +253,8 @@ buckets { mixins: "debug" mixins: "x64" mixins: "chromium" + mixins: "ci" + mixins: "goma_rbe_ats" } builders { @@ -204,6 +263,8 @@ buckets { mixins: "debug" mixins: "x64" mixins: "chromium" + mixins: "ci" + mixins: "goma_rbe" } # TODO(issuetracker.google.com/155812080): Integrate this with existing @@ -214,6 +275,8 @@ buckets { mixins: "debug" mixins: "x64" mixins: "code_coverage" + mixins: "ci" + mixins: "goma_rbe_ats" } } } @@ -231,17 +294,17 @@ buckets: { name: "openscreen" # Note: we use bash-style heredocs to avoid having to escape everything. properties_j: <<EOF - $depot_tools/bot_update:{ + $depot_tools/bot_update: { "apply_patch_on_gclient":true } EOF properties_j: <<EOF - $recipe_engine/isolated:{ + $recipe_engine/isolated: { "server": "https://isolateserver.appspot.com" } EOF properties_j: <<EOF - $recipe_engine/swarming:{ + $recipe_engine/swarming: { "server": "https://chromium-swarm.appspot.com" } EOF @@ -255,6 +318,7 @@ buckets: { mixins: "debug" mixins: "x64" mixins: "asan" + mixins: "goma_rbe_ats" } builders { @@ -270,6 +334,7 @@ buckets: { mixins: "linux" mixins: "x64" mixins: "tsan" + mixins: "goma_rbe_ats" } builders { @@ -278,6 +343,7 @@ buckets: { mixins: "arm64" mixins: "debug" mixins: "sysroot_platform_stretch" + mixins: "goma_rbe_ats" } builders { @@ -285,6 +351,7 @@ buckets: { mixins: "mac" mixins: "debug" mixins: "x64" + mixins: "goma_rbe" } builders { @@ -304,6 +371,7 @@ buckets: { mixins: "debug" mixins: "x64" mixins: "chromium" + mixins: "goma_rbe_ats" } builders { @@ -312,6 +380,18 @@ buckets: { mixins: "debug" mixins: "x64" mixins: "chromium" + mixins: "goma_rbe" + } + + # TODO(issuetracker.google.com/155812080): Integrate this with existing + # linux64_debug bot. + builders { + name: "linux64_coverage_debug" + mixins: "linux" + mixins: "debug" + mixins: "x64" + mixins: "code_coverage" + mixins: "goma_rbe_ats" } } } diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg index 974670a59f2..70df9ce0b79 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg @@ -50,7 +50,7 @@ consoles { } builders { - name: "buildbucket/luci.openscreen.ci/code_coverage" + name: "buildbucket/luci.openscreen.ci/linux64_coverage_debug" category: "linux|x64" short_name: "coverage" } @@ -106,7 +106,7 @@ consoles { } builders { - name: "buildbucket/luci.openscreen.try/code_coverage" + name: "buildbucket/luci.openscreen.try/linux64_coverage_debug" category: "linux|x64" short_name: "coverage" } diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg index dfe8b6ae1d3..13612af8589 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg @@ -27,7 +27,7 @@ trigger { triggers: "linux64_tsan" triggers: "linux_arm64_debug" triggers: "mac_debug" - triggers: "code_coverage" + triggers: "linux64_coverage_debug" } trigger { @@ -113,11 +113,11 @@ job { } job { - id: "code_coverage" + id: "linux64_coverage_debug" acl_sets: "default" buildbucket: { server: "cr-buildbucket.appspot.com" bucket: "luci.openscreen.ci" - builder: "code_coverage" + builder: "linux64_coverage_debug" } } diff --git a/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc b/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc index c14c7f297c0..c81ce9272a7 100644 --- a/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc +++ b/chromium/third_party/openscreen/src/osp/demo/osp_demo.cc @@ -415,7 +415,7 @@ void RunControllerPollLoop(Controller* controller) { request_delegate.connection->Terminate( TerminationReason::kControllerTerminateCalled); } - }; + } watch = Controller::ReceiverWatch(); } @@ -644,7 +644,8 @@ int main(int argc, char** argv) { // TODO(jophba): Mac on Mojave hangs on this command forever. openscreen::SetLogFifoOrDie(log_filename); - PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}); + PlatformClientPosix::Create(std::chrono::milliseconds(50), + std::chrono::milliseconds(50)); if (is_receiver_demo) { OSP_LOG_INFO << "Running publisher demo..."; diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc index a6c0e508165..8037a16cebf 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_demo.cc @@ -58,7 +58,7 @@ struct Service { class DemoSocketClient : public UdpSocket::Client { public: - DemoSocketClient(MdnsResponderAdapterImpl* mdns) : mdns_(mdns) {} + explicit DemoSocketClient(MdnsResponderAdapterImpl* mdns) : mdns_(mdns) {} void OnError(UdpSocket* socket, Error error) override { // TODO(crbug.com/openscreen/66): Change to OSP_LOG_FATAL. @@ -361,7 +361,8 @@ int main(int argc, char** argv) { openscreen::osp::ServiceMap services; openscreen::osp::g_services = &services; - PlatformClientPosix::Create(Clock::duration{50}, Clock::duration{50}); + PlatformClientPosix::Create(std::chrono::milliseconds(50), + std::chrono::milliseconds(50)); openscreen::osp::BrowseDemo( PlatformClientPosix::GetInstance()->GetTaskRunner(), labels[0], labels[1], diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc index 02fa9ba6ac2..0d31b2b82cf 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_platform.cc @@ -18,11 +18,15 @@ #include "third_party/mDNSResponder/src/mDNSCore/mDNSEmbeddedAPI.h" #include "util/osp_logging.h" +namespace { + using std::chrono::duration_cast; using std::chrono::hours; using std::chrono::milliseconds; using std::chrono::seconds; +} // namespace + extern "C" { const char ProgramName[] = "openscreen"; @@ -73,7 +77,13 @@ void mDNSPlatformLock(const mDNS* m) { void mDNSPlatformUnlock(const mDNS* m) {} void mDNSPlatformStrCopy(void* dst, const void* src) { - std::strcpy(static_cast<char*>(dst), static_cast<const char*>(src)); + const char* source = static_cast<const char*>(src); + const size_t source_len = strlen(source); + + // Unfortunately, the caller is responsible for making sure that dst + // if of sufficient length to store the src string. Otherwise we may + // cause an access violation. + std::strncpy(static_cast<char*>(dst), source, source_len); } mDNSu32 mDNSPlatformStrLen(const void* src) { diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc index 5a9891acf5b..0961d0d1fff 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc @@ -7,6 +7,7 @@ #include <cstdint> #include <iostream> #include <memory> +#include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -130,7 +131,8 @@ class WrapperMdnsResponderAdapterFactory final : public MdnsResponderAdapterFactory, public FakeMdnsResponderAdapter::LifetimeObserver { public: - WrapperMdnsResponderAdapterFactory(FakeMdnsResponderAdapterFactory* ptr) + explicit WrapperMdnsResponderAdapterFactory( + FakeMdnsResponderAdapterFactory* ptr) : other_(ptr) {} std::unique_ptr<MdnsResponderAdapter> Create() override { diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_common.h b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_common.h index eed24f94536..a69919fcf90 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_common.h +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_common.h @@ -7,6 +7,7 @@ #include <algorithm> #include <memory> +#include <string> #include "osp/msgs/osp_messages.h" #include "osp/public/message_demuxer.h" @@ -34,7 +35,7 @@ MessageDemuxer* GetClientDemuxer(); class PresentationID { public: - PresentationID(const std::string presentation_id); + explicit PresentationID(const std::string presentation_id); operator bool() { return id_; } operator std::string() { return id_.value(); } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc index bad4c66616a..343ea4d888a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc @@ -21,7 +21,6 @@ namespace openscreen { namespace osp { -using std::chrono::seconds; using ::testing::_; using ::testing::Invoke; using ::testing::NiceMock; @@ -77,8 +76,8 @@ class MockRequestDelegate final : public RequestDelegate { class ControllerTest : public ::testing::Test { public: ControllerTest() { - fake_clock_ = - std::make_unique<FakeClock>(Clock::time_point(seconds(11111))); + fake_clock_ = std::make_unique<FakeClock>( + Clock::time_point(std::chrono::seconds(11111))); task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc index 4b601b38b6f..029ed6c6cbf 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc @@ -4,7 +4,10 @@ #include "osp/impl/presentation/url_availability_requester.h" +#include <chrono> #include <memory> +#include <utility> +#include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -17,9 +20,6 @@ #include "platform/test/fake_task_runner.h" #include "util/osp_logging.h" -using std::chrono::milliseconds; -using std::chrono::seconds; - namespace openscreen { namespace osp { @@ -46,8 +46,8 @@ class MockReceiverObserver : public ReceiverObserver { class UrlAvailabilityRequesterTest : public Test { public: UrlAvailabilityRequesterTest() { - fake_clock_ = - std::make_unique<FakeClock>(Clock::time_point(milliseconds(1298424))); + fake_clock_ = std::make_unique<FakeClock>( + Clock::time_point(std::chrono::milliseconds(1298424))); task_runner_ = std::make_unique<FakeTaskRunner>(fake_clock_.get()); quic_bridge_ = std::make_unique<FakeQuicBridge>(task_runner_.get(), FakeClock::now); @@ -546,7 +546,7 @@ TEST_F(UrlAvailabilityRequesterTest, RefreshWatches) { EXPECT_CALL(mock_observer1, OnReceiverUnavailable(_, service_id_)).Times(0); quic_bridge_->RunTasksUntilIdle(); - fake_clock_->Advance(seconds(60)); + fake_clock_->Advance(std::chrono::seconds(60)); ExpectStreamMessage(&mock_callback_, &request); listener_.RefreshWatches(); @@ -667,7 +667,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserverInSteps) { quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ((std::vector<std::string>{url2_}), request.urls); - fake_clock_->Advance(seconds(60)); + fake_clock_->Advance(std::chrono::seconds(60)); listener_.RefreshWatches(); EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0); diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h index e3588f6a2f0..99b8959707a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h @@ -7,6 +7,7 @@ #include <map> #include <memory> +#include <vector> #include "osp/impl/quic/quic_connection_factory.h" #include "platform/api/udp_socket.h" @@ -22,7 +23,7 @@ class QuicTaskRunner; class QuicConnectionFactoryImpl final : public QuicConnectionFactory { public: - QuicConnectionFactoryImpl(TaskRunner* task_runner); + explicit QuicConnectionFactoryImpl(TaskRunner* task_runner); ~QuicConnectionFactoryImpl() override; // UdpSocket::Client overrides. diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h index 2aac1145359..834d2889c3e 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h @@ -5,6 +5,7 @@ #ifndef OSP_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_FACTORY_H_ #define OSP_IMPL_QUIC_TESTING_FAKE_QUIC_CONNECTION_FACTORY_H_ +#include <memory> #include <vector> #include "gmock/gmock.h" @@ -17,7 +18,8 @@ namespace osp { class FakeQuicConnectionFactoryBridge { public: - FakeQuicConnectionFactoryBridge(const IPEndpoint& controller_endpoint); + explicit FakeQuicConnectionFactoryBridge( + const IPEndpoint& controller_endpoint); bool server_idle() const { return server_idle_; } bool client_idle() const { return client_idle_; } diff --git a/chromium/third_party/openscreen/src/platform/BUILD.gn b/chromium/third_party/openscreen/src/platform/BUILD.gn index 375ba41a324..f510fbf8ca0 100644 --- a/chromium/third_party/openscreen/src/platform/BUILD.gn +++ b/chromium/third_party/openscreen/src/platform/BUILD.gn @@ -202,6 +202,7 @@ source_set("unittests") { "base/error_unittest.cc", "base/ip_address_unittest.cc", "base/location_unittest.cc", + "base/udp_packet_unittest.cc", ] # The socket integration tests assume that you can Bind with UDP sockets, diff --git a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h index 2843be4eea5..6ee2ca975c6 100644 --- a/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h +++ b/chromium/third_party/openscreen/src/platform/api/scoped_wake_lock.h @@ -7,6 +7,9 @@ #include <memory> +#include "platform/api/serial_delete_ptr.h" +#include "platform/api/task_runner.h" + namespace openscreen { // Ensures that the device does not got to sleep. This is used, for example, @@ -20,7 +23,7 @@ namespace openscreen { // instances have been destroyed. class ScopedWakeLock { public: - static std::unique_ptr<ScopedWakeLock> Create(); + static SerialDeletePtr<ScopedWakeLock> Create(TaskRunner* task_runner); // Instances are not copied nor moved. ScopedWakeLock(const ScopedWakeLock&) = delete; diff --git a/chromium/third_party/openscreen/src/platform/api/task_runner.h b/chromium/third_party/openscreen/src/platform/api/task_runner.h index e114db68860..c061086f8aa 100644 --- a/chromium/third_party/openscreen/src/platform/api/task_runner.h +++ b/chromium/third_party/openscreen/src/platform/api/task_runner.h @@ -5,7 +5,7 @@ #ifndef PLATFORM_API_TASK_RUNNER_H_ #define PLATFORM_API_TASK_RUNNER_H_ -#include <future> // NOLINT +#include <future> #include <utility> #include "platform/api/time.h" diff --git a/chromium/third_party/openscreen/src/platform/api/time_unittest.cc b/chromium/third_party/openscreen/src/platform/api/time_unittest.cc index 2bef8250f04..c0328526703 100644 --- a/chromium/third_party/openscreen/src/platform/api/time_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/api/time_unittest.cc @@ -4,12 +4,11 @@ #include "platform/api/time.h" +#include <chrono> #include <thread> #include "gtest/gtest.h" - -using std::chrono::microseconds; -using std::chrono::milliseconds; +#include "util/chrono_helpers.h" namespace openscreen { namespace { diff --git a/chromium/third_party/openscreen/src/platform/base/error.cc b/chromium/third_party/openscreen/src/platform/base/error.cc index e0c51d5a493..c1ea54679ae 100644 --- a/chromium/third_party/openscreen/src/platform/base/error.cc +++ b/chromium/third_party/openscreen/src/platform/base/error.cc @@ -240,6 +240,8 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: kUpdateReceivedRecordFailure"; case Error::Code::kRecordPublicationError: return os << "Failure: kRecordPublicationError"; + case Error::Code::kProcessReceivedRecordFailure: + return os << "Failure: ProcessReceivedRecordFailure"; } // Unused 'return' to get around failure on GCC. diff --git a/chromium/third_party/openscreen/src/platform/base/error.h b/chromium/third_party/openscreen/src/platform/base/error.h index e9aef2bb940..578db1bd568 100644 --- a/chromium/third_party/openscreen/src/platform/base/error.h +++ b/chromium/third_party/openscreen/src/platform/base/error.h @@ -161,6 +161,7 @@ class Error { // Discovery errors. kUpdateReceivedRecordFailure, kRecordPublicationError, + kProcessReceivedRecordFailure, // Generic errors. kUnknownError, @@ -341,6 +342,60 @@ class ErrorOr { const bool is_value_; }; +// Define comparison operators using SFINAE. +template <typename ValueType> +bool operator<(const ErrorOr<ValueType>& lhs, const ErrorOr<ValueType>& rhs) { + // Handle the cases where one side is an error. + if (lhs.is_error() != rhs.is_error()) { + return lhs.is_error(); + } + + // Handle the case where both sides are errors. + if (lhs.is_error()) { + return static_cast<int8_t>(lhs.error().code()) < + static_cast<int8_t>(rhs.error().code()); + } + + // Handle the case where both are values. + return lhs.value() < rhs.value(); +} + +template <typename ValueType> +bool operator>(const ErrorOr<ValueType>& lhs, const ErrorOr<ValueType>& rhs) { + return rhs < lhs; +} + +template <typename ValueType> +bool operator<=(const ErrorOr<ValueType>& lhs, const ErrorOr<ValueType>& rhs) { + return !(lhs > rhs); +} + +template <typename ValueType> +bool operator>=(const ErrorOr<ValueType>& lhs, const ErrorOr<ValueType>& rhs) { + return !(rhs < lhs); +} + +template <typename ValueType> +bool operator==(const ErrorOr<ValueType>& lhs, const ErrorOr<ValueType>& rhs) { + // Handle the cases where one side is an error. + if (lhs.is_error() != rhs.is_error()) { + return false; + } + + // Handle the case where both sides are errors. + if (lhs.is_error()) { + return lhs.error() == rhs.error(); + } + + // Handle the case where both are values. + return lhs.value() == rhs.value(); +} + +template <typename ValueType> +bool operator!=(const ErrorOr<ValueType>& lhs, const ErrorOr<ValueType>& rhs) { + return !(lhs == rhs); +} + } // namespace openscreen #endif // PLATFORM_BASE_ERROR_H_ diff --git a/chromium/third_party/openscreen/src/platform/base/error_unittest.cc b/chromium/third_party/openscreen/src/platform/base/error_unittest.cc index d8b6d839952..2af1c153f74 100644 --- a/chromium/third_party/openscreen/src/platform/base/error_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/base/error_unittest.cc @@ -130,4 +130,58 @@ TEST(ErrorOrTest, ErrorOrWithValue) { EXPECT_EQ(value.message, "Riverrun"); } +TEST(ErrorOrTest, ComparisonTests) { + ErrorOr<int> e1(7); + ErrorOr<int> e2(7); + ErrorOr<int> e3(2); + ErrorOr<int> e4(10); + + ErrorOr<int> e5(Error::Code::kAgain); + ErrorOr<int> e6(Error::Code::kCborParsing); + ErrorOr<int> e7(Error::Code::kCborEncoding); + ErrorOr<int> e8(Error::Code::kCborEncoding); + + ErrorOr<int> e9(Error::Code::kAgain, "foo"); + ErrorOr<int> e10(Error::Code::kAgain, "bar"); + + EXPECT_EQ(e1, e2); + EXPECT_EQ(e7, e8); + EXPECT_LE(e1, e2); + EXPECT_GE(e7, e8); + + EXPECT_NE(e1, e3); + EXPECT_NE(e1, e4); + EXPECT_NE(e1, e5); + EXPECT_NE(e1, e6); + EXPECT_NE(e1, e7); + EXPECT_NE(e5, e2); + EXPECT_NE(e5, e3); + EXPECT_NE(e5, e4); + EXPECT_NE(e5, e6); + EXPECT_NE(e5, e9); + EXPECT_NE(e5, e10); + EXPECT_NE(e9, e10); + + EXPECT_LT(e3, e1); + EXPECT_GT(e4, e1); + EXPECT_LT(e5, e6); + EXPECT_LE(e5, e9); + EXPECT_LE(e5, e10); + EXPECT_LE(e9, e10); + EXPECT_GT(e7, e6); + + EXPECT_GT(e1, e5); + EXPECT_GT(e2, e5); + EXPECT_GT(e3, e5); + EXPECT_GT(e4, e5); + EXPECT_GT(e1, e6); + EXPECT_GT(e2, e6); + EXPECT_GT(e3, e6); + EXPECT_GT(e4, e6); + EXPECT_GT(e1, e7); + EXPECT_GT(e2, e7); + EXPECT_GT(e3, e7); + EXPECT_GT(e4, e7); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.cc b/chromium/third_party/openscreen/src/platform/base/ip_address.cc index 6fee6a3c9c1..78fcd70fad8 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address.cc +++ b/chromium/third_party/openscreen/src/platform/base/ip_address.cc @@ -12,23 +12,32 @@ #include <cstring> #include <iomanip> #include <iterator> +#include <limits> #include <sstream> #include <utility> namespace openscreen { // static -const IPAddress IPAddress::kV4LoopbackAddress{127, 0, 0, 1}; +const IPAddress IPAddress::kAnyV4() { + return IPAddress{0, 0, 0, 0}; +} // static -const IPAddress IPAddress::kV6LoopbackAddress{0, 0, 0, 0, 0, 0, 0, 1}; +const IPAddress IPAddress::kAnyV6() { + return IPAddress{0, 0, 0, 0, 0, 0, 0, 0}; +} + +// static +const IPAddress IPAddress::kV4LoopbackAddress() { + return IPAddress{127, 0, 0, 1}; +} + +// static +const IPAddress IPAddress::kV6LoopbackAddress() { + return IPAddress{0, 0, 0, 0, 0, 0, 0, 1}; +} -IPAddress::IPAddress() : version_(Version::kV4), bytes_({}) {} -IPAddress::IPAddress(const std::array<uint8_t, 4>& bytes) - : version_(Version::kV4), - bytes_{{bytes[0], bytes[1], bytes[2], bytes[3]}} {} -IPAddress::IPAddress(const uint8_t (&b)[4]) - : version_(Version::kV4), bytes_{{b[0], b[1], b[2], b[3]}} {} IPAddress::IPAddress(Version version, const uint8_t* b) : version_(version) { if (version_ == Version::kV4) { bytes_ = {{b[0], b[1], b[2], b[3]}}; @@ -37,6 +46,14 @@ IPAddress::IPAddress(Version version, const uint8_t* b) : version_(version) { b[10], b[11], b[12], b[13], b[14], b[15]}}; } } + +IPAddress::IPAddress(const std::array<uint8_t, 4>& bytes) + : version_(Version::kV4), + bytes_{{bytes[0], bytes[1], bytes[2], bytes[3]}} {} + +IPAddress::IPAddress(const uint8_t (&b)[4]) + : version_(Version::kV4), bytes_{{b[0], b[1], b[2], b[3]}} {} + IPAddress::IPAddress(uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4) : version_(Version::kV4), bytes_{{b1, b2, b3, b4}} {} @@ -211,6 +228,16 @@ ErrorOr<IPAddress> IPAddress::Parse(const std::string& s) { return v4 ? std::move(v4) : ParseV6(s); } +// static +const IPEndpoint IPEndpoint::kAnyV4() { + return IPEndpoint{}; +} + +// static +const IPEndpoint IPEndpoint::kAnyV6() { + return IPEndpoint{IPAddress::kAnyV6(), 0}; +} + IPEndpoint::operator bool() const { return address || port; } @@ -330,7 +357,7 @@ std::ostream& operator<<(std::ostream& out, const IPEndpoint& endpoint) { std::string IPEndpoint::ToString() const { std::ostringstream name; - name << this; + name << *this; return name.str(); } diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.h b/chromium/third_party/openscreen/src/platform/base/ip_address.h index c37054d9e83..999a02652e0 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address.h +++ b/chromium/third_party/openscreen/src/platform/base/ip_address.h @@ -22,13 +22,14 @@ class IPAddress { kV6, }; - static const IPAddress kV4LoopbackAddress; - static const IPAddress kV6LoopbackAddress; - + static const IPAddress kAnyV4(); + static const IPAddress kAnyV6(); + static const IPAddress kV4LoopbackAddress(); + static const IPAddress kV6LoopbackAddress(); static constexpr size_t kV4Size = 4; static constexpr size_t kV6Size = 16; - IPAddress(); + constexpr IPAddress() : version_(Version::kV4), bytes_({}) {} // |bytes| contains 4 octets for IPv4, or 8 hextets (16 bytes of big-endian // shorts) for IPv6. @@ -61,6 +62,9 @@ class IPAddress { bool operator==(const IPAddress& o) const; bool operator!=(const IPAddress& o) const; + // IP address comparison rules are based on the following two principles: + // 1. newer versions are greater, e.g. IPv6 > IPv4 + // 2. higher numerical values are greater, e.g. 192.168.0.1 > 10.0.0.1 bool operator<(const IPAddress& other) const; bool operator>(const IPAddress& other) const { return other < *this; } bool operator<=(const IPAddress& other) const { return !(other < *this); } @@ -95,6 +99,9 @@ struct IPEndpoint { IPAddress address; uint16_t port = 0; + // Used with various socket types to indicate "any" address. + static const IPEndpoint kAnyV4(); + static const IPEndpoint kAnyV6(); explicit operator bool() const; // Parses a text representation of an IPv4/IPv6 address and port (e.g. @@ -112,10 +119,10 @@ inline bool operator>(const IPEndpoint& a, const IPEndpoint& b) { return b < a; } inline bool operator<=(const IPEndpoint& a, const IPEndpoint& b) { - return !(b > a); + return !(a > b); } inline bool operator>=(const IPEndpoint& a, const IPEndpoint& b) { - return !(a > b); + return !(a < b); } // Outputs a string of the form: diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc index 03a3c7537e5..65332eaeb2f 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc @@ -259,24 +259,10 @@ TEST(IPAddressTest, V6ParseThreeDigitValue) { TEST(IPAddressTest, IPEndpointBoolOperator) { IPEndpoint endpoint; - if (endpoint) { - FAIL(); - } - - endpoint = IPEndpoint{{192, 168, 0, 1}, 80}; - if (!endpoint) { - FAIL(); - } - - endpoint = IPEndpoint{{192, 168, 0, 1}, 0}; - if (!endpoint) { - FAIL(); - } - - endpoint = IPEndpoint{{}, 80}; - if (!endpoint) { - FAIL(); - } + ASSERT_FALSE((endpoint)); + ASSERT_TRUE((IPEndpoint{{192, 168, 0, 1}, 80})); + ASSERT_TRUE((IPEndpoint{{192, 168, 0, 1}, 0})); + ASSERT_TRUE((IPEndpoint{{}, 80})); } TEST(IPAddressTest, IPEndpointParse) { @@ -317,4 +303,115 @@ TEST(IPAddressTest, IPEndpointParse) { EXPECT_FALSE(IPEndpoint::Parse("[abcd::1]:99 ")); } +TEST(IPAddressTest, IPAddressComparisons) { + const IPAddress kV4Low{192, 168, 0, 1}; + const IPAddress kV4High{192, 168, 0, 2}; + const IPAddress kV6Low{0, 0, 0, 0, 0, 0, 0, 1}; + const IPAddress kV6High{0, 0, 1, 0, 0, 0, 0, 0}; + + EXPECT_TRUE(kV4Low == kV4Low); + EXPECT_TRUE(kV4High == kV4High); + EXPECT_TRUE(kV6Low == kV6Low); + EXPECT_TRUE(kV6High == kV6High); + EXPECT_FALSE(kV4Low == kV4High); + EXPECT_FALSE(kV4High == kV4Low); + EXPECT_FALSE(kV6Low == kV6High); + EXPECT_FALSE(kV6High == kV6Low); + + EXPECT_FALSE(kV4Low != kV4Low); + EXPECT_FALSE(kV4High != kV4High); + EXPECT_FALSE(kV6Low != kV6Low); + EXPECT_FALSE(kV6High != kV6High); + EXPECT_TRUE(kV4Low != kV4High); + EXPECT_TRUE(kV4High != kV4Low); + EXPECT_TRUE(kV6Low != kV6High); + EXPECT_TRUE(kV6High != kV6Low); + + EXPECT_TRUE(kV4Low < kV4High); + EXPECT_TRUE(kV4High < kV6Low); + EXPECT_TRUE(kV6Low < kV6High); + EXPECT_FALSE(kV6High < kV6Low); + EXPECT_FALSE(kV6Low < kV4High); + EXPECT_FALSE(kV4High < kV4Low); + + EXPECT_FALSE(kV4Low > kV4High); + EXPECT_FALSE(kV4High > kV6Low); + EXPECT_FALSE(kV6Low > kV6High); + EXPECT_TRUE(kV6High > kV6Low); + EXPECT_TRUE(kV6Low > kV4High); + EXPECT_TRUE(kV4High > kV4Low); + + EXPECT_TRUE(kV4Low <= kV4High); + EXPECT_TRUE(kV4High <= kV6Low); + EXPECT_TRUE(kV6Low <= kV6High); + EXPECT_TRUE(kV4Low <= kV4Low); + EXPECT_TRUE(kV4High <= kV4High); + EXPECT_TRUE(kV6Low <= kV6Low); + EXPECT_TRUE(kV6High <= kV6High); + EXPECT_FALSE(kV6High <= kV6Low); + EXPECT_FALSE(kV6Low <= kV4High); + EXPECT_FALSE(kV4High <= kV4Low); + + EXPECT_FALSE(kV4Low >= kV4High); + EXPECT_FALSE(kV4High >= kV6Low); + EXPECT_FALSE(kV6Low >= kV6High); + EXPECT_TRUE(kV4Low >= kV4Low); + EXPECT_TRUE(kV4High >= kV4High); + EXPECT_TRUE(kV6Low >= kV6Low); + EXPECT_TRUE(kV6High >= kV6High); + EXPECT_TRUE(kV6High >= kV6Low); + EXPECT_TRUE(kV6Low >= kV4High); + EXPECT_TRUE(kV4High >= kV4Low); +} + +TEST(IPAddressTest, IPEndpointComparisons) { + const IPEndpoint kV4LowHighPort{{192, 168, 0, 1}, 1000}; + const IPEndpoint kV4LowLowPort{{192, 168, 0, 1}, 1}; + const IPEndpoint kV4High{{192, 168, 0, 2}, 22}; + const IPEndpoint kV6Low{{0, 0, 0, 0, 0, 0, 0, 1}, 22}; + const IPEndpoint kV6High{{0, 0, 1, 0, 0, 0, 0, 0}, 22}; + + EXPECT_TRUE(kV4LowHighPort == kV4LowHighPort); + EXPECT_TRUE(kV4High == kV4High); + EXPECT_TRUE(kV6Low == kV6Low); + EXPECT_TRUE(kV6High == kV6High); + + EXPECT_TRUE(kV4LowLowPort != kV4LowHighPort); + EXPECT_TRUE(kV4LowLowPort != kV4High); + EXPECT_TRUE(kV4High != kV6Low); + EXPECT_TRUE(kV6Low != kV6High); + + EXPECT_TRUE(kV4LowLowPort < kV4LowHighPort); + EXPECT_TRUE(kV4LowLowPort < kV4High); + EXPECT_TRUE(kV4High < kV6Low); + EXPECT_TRUE(kV6Low < kV6High); + + EXPECT_TRUE(kV4LowHighPort > kV4LowLowPort); + EXPECT_TRUE(kV4High > kV4LowLowPort); + EXPECT_TRUE(kV6Low > kV4High); + EXPECT_TRUE(kV6High > kV6Low); + + EXPECT_TRUE(kV4LowLowPort <= kV4LowHighPort); + EXPECT_TRUE(kV4LowLowPort <= kV4High); + EXPECT_TRUE(kV4High <= kV6Low); + EXPECT_TRUE(kV6Low <= kV6High); + EXPECT_TRUE(kV4LowLowPort <= kV4LowHighPort); + EXPECT_TRUE(kV4LowLowPort <= kV4High); + EXPECT_TRUE(kV4High <= kV6Low); + EXPECT_TRUE(kV6Low <= kV6High); + + EXPECT_FALSE(kV4LowLowPort >= kV4LowHighPort); + EXPECT_FALSE(kV4LowLowPort >= kV4High); + EXPECT_FALSE(kV4High >= kV6Low); + EXPECT_FALSE(kV6Low >= kV6High); + EXPECT_TRUE(kV4LowHighPort >= kV4LowLowPort); + EXPECT_TRUE(kV4High >= kV4LowLowPort); + EXPECT_TRUE(kV6Low >= kV4High); + EXPECT_TRUE(kV6High >= kV6Low); + EXPECT_TRUE(kV4LowHighPort >= kV4LowLowPort); + EXPECT_TRUE(kV4High >= kV4LowLowPort); + EXPECT_TRUE(kV6Low >= kV4High); + EXPECT_TRUE(kV6High >= kV6Low); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h index 426a2b8a9ce..3e5c1fbf308 100644 --- a/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h +++ b/chromium/third_party/openscreen/src/platform/base/trivial_clock_traits.h @@ -15,14 +15,23 @@ namespace openscreen { class TrivialClockTraits { public: // TrivialClock named requirements: std::chrono templates can/may use these. + // NOTE: unless you are specifically integrating with the clock, you probably + // don't want to use these types, and instead should reference the std::chrono + // types directly. using duration = std::chrono::microseconds; using rep = duration::rep; using period = duration::period; using time_point = std::chrono::time_point<TrivialClockTraits, duration>; static constexpr bool is_steady = true; + // Helper method for named requirements. + template <typename D> + static constexpr duration to_duration(D d) { + return std::chrono::duration_cast<duration>(d); + } + // Time point values from the clock use microsecond precision, as a reasonably - // high-resoulution clock is required. The time source must tick forward at + // high-resolution clock is required. The time source must tick forward at // least 10000 times per second. using kRequiredResolution = std::ratio<1, 10000>; diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet.cc b/chromium/third_party/openscreen/src/platform/base/udp_packet.cc index 6cbb0da2f83..8470893e8e2 100644 --- a/chromium/third_party/openscreen/src/platform/base/udp_packet.cc +++ b/chromium/third_party/openscreen/src/platform/base/udp_packet.cc @@ -1,10 +1,11 @@ // Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file +// found in the LICENSE file. #include "platform/base/udp_packet.h" #include <cassert> +#include <sstream> namespace openscreen { @@ -26,4 +27,17 @@ UdpPacket::~UdpPacket() = default; UdpPacket& UdpPacket::operator=(UdpPacket&& other) = default; +std::string UdpPacket::ToString() const { + // TODO(issuetracker.google.com/158660166): Change to use shared hex-to-string + // method. + static constexpr char hex[] = "0123456789ABCDEF"; + std::stringstream ss; + ss << "["; + for (auto it = begin(); it != end(); it++) { + ss << hex[*it / 16] << hex[*it % 16] << " "; + } + ss << "]"; + return ss.str(); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet.h b/chromium/third_party/openscreen/src/platform/base/udp_packet.h index a8fcce045a1..3a3b262e985 100644 --- a/chromium/third_party/openscreen/src/platform/base/udp_packet.h +++ b/chromium/third_party/openscreen/src/platform/base/udp_packet.h @@ -1,12 +1,13 @@ // Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file +// found in the LICENSE file. #ifndef PLATFORM_BASE_UDP_PACKET_H_ #define PLATFORM_BASE_UDP_PACKET_H_ #include <stdint.h> +#include <string> #include <utility> #include <vector> @@ -44,6 +45,8 @@ class UdpPacket : public std::vector<uint8_t> { UdpSocket* socket() const { return socket_; } void set_socket(UdpSocket* socket) { socket_ = socket; } + std::string ToString() const; + static constexpr size_type kUdpMaxPacketSize = 1 << 16; private: diff --git a/chromium/third_party/openscreen/src/platform/base/udp_packet_unittest.cc b/chromium/third_party/openscreen/src/platform/base/udp_packet_unittest.cc new file mode 100644 index 00000000000..84949576857 --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/base/udp_packet_unittest.cc @@ -0,0 +1,34 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "platform/base/udp_packet.h" + +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace openscreen { + +TEST(UdpPacketTest, ValidateToStringNormalCase) { + UdpPacket packet{0x73, 0xC7, 0x00, 0x14, 0xFF, 0x2C}; + std::string result = packet.ToString(); + EXPECT_EQ(result, "[73 C7 00 14 FF 2C ]"); + + UdpPacket packet2{0x1, 0x2, 0x3, 0x4, 0x5}; + result = packet2.ToString(); + EXPECT_EQ(result, "[01 02 03 04 05 ]"); + + UdpPacket packet3{0x0, 0x0, 0x0}; + result = packet3.ToString(); + EXPECT_EQ(result, "[00 00 00 ]"); +} + +TEST(UdpPacketTest, ValidateToStringEmpty) { + UdpPacket packet{}; + std::string result = packet.ToString(); + EXPECT_EQ(result, "[]"); +} + +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/network_interface.cc b/chromium/third_party/openscreen/src/platform/impl/network_interface.cc index 17e240f2e45..00809f36a4c 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_interface.cc +++ b/chromium/third_party/openscreen/src/platform/impl/network_interface.cc @@ -29,8 +29,8 @@ absl::optional<InterfaceInfo> GetLoopbackInterfaceForTesting() { std::find_if( info.addresses.begin(), info.addresses.end(), [](const IPSubnet& subnet) { - return subnet.address == IPAddress::kV4LoopbackAddress || - subnet.address == IPAddress::kV6LoopbackAddress; + return subnet.address == IPAddress::kV4LoopbackAddress() || + subnet.address == IPAddress::kV6LoopbackAddress(); }) != info.addresses.end(); }); diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc index 4dd751c623d..c253586cece 100644 --- a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc +++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_linux.cc @@ -12,8 +12,10 @@ namespace openscreen { int ScopedWakeLockLinux::reference_count_ = 0; -std::unique_ptr<ScopedWakeLock> ScopedWakeLock::Create() { - return std::make_unique<ScopedWakeLockLinux>(); +SerialDeletePtr<ScopedWakeLock> ScopedWakeLock::Create( + TaskRunner* task_runner) { + return SerialDeletePtr<ScopedWakeLock>(task_runner, + new ScopedWakeLockLinux()); } namespace { @@ -45,12 +47,12 @@ ScopedWakeLockLinux::~ScopedWakeLockLinux() { // static void ScopedWakeLockLinux::AcquireWakeLock() { - OSP_UNIMPLEMENTED(); + OSP_VLOG << "Acquired wake lock: currently a noop"; } // static void ScopedWakeLockLinux::ReleaseWakeLock() { - OSP_UNIMPLEMENTED(); + OSP_VLOG << "Released wake lock: currently a noop"; } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc index aa873b41922..441ee6c69bf 100644 --- a/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc +++ b/chromium/third_party/openscreen/src/platform/impl/scoped_wake_lock_mac.cc @@ -14,8 +14,9 @@ namespace openscreen { ScopedWakeLockMac::LockState ScopedWakeLockMac::lock_state_{}; -std::unique_ptr<ScopedWakeLock> ScopedWakeLock::Create() { - return std::make_unique<ScopedWakeLockMac>(); +SerialDeletePtr<ScopedWakeLock> ScopedWakeLock::Create( + TaskRunner* task_runner) { + return SerialDeletePtr<ScopedWakeLock>(task_runner, new ScopedWakeLockMac()); } namespace { @@ -31,10 +32,11 @@ TaskRunner* GetTaskRunner() { } // namespace ScopedWakeLockMac::ScopedWakeLockMac() : ScopedWakeLock() { - OSP_DCHECK(GetTaskRunner()->IsRunningOnTaskRunner()); - if (lock_state_.reference_count++ == 0) { - AcquireWakeLock(); - } + GetTaskRunner()->PostTask([] { + if (lock_state_.reference_count++ == 0) { + AcquireWakeLock(); + } + }); } ScopedWakeLockMac::~ScopedWakeLockMac() { diff --git a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h index a0a1c7cb65b..ab3fddc31c1 100644 --- a/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/socket_handle_waiter_posix.h @@ -8,7 +8,8 @@ #include <unistd.h> #include <atomic> -#include <mutex> // NOLINT +#include <mutex> +#include <vector> #include "platform/impl/socket_handle_waiter.h" diff --git a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h index 93f81744329..082b4784716 100644 --- a/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/stream_socket_posix.h @@ -21,8 +21,8 @@ namespace openscreen { class StreamSocketPosix : public StreamSocket { public: - StreamSocketPosix(IPAddress::Version version); - StreamSocketPosix(const IPEndpoint& local_endpoint); + explicit StreamSocketPosix(IPAddress::Version version); + explicit StreamSocketPosix(const IPEndpoint& local_endpoint); StreamSocketPosix(SocketAddressPosix local_address, IPEndpoint remote_address, int file_descriptor); diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.h b/chromium/third_party/openscreen/src/platform/impl/task_runner.h index 777f2959f54..65dfbf5ea34 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner.h +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.h @@ -8,7 +8,7 @@ #include <condition_variable> // NOLINT #include <map> #include <memory> -#include <mutex> // NOLINT +#include <mutex> #include <thread> #include <utility> #include <vector> @@ -85,7 +85,7 @@ class TaskRunnerImpl final : public TaskRunner { // NOTE: 'explicit' keyword omitted so that conversion construtor can be // used. This simplifies switching between 'Task' and 'TaskWithMetadata' // based on the compilation flag. - TaskWithMetadata(Task task) + TaskWithMetadata(Task task) // NOLINT : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY) {} void operator()() { diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc index d7535cf34b2..910cfedb70c 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc @@ -4,19 +4,21 @@ #include "platform/impl/task_runner.h" +#include <unistd.h> + #include <atomic> -#include <thread> // NOLINT +#include <chrono> +#include <string> +#include <thread> #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/api/time.h" #include "platform/test/fake_clock.h" - +#include "util/chrono_helpers.h" namespace openscreen { namespace { -using namespace ::testing; -using std::chrono::milliseconds; using ::testing::_; const auto kTaskRunnerSleepTime = milliseconds(1); @@ -38,7 +40,7 @@ class FakeTaskWaiter final : public TaskRunnerImpl::TaskWaiter { Clock::time_point start = now_function_(); waiting_.store(true); while (!has_event_.load() && (now_function_() - start) < timeout) { - ; + EXPECT_EQ(usleep(100 /* microseconds */), 0); } waiting_.store(false); has_event_.store(false); diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc index 1054cef4f2e..e650fa1a396 100644 --- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc +++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc @@ -4,8 +4,10 @@ #include "platform/impl/text_trace_logging_platform.h" +#include <limits> #include <sstream> +#include "util/chrono_helpers.h" #include "util/osp_logging.h" namespace openscreen { @@ -32,9 +34,7 @@ void TextTraceLoggingPlatform::LogTrace(const char* name, Clock::time_point end_time, TraceIdHierarchy ids, Error::Code error) { - auto total_runtime = std::chrono::duration_cast<std::chrono::microseconds>( - end_time - start_time) - .count(); + auto total_runtime = to_microseconds(end_time - start_time).count(); constexpr auto microseconds_symbol = "\u03BCs"; // Greek Mu + 's' std::stringstream ss; ss << "TRACE [" << std::hex << ids.root << ":" << ids.parent << ":" diff --git a/chromium/third_party/openscreen/src/platform/impl/time.cc b/chromium/third_party/openscreen/src/platform/impl/time.cc index 21f41c74ff6..f4a3e85f292 100644 --- a/chromium/third_party/openscreen/src/platform/impl/time.cc +++ b/chromium/third_party/openscreen/src/platform/impl/time.cc @@ -4,15 +4,14 @@ #include "platform/api/time.h" +#include <chrono> #include <ctime> #include <ratio> +#include "util/chrono_helpers.h" #include "util/osp_logging.h" -using std::chrono::duration_cast; using std::chrono::high_resolution_clock; -using std::chrono::hours; -using std::chrono::seconds; using std::chrono::steady_clock; using std::chrono::system_clock; @@ -39,10 +38,10 @@ Clock::time_point Clock::now() noexcept { // or significant math actually taking place here. if (can_use_steady_clock) { return Clock::time_point( - duration_cast<Clock::duration>(steady_clock::now().time_since_epoch())); + Clock::to_duration(steady_clock::now().time_since_epoch())); } - return Clock::time_point(duration_cast<Clock::duration>( - high_resolution_clock::now().time_since_epoch())); + return Clock::time_point( + Clock::to_duration(high_resolution_clock::now().time_since_epoch())); } std::chrono::seconds GetWallTimeSinceUnixEpoch() noexcept { @@ -58,7 +57,7 @@ std::chrono::seconds GetWallTimeSinceUnixEpoch() noexcept { if (sizeof(std::time_t) <= 4) { constexpr std::time_t a_year_before_overflow = std::numeric_limits<std::time_t>::max() - - duration_cast<seconds>(365 * hours(24)).count(); + to_seconds(365 * hours(24)).count(); OSP_DCHECK_LE(since_epoch, a_year_before_overflow); } diff --git a/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc index 321bb7b32b5..0ab827a8ac0 100644 --- a/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/time_unittest.cc @@ -4,11 +4,11 @@ #include "platform/api/time.h" +#include <chrono> #include <ctime> #include "gtest/gtest.h" - -using std::chrono::seconds; +#include "util/chrono_helpers.h" namespace openscreen { diff --git a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc index 28c25ffe288..2a79e681f0f 100644 --- a/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/timeval_posix.cc @@ -6,16 +6,15 @@ #include <chrono> +#include "util/chrono_helpers.h" + namespace openscreen { struct timeval ToTimeval(const Clock::duration& timeout) { struct timeval tv; - const auto whole_seconds = - std::chrono::duration_cast<std::chrono::seconds>(timeout); + const auto whole_seconds = to_seconds(timeout); tv.tv_sec = whole_seconds.count(); - tv.tv_usec = std::chrono::duration_cast<std::chrono::microseconds>( - timeout - whole_seconds) - .count(); + tv.tv_usec = to_microseconds(timeout - whole_seconds).count(); return tv; } diff --git a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc index b345011eb69..5e9abb11adb 100644 --- a/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/tls_data_router_posix_unittest.cc @@ -32,7 +32,8 @@ class MockNetworkWaiter final : public SocketHandleWaiter { class MockSocket : public StreamSocketPosix { public: - MockSocket(int fd) : StreamSocketPosix(IPAddress::Version::kV4), handle(fd) {} + explicit MockSocket(int fd) + : StreamSocketPosix(IPAddress::Version::kV4), handle(fd) {} const SocketHandle& socket_handle() const override { return handle; } diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h index 2291bd010b2..fdc06eac324 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_reader_posix.h @@ -6,7 +6,7 @@ #define PLATFORM_IMPL_UDP_SOCKET_READER_POSIX_H_ #include <map> -#include <mutex> // NOLINT +#include <mutex> #include <vector> #include "platform/api/task_runner.h" @@ -28,7 +28,7 @@ class UdpSocketReaderPosix : public SocketHandleWaiter::Subscriber { // Creates a new instance of this object. // NOTE: The provided NetworkWaiter must outlive this object. explicit UdpSocketReaderPosix(SocketHandleWaiter* waiter); - virtual ~UdpSocketReaderPosix() override; + ~UdpSocketReaderPosix() override; // Waits for |socket| to be readable and then calls the socket's // RecieveMessage(...) method to process the available packet. diff --git a/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni b/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni index 4de38ac46e7..cbd33178d30 100644 --- a/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni +++ b/chromium/third_party/openscreen/src/testing/libfuzzer/fuzzer_test.gni @@ -75,9 +75,7 @@ template("openscreen_fuzzer_test") { } } - outputs = [ - out, - ] + outputs = [ out ] deps = [ "//testing/libfuzzer:seed_corpus" ] + seed_corpus_deps } @@ -92,12 +90,8 @@ template("openscreen_fuzzer_test") { if (defined(invoker.dict)) { # Copy dictionary to output. copy(target_name + "_dict_copy") { - sources = [ - invoker.dict, - ] - outputs = [ - "$root_build_dir/" + target_name + ".dict", - ] + sources = [ invoker.dict ] + outputs = [ "$root_build_dir/" + target_name + ".dict" ] } test_deps += [ ":" + target_name + "_dict_copy" ] } @@ -144,14 +138,13 @@ template("openscreen_fuzzer_test") { args += invoker.environment_variables } - outputs = [ - "$root_build_dir/$config_file_name", - ] + outputs = [ "$root_build_dir/$config_file_name" ] } test_deps += [ ":" + config_file_name ] } executable(target_name) { + testonly = true forward_variables_from(invoker, [ "cflags", diff --git a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn index 5f006f96073..a4ca3d53bcc 100644 --- a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn @@ -7,10 +7,7 @@ import("//build_overrides/build.gni") if (build_with_chromium) { source_set("abseil") { public_deps = [ - "//third_party/abseil-cpp/absl/hash", - "//third_party/abseil-cpp/absl/strings", - "//third_party/abseil-cpp/absl/types:optional", - "//third_party/abseil-cpp/absl/types:variant", + "//third_party/abseil-cpp:absl", ] } } else { @@ -44,10 +41,12 @@ if (build_with_chromium) { "src/absl/base/port.h", "src/absl/container/internal/common.h", "src/absl/hash/hash.h", + "src/absl/hash/hash_testing.h", "src/absl/hash/internal/city.cc", "src/absl/hash/internal/city.h", "src/absl/hash/internal/hash.cc", "src/absl/hash/internal/hash.h", + "src/absl/hash/internal/spy_hash_state.h", "src/absl/memory/memory.h", "src/absl/meta/type_traits.h", "src/absl/numeric/int128.cc", @@ -75,6 +74,7 @@ if (build_with_chromium) { "src/absl/strings/numbers.h", "src/absl/strings/str_cat.cc", "src/absl/strings/str_cat.h", + "src/absl/strings/str_format.h", "src/absl/strings/str_join.h", "src/absl/strings/str_replace.cc", "src/absl/strings/str_replace.h", diff --git a/chromium/third_party/openscreen/src/third_party/libprotobuf-mutator/BUILD.gn b/chromium/third_party/openscreen/src/third_party/libprotobuf-mutator/BUILD.gn new file mode 100644 index 00000000000..cc3eeaebdcb --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/libprotobuf-mutator/BUILD.gn @@ -0,0 +1,87 @@ +# Copyright 2020 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//build_overrides/build.gni") +import("//testing/libfuzzer/fuzzer_test.gni") +import("//third_party/libprotobuf-mutator/fuzzable_proto_library.gni") + +config("include_config") { + include_dirs = [ "src/" ] +} + +source_set("libprotobuf-mutator") { + testonly = true + + configs += [ ":include_config" ] + + public_configs = [ ":include_config" ] + sources = [ + "src/src/binary_format.cc", + "src/src/libfuzzer/libfuzzer_macro.cc", + "src/src/libfuzzer/libfuzzer_mutator.cc", + "src/src/mutator.cc", + "src/src/text_format.cc", + "src/src/utf8_fix.cc", + ] + + # Allow users of LPM to use protobuf reflection and other features from + # protobuf_full. + public_deps = [ "//third_party/protobuf:protobuf_full" ] +} + +# This protoc plugin, like the compiler, should only be built for the host +# architecture. +if (current_toolchain == host_toolchain) { + # This plugin will be needed to fuzz most protobuf code in Chromium. That's + # because production protobuf code must contain the line: + # "option optimize_for = LITE_RUNTIME", which instructs the proto compiler not + # to compile the proto using the full protobuf runtime. This allows Chromium + # not to depend on the full protobuf library, but prevents + # libprotobuf-mutator from fuzzing because the lite runtime lacks needed + # features (such as reflection). The plugin simply compiles a proto library + # as normal but ensures that is compiled with the full protobuf runtime. + executable("override_lite_runtime_plugin") { + sources = [ "protoc_plugin/protoc_plugin.cc" ] + deps = [ "//third_party/protobuf:protoc_lib" ] + public_configs = [ "//third_party/protobuf:protobuf_config" ] + } + # To use the plugin in a proto_library you want to fuzz, change the build + # target to fuzzable_proto_library (defined in + # //third_party/libprotobuf-mutator/fuzzable_proto_library.gni) +} + +# The CQ will try building this target without "use_libfuzzer" if it is defined. +# That will cause the build to fail, so don't define it when "use_libfuzzer" is +# is false. +if (use_libfuzzer) { + # Test that override_lite_runtime_plugin is working when built. This target + # contains files that are optimized for LITE_RUNTIME and which import other + # files that are also optimized for LITE_RUNTIME. + openscreen_fuzzer_test("override_lite_runtime_plugin_test_fuzzer") { + sources = [ "protoc_plugin/test_fuzzer.cc" ] + deps = [ + ":libprotobuf-mutator", + ":override_lite_runtime_plugin_test_fuzzer_proto", + ] + } +} + +# Proto library for override_lite_runtime_plugin_test_fuzzer +fuzzable_proto_library("override_lite_runtime_plugin_test_fuzzer_proto") { + sources = [ + "protoc_plugin/imported.proto", + "protoc_plugin/imported_publicly.proto", + "protoc_plugin/test_fuzzer_input.proto", + ] +} + +# Avoid CQ complaints on platforms we don't care about (ie: iOS). +# Also prevent people from using this to include protobuf_full into a production +# build of Chrome. +if (use_libfuzzer) { + # Component that can provide protobuf_full to non-testonly targets + static_library("protobuf_full") { + public_deps = [ "//third_party/protobuf:protobuf_full" ] + } +} diff --git a/chromium/third_party/openscreen/src/third_party/libprotobuf-mutator/fuzzable_proto_library.gni b/chromium/third_party/openscreen/src/third_party/libprotobuf-mutator/fuzzable_proto_library.gni new file mode 100644 index 00000000000..fee136c6f74 --- /dev/null +++ b/chromium/third_party/openscreen/src/third_party/libprotobuf-mutator/fuzzable_proto_library.gni @@ -0,0 +1,62 @@ +# Copyright 2020 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +# A fuzzable_proto_library is a proto_library that is the same as any other in +# non-fuzzer builds (ie: use_libfuzzer=false). However, in fuzzer builds, the +# proto_library is built with the full protobuf runtime and any "optimize_for = +# LITE_RUNTIME" options are ignored. This is done because libprotobuf-mutator +# needs the full protobuf runtime, but proto_libraries shipped in chrome must +# use the optimize for LITE_RUNTIME option which is incompatible with the full +# protobuf runtime. tl;dr: A fuzzable_proto_library is a proto_library that can +# be fuzzed with libprotobuf-mutator and shipped in Chrome. + +import("//build_overrides/build.gni") +import("//testing/libfuzzer/fuzzer_test.gni") +import("//third_party/protobuf/proto_library.gni") + +template("fuzzable_proto_library") { + # Only make the proto library fuzzable if we are doing a build that we can + # use LPM on (i.e. libFuzzer not on Chrome OS). + if (use_libfuzzer && current_toolchain != "//build/toolchain/cros:target") { + proto_library("proto_library_" + target_name) { + forward_variables_from(invoker, "*") + assert(current_toolchain == host_toolchain) + if (!defined(proto_deps)) { + proto_deps = [] + } + proto_deps += + [ "//third_party/libprotobuf-mutator:override_lite_runtime_plugin" ] + + extra_configs = [ "//third_party/protobuf:protobuf_config" ] + } + + # Inspired by proto_library.gni's handling of + # component_build_force_source_set. + if (defined(component_build_force_source_set) && + component_build_force_source_set && is_component_build) { + link_target_type = "source_set" + } else { + link_target_type = "static_library" + } + + # By making target a static_library or source_set, we can add protobuf_full + # to public_deps. + target(link_target_type, target_name) { + if (defined(invoker.testonly)) { + testonly = invoker.testonly + } + sources = [ "//third_party/libprotobuf-mutator/dummy.cc" ] + public_deps = [ + ":proto_library_" + target_name, + "//third_party/libprotobuf-mutator:protobuf_full", + ] + } + } else { + # fuzzable_proto_library should behave like a proto_library when + # !use_libfuzzer. + proto_library(target_name) { + forward_variables_from(invoker, "*") + } + } +} diff --git a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn index aa898d129ce..115164806c5 100644 --- a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn @@ -198,6 +198,7 @@ static_library("protobuf_full") { visibility = [ ":protoc_lib", "../chromium_quic/src/third_party:quic_trace", + "//third_party/libprotobuf-mutator:*", ] } diff --git a/chromium/third_party/openscreen/src/util/BUILD.gn b/chromium/third_party/openscreen/src/util/BUILD.gn index b21003462dc..000b3ea679c 100644 --- a/chromium/third_party/openscreen/src/util/BUILD.gn +++ b/chromium/third_party/openscreen/src/util/BUILD.gn @@ -23,6 +23,7 @@ source_set("util") { "alarm.h", "big_endian.cc", "big_endian.h", + "chrono_helpers.h", "crypto/certificate_utils.cc", "crypto/certificate_utils.h", "crypto/digest_sign.cc", @@ -37,6 +38,7 @@ source_set("util") { "crypto/sha2.h", "hashing.h", "integer_division.h", + "json/json_helpers.h", "json/json_serialization.cc", "json/json_serialization.h", "json/json_value.cc", @@ -86,6 +88,7 @@ source_set("unittests") { "crypto/secure_hash_unittest.cc", "crypto/sha2_unittest.cc", "integer_division_unittest.cc", + "json/json_helpers_unittest.cc", "json/json_serialization_unittest.cc", "json/json_value_unittest.cc", "operation_loop_unittest.cc", diff --git a/chromium/third_party/openscreen/src/util/alarm_unittest.cc b/chromium/third_party/openscreen/src/util/alarm_unittest.cc index 5fad74bf7f4..094afc9c42a 100644 --- a/chromium/third_party/openscreen/src/util/alarm_unittest.cc +++ b/chromium/third_party/openscreen/src/util/alarm_unittest.cc @@ -5,10 +5,12 @@ #include "util/alarm.h" #include <algorithm> +#include <chrono> #include "gtest/gtest.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" +#include "util/chrono_helpers.h" namespace openscreen { namespace { @@ -26,7 +28,7 @@ class AlarmTest : public testing::Test { }; TEST_F(AlarmTest, RunsTaskAsClockAdvances) { - constexpr Clock::duration kDelay = std::chrono::milliseconds(20); + constexpr Clock::duration kDelay = milliseconds(20); const Clock::time_point alarm_time = FakeClock::now() + kDelay; Clock::time_point actual_run_time{}; @@ -61,12 +63,12 @@ TEST_F(AlarmTest, RunsTaskImmediately) { ASSERT_EQ(expected_run_time, actual_run_time); // Confirm the lambda is only run once. - clock()->Advance(std::chrono::seconds(2)); + clock()->Advance(seconds(2)); ASSERT_EQ(expected_run_time, actual_run_time); } TEST_F(AlarmTest, CancelsTaskWhenGoingOutOfScope) { - constexpr Clock::duration kDelay = std::chrono::milliseconds(20); + constexpr Clock::duration kDelay = milliseconds(20); constexpr Clock::time_point kNever{}; Clock::time_point actual_run_time{}; @@ -85,7 +87,7 @@ TEST_F(AlarmTest, CancelsTaskWhenGoingOutOfScope) { } TEST_F(AlarmTest, Cancels) { - constexpr Clock::duration kDelay = std::chrono::milliseconds(20); + constexpr Clock::duration kDelay = milliseconds(20); const Clock::time_point alarm_time = FakeClock::now() + kDelay; Clock::time_point actual_run_time{}; @@ -104,8 +106,8 @@ TEST_F(AlarmTest, Cancels) { } TEST_F(AlarmTest, CancelsAndRearms) { - constexpr Clock::duration kShorterDelay = std::chrono::milliseconds(10); - constexpr Clock::duration kLongerDelay = std::chrono::milliseconds(100); + constexpr Clock::duration kShorterDelay = milliseconds(10); + constexpr Clock::duration kLongerDelay = milliseconds(100); // Run the test twice: Once when scheduling first with a long delay, then a // shorter delay; and once when scheduling first with a short delay, then a diff --git a/chromium/third_party/openscreen/src/util/big_endian.cc b/chromium/third_party/openscreen/src/util/big_endian.cc index d56589003ae..59992d605d4 100644 --- a/chromium/third_party/openscreen/src/util/big_endian.cc +++ b/chromium/third_party/openscreen/src/util/big_endian.cc @@ -30,4 +30,4 @@ bool BigEndianWriter::Write(const void* buffer, size_t length) { return false; } -} // namespace openscreen
\ No newline at end of file +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/big_endian.h b/chromium/third_party/openscreen/src/util/big_endian.h index b2067d7537e..0953ddf08a6 100644 --- a/chromium/third_party/openscreen/src/util/big_endian.h +++ b/chromium/third_party/openscreen/src/util/big_endian.h @@ -150,7 +150,7 @@ class BigEndianBuffer { public: class Cursor { public: - Cursor(BigEndianBuffer* buffer) + explicit Cursor(BigEndianBuffer* buffer) : buffer_(buffer), origin_(buffer_->current_) {} Cursor(const Cursor& other) = delete; Cursor(Cursor&& other) = delete; diff --git a/chromium/third_party/openscreen/src/util/chrono_helpers.h b/chromium/third_party/openscreen/src/util/chrono_helpers.h new file mode 100644 index 00000000000..a3711bbebb3 --- /dev/null +++ b/chromium/third_party/openscreen/src/util/chrono_helpers.h @@ -0,0 +1,50 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_CHRONO_HELPERS_H_ +#define UTIL_CHRONO_HELPERS_H_ + +#include <chrono> + +// This file is a collection of helpful utilities and using statement for +// working with std::chrono. In practice we previously defined these frequently, +// this header allows for a single set of convenience statements. +namespace openscreen { + +using hours = std::chrono::hours; +using microseconds = std::chrono::microseconds; +using milliseconds = std::chrono::milliseconds; +using nanoseconds = std::chrono::nanoseconds; +using seconds = std::chrono::seconds; + +// Casting statements. Note that duration_cast is not a type, it's a function, +// so its behavior is different than the using statements above. +template <typename D> +static constexpr hours to_hours(D d) { + return std::chrono::duration_cast<hours>(d); +} + +template <typename D> +static constexpr microseconds to_microseconds(D d) { + return std::chrono::duration_cast<microseconds>(d); +} + +template <typename D> +static constexpr milliseconds to_milliseconds(D d) { + return std::chrono::duration_cast<milliseconds>(d); +} + +template <typename D> +static constexpr nanoseconds to_nanoseconds(D d) { + return std::chrono::duration_cast<nanoseconds>(d); +} + +template <typename D> +static constexpr seconds to_seconds(D d) { + return std::chrono::duration_cast<seconds>(d); +} + +} // namespace openscreen + +#endif // UTIL_CHRONO_HELPERS_H_ diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc index ddedb533bb8..8f18b7ea9d8 100644 --- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc +++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.cc @@ -146,19 +146,6 @@ ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate( absl::string_view name, std::chrono::seconds duration, const EVP_PKEY& key_pair, - std::chrono::seconds time_since_unix_epoch) { - bssl::UniquePtr<X509> certificate = CreateCertificateInternal( - name, duration, key_pair, time_since_unix_epoch, false, nullptr, nullptr); - if (!certificate) { - return Error::Code::kCertificateCreationError; - } - return certificate; -} - -ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509CertificateForTest( - absl::string_view name, - std::chrono::seconds duration, - const EVP_PKEY& key_pair, std::chrono::seconds time_since_unix_epoch, bool make_ca, X509* issuer, diff --git a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h index e60c28c3d26..22da0330a0b 100644 --- a/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h +++ b/chromium/third_party/openscreen/src/util/crypto/certificate_utils.h @@ -23,22 +23,12 @@ namespace openscreen { // Generates a new RSA key pair with bit width |key_bits|. bssl::UniquePtr<EVP_PKEY> GenerateRsaKeyPair(int key_bits = 2048); -// Creates a new self-signed X509 certificate having the given |name| and -// |duration| until expiration, and based on the given |key_pair|, which is -// expected to contain a valid private key. -// |time_since_unix_epoch| is the current time. -ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate( - absl::string_view name, - std::chrono::seconds duration, - const EVP_PKEY& key_pair, - std::chrono::seconds time_since_unix_epoch = GetWallTimeSinceUnixEpoch()); - // Creates a new X509 certificate having the given |name| and |duration| until // expiration, and based on the given |key_pair|. If |issuer| and |issuer_key| // are provided, they are used to set the issuer information, otherwise it will // be self-signed. |make_ca| determines whether additional extensions are added // to make it a valid certificate authority cert. -ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509CertificateForTest( +ErrorOr<bssl::UniquePtr<X509>> CreateSelfSignedX509Certificate( absl::string_view name, std::chrono::seconds duration, const EVP_PKEY& key_pair, diff --git a/chromium/third_party/openscreen/src/util/crypto/secure_hash.h b/chromium/third_party/openscreen/src/util/crypto/secure_hash.h index 7c007f96350..f748cc458b5 100644 --- a/chromium/third_party/openscreen/src/util/crypto/secure_hash.h +++ b/chromium/third_party/openscreen/src/util/crypto/secure_hash.h @@ -21,7 +21,7 @@ namespace openscreen { // same as if we have the full input in advance. class SecureHash { public: - SecureHash(const EVP_MD* type); + explicit SecureHash(const EVP_MD* type); SecureHash(const SecureHash& other); SecureHash(SecureHash&& other); SecureHash& operator=(const SecureHash& other); diff --git a/chromium/third_party/openscreen/src/util/json/json_helpers.h b/chromium/third_party/openscreen/src/util/json/json_helpers.h new file mode 100644 index 00000000000..a4c43479edc --- /dev/null +++ b/chromium/third_party/openscreen/src/util/json/json_helpers.h @@ -0,0 +1,209 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef UTIL_JSON_JSON_HELPERS_H_ +#define UTIL_JSON_JSON_HELPERS_H_ + +#include <chrono> +#include <functional> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "json/value.h" +#include "platform/base/error.h" +#include "util/chrono_helpers.h" +#include "util/simple_fraction.h" + +// This file contains helper methods for parsing JSON, in an attempt to +// reduce boilerplate code when working with JsonCpp. +namespace openscreen { +namespace json { + +// TODO(jophba): remove these methods after refactoring offer messaging. +inline Error CreateParseError(const std::string& type) { + return Error(Error::Code::kJsonParseError, "Failed to parse " + type); +} + +inline Error CreateParameterError(const std::string& type) { + return Error(Error::Code::kParameterInvalid, "Invalid parameter: " + type); +} + +inline ErrorOr<bool> ParseBool(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isBool()) { + return CreateParseError("bool field " + field); + } + return value.asBool(); +} + +inline ErrorOr<int> ParseInt(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isInt()) { + return CreateParseError("integer field: " + field); + } + return value.asInt(); +} + +inline ErrorOr<uint32_t> ParseUint(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isUInt()) { + return CreateParseError("unsigned integer field: " + field); + } + return value.asUInt(); +} + +inline ErrorOr<std::string> ParseString(const Json::Value& parent, + const std::string& field) { + const Json::Value& value = parent[field]; + if (!value.isString()) { + return CreateParseError("string field: " + field); + } + return value.asString(); +} + +// TODO(jophba): offer messaging should use these methods instead. +inline bool ParseBool(const Json::Value& value, bool* out) { + if (!value.isBool()) { + return false; + } + *out = value.asBool(); + return true; +} + +// A general note about parsing primitives. "Validation" in this context +// generally means ensuring that the values are non-negative. There are +// currently no cases in our usage of JSON strings where we accept negative +// values. If this changes in the future, care must be taken to ensure +// that we don't break anything in existing code. +inline bool ParseAndValidateDouble(const Json::Value& value, double* out) { + if (!value.isDouble()) { + return false; + } + const double d = value.asDouble(); + if (d < 0) { + return false; + } + *out = d; + return true; +} + +inline bool ParseAndValidateInt(const Json::Value& value, int* out) { + if (!value.isInt()) { + return false; + } + int i = value.asInt(); + if (i < 0) { + return false; + } + *out = i; + return true; +} + +inline bool ParseAndValidateUint(const Json::Value& value, uint32_t* out) { + if (!value.isUInt()) { + return false; + } + *out = value.asUInt(); + return true; +} + +inline bool ParseAndValidateString(const Json::Value& value, std::string* out) { + if (!value.isString()) { + return false; + } + *out = value.asString(); + return true; +} + +// We want to be more robust when we parse fractions then just +// allowing strings, this will parse numeral values such as +// value: 50 as well as value: "50" and value: "100/2". +inline bool ParseAndValidateSimpleFraction(const Json::Value& value, + SimpleFraction* out) { + if (value.isInt()) { + int parsed = value.asInt(); + if (parsed < 0) { + return false; + } + *out = SimpleFraction{parsed, 1}; + return true; + } + + if (value.isString()) { + auto fraction_or_error = SimpleFraction::FromString(value.asString()); + if (!fraction_or_error) { + return false; + } + + if (!fraction_or_error.value().is_positive() || + !fraction_or_error.value().is_defined()) { + return false; + } + *out = std::move(fraction_or_error.value()); + return true; + } + return false; +} + +inline bool ParseAndValidateMilliseconds(const Json::Value& value, + milliseconds* out) { + int out_ms; + if (!ParseAndValidateInt(value, &out_ms) || out_ms < 0) { + return false; + } + *out = milliseconds(out_ms); + return true; +} + +template <typename T> +using Parser = std::function<bool(const Json::Value&, T*)>; + +// NOTE: array parsing methods reset the output vector to an empty vector in +// any error case. This is especially useful for optional arrays. +template <typename T> +bool ParseAndValidateArray(const Json::Value& value, + Parser<T> parser, + std::vector<T>* out) { + out->clear(); + if (!value.isArray() || value.empty()) { + return false; + } + + out->reserve(value.size()); + for (Json::ArrayIndex i = 0; i < value.size(); ++i) { + T v; + if (!parser(value[i], &v)) { + out->clear(); + return false; + } + out->push_back(v); + } + + return true; +} + +inline bool ParseAndValidateIntArray(const Json::Value& value, + std::vector<int>* out) { + return ParseAndValidateArray<int>(value, ParseAndValidateInt, out); +} + +inline bool ParseAndValidateUintArray(const Json::Value& value, + std::vector<uint32_t>* out) { + return ParseAndValidateArray<uint32_t>(value, ParseAndValidateUint, out); +} + +inline bool ParseAndValidateStringArray(const Json::Value& value, + std::vector<std::string>* out) { + return ParseAndValidateArray<std::string>(value, ParseAndValidateString, out); +} + +} // namespace json +} // namespace openscreen + +#endif // UTIL_JSON_JSON_HELPERS_H_ diff --git a/chromium/third_party/openscreen/src/util/json/json_helpers_unittest.cc b/chromium/third_party/openscreen/src/util/json/json_helpers_unittest.cc new file mode 100644 index 00000000000..fdac1897e4b --- /dev/null +++ b/chromium/third_party/openscreen/src/util/json/json_helpers_unittest.cc @@ -0,0 +1,209 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util/json/json_helpers.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/chrono_helpers.h" + +namespace openscreen { +namespace json { +namespace { + +using ::testing::ElementsAre; + +const Json::Value kNone; +const Json::Value kEmptyString = ""; +const Json::Value kEmptyArray(Json::arrayValue); + +struct Dummy { + int value; + + constexpr bool operator==(const Dummy& other) const { + return other.value == value; + } +}; + +bool ParseAndValidateDummy(const Json::Value& value, Dummy* out) { + int value_out; + if (!ParseAndValidateInt(value, &value_out)) { + return false; + } + *out = Dummy{value_out}; + return true; +} + +} // namespace + +TEST(ParsingHelpersTest, ParseAndValidateDouble) { + const Json::Value kValid = 13.37; + const Json::Value kNotDouble = "coffee beans"; + const Json::Value kNegativeDouble = -4.2; + const Json::Value kZeroDouble = 0.0; + + double out; + EXPECT_TRUE(ParseAndValidateDouble(kValid, &out)); + EXPECT_DOUBLE_EQ(13.37, out); + EXPECT_TRUE(ParseAndValidateDouble(kZeroDouble, &out)); + EXPECT_DOUBLE_EQ(0.0, out); + EXPECT_FALSE(ParseAndValidateDouble(kNotDouble, &out)); + EXPECT_FALSE(ParseAndValidateDouble(kNegativeDouble, &out)); + EXPECT_FALSE(ParseAndValidateDouble(kNone, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateInt) { + const Json::Value kValid = 1337; + const Json::Value kNotInt = "cold brew"; + const Json::Value kNegativeInt = -42; + const Json::Value kZeroInt = 0; + + int out; + EXPECT_TRUE(ParseAndValidateInt(kValid, &out)); + EXPECT_EQ(1337, out); + EXPECT_TRUE(ParseAndValidateInt(kZeroInt, &out)); + EXPECT_EQ(0, out); + EXPECT_FALSE(ParseAndValidateInt(kNone, &out)); + EXPECT_FALSE(ParseAndValidateInt(kNotInt, &out)); + EXPECT_FALSE(ParseAndValidateInt(kNegativeInt, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateUint) { + const Json::Value kValid = 1337u; + const Json::Value kNotUint = "espresso"; + const Json::Value kZeroUint = 0u; + + uint32_t out; + EXPECT_TRUE(ParseAndValidateUint(kValid, &out)); + EXPECT_EQ(1337u, out); + EXPECT_TRUE(ParseAndValidateUint(kZeroUint, &out)); + EXPECT_EQ(0u, out); + EXPECT_FALSE(ParseAndValidateUint(kNone, &out)); + EXPECT_FALSE(ParseAndValidateUint(kNotUint, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateString) { + const Json::Value kValid = "macchiato"; + const Json::Value kNotString = 42; + + std::string out; + EXPECT_TRUE(ParseAndValidateString(kValid, &out)); + EXPECT_EQ("macchiato", out); + EXPECT_TRUE(ParseAndValidateString(kEmptyString, &out)); + EXPECT_EQ("", out); + EXPECT_FALSE(ParseAndValidateString(kNone, &out)); + EXPECT_FALSE(ParseAndValidateString(kNotString, &out)); +} + +// Simple fraction validity is tested extensively in its unit tests, so we +// just check the major cases here. +TEST(ParsingHelpersTest, ParseAndValidateSimpleFraction) { + const Json::Value kValid = "42/30"; + const Json::Value kValidNumber = "42"; + const Json::Value kUndefined = "5/0"; + const Json::Value kNegative = "10/-2"; + const Json::Value kInvalidNumber = "-1"; + const Json::Value kNotSimpleFraction = "latte"; + + SimpleFraction out; + EXPECT_TRUE(ParseAndValidateSimpleFraction(kValid, &out)); + EXPECT_EQ((SimpleFraction{42, 30}), out); + EXPECT_TRUE(ParseAndValidateSimpleFraction(kValidNumber, &out)); + EXPECT_EQ((SimpleFraction{42, 1}), out); + EXPECT_FALSE(ParseAndValidateSimpleFraction(kUndefined, &out)); + EXPECT_FALSE(ParseAndValidateSimpleFraction(kNegative, &out)); + EXPECT_FALSE(ParseAndValidateSimpleFraction(kInvalidNumber, &out)); + EXPECT_FALSE(ParseAndValidateSimpleFraction(kNotSimpleFraction, &out)); + EXPECT_FALSE(ParseAndValidateSimpleFraction(kNone, &out)); + EXPECT_FALSE(ParseAndValidateSimpleFraction(kEmptyString, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateMilliseconds) { + const Json::Value kValid = 1000; + const Json::Value kValidFloat = 500.0; + const Json::Value kNegativeNumber = -120; + const Json::Value kZeroNumber = 0; + const Json::Value kNotNumber = "affogato"; + + milliseconds out; + EXPECT_TRUE(ParseAndValidateMilliseconds(kValid, &out)); + EXPECT_EQ(milliseconds(1000), out); + EXPECT_TRUE(ParseAndValidateMilliseconds(kValidFloat, &out)); + EXPECT_EQ(milliseconds(500), out); + EXPECT_TRUE(ParseAndValidateMilliseconds(kZeroNumber, &out)); + EXPECT_EQ(milliseconds(0), out); + EXPECT_FALSE(ParseAndValidateMilliseconds(kNone, &out)); + EXPECT_FALSE(ParseAndValidateMilliseconds(kNegativeNumber, &out)); + EXPECT_FALSE(ParseAndValidateMilliseconds(kNotNumber, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateArray) { + Json::Value valid_dummy_array; + valid_dummy_array[0] = 123; + valid_dummy_array[1] = 456; + + Json::Value invalid_dummy_array; + invalid_dummy_array[0] = "iced coffee"; + invalid_dummy_array[1] = 456; + + std::vector<Dummy> out; + EXPECT_TRUE(ParseAndValidateArray<Dummy>(valid_dummy_array, + ParseAndValidateDummy, &out)); + EXPECT_THAT(out, ElementsAre(Dummy{123}, Dummy{456})); + EXPECT_FALSE(ParseAndValidateArray<Dummy>(invalid_dummy_array, + ParseAndValidateDummy, &out)); + EXPECT_FALSE( + ParseAndValidateArray<Dummy>(kEmptyArray, ParseAndValidateDummy, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateIntArray) { + Json::Value valid_int_array; + valid_int_array[0] = 123; + valid_int_array[1] = 456; + + Json::Value invalid_int_array; + invalid_int_array[0] = "iced coffee"; + invalid_int_array[1] = 456; + + std::vector<int> out; + EXPECT_TRUE(ParseAndValidateIntArray(valid_int_array, &out)); + EXPECT_THAT(out, ElementsAre(123, 456)); + EXPECT_FALSE(ParseAndValidateIntArray(invalid_int_array, &out)); + EXPECT_FALSE(ParseAndValidateIntArray(kEmptyArray, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateUintArray) { + Json::Value valid_uint_array; + valid_uint_array[0] = 123u; + valid_uint_array[1] = 456u; + + Json::Value invalid_uint_array; + invalid_uint_array[0] = "breve"; + invalid_uint_array[1] = 456u; + + std::vector<uint32_t> out; + EXPECT_TRUE(ParseAndValidateUintArray(valid_uint_array, &out)); + EXPECT_THAT(out, ElementsAre(123u, 456u)); + EXPECT_FALSE(ParseAndValidateUintArray(invalid_uint_array, &out)); + EXPECT_FALSE(ParseAndValidateUintArray(kEmptyArray, &out)); +} + +TEST(ParsingHelpersTest, ParseAndValidateStringArray) { + Json::Value valid_string_array; + valid_string_array[0] = "nitro cold brew"; + valid_string_array[1] = "doppio espresso"; + + Json::Value invalid_string_array; + invalid_string_array[0] = "mocha latte"; + invalid_string_array[1] = 456; + + std::vector<std::string> out; + EXPECT_TRUE(ParseAndValidateStringArray(valid_string_array, &out)); + EXPECT_THAT(out, ElementsAre("nitro cold brew", "doppio espresso")); + EXPECT_FALSE(ParseAndValidateStringArray(invalid_string_array, &out)); + EXPECT_FALSE(ParseAndValidateStringArray(kEmptyArray, &out)); +} + +} // namespace json +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc b/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc index 7cdbfeeccad..49bf3987df1 100644 --- a/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc +++ b/chromium/third_party/openscreen/src/util/simple_fraction_unittest.cc @@ -4,6 +4,7 @@ #include "util/simple_fraction.h" +#include <cmath> #include <limits> #include "gtest/gtest.h" @@ -15,16 +16,15 @@ namespace { constexpr int kMin = std::numeric_limits<int>::min(); constexpr int kMax = std::numeric_limits<int>::max(); -void ExpectFromStringEquals(absl::string_view s, - const SimpleFraction& expected) { - const ErrorOr<SimpleFraction> f = SimpleFraction::FromString(s); - EXPECT_TRUE(f.is_value()); +void ExpectFromStringEquals(const char* s, const SimpleFraction& expected) { + const ErrorOr<SimpleFraction> f = SimpleFraction::FromString(std::string(s)); + EXPECT_TRUE(f.is_value()) << "from string: '" << s << "'"; EXPECT_EQ(expected, f.value()); } -void ExpectFromStringError(absl::string_view s) { - const auto f = SimpleFraction::FromString(s); - EXPECT_TRUE(f.is_error()); +void ExpectFromStringError(const char* s) { + const auto f = SimpleFraction::FromString(std::string(s)); + EXPECT_TRUE(f.is_error()) << "from string: '" << s << "'"; } } // namespace @@ -46,6 +46,7 @@ TEST(SimpleFractionTest, FromStringErrorsOnInvalid) { ExpectFromStringError("1/"); ExpectFromStringError("/1"); ExpectFromStringError("888/"); + ExpectFromStringError("1/2/3"); ExpectFromStringError("not a fraction at all"); } @@ -91,6 +92,7 @@ TEST(SimpleFractionTest, Positivity) { TEST(SimpleFractionTest, CastToDouble) { EXPECT_DOUBLE_EQ(0.0, static_cast<double>(SimpleFraction{0, 1})); EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{1, 1})); + EXPECT_TRUE(std::isnan(static_cast<double>(SimpleFraction{1, 0}))); EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{kMax, kMax})); EXPECT_DOUBLE_EQ(1.0, static_cast<double>(SimpleFraction{kMin, kMin})); } diff --git a/chromium/third_party/openscreen/src/util/std_util.h b/chromium/third_party/openscreen/src/util/std_util.h index 5a77bb5513a..87726955f94 100644 --- a/chromium/third_party/openscreen/src/util/std_util.h +++ b/chromium/third_party/openscreen/src/util/std_util.h @@ -5,8 +5,10 @@ #ifndef UTIL_STD_UTIL_H_ #define UTIL_STD_UTIL_H_ +#include <algorithm> #include <map> #include <string> +#include <utility> #include <vector> #include "absl/algorithm/container.h" @@ -53,6 +55,34 @@ void SortAndDedupeElements(RandomAccessContainer* c) { c->erase(new_end, c->end()); } +// Append the provided elements together into a single vector. This can be +// useful when creating a vector of variadic templates in the ctor. +// +// This is the base case for the recursion +template <typename T> +std::vector<T>&& Append(std::vector<T>&& so_far) { + return std::move(so_far); +} + +// This is the recursive call. Depending on the number of remaining elements, it +// either calls into itself or into the above base case. +template <typename T, typename TFirst, typename... TOthers> +std::vector<T>&& Append(std::vector<T>&& so_far, + TFirst&& new_element, + TOthers&&... new_elements) { + so_far.push_back(std::move(new_element)); + return Append(std::move(so_far), std::move(new_elements)...); +} + +// Creates an empty vector with |size| elements reserved. Intended to be used as +// GetEmptyVectorOfSize<T>(sizeof...(variadic_input)) +template <typename T> +std::vector<T> GetVectorWithCapacity(size_t size) { + std::vector<T> results; + results.reserve(size); + return results; +} + } // namespace openscreen #endif // UTIL_STD_UTIL_H_ diff --git a/chromium/third_party/openscreen/src/util/weak_ptr.h b/chromium/third_party/openscreen/src/util/weak_ptr.h index cb6967a1d59..c941017d511 100644 --- a/chromium/third_party/openscreen/src/util/weak_ptr.h +++ b/chromium/third_party/openscreen/src/util/weak_ptr.h @@ -6,6 +6,7 @@ #define UTIL_WEAK_PTR_H_ #include <memory> +#include <utility> #include "util/osp_logging.h" @@ -92,7 +93,7 @@ class WeakPtr { } // Create/Assign from nullptr. - WeakPtr(std::nullptr_t) {} + WeakPtr(std::nullptr_t) {} // NOLINT WeakPtr& operator=(std::nullptr_t) { impl_.reset(); |