From 98c4f115782ed1bffac9910ed75bb64e630fba0e Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 22 Jul 2024 08:40:27 +0200 Subject: [PATCH] feat(core): implement a new Trezor-Host protocol --- ci/build.yml | 1 + common/protob/messages-common.proto | 2 + common/protob/messages-debug.proto | 5 + common/protob/messages-thp.proto | 183 + common/protob/messages.proto | 26 + common/protob/pb2py | 2 + core/Makefile | 10 +- core/SConscript.firmware | 19 +- core/SConscript.unix | 4 +- .../modtrezorcrypto/modtrezorcrypto-aesgcm.h | 12 +- .../extmod/modtrezorutils/modtrezorutils.c | 2 +- core/embed/rust/src/protobuf/obj.rs | 2 +- .../mocks/generated/trezorcrypto/__init__.pyi | 12 +- core/mocks/generated/trezorproto.pyi | 2 +- core/mocks/generated/trezorui2.pyi | 1 - core/mocks/generated/trezorutils.pyi | 2 +- core/src/all_modules.py | 54 + core/src/apps/base.py | 63 +- .../apps/bitcoin/sign_tx/payment_request.py | 8 +- core/src/apps/cardano/seed.py | 31 +- core/src/apps/common/authorization.py | 28 +- core/src/apps/common/backup.py | 14 +- core/src/apps/common/cache.py | 23 + core/src/apps/common/passphrase.py | 88 +- core/src/apps/common/request_pin.py | 7 +- core/src/apps/common/safety_checks.py | 12 +- core/src/apps/common/seed.py | 107 +- core/src/apps/debug/__init__.py | 57 +- core/src/apps/management/get_nonce.py | 5 +- .../apps/management/reboot_to_bootloader.py | 2 +- .../management/recovery_device/homescreen.py | 5 +- core/src/apps/management/wipe_device.py | 22 +- core/src/apps/monero/live_refresh.py | 5 +- core/src/apps/thp/create_new_session.py | 51 + core/src/apps/thp/credential_manager.py | 6 +- core/src/apps/thp/pairing.py | 412 ++ core/src/apps/webauthn/fido2.py | 2 +- core/src/apps/workflow_handlers.py | 9 +- core/src/boot.py | 5 +- core/src/storage/cache.py | 358 +- core/src/storage/cache_codec.py | 142 + core/src/storage/cache_common.py | 179 + core/src/storage/cache_thp.py | 334 ++ core/src/trezor/enums/FailureType.py | 2 + core/src/trezor/enums/MessageType.py | 18 + core/src/trezor/enums/ThpPairingMethod.py | 8 + core/src/trezor/enums/__init__.py | 26 + core/src/trezor/log.py | 2 +- core/src/trezor/messages.py | 281 ++ core/src/trezor/ui/__init__.py | 2 +- core/src/trezor/ui/layouts/common.py | 2 +- core/src/trezor/ui/layouts/homescreen.py | 6 +- .../src/trezor/ui/layouts/mercury/__init__.py | 4 +- core/src/trezor/ui/layouts/tr/__init__.py | 2 +- core/src/trezor/ui/layouts/tr/reset.py | 5 +- core/src/trezor/utils.py | 14 +- core/src/trezor/wire/__init__.py | 275 +- core/src/trezor/wire/codec_v1.py | 9 +- core/src/trezor/wire/context.py | 121 +- core/src/trezor/wire/errors.py | 6 + core/src/trezor/wire/message_handler.py | 254 ++ core/src/trezor/wire/protocol_common.py | 73 + core/src/trezor/wire/thp/__init__.py | 80 + .../wire/thp/alternating_bit_protocol.py | 102 + core/src/trezor/wire/thp/channel.py | 402 ++ core/src/trezor/wire/thp/channel_manager.py | 34 + core/src/trezor/wire/thp/checksum.py | 22 + core/src/trezor/wire/thp/control_byte.py | 47 + core/src/trezor/wire/thp/cpace.py | 36 + core/src/trezor/wire/thp/crypto.py | 211 + core/src/trezor/wire/thp/interface_manager.py | 30 + core/src/trezor/wire/thp/memory_manager.py | 174 + core/src/trezor/wire/thp/pairing_context.py | 251 ++ .../wire/thp/received_message_handler.py | 413 ++ core/src/trezor/wire/thp/session_context.py | 198 + core/src/trezor/wire/thp/session_manager.py | 37 + core/src/trezor/wire/thp/thp_messages.py | 136 + core/src/trezor/wire/thp/transmission_loop.py | 51 + core/src/trezor/wire/thp/writer.py | 91 + core/src/trezor/wire/thp_main.py | 176 + core/src/trezor/workflow.py | 10 +- core/tests/mock_wire_interface.py | 17 + core/tests/myTests.sh | 42 + core/tests/test_apps.bitcoin.approver.py | 24 +- core/tests/test_apps.bitcoin.authorization.py | 24 +- core/tests/test_apps.bitcoin.keychain.py | 52 +- core/tests/test_apps.common.keychain.py | 36 +- core/tests/test_apps.ethereum.keychain.py | 36 +- core/tests/test_storage.cache.py | 709 +++- core/tests/test_trezor.wire.codec_v1.py | 19 +- core/tests/test_trezor.wire.thp.checksum.py | 94 + ...test_trezor.wire.thp.credential_manager.py | 66 + core/tests/test_trezor.wire.thp.crypto.py | 156 + core/tests/test_trezor.wire.thp.py | 370 ++ core/tests/test_trezor.wire.thp.writer.py | 150 + core/tests/test_trezor.wire.thp_deprecated.py | 338 ++ core/tests/thp_common.py | 43 + core/tools/codegen/get_trezor_keys.py | 2 +- docs/ci/jobs.md | 34 +- legacy/firmware/fsm.c | 3 + legacy/firmware/protob/Makefile | 2 +- legacy/firmware/protob/messages-thp.proto | 1 + python/channel_data.json | 1 + python/src/trezorlib/_internal/emulator.py | 2 +- python/src/trezorlib/authentication.py | 6 +- python/src/trezorlib/binance.py | 19 +- python/src/trezorlib/btc.py | 55 +- python/src/trezorlib/cardano.py | 36 +- python/src/trezorlib/cli/__init__.py | 320 +- python/src/trezorlib/cli/binance.py | 22 +- python/src/trezorlib/cli/btc.py | 49 +- python/src/trezorlib/cli/cardano.py | 32 +- python/src/trezorlib/cli/crypto.py | 22 +- python/src/trezorlib/cli/debug.py | 97 +- python/src/trezorlib/cli/device.py | 95 +- python/src/trezorlib/cli/eos.py | 16 +- python/src/trezorlib/cli/ethereum.py | 50 +- python/src/trezorlib/cli/fido.py | 34 +- python/src/trezorlib/cli/firmware.py | 57 +- python/src/trezorlib/cli/monero.py | 16 +- python/src/trezorlib/cli/nem.py | 16 +- python/src/trezorlib/cli/ripple.py | 16 +- python/src/trezorlib/cli/settings.py | 129 +- python/src/trezorlib/cli/solana.py | 22 +- python/src/trezorlib/cli/stellar.py | 16 +- python/src/trezorlib/cli/tezos.py | 22 +- python/src/trezorlib/cli/trezorctl.py | 71 +- python/src/trezorlib/client.py | 1037 ++--- python/src/trezorlib/debuglink.py | 534 ++- python/src/trezorlib/device.py | 141 +- python/src/trezorlib/eos.py | 15 +- python/src/trezorlib/ethereum.py | 48 +- python/src/trezorlib/fido.py | 22 +- python/src/trezorlib/firmware/__init__.py | 19 +- python/src/trezorlib/mapping.py | 14 +- python/src/trezorlib/messages.py | 313 ++ python/src/trezorlib/misc.py | 26 +- python/src/trezorlib/monero.py | 10 +- python/src/trezorlib/nem.py | 10 +- python/src/trezorlib/ripple.py | 10 +- python/src/trezorlib/solana.py | 14 +- python/src/trezorlib/stellar.py | 12 +- python/src/trezorlib/tezos.py | 14 +- python/src/trezorlib/tools.py | 19 +- python/src/trezorlib/transport/__init__.py | 98 +- python/src/trezorlib/transport/bridge.py | 108 +- python/src/trezorlib/transport/hid.py | 118 +- .../src/trezorlib/transport/new/__init__.py | 0 .../transport/new/alternating_bit_protocol.py | 102 + .../trezorlib/transport/new/channel_data.py | 40 + .../transport/new/channel_database.py | 95 + .../trezorlib/transport/new/control_byte.py | 59 + .../transport/new/protocol_and_channel.py | 32 + .../trezorlib/transport/new/protocol_v1.py | 94 + .../trezorlib/transport/new/protocol_v2.py | 380 ++ python/src/trezorlib/transport/new/session.py | 0 python/src/trezorlib/transport/protocol.py | 165 - python/src/trezorlib/transport/session.py | 158 + .../src/trezorlib/transport/thp/checksum.py | 19 + .../src/trezorlib/transport/thp/curve25519.py | 116 + .../trezorlib/transport/thp/packet_header.py | 82 + python/src/trezorlib/transport/thp/thp_io.py | 86 + python/src/trezorlib/transport/udp.py | 95 +- python/src/trezorlib/transport/webusb.py | 141 +- python/tools/encfs_aes_getpass.py | 14 +- python/tools/helloworld.py | 5 +- python/tools/pwd_reader.py | 18 +- python/tools/pybridge.py | 19 +- python/tools/rng_entropy_collector.py | 9 +- python/tools/trezor-otp.py | 13 +- rust/trezor-client/src/messages/generated.rs | 18 + .../src/protos/generated/messages.rs | 197 +- .../src/protos/generated/messages_common.rs | 95 +- .../src/protos/generated/messages_debug.rs | 266 +- .../src/protos/generated/messages_thp.rs | 3413 ++++++++++++++++- tests/click_tests/record_layout.py | 4 +- tests/common.py | 10 +- tests/conftest.py | 67 +- .../device_tests/binance/test_get_address.py | 12 +- .../binance/test_get_public_key.py | 8 +- tests/device_tests/binance/test_sign_tx.py | 6 +- tests/device_tests/bitcoin/payment_req.py | 12 +- .../bitcoin/test_authorize_coinjoin.py | 195 +- tests/device_tests/bitcoin/test_bcash.py | 88 +- tests/device_tests/bitcoin/test_bgold.py | 110 +- tests/device_tests/bitcoin/test_dash.py | 24 +- tests/device_tests/bitcoin/test_decred.py | 52 +- .../device_tests/bitcoin/test_descriptors.py | 16 +- tests/device_tests/bitcoin/test_firo.py | 6 +- tests/device_tests/bitcoin/test_fujicoin.py | 6 +- tests/device_tests/bitcoin/test_getaddress.py | 146 +- .../bitcoin/test_getaddress_segwit.py | 42 +- .../bitcoin/test_getaddress_segwit_native.py | 26 +- .../bitcoin/test_getaddress_show.py | 59 +- .../bitcoin/test_getownershipproof.py | 38 +- .../device_tests/bitcoin/test_getpublickey.py | 26 +- .../bitcoin/test_getpublickey_curve.py | 14 +- tests/device_tests/bitcoin/test_grs.py | 30 +- tests/device_tests/bitcoin/test_komodo.py | 24 +- tests/device_tests/bitcoin/test_multisig.py | 52 +- .../bitcoin/test_multisig_change.py | 80 +- .../bitcoin/test_nonstandard_paths.py | 56 +- tests/device_tests/bitcoin/test_op_return.py | 28 +- tests/device_tests/bitcoin/test_peercoin.py | 18 +- .../device_tests/bitcoin/test_signmessage.py | 38 +- tests/device_tests/bitcoin/test_signtx.py | 270 +- .../bitcoin/test_signtx_amount_unit.py | 14 +- .../bitcoin/test_signtx_external.py | 144 +- .../bitcoin/test_signtx_invalid_path.py | 46 +- .../bitcoin/test_signtx_mixed_inputs.py | 26 +- .../bitcoin/test_signtx_payreq.py | 59 +- .../bitcoin/test_signtx_prevhash.py | 32 +- .../bitcoin/test_signtx_replacement.py | 90 +- .../bitcoin/test_signtx_segwit.py | 94 +- .../bitcoin/test_signtx_segwit_native.py | 160 +- .../bitcoin/test_signtx_taproot.py | 66 +- .../bitcoin/test_verifymessage.py | 48 +- .../bitcoin/test_verifymessage_segwit.py | 26 +- .../test_verifymessage_segwit_native.py | 26 +- tests/device_tests/bitcoin/test_zcash.py | 38 +- .../cardano/test_address_public_key.py | 18 +- .../device_tests/cardano/test_derivations.py | 30 +- .../cardano/test_get_native_script_hash.py | 8 +- tests/device_tests/cardano/test_sign_tx.py | 25 +- tests/device_tests/eos/test_get_public_key.py | 12 +- tests/device_tests/eos/test_signtx.py | 98 +- .../device_tests/ethereum/test_definitions.py | 84 +- .../ethereum/test_definitions_bad.py | 64 +- .../device_tests/ethereum/test_getaddress.py | 12 +- .../ethereum/test_getpublickey.py | 14 +- .../ethereum/test_sign_typed_data.py | 26 +- .../ethereum/test_sign_verify_message.py | 20 +- tests/device_tests/ethereum/test_signtx.py | 99 +- .../misc/test_msg_cipherkeyvalue.py | 42 +- .../misc/test_msg_getecdhsessionkey.py | 10 +- .../device_tests/misc/test_msg_getentropy.py | 10 +- .../misc/test_msg_signidentity.py | 16 +- tests/device_tests/monero/test_getaddress.py | 12 +- tests/device_tests/monero/test_getwatchkey.py | 10 +- tests/device_tests/nem/test_getaddress.py | 8 +- tests/device_tests/nem/test_signtx_mosaics.py | 18 +- .../device_tests/nem/test_signtx_multisig.py | 18 +- tests/device_tests/nem/test_signtx_others.py | 12 +- .../device_tests/nem/test_signtx_transfers.py | 42 +- .../test_recovery_bip39_dryrun.py | 48 +- .../reset_recovery/test_recovery_bip39_t1.py | 109 +- .../reset_recovery/test_recovery_bip39_t2.py | 40 +- .../test_recovery_slip39_advanced.py | 62 +- .../test_recovery_slip39_advanced_dryrun.py | 14 +- .../test_recovery_slip39_basic.py | 108 +- .../test_recovery_slip39_basic_dryrun.py | 14 +- .../reset_recovery/test_reset_backup.py | 84 +- .../test_reset_bip39_skipbackup.py | 58 +- .../reset_recovery/test_reset_bip39_t1.py | 103 +- .../reset_recovery/test_reset_bip39_t2.py | 76 +- .../test_reset_recovery_bip39.py | 38 +- .../test_reset_recovery_slip39_advanced.py | 42 +- .../test_reset_recovery_slip39_basic.py | 42 +- .../test_reset_slip39_advanced.py | 22 +- .../reset_recovery/test_reset_slip39_basic.py | 30 +- tests/device_tests/ripple/test_get_address.py | 18 +- tests/device_tests/ripple/test_sign_tx.py | 14 +- tests/device_tests/solana/test_address.py | 6 +- tests/device_tests/solana/test_public_key.py | 6 +- tests/device_tests/solana/test_sign_tx.py | 8 +- tests/device_tests/stellar/test_stellar.py | 16 +- .../device_tests/test_authenticate_device.py | 12 +- tests/device_tests/test_autolock.py | 93 +- tests/device_tests/test_basic.py | 42 +- tests/device_tests/test_bip32_speed.py | 26 +- tests/device_tests/test_busy_state.py | 71 +- tests/device_tests/test_cancel.py | 53 +- tests/device_tests/test_debuglink.py | 52 +- tests/device_tests/test_firmware_hash.py | 22 +- tests/device_tests/test_language.py | 274 +- tests/device_tests/test_msg_applysettings.py | 314 +- tests/device_tests/test_msg_backup_device.py | 123 +- .../test_msg_change_wipe_code_t2.py | 128 +- tests/device_tests/test_msg_changepin_t1.py | 137 +- tests/device_tests/test_msg_changepin_t2.py | 166 +- tests/device_tests/test_msg_loaddevice.py | 77 +- tests/device_tests/test_msg_ping.py | 22 +- tests/device_tests/test_msg_sd_protect.py | 65 +- .../test_msg_show_device_tutorial.py | 8 +- .../test_passphrase_slip39_advanced.py | 18 +- .../test_passphrase_slip39_basic.py | 20 +- tests/device_tests/test_pin.py | 50 +- tests/device_tests/test_protection_levels.py | 265 +- tests/device_tests/test_repeated_backup.py | 143 +- tests/device_tests/test_sdcard.py | 66 +- tests/device_tests/tezos/test_getaddress.py | 12 +- tests/device_tests/tezos/test_getpublickey.py | 8 +- tests/device_tests/tezos/test_sign_tx.py | 58 +- .../webauthn/test_msg_webauthn.py | 34 +- .../device_tests/webauthn/test_u2f_counter.py | 18 +- tests/device_tests/zcash/test_sign_tx.py | 88 +- tests/input_flows.py | 9 +- tests/translations.py | 14 +- 298 files changed, 18200 insertions(+), 6063 deletions(-) create mode 100644 core/src/apps/common/cache.py create mode 100644 core/src/apps/thp/create_new_session.py create mode 100644 core/src/apps/thp/pairing.py create mode 100644 core/src/storage/cache_codec.py create mode 100644 core/src/storage/cache_common.py create mode 100644 core/src/storage/cache_thp.py create mode 100644 core/src/trezor/enums/ThpPairingMethod.py create mode 100644 core/src/trezor/wire/message_handler.py create mode 100644 core/src/trezor/wire/protocol_common.py create mode 100644 core/src/trezor/wire/thp/__init__.py create mode 100644 core/src/trezor/wire/thp/alternating_bit_protocol.py create mode 100644 core/src/trezor/wire/thp/channel.py create mode 100644 core/src/trezor/wire/thp/channel_manager.py create mode 100644 core/src/trezor/wire/thp/checksum.py create mode 100644 core/src/trezor/wire/thp/control_byte.py create mode 100644 core/src/trezor/wire/thp/cpace.py create mode 100644 core/src/trezor/wire/thp/crypto.py create mode 100644 core/src/trezor/wire/thp/interface_manager.py create mode 100644 core/src/trezor/wire/thp/memory_manager.py create mode 100644 core/src/trezor/wire/thp/pairing_context.py create mode 100644 core/src/trezor/wire/thp/received_message_handler.py create mode 100644 core/src/trezor/wire/thp/session_context.py create mode 100644 core/src/trezor/wire/thp/session_manager.py create mode 100644 core/src/trezor/wire/thp/thp_messages.py create mode 100644 core/src/trezor/wire/thp/transmission_loop.py create mode 100644 core/src/trezor/wire/thp/writer.py create mode 100644 core/src/trezor/wire/thp_main.py create mode 100644 core/tests/mock_wire_interface.py create mode 100755 core/tests/myTests.sh create mode 100644 core/tests/test_trezor.wire.thp.checksum.py create mode 100644 core/tests/test_trezor.wire.thp.credential_manager.py create mode 100644 core/tests/test_trezor.wire.thp.crypto.py create mode 100644 core/tests/test_trezor.wire.thp.py create mode 100644 core/tests/test_trezor.wire.thp.writer.py create mode 100644 core/tests/test_trezor.wire.thp_deprecated.py create mode 100644 core/tests/thp_common.py create mode 120000 legacy/firmware/protob/messages-thp.proto create mode 100644 python/channel_data.json create mode 100644 python/src/trezorlib/transport/new/__init__.py create mode 100644 python/src/trezorlib/transport/new/alternating_bit_protocol.py create mode 100644 python/src/trezorlib/transport/new/channel_data.py create mode 100644 python/src/trezorlib/transport/new/channel_database.py create mode 100644 python/src/trezorlib/transport/new/control_byte.py create mode 100644 python/src/trezorlib/transport/new/protocol_and_channel.py create mode 100644 python/src/trezorlib/transport/new/protocol_v1.py create mode 100644 python/src/trezorlib/transport/new/protocol_v2.py create mode 100644 python/src/trezorlib/transport/new/session.py delete mode 100644 python/src/trezorlib/transport/protocol.py create mode 100644 python/src/trezorlib/transport/session.py create mode 100644 python/src/trezorlib/transport/thp/checksum.py create mode 100644 python/src/trezorlib/transport/thp/curve25519.py create mode 100644 python/src/trezorlib/transport/thp/packet_header.py create mode 100644 python/src/trezorlib/transport/thp/thp_io.py diff --git a/ci/build.yml b/ci/build.yml index 8faa98da557..3f2861edf04 100644 --- a/ci/build.yml +++ b/ci/build.yml @@ -307,6 +307,7 @@ core unix frozen debug build: needs: [] variables: PYOPT: "0" + THP: "1" script: - $NIX_SHELL --run "poetry run make -C core build_unix_frozen" artifacts: diff --git a/common/protob/messages-common.proto b/common/protob/messages-common.proto index 500ddf73894..d24fc76b3de 100644 --- a/common/protob/messages-common.proto +++ b/common/protob/messages-common.proto @@ -39,6 +39,8 @@ message Failure { Failure_PinMismatch = 12; Failure_WipeCodeMismatch = 13; Failure_InvalidSession = 14; + Failure_ThpUnallocatedSession=15; + Failure_InvalidProtocol=16; Failure_FirmwareError = 99; } } diff --git a/common/protob/messages-debug.proto b/common/protob/messages-debug.proto index 0edfd4979f2..148c01369bf 100644 --- a/common/protob/messages-debug.proto +++ b/common/protob/messages-debug.proto @@ -110,6 +110,8 @@ message DebugLinkGetState { // trezor-core only - wait until current layout changes // changed in 2.6.4: multiple wait types instead of true/false. optional DebugWaitType wait_layout = 3 [default=IMMEDIATE]; + + optional bytes thp_channel_id=4; // THP only - used to get information from particular channel } /** @@ -130,6 +132,9 @@ message DebugLinkState { optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow optional management.BackupType mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39) repeated string tokens = 13; // current layout represented as a list of string tokens + optional uint32 thp_pairing_code_entry_code = 14; + optional bytes thp_pairing_code_qr_code = 15; + optional bytes thp_pairing_code_nfc_unidirectional = 16; } /** diff --git a/common/protob/messages-thp.proto b/common/protob/messages-thp.proto index 743cf3a1eee..579a06c5382 100644 --- a/common/protob/messages-thp.proto +++ b/common/protob/messages-thp.proto @@ -9,6 +9,189 @@ option (include_in_bitcoin_only) = true; import "messages.proto"; +/** + * Numeric identifiers of pairing methods. + * @embed + */ +enum ThpPairingMethod { + NoMethod = 1; // Trust without MITM protection. + CodeEntry = 2; // User types code diplayed on Trezor into the host application. + QrCode = 3; // User scans code displayed on Trezor into host application. + NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC. +} + +/** + * @embed + */ +message ThpDeviceProperties { + optional string internal_model = 1; // Internal model name e.g. "T2B1". + optional uint32 model_variant = 2; // Encodes the device properties such as color. + optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode. + optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware. + repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor. +} + +/** + * @embed + */ +message ThpHandshakeCompletionReqNoisePayload { + optional bytes host_pairing_credential = 1; // Host's pairing credential + repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host +} + +/** + * Request: Ask device for a new session with given passphrase. + * @start + * @next ThpNewSession + */ +message ThpCreateNewSession{ + optional string passphrase = 1; + optional bool on_device = 2; // User wants to enter passphrase on the device + optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only +} + +/** + * Response: Contains session_id of the newly created session. + * @end + */ +message ThpNewSession{ + optional uint32 new_session_id = 1; +} + +/** + * Request: Start pairing process. + * @start + * @next ThpCodeEntryCommitment + * @next ThpPairingPreparationsFinished + */ +message ThpStartPairingRequest{ + optional string host_name = 1; // Human-readable host name +} + +/** + * Response: Pairing is ready for user input / OOB communication. + * @next ThpCodeEntryCpace + * @next ThpQrCodeTag + * @next ThpNfcUnidirectionalTag + */ + message ThpPairingPreparationsFinished{ +} + +/** + * Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment. + * @next ThpCodeEntryChallenge + */ +message ThpCodeEntryCommitment { + optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret +} + +/** + * Response: Host responds to Trezor's Code Entry commitment with a challenge. + * @next ThpPairingPreparationsFinished + */ +message ThpCodeEntryChallenge { + optional bytes challenge = 1; // host's random 32-byte challenge +} + +/** + * Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor. + * @next ThpCodeEntryCpaceTrezor + */ +message ThpCodeEntryCpaceHost { + optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key +} + +/** + * Response: Trezor continues with the CPACE protocol. + * @next ThpCodeEntryTag + */ +message ThpCodeEntryCpaceTrezor { + optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key +} + +/** + * Response: Host continues with the CPACE protocol. + * @next ThpCodeEntrySecret + */ +message ThpCodeEntryTag { + optional bytes tag = 2; // SHA-256 of shared secret +} + +/** + * Response: Trezor finishes the CPACE protocol. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpCodeEntrySecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: User selected QR Code pairing option. Host sends a QR Tag. + * @next ThpQrCodeSecret + */ +message ThpQrCodeTag { + optional bytes tag = 1; // SHA-256 of shared secret +} + +/** + * Response: Trezor sends the QR secret. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpQrCodeSecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag. + * @next ThpNfcUnidirectionalSecret + */ +message ThpNfcUnidirectionalTag { + optional bytes tag = 1; // SHA-256 of shared secret +} + +/** + * Response: Trezor sends the Unidirectioal NFC secret. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpNfcUnidirectionalSecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: Host requests issuance of a new pairing credential. + * @start + * @next ThpCredentialResponse + */ +message ThpCredentialRequest { + optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake. +} + +/** + * Response: Trezor issues a new pairing credential. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpCredentialResponse { + optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake. + optional bytes credential = 2; // The pairing credential issued by the Trezor to the host. +} + +/** + * Request: Host requests transition to the encrypted traffic phase. + * @start + * @next ThpEndResponse + */ +message ThpEndRequest {} + +/** + * Response: Trezor approves transition to the encrypted traffic phase + * @end + */ +message ThpEndResponse {} + /** * Only for internal use. * @embed diff --git a/common/protob/messages.proto b/common/protob/messages.proto index 8156bc119cb..55bfbeaadd7 100644 --- a/common/protob/messages.proto +++ b/common/protob/messages.proto @@ -42,6 +42,10 @@ extend google.protobuf.EnumValueOptions { optional bool wire_tiny = 50006; // message is handled by Trezor when the USB stack is in tiny mode optional bool wire_bootloader = 50007; // message is only handled by Trezor Bootloader optional bool wire_no_fsm = 50008; // message is not handled by Trezor unless the USB stack is in tiny mode + optional bool channel_in = 50009; + optional bool channel_out = 50010; + optional bool pairing_in = 50011; + optional bool pairing_out = 50012; optional bool bitcoin_only = 60000; // enum value is available on BITCOIN_ONLY build // (messages not marked bitcoin_only will be EXCLUDED) @@ -376,4 +380,26 @@ enum MessageType { MessageType_SolanaAddress = 903 [(wire_out) = true]; MessageType_SolanaSignTx = 904 [(wire_in) = true]; MessageType_SolanaTxSignature = 905 [(wire_out) = true]; + + // THP + MessageType_ThpCreateNewSession = 1000[(bitcoin_only)=true, (channel_in) = true]; + MessageType_ThpNewSession = 1001[(bitcoin_only)=true, (channel_out) = true]; + MessageType_ThpStartPairingRequest = 1008 [(bitcoin_only) = true, (pairing_in) = true]; + MessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true, (pairing_out) = true]; + MessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true, (pairing_in) = true]; + MessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true, (pairing_out) = true]; + MessageType_ThpEndRequest = 1012 [(bitcoin_only) = true, (pairing_in) = true]; + MessageType_ThpEndResponse = 1013[(bitcoin_only) = true, (pairing_out) = true]; + MessageType_ThpCodeEntryCommitment = 1016[(bitcoin_only)=true, (pairing_out) = true]; + MessageType_ThpCodeEntryChallenge = 1017[(bitcoin_only)=true, (pairing_in) = true]; + MessageType_ThpCodeEntryCpaceHost = 1018[(bitcoin_only)=true, (pairing_in) = true]; + MessageType_ThpCodeEntryCpaceTrezor = 1019[(bitcoin_only)=true, (pairing_out) = true]; + MessageType_ThpCodeEntryTag = 1020[(bitcoin_only)=true, (pairing_in) = true]; + MessageType_ThpCodeEntrySecret = 1021[(bitcoin_only)=true, (pairing_out) = true]; + MessageType_ThpQrCodeTag = 1024[(bitcoin_only)=true, (pairing_in) = true]; + MessageType_ThpQrCodeSecret = 1025[(bitcoin_only)=true, (pairing_out) = true]; + MessageType_ThpNfcUnidirectionalTag = 1032[(bitcoin_only)=true, (pairing_in) = true]; + MessageType_ThpNfcUnidirectionalSecret = 1033[(bitcoin_only)=true, (pairing_in) = true]; } + + diff --git a/common/protob/pb2py b/common/protob/pb2py index 5eddadd42a2..c1b94b1c9c8 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -558,6 +558,8 @@ class RustBlobRenderer: enums = [] cursor = 0 for enum in sorted(self.descriptor.enums, key=lambda e: e.name): + if enum.name == "MessageType": + continue self.enum_map[enum.name] = cursor enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value)) enums.append(enum_blob) diff --git a/core/Makefile b/core/Makefile index 4b28614aec1..b7f13242041 100644 --- a/core/Makefile +++ b/core/Makefile @@ -298,14 +298,20 @@ build_unix: templates ## build unix port $(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \ TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)" \ PYOPT="0" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \ - NEW_RENDERING="$(NEW_RENDERING)" + NEW_RENDERING="$(NEW_RENDERING)" TREZOR_MEMPERF="$(TREZOR_MEMPERF)" build_unix_frozen: templates build_cross ## build unix port with frozen modules $(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \ - TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \ + TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\ PYOPT="$(PYOPT)" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \ TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 NEW_RENDERING="$(NEW_RENDERING)" +build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0") + $(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \ + TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\ + PYOPT="0" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \ + TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 + build_unix_debug: templates ## build unix port $(SCONS) --max-drift=1 CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \ TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \ diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 7646f293068..3f357ad3dda 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -24,6 +24,19 @@ FEATURE_FLAGS = { "AES_GCM": False, } + +if THP: + FEATURE_FLAGS = { + "RDI": True, + "SECP256K1_ZKP": True, # required for trezor.crypto.curve.bip340 (BIP340/Taproot) + "AES_GCM": True, # Required for THP encryption + } +else: + FEATURE_FLAGS = { + "RDI": True, + "SECP256K1_ZKP": True, # required for trezor.crypto.curve.bip340 (BIP340/Taproot) + "AES_GCM": False, + } FEATURES_WANTED = ["input", "sbu", "sd_card", "rgb_led", "dma2d", "consumption_mask", "usb" ,"optiga", "haptic"] if DISABLE_OPTIGA and PYOPT == '0': FEATURES_WANTED.remove("optiga") @@ -567,6 +580,8 @@ if FROZEN: ] if not EVERYTHING else [] )) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', @@ -616,6 +631,8 @@ if FROZEN: SOURCE_PY_DIR + 'apps/bitcoin/sign_tx/zcash_v4.py', ]) ) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py')) if EVERYTHING: SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/binance/*.py')) @@ -696,7 +713,7 @@ if FROZEN: source_files = SOURCE_MOD + SOURCE_MOD_CRYPTO + SOURCE_FIRMWARE + SOURCE_MICROPYTHON + SOURCE_MICROPYTHON_SPEED + SOURCE_HAL obj_program = [] obj_program.extend(env.Object(source=SOURCE_MOD)) -obj_program.extend(env.Object(source=SOURCE_MOD_CRYPTO, CCFLAGS='$CCFLAGS -ftrivial-auto-var-init=zero')) +obj_program.extend(env.Object(source=SOURCE_MOD_CRYPTO)) if FEATURE_FLAGS["SECP256K1_ZKP"]: obj_program.extend(env.Object(source=SOURCE_MOD_SECP256K1_ZKP, CCFLAGS='$CCFLAGS -Wno-unused-function')) source_files.extend(SOURCE_MOD_SECP256K1_ZKP) diff --git a/core/SConscript.unix b/core/SConscript.unix index 7dd9beb4746..8a958b5b329 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -482,7 +482,7 @@ if ARGUMENTS.get('TREZOR_EMULATOR_DEBUGGABLE', '0') == '1': if ARGUMENTS.get('TREZOR_MEMPERF', '0') == '1': CPPDEFINES_MOD += [ - ('MICROPY_TREZOR_MEMPERF', '\(1\)') + ('MICROPY_TREZOR_MEMPERF', '1') ] env.Replace( @@ -633,6 +633,8 @@ if FROZEN: ] if not EVERYTHING else [] )) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', diff --git a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h index b7edf784b84..80dfaa809ec 100644 --- a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h +++ b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h @@ -111,9 +111,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_obj, mod_trezorcrypto_AesGcm_encrypt); /// def encrypt_in_place(self, data: bytearray | memoryview) -> int: -/// """ -/// Encrypt data chunk in place. Returns the length of the encrypted data. -/// """ +/// """ +/// Encrypt data chunk in place. Returns the length of the encrypted data. +/// """ STATIC mp_obj_t mod_trezorcrypto_AesGcm_encrypt_in_place(mp_obj_t self, mp_obj_t data) { mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self); @@ -158,9 +158,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_obj, mod_trezorcrypto_AesGcm_decrypt); /// def decrypt_in_place(self, data: bytearray | memoryview) -> int: -/// """ -/// Decrypt data chunk in place. Returns the length of the decrypted data. -/// """ +/// """ +/// Decrypt data chunk in place. Returns the length of the decrypted data. +/// """ STATIC mp_obj_t mod_trezorcrypto_AesGcm_decrypt_in_place(mp_obj_t self, mp_obj_t data) { mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self); diff --git a/core/embed/extmod/modtrezorutils/modtrezorutils.c b/core/embed/extmod/modtrezorutils/modtrezorutils.c index bb882fc6f5b..4557c554c58 100644 --- a/core/embed/extmod/modtrezorutils/modtrezorutils.c +++ b/core/embed/extmod/modtrezorutils/modtrezorutils.c @@ -410,7 +410,7 @@ STATIC mp_obj_tuple_t mod_trezorutils_version_obj = { /// UI_LAYOUT: str /// """UI layout identifier ("tt" for model T, "tr" for models One and R).""" /// USE_THP: bool -/// """Whether the firmware supports Trezor-Host Protocol (version 3).""" +/// """Whether the firmware supports the Trezor-Host Protocol.""" STATIC const mp_rom_map_elem_t mp_module_trezorutils_globals_table[] = { {MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorutils)}, diff --git a/core/embed/rust/src/protobuf/obj.rs b/core/embed/rust/src/protobuf/obj.rs index 7b4c69033cb..fc3a1f3a1dd 100644 --- a/core/embed/rust/src/protobuf/obj.rs +++ b/core/embed/rust/src/protobuf/obj.rs @@ -356,7 +356,7 @@ pub static mp_module_trezorproto: Module = obj_module! { /// """Calculate length of encoding of the specified message.""" Qstr::MP_QSTR_encoded_length => obj_fn_1!(protobuf_len).as_obj(), - /// def encode(buffer: bytearray, msg: MessageType) -> int: + /// def encode(buffer: bytearray | memoryview, msg: MessageType) -> int: /// """Encode the message into the specified buffer. Return length of /// encoding.""" Qstr::MP_QSTR_encode => obj_fn_2!(protobuf_encode).as_obj() diff --git a/core/mocks/generated/trezorcrypto/__init__.pyi b/core/mocks/generated/trezorcrypto/__init__.pyi index a7e0d95f3d8..7792fbe085e 100644 --- a/core/mocks/generated/trezorcrypto/__init__.pyi +++ b/core/mocks/generated/trezorcrypto/__init__.pyi @@ -55,9 +55,9 @@ class aesgcm: """ def encrypt_in_place(self, data: bytearray | memoryview) -> int: - """ - Encrypt data chunk in place. Returns the length of the encrypted data. - """ + """ + Encrypt data chunk in place. Returns the length of the encrypted data. + """ def decrypt(self, data: bytes) -> bytes: """ @@ -65,9 +65,9 @@ class aesgcm: """ def decrypt_in_place(self, data: bytearray | memoryview) -> int: - """ - Decrypt data chunk in place. Returns the length of the decrypted data. - """ + """ + Decrypt data chunk in place. Returns the length of the decrypted data. + """ def auth(self, data: bytes) -> None: """ diff --git a/core/mocks/generated/trezorproto.pyi b/core/mocks/generated/trezorproto.pyi index ee7c9eb72a8..530ef0963f8 100644 --- a/core/mocks/generated/trezorproto.pyi +++ b/core/mocks/generated/trezorproto.pyi @@ -42,6 +42,6 @@ def encoded_length(msg: MessageType) -> int: # rust/src/protobuf/obj.rs -def encode(buffer: bytearray, msg: MessageType) -> int: +def encode(buffer: bytearray | memoryview, msg: MessageType) -> int: """Encode the message into the specified buffer. Return length of encoding.""" diff --git a/core/mocks/generated/trezorui2.pyi b/core/mocks/generated/trezorui2.pyi index 3bb07ca7cf6..2fbb2cc73f2 100644 --- a/core/mocks/generated/trezorui2.pyi +++ b/core/mocks/generated/trezorui2.pyi @@ -1119,7 +1119,6 @@ class LayoutObj(Generic[T]): """Return (code, type) of button request made during the last event or timer pass.""" def get_transition_out(self) -> AttachType: """Return the transition type.""" - def return_value(self) -> T: """Retrieve the return value of the layout object.""" def __del__(self) -> None: diff --git a/core/mocks/generated/trezorutils.pyi b/core/mocks/generated/trezorutils.pyi index 0607b5c91a2..7116d88c7cf 100644 --- a/core/mocks/generated/trezorutils.pyi +++ b/core/mocks/generated/trezorutils.pyi @@ -151,4 +151,4 @@ BITCOIN_ONLY: bool UI_LAYOUT: str """UI layout identifier ("tt" for model T, "tr" for models One and R).""" USE_THP: bool -"""Whether the firmware supports Trezor-Host Protocol (version 3).""" +"""Whether the firmware supports the Trezor-Host Protocol.""" diff --git a/core/src/all_modules.py b/core/src/all_modules.py index f6706b0de53..17837cc60a3 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -47,6 +47,12 @@ import storage storage.cache import storage.cache +storage.cache_codec +import storage.cache_codec +storage.cache_common +import storage.cache_common +storage.cache_thp +import storage.cache_thp storage.common import storage.common storage.debug @@ -205,6 +211,10 @@ import trezor.wire.context trezor.wire.errors import trezor.wire.errors +trezor.wire.message_handler +import trezor.wire.message_handler +trezor.wire.protocol_common +import trezor.wire.protocol_common trezor.workflow import trezor.workflow apps @@ -289,6 +299,8 @@ import apps.common.backup apps.common.backup_types import apps.common.backup_types +apps.common.cache +import apps.common.cache apps.common.cbor import apps.common.cbor apps.common.coininfo @@ -381,10 +393,52 @@ import apps.workflow_handlers if utils.USE_THP: + trezor.enums.ThpPairingMethod + import trezor.enums.ThpPairingMethod + trezor.wire.thp + import trezor.wire.thp + trezor.wire.thp.alternating_bit_protocol + import trezor.wire.thp.alternating_bit_protocol + trezor.wire.thp.channel + import trezor.wire.thp.channel + trezor.wire.thp.channel_manager + import trezor.wire.thp.channel_manager + trezor.wire.thp.checksum + import trezor.wire.thp.checksum + trezor.wire.thp.control_byte + import trezor.wire.thp.control_byte + trezor.wire.thp.cpace + import trezor.wire.thp.cpace + trezor.wire.thp.crypto + import trezor.wire.thp.crypto + trezor.wire.thp.interface_manager + import trezor.wire.thp.interface_manager + trezor.wire.thp.memory_manager + import trezor.wire.thp.memory_manager + trezor.wire.thp.pairing_context + import trezor.wire.thp.pairing_context + trezor.wire.thp.received_message_handler + import trezor.wire.thp.received_message_handler + trezor.wire.thp.session_context + import trezor.wire.thp.session_context + trezor.wire.thp.session_manager + import trezor.wire.thp.session_manager + trezor.wire.thp.thp_messages + import trezor.wire.thp.thp_messages + trezor.wire.thp.transmission_loop + import trezor.wire.thp.transmission_loop + trezor.wire.thp.writer + import trezor.wire.thp.writer + trezor.wire.thp_main + import trezor.wire.thp_main apps.thp import apps.thp + apps.thp.create_new_session + import apps.thp.create_new_session apps.thp.credential_manager import apps.thp.credential_manager + apps.thp.pairing + import apps.thp.pairing if not utils.BITCOIN_ONLY: trezor.enums.BinanceOrderSide diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 25015459cf6..e9d01c28977 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -1,11 +1,15 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +import storage.cache_codec as cache_codec import storage.device as storage_device +from storage.cache import check_thp_is_not_used +from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED from trezor import TR, config, utils, wire, workflow from trezor.enums import HomescreenFormat, MessageType from trezor.messages import Success, UnlockPath from trezor.ui.layouts import confirm_action +from trezor.wire import context +from trezor.wire.message_handler import filters, remove_filter from . import workflow_handlers @@ -34,7 +38,7 @@ def busy_expiry_ms() -> int: Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS) if busy_deadline_ms is None: return 0 @@ -202,13 +206,20 @@ def get_features() -> Features: return f -async def handle_Initialize(msg: Initialize) -> Features: - session_id = storage_cache.start_session(msg.session_id) +@check_thp_is_not_used +async def handle_Initialize( + msg: Initialize, +) -> Features: + session_id = cache_codec.start_session(msg.session_id) + + # TODO change cardano derivation + # ctx = context.get_context() if not utils.BITCOIN_ONLY: - derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO) - have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) + from storage.cache_common import APP_COMMON_DERIVE_CARDANO + derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) if ( have_seed and msg.derive_cardano is not None @@ -216,14 +227,12 @@ async def handle_Initialize(msg: Initialize) -> Features: ): # seed is already derived, and host wants to change derive_cardano setting # => create a new session - storage_cache.end_current_session() - session_id = storage_cache.start_session() + cache_codec.end_current_session() + session_id = cache_codec.start_session() have_seed = False if not have_seed: - storage_cache.set_bool( - storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano) - ) + context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) features = get_features() features.session_id = session_id @@ -252,16 +261,16 @@ async def handle_SetBusy(msg: SetBusy) -> Success: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) + context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline) else: - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() workflow.close_others() return Success() async def handle_EndSession(msg: EndSession) -> Success: - storage_cache.end_current_session() + cache_codec.end_current_session() return Success() @@ -276,7 +285,7 @@ async def handle_Ping(msg: Ping) -> Success: async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: from trezor.messages import PreauthorizedRequest - from trezor.wire.context import call_any, get_context + from trezor.wire.context import call_any from apps.common import authorization @@ -289,11 +298,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: req = await call_any(PreauthorizedRequest(), *wire_types) assert req.MESSAGE_WIRE_TYPE is not None - handler = workflow_handlers.find_registered_handler( - get_context().iface, req.MESSAGE_WIRE_TYPE - ) + handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE) if handler is None: - return wire.unexpected_message() + return wire.message_handler.unexpected_message() return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument] @@ -301,7 +308,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: from trezor.crypto import hmac from trezor.messages import UnlockedPathRequest - from trezor.wire.context import call_any, get_context + from trezor.wire.context import call_any from apps.common.paths import SLIP25_PURPOSE from apps.common.seed import Slip21Node, get_seed @@ -342,9 +349,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types) assert req.MESSAGE_WIRE_TYPE in wire_types - handler = workflow_handlers.find_registered_handler( - get_context().iface, req.MESSAGE_WIRE_TYPE - ) + handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE) assert handler is not None return await handler(req, msg) # type: ignore [Expected 1 positional argument] @@ -364,7 +369,7 @@ def set_homescreen() -> None: set_default = workflow.set_default # local_cache_attribute - if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): + if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS): from apps.homescreen import busyscreen set_default(busyscreen) @@ -393,7 +398,7 @@ def set_homescreen() -> None: def lock_device(interrupt_workflow: bool = True) -> None: if config.has_pin(): config.lock() - wire.filters.append(_pinlock_filter) + filters.append(_pinlock_filter) set_homescreen() if interrupt_workflow: workflow.close_others() @@ -429,7 +434,7 @@ async def unlock_device() -> None: _SCREENSAVER_IS_ON = False set_homescreen() - wire.remove_filter(_pinlock_filter) + remove_filter(_pinlock_filter) def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: @@ -450,7 +455,9 @@ def reload_settings_from_storage() -> None: workflow.idle_timer.set( storage_device.get_autolock_delay_ms(), lock_device_if_unlocked ) - wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features() + wire.message_handler.EXPERIMENTAL_ENABLED = ( + storage_device.get_experimental_features() + ) if ui.display.orientation() != storage_device.get_rotation(): ui.backlight_fade(ui.BacklightLevels.DIM) ui.display.orientation(storage_device.get_rotation()) @@ -482,4 +489,4 @@ def boot() -> None: backup.activate_repeated_backup() if not config.is_unlocked(): # pinlocked handler should always be the last one - wire.filters.append(_pinlock_filter) + filters.append(_pinlock_filter) diff --git a/core/src/apps/bitcoin/sign_tx/payment_request.py b/core/src/apps/bitcoin/sign_tx/payment_request.py index 8f2f7b88a8a..779646cc1ca 100644 --- a/core/src/apps/bitcoin/sign_tx/payment_request.py +++ b/core/src/apps/bitcoin/sign_tx/payment_request.py @@ -1,7 +1,7 @@ from micropython import const from typing import TYPE_CHECKING -from trezor.wire import DataError +from trezor.wire import DataError, context from .. import writers @@ -26,7 +26,7 @@ class PaymentRequestVerifier: def __init__( self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain ) -> None: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto.hashlib import sha256 from trezor.utils import HashWriter @@ -42,9 +42,9 @@ def __init__( if msg.nonce: nonce = bytes(msg.nonce) - if cache.get(cache.APP_COMMON_NONCE) != nonce: + if context.cache_get(APP_COMMON_NONCE) != nonce: raise DataError("Invalid nonce in payment request.") - cache.delete(cache.APP_COMMON_NONCE) + context.cache_delete(APP_COMMON_NONCE) else: nonce = b"" if msg.memos: diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 06b662c87bb..0ddfbf8ac28 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -1,6 +1,11 @@ from typing import TYPE_CHECKING -from storage import cache, device +import storage.device as device +from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_CARDANO_ICARUS_TREZOR_SECRET, + APP_COMMON_DERIVE_CARDANO, +) from trezor import wire from trezor.crypto import cardano @@ -15,6 +20,7 @@ from trezor import messages from trezor.crypto import bip32 from trezor.enums import CardanoDerivationType + from trezor.wire.protocol_common import Context from apps.common.keychain import Handler, MsgOut from apps.common.paths import Bip32Path @@ -110,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool: return path[: len(MINTING_ROOT)] == MINTING_ROOT -def derive_and_store_secrets(passphrase: str) -> None: +def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: assert device.is_initialized() - assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO) + assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -132,14 +138,15 @@ def derive_and_store_secrets(passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) - cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret) + ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: from trezor.enums import CardanoDerivationType + from trezor.wire import context - from apps.common.seed import derive_and_store_roots + from apps.common.seed import derive_and_store_roots_legacy if not device.is_initialized(): raise wire.NotInitialized("Device is not initialized") @@ -148,19 +155,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai seed = await get_seed() return Keychain(cardano.from_seed_ledger(seed)) - if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO): + if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO): raise wire.ProcessError("Cardano derivation is not enabled for this session") if derivation_type == CardanoDerivationType.ICARUS: - cache_entry = cache.APP_CARDANO_ICARUS_SECRET + cache_entry = APP_CARDANO_ICARUS_SECRET else: - cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET + cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET # _get_secret - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) if secret is None: - await derive_and_store_roots() - secret = cache.get(cache_entry) + await derive_and_store_roots_legacy() + secret = context.cache_get(cache_entry) assert secret is not None root = cardano.from_secret(secret) diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index 4d6e58e4d61..08527c4565d 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -1,23 +1,21 @@ from typing import Iterable import storage.cache as storage_cache +from storage.cache_common import ( + APP_COMMON_AUTHORIZATION_DATA, + APP_COMMON_AUTHORIZATION_TYPE, +) from trezor import protobuf from trezor.enums import MessageType +from trezor.wire import context WIRE_TYPES: dict[int, tuple[int, ...]] = { MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), } -APP_COMMON_AUTHORIZATION_DATA = ( - storage_cache.APP_COMMON_AUTHORIZATION_DATA -) # global_import_cache -APP_COMMON_AUTHORIZATION_TYPE = ( - storage_cache.APP_COMMON_AUTHORIZATION_TYPE -) # global_import_cache - def is_set() -> bool: - return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None + return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None def set(auth_message: protobuf.MessageType) -> None: @@ -29,16 +27,16 @@ def set(auth_message: protobuf.MessageType) -> None: # (because only wire-level messages have wire_type, which we use as identifier) ensure(auth_message.MESSAGE_WIRE_TYPE is not None) assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too - storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) - storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer) + context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) + context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer) def get() -> protobuf.MessageType | None: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if not stored_auth_type: return None - buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"") + buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"") return protobuf.load_message_buffer(buffer, stored_auth_type) @@ -49,7 +47,7 @@ def is_set_any_session(auth_type: MessageType) -> bool: def get_wire_types() -> Iterable[int]: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if stored_auth_type is None: return () @@ -57,5 +55,5 @@ def get_wire_types() -> Iterable[int]: def clear() -> None: - storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE) - storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA) + context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE) + context.cache_delete(APP_COMMON_AUTHORIZATION_DATA) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py index 48fe93e0705..a48e84e5c33 100644 --- a/core/src/apps/common/backup.py +++ b/core/src/apps/common/backup.py @@ -1,25 +1,27 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +import storage.cache_common as storage_cache from trezor import wire from trezor.enums import MessageType +from trezor.wire import context +from trezor.wire.message_handler import filters, remove_filter if TYPE_CHECKING: from trezor.wire import Handler, Msg def repeated_backup_enabled() -> bool: - return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + return context.cache_get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) def activate_repeated_backup(): - storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) - wire.filters.append(_repeated_backup_filter) + context.cache_set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) + filters.append(_repeated_backup_filter) def deactivate_repeated_backup(): - storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) - wire.remove_filter(_repeated_backup_filter) + context.cache_delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + remove_filter(_repeated_backup_filter) _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( diff --git a/core/src/apps/common/cache.py b/core/src/apps/common/cache.py new file mode 100644 index 00000000000..af3dd977f34 --- /dev/null +++ b/core/src/apps/common/cache.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from trezor.wire import context + +if TYPE_CHECKING: + from typing import Callable, ParamSpec + + P = ParamSpec("P") + ByteFunc = Callable[P, bytes] + + +def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: + def decorator(func: ByteFunc[P]) -> ByteFunc[P]: + def wrapper(*args: P.args, **kwargs: P.kwargs): + value = context.cache_get(key) + if value is None: + value = func(*args, **kwargs) + context.cache_set(key, value) + return value + + return wrapper + + return decorator diff --git a/core/src/apps/common/passphrase.py b/core/src/apps/common/passphrase.py index da5af75d7fd..6583052ea0b 100644 --- a/core/src/apps/common/passphrase.py +++ b/core/src/apps/common/passphrase.py @@ -1,15 +1,48 @@ from micropython import const +from typing import TYPE_CHECKING import storage.device as storage_device +from storage.cache import check_thp_is_not_used from trezor.wire import DataError _MAX_PASSPHRASE_LEN = const(50) +if TYPE_CHECKING: + from trezor.messages import ThpCreateNewSession + def is_enabled() -> bool: return storage_device.is_passphrase_enabled() +async def get_passphrase(msg: ThpCreateNewSession) -> str: + if not is_enabled(): + return "" + + if msg.on_device or storage_device.get_passphrase_always_on_device(): + passphrase = await _get_on_device() + else: + passphrase = msg.passphrase or "" + if passphrase: + await _handle_displaying_passphrase_from_host(passphrase) + + if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: + raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes") + + return passphrase + + +async def _get_on_device() -> str: + from trezor import workflow + from trezor.ui.layouts import request_passphrase_on_device + + workflow.close_others() # request exclusive UI access + passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) + + return passphrase + + +@check_thp_is_not_used async def get() -> str: from trezor import workflow @@ -29,8 +62,8 @@ async def get() -> str: return passphrase +@check_thp_is_not_used async def _request_on_host() -> str: - from trezor import TR from trezor.messages import PassphraseAck, PassphraseRequest from trezor.ui.layouts import request_passphrase_on_host from trezor.wire.context import call @@ -55,29 +88,34 @@ async def _request_on_host() -> str: # non-empty passphrase if passphrase: - from trezor.ui.layouts import confirm_action, confirm_blob - - # We want to hide the passphrase, or show it, according to settings. - if storage_device.get_hide_passphrase_from_host(): - await confirm_action( - "passphrase_host1_hidden", - TR.passphrase__wallet, - description=TR.passphrase__from_host_not_shown, - prompt_screen=True, - prompt_title=TR.passphrase__access_wallet, - ) - else: - await confirm_action( - "passphrase_host1", - TR.passphrase__wallet, - description=TR.passphrase__next_screen_will_show_passphrase, - verb=TR.buttons__continue, - ) - - await confirm_blob( - "passphrase_host2", - TR.passphrase__title_confirm, - passphrase, - ) + await _handle_displaying_passphrase_from_host(passphrase) return passphrase + + +async def _handle_displaying_passphrase_from_host(passphrase: str) -> None: + from trezor import TR + from trezor.ui.layouts import confirm_action, confirm_blob + + # We want to hide the passphrase, or show it, according to settings. + if storage_device.get_hide_passphrase_from_host(): + await confirm_action( + "passphrase_host1_hidden", + TR.passphrase__wallet, + description=TR.passphrase__from_host_not_shown, + prompt_screen=True, + prompt_title=TR.passphrase__access_wallet, + ) + else: + await confirm_action( + "passphrase_host1", + TR.passphrase__wallet, + description=TR.passphrase__next_screen_will_show_passphrase, + verb=TR.buttons__continue, + ) + + await confirm_blob( + "passphrase_host2", + TR.passphrase__title_confirm, + passphrase, + ) diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index 95afa1b8fb0..988d828733f 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -1,9 +1,10 @@ import utime from typing import Any, NoReturn -import storage.cache as storage_cache +from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK from trezor import TR, config, utils, wire from trezor.ui.layouts import show_error_and_raise +from trezor.wire import context async def _request_sd_salt( @@ -77,7 +78,7 @@ async def request_pin_and_sd_salt( def _set_last_unlock_time() -> None: now = utime.ticks_ms() - storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) + context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) _DEF_ARG_PIN_ENTER: str = TR.pin__enter @@ -91,7 +92,7 @@ async def verify_user_pin( ) -> None: # _get_last_unlock_time last_unlock = int.from_bytes( - storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" + context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" ) if ( diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index dbdff4463e8..ddfe841f615 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -1,15 +1,15 @@ -import storage.cache as storage_cache import storage.device as storage_device -from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY +from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from trezor.enums import SafetyCheckLevel +from trezor.wire import context def read_setting() -> SafetyCheckLevel: """ Returns the effective safety check level. """ - temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) if temporary_safety_check_level: return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum] else: @@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None: Changes the safety level settings. """ if level == SafetyCheckLevel.Strict: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) elif level == SafetyCheckLevel.PromptAlways: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) - storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) + context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) else: raise ValueError("Unknown SafetyCheckLevel") diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 58846b4f9db..48cabdb9f9f 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -1,18 +1,33 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache import storage.device as storage_device +from storage.cache import check_thp_is_not_used +from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE from trezor import utils from trezor.crypto import hmac +from trezor.wire import context +from trezor.wire.context import get_context +from trezor.wire.errors import DataError + +from apps.common import cache from . import mnemonic -from .passphrase import get as get_passphrase +from .passphrase import get as get_passphrase_legacy +from .passphrase import get_passphrase as get_passphrase if TYPE_CHECKING: from trezor.crypto import bip32 + from trezor.messages import ThpCreateNewSession + from trezor.wire.protocol_common import Context from .paths import Bip32Path, Slip21Path +if not utils.BITCOIN_ONLY: + from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_COMMON_DERIVE_CARDANO, + ) + class Slip21Node: """ @@ -45,54 +60,90 @@ def clone(self) -> "Slip21Node": return Slip21Node(data=self.data) -if not utils.BITCOIN_ONLY: +async def get_seed() -> bytes: + common_seed = context.cache_get(APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + +if utils.BITCOIN_ONLY: + # === Bitcoin_only variant === + # We want to derive the normal seed ONLY + + async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None: + + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") + + if ctx.cache.is_set(APP_COMMON_SEED): + raise Exception("Seed is already set!") + + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + passphrase = await get_passphrase(msg) + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + +else: # === Cardano variant === # We want to derive both the normal seed and the Cardano seed together, AND # expose a method for Cardano to do the same - async def derive_and_store_roots() -> None: + async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None: + + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") + + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET): + raise Exception("Cardano icarus secret is already set!") + + passphrase = await get_passphrase(msg) + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if msg.derive_cardano: + from apps.cardano.seed import derive_and_store_secrets + + ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True) + derive_and_store_secrets(ctx, passphrase) + + @check_thp_is_not_used + async def derive_and_store_roots_legacy() -> None: from trezor import wire if not storage_device.is_initialized(): raise wire.NotInitialized("Device is not initialized") - need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED) - need_cardano_secret = storage_cache.get_bool( - storage_cache.APP_COMMON_DERIVE_CARDANO - ) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) + ctx = get_context() + need_seed = not ctx.cache.is_set(APP_COMMON_SEED) + need_cardano_secret = ctx.cache.get_bool( + APP_COMMON_DERIVE_CARDANO + ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET) if not need_seed and not need_cardano_secret: return - passphrase = await get_passphrase() + passphrase = await get_passphrase_legacy() if need_seed: common_seed = mnemonic.get_seed(passphrase) - storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed) + ctx.cache.set(APP_COMMON_SEED, common_seed) if need_cardano_secret: from apps.cardano.seed import derive_and_store_secrets - derive_and_store_secrets(passphrase) - - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) - async def get_seed() -> bytes: - await derive_and_store_roots() - common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED) - assert common_seed is not None - return common_seed - -else: - # === Bitcoin-only variant === - # We use the simple version of `get_seed` that never needs to derive anything else. - - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) - async def get_seed() -> bytes: - passphrase = await get_passphrase() - return mnemonic.get_seed(passphrase) + derive_and_store_secrets(ctx, passphrase) -@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) +@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) def _get_seed_without_passphrase() -> bytes: if not storage_device.is_initialized(): raise Exception("Device is not initialized") diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 94dd3bb783b..6e51e03d340 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -1,3 +1,5 @@ +from trezor.wire import message_handler + if not __debug__: from trezor.utils import halt @@ -246,7 +248,11 @@ async def dispatch_DebugLinkDecision( # If no exception was raised, the layout did not shut down. That means that it # just updated itself. The update is already live for the caller to retrieve. - def _state() -> DebugLinkState: + def _state( + thp_pairing_code_entry_code: int | None = None, + thp_pairing_code_qr_code: bytes | None = None, + thp_pairing_code_nfc_unidirectional: bytes | None = None, + ) -> DebugLinkState: from trezor.messages import DebugLinkState from apps.common import mnemonic, passphrase @@ -265,13 +271,45 @@ def callback(*args: str) -> None: passphrase_protection=passphrase.is_enabled(), reset_entropy=storage.reset_internal_entropy, tokens=tokens, + thp_pairing_code_entry_code=thp_pairing_code_entry_code, + thp_pairing_code_qr_code=thp_pairing_code_qr_code, + thp_pairing_code_nfc_unidirectional=thp_pairing_code_nfc_unidirectional, ) async def dispatch_DebugLinkGetState( msg: DebugLinkGetState, ) -> DebugLinkState | None: + + thp_pairing_code_entry_code: int | None = None + thp_pairing_code_qr_code: bytes | None = None + thp_pairing_code_nfc_unidirectional: bytes | None = None + if utils.USE_THP and msg.thp_channel_id is not None: + channel_id = int.from_bytes(msg.thp_channel_id, "big") + + from trezor.wire.thp.channel import Channel + from trezor.wire.thp.pairing_context import PairingContext + from trezor.wire.thp_main import _CHANNELS + + channel: Channel | None = None + ctx: PairingContext | None = None + try: + channel = _CHANNELS[channel_id] + ctx = channel.connection_context + except KeyError: + pass + if ctx is not None and isinstance(ctx, PairingContext): + thp_pairing_code_entry_code = ctx.display_data.code_code_entry + thp_pairing_code_qr_code = ctx.display_data.code_qr_code + thp_pairing_code_nfc_unidirectional = ( + ctx.display_data.code_nfc_unidirectional + ) + if msg.wait_layout == DebugWaitType.IMMEDIATE: - return _state() + return _state( + thp_pairing_code_entry_code, + thp_pairing_code_qr_code, + thp_pairing_code_nfc_unidirectional, + ) assert DEBUG_CONTEXT is not None if msg.wait_layout == DebugWaitType.NEXT_LAYOUT: @@ -282,7 +320,11 @@ async def dispatch_DebugLinkGetState( if not layout_is_ready(): return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True) else: - return _state() + return _state( + thp_pairing_code_entry_code, + thp_pairing_code_qr_code, + thp_pairing_code_nfc_unidirectional, + ) async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success: if msg.target_directory: @@ -362,7 +404,7 @@ async def handle_session(iface: WireInterface) -> None: global DEBUG_CONTEXT - DEBUG_CONTEXT = ctx = context.Context(iface, 0, WIRE_BUFFER_DEBUG) + DEBUG_CONTEXT = ctx = context.CodecContext(iface, WIRE_BUFFER_DEBUG) if storage.layout_watcher: try: @@ -387,14 +429,13 @@ async def handle_session(iface: WireInterface) -> None: msg_type = f"{msg.type} - unknown message type" log.debug( __name__, - "%s:%x receive: <%s>", + "%s receive: <%s>", ctx.iface.iface_num(), - ctx.sid, msg_type, ) if msg.type not in WORKFLOW_HANDLERS: - await ctx.write(wire.unexpected_message()) + await ctx.write(message_handler.unexpected_message()) continue elif req_type is None: @@ -405,7 +446,7 @@ async def handle_session(iface: WireInterface) -> None: await ctx.write(Success()) continue - req_msg = wire.wrap_protobuf_load(msg.data, req_type) + req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) try: res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) except Exception as exc: diff --git a/core/src/apps/management/get_nonce.py b/core/src/apps/management/get_nonce.py index 2eb97353406..16779251cc5 100644 --- a/core/src/apps/management/get_nonce.py +++ b/core/src/apps/management/get_nonce.py @@ -5,10 +5,11 @@ async def get_nonce(msg: GetNonce) -> Nonce: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto import random from trezor.messages import Nonce + from trezor.wire.context import cache_set nonce = random.bytes(32) - cache.set(cache.APP_COMMON_NONCE, nonce) + cache_set(APP_COMMON_NONCE, nonce) return Nonce(nonce=nonce) diff --git a/core/src/apps/management/reboot_to_bootloader.py b/core/src/apps/management/reboot_to_bootloader.py index 85596c0268d..2213d2c17a5 100644 --- a/core/src/apps/management/reboot_to_bootloader.py +++ b/core/src/apps/management/reboot_to_bootloader.py @@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn: boot_args = None ctx = get_context() - await ctx.write(Success(message="Rebooting")) + await ctx.write_force(Success(message="Rebooting")) # make sure the outgoing USB buffer is flushed await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) # reboot to the bootloader, pass the firmware header hash if any diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 68ca5297594..e375f6d866c 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -5,6 +5,7 @@ import storage.recovery_shares as storage_recovery_shares from trezor import TR, wire from trezor.messages import Success +from trezor.wire import message_handler from apps.common import backup_types @@ -38,7 +39,7 @@ async def recovery_process() -> Success: recovery_type = storage_recovery.get_type() - wire.AVOID_RESTARTING_FOR = ( + message_handler.AVOID_RESTARTING_FOR = ( MessageType.Initialize, MessageType.GetFeatures, MessageType.EndSession, @@ -59,7 +60,7 @@ async def _continue_repeated_backup() -> None: from apps.common import backup from apps.management.backup_device import perform_backup - wire.AVOID_RESTARTING_FOR = ( + message_handler.AVOID_RESTARTING_FOR = ( MessageType.Initialize, MessageType.GetFeatures, MessageType.EndSession, diff --git a/core/src/apps/management/wipe_device.py b/core/src/apps/management/wipe_device.py index b6e60057a6c..092b0ce04ba 100644 --- a/core/src/apps/management/wipe_device.py +++ b/core/src/apps/management/wipe_device.py @@ -1,12 +1,19 @@ from typing import TYPE_CHECKING +from trezor.wire.context import get_context + if TYPE_CHECKING: - from trezor.messages import Success, WipeDevice + from typing import NoReturn + + from trezor.messages import WipeDevice + +if __debug__: + from trezor import log -async def wipe_device(msg: WipeDevice) -> Success: +async def wipe_device(msg: WipeDevice) -> NoReturn: import storage - from trezor import TR, config, translations + from trezor import TR, config, loop, translations from trezor.enums import ButtonRequestType from trezor.messages import Success from trezor.pin import render_empty_loader @@ -27,8 +34,11 @@ async def wipe_device(msg: WipeDevice) -> Success: ) # start an empty progress screen so that the screen is not blank while waiting - render_empty_loader(config.StorageMessage.PROCESSING_MSG) + await get_context().write_force(Success(message="Device wiped")) + if __debug__: + log.debug(__name__, "Device wipe - start") + render_empty_loader(config.StorageMessage.PROCESSING_MSG) # wipe storage storage.wipe() # erase translations @@ -37,5 +47,7 @@ async def wipe_device(msg: WipeDevice) -> Success: # reload settings reload_settings_from_storage() + loop.clear() - return Success(message="Device wiped") + if __debug__: + log.debug(__name__, "Device wipe - finished") diff --git a/core/src/apps/monero/live_refresh.py b/core/src/apps/monero/live_refresh.py index 90d2dec6423..7cad50e2169 100644 --- a/core/src/apps/monero/live_refresh.py +++ b/core/src/apps/monero/live_refresh.py @@ -59,14 +59,15 @@ async def _init_step( ) -> MoneroLiveRefreshStartAck: import storage.cache as storage_cache from trezor.messages import MoneroLiveRefreshStartAck + from trezor.wire import context from apps.common import paths await paths.validate_path(keychain, msg.address_n) - if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH): + if not context.cache_get_bool(storage_cache.APP_MONERO_LIVE_REFRESH): await layout.require_confirm_live_refresh() - storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True) + context.cache_set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01") s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/core/src/apps/thp/create_new_session.py b/core/src/apps/thp/create_new_session.py new file mode 100644 index 00000000000..fc3d31950ca --- /dev/null +++ b/core/src/apps/thp/create_new_session.py @@ -0,0 +1,51 @@ +from trezor import log, loop +from trezor.enums import FailureType +from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession +from trezor.wire.context import get_context +from trezor.wire.errors import ActionCancelled, DataError +from trezor.wire.thp import SessionState + + +async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure: + from trezor.wire.thp.session_manager import create_new_session + + from apps.common.seed import derive_and_store_roots + + ctx = get_context() + + # Assert that context `ctx` is ManagementSessionContext + from trezor.wire.thp.session_context import ManagementSessionContext + + assert isinstance(ctx, ManagementSessionContext) + + channel = ctx.channel + + # Do not use `ctx` beyond this point, as it is techically + # allowed to change inbetween await statements + + new_session = create_new_session(channel) + try: + await derive_and_store_roots(new_session, message) + except DataError as e: + return Failure(code=FailureType.DataError, message=e.message) + except ActionCancelled as e: + return Failure(code=FailureType.ActionCancelled, message=e.message) + # TODO handle other errors + # TODO handle BITCOIN_ONLY + + new_session.set_session_state(SessionState.ALLOCATED) + channel.sessions[new_session.session_id] = new_session + loop.schedule(new_session.handle()) + new_session_id: int = new_session.session_id + # await get_seed() TODO + + if __debug__: + log.debug( + __name__, + "create_new_session - new session created. Passphrase: %s, Session id: %d\n%s", + message.passphrase if message.passphrase is not None else "", + new_session.session_id, + str(channel.sessions), + ) + + return ThpNewSession(new_session_id=new_session_id) diff --git a/core/src/apps/thp/credential_manager.py b/core/src/apps/thp/credential_manager.py index 73c1d0abcdb..42231dc1b91 100644 --- a/core/src/apps/thp/credential_manager.py +++ b/core/src/apps/thp/credential_manager.py @@ -7,7 +7,7 @@ ThpCredentialMetadata, ThpPairingCredential, ) -from trezor.wire import wrap_protobuf_load +from trezor.wire import message_handler if TYPE_CHECKING: from apps.common.paths import Slip21Path @@ -72,7 +72,9 @@ def validate_credential( """ cred_auth_key = derive_cred_auth_key() expected_type = protobuf.type_for_name("ThpPairingCredential") - credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type) + credential = message_handler.wrap_protobuf_load( + encoded_pairing_credential_message, expected_type + ) assert ThpPairingCredential.is_type_of(credential) proto_msg = ThpAuthenticatedCredentialData( host_static_pubkey=host_static_pubkey, diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py new file mode 100644 index 00000000000..34adcea7a5a --- /dev/null +++ b/core/src/apps/thp/pairing.py @@ -0,0 +1,412 @@ +from typing import TYPE_CHECKING +from ubinascii import hexlify + +from trezor import loop, protobuf +from trezor.crypto.hashlib import sha256 +from trezor.enums import MessageType, ThpPairingMethod +from trezor.messages import ( + Cancel, + ThpCodeEntryChallenge, + ThpCodeEntryCommitment, + ThpCodeEntryCpaceHost, + ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, + ThpCodeEntryTag, + ThpCredentialMetadata, + ThpCredentialRequest, + ThpCredentialResponse, + ThpEndRequest, + ThpEndResponse, + ThpNfcUnidirectionalSecret, + ThpNfcUnidirectionalTag, + ThpPairingPreparationsFinished, + ThpQrCodeSecret, + ThpQrCodeTag, + ThpStartPairingRequest, +) +from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage +from trezor.wire.thp import ChannelState, ThpError, crypto +from trezor.wire.thp.pairing_context import PairingContext + +from .credential_manager import issue_credential + +if __debug__: + from trezor import log + +if TYPE_CHECKING: + from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple + + P = ParamSpec("P") + FuncWithContext = Callable[Concatenate[PairingContext, P], Any] + +# +# Helpers - decorators + + +def check_state_and_log( + *allowed_states: ChannelState, +) -> Callable[[FuncWithContext], FuncWithContext]: + def decorator(f: FuncWithContext) -> FuncWithContext: + def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object: + _check_state(context, *allowed_states) + if __debug__: + try: + log.debug(__name__, "started %s", f.__name__) + except AttributeError: + log.debug( + __name__, + "started a function that cannot be named, because it raises AttributeError, eg. closure", + ) + return f(context, *args, **kwargs) + + return inner + + return decorator + + +def check_method_is_allowed( + pairing_method: ThpPairingMethod, +) -> Callable[[FuncWithContext], FuncWithContext]: + def decorator(f: FuncWithContext) -> FuncWithContext: + def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object: + _check_method_is_allowed(context, pairing_method) + return f(context, *args, **kwargs) + + return inner + + return decorator + + +# +# Pairing handlers + + +@check_state_and_log(ChannelState.TP1) +async def handle_pairing_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + + if not ThpStartPairingRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + + ctx.host_name = message.host_name or "" + + skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod) + if skip_pairing: + return await _end_pairing(ctx) + + await _prepare_pairing(ctx) + await ctx.write(ThpPairingPreparationsFinished()) + ctx.channel_ctx.set_channel_state(ChannelState.TP3) + response = await show_display_data( + ctx, _get_possible_pairing_methods_and_cancel(ctx) + ) + if __debug__: + from trezor.messages import DebugLinkGetState + + while DebugLinkGetState.is_type_of(response): + from apps.debug import dispatch_DebugLinkGetState + + dl_state = await dispatch_DebugLinkGetState(response) + assert dl_state is not None + await ctx.write(dl_state) + response = await show_display_data( + ctx, _get_possible_pairing_methods_and_cancel(ctx) + ) + if Cancel.is_type_of(response): + ctx.channel_ctx.clear() + raise SilentError("Action was cancelled by the Host") + # TODO disable NFC (if enabled) + response = await _handle_different_pairing_methods(ctx, response) + + while ThpCredentialRequest.is_type_of(response): + response = await _handle_credential_request(ctx, response) + + return await _handle_end_request(ctx, response) + + +async def _prepare_pairing(ctx: PairingContext) -> None: + + if _is_method_included(ctx, ThpPairingMethod.CodeEntry): + await _handle_code_entry_is_included(ctx) + + if _is_method_included(ctx, ThpPairingMethod.QrCode): + _handle_qr_code_is_included(ctx) + + if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional): + _handle_nfc_unidirectional_is_included(ctx) + + +async def show_display_data(ctx: PairingContext, expected_types: Container[int] = ()): + from trezorui2 import CANCELLED + + read_task = ctx.read(expected_types) + cancel_task = ctx.display_data.get_display_layout() + race = loop.race(read_task, cancel_task.get_result()) + result = await race + + if result is CANCELLED: + raise ActionCancelled + + return result + + +@check_state_and_log(ChannelState.TP1) +async def _handle_code_entry_is_included(ctx: PairingContext) -> None: + commitment = sha256(ctx.secret).digest() + + challenge_message = await ctx.call( # noqa: F841 + ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge + ) + ctx.channel_ctx.set_channel_state(ChannelState.TP2) + + if not ThpCodeEntryChallenge.is_type_of(challenge_message): + raise UnexpectedMessage("Unexpected message") + + if challenge_message.challenge is None: + raise Exception("Invalid message") + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.secret) + sha_ctx.update(challenge_message.challenge) + sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8")) + code_code_entry_hash = sha_ctx.digest() + ctx.display_data.code_code_entry = ( + int.from_bytes(code_code_entry_hash, "big") % 1000000 + ) + + +@check_state_and_log(ChannelState.TP1, ChannelState.TP2) +def _handle_qr_code_is_included(ctx: PairingContext) -> None: + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.secret) + sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8")) + ctx.display_data.code_qr_code = sha_ctx.digest()[:16] + + +@check_state_and_log(ChannelState.TP1, ChannelState.TP2) +def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None: + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.secret) + sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8")) + ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16] + + +@check_state_and_log(ChannelState.TP3) +async def _handle_different_pairing_methods( + ctx: PairingContext, response: protobuf.MessageType +) -> protobuf.MessageType: + if ThpCodeEntryCpaceHost.is_type_of(response): + return await _handle_code_entry_cpace(ctx, response) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcUnidirectionalTag.is_type_of(response): + return await _handle_nfc_unidirectional_tag(ctx, response) + raise UnexpectedMessage("Unexpected message") + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed(ThpPairingMethod.CodeEntry) +async def _handle_code_entry_cpace( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + from trezor.wire.thp.cpace import Cpace + + # TODO check that ThpCodeEntryCpaceHost message is valid + + if TYPE_CHECKING: + assert isinstance(message, ThpCodeEntryCpaceHost) + if message.cpace_host_public_key is None: + raise ThpError("Message ThpCodeEntryCpaceHost has no public key") + + ctx.cpace = Cpace( + message.cpace_host_public_key, + ctx.channel_ctx.get_handshake_hash(), + ) + assert ctx.display_data.code_code_entry is not None + ctx.cpace.generate_keys_and_secret( + ctx.display_data.code_code_entry.to_bytes(6, "big") + ) + + ctx.channel_ctx.set_channel_state(ChannelState.TP4) + response = await ctx.call( + ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key), + ThpCodeEntryTag, + ) + return await _handle_code_entry_tag(ctx, response) + + +@check_state_and_log(ChannelState.TP4) +@check_method_is_allowed(ThpPairingMethod.CodeEntry) +async def _handle_code_entry_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + + if TYPE_CHECKING: + assert isinstance(message, ThpCodeEntryTag) + + expected_tag = sha256(ctx.cpace.shared_secret).digest() + if expected_tag != message.tag: + print( + "expected code entry tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + print( + "expected code entry shared secret:", + hexlify(ctx.cpace.shared_secret).decode(), + ) # TODO remove after testing + raise ThpError("Unexpected Code Entry Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpCodeEntrySecret(secret=ctx.secret), + ) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed(ThpPairingMethod.QrCode) +async def _handle_qr_code_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + if TYPE_CHECKING: + assert isinstance(message, ThpQrCodeTag) + assert ctx.display_data.code_qr_code is not None + expected_tag = sha256(ctx.display_data.code_qr_code).digest() + if expected_tag != message.tag: + print( + "expected qr code tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + print( + "expected code qr code tag:", + hexlify(ctx.display_data.code_qr_code).decode(), + ) # TODO remove after testing + print( + "expected secret:", hexlify(ctx.secret).decode() + ) # TODO remove after testing + raise ThpError("Unexpected QR Code Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpQrCodeSecret(secret=ctx.secret), + ) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional) +async def _handle_nfc_unidirectional_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + if TYPE_CHECKING: + assert isinstance(message, ThpNfcUnidirectionalTag) + + expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest() + if expected_tag != message.tag: + print( + "expected nfc tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + raise ThpError("Unexpected NFC Unidirectional Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpNfcUnidirectionalSecret(secret=ctx.secret), + ) + + +@check_state_and_log(ChannelState.TP3, ChannelState.TP4) +async def _handle_secret_reveal( + ctx: PairingContext, + msg: protobuf.MessageType, +) -> protobuf.MessageType: + ctx.channel_ctx.set_channel_state(ChannelState.TC1) + return await ctx.call_any( + msg, + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) + + +@check_state_and_log(ChannelState.TC1) +async def _handle_credential_request( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + ctx.secret + + if not ThpCredentialRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + if message.host_static_pubkey is None: + raise Exception("Invalid message") # TODO change failure type + + trezor_static_pubkey = crypto.get_trezor_static_pubkey() + credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name) + credential = issue_credential(message.host_static_pubkey, credential_metadata) + + return await ctx.call_any( + ThpCredentialResponse( + trezor_static_pubkey=trezor_static_pubkey, credential=credential + ), + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) + + +@check_state_and_log(ChannelState.TC1) +async def _handle_end_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + if not ThpEndRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + return await _end_pairing(ctx) + + +async def _end_pairing(ctx: PairingContext) -> ThpEndResponse: + ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + return ThpEndResponse() + + +# +# Helpers - checkers + + +def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None: + if ctx.channel_ctx.get_channel_state() not in allowed_states: + raise UnexpectedMessage("Unexpected message") + + +def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None: + if not _is_method_included(ctx, method): + raise ThpError("Unexpected pairing method") + + +def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: + return method in ctx.channel_ctx.selected_pairing_methods + + +# +# Helpers - getters + + +def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]: + r = _get_possible_pairing_methods(ctx) + mtype = Cancel.MESSAGE_WIRE_TYPE + return r + ((mtype,) if mtype is not None else ()) + + +def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]: + r = tuple( + _get_message_type_for_method(method) + for method in ctx.channel_ctx.selected_pairing_methods + ) + if __debug__: + from trezor.messages import DebugLinkGetState + + mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE + return r + ((mtype,) if mtype is not None else ()) + return r + + +def _get_message_type_for_method(method: int) -> int: + if method is ThpPairingMethod.CodeEntry: + return MessageType.ThpCodeEntryCpaceHost + if method is ThpPairingMethod.NFC_Unidirectional: + return MessageType.ThpNfcUnidirectionalTag + if method is ThpPairingMethod.QrCode: + return MessageType.ThpQrCodeTag + raise ValueError("Unexpected pairing method - no message type available") diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index ab712e1d193..3b42dbf6db8 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -375,7 +375,7 @@ async def _read_cmd(iface: HID) -> Cmd | None: desc_cont = frame_cont() read = loop.wait(iface.iface_num() | io.POLL_READ) - # wait for incoming comand indefinitely + # wait for incoming command indefinitely buf = await read while True: ifrm = overlay_struct(bytearray(buf), desc_init) diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py index 2bc7ac692b8..ffcf353cd7c 100644 --- a/core/src/apps/workflow_handlers.py +++ b/core/src/apps/workflow_handlers.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from trezorio import WireInterface - from trezor.wire import Handler, Msg @@ -31,6 +29,11 @@ def _find_message_handler_module(msg_type: int) -> str: if __debug__ and msg_type == MessageType.LoadDevice: return "apps.debug.load_device" + if utils.USE_THP: + # thp management + if msg_type == MessageType.ThpCreateNewSession: + return "apps.thp.create_new_session" + # management if msg_type == MessageType.ResetDevice: return "apps.management.reset_device" @@ -209,7 +212,7 @@ def _find_message_handler_module(msg_type: int) -> str: raise ValueError -def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None: +def find_registered_handler(msg_type: int) -> Handler | None: if msg_type in workflow_handlers: # Message has a handler available, return it directly. return workflow_handlers[msg_type] diff --git a/core/src/boot.py b/core/src/boot.py index 01777b7a9cf..e0232714023 100644 --- a/core/src/boot.py +++ b/core/src/boot.py @@ -111,8 +111,9 @@ async def bootscreen() -> None: config.init(show_pin_timeout) translations.init() -if __debug__ and not utils.EMULATOR: - config.wipe() +# TODO return after testing +# if __debug__ and not utils.EMULATOR: +# config.wipe() loop.schedule(bootscreen()) loop.run() diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 2228b3532b9..ca88caf97f2 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,153 +1,10 @@ import builtins import gc -from micropython import const from typing import TYPE_CHECKING +from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache from trezor import utils -if TYPE_CHECKING: - from typing import Sequence, TypeVar, overload - - T = TypeVar("T") - - -_MAX_SESSIONS_COUNT = const(10) -_SESSIONLESS_FLAG = const(128) -_SESSION_ID_LENGTH = const(32) - -# Traditional cache keys -APP_COMMON_SEED = const(0) -APP_COMMON_AUTHORIZATION_TYPE = const(1) -APP_COMMON_AUTHORIZATION_DATA = const(2) -APP_COMMON_NONCE = const(3) -if not utils.BITCOIN_ONLY: - APP_COMMON_DERIVE_CARDANO = const(4) - APP_CARDANO_ICARUS_SECRET = const(5) - APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) - APP_MONERO_LIVE_REFRESH = const(7) - -# Keys that are valid across sessions -APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG) -APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG) -APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG) -APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG) -APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG) -APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG) -APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG) - -# === Homescreen storage === -# This does not logically belong to the "cache" functionality, but the cache module is -# a convenient place to put this. -# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown` -# to know whether it should render itself or whether the result of a previous instance -# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends. -HOMESCREEN_ON = object() -LOCKSCREEN_ON = object() -BUSYSCREEN_ON = object() -homescreen_shown: object | None = None - -# Timestamp of last autolock activity. -# Here to persist across main loop restart between workflows. -autolock_last_touch: int | None = None - - -class InvalidSessionError(Exception): - pass - - -class DataCache: - fields: Sequence[int] # field sizes - - def __init__(self) -> None: - self.data = [bytearray(f + 1) for f in self.fields] - - def set(self, key: int, value: bytes) -> None: - utils.ensure(key < len(self.fields)) - utils.ensure(len(value) <= self.fields[key]) - self.data[key][0] = 1 - self.data[key][1:] = value - - if TYPE_CHECKING: - - @overload - def get(self, key: int) -> bytes | None: ... - - @overload - def get(self, key: int, default: T) -> bytes | T: # noqa: F811 - ... - - def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - utils.ensure(key < len(self.fields)) - if self.data[key][0] != 1: - return default - return bytes(self.data[key][1:]) - - def is_set(self, key: int) -> bool: - utils.ensure(key < len(self.fields)) - return self.data[key][0] == 1 - - def delete(self, key: int) -> None: - utils.ensure(key < len(self.fields)) - self.data[key][:] = b"\x00" - - def clear(self) -> None: - for i in range(len(self.fields)): - self.delete(i) - - -class SessionCache(DataCache): - def __init__(self) -> None: - self.session_id = bytearray(_SESSION_ID_LENGTH) - if utils.BITCOIN_ONLY: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - ) - else: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - 0, # APP_COMMON_DERIVE_CARDANO - 96, # APP_CARDANO_ICARUS_SECRET - 96, # APP_CARDANO_ICARUS_TREZOR_SECRET - 0, # APP_MONERO_LIVE_REFRESH - ) - self.last_usage = 0 - super().__init__() - - def export_session_id(self) -> bytes: - from trezorcrypto import random # avoid pulling in trezor.crypto - - # generate a new session id if we don't have it yet - if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) - # export it as immutable bytes - return bytes(self.session_id) - - def clear(self) -> None: - super().clear() - self.last_usage = 0 - self.session_id[:] = b"" - - -class SessionlessCache(DataCache): - def __init__(self) -> None: - self.fields = ( - 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE - 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY - 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK - 8, # APP_COMMON_BUSY_DEADLINE_MS - 32, # APP_MISC_COSI_NONCE - 32, # APP_MISC_COSI_COMMITMENT - 0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED - ) - super().__init__() - - # XXX # Allocation notes: # Instantiation of a DataCache subclass should make as little garbage as possible, so @@ -156,210 +13,61 @@ def __init__(self) -> None: # bytearrays, then later call `clear()` on all the existing objects, which resets them # to zero length. This is producing some trash - `b[:]` allocates a slice. -_SESSIONS: list[SessionCache] = [] -for _ in range(_MAX_SESSIONS_COUNT): - _SESSIONS.append(SessionCache()) - _SESSIONLESS_CACHE = SessionlessCache() -for session in _SESSIONS: - session.clear() -_SESSIONLESS_CACHE.clear() - -gc.collect() - - -_active_session_idx: int | None = None -_session_usage_counter = 0 - - -def start_session(received_session_id: bytes | None = None) -> bytes: - global _active_session_idx - global _session_usage_counter - - if ( - received_session_id is not None - and len(received_session_id) != _SESSION_ID_LENGTH - ): - # Prevent the caller from setting received_session_id=b"" and finding a cleared - # session. More generally, short-circuit the session id search, because we know - # that wrong-length session ids should not be in cache. - # Reduce to "session id not provided" case because that's what we do when - # caller supplies an id that is not found. - received_session_id = None - - _session_usage_counter += 1 - - # attempt to find specified session id - if received_session_id: - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].session_id == received_session_id: - _active_session_idx = i - _SESSIONS[i].last_usage = _session_usage_counter - return received_session_id - # allocate least recently used session - lru_counter = _session_usage_counter - lru_session_idx = 0 - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].last_usage < lru_counter: - lru_counter = _SESSIONS[i].last_usage - lru_session_idx = i +if utils.USE_THP: + from storage import cache_thp - _active_session_idx = lru_session_idx - selected_session = _SESSIONS[lru_session_idx] - selected_session.clear() - selected_session.last_usage = _session_usage_counter - return selected_session.export_session_id() + _PROTOCOL_CACHE = cache_thp +else: + from storage import cache_codec + _PROTOCOL_CACHE = cache_codec -def end_current_session() -> None: - global _active_session_idx - - if _active_session_idx is None: - return - - _SESSIONS[_active_session_idx].clear() - _active_session_idx = None - - -def set(key: int, value: bytes) -> None: - if key & _SESSIONLESS_FLAG: - _SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value) - return - if _active_session_idx is None: - raise InvalidSessionError - _SESSIONS[_active_session_idx].set(key, value) - - -def _get_length(key: int) -> int: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG] - elif _active_session_idx is None: - raise InvalidSessionError - else: - return _SESSIONS[_active_session_idx].fields[key] - - -def set_int(key: int, value: int) -> None: - length = _get_length(key) - - encoded = value.to_bytes(length, "big") - - # Ensure that the value fits within the length. Micropython's int.to_bytes() - # doesn't raise OverflowError. - assert int.from_bytes(encoded, "big") == value - - set(key, encoded) - - -def set_bool(key: int, value: bool) -> None: - assert _get_length(key) == 0 # skipping get_length in production build - if value: - set(key, b"") - else: - delete(key) - - -if TYPE_CHECKING: - - @overload - def get(key: int) -> bytes | None: ... - - @overload - def get(key: int, default: T) -> bytes | T: # noqa: F811 - ... - - -def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].get(key, default) - +_PROTOCOL_CACHE.initialize() +_SESSIONLESS_CACHE.clear() -def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 - encoded = get(key) - if encoded is None: - return default - else: - return int.from_bytes(encoded, "big") +gc.collect() -def get_bool(key: int) -> bool: # noqa: F811 - return get(key) is not None +def clear_all() -> None: + global autolock_last_touch + autolock_last_touch = None + _SESSIONLESS_CACHE.clear() + _PROTOCOL_CACHE.clear_all() def get_int_all_sessions(key: int) -> builtins.set[int]: - sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS - values = builtins.set() - for session in sessions: - encoded = session.get(key) + if key & SESSIONLESS_FLAG: + values = builtins.set() + encoded = _SESSIONLESS_CACHE.get(key) if encoded is not None: values.add(int.from_bytes(encoded, "big")) - return values - - -def is_set(key: int) -> bool: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].is_set(key) + return values + return _PROTOCOL_CACHE.get_int_all_sessions(key) -def delete(key: int) -> None: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].delete(key) +def get_sessionless_cache() -> SessionlessCache: + return _SESSIONLESS_CACHE if TYPE_CHECKING: - from typing import Awaitable, Callable, ParamSpec, TypeVar + from typing import Callable, ParamSpec, TypeVar + T = TypeVar("T") P = ParamSpec("P") - ByteFunc = Callable[P, bytes] - AsyncByteFunc = Callable[P, Awaitable[bytes]] - -def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: - def decorator(func: ByteFunc[P]) -> ByteFunc[P]: - def wrapper(*args: P.args, **kwargs: P.kwargs): - value = get(key) - if value is None: - value = func(*args, **kwargs) - set(key, value) - return value - return wrapper +def check_thp_is_not_used(f: Callable[P, T]) -> Callable[P, T]: + """A type-safe decorator to raise an exception when the function is called with THP enabled. - return decorator + This decorator should be removed after the caches for Codec_v1 and THP are properly refactored and separated. + """ + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + if utils.USE_THP: + raise Exception("Cannot call this function with the new THP enabled") + return f(*args, **kwargs) -def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: - def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: - async def wrapper(*args: P.args, **kwargs: P.kwargs): - value = get(key) - if value is None: - value = await func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator - - -def clear_all() -> None: - global _active_session_idx - global autolock_last_touch - - _active_session_idx = None - _SESSIONLESS_CACHE.clear() - for session in _SESSIONS: - session.clear() - - autolock_last_touch = None + return inner diff --git a/core/src/storage/cache_codec.py b/core/src/storage/cache_codec.py new file mode 100644 index 00000000000..9bc193f5ae7 --- /dev/null +++ b/core/src/storage/cache_codec.py @@ -0,0 +1,142 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache +from trezor import utils + +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar("T") + + +_MAX_SESSIONS_COUNT = const(10) +SESSION_ID_LENGTH = const(32) + + +class SessionCache(DataCache): + def __init__(self) -> None: + self.session_id = bytearray(SESSION_ID_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 0, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 0, # APP_MONERO_LIVE_REFRESH + ) + self.last_usage = 0 + super().__init__() + + def export_session_id(self) -> bytes: + from trezorcrypto import random # avoid pulling in trezor.crypto + + # generate a new session id if we don't have it yet + if not self.session_id: + self.session_id[:] = random.bytes(SESSION_ID_LENGTH) + # export it as immutable bytes + return bytes(self.session_id) + + def clear(self) -> None: + super().clear() + self.last_usage = 0 + self.session_id[:] = b"" + + +_SESSIONS: list[SessionCache] = [] + + +def initialize() -> None: + global _SESSIONS + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionCache()) + + for session in _SESSIONS: + session.clear() + + +_active_session_idx: int | None = None +_session_usage_counter = 0 + + +def get_active_session() -> SessionCache | None: + if _active_session_idx is None: + return None + return _SESSIONS[_active_session_idx] + + +def start_session(received_session_id: bytes | None = None) -> bytes: + global _active_session_idx + global _session_usage_counter + + if ( + received_session_id is not None + and len(received_session_id) != SESSION_ID_LENGTH + ): + # Prevent the caller from setting received_session_id=b"" and finding a cleared + # session. More generally, short-circuit the session id search, because we know + # that wrong-length session ids should not be in cache. + # Reduce to "session id not provided" case because that's what we do when + # caller supplies an id that is not found. + received_session_id = None + + _session_usage_counter += 1 + + # attempt to find specified session id + if received_session_id: + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].session_id == received_session_id: + _active_session_idx = i + _SESSIONS[i].last_usage = _session_usage_counter + return received_session_id + + # allocate least recently used session + lru_counter = _session_usage_counter + lru_session_idx = 0 + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].last_usage < lru_counter: + lru_counter = _SESSIONS[i].last_usage + lru_session_idx = i + + _active_session_idx = lru_session_idx + selected_session = _SESSIONS[lru_session_idx] + selected_session.clear() + selected_session.last_usage = _session_usage_counter + return selected_session.export_session_id() + + +def end_current_session() -> None: + global _active_session_idx + + if _active_session_idx is None: + return + + _SESSIONS[_active_session_idx].clear() + _active_session_idx = None + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_all() -> None: + global _active_session_idx + _active_session_idx = None + for session in _SESSIONS: + session.clear() diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py new file mode 100644 index 00000000000..bd714fa9b8d --- /dev/null +++ b/core/src/storage/cache_common.py @@ -0,0 +1,179 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import utils + +# Traditional cache keys +APP_COMMON_SEED = const(0) +APP_COMMON_AUTHORIZATION_TYPE = const(1) +APP_COMMON_AUTHORIZATION_DATA = const(2) +APP_COMMON_NONCE = const(3) +if not utils.BITCOIN_ONLY: + APP_COMMON_DERIVE_CARDANO = const(4) + APP_CARDANO_ICARUS_SECRET = const(5) + APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) + APP_MONERO_LIVE_REFRESH = const(7) + +# Cache keys for THP channel +if utils.USE_THP: + CHANNEL_HANDSHAKE_HASH = const(0) + CHANNEL_KEY_RECEIVE = const(1) + CHANNEL_KEY_SEND = const(2) + CHANNEL_NONCE_RECEIVE = const(3) + CHANNEL_NONCE_SEND = const(4) + +# Keys that are valid across sessions +SESSIONLESS_FLAG = const(128) +APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) +APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG) +APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG) +APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG) +APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG) +APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG) +APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG) + + +# === Homescreen storage === +# This does not logically belong to the "cache" functionality, but the cache module is +# a convenient place to put this. +# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown` +# to know whether it should render itself or whether the result of a previous instance +# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends. +HOMESCREEN_ON = object() +LOCKSCREEN_ON = object() +BUSYSCREEN_ON = object() +homescreen_shown: object | None = None + +# Timestamp of last autolock activity. +# Here to persist across main loop restart between workflows. +autolock_last_touch: int | None = None + + +if TYPE_CHECKING: + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + + +class InvalidSessionError(Exception): + pass + + +class DataCache: + fields: Sequence[int] + + def __init__(self) -> None: + self.data = [bytearray(f + 1) for f in self.fields] + + if TYPE_CHECKING: + + @overload + def get(self, key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + utils.ensure(key < len(self.fields)) + if self.data[key][0] != 1: + return default + return bytes(self.data[key][1:]) + + def get_bool(self, key: int) -> bool: # noqa: F811 + return self.get(key) is not None + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + encoded = self.get(key) + if encoded is None: + return default + else: + return int.from_bytes(encoded, "big") + + def is_set(self, key: int) -> bool: + utils.ensure(key < len(self.fields)) + return self.data[key][0] == 1 + + def set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.fields)) + utils.ensure(len(value) <= self.fields[key]) + self.data[key][0] = 1 + self.data[key][1:] = value + + def set_bool(self, key: int, value: bool) -> None: + utils.ensure( + self._get_length(key) == 0, "Field does not have zero length!" + ) # skipping get_length in production build + if value: + self.set(key, b"") + else: + self.delete(key) + + def set_int(self, key: int, value: int) -> None: + length = self.fields[key] + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + self.set(key, encoded) + + def delete(self, key: int) -> None: + utils.ensure(key < len(self.fields)) + self.data[key][:] = b"\x00" + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) + + def _get_length(self, key: int) -> int: + utils.ensure(key < len(self.fields)) + return self.fields[key] + + +class SessionlessCache(DataCache): + def __init__(self) -> None: + self.fields = ( + 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE + 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY + 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK + 8, # APP_COMMON_BUSY_DEADLINE_MS + 32, # APP_MISC_COSI_NONCE + 32, # APP_MISC_COSI_COMMITMENT + 0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED + ) + super().__init__() + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + return super().get(key & ~SESSIONLESS_FLAG, default) + + def get_bool(self, key: int) -> bool: # noqa: F811 + return super().get_bool(key & ~SESSIONLESS_FLAG) + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + return super().get_int(key & ~SESSIONLESS_FLAG, default) + + def is_set(self, key: int) -> bool: + return super().is_set(key & ~SESSIONLESS_FLAG) + + def set(self, key: int, value: bytes) -> None: + super().set(key & ~SESSIONLESS_FLAG, value) + + def set_bool(self, key: int, value: bool) -> None: + super().set_bool(key & ~SESSIONLESS_FLAG, value) + + def set_int(self, key: int, value: int) -> None: + super().set_int(key & ~SESSIONLESS_FLAG, value) + + def delete(self, key: int) -> None: + super().delete(key & ~SESSIONLESS_FLAG) + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py new file mode 100644 index 00000000000..2c1514df792 --- /dev/null +++ b/core/src/storage/cache_thp.py @@ -0,0 +1,334 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache +from trezor import utils + +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar("T") + +if __debug__: + from trezor import log + +# THP specific constants +_MAX_CHANNELS_COUNT = const(10) +_MAX_SESSIONS_COUNT = const(20) + + +_CHANNEL_STATE_LENGTH = const(1) +_WIRE_INTERFACE_LENGTH = const(1) +_SESSION_STATE_LENGTH = const(1) +_CHANNEL_ID_LENGTH = const(2) +SESSION_ID_LENGTH = const(1) +BROADCAST_CHANNEL_ID = const(0xFFFF) +KEY_LENGTH = const(32) +TAG_LENGTH = const(16) +_UNALLOCATED_STATE = const(0) +MANAGEMENT_SESSION_ID = const(0) + + +class ConnectionCache(DataCache): + def __init__(self) -> None: + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + self.last_usage = 0 + super().__init__() + + def clear(self) -> None: + self.channel_id[:] = b"" + self.last_usage = 0 + super().clear() + + +class ChannelCache(ConnectionCache): + def __init__(self) -> None: + self.host_ephemeral_pubkey = bytearray(KEY_LENGTH) + self.state = bytearray(_CHANNEL_STATE_LENGTH) + self.iface = bytearray(1) # TODO add decoding + self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) + self.session_id_counter = 0x00 + self.fields = ( + 32, # CHANNEL_HANDSHAKE_HASH + 32, # CHANNEL_KEY_RECEIVE + 32, # CHANNEL_KEY_SEND + 8, # CHANNEL_NONCE_RECEIVE + 8, # CHANNEL_NONCE_SEND + ) + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray( + int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big") + ) # Set state to UNALLOCATED + # TODO clear all keys + super().clear() + + +class SessionThpCache(ConnectionCache): + def __init__(self) -> None: + self.session_id = bytearray(SESSION_ID_LENGTH) + self.state = bytearray(_SESSION_STATE_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 0, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 0, # APP_MONERO_LIVE_REFRESH + ) + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED + self.session_id[:] = b"" + super().clear() + + +_CHANNELS: list[ChannelCache] = [] +_SESSIONS: list[SessionThpCache] = [] + + +def initialize() -> None: + global _CHANNELS + global _SESSIONS + + for _ in range(_MAX_CHANNELS_COUNT): + _CHANNELS.append(ChannelCache()) + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionThpCache()) + + for channel in _CHANNELS: + channel.clear() + for session in _SESSIONS: + session.clear() + + +# First unauthenticated channel will have index 0 +_usage_counter = 0 + +# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex) +cid_counter: int = 4659 # TODO change to random value on start + + +def get_new_channel(iface: bytes) -> ChannelCache: + if len(iface) != _WIRE_INTERFACE_LENGTH: + raise Exception("Invalid WireInterface (encoded) length") + + new_cid = get_next_channel_id() + index = _get_next_unauthenticated_channel_index() + + # clear sessions from replaced channel + if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE: + old_cid = _CHANNELS[index].channel_id + clear_sessions_with_channel_id(old_cid) + + _CHANNELS[index] = ChannelCache() + _CHANNELS[index].channel_id[:] = new_cid + _CHANNELS[index].last_usage = _get_usage_counter_and_increment() + _CHANNELS[index].state[:] = bytearray( + _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big") + ) + _CHANNELS[index].iface[:] = bytearray(iface) + return _CHANNELS[index] + + +def update_channel_last_used(channel_id): + for channel in _CHANNELS: + if channel.channel_id == channel_id: + channel.last_usage = _get_usage_counter_and_increment() + return + + +def update_session_last_used(channel_id, session_id): + for session in _SESSIONS: + if session.channel_id == channel_id and session.session_id == session_id: + session.last_usage = _get_usage_counter_and_increment() + update_channel_last_used(channel_id) + return + + +def get_all_allocated_channels() -> list[ChannelCache]: + _list: list[ChannelCache] = [] + for channel in _CHANNELS: + if _get_channel_state(channel) != _UNALLOCATED_STATE: + _list.append(channel) + return _list + + +def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]: + if __debug__: + from trezor.utils import get_bytes_as_str + _list: list[SessionThpCache] = [] + for session in _SESSIONS: + if _get_session_state(session) == _UNALLOCATED_STATE: + continue + if session.channel_id != channel_id: + continue + _list.append(session) + if __debug__: + log.debug( + __name__, + "session with channel_id: %s and session_id: %s is in ALLOCATED state", + get_bytes_as_str(session.channel_id), + get_bytes_as_str(session.session_id), + ) + + return _list + + +def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None: + if len(key) != KEY_LENGTH: + raise Exception("Invalid key length") + channel.host_ephemeral_pubkey = key + + +def get_new_session(channel: ChannelCache): + new_sid = get_next_session_id(channel) + index = _get_next_session_index() + + _SESSIONS[index] = SessionThpCache() + _SESSIONS[index].channel_id[:] = channel.channel_id + _SESSIONS[index].session_id[:] = new_sid + _SESSIONS[index].last_usage = _get_usage_counter_and_increment() + channel.last_usage = ( + _get_usage_counter_and_increment() + ) # increment also use of the channel so it does not get replaced + _SESSIONS[index].state[:] = bytearray( + _UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big") + ) + return _SESSIONS[index] + + +def _get_usage_counter() -> int: + global _usage_counter + return _usage_counter + + +def _get_usage_counter_and_increment() -> int: + global _usage_counter + _usage_counter += 1 + return _usage_counter + + +def _get_next_unauthenticated_channel_index() -> int: + idx = _get_unallocated_channel_index() + if idx is not None: + return idx + return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT) + + +def _get_next_session_index() -> int: + idx = _get_unallocated_session_index() + if idx is not None: + return idx + return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT) + + +def _get_unallocated_channel_index() -> int | None: + for i in range(_MAX_CHANNELS_COUNT): + if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_unallocated_session_index() -> int | None: + for i in range(_MAX_SESSIONS_COUNT): + if (_SESSIONS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_channel_state(channel: ChannelCache) -> int: + return int.from_bytes(channel.state, "big") + + +def _get_session_state(session: SessionThpCache) -> int: + return int.from_bytes(session.state, "big") + + +def get_next_channel_id() -> bytes: + global cid_counter + while True: + cid_counter += 1 + if cid_counter >= BROADCAST_CHANNEL_ID: + cid_counter = 1 + if _is_cid_unique(): + break + return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") + + +def get_next_session_id(channel: ChannelCache) -> bytes: + while True: + if channel.session_id_counter >= 255: + channel.session_id_counter = 1 + else: + channel.session_id_counter += 1 + if _is_session_id_unique(channel): + break + new_sid = channel.session_id_counter + return new_sid.to_bytes(SESSION_ID_LENGTH, "big") + + +def _is_session_id_unique(channel: ChannelCache) -> bool: + for session in _SESSIONS: + if session.channel_id == channel.channel_id: + if session.session_id == channel.session_id_counter: + return False + return True + + +def _is_cid_unique() -> bool: + for session in _SESSIONS: + if cid_counter == _get_cid(session): + return False + return True + + +def _get_cid(session: SessionThpCache) -> int: + return int.from_bytes(session.session_id[2:], "big") + + +def get_least_recently_used_item( + list: list[ChannelCache] | list[SessionThpCache], max_count: int +): + lru_counter = _get_usage_counter() + lru_item_index = 0 + for i in range(max_count): + if list[i].last_usage < lru_counter: + lru_counter = list[i].last_usage + lru_item_index = i + return lru_item_index + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_sessions_with_channel_id(channel_id: bytes): + for session in _SESSIONS: + if session.channel_id == channel_id: + session.clear() + + +def clear_all() -> None: + for session in _SESSIONS: + session.clear() + for channel in _CHANNELS: + channel.clear() diff --git a/core/src/trezor/enums/FailureType.py b/core/src/trezor/enums/FailureType.py index fbb2001e54c..883844307a1 100644 --- a/core/src/trezor/enums/FailureType.py +++ b/core/src/trezor/enums/FailureType.py @@ -16,4 +16,6 @@ PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 +ThpUnallocatedSession = 15 +InvalidProtocol = 16 FirmwareError = 99 diff --git a/core/src/trezor/enums/MessageType.py b/core/src/trezor/enums/MessageType.py index ae94e76acee..6fe8acea397 100644 --- a/core/src/trezor/enums/MessageType.py +++ b/core/src/trezor/enums/MessageType.py @@ -95,6 +95,24 @@ DebugLinkWatchLayout = 9006 DebugLinkResetDebugEvents = 9007 DebugLinkOptigaSetSecMax = 9008 +ThpCreateNewSession = 1000 +ThpNewSession = 1001 +ThpStartPairingRequest = 1008 +ThpPairingPreparationsFinished = 1009 +ThpCredentialRequest = 1010 +ThpCredentialResponse = 1011 +ThpEndRequest = 1012 +ThpEndResponse = 1013 +ThpCodeEntryCommitment = 1016 +ThpCodeEntryChallenge = 1017 +ThpCodeEntryCpaceHost = 1018 +ThpCodeEntryCpaceTrezor = 1019 +ThpCodeEntryTag = 1020 +ThpCodeEntrySecret = 1021 +ThpQrCodeTag = 1024 +ThpQrCodeSecret = 1025 +ThpNfcUnidirectionalTag = 1032 +ThpNfcUnidirectionalSecret = 1033 if not utils.BITCOIN_ONLY: SetU2FCounter = 63 GetNextU2FCounter = 80 diff --git a/core/src/trezor/enums/ThpPairingMethod.py b/core/src/trezor/enums/ThpPairingMethod.py new file mode 100644 index 00000000000..b356cdf470b --- /dev/null +++ b/core/src/trezor/enums/ThpPairingMethod.py @@ -0,0 +1,8 @@ +# Automatically generated by pb2py +# fmt: off +# isort:skip_file + +NoMethod = 1 +CodeEntry = 2 +QrCode = 3 +NFC_Unidirectional = 4 diff --git a/core/src/trezor/enums/__init__.py b/core/src/trezor/enums/__init__.py index 5567ec9e15c..b93d6d22194 100644 --- a/core/src/trezor/enums/__init__.py +++ b/core/src/trezor/enums/__init__.py @@ -266,6 +266,24 @@ class MessageType(IntEnum): SolanaAddress = 903 SolanaSignTx = 904 SolanaTxSignature = 905 + ThpCreateNewSession = 1000 + ThpNewSession = 1001 + ThpStartPairingRequest = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceHost = 1018 + ThpCodeEntryCpaceTrezor = 1019 + ThpCodeEntryTag = 1020 + ThpCodeEntrySecret = 1021 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcUnidirectionalTag = 1032 + ThpNfcUnidirectionalSecret = 1033 class FailureType(IntEnum): UnexpectedMessage = 1 @@ -282,6 +300,8 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 FirmwareError = 99 class ButtonRequestType(IntEnum): @@ -579,3 +599,9 @@ class TezosBallotType(IntEnum): Yay = 0 Nay = 1 Pass = 2 + + class ThpPairingMethod(IntEnum): + NoMethod = 1 + CodeEntry = 2 + QrCode = 3 + NFC_Unidirectional = 4 diff --git a/core/src/trezor/log.py b/core/src/trezor/log.py index 4c3aa65d051..199ccd3efc6 100644 --- a/core/src/trezor/log.py +++ b/core/src/trezor/log.py @@ -60,7 +60,7 @@ def exception(name: str, exc: BaseException) -> None: name, _DEBUG, "ui.Result: %s", - exc.value, # type: ignore[Cannot access attribute "value" for class "BaseException"] + exc.value, # type: ignore [Cannot access attribute "value" for class "BaseException"] ) elif exc.__class__.__name__ == "Cancelled": _log(name, _DEBUG, "ui.Cancelled") diff --git a/core/src/trezor/messages.py b/core/src/trezor/messages.py index 3c5e2d1f290..0836f679e14 100644 --- a/core/src/trezor/messages.py +++ b/core/src/trezor/messages.py @@ -66,6 +66,7 @@ def __getattr__(name: str) -> Any: from trezor.enums import StellarSignerType # noqa: F401 from trezor.enums import TezosBallotType # noqa: F401 from trezor.enums import TezosContractType # noqa: F401 + from trezor.enums import ThpPairingMethod # noqa: F401 from trezor.enums import WordRequestType # noqa: F401 class BinanceGetAddress(protobuf.MessageType): @@ -2812,11 +2813,13 @@ def is_type_of(cls, msg: Any) -> TypeGuard["DebugLinkRecordScreen"]: class DebugLinkGetState(protobuf.MessageType): wait_layout: "DebugWaitType" + thp_channel_id: "bytes | None" def __init__( self, *, wait_layout: "DebugWaitType | None" = None, + thp_channel_id: "bytes | None" = None, ) -> None: pass @@ -2838,6 +2841,9 @@ class DebugLinkState(protobuf.MessageType): reset_word_pos: "int | None" mnemonic_type: "BackupType | None" tokens: "list[str]" + thp_pairing_code_entry_code: "int | None" + thp_pairing_code_qr_code: "bytes | None" + thp_pairing_code_nfc_unidirectional: "bytes | None" def __init__( self, @@ -2855,6 +2861,9 @@ def __init__( recovery_word_pos: "int | None" = None, reset_word_pos: "int | None" = None, mnemonic_type: "BackupType | None" = None, + thp_pairing_code_entry_code: "int | None" = None, + thp_pairing_code_qr_code: "bytes | None" = None, + thp_pairing_code_nfc_unidirectional: "bytes | None" = None, ) -> None: pass @@ -6076,6 +6085,278 @@ def __init__( def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]: return isinstance(msg, cls) + class ThpDeviceProperties(protobuf.MessageType): + internal_model: "str | None" + model_variant: "int | None" + bootloader_mode: "bool | None" + protocol_version: "int | None" + pairing_methods: "list[ThpPairingMethod]" + + def __init__( + self, + *, + pairing_methods: "list[ThpPairingMethod] | None" = None, + internal_model: "str | None" = None, + model_variant: "int | None" = None, + bootloader_mode: "bool | None" = None, + protocol_version: "int | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]: + return isinstance(msg, cls) + + class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + host_pairing_credential: "bytes | None" + pairing_methods: "list[ThpPairingMethod]" + + def __init__( + self, + *, + pairing_methods: "list[ThpPairingMethod] | None" = None, + host_pairing_credential: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]: + return isinstance(msg, cls) + + class ThpCreateNewSession(protobuf.MessageType): + passphrase: "str | None" + on_device: "bool | None" + derive_cardano: "bool | None" + + def __init__( + self, + *, + passphrase: "str | None" = None, + on_device: "bool | None" = None, + derive_cardano: "bool | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]: + return isinstance(msg, cls) + + class ThpNewSession(protobuf.MessageType): + new_session_id: "int | None" + + def __init__( + self, + *, + new_session_id: "int | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]: + return isinstance(msg, cls) + + class ThpStartPairingRequest(protobuf.MessageType): + host_name: "str | None" + + def __init__( + self, + *, + host_name: "str | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]: + return isinstance(msg, cls) + + class ThpPairingPreparationsFinished(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]: + return isinstance(msg, cls) + + class ThpCodeEntryCommitment(protobuf.MessageType): + commitment: "bytes | None" + + def __init__( + self, + *, + commitment: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]: + return isinstance(msg, cls) + + class ThpCodeEntryChallenge(protobuf.MessageType): + challenge: "bytes | None" + + def __init__( + self, + *, + challenge: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]: + return isinstance(msg, cls) + + class ThpCodeEntryCpaceHost(protobuf.MessageType): + cpace_host_public_key: "bytes | None" + + def __init__( + self, + *, + cpace_host_public_key: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]: + return isinstance(msg, cls) + + class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + cpace_trezor_public_key: "bytes | None" + + def __init__( + self, + *, + cpace_trezor_public_key: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]: + return isinstance(msg, cls) + + class ThpCodeEntryTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]: + return isinstance(msg, cls) + + class ThpCodeEntrySecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]: + return isinstance(msg, cls) + + class ThpQrCodeTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]: + return isinstance(msg, cls) + + class ThpQrCodeSecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]: + return isinstance(msg, cls) + + class ThpNfcUnidirectionalTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]: + return isinstance(msg, cls) + + class ThpNfcUnidirectionalSecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]: + return isinstance(msg, cls) + + class ThpCredentialRequest(protobuf.MessageType): + host_static_pubkey: "bytes | None" + + def __init__( + self, + *, + host_static_pubkey: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]: + return isinstance(msg, cls) + + class ThpCredentialResponse(protobuf.MessageType): + trezor_static_pubkey: "bytes | None" + credential: "bytes | None" + + def __init__( + self, + *, + trezor_static_pubkey: "bytes | None" = None, + credential: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]: + return isinstance(msg, cls) + + class ThpEndRequest(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]: + return isinstance(msg, cls) + + class ThpEndResponse(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]: + return isinstance(msg, cls) + class ThpCredentialMetadata(protobuf.MessageType): host_name: "str | None" diff --git a/core/src/trezor/ui/__init__.py b/core/src/trezor/ui/__init__.py index ba9882d10e5..4dcb2904dcc 100644 --- a/core/src/trezor/ui/__init__.py +++ b/core/src/trezor/ui/__init__.py @@ -320,7 +320,7 @@ def _button_request(self) -> None: def _paint(self) -> None: """Paint the layout and ensure that homescreen cache is properly invalidated.""" - import storage.cache as storage_cache + import storage.cache_common as storage_cache painted = self.layout.paint() if painted: diff --git a/core/src/trezor/ui/layouts/common.py b/core/src/trezor/ui/layouts/common.py index 9047cd0a3cf..289eea009d1 100644 --- a/core/src/trezor/ui/layouts/common.py +++ b/core/src/trezor/ui/layouts/common.py @@ -57,7 +57,7 @@ def raise_if_not_confirmed( exc: ExceptionType = ActionCancelled, ) -> Awaitable[None]: action = interact(layout_obj, br_name, br_code, exc) - return action # type: ignore ["UiResult" is incompatible with "None"] + return action # type: ignore [Expression of type "Coroutine[Any, Any, UiResult]" is incompatible with return type "Awaitable[None]"] async def with_info( diff --git a/core/src/trezor/ui/layouts/homescreen.py b/core/src/trezor/ui/layouts/homescreen.py index 1e15e4ed92c..4b6f76a4321 100644 --- a/core/src/trezor/ui/layouts/homescreen.py +++ b/core/src/trezor/ui/layouts/homescreen.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +import storage.cache_common as storage_cache import trezorui2 from trezor import TR, ui @@ -123,11 +123,13 @@ def __init__(self, delay_ms: int) -> None: ) async def get_result(self) -> Any: + from trezor.wire import context + from apps.base import set_homescreen # Handle timeout. result = await super().get_result() assert result == trezorui2.CANCELLED - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() return result diff --git a/core/src/trezor/ui/layouts/mercury/__init__.py b/core/src/trezor/ui/layouts/mercury/__init__.py index 02bfd2c8c76..4ac1a69ad0f 100644 --- a/core/src/trezor/ui/layouts/mercury/__init__.py +++ b/core/src/trezor/ui/layouts/mercury/__init__.py @@ -1053,7 +1053,7 @@ def request_passphrase_on_device(max_len: int) -> Awaitable[str]: ButtonRequestType.PassphraseEntry, raise_on_cancel=ActionCancelled("Passphrase entry cancelled"), ) - return result # type: ignore ["UiResult" is incompatible with "str"] + return result # type: ignore [Expression of type "Coroutine[Any, Any, str | UiResult]" is incompatible with return type "Awaitable[str]"] def request_pin_on_device( @@ -1082,7 +1082,7 @@ def request_pin_on_device( ButtonRequestType.PinEntry, raise_on_cancel=PinCancelled, ) - return result # type: ignore ["UiResult" is incompatible with "str"] + return result # type: ignore [Expression of type "Coroutine[Any, Any, str | UiResult]" is incompatible with return type "Awaitable[str]"] async def confirm_reenter_pin(is_wipe_code: bool = False) -> None: diff --git a/core/src/trezor/ui/layouts/tr/__init__.py b/core/src/trezor/ui/layouts/tr/__init__.py index 3991eed3a65..79cd8d9743a 100644 --- a/core/src/trezor/ui/layouts/tr/__init__.py +++ b/core/src/trezor/ui/layouts/tr/__init__.py @@ -1199,7 +1199,7 @@ def pin_mismatch_popup(is_wipe_code: bool = False) -> Awaitable[None]: TR.buttons__check_again, BR_CODE_OTHER, ) - return layout # type: ignore ["UiResult" is incompatible with "None"] + return layout # type: ignore [Expression of type "Awaitable[UiResult]" is incompatible with return type "Awaitable[None]"] def wipe_code_same_as_pin_popup() -> Awaitable[None]: diff --git a/core/src/trezor/ui/layouts/tr/reset.py b/core/src/trezor/ui/layouts/tr/reset.py index f21e2613ddf..f33cd3eeb94 100644 --- a/core/src/trezor/ui/layouts/tr/reset.py +++ b/core/src/trezor/ui/layouts/tr/reset.py @@ -1,4 +1,4 @@ -from typing import Awaitable, Sequence +from typing import TYPE_CHECKING import trezorui2 from trezor import TR @@ -9,6 +9,9 @@ CONFIRMED = trezorui2.CONFIRMED # global_import_cache +if TYPE_CHECKING: + from typing import Awaitable, Sequence + async def show_share_words( share_words: Sequence[str], diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 7021759ba9f..b97fadd5b77 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -35,6 +35,8 @@ DISABLE_ANIMATION = 0 +DISABLE_ENCRYPTION: bool = False + if __debug__: if EMULATOR: import uos @@ -45,7 +47,13 @@ LOG_MEMORY = 0 if TYPE_CHECKING: - from typing import Any, Iterator, Protocol, Sequence, TypeVar + from typing import ( # pyright: ignore[reportShadowedImports] + Any, + Iterator, + Protocol, + Sequence, + TypeVar, + ) from trezor.protobuf import MessageType @@ -111,6 +119,7 @@ def presize_module(modname: str, size: int) -> None: if __debug__: + from ubinascii import hexlify def mem_dump(filename: str) -> None: from micropython import mem_info @@ -127,6 +136,9 @@ def mem_dump(filename: str) -> None: else: mem_info(True) + def get_bytes_as_str(a): + return hexlify(a).decode("utf-8") + def ensure(cond: bool, msg: str | None = None) -> None: if not cond: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index df5ee773776..8cc1e8bf094 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_main.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -16,22 +16,23 @@ ## Session handler -When the `wire.setup` is called the `handle_session` coroutine is scheduled. The +When the `wire.setup` is called the `handle_session` (or `handle_thp_session`) coroutine is scheduled. The `handle_session` waits for some messages to be received on some particular interface and reads the message's header. When the message type is known the first handler is called. This way the `handle_session` goes through all the workflows. """ -from micropython import const from typing import TYPE_CHECKING -from storage.cache import InvalidSessionError -from trezor import log, loop, protobuf, utils, workflow -from trezor.enums import FailureType -from trezor.messages import Failure -from trezor.wire import codec_v1, context -from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage +from trezor import log, loop, protobuf, utils +from trezor.wire import context, message_handler, protocol_common + +if utils.USE_THP: + from trezor.wire import thp_main + from trezor.wire.message_handler import WIRE_BUFFER_2 +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler # Import all errors into namespace, so that `wire.Error` is available from # other packages. @@ -40,158 +41,52 @@ if TYPE_CHECKING: from trezorio import WireInterface - from typing import Any, Callable, Container, Coroutine, TypeVar + from typing import Any, Callable, Coroutine, TypeVar Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[[Msg], HandlerTask] - LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) -# If set to False protobuf messages marked with "experimental_message" option are rejected. -EXPERIMENTAL_ENABLED = False - - def setup(iface: WireInterface) -> None: - """Initialize the wire stack on passed USB interface.""" - loop.schedule(handle_session(iface, codec_v1.SESSION_ID)) - - -def wrap_protobuf_load( - buffer: bytes, - expected_type: type[LoadedMessageType], -) -> LoadedMessageType: - try: - msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) - if __debug__ and utils.EMULATOR: - log.debug( - __name__, "received message contents:\n%s", utils.dump_protobuf(msg) - ) - return msg - except Exception as e: - if __debug__: - log.exception(__name__, e) - if e.args: - raise DataError("Failed to decode message: " + " ".join(e.args)) - else: - raise DataError("Failed to decode message") - - -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - -if __debug__: - PROTOBUF_BUFFER_SIZE_DEBUG = 1024 - WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) - - -async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool: - """Handle a message that was loaded from USB by the caller. - - Find the appropriate handler, run it and write its result on the wire. In case - a problem is encountered at any point, write the appropriate error on the wire. - - The return value indicates whether to override the default restarting behavior. If - `False` is returned, the caller is allowed to clear the loop and restart the - MicroPython machine (see `session.py`). This would lose all state and incurs a cost - in terms of repeated startup time. When handling the message didn't cause any - significant fragmentation (e.g., if decoding the message was skipped), or if - the type of message is supposed to be optimized and not disrupt the running state, - this function will return `True`. - """ - if __debug__: - try: - msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME - except Exception: - msg_type = f"{msg.type} - unknown message type" - log.debug( - __name__, - "%s:%x receive: <%s>", - ctx.iface.iface_num(), - ctx.sid, - msg_type, - ) - - res_msg: protobuf.MessageType | None = None - - # We need to find a handler for this message type. - try: - handler = find_handler(ctx.iface, msg.type) - except Error as exc: - # Handlers are allowed to exception out. In that case, we can skip decoding - # and return the error. - await ctx.write(failure(exc)) - return True - - if msg.type in workflow.ALLOW_WHILE_LOCKED: - workflow.autolock_interrupts_workflow = False - - # Here we make sure we always respond with a Failure response - # in case of any errors. - try: - # Find a protobuf.MessageType subclass that describes this - # message. Raises if the type is not found. - req_type = protobuf.type_for_wire(msg.type) - - # Try to decode the message according to schema from - # `req_type`. Raises if the message is malformed. - req_msg = wrap_protobuf_load(msg.data, req_type) + """Initialize the wire stack on passed WireInterface.""" + if utils.USE_THP: + loop.schedule(handle_thp_session(iface)) + else: + loop.schedule(handle_session(iface)) - # Create the handler task. - task = handler(req_msg) - # Run the workflow task. Workflow can do more on-the-wire - # communication inside, but it should eventually return a - # response message, or raise an exception (a rather common - # thing to do). Exceptions are handled in the code below. - res_msg = await workflow.spawn(context.with_context(ctx, task)) +if utils.USE_THP: - except context.UnexpectedMessage: - # Workflow was trying to read a message from the wire, and - # something unexpected came in. See Context.read() for - # example, which expects some particular message and raises - # UnexpectedMessage if another one comes in. - # - # We process the unexpected message by aborting the current workflow and - # possibly starting a new one, initiated by that message. (The main usecase - # being, the host does not finish the workflow, we want other callers to - # be able to do their own thing.) - # - # The message is stored in the exception, which we re-raise for the caller - # to process. It is not a standard exception that should be logged and a result - # sent to the wire. - raise + async def handle_thp_session(iface: WireInterface): - except BaseException as exc: - # Either: - # - the message had a type that has a registered handler, but does not have - # a protobuf class - # - the message was not valid protobuf - # - workflow raised some kind of an exception while running - # - something canceled the workflow from the outside - if __debug__: - if isinstance(exc, ActionCancelled): - log.debug(__name__, "cancelled: %s", exc.message) - elif isinstance(exc, loop.TaskClosed): - log.debug(__name__, "cancelled: loop task was closed") - else: - log.exception(__name__, exc) - res_msg = failure(exc) + thp_main.set_read_buffer(WIRE_BUFFER) + thp_main.set_write_buffer(WIRE_BUFFER_2) - if res_msg is not None: - # perform the write outside the big try-except block, so that usb write - # problem bubbles up - await ctx.write(res_msg) + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() - # Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting. - return msg.type in AVOID_RESTARTING_FOR + while True: + try: + await thp_main.thp_main_loop(iface) + except Exception as exc: + # Log and try again. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + if __debug__: + log.debug(__name__, "utils.unimport_end(modules) and loop.clear()") + utils.unimport_end(modules) + loop.clear() + return # pylint: disable=lost-exception -async def handle_session(iface: WireInterface, session_id: int) -> None: - ctx = context.Context(iface, session_id, WIRE_BUFFER) - next_msg: codec_v1.Message | None = None +async def handle_session(iface: WireInterface) -> None: + ctx = context.CodecContext(iface, WIRE_BUFFER) + next_msg: protocol_common.Message | None = None # Take a mark of modules that are imported at this point, so we can # roll back and un-import any others. @@ -203,7 +98,7 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: # wait for a new one coming from the wire. try: msg = await ctx.read_from_wire() - except codec_v1.CodecError as exc: + except protocol_common.WireError as exc: if __debug__: log.exception(__name__, exc) await ctx.write(failure(exc)) @@ -216,8 +111,10 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: do_not_restart = False try: - do_not_restart = await _handle_single_message(ctx, msg) - except context.UnexpectedMessage as unexpected: + do_not_restart = await message_handler.handle_single_message( + ctx, msg, handler_finder=find_handler + ) + except UnexpectedMessageException as unexpected: # The workflow was interrupted by an unexpected message. We need to # process it as if it was a new message... next_msg = unexpected.msg @@ -230,11 +127,13 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: if __debug__: log.exception(__name__, exc) finally: - # Unload modules imported by the workflow. Should not raise. + # Unload modules imported by the workflow. Should not raise. utils.unimport_end(modules) if not do_not_restart: # Let the session be restarted from `main`. + if __debug__: + log.debug(__name__, "loop.clear()") loop.clear() return # pylint: disable=lost-exception @@ -243,81 +142,3 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: # loop.clear() above. if __debug__: log.exception(__name__, exc) - - -def find_handler(iface: WireInterface, msg_type: int) -> Handler: - import usb - - from apps import workflow_handlers - - handler = workflow_handlers.find_registered_handler(iface, msg_type) - if handler is None: - raise UnexpectedMessage("Unexpected message") - - if __debug__ and iface is usb.iface_debug: - # no filtering allowed for debuglink - return handler - - for filter in filters: - handler = filter(msg_type, handler) - - return handler - - -filters: list[Callable[[int, Handler], Handler]] = [] -"""Filters for the wire handler. - -Filters are applied in order. Each filter gets a message id and a preceding handler. It -must either return a handler (the same one or a modified one), or raise an exception -that gets sent to wire directly. - -Filters are not applied to debug sessions. - -The filters are designed for: - * rejecting messages -- while in Recovery mode, most messages are not allowed - * adding additional behavior -- while device is soft-locked, a PIN screen will be shown - before allowing a message to trigger its original behavior. - -For this, the filters are effectively deny-first. If an earlier filter rejects the -message, the later filters are not called. But if a filter adds behavior, the latest -filter "wins" and the latest behavior triggers first. -Please note that this behavior is really unsuited to anything other than what we are -using it for now. It might be necessary to modify the semantics if we need more complex -usecases. - -NB: `filters` is currently public so callers can have control over where they insert -new filters, but removal should be done using `remove_filter`! -We should, however, change it such that filters must be added using an `add_filter` -and `filters` becomes private! -""" - - -def remove_filter(filter): - try: - filters.remove(filter) - except ValueError: - pass - - -AVOID_RESTARTING_FOR: Container[int] = () - - -def failure(exc: BaseException) -> Failure: - if isinstance(exc, Error): - return Failure(code=exc.code, message=exc.message) - elif isinstance(exc, loop.TaskClosed): - return Failure(code=FailureType.ActionCancelled, message="Cancelled") - elif isinstance(exc, InvalidSessionError): - return Failure(code=FailureType.InvalidSession, message="Invalid session") - else: - # NOTE: when receiving generic `FirmwareError` on non-debug build, - # change the `if __debug__` to `if True` to get the full error message. - if __debug__: - message = str(exc) - else: - message = "Firmware error" - return Failure(code=FailureType.FirmwareError, message=message) - - -def unexpected_message() -> Failure: - return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 54c0871b999..c600201d56a 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from trezor import io, loop, utils +from trezor.wire.protocol_common import Message, WireError if TYPE_CHECKING: from trezorio import WireInterface @@ -18,16 +19,10 @@ SESSION_ID = const(0) -class CodecError(Exception): +class CodecError(WireError): pass -class Message: - def __init__(self, mtype: int, mdata: bytes) -> None: - self.type = mtype - self.data = mdata - - async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index e45f609849f..6a7a4c79d42 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -15,9 +15,12 @@ from typing import TYPE_CHECKING +from storage import cache, cache_codec +from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError from trezor import log, loop, protobuf +from trezor.wire import codec_v1 -from . import codec_v1 +from .protocol_common import Context, Message if TYPE_CHECKING: from trezorio import WireInterface @@ -32,6 +35,8 @@ overload, ) + from storage.cache_common import DataCache + Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[["Context", Msg], HandlerTask] @@ -41,31 +46,35 @@ T = TypeVar("T") -class UnexpectedMessage(Exception): +class UnexpectedMessageException(Exception): """A message was received that is not part of the current workflow. Utility exception to inform the session handler that the current workflow should be aborted and a new one started as if `msg` was the first message. """ - def __init__(self, msg: codec_v1.Message) -> None: + def __init__(self, msg: Message) -> None: super().__init__() self.msg = msg -class Context: +class CodecContext(Context): """Wire context. - Represents USB communication inside a particular session on a particular interface + Represents USB communication inside a particular session (channel) on a particular interface (i.e., wire, debug, single BT connection, etc.) """ - def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: + def __init__( + self, + iface: WireInterface, + buffer: bytearray, + ) -> None: self.iface = iface - self.sid = sid self.buffer = buffer + super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big")) - def read_from_wire(self) -> Awaitable[codec_v1.Message]: + def read_from_wire(self) -> Awaitable[Message]: """Read a whole message from the wire without parsing it.""" return codec_v1.read_message(self.iface, self.buffer) @@ -81,6 +90,8 @@ async def read( self, expected_types: Container[int], expected_type: type[LoadedMessageType] ) -> LoadedMessageType: ... + reading: bool = False + async def read( self, expected_types: Container[int], @@ -95,9 +106,8 @@ async def read( if __debug__: log.debug( __name__, - "%s:%x expect: %s", + "%s: expect: %s", self.iface.iface_num(), - self.sid, expected_type.MESSAGE_NAME if expected_type else expected_types, ) @@ -107,7 +117,7 @@ async def read( # If we got a message with unexpected type, raise the message via # `UnexpectedMessageError` and let the session handler deal with it. if msg.type not in expected_types: - raise UnexpectedMessage(msg) + raise UnexpectedMessageException(msg) if expected_type is None: expected_type = protobuf.type_for_wire(msg.type) @@ -115,14 +125,14 @@ async def read( if __debug__: log.debug( __name__, - "%s:%x read: %s", + "%s: read: %s", self.iface.iface_num(), - self.sid, expected_type.MESSAGE_NAME, ) # look up the protobuf class and parse the message - from . import wrap_protobuf_load + from . import message_handler # noqa: F401 + from .message_handler import wrap_protobuf_load return wrap_protobuf_load(msg.data, expected_type) @@ -131,9 +141,8 @@ async def write(self, msg: protobuf.MessageType) -> None: if __debug__: log.debug( __name__, - "%s:%x write: %s", + "%s: write: %s", self.iface.iface_num(), - self.sid, msg.MESSAGE_NAME, ) @@ -150,23 +159,19 @@ async def write(self, msg: protobuf.MessageType) -> None: buffer = bytearray(msg_size) msg_size = protobuf.encode(buffer, msg) - await codec_v1.write_message( self.iface, msg.MESSAGE_WIRE_TYPE, memoryview(buffer)[:msg_size], ) - async def call( - self, - msg: protobuf.MessageType, - expected_type: type[LoadedMessageType], - ) -> LoadedMessageType: - assert expected_type.MESSAGE_WIRE_TYPE is not None - - await self.write(msg) - del msg - return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + c = cache_codec.get_active_session() + if c is None: + raise InvalidSessionError() + return c CURRENT_CONTEXT: Context | None = None @@ -258,3 +263,65 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: send_exc = e else: send_exc = None + + +# ACCESS TO CACHE + +if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def cache_get(key: int, default: T) -> bytes | T: # noqa: F811 + ... + + +def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get(key, default) + + +def cache_get_bool(key: int) -> bool: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get_bool(key) + + +def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get_int(key, default) + + +def cache_is_set(key: int) -> bool: + cache = _get_cache_for_key(key) + return cache.is_set(key) + + +def cache_set(key: int, value: bytes) -> None: + cache = _get_cache_for_key(key) + cache.set(key, value) + + +def cache_set_bool(key: int, value: bool) -> None: + cache = _get_cache_for_key(key) + cache.set_bool(key, value) + + +def cache_set_int(key: int, value: int) -> None: + cache = _get_cache_for_key(key) + cache.set_int(key, value) + + +def cache_delete(key: int) -> None: + cache = _get_cache_for_key(key) + cache.delete(key) + + +def _get_cache_for_key(key) -> DataCache: + if key & SESSIONLESS_FLAG: + return cache.get_sessionless_cache() + if CURRENT_CONTEXT: + return CURRENT_CONTEXT.cache + raise Exception("No wire context") diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index 376820b5834..e8b2d3feb45 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -8,6 +8,12 @@ def __init__(self, code: FailureType, message: str) -> None: self.message = message +class SilentError(Exception): + def __init__(self, message: str) -> None: + super().__init__() + self.message = message + + class UnexpectedMessage(Error): def __init__(self, message: str) -> None: super().__init__(FailureType.UnexpectedMessage, message) diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py new file mode 100644 index 00000000000..cc3f55edfcf --- /dev/null +++ b/core/src/trezor/wire/message_handler.py @@ -0,0 +1,254 @@ +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import InvalidSessionError +from trezor import log, loop, protobuf, utils, workflow +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor.wire.context import Context, UnexpectedMessageException, with_context +from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage +from trezor.wire.protocol_common import Message + +# Import all errors into namespace, so that `wire.Error` is available from +# other packages. +from trezor.wire.errors import * # isort:skip # noqa: F401,F403 + + +if TYPE_CHECKING: + from typing import Any, Callable, Container + + from trezor.wire import Handler, LoadedMessageType + + HandlerFinder = Callable[[Any], Handler | None] + +# If set to False protobuf messages marked with "experimental_message" option are rejected. +EXPERIMENTAL_ENABLED = False + + +def wrap_protobuf_load( + buffer: bytes, + expected_type: type[LoadedMessageType], +) -> LoadedMessageType: + try: + if __debug__: + log.debug( + __name__, + "Buffer to be parsed to a LoadedMessage: %s", + utils.get_bytes_as_str(buffer), + ) + msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) + if __debug__ and utils.EMULATOR: + log.debug( + __name__, "received message contents:\n%s", utils.dump_protobuf(msg) + ) + return msg + except Exception as e: + if __debug__: + log.exception(__name__, e) + if e.args: + raise DataError("Failed to decode message: " + " ".join(e.args)) + else: + raise DataError("Failed to decode message") + + +_PROTOBUF_BUFFER_SIZE = const(8192) + +WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) +if utils.USE_THP: + WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE) + + +async def handle_single_message( + ctx: Context, + msg: Message, + handler_finder: HandlerFinder, +) -> bool: + """Handle a message that was loaded from USB by the caller. + + Find the appropriate handler, run it and write its result on the wire. In case + a problem is encountered at any point, write the appropriate error on the wire. + + The return value indicates whether to override the default restarting behavior. If + `False` is returned, the caller is allowed to clear the loop and restart the + MicroPython machine (see `session.py`). This would lose all state and incurs a cost + in terms of repeated startup time. When handling the message didn't cause any + significant fragmentation (e.g., if decoding the message was skipped), or if + the type of message is supposed to be optimized and not disrupt the running state, + this function will return `True`. + """ + if __debug__: + try: + msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME + except Exception: + msg_type = f"{msg.type} - unknown message type" + if ctx.channel_id is not None: + sid = int.from_bytes(ctx.channel_id, "big") + log.debug( + __name__, + "%s:%x receive: <%s>", + ctx.iface.iface_num(), + sid, + msg_type, + ) + else: + log.debug( + __name__, + "%s:unknown_sid receive: <%s>", + ctx.iface.iface_num(), + msg_type, + ) + + res_msg: protobuf.MessageType | None = None + + # We need to find a handler for this message type. + try: + handler: Handler | None = handler_finder(msg.type) + except Error as exc: + # Handlers are allowed to exception out. In that case, we can skip decoding + # and return the error. + await ctx.write(failure(exc)) + return True + + if handler is None: + # If no handler is found, we can skip decoding and directly + # respond with failure. + await ctx.write(unexpected_message()) + return True + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + req_type = protobuf.type_for_wire(msg.type) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + task = handler(req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + + # Spawn a workflow around the task. This ensures that concurrent + # workflows are shut down. + res_msg = await workflow.spawn(with_context(ctx, task)) + + except UnexpectedMessageException: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + + # We process the unexpected message by aborting the current workflow and + # possibly starting a new one, initiated by that message. (The main usecase + # being, the host does not finish the workflow, we want other callers to + # be able to do their own thing.) + # + # The message is stored in the exception, which we re-raise for the caller + # to process. It is not a standard exception that should be logged and a result + # sent to the wire. + raise + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await ctx.write(res_msg) + + # Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting. + return msg.type in AVOID_RESTARTING_FOR + + +AVOID_RESTARTING_FOR: Container[int] = () + + +def failure(exc: BaseException) -> Failure: + if isinstance(exc, Error): + return Failure(code=exc.code, message=exc.message) + elif isinstance(exc, loop.TaskClosed): + return Failure(code=FailureType.ActionCancelled, message="Cancelled") + elif isinstance(exc, InvalidSessionError): + return Failure(code=FailureType.InvalidSession, message="Invalid session") + else: + # NOTE: when receiving generic `FirmwareError` on non-debug build, + # change the `if __debug__` to `if True` to get the full error message. + if __debug__: + message = str(exc) + else: + message = "Firmware error" + return Failure(code=FailureType.FirmwareError, message=message) + + +def unexpected_message() -> Failure: + return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") + + +def find_handler(msg_type: int) -> Handler: + from apps import workflow_handlers + + handler = workflow_handlers.find_registered_handler(msg_type) + if handler is None: + raise UnexpectedMessage("Unexpected message") + + for filter in filters: + handler = filter(msg_type, handler) + + return handler + + +filters: list[Callable[[int, Handler], Handler]] = [] +"""Filters for the wire handler. + +Filters are applied in order. Each filter gets a message id and a preceding handler. It +must either return a handler (the same one or a modified one), or raise an exception +that gets sent to wire directly. + +Filters are not applied to debug sessions. + +The filters are designed for: + * rejecting messages -- while in Recovery mode, most messages are not allowed + * adding additional behavior -- while device is soft-locked, a PIN screen will be shown + before allowing a message to trigger its original behavior. + +For this, the filters are effectively deny-first. If an earlier filter rejects the +message, the later filters are not called. But if a filter adds behavior, the latest +filter "wins" and the latest behavior triggers first. +Please note that this behavior is really unsuited to anything other than what we are +using it for now. It might be necessary to modify the semantics if we need more complex +usecases. + +NB: `filters` is currently public so callers can have control over where they insert +new filters, but removal should be done using `remove_filter`! +We should, however, change it such that filters must be added using an `add_filter` +and `filters` becomes private! +""" + + +def remove_filter(filter): + try: + filters.remove(filter) + except ValueError: + pass diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py new file mode 100644 index 00000000000..d3c02c66a0d --- /dev/null +++ b/core/src/trezor/wire/protocol_common.py @@ -0,0 +1,73 @@ +from typing import TYPE_CHECKING + +from trezor import protobuf + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Container, TypeVar, overload + + from storage.cache_common import DataCache + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + T = TypeVar("T") + + +class Message: + + def __init__( + self, + message_type: int, + message_data: bytes, + ) -> None: + self.data = message_data + self.type = message_type + + def to_bytes(self): + return self.type.to_bytes(2, "big") + self.data + + +class Context: + def __init__(self, iface: WireInterface, channel_id: bytes) -> None: + self.iface: WireInterface = iface + self.channel_id: bytes = channel_id + + if TYPE_CHECKING: + + @overload + async def read( + self, expected_types: Container[int] + ) -> protobuf.MessageType: ... + + @overload + async def read( + self, expected_types: Container[int], expected_type: type[LoadedMessageType] + ) -> LoadedMessageType: ... + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: ... + + async def write(self, msg: protobuf.MessageType) -> None: ... + + async def write_force(self, msg: protobuf.MessageType) -> None: + await self.write(msg) + + async def call( + self, + msg: protobuf.MessageType, + expected_type: type[LoadedMessageType], + ) -> LoadedMessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + @property + def cache(self) -> DataCache: ... + + +class WireError(Exception): + pass diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py new file mode 100644 index 00000000000..aa620197638 --- /dev/null +++ b/core/src/trezor/wire/thp/__init__.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING + +from trezor.wire.protocol_common import WireError + + +class ThpError(WireError): + pass + + +class ThpDecryptionError(ThpError): + pass + + +class ThpInvalidDataError(ThpError): + pass + + +class ThpUnallocatedSessionError(ThpError): + def __init__(self, session_id: int): + self.session_id = session_id + + +if TYPE_CHECKING: + from enum import IntEnum +else: + IntEnum = object + + +class ThpErrorType(IntEnum): + TRANSPORT_BUSY = 1 + UNALLOCATED_CHANNEL = 2 + DECRYPTION_FAILED = 3 + INVALID_DATA = 4 + + +class ChannelState(IntEnum): + UNALLOCATED = 0 + TH1 = 1 + TH2 = 2 + TP1 = 3 + TP2 = 4 + TP3 = 5 + TP4 = 6 + TC1 = 7 + ENCRYPTED_TRANSPORT = 8 + + +class SessionState(IntEnum): + UNALLOCATED = 0 + ALLOCATED = 1 + MANAGEMENT = 2 + + +class WireInterfaceType(IntEnum): + MOCK = 0 + USB = 1 + BLE = 2 + + +def is_channel_state_pairing(state: int) -> bool: + if state in ( + ChannelState.TP1, + ChannelState.TP2, + ChannelState.TP3, + ChannelState.TP4, + ChannelState.TC1, + ): + return True + return False + + +if __debug__: + + def state_to_str(state: int) -> str: + name = { + v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") + }.get(state) + if name is not None: + return name + return "UNKNOWN_STATE" diff --git a/core/src/trezor/wire/thp/alternating_bit_protocol.py b/core/src/trezor/wire/thp/alternating_bit_protocol.py new file mode 100644 index 00000000000..17f27c0d97f --- /dev/null +++ b/core/src/trezor/wire/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +from storage.cache_thp import ChannelCache +from trezor import log +from trezor.wire.thp import ThpError + + +def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: + """ + Checks if: + - an ACK message is expected + - the received ACK message acknowledges correct sequence number (bit) + """ + if not _is_ack_expected(cache): + return False + + if not _has_ack_correct_sync_bit(cache, ack_bit): + return False + + return True + + +def _is_ack_expected(cache: ChannelCache) -> bool: + is_expected: bool = not is_sending_allowed(cache) + if __debug__ and not is_expected: + log.debug(__name__, "Received unexpected ACK message") + return is_expected + + +def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: + is_correct: bool = get_send_seq_bit(cache) == sync_bit + if __debug__ and not is_correct: + log.debug(__name__, "Received ACK message with wrong ack bit") + return is_correct + + +def is_sending_allowed(cache: ChannelCache) -> bool: + """ + Checks whether sending a message in the provided channel is allowed. + + Note: Sending a message in a channel before receipt of ACK message for the previously + sent message (in the channel) is prohibited, as it can lead to desynchronization. + """ + return bool(cache.sync >> 7) + + +def get_send_seq_bit(cache: ChannelCache) -> int: + """ + Returns the sequential number (bit) of the next message to be sent + in the provided channel. + """ + return (cache.sync & 0x20) >> 5 + + +def get_expected_receive_seq_bit(cache: ChannelCache) -> int: + """ + Returns the (expected) sequential number (bit) of the next message + to be received in the provided channel. + """ + return (cache.sync & 0x40) >> 6 + + +def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: + """ + Set the flag whether sending a message in this channel is allowed or not. + """ + cache.sync &= 0x7F + if sending_allowed: + cache.sync |= 0x80 + + +def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: + """ + Set the expected sequential number (bit) of the next message to be received + in the provided channel + """ + if __debug__: + log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) + if seq_bit not in (0, 1): + raise ThpError("Unexpected receive sync bit") + + # set second bit to "seq_bit" value + cache.sync &= 0xBF + if seq_bit: + cache.sync |= 0x40 + + +def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: + if seq_bit not in (0, 1): + raise ThpError("Unexpected send seq bit") + if __debug__: + log.debug(__name__, "setting sync send seq bit to %d", seq_bit) + # set third bit to "seq_bit" value + cache.sync &= 0xDF + if seq_bit: + cache.sync |= 0x20 + + +def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: + """ + Set the sequential bit of the "next message to be send" to the opposite value, + i.e. 1 -> 0 and 0 -> 1 + """ + _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py new file mode 100644 index 00000000000..07bf2b07b80 --- /dev/null +++ b/core/src/trezor/wire/thp/channel.py @@ -0,0 +1,402 @@ +import ustruct +from typing import TYPE_CHECKING + +from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, +) +from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id +from trezor import log, loop, protobuf, utils, workflow +from trezor.wire.thp.transmission_loop import TransmissionLoop + +from . import ChannelState, ThpDecryptionError, ThpError +from . import alternating_bit_protocol as ABP +from . import ( + control_byte, + crypto, + interface_manager, + memory_manager, + received_message_handler, +) +from .checksum import CHECKSUM_LENGTH +from .thp_messages import ENCRYPTED_TRANSPORT, PacketHeader +from .writer import ( + CONT_HEADER_LENGTH, + INIT_HEADER_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if __debug__: + from ubinascii import hexlify + + from . import state_to_str + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable + + from .pairing_context import PairingContext + from .session_context import GenericSessionContext + + +class Channel: + def __init__(self, channel_cache: ChannelCache) -> None: + if __debug__: + log.debug(__name__, "channel initialization") + self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) + self.channel_cache: ChannelCache = channel_cache + self.is_cont_packet_expected: bool = False + self.expected_payload_length: int = 0 + self.bytes_read: int = 0 + self.buffer: utils.BufferType + self.channel_id: bytes = channel_cache.channel_id + self.selected_pairing_methods = [] + self.sessions: dict[int, GenericSessionContext] = {} + self.write_task_spawn: loop.spawn | None = None + self.connection_context: PairingContext | None = None + self.transmission_loop: TransmissionLoop | None = None + self.handshake: crypto.Handshake | None = None + + def clear(self): + clear_sessions_with_channel_id(self.channel_id) + self.channel_cache.clear() + + # ACCESS TO CHANNEL_DATA + def get_channel_id_int(self) -> int: + return int.from_bytes(self.channel_id, "big") + + def get_channel_state(self) -> int: + state = int.from_bytes(self.channel_cache.state, "big") + if __debug__: + log.debug( + __name__, + "(cid: %s) get_channel_state: %s", + utils.get_bytes_as_str(self.channel_id), + state_to_str(state), + ) + return state + + def get_handshake_hash(self) -> bytes: + h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH) + assert h is not None + return h + + def set_channel_state(self, state: ChannelState) -> None: + self.channel_cache.state = bytearray(state.to_bytes(1, "big")) + if __debug__: + log.debug( + __name__, + "(cid: %s) set_channel_state: %s", + utils.get_bytes_as_str(self.channel_id), + state_to_str(state), + ) + + def set_buffer(self, buffer: utils.BufferType) -> None: + self.buffer = buffer + if __debug__: + log.debug( + __name__, + "(cid: %s) set_buffer: %s", + utils.get_bytes_as_str(self.channel_id), + type(self.buffer), + ) + + # CALLED BY THP_MAIN_LOOP + + def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: + if __debug__: + log.debug( + __name__, + "(cid: %s) receive_packet", + utils.get_bytes_as_str(self.channel_id), + ) + + self._handle_received_packet(packet) + + if __debug__: + log.debug( + __name__, + "(cid: %s) self.buffer: %s", + utils.get_bytes_as_str(self.channel_id), + utils.get_bytes_as_str(self.buffer), + ) + + if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: + self._finish_message() + return received_message_handler.handle_received_message(self, self.buffer) + elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: + self.is_cont_packet_expected = True + else: + raise ThpError( + "Read more bytes than is the expected length of the message!" + ) + return None + + def _handle_received_packet(self, packet: utils.BufferType) -> None: + ctrl_byte = packet[0] + if control_byte.is_continuation(ctrl_byte): + return self._handle_cont_packet(packet) + return self._handle_init_packet(packet) + + def _handle_init_packet(self, packet: utils.BufferType) -> None: + if __debug__: + log.debug( + __name__, + "(cid: %s) handle_init_packet", + utils.get_bytes_as_str(self.channel_id), + ) + # ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption + _, _, payload_length = ustruct.unpack(">BHH", packet) + self.expected_payload_length = payload_length + packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:] + + # If the channel does not "own" the buffer lock, decrypt first packet + # TODO do it only when needed! + # TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation. + # On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented + # in "memory_manager.select_buffer", because the session id is unknown (encrypted). + + # if control_byte.is_encrypted_transport(ctrl_byte): + # packet_payload = self._decrypt_single_packet_payload(packet_payload) + + self.buffer = memory_manager.select_buffer( + self.get_channel_state(), + self.buffer, + packet_payload, + payload_length, + ) + + if __debug__: + log.debug( + __name__, + "(cid: %s) handle_init_packet - payload len: %d", + utils.get_bytes_as_str(self.channel_id), + payload_length, + ) + log.debug( + __name__, + "(cid: %s) handle_init_packet - buffer len: %d", + utils.get_bytes_as_str(self.channel_id), + len(self.buffer), + ) + return self._buffer_packet_data(self.buffer, packet, 0) + + def _handle_cont_packet(self, packet: utils.BufferType) -> None: + if __debug__: + log.debug( + __name__, + "(cid: %s) handle_cont_packet", + utils.get_bytes_as_str(self.channel_id), + ) + if not self.is_cont_packet_expected: + raise ThpError("Continuation packet is not expected, ignoring") + return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) + + def _decrypt_single_packet_payload( + self, payload: utils.BufferType + ) -> utils.BufferType: + # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) + return payload + + def decrypt_buffer( + self, message_length: int, offset: int = INIT_HEADER_LENGTH + ) -> None: + noise_buffer = memoryview(self.buffer)[ + offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH + ] + tag = self.buffer[ + message_length + - CHECKSUM_LENGTH + - TAG_LENGTH : message_length + - CHECKSUM_LENGTH + ] + if utils.DISABLE_ENCRYPTION: + is_tag_valid = tag == crypto.DUMMY_TAG + else: + key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) + nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) + + assert key_receive is not None + assert nonce_receive is not None + if __debug__: + log.debug( + __name__, + "(cid: %s) Buffer before decryption: %s", + utils.get_bytes_as_str(self.channel_id), + hexlify(noise_buffer), + ) + is_tag_valid = crypto.dec( + noise_buffer, tag, key_receive, nonce_receive, b"" + ) + if __debug__: + log.debug( + __name__, + "(cid: %s) Buffer after decryption: %s", + utils.get_bytes_as_str(self.channel_id), + hexlify(noise_buffer), + ) + + self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1) + + if __debug__: + log.debug( + __name__, + "(cid: %s) Is decrypted tag valid? %s", + utils.get_bytes_as_str(self.channel_id), + str(is_tag_valid), + ) + log.debug( + __name__, + "(cid: %s) Received tag: %s", + utils.get_bytes_as_str(self.channel_id), + (hexlify(tag).decode()), + ) + log.debug( + __name__, + "(cid: %s) New nonce_receive: %i", + utils.get_bytes_as_str(self.channel_id), + nonce_receive + 1, + ) + + if not is_tag_valid: + raise ThpDecryptionError() + + def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: + if __debug__: + log.debug( + __name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id) + ) + assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH + + noise_buffer = memoryview(buffer)[0:noise_payload_len] + + if utils.DISABLE_ENCRYPTION: + tag = crypto.DUMMY_TAG + else: + key_send = self.channel_cache.get(CHANNEL_KEY_SEND) + nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND) + + assert key_send is not None + assert nonce_send is not None + + tag = crypto.enc(noise_buffer, key_send, nonce_send, b"") + + self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) + if __debug__: + log.debug(__name__, "New nonce_send: %i", nonce_send + 1) + + buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag + + def _buffer_packet_data( + self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int + ): + self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) + + def _finish_message(self): + self.bytes_read = 0 + self.expected_payload_length = 0 + self.is_cont_packet_expected = False + + # CALLED BY WORKFLOW / SESSION CONTEXT + + async def write( + self, + msg: protobuf.MessageType, + session_id: int = 0, + force: bool = False, + ) -> None: + if __debug__ and utils.EMULATOR: + log.debug( + __name__, + "(cid: %s) write message: %s\n%s", + utils.get_bytes_as_str(self.channel_id), + msg.MESSAGE_NAME, + utils.dump_protobuf(msg), + ) + + self.buffer = memory_manager.get_write_buffer(self.buffer, msg) + noise_payload_len = memory_manager.encode_into_buffer( + self.buffer, msg, session_id + ) + task = self.write_and_encrypt(self.buffer[:noise_payload_len], force) + if task is not None: + await task + + def write_error(self, err_type: int) -> Awaitable[None]: + msg_data = err_type.to_bytes(1, "big") + length = len(msg_data) + CHECKSUM_LENGTH + header = PacketHeader.get_error_header(self.get_channel_id_int(), length) + return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) + + def write_and_encrypt( + self, payload: bytes, force: bool = False + ) -> Awaitable[None] | None: + payload_length = len(payload) + self._encrypt(self.buffer, payload_length) + payload_length = payload_length + TAG_LENGTH + + if self.write_task_spawn is not None: + self.write_task_spawn.close() # UPS TODO might break something + print("\nCLOSED\n") + self._prepare_write() + if force: + if __debug__: + log.debug( + __name__, "Writing FORCE message (without async or retransmission)." + ) + return self._write_encrypted_payload_loop( + ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length]) + ) + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop( + ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length]) + ) + ) + return None + + def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: + self._prepare_write() + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop(ctrl_byte, payload) + ) + + def _prepare_write(self) -> None: + # TODO add condition that disallows to write when can_send_message is false + ABP.set_sending_allowed(self.channel_cache, False) + + async def _write_encrypted_payload_loop( + self, ctrl_byte: int, payload: bytes + ) -> None: + if __debug__: + log.debug( + __name__, + "(cid %s) write_encrypted_payload_loop", + utils.get_bytes_as_str(self.channel_id), + ) + payload_len = len(payload) + CHECKSUM_LENGTH + sync_bit = ABP.get_send_seq_bit(self.channel_cache) + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) + header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len) + self.transmission_loop = TransmissionLoop(self, header, payload) + await self.transmission_loop.start() + + ABP.set_send_seq_bit_to_opposite(self.channel_cache) + + # Let the main loop be restarted and clear loop, if there is no other + # workflow and the state is ENCRYPTED_TRANSPORT + if self._can_clear_loop(): + if __debug__: + log.debug( + __name__, + "(cid: %s) clearing loop from channel", + utils.get_bytes_as_str(self.channel_id), + ) + loop.clear() + + def _can_clear_loop(self) -> bool: + return ( + not workflow.tasks + ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py new file mode 100644 index 00000000000..a48f6d7fdb4 --- /dev/null +++ b/core/src/trezor/wire/thp/channel_manager.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp +from trezor import utils + +from . import ChannelState, interface_manager +from .channel import Channel + +if TYPE_CHECKING: + from trezorio import WireInterface + + +def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel: + """ + Creates a new channel for the interface `iface` with the buffer `buffer`. + """ + channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface)) + r = Channel(channel_cache) + r.set_buffer(buffer) + r.set_channel_state(ChannelState.TH1) + return r + + +def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: + """ + Returns all allocated channels from cache. + """ + channels: dict[int, Channel] = {} + cached_channels = cache_thp.get_all_allocated_channels() + for c in cached_channels: + channels[int.from_bytes(c.channel_id, "big")] = Channel(c) + for c in channels.values(): + c.set_buffer(buffer) + return channels diff --git a/core/src/trezor/wire/thp/checksum.py b/core/src/trezor/wire/thp/checksum.py new file mode 100644 index 00000000000..9c28f2e78d8 --- /dev/null +++ b/core/src/trezor/wire/thp/checksum.py @@ -0,0 +1,22 @@ +from micropython import const + +from trezor import utils +from trezor.crypto import crc + +CHECKSUM_LENGTH = const(4) + + +def compute(data: bytes | utils.BufferType) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/core/src/trezor/wire/thp/control_byte.py b/core/src/trezor/wire/thp/control_byte.py new file mode 100644 index 00000000000..cf3162587bc --- /dev/null +++ b/core/src/trezor/wire/thp/control_byte.py @@ -0,0 +1,47 @@ +from trezor.wire.thp import ThpError +from trezor.wire.thp.thp_messages import ( + ACK_MASK, + ACK_MESSAGE, + CONTINUATION_PACKET, + CONTINUATION_PACKET_MASK, + DATA_MASK, + ENCRYPTED_TRANSPORT, + HANDSHAKE_COMP_REQ, + HANDSHAKE_INIT_REQ, +) + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise ThpError("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise ThpError("Unexpected acknowledgement bit") + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & ACK_MASK == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/core/src/trezor/wire/thp/cpace.py b/core/src/trezor/wire/thp/cpace.py new file mode 100644 index 00000000000..302dd3e5e37 --- /dev/null +++ b/core/src/trezor/wire/thp/cpace.py @@ -0,0 +1,36 @@ +from trezor.crypto import elligator2, random +from trezor.crypto.curve import curve25519 +from trezor.crypto.hashlib import sha512 + +_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06" +_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20" + + +class Cpace: + """ + CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/ + """ + + def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None: + self.handshake_hash: bytes = handshake_hash + self.host_public_key: bytes = cpace_host_public_key + self.shared_secret: bytes + self.trezor_private_key: bytes + self.trezor_public_key: bytes + + def generate_keys_and_secret(self, code_code_entry: bytes) -> None: + """ + Generate ephemeral key pair and a shared secret using Elligator2 with X25519. + """ + sha_ctx = sha512(_PREFIX) + sha_ctx.update(code_code_entry) + sha_ctx.update(_PADDING) + sha_ctx.update(self.handshake_hash) + sha_ctx.update(b"\x00") + pregenerator = sha_ctx.digest()[:32] + generator = elligator2.map_to_curve25519(pregenerator) + self.trezor_private_key = random.bytes(32) + self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator) + self.shared_secret = curve25519.multiply( + self.trezor_private_key, self.host_public_key + ) diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py new file mode 100644 index 00000000000..ba211490f62 --- /dev/null +++ b/core/src/trezor/wire/thp/crypto.py @@ -0,0 +1,211 @@ +from micropython import const +from trezorcrypto import aesgcm, bip32, curve25519, hmac + +from storage import device +from trezor import log, utils +from trezor.crypto.hashlib import sha256 +from trezor.wire.thp import ThpDecryptionError + +# The HARDENED flag is taken from apps.common.paths +# It is not imported to save on resources +HARDENED = const(0x8000_0000) +PUBKEY_LENGTH = const(32) +if utils.DISABLE_ENCRYPTION: + DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" + +if __debug__: + from ubinascii import hexlify + + +def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes: + """ + Encrypts the provided `buffer` with AES-GCM (in place). + Returns a 16-byte long encryption tag. + """ + if __debug__: + log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce) + iv = _get_iv_from_nonce(nonce) + aes_ctx = aesgcm(key, iv) + aes_ctx.auth(auth_data) + aes_ctx.encrypt_in_place(buffer) + return aes_ctx.finish() + + +def dec( + buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes +) -> bool: + """ + Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as + the tag computed in decryption, otherwise it returns `False`. + """ + iv = _get_iv_from_nonce(nonce) + if __debug__: + log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce) + aes_ctx = aesgcm(key, iv) + aes_ctx.auth(auth_data) + aes_ctx.decrypt_in_place(buffer) + computed_tag = aes_ctx.finish() + return computed_tag == tag + + +class BusyDecoder: + def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None: + iv = _get_iv_from_nonce(nonce) + self.aes_ctx = aesgcm(key, iv) + self.aes_ctx.auth(auth_data) + + def decrypt_part(self, part: utils.BufferType) -> None: + self.aes_ctx.decrypt_in_place(part) + + def finish_and_check_tag(self, tag: bytes) -> bool: + computed_tag = self.aes_ctx.finish() + return computed_tag == tag + + +PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" +IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + + +class Handshake: + """ + `Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel. + The following values should be saved for future use before disposing of this object: + - `h` (handshake hash, can be used to bind other values to the channel) + - `key_receive` (key for decrypting incoming communication) + - `key_send` (key for encrypting outgoing communication) + """ + + def __init__(self) -> None: + self.trezor_ephemeral_privkey: bytes + self.ck: bytes + self.k: bytes + self.h: bytes + self.key_receive: bytes + self.key_send: bytes + + def handle_th1_crypto( + self, + device_properties: bytes, + host_ephemeral_pubkey: bytes, + ) -> tuple[bytes, bytes, bytes]: + + trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair() + self.trezor_ephemeral_privkey = curve25519.generate_secret() + trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey) + self.h = _hash_of_two(PROTOCOL_NAME, device_properties) + self.h = _hash_of_two(self.h, host_ephemeral_pubkey) + self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey) + point = curve25519.multiply( + self.trezor_ephemeral_privkey, host_ephemeral_pubkey + ) + self.ck, self.k = _hkdf(PROTOCOL_NAME, point) + mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey) + trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey) + aes_ctx = aesgcm(self.k, IV_1) + encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey) + if __debug__: + log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0) + aes_ctx.auth(self.h) + tag_to_encrypted_key = aes_ctx.finish() + encrypted_trezor_static_pubkey = ( + encrypted_trezor_static_pubkey + tag_to_encrypted_key + ) + self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey) + point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey) + self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point)) + aes_ctx = aesgcm(self.k, IV_1) + aes_ctx.auth(self.h) + tag = aes_ctx.finish() + self.h = _hash_of_two(self.h, tag) + return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag) + + def handle_th2_crypto( + self, + encrypted_host_static_pubkey: utils.BufferType, + encrypted_payload: utils.BufferType, + ): + + aes_ctx = aesgcm(self.k, IV_2) + + # The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted. + # However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for + # authentication of the gcm tag. + aes_ctx.auth(self.h) # Authenticate with the previous value of `h` + self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value + aes_ctx.decrypt_in_place( + memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] + ) + if __debug__: + log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1) + host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] + tag = aes_ctx.finish() + if tag != encrypted_host_static_pubkey[-16:]: + raise ThpDecryptionError() + + self.ck, self.k = _hkdf( + self.ck, + curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey), + ) + aes_ctx = aesgcm(self.k, IV_1) + aes_ctx.auth(self.h) + aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16]) + if __debug__: + log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0) + tag = aes_ctx.finish() + if tag != encrypted_payload[-16:]: + raise ThpDecryptionError() + + self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16]) + self.key_receive, self.key_send = _hkdf(self.ck, b"") + if __debug__: + log.debug( + __name__, + "(key_receive: %s, key_send: %s)", + hexlify(self.key_receive), + hexlify(self.key_send), + ) + + def get_handshake_completion_response(self, trezor_state: bytes) -> bytes: + aes_ctx = aesgcm(self.key_send, IV_1) + encrypted_trezor_state = aes_ctx.encrypt(trezor_state) + tag = aes_ctx.finish() + return encrypted_trezor_state + tag + + +def _derive_static_key_pair() -> tuple[bytes, bytes]: + node_int = HARDENED | int.from_bytes(b"\x00THP", "big") + node = bip32.from_seed(device.get_device_secret(), "curve25519") + node.derive(node_int) + + trezor_static_privkey = node.private_key() + trezor_static_pubkey = node.public_key()[1:33] + # Note: the first byte (\x01) of the public key is removed, as it + # only indicates the type of the elliptic curve used + + return trezor_static_privkey, trezor_static_pubkey + + +def get_trezor_static_pubkey() -> bytes: + _, pubkey = _derive_static_key_pair() + return pubkey + + +def _hkdf(chaining_key, input: bytes): + temp_key = hmac(hmac.SHA256, chaining_key, input).digest() + output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest() + ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes: + ctx = sha256(part_1) + ctx.update(part_2) + return ctx.digest() + + +def _get_iv_from_nonce(nonce: int) -> bytes: + utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") diff --git a/core/src/trezor/wire/thp/interface_manager.py b/core/src/trezor/wire/thp/interface_manager.py new file mode 100644 index 00000000000..f71dae0d615 --- /dev/null +++ b/core/src/trezor/wire/thp/interface_manager.py @@ -0,0 +1,30 @@ +from typing import TYPE_CHECKING + +import usb + +_MOCK_INTERFACE_HID = b"\x00" +_WIRE_INTERFACE_USB = b"\x01" + +if TYPE_CHECKING: + from trezorio import WireInterface + + +def decode_iface(cached_iface: bytes) -> WireInterface: + """Decode the cached wire interface.""" + if cached_iface == _WIRE_INTERFACE_USB: + iface = usb.iface_wire + if iface is None: + raise RuntimeError("There is no valid USB WireInterface") + return iface + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") + + +def encode_iface(iface: WireInterface) -> bytes: + """Encode wire interface into bytes.""" + if iface is usb.iface_wire: + return _WIRE_INTERFACE_USB + # TODO implement bluetooth interface + if __debug__: + return _MOCK_INTERFACE_HID + raise Exception("Unknown WireInterface") diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py new file mode 100644 index 00000000000..cbaef4a0e8b --- /dev/null +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -0,0 +1,174 @@ +from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH +from trezor import log, protobuf, utils + +from . import ChannelState, ThpError +from .checksum import CHECKSUM_LENGTH +from .writer import ( + INIT_HEADER_LENGTH, + MAX_PAYLOAD_LEN, + MESSAGE_TYPE_LENGTH, + PACKET_LENGTH, +) + + +def select_buffer( + channel_state: int, + channel_buffer: utils.BufferType, + packet_payload: utils.BufferType, + payload_length: int, +) -> utils.BufferType: + + if channel_state is ChannelState.ENCRYPTED_TRANSPORT: + session_id = packet_payload[0] + if session_id == 0: + pass + # TODO use small buffer + else: + pass + # TODO use big buffer but only if the channel owns the buffer lock. + # Otherwise send BUSY message and return + else: + pass + # TODO use small buffer + try: + # TODO for now, we create a new big buffer every time. It should be changed + buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer) + return buffer + except Exception as e: + if __debug__: + log.exception(__name__, e) + raise Exception("Failed to create a buffer for channel") # TODO handle better + + +def get_write_buffer( + buffer: utils.BufferType, msg: protobuf.MessageType +) -> utils.BufferType: + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + + if required_min_size > len(buffer): + return _get_buffer_for_write(required_min_size, buffer) + return buffer + + +def encode_into_buffer( + buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int +) -> int: + + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None + + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + + _encode_session_into_buffer(memoryview(buffer), session_id) + _encode_message_type_into_buffer( + memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH + ) + _encode_message_into_buffer( + memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + ) + + return payload_size + + +def _encode_session_into_buffer( + buffer: memoryview, session_id: int, buffer_offset: int = 0 +) -> None: + session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") + utils.memcpy(buffer, buffer_offset, session_id_bytes, 0) + + +def _encode_message_type_into_buffer( + buffer: memoryview, message_type: int, offset: int = 0 +) -> None: + msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big") + utils.memcpy(buffer, offset, msg_type_bytes, 0) + + +def _encode_message_into_buffer( + buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 +) -> None: + protobuf.encode(memoryview(buffer[buffer_offset:]), message) + + +def _get_buffer_for_read( + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN +) -> utils.BufferType: + length = payload_length + INIT_HEADER_LENGTH + if __debug__: + log.debug( + __name__, + "get_buffer_for_read - length: %d, %s %s", + length, + "existing buffer type:", + type(existing_buffer), + ) + if length > max_length: + raise ThpError("Message too large") + + if length > len(existing_buffer): + if __debug__: + log.debug(__name__, "Allocating a new buffer") + + from ..thp_main import get_raw_read_buffer + + if length > len(get_raw_read_buffer()): + if __debug__: + log.debug( + __name__, + "Required length is %d, where raw buffer has capacity only %d", + length, + len(get_raw_read_buffer()), + ) + raise ThpError("Message is too large") + + try: + payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length] + except MemoryError: + payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH] + raise ThpError("Message is too large") + return payload + + # reuse a part of the supplied buffer + if __debug__: + log.debug(__name__, "Reusing already allocated buffer") + return memoryview(existing_buffer)[:length] + + +def _get_buffer_for_write( + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN +) -> utils.BufferType: + length = payload_length + INIT_HEADER_LENGTH + if __debug__: + log.debug( + __name__, + "get_buffer_for_write - length: %d, %s %s", + length, + "existing buffer type:", + type(existing_buffer), + ) + if length > max_length: + raise ThpError("Message too large") + + if length > len(existing_buffer): + if __debug__: + log.debug(__name__, "Creating a new write buffer from raw write buffer") + + from ..thp_main import get_raw_write_buffer + + if length > len(get_raw_write_buffer()): + raise ThpError("Message is too large") + + try: + payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length] + except MemoryError: + payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH] + raise ThpError("Message is too large") + return payload + + # reuse a part of the supplied buffer + if __debug__: + log.debug(__name__, "Reusing already allocated buffer") + return memoryview(existing_buffer)[:length] diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py new file mode 100644 index 00000000000..4a13ec6ea01 --- /dev/null +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -0,0 +1,251 @@ +from typing import TYPE_CHECKING +from ubinascii import hexlify + +import trezorui2 +from trezor import loop, protobuf, workflow +from trezor.crypto import random +from trezor.wire import context, message_handler, protocol_common +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.errors import ActionCancelled, SilentError +from trezor.wire.protocol_common import Context, Message + +if TYPE_CHECKING: + from typing import Container + + from trezor import ui + + from .channel import Channel + from .cpace import Cpace + + pass + +if __debug__: + from trezor import log + + +class PairingDisplayData: + + def __init__(self) -> None: + self.code_code_entry: int | None = None + self.code_qr_code: bytes | None = None + self.code_nfc_unidirectional: bytes | None = None + + def get_display_layout(self) -> ui.Layout: + from trezor import ui + + # TODO have different layouts when there is only QR code or only Code Entry + qr_str = "" + code_str = "" + if self.code_qr_code is not None: + qr_str = self._get_code_qr_code_str() + if self.code_code_entry is not None: + code_str = self._get_code_code_entry_str() + + return ui.Layout( + trezorui2.show_address_details( # noqa + qr_title="Scan QR code to pair", + address=qr_str, + case_sensitive=True, + details_title="", + account="Code to rewrite:\n" + code_str, + path="", + xpubs=[], + ) + ) + + def _get_code_code_entry_str(self) -> str: + if self.code_code_entry is not None: + code_str = f"{self.code_code_entry:06}" + if __debug__: + log.debug(__name__, "code_code_entry: %s", code_str) + + return code_str[:3] + " " + code_str[3:] + raise Exception("Code entry string is not available") + + def _get_code_qr_code_str(self) -> str: + if self.code_qr_code is not None: + code_str = (hexlify(self.code_qr_code)).decode("utf-8") + if __debug__: + log.debug(__name__, "code_qr_code_hexlified: %s", code_str) + return code_str + raise Exception("QR code string is not available") + + +class PairingContext(Context): + + def __init__(self, channel_ctx: Channel) -> None: + super().__init__(channel_ctx.iface, channel_ctx.channel_id) + self.channel_ctx: Channel = channel_ctx + self.incoming_message = loop.chan() + self.secret: bytes = random.bytes(16) + + self.display_data: PairingDisplayData = PairingDisplayData() + self.cpace: Cpace + self.host_name: str + + async def handle(self, is_debug_session: bool = False) -> None: + # if __debug__: + # log.debug(__name__, "handle - start") + # if is_debug_session: + # import apps.debug + + # apps.debug.DEBUG_CONTEXT = self + + take = self.incoming_message.take() + next_message: Message | None = None + + while True: + try: + if next_message is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one. + try: + message: Message = await take + except protocol_common.WireError as e: + if __debug__: + log.exception(__name__, e) + await self.write(message_handler.failure(e)) + continue + else: + # Process the message from previous run. + message = next_message + next_message = None + + try: + next_message = await handle_pairing_request_message(self, message) + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + # This is not done for the debug session because the snapshot taken + # in a debug session would clear modules which are in use by the + # workflow running on wire. + # TODO utils.unimport_end(modules) + + if next_message is None: + + # Shut down the loop if there is no next message waiting. + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. # TODO not updated comments + if __debug__: + log.exception(__name__, exc) + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) + + message: Message = await self.incoming_message.take() + + if message.type not in expected_types: + raise UnexpectedMessageException(message) + + if expected_type is None: + expected_type = protobuf.type_for_wire(message.type) + + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel_ctx.write(msg) + + async def call( + self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] + ) -> protobuf.MessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + async def call_any( + self, msg: protobuf.MessageType, *expected_types: int + ) -> protobuf.MessageType: + await self.write(msg) + del msg + return await self.read(expected_types) + + +async def handle_pairing_request_message( + pairing_ctx: PairingContext, + msg: protocol_common.Message, +) -> protocol_common.Message | None: + + res_msg: protobuf.MessageType | None = None + + from apps.thp.pairing import handle_pairing_request + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + req_type = protobuf.type_for_wire(msg.type) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + task = handle_pairing_request(pairing_ctx, req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + res_msg = await workflow.spawn(context.with_context(pairing_ctx, task)) + + except UnexpectedMessageException as exc: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + # TODO: + # We might handle only the few common cases here, like + # Initialize and Cancel. + return exc.msg + except SilentError as exc: + if __debug__: + log.error(__name__, "SilentError: %s", exc.message) + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = message_handler.failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await pairing_ctx.write(res_msg) + return None diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py new file mode 100644 index 00000000000..d6ed3e05cfa --- /dev/null +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -0,0 +1,413 @@ +import ustruct +from typing import TYPE_CHECKING + +from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, +) +from storage.cache_thp import ( + KEY_LENGTH, + MANAGEMENT_SESSION_ID, + SESSION_ID_LENGTH, + TAG_LENGTH, + update_channel_last_used, + update_session_last_used, +) +from trezor import log, loop, utils +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor.wire.thp import session_manager + +from ..errors import DataError +from ..protocol_common import Message +from . import ( + ChannelState, + SessionState, + ThpDecryptionError, + ThpError, + ThpErrorType, + ThpInvalidDataError, + ThpUnallocatedSessionError, +) +from . import alternating_bit_protocol as ABP +from . import checksum, control_byte, is_channel_state_pairing, thp_messages +from .checksum import CHECKSUM_LENGTH +from .crypto import PUBKEY_LENGTH, Handshake +from .thp_messages import ( + ACK_MESSAGE, + HANDSHAKE_COMP_RES, + HANDSHAKE_INIT_RES, + PacketHeader, +) +from .writer import ( + INIT_HEADER_LENGTH, + MESSAGE_TYPE_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if TYPE_CHECKING: + from typing import Awaitable + + from trezor.messages import ThpHandshakeCompletionReqNoisePayload + + from .channel import Channel + +if __debug__: + from ubinascii import hexlify + + from . import state_to_str + + +async def handle_received_message( + ctx: Channel, message_buffer: utils.BufferType +) -> None: + """Handle a message received from the channel.""" + + if __debug__: + log.debug(__name__, "handle_received_message") + try: + import micropython + + print("micropython.mem_info() from received_message_handler.py") + micropython.mem_info() + print( + "Allocation count:", micropython.alloc_count() # type: ignore ["alloc_count" is not a known attribute of module "micropython"] + ) + except AttributeError: + print("To show allocation count, create the build with TREZOR_MEMPERF=1") + ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer) + message_length = payload_length + INIT_HEADER_LENGTH + + _check_checksum(message_length, message_buffer) + + # Synchronization process + seq_bit = (ctrl_byte & 0x10) >> 4 + ack_bit = (ctrl_byte & 0x08) >> 3 + if __debug__: + log.debug( + __name__, + "handle_completed_message - seq bit of message: %d, ack bit of message: %d", + seq_bit, + ack_bit, + ) + # 0: Update "last-time used" + update_channel_last_used(ctx.channel_id) + + # 1: Handle ACKs + if control_byte.is_ack(ctrl_byte): + await _handle_ack(ctx, ack_bit) + return + + if _should_have_ctrl_byte_encrypted_transport( + ctx + ) and not control_byte.is_encrypted_transport(ctrl_byte): + raise ThpError("Message is not encrypted. Ignoring") + + # 2: Handle message with unexpected sequential bit + if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache): + if __debug__: + log.debug(__name__, "Received message with an unexpected sequential bit") + await _send_ack(ctx, ack_bit=seq_bit) + raise ThpError("Received message with an unexpected sequential bit") + + # 3: Send ACK in response + await _send_ack(ctx, ack_bit=seq_bit) + + ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit) + + try: + await _handle_message_to_app_or_channel( + ctx, payload_length, message_length, ctrl_byte + ) + except ThpUnallocatedSessionError as e: + error_message = Failure(code=FailureType.ThpUnallocatedSession) + await ctx.write(error_message, e.session_id) + except ThpDecryptionError: + await ctx.write_error(ThpErrorType.DECRYPTION_FAILED) + ctx.clear() + except ThpInvalidDataError: + await ctx.write_error(ThpErrorType.INVALID_DATA) + ctx.clear() + if __debug__: + log.debug(__name__, "handle_received_message - end") + + +def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]: + ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) + header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH) + if __debug__: + log.debug( + __name__, + "Writing ACK message to a channel with id: %d, ack_bit: %d", + ctx.get_channel_id_int(), + ack_bit, + ) + return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"") + + +def _check_checksum(message_length: int, message_buffer: utils.BufferType): + if __debug__: + log.debug(__name__, "check_checksum") + if not checksum.is_valid( + checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length], + data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH], + ): + if __debug__: + log.debug(__name__, "Invalid checksum, ignoring message.") + raise ThpError("Invalid checksum, ignoring message.") + + +async def _handle_ack(ctx: Channel, ack_bit: int): + if not ABP.is_ack_valid(ctx.channel_cache, ack_bit): + return + # ACK is expected and it has correct sync bit + if __debug__: + log.debug(__name__, "Received ACK message with correct ack bit") + if ctx.transmission_loop is not None: + ctx.transmission_loop.stop_immediately() + if __debug__: + log.debug(__name__, "Stopped transmission loop") + + ABP.set_sending_allowed(ctx.channel_cache, True) + + if ctx.write_task_spawn is not None: + if __debug__: + log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') + await ctx.write_task_spawn + # Note that no the write_task_spawn could result in loop.clear(), + # which will result in termination of this function - any code after + # this await might not be executed + + +def _handle_message_to_app_or_channel( + ctx: Channel, + payload_length: int, + message_length: int, + ctrl_byte: int, +) -> Awaitable[None]: + state = ctx.get_channel_state() + if __debug__: + log.debug(__name__, "state: %s", state_to_str(state)) + + if state is ChannelState.ENCRYPTED_TRANSPORT: + return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) + + if state is ChannelState.TH1: + return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte) + + if state is ChannelState.TH2: + return _handle_state_TH2(ctx, message_length, ctrl_byte) + + if is_channel_state_pairing(state): + return _handle_pairing(ctx, message_length) + + raise ThpError("Unimplemented channel state") + + +async def _handle_state_TH1( + ctx: Channel, + payload_length: int, + message_length: int, + ctrl_byte: int, +) -> None: + if __debug__: + log.debug(__name__, "handle_state_TH1") + if not control_byte.is_handshake_init_req(ctrl_byte): + raise ThpError("Message received is not a handshake init request!") + if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: + raise ThpError("Message received is not a valid handshake init request!") + + ctx.handshake = Handshake() + + host_ephemeral_pubkey = bytearray( + ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH] + ) + trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = ( + ctx.handshake.handle_th1_crypto( + thp_messages.get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey + ) + ) + + if __debug__: + log.debug( + __name__, + "trezor ephemeral pubkey: %s", + hexlify(trezor_ephemeral_pubkey).decode(), + ) + log.debug( + __name__, + "encrypted trezor masked static pubkey: %s", + hexlify(encrypted_trezor_static_pubkey).decode(), + ) + log.debug(__name__, "tag: %s", hexlify(tag)) + + payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag + + # send handshake init response message + ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload) + ctx.set_channel_state(ChannelState.TH2) + return + + +async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None: + from apps.thp.credential_manager import validate_credential + + if __debug__: + log.debug(__name__, "handle_state_TH2") + if not control_byte.is_handshake_comp_req(ctrl_byte): + raise ThpError("Message received is not a handshake completion request!") + if ctx.handshake is None: + raise Exception("Handshake object is not prepared. Retry handshake.") + + host_encrypted_static_pubkey = memoryview(ctx.buffer)[ + INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH + ] + handshake_completion_request_noise_payload = memoryview(ctx.buffer)[ + INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH + ] + + ctx.handshake.handle_th2_crypto( + host_encrypted_static_pubkey, handshake_completion_request_noise_payload + ) + + ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive) + ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send) + ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h) + ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1) + + noise_payload = thp_messages.decode_message( + ctx.buffer[ + INIT_HEADER_LENGTH + + KEY_LENGTH + + TAG_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + 0, + "ThpHandshakeCompletionReqNoisePayload", + ) + if TYPE_CHECKING: + assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) + enabled_methods = thp_messages.get_enabled_pairing_methods(ctx.iface) + for method in noise_payload.pairing_methods: + if method not in enabled_methods: + raise ThpInvalidDataError() + if method not in ctx.selected_pairing_methods: + ctx.selected_pairing_methods.append(method) + if __debug__: + log.debug( + __name__, + "host static pubkey: %s, noise payload: %s", + utils.get_bytes_as_str(host_encrypted_static_pubkey), + utils.get_bytes_as_str(handshake_completion_request_noise_payload), + ) + + # key is decoded in handshake._handle_th2_crypto + host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH] + + paired: bool = False + + if noise_payload.host_pairing_credential is not None: + try: # TODO change try-except for something better + paired = validate_credential( + noise_payload.host_pairing_credential, + host_static_pubkey, + ) + except DataError as e: + if __debug__: + log.exception(__name__, e) + pass + + trezor_state = thp_messages.TREZOR_STATE_UNPAIRED + if paired: + trezor_state = thp_messages.TREZOR_STATE_PAIRED + # send hanshake completion response + ctx.write_handshake_message( + HANDSHAKE_COMP_RES, + ctx.handshake.get_handshake_completion_response(trezor_state), + ) + + ctx.handshake = None + + if paired: + ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + else: + ctx.set_channel_state(ChannelState.TP1) + + +async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None: + if __debug__: + log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") + + ctx.decrypt_buffer(message_length) + session_id, message_type = ustruct.unpack( + ">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:] + ) + if session_id not in ctx.sessions: + if session_id == MANAGEMENT_SESSION_ID: + s = session_manager.create_new_management_session(ctx) + else: + s = session_manager.get_session_from_cache(ctx, session_id) + if s is None: + raise ThpUnallocatedSessionError(session_id) + ctx.sessions[session_id] = s + loop.schedule(s.handle()) + + elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED: + raise ThpUnallocatedSessionError(session_id) + + s = ctx.sessions[session_id] + update_session_last_used(s.channel_id, s.session_id) + + s.incoming_message.publish( + Message( + message_type, + ctx.buffer[ + INIT_HEADER_LENGTH + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +async def _handle_pairing(ctx: Channel, message_length: int) -> None: + from .pairing_context import PairingContext + + if ctx.connection_context is None: + ctx.connection_context = PairingContext(ctx) + loop.schedule(ctx.connection_context.handle()) + + ctx.decrypt_buffer(message_length) + message_type = ustruct.unpack( + ">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] + )[0] + + ctx.connection_context.incoming_message.publish( + Message( + message_type, + ctx.buffer[ + INIT_HEADER_LENGTH + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool: + if ctx.get_channel_state() in [ + ChannelState.UNALLOCATED, + ChannelState.TH1, + ChannelState.TH2, + ]: + return False + return True diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py new file mode 100644 index 00000000000..728ee5f4e97 --- /dev/null +++ b/core/src/trezor/wire/thp/session_context.py @@ -0,0 +1,198 @@ +from typing import TYPE_CHECKING + +from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache +from trezor import log, loop, protobuf +from trezor.wire import message_handler, protocol_common +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure, find_handler + +from ..protocol_common import Context, Message +from . import SessionState + +if TYPE_CHECKING: + from typing import Any, Awaitable, Container + + from storage.cache_common import DataCache + + from ..message_handler import HandlerFinder + from .channel import Channel + + pass + +_EXIT_LOOP = True +_REPEAT_LOOP = False + +if __debug__: + from trezor.utils import get_bytes_as_str + + +class GenericSessionContext(Context): + + def __init__(self, channel: Channel, session_id: int) -> None: + super().__init__(channel.iface, channel.channel_id) + self.channel: Channel = channel + self.session_id: int = session_id + self.incoming_message = loop.chan() + self.handler_finder: HandlerFinder = find_handler + + async def handle(self) -> None: + if __debug__: + self._handle_debug() + + take = self.incoming_message.take() + next_message: Message | None = None + + while True: + try: + if await self._handle_message(take, next_message): + loop.schedule(self.handle()) + return + except UnexpectedMessageException as unexpected: + # The workflow was interrupted by an unexpected message. We need to + # process it as if it was a new message... + next_message = unexpected.msg + continue + except Exception as exc: + # Log and try again. + if __debug__: + log.exception(__name__, exc) + + def _handle_debug(self) -> None: + log.debug( + __name__, + "handle - start (channel_id (bytes): %s, session_id: %d)", + get_bytes_as_str(self.channel_id), + self.session_id, + ) + # if is_debug_session: + # import apps.debug + + # apps.debug.DEBUG_CONTEXT = self + + async def _handle_message( + self, + take: Awaitable[Any], + next_message: Message | None, + ) -> bool: + + try: + message = await self._get_message(take, next_message) + except protocol_common.WireError as e: + if __debug__: + log.exception(__name__, e) + await self.write(failure(e)) + return _REPEAT_LOOP + + try: + await message_handler.handle_single_message( + self, + message, + self.handler_finder, + ) + except UnexpectedMessageException: + raise + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + # This is not done for the debug session because the snapshot taken + # in a debug session would clear modules which are in use by the + # workflow running on wire. + # TODO utils.unimport_end(modules) + + if next_message is None and message.type not in AVOID_RESTARTING_FOR: + # Shut down the loop if there is no next message waiting. + return _EXIT_LOOP # pylint: disable=lost-exception + return _REPEAT_LOOP # pylint: disable=lost-exception + + async def _get_message( + self, take: Awaitable[Any], next_message: Message | None + ) -> Message: + if next_message is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one. + message: Message = await take + else: + # Process the message from previous run. + message = next_message + next_message = None + return message + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) + message: Message = await self.incoming_message.take() + if message.type not in expected_types: + if __debug__: + log.debug( + __name__, + "EXPECTED TYPES: %s\nRECEIVED TYPE: %s", + str(expected_types), + str(message.type), + ) + raise UnexpectedMessageException(message) + + if expected_type is None: + expected_type = protobuf.type_for_wire(message.type) + + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel.write(msg, self.session_id) + + async def write_force(self, msg: protobuf.MessageType) -> None: + return await self.channel.write(msg, self.session_id, force=True) + + def get_session_state(self) -> SessionState: ... + + +class ManagementSessionContext(GenericSessionContext): + + def __init__( + self, channel_ctx: Channel, session_id: int = MANAGEMENT_SESSION_ID + ) -> None: + super().__init__(channel_ctx, session_id) + + def get_session_state(self) -> SessionState: + return SessionState.MANAGEMENT + + +class SessionContext(GenericSessionContext): + + def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None: + if channel_ctx.channel_id != session_cache.channel_id: + raise Exception( + "The session has different channel id than the provided channel context!" + ) + session_id = int.from_bytes(session_cache.session_id, "big") + super().__init__(channel_ctx, session_id) + self.session_cache = session_cache + + # ACCESS TO SESSION DATA + + def get_session_state(self) -> SessionState: + state = int.from_bytes(self.session_cache.state, "big") + return SessionState(state) + + def set_session_state(self, state: SessionState) -> None: + self.session_cache.state = bytearray(state.to_bytes(1, "big")) + + # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + return self.session_cache diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py new file mode 100644 index 00000000000..d7ab1762d6f --- /dev/null +++ b/core/src/trezor/wire/thp/session_manager.py @@ -0,0 +1,37 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp + +from .session_context import ( + GenericSessionContext, + ManagementSessionContext, + SessionContext, +) + +if TYPE_CHECKING: + from .channel import Channel + + +def create_new_session(channel_ctx: Channel) -> SessionContext: + session_cache = cache_thp.get_new_session(channel_ctx.channel_cache) + return SessionContext(channel_ctx, session_cache) + + +def create_new_management_session( + channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID +) -> ManagementSessionContext: + return ManagementSessionContext(channel_ctx, session_id) + + +def get_session_from_cache( + channel_ctx: Channel, session_id: int +) -> GenericSessionContext | None: + cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id) + for s in cached_sessions: + print(s, s.channel_id, int.from_bytes(s.session_id, "big")) + if ( + s.channel_id == channel_ctx.channel_id + and int.from_bytes(s.session_id, "big") == session_id + ): + return SessionContext(channel_ctx, s) + return None diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py new file mode 100644 index 00000000000..a1985f58dd3 --- /dev/null +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -0,0 +1,136 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import protobuf, utils +from trezor.enums import ThpPairingMethod +from trezor.messages import ThpDeviceProperties + +from .. import message_handler + +CODEC_V1 = const(0x3F) +CONTINUATION_PACKET = const(0x80) +HANDSHAKE_INIT_REQ = const(0x00) +HANDSHAKE_INIT_RES = const(0x01) +HANDSHAKE_COMP_REQ = const(0x02) +HANDSHAKE_COMP_RES = const(0x03) +ENCRYPTED_TRANSPORT = const(0x04) + +CONTINUATION_PACKET_MASK = const(0x80) +ACK_MASK = const(0xF7) +DATA_MASK = const(0xE7) + +ACK_MESSAGE = const(0x20) +_ERROR = const(0x42) +CHANNEL_ALLOCATION_REQ = const(0x40) +_CHANNEL_ALLOCATION_RES = const(0x41) + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +if __debug__: + from trezor import log + +if TYPE_CHECKING: + from trezor.wire import WireInterface + + +class PacketHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.length = length + + def to_bytes(self) -> bytes: + return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + ustruct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + ustruct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_response_header(cls, length: int): + return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length) + + +_DEFAULT_ENABLED_PAIRING_METHODS = [ + ThpPairingMethod.CodeEntry, + ThpPairingMethod.QrCode, + ThpPairingMethod.NFC_Unidirectional, +] + + +def get_enabled_pairing_methods( + iface: WireInterface | None = None, +) -> list[ThpPairingMethod]: + import usb + + l = _DEFAULT_ENABLED_PAIRING_METHODS.copy() + if iface is not None and iface is usb.iface_wire: + l.append(ThpPairingMethod.NoMethod) + return l + + +def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties: + # TODO define model variants + return ThpDeviceProperties( + pairing_methods=get_enabled_pairing_methods(iface), + internal_model=utils.INTERNAL_MODEL, + model_variant=0, + bootloader_mode=False, + protocol_version=2, + ) + + +def get_encoded_device_properties(iface: WireInterface) -> bytes: + props = _get_device_properties(iface) + length = protobuf.encoded_length(props) + encoded_properties = bytearray(length) + protobuf.encode(encoded_properties, props) + return encoded_properties + + +def get_channel_allocation_response( + nonce: bytes, new_cid: bytes, iface: WireInterface +) -> bytes: + props_msg = get_encoded_device_properties(iface) + return nonce + new_cid + props_msg + + +def get_codec_v1_error_message() -> bytes: + # Codec_v1 magic constant "?##" + Failure message type + msg_size + # + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B + ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + return ERROR_MSG + + +def decode_message( + buffer: bytes, msg_type: int, message_name: str | None = None +) -> protobuf.MessageType: + if __debug__: + log.debug(__name__, "decode message") + if message_name is not None: + expected_type = protobuf.type_for_name(message_name) + else: + expected_type = protobuf.type_for_wire(msg_type) + x = message_handler.wrap_protobuf_load(buffer, expected_type) + return x diff --git a/core/src/trezor/wire/thp/transmission_loop.py b/core/src/trezor/wire/thp/transmission_loop.py new file mode 100644 index 00000000000..529499f84f5 --- /dev/null +++ b/core/src/trezor/wire/thp/transmission_loop.py @@ -0,0 +1,51 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import loop +from trezor.wire.thp.thp_messages import PacketHeader +from trezor.wire.thp.writer import write_payload_to_wire_and_add_checksum + +if TYPE_CHECKING: + from trezor.wire.thp.channel import Channel + +MAX_RETRANSMISSION_COUNT = const(50) +MIN_RETRANSMISSION_COUNT = const(2) + + +class TransmissionLoop: + + def __init__( + self, channel: Channel, header: PacketHeader, transport_payload: bytes + ) -> None: + self.channel: Channel = channel + self.header: PacketHeader = header + self.transport_payload: bytes = transport_payload + self.wait_task: loop.spawn | None = None + self.min_retransmisson_count_achieved: bool = False + + async def start(self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT): + self.min_retransmisson_count_achieved = False + for i in range(max_retransmission_count): + if i >= MIN_RETRANSMISSION_COUNT: + self.min_retransmisson_count_achieved = True + await write_payload_to_wire_and_add_checksum( + self.channel.iface, self.header, self.transport_payload + ) + self.wait_task = loop.spawn(self._wait(i)) + try: + await self.wait_task + except loop.TaskClosed: + self.wait_task = None + break + + def stop_immediately(self): + if self.wait_task is not None: + self.wait_task.close() + self.wait_task = None + + async def _wait(self, counter: int = 0) -> None: + timeout_ms = round(10200 - 1010000 / (counter + 100)) + await loop.sleep(timeout_ms) + + def __del__(self): + self.stop_immediately() diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py new file mode 100644 index 00000000000..0cd32cbc8ab --- /dev/null +++ b/core/src/trezor/wire/thp/writer.py @@ -0,0 +1,91 @@ +from micropython import const +from trezorcrypto import crc +from typing import TYPE_CHECKING + +from trezor import io, log, loop, utils +from trezor.wire.thp.thp_messages import PacketHeader + +INIT_HEADER_LENGTH = const(5) +CONT_HEADER_LENGTH = const(3) +PACKET_LENGTH = const(64) +CHECKSUM_LENGTH = const(4) +MAX_PAYLOAD_LEN = const(60000) +MESSAGE_TYPE_LENGTH = const(2) + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable, Sequence + + +def write_payload_to_wire_and_add_checksum( + iface: WireInterface, header: PacketHeader, transport_payload: bytes +) -> Awaitable[None]: + header_checksum: int = crc.crc32(header.to_bytes()) + checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes( + CHECKSUM_LENGTH, "big" + ) + data = (transport_payload, checksum) + return write_payloads_to_wire(iface, header, data) + + +async def write_payloads_to_wire( + iface: WireInterface, header: PacketHeader, data: Sequence[bytes] +): + n_of_data = len(data) + total_length = sum(len(item) for item in data) + + current_data_idx = 0 + current_data_offset = 0 + + packet = bytearray(PACKET_LENGTH) + header.pack_to_init_buffer(packet) + packet_offset: int = INIT_HEADER_LENGTH + packet_number = 0 + nwritten = 0 + while nwritten < total_length: + if packet_number == 1: + header.pack_to_cont_buffer(packet) + if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH: + packet[:] = bytearray(PACKET_LENGTH) + header.pack_to_cont_buffer(packet) + while True: + n = utils.memcpy( + packet, packet_offset, data[current_data_idx], current_data_offset + ) + packet_offset += n + current_data_offset += n + nwritten += n + + if packet_offset < PACKET_LENGTH: + current_data_idx += 1 + current_data_offset = 0 + if current_data_idx >= n_of_data: + break + elif packet_offset == PACKET_LENGTH: + break + else: + raise Exception("Should not happen!!!") + packet_number += 1 + packet_offset = CONT_HEADER_LENGTH + + # write packet to wire (in-lined) + if __debug__: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + written_by_iface: int = 0 + while written_by_iface < len(packet): + await loop.wait(iface.iface_num() | io.POLL_WRITE) + written_by_iface = iface.write(packet) + + +async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None: + while True: + await loop.wait(iface.iface_num() | io.POLL_WRITE) + if __debug__: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + n_written = iface.write(packet) + if n_written == len(packet): + return diff --git a/core/src/trezor/wire/thp_main.py b/core/src/trezor/wire/thp_main.py new file mode 100644 index 00000000000..d25fa1d8cd2 --- /dev/null +++ b/core/src/trezor/wire/thp_main.py @@ -0,0 +1,176 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import io, log, loop, utils +from trezor.wire.thp import writer + +from .thp import ( + ChannelState, + ThpError, + ThpErrorType, + channel_manager, + checksum, + thp_messages, +) +from .thp.channel import Channel +from .thp.checksum import CHECKSUM_LENGTH +from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, PacketHeader +from .thp.writer import ( + INIT_HEADER_LENGTH, + MAX_PAYLOAD_LEN, + PACKET_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if TYPE_CHECKING: + from trezorio import WireInterface + +_CID_REQ_PAYLOAD_LENGTH = const(12) +_READ_BUFFER: bytearray +_WRITE_BUFFER: bytearray +_CHANNELS: dict[int, Channel] = {} + + +def set_read_buffer(buffer: bytearray): + global _READ_BUFFER + _READ_BUFFER = buffer + + +def set_write_buffer(buffer: bytearray): + global _WRITE_BUFFER + _WRITE_BUFFER = buffer + + +def get_raw_read_buffer() -> bytearray: + global _READ_BUFFER + return _READ_BUFFER + + +def get_raw_write_buffer() -> bytearray: + global _WRITE_BUFFER + return _WRITE_BUFFER + + +async def thp_main_loop(iface: WireInterface): + global _CHANNELS + global _READ_BUFFER + _CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER) + + read = loop.wait(iface.iface_num() | io.POLL_READ) + + while True: + try: + if __debug__: + log.debug(__name__, "thp_main_loop") + packet = await read + ctrl_byte, cid = ustruct.unpack(">BH", packet) + + if ctrl_byte == CODEC_V1: + await _handle_codec_v1(iface, packet) + continue + + if cid == BROADCAST_CHANNEL_ID: + await _handle_broadcast(iface, ctrl_byte, packet) + continue + + if cid in _CHANNELS: + await _handle_allocated(iface, cid, packet) + else: + await _handle_unallocated(iface, cid) + + except ThpError as e: + if __debug__: + log.exception(__name__, e) + + +async def _handle_codec_v1(iface: WireInterface, packet): + # If the received packet is not initial codec_v1 packet, do not send error message + if not packet[1:3] == b"##": + return + if __debug__: + log.debug(__name__, "Received codec_v1 message, returning error") + error_message = thp_messages.get_codec_v1_error_message() + await writer.write_packet_to_wire(iface, error_message) + + +async def _handle_broadcast( + iface: WireInterface, ctrl_byte: int, packet: utils.BufferType +) -> None: + global _READ_BUFFER + if ctrl_byte != CHANNEL_ALLOCATION_REQ: + raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") + if __debug__: + log.debug(__name__, "Received valid message on the broadcast channel") + + length, nonce = ustruct.unpack(">H8s", packet[3:]) + payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH) + if not checksum.is_valid( + payload[-4:], + packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH], + ): + raise ThpError("Checksum is not valid") + + new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER) + cid = int.from_bytes(new_channel.channel_id, "big") + _CHANNELS[cid] = new_channel + + response_data = thp_messages.get_channel_allocation_response( + nonce, new_channel.channel_id, iface + ) + response_header = PacketHeader.get_channel_allocation_response_header( + len(response_data) + CHECKSUM_LENGTH, + ) + if __debug__: + log.debug(__name__, "New channel allocated with id %d", cid) + + await write_payload_to_wire_and_add_checksum(iface, response_header, response_data) + + +async def _handle_allocated( + iface: WireInterface, cid: int, packet: utils.BufferType +) -> None: + channel = _CHANNELS[cid] + if channel is None: + await _handle_unallocated(iface, cid) + raise ThpError("Invalid state of a channel") + if channel.iface is not iface: + # TODO send error message to wire + raise ThpError("Channel has different WireInterface") + + if channel.get_channel_state() != ChannelState.UNALLOCATED: + x = channel.receive_packet(packet) + if x is not None: + await x + + +async def _handle_unallocated(iface, cid) -> None: + data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big") + header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) + await write_payload_to_wire_and_add_checksum(iface, header, data) + + +def _get_buffer_for_payload( + payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN +) -> utils.BufferType: + if payload_length > max_length: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + return _try_allocate_new_buffer(payload_length) + return _reuse_existing_buffer(payload_length, existing_buffer) + + +def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType: + try: + payload: utils.BufferType = bytearray(payload_length) + except MemoryError: + payload = bytearray(PACKET_LENGTH) + raise ThpError("Message too large") + return payload + + +def _reuse_existing_buffer( + payload_length: int, existing_buffer: utils.BufferType +) -> utils.BufferType: + return memoryview(existing_buffer)[:payload_length] diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 1252a1bf5f1..62ce2726efb 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -1,7 +1,7 @@ import utime from typing import TYPE_CHECKING -import storage.cache +import storage.cache_common as cache_common from trezor import log, loop from trezor.enums import MessageType @@ -153,7 +153,7 @@ def close_others() -> None: if not task.is_running(): task.close() - storage.cache.homescreen_shown = None + cache_common.homescreen_shown = None # if tasks were running, closing the last of them will run start_default @@ -211,11 +211,11 @@ def touch(self, _restore_from_cache: bool = False) -> None: time and saves it to storage.cache. This is done to avoid losing an active timer when workflow restart happens and tasks are lost. """ - if _restore_from_cache and storage.cache.autolock_last_touch is not None: - now = storage.cache.autolock_last_touch + if _restore_from_cache and cache_common.autolock_last_touch is not None: + now = cache_common.autolock_last_touch else: now = utime.ticks_ms() - storage.cache.autolock_last_touch = now + cache_common.autolock_last_touch = now for callback, task in self.tasks.items(): timeout_us = self.timeouts[callback] diff --git a/core/tests/mock_wire_interface.py b/core/tests/mock_wire_interface.py new file mode 100644 index 00000000000..b74b2150643 --- /dev/null +++ b/core/tests/mock_wire_interface.py @@ -0,0 +1,17 @@ +from trezor.loop import wait + + +class MockHID: + def __init__(self, num): + self.num = num + self.data = [] + + def iface_num(self): + return self.num + + def write(self, msg): + self.data.append(bytearray(msg)) + return len(msg) + + def wait_object(self, mode): + return wait(mode | self.num) diff --git a/core/tests/myTests.sh b/core/tests/myTests.sh new file mode 100755 index 00000000000..1c29c1fd01b --- /dev/null +++ b/core/tests/myTests.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +declare -a results +declare -i passed=0 failed=0 exit_code=0 +declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m' +MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}" +print_summary() { + echo + echo 'Summary:' + echo '-------------------' + printf '%b\n' "${results[@]}" + if [ $exit_code == 0 ]; then + echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!" + else + echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!" + fi +} + +trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT + +cd $(dirname $0) + +[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*) + +declare -i num_of_tests=${#tests[@]} + +for test_case in ${tests[@]}; do + echo ${MICROPYTHON} + echo ${test_case} + echo + if $MICROPYTHON $test_case; then + results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case") + ((passed++)) + else + results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case") + ((failed++)) + exit_code=1 + fi +done + +print_summary +exit $exit_code diff --git a/core/tests/test_apps.bitcoin.approver.py b/core/tests/test_apps.bitcoin.approver.py index 7354a846b1d..05750e594ec 100644 --- a/core/tests/test_apps.bitcoin.approver.py +++ b/core/tests/test_apps.bitcoin.approver.py @@ -1,4 +1,4 @@ -from common import H_, await_result, unittest # isort:skip +from common import * # isort:skip import storage.cache from trezor import wire @@ -11,6 +11,7 @@ TxInput, TxOutput, ) +from trezor.wire import context from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.sign_tx.approvers import CoinJoinApprover @@ -18,8 +19,26 @@ from apps.bitcoin.sign_tx.tx_info import TxInfo from apps.common import coins +if utils.USE_THP: + import thp_common +else: + import storage.cache_codec class TestApprover(unittest.TestCase): + if utils.USE_THP: + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + else: + + def __init__(self): + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): self.coin = coins.by_name("Bitcoin") self.fee_rate_percent = 0.3 @@ -47,7 +66,8 @@ def setUp(self): coin_name=self.coin.coin_name, script_type=InputScriptType.SPENDTAPROOT, ) - storage.cache.start_session() + if not utils.USE_THP: + storage.cache_codec.start_session() def make_coinjoin_request(self, inputs): return CoinJoinRequest( diff --git a/core/tests/test_apps.bitcoin.authorization.py b/core/tests/test_apps.bitcoin.authorization.py index 503c181569c..dc30d177dd3 100644 --- a/core/tests/test_apps.bitcoin.authorization.py +++ b/core/tests/test_apps.bitcoin.authorization.py @@ -1,16 +1,35 @@ -from common import H_, unittest # isort:skip +from common import * # isort:skip import storage.cache from trezor.enums import InputScriptType from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx +from trezor.wire import context from apps.bitcoin.authorization import CoinJoinAuthorization from apps.common import coins _ROUND_ID_LEN = 32 +if utils.USE_THP: + import thp_common +else: + import storage.cache_codec + class TestAuthorization(unittest.TestCase): + if utils.USE_THP: + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + else: + + def __init__(self): + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() coin = coins.by_name("Bitcoin") @@ -26,7 +45,8 @@ def setUp(self): ) self.authorization = CoinJoinAuthorization(self.msg_auth) - storage.cache.start_session() + if not utils.USE_THP: + storage.cache_codec.start_session() def test_ownership_proof_account_depth_mismatch(self): # Account depth mismatch. diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index 3828a3ebbcf..fdad1c56568 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -1,17 +1,41 @@ from common import * # isort:skip -from storage import cache +from storage import cache_common from trezor import wire from trezor.crypto import bip39 +from trezor.wire import context from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec + class TestBitcoinKeychain(unittest.TestCase): - def setUp(self): - cache.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + if utils.USE_THP: + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def __init__(self): + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bitcoin(self): coin = _get_coin_by_name("Bitcoin") @@ -88,10 +112,20 @@ def test_unknown(self): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestAltcoinKeychains(unittest.TestCase): - def setUp(self): - cache.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + if not utils.USE_THP: + + def __init__(self): + # Context is needed to test decorators and handleInitialize + # It allows access to codec cache from different parts of the code + from trezor.wire import context + + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bcash(self): coin = _get_coin_by_name("Bcash") diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index 84681a0b01e..2e89fbc1b92 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -1,19 +1,43 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache -from trezor import wire +from storage import cache, cache_common +from trezor import utils, wire from trezor.crypto import bip39 from trezor.enums import SafetyCheckLevel +from trezor.wire import context from apps.common import safety_checks from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.paths import PATTERN_SEP5, PathSchema +if utils.USE_THP: + import thp_common +if not utils.USE_THP: + from storage import cache_codec + class TestKeychain(unittest.TestCase): - def setUp(self): - cache.start_session() + + if utils.USE_THP: + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + else: + + def __init__(self): + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() + + def cache_set(self, key: int, value: bytes) -> None: + context.cache_set(key, value) def tearDown(self): cache.clear_all() @@ -71,7 +95,7 @@ def test_no_schemas(self): def test_get_keychain(self): seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + self.cache_set(cache_common.APP_COMMON_SEED, seed) schema = PathSchema.parse("m/44'/1'", 0) keychain = await_result(get_keychain("secp256k1", [schema])) @@ -85,7 +109,7 @@ def test_get_keychain(self): def test_with_slip44(self): seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + self.cache_set(cache_common.APP_COMMON_SEED, seed) slip44_id = 42 valid_path = [H_(44), H_(slip44_id), H_(0)] diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 53affef1b73..2507e5502af 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -2,13 +2,20 @@ import unittest -from storage import cache -from trezor import utils, wire +from storage import cache_common +from trezor import wire from trezor.crypto import bip39 +from trezor.wire import context from apps.common.keychain import get_keychain from apps.common.paths import HARDENED +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec + + if not utils.BITCOIN_ONLY: from ethereum_common import encode_network, make_network from trezor.messages import ( @@ -71,10 +78,27 @@ def _check_keychain(self, keychain, slip44_id): addr, ) - def setUp(self): - cache.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + if utils.USE_THP: + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def __init__(self): + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n) diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 76fe29655b9..45d8fb09657 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,230 +1,517 @@ -from common import * # isort:skip +from common import * # isort:skip # noqa: F403 from mock_storage import mock_storage -from storage import cache +from storage import cache, cache_codec from trezor.messages import EndSession, Initialize from apps.base import handle_EndSession, handle_Initialize KEY = 0 +if utils.USE_THP: + import thp_common + from mock_wire_interface import MockHID + from storage import cache_thp + from trezor.wire.thp import ChannelState + from trezor.wire.thp.session_context import ManagementSessionContext, SessionContext -# Function moved from cache.py, as it was not used there -def is_session_started() -> bool: - return cache._active_session_idx is not None + _PROTOCOL_CACHE = cache_thp +else: + _PROTOCOL_CACHE = cache_codec -class TestStorageCache(unittest.TestCase): - def setUp(self): - cache.clear_all() + def is_session_started() -> bool: + return cache_codec.get_active_session() is not None - def test_start_session(self): - session_id_a = cache.start_session() - self.assertIsNotNone(session_id_a) - session_id_b = cache.start_session() - self.assertNotEqual(session_id_a, session_id_b) + def get_active_session(): + return cache_codec.get_active_session() - cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.set(KEY, "something") - with self.assertRaises(cache.InvalidSessionError): - cache.get(KEY) - - def test_end_session(self): - session_id = cache.start_session() - self.assertTrue(is_session_started()) - cache.set(KEY, b"A") - cache.end_current_session() - self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) - - # ending an ended session should be a no-op - cache.end_current_session() - self.assertFalse(is_session_started()) - - session_id_a = cache.start_session(session_id) - # original session no longer exists - self.assertNotEqual(session_id_a, session_id) - # original session data no longer exists - self.assertIsNone(cache.get(KEY)) - - # create a new session - session_id_b = cache.start_session() - # switch back to original session - session_id = cache.start_session(session_id_a) - self.assertEqual(session_id, session_id_a) - # end original session - cache.end_current_session() - # switch back to B - session_id = cache.start_session(session_id_b) - self.assertEqual(session_id, session_id_b) - - def test_session_queue(self): - session_id = cache.start_session() - self.assertEqual(cache.start_session(session_id), session_id) - cache.set(KEY, b"A") - for i in range(cache._MAX_SESSIONS_COUNT): - cache.start_session() - self.assertNotEqual(cache.start_session(session_id), session_id) - self.assertIsNone(cache.get(KEY)) - - def test_get_set(self): - session_id1 = cache.start_session() - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - - session_id2 = cache.start_session() - cache.set(KEY, b"world") - self.assertEqual(cache.get(KEY), b"world") - - cache.start_session(session_id2) - self.assertEqual(cache.get(KEY), b"world") - cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), b"hello") +class TestStorageCache( + unittest.TestCase +): # noqa: F405 # pyright: ignore[reportUndefinedVariable] + def setUp(self): cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.get(KEY) - def test_get_set_int(self): - session_id1 = cache.start_session() - cache.set_int(KEY, 1234) - self.assertEqual(cache.get_int(KEY), 1234) - - session_id2 = cache.start_session() - cache.set_int(KEY, 5678) - self.assertEqual(cache.get_int(KEY), 5678) - - cache.start_session(session_id2) - self.assertEqual(cache.get_int(KEY), 5678) - cache.start_session(session_id1) - self.assertEqual(cache.get_int(KEY), 1234) - - cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.get_int(KEY) - - def test_delete(self): - session_id1 = cache.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - cache.delete(KEY) - self.assertIsNone(cache.get(KEY)) - - cache.set(KEY, b"hello") - cache.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - cache.delete(KEY) - self.assertIsNone(cache.get(KEY)) - - cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), b"hello") - - def test_decorators(self): - run_count = 0 - cache.start_session() - - @cache.stored(KEY) - def func(): - nonlocal run_count - run_count += 1 - return b"foo" - - # cache is empty - self.assertIsNone(cache.get(KEY)) - self.assertEqual(run_count, 0) - self.assertEqual(func(), b"foo") - # function was run - self.assertEqual(run_count, 1) - self.assertEqual(cache.get(KEY), b"foo") - # function does not run again but returns cached value - self.assertEqual(func(), b"foo") - self.assertEqual(run_count, 1) - - @cache.stored_async(KEY) - async def async_func(): - nonlocal run_count - run_count += 1 - return b"bar" - - # cache is still full - self.assertEqual(await_result(async_func()), b"foo") - self.assertEqual(run_count, 1) - - cache.start_session() - self.assertEqual(await_result(async_func()), b"bar") - self.assertEqual(run_count, 2) - # awaitable is also run only once - self.assertEqual(await_result(async_func()), b"bar") - self.assertEqual(run_count, 2) - - def test_empty_value(self): - cache.start_session() - - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"") - self.assertEqual(cache.get(KEY), b"") - - cache.delete(KEY) - run_count = 0 - - @cache.stored(KEY) - def func(): - nonlocal run_count - run_count += 1 - return b"" - - self.assertEqual(func(), b"") - # function gets called once - self.assertEqual(run_count, 1) - self.assertEqual(func(), b"") - # function is not called for a second time - self.assertEqual(run_count, 1) - - @mock_storage - def test_Initialize(self): - def call_Initialize(**kwargs): - msg = Initialize(**kwargs) - return await_result(handle_Initialize(msg)) - - # calling Initialize without an ID allocates a new one - session_id = cache.start_session() - features = call_Initialize() - self.assertNotEqual(session_id, features.session_id) - - # calling Initialize with the current ID does not allocate a new one - features = call_Initialize(session_id=session_id) - self.assertEqual(session_id, features.session_id) - - # store "hello" - cache.set(KEY, b"hello") - # check that it is cleared - features = call_Initialize() - session_id = features.session_id - self.assertIsNone(cache.get(KEY)) - # store "hello" again - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - - # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) - self.assertIsNone(cache.get(KEY)) - - # but resuming a session loads the previous one - call_Initialize(session_id=session_id) - self.assertEqual(cache.get(KEY), b"hello") - - def test_EndSession(self): - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) - cache.start_session() - self.assertTrue(is_session_started()) - self.assertIsNone(cache.get(KEY)) - await_result(handle_EndSession(EndSession())) - self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + if utils.USE_THP: + + def __init__(self): + thp_common.suppres_debug_log() + # xthp_common.prepare_context() + # config.init() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + cache.clear_all() + + def test_new_channel_and_session(self): + channel = thp_common.get_new_channel(self.interface) + + # Assert that channel is created without any sessions + self.assertEqual(len(channel.sessions), 0) + + cid_1 = channel.channel_id + session_cache_1 = cache_thp.get_new_session(channel.channel_cache) + session_1 = SessionContext(channel, session_cache_1) + self.assertEqual(session_1.channel_id, cid_1) + + session_cache_2 = cache_thp.get_new_session(channel.channel_cache) + session_2 = SessionContext(channel, session_cache_2) + self.assertEqual(session_2.channel_id, cid_1) + self.assertEqual(session_1.channel_id, session_2.channel_id) + self.assertNotEqual(session_1.session_id, session_2.session_id) + + channel_2 = thp_common.get_new_channel(self.interface) + cid_2 = channel_2.channel_id + self.assertNotEqual(cid_1, cid_2) + + session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache) + session_3 = SessionContext(channel_2, session_cache_3) + self.assertEqual(session_3.channel_id, cid_2) + + # Sessions 1 and 3 should have different channel_id, but the same session_id + self.assertNotEqual(session_1.channel_id, session_3.channel_id) + self.assertEqual(session_1.session_id, session_3.session_id) + + self.assertEqual(cache_thp._SESSIONS[0], session_cache_1) + self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2) + self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id) + + # Check that session data IS in cache for created sessions ONLY + for i in range(3): + self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"") + self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"") + self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0) + for i in range(3, cache_thp._MAX_SESSIONS_COUNT): + self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"") + self.assertEqual(cache_thp._SESSIONS[i].session_id, b"") + self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0) + + # Check that session data IS NOT in cache after cache.clear_all() + cache.clear_all() + for session in cache_thp._SESSIONS: + self.assertEqual(session.channel_id, b"") + self.assertEqual(session.session_id, b"") + self.assertEqual(session.last_usage, 0) + self.assertEqual(session.state, b"\x00") + + def test_channel_capacity_in_cache(self): + self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3) + channels = [] + for i in range(cache_thp._MAX_CHANNELS_COUNT): + channels.append(thp_common.get_new_channel(self.interface)) + channel_ids = [channel.channel_cache.channel_id for channel in channels] + + # Assert that each channel_id is unique and that cache and list of channels + # have the same "channels" on the same indexes + for i in range(len(channel_ids)): + self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i]) + for j in range(i + 1, len(channel_ids)): + self.assertNotEqual(channel_ids[i], channel_ids[j]) + + # Create a new channel that is over the capacity + new_channel = thp_common.get_new_channel(self.interface) + for c in channels: + self.assertNotEqual(c.channel_id, new_channel.channel_id) + + # Test that the oldest (least used) channel was replaced (_CHANNELS[0]) + self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0]) + self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id) + + # Update the "last used" value of the second channel in cache (_CHANNELS[1]) and + # assert that it is not replaced when creating a new channel + cache_thp.update_channel_last_used(channel_ids[1]) + new_new_channel = thp_common.get_new_channel(self.interface) + self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1]) + + # Assert that it was in fact the _CHANNEL[2] that was replaced + self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2]) + self.assertEqual( + cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id + ) + + def test_session_capacity_in_cache(self): + self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4) + channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache + channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache + + sesions_A = [] + cid = [] + sid = [] + for i in range(3): + sesions_A.append(cache_thp.get_new_session(channel_cache_A)) + cid.append(sesions_A[i].channel_id) + sid.append(sesions_A[i].session_id) + + sessions_B = [] + for i in range(cache_thp._MAX_SESSIONS_COUNT - 3): + sessions_B.append(cache_thp.get_new_session(channel_cache_B)) + + for i in range(3): + self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i]) + self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id) + self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id) + for i in range(3, cache_thp._MAX_SESSIONS_COUNT): + self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i]) + + # Assert that new session replaces the oldest (least used) one (_SESSOIONS[0]) + new_session = cache_thp.get_new_session(channel_cache_B) + self.assertEqual(new_session, cache_thp._SESSIONS[0]) + self.assertNotEqual(new_session.channel_id, cid[0]) + self.assertNotEqual(new_session.session_id, sid[0]) + + # Assert that updating "last used" for session on channel A increases also + # the "last usage" of channel A. + self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage) + cache_thp.update_session_last_used( + channel_cache_A.channel_id, sesions_A[1].session_id + ) + self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage) + + new_new_session = cache_thp.get_new_session(channel_cache_B) + + # Assert that creating a new session on channel B shifts the "last usage" again + # and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced + self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage) + self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1]) + self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2]) + self.assertEqual(new_new_session, cache_thp._SESSIONS[2]) + + def test_clear(self): + channel_A = thp_common.get_new_channel(self.interface) + channel_B = thp_common.get_new_channel(self.interface) + cid_A = channel_A.channel_id + cid_B = channel_B.channel_id + sessions = [] + + for i in range(3): + sessions.append(cache_thp.get_new_session(channel_A.channel_cache)) + sessions.append(cache_thp.get_new_session(channel_B.channel_cache)) + + self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A) + self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0) + + self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B) + self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0) + + # Assert that clearing of channel A works + self.assertNotEqual(channel_A.channel_cache.channel_id, b"") + self.assertNotEqual(channel_A.channel_cache.last_usage, 0) + self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1) + + channel_A.clear() + + self.assertEqual(channel_A.channel_cache.channel_id, b"") + self.assertEqual(channel_A.channel_cache.last_usage, 0) + self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED) + + # Assert that clearing channel A also cleared all its sessions + for i in range(3): + self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0) + self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"") + + self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0) + self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B) + + cache.clear_all() + for session in cache_thp._SESSIONS: + self.assertEqual(session.last_usage, 0) + self.assertEqual(session.channel_id, b"") + for channel in cache_thp._CHANNELS: + self.assertEqual(channel.channel_id, b"") + self.assertEqual(channel.last_usage, 0) + self.assertEqual( + cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED + ) + + def test_get_set(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.get_new_session(channel.channel_cache) + session_1.set(KEY, b"hello") + self.assertEqual(session_1.get(KEY), b"hello") + + session_2 = cache_thp.get_new_session(channel.channel_cache) + session_2.set(KEY, b"world") + self.assertEqual(session_2.get(KEY), b"world") + + self.assertEqual(session_1.get(KEY), b"hello") + + cache.clear_all() + self.assertIsNone(session_1.get(KEY)) + self.assertIsNone(session_2.get(KEY)) + + def test_get_set_int(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.get_new_session(channel.channel_cache) + session_1.set_int(KEY, 1234) + + self.assertEqual(session_1.get_int(KEY), 1234) + + session_2 = cache_thp.get_new_session(channel.channel_cache) + session_2.set_int(KEY, 5678) + self.assertEqual(session_2.get_int(KEY), 5678) + + self.assertEqual(session_1.get_int(KEY), 1234) + + cache.clear_all() + self.assertIsNone(session_1.get_int(KEY)) + self.assertIsNone(session_2.get_int(KEY)) + + def test_get_set_bool(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.get_new_session(channel.channel_cache) + with self.assertRaises(AssertionError) as e: + session_1.set_bool(KEY, True) + self.assertEqual(e.value.value, "Field does not have zero length!") + + # Change length of first session field to 0 so that the length check passes + session_1.fields = (0,) + session_1.fields[1:] + + # with self.assertRaises(AssertionError) as e: + session_1.set_bool(KEY, True) + self.assertEqual(session_1.get_bool(KEY), True) + + session_2 = cache_thp.get_new_session(channel.channel_cache) + session_2.fields = session_2.fields = (0,) + session_2.fields[1:] + session_2.set_bool(KEY, False) + self.assertEqual(session_2.get_bool(KEY), False) + + self.assertEqual(session_1.get_bool(KEY), True) + + cache.clear_all() + + # Default value is False + self.assertFalse(session_1.get_bool(KEY)) + self.assertFalse(session_2.get_bool(KEY)) + + def test_delete(self): + channel = thp_common.get_new_channel(self.interface) + session_1 = cache_thp.get_new_session(channel.channel_cache) + + self.assertIsNone(session_1.get(KEY)) + session_1.set(KEY, b"hello") + self.assertEqual(session_1.get(KEY), b"hello") + session_1.delete(KEY) + self.assertIsNone(session_1.get(KEY)) + + session_1.set(KEY, b"hello") + session_2 = cache_thp.get_new_session(channel.channel_cache) + + self.assertIsNone(session_2.get(KEY)) + session_2.set(KEY, b"hello") + self.assertEqual(session_2.get(KEY), b"hello") + session_2.delete(KEY) + self.assertIsNone(session_2.get(KEY)) + + self.assertEqual(session_1.get(KEY), b"hello") + + else: + + def __init__(self): + # Context is needed to test decorators and handleInitialize + # It allows access to codec cache from different parts of the code + from trezor.wire import context + + context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64)) + super().__init__() + + def test_start_session(self): + session_id_a = cache_codec.start_session() + self.assertIsNotNone(session_id_a) + session_id_b = cache_codec.start_session() + self.assertNotEqual(session_id_a, session_id_b) + + cache.clear_all() + self.assertIsNone(get_active_session()) + for session in cache_codec._SESSIONS: + self.assertEqual(session.session_id, b"") + self.assertEqual(session.last_usage, 0) + + def test_end_session(self): + session_id = cache_codec.start_session() + self.assertTrue(is_session_started()) + get_active_session().set(KEY, b"A") + cache_codec.end_current_session() + self.assertFalse(is_session_started()) + self.assertIsNone(get_active_session()) + + # ending an ended session should be a no-op + cache_codec.end_current_session() + self.assertFalse(is_session_started()) + + session_id_a = cache_codec.start_session(session_id) + # original session no longer exists + self.assertNotEqual(session_id_a, session_id) + # original session data no longer exists + self.assertIsNone(get_active_session().get(KEY)) + + # create a new session + session_id_b = cache_codec.start_session() + # switch back to original session + session_id = cache_codec.start_session(session_id_a) + self.assertEqual(session_id, session_id_a) + # end original session + cache_codec.end_current_session() + # switch back to B + session_id = cache_codec.start_session(session_id_b) + self.assertEqual(session_id, session_id_b) + + def test_session_queue(self): + session_id = cache_codec.start_session() + self.assertEqual(cache_codec.start_session(session_id), session_id) + get_active_session().set(KEY, b"A") + for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT): + cache_codec.start_session() + self.assertNotEqual(cache_codec.start_session(session_id), session_id) + self.assertIsNone(get_active_session().get(KEY)) + + def test_get_set(self): + session_id1 = cache_codec.start_session() + cache_codec.get_active_session().set(KEY, b"hello") + self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello") + + session_id2 = cache_codec.start_session() + cache_codec.get_active_session().set(KEY, b"world") + self.assertEqual(cache_codec.get_active_session().get(KEY), b"world") + + cache_codec.start_session(session_id2) + self.assertEqual(cache_codec.get_active_session().get(KEY), b"world") + cache_codec.start_session(session_id1) + self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello") + + cache.clear_all() + self.assertIsNone(cache_codec.get_active_session()) + + def test_get_set_int(self): + session_id1 = cache_codec.start_session() + get_active_session().set_int(KEY, 1234) + self.assertEqual(get_active_session().get_int(KEY), 1234) + + session_id2 = cache_codec.start_session() + get_active_session().set_int(KEY, 5678) + self.assertEqual(get_active_session().get_int(KEY), 5678) + + cache_codec.start_session(session_id2) + self.assertEqual(get_active_session().get_int(KEY), 5678) + cache_codec.start_session(session_id1) + self.assertEqual(get_active_session().get_int(KEY), 1234) + + cache.clear_all() + self.assertIsNone(get_active_session()) + + def test_delete(self): + session_id1 = cache_codec.start_session() + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + get_active_session().delete(KEY) + self.assertIsNone(get_active_session().get(KEY)) + + get_active_session().set(KEY, b"hello") + cache_codec.start_session() + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + get_active_session().delete(KEY) + self.assertIsNone(get_active_session().get(KEY)) + + cache_codec.start_session(session_id1) + self.assertEqual(get_active_session().get(KEY), b"hello") + + def test_decorators(self): + run_count = 0 + cache_codec.start_session() + from apps.common.cache import stored + + @stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"foo" + + # cache is empty + self.assertIsNone(get_active_session().get(KEY)) + self.assertEqual(run_count, 0) + self.assertEqual(func(), b"foo") + # function was run + self.assertEqual(run_count, 1) + self.assertEqual(get_active_session().get(KEY), b"foo") + # function does not run again but returns cached value + self.assertEqual(func(), b"foo") + self.assertEqual(run_count, 1) + + def test_empty_value(self): + cache_codec.start_session() + + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"") + self.assertEqual(get_active_session().get(KEY), b"") + + get_active_session().delete(KEY) + run_count = 0 + + from apps.common.cache import stored + + @stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"" + + self.assertEqual(func(), b"") + # function gets called once + self.assertEqual(run_count, 1) + self.assertEqual(func(), b"") + # function is not called for a second time + self.assertEqual(run_count, 1) + + @mock_storage + def test_Initialize(self): + + def call_Initialize(**kwargs): + msg = Initialize(**kwargs) + return await_result(handle_Initialize(msg)) + + # calling Initialize without an ID allocates a new one + session_id = cache_codec.start_session() + features = call_Initialize() + self.assertNotEqual(session_id, features.session_id) + + # calling Initialize with the current ID does not allocate a new one + features = call_Initialize(session_id=session_id) + self.assertEqual(session_id, features.session_id) + + # store "hello" + get_active_session().set(KEY, b"hello") + # check that it is cleared + features = call_Initialize() + session_id = features.session_id + self.assertIsNone(get_active_session().get(KEY)) + # store "hello" again + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + + # supplying a different session ID starts a new session + call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH) + self.assertIsNone(get_active_session().get(KEY)) + + # but resuming a session loads the previous one + call_Initialize(session_id=session_id) + self.assertEqual(get_active_session().get(KEY), b"hello") + + def test_EndSession(self): + + self.assertIsNone(get_active_session()) + cache_codec.start_session() + self.assertTrue(is_session_started()) + self.assertIsNone(get_active_session().get(KEY)) + await_result(handle_EndSession(EndSession())) + self.assertFalse(is_session_started()) + self.assertIsNone(cache_codec.get_active_session()) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec_v1.py index 1da0ea896b4..6c8373aef4e 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec_v1.py @@ -2,28 +2,11 @@ import ustruct +from mock_wire_interface import MockHID from trezor import io -from trezor.loop import wait from trezor.utils import chunks from trezor.wire import codec_v1 - -class MockHID: - def __init__(self, num): - self.num = num - self.data = [] - - def iface_num(self): - return self.num - - def write(self, msg): - self.data.append(bytearray(msg)) - return len(msg) - - def wait_object(self, mode): - return wait(mode | self.num) - - MESSAGE_TYPE = 0x4242 HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL") diff --git a/core/tests/test_trezor.wire.thp.checksum.py b/core/tests/test_trezor.wire.thp.checksum.py new file mode 100644 index 00000000000..41c93250012 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.checksum.py @@ -0,0 +1,94 @@ +from common import * # isort:skip + +if utils.USE_THP: + from trezor.wire.thp import checksum + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolChecksum(unittest.TestCase): + vectors_correct = [ + ( + b"", + b"\x00\x00\x00\x00", + ), + ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b"\x19\x0A\x55\xAD", + ), + ( + bytes("a", "ascii"), + b"\xE8\xB7\xBE\x43", + ), + ( + bytes("abc", "ascii"), + b"\x35\x24\x41\xC2", + ), + ( + bytes("123456789", "ascii"), + b"\xCB\xF4\x39\x26", + ), + ( + bytes( + "12345678901234567890123456789012345678901234567890123456789012345678901234567890", + "ascii", + ), + b"\x7C\xA9\x4A\x72", + ), + ( + b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61", + b"\x9B\xD3\x66\xAE", + ), + ( + b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4", + b"\x6B\xA4\xEC\x92", + ), + ] + vectors_incorrect = [ + ( + b"", + b"\x00\x00\x00\x00\x00", + ), + ( + b"", + b"", + ), + ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b"\x19\x0A\x55\xAE", + ), + ( + bytes("A", "ascii"), + b"\xE8\xB7\xBE\x43", + ), + ( + bytes("abc ", "ascii"), + b"\x35\x24\x41\xC2", + ), + ( + bytes("1234567890", "ascii"), + b"\xCB\xF4\x39\x26", + ), + ( + bytes( + "1234567890123456789012345678901234567890123456789012345678901234567890123456789", + "ascii", + ), + b"\x7C\xA9\x4A\x72", + ), + ] + + def test_computation(self): + for data, chksum in self.vectors_correct: + self.assertEqual(checksum.compute(data), chksum) + + def test_validation_correct(self): + for data, chksum in self.vectors_correct: + self.assertTrue(checksum.is_valid(chksum, data)) + + def test_validation_incorrect(self): + for data, chksum in self.vectors_incorrect: + self.assertFalse(checksum.is_valid(chksum, data)) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.credential_manager.py b/core/tests/test_trezor.wire.thp.credential_manager.py new file mode 100644 index 00000000000..59631979d6a --- /dev/null +++ b/core/tests/test_trezor.wire.thp.credential_manager.py @@ -0,0 +1,66 @@ +from common import * # isort:skip + + +if utils.USE_THP: + import thp_common + from trezor import config + from trezor.messages import ThpCredentialMetadata + + from apps.thp import credential_manager + + def _issue_credential(host_name: str, host_static_pubkey: bytes) -> bytes: + metadata = ThpCredentialMetadata(host_name=host_name) + return credential_manager.issue_credential(host_static_pubkey, metadata) + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolCredentialManager(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + config.init() + config.wipe() + + def test_derive_cred_auth_key(self): + key1 = credential_manager.derive_cred_auth_key() + key2 = credential_manager.derive_cred_auth_key() + self.assertEqual(len(key1), 32) + self.assertEqual(key1, key2) + + def test_invalidate_cred_auth_key(self): + key1 = credential_manager.derive_cred_auth_key() + credential_manager.invalidate_cred_auth_key() + key2 = credential_manager.derive_cred_auth_key() + self.assertNotEqual(key1, key2) + + def test_credentials(self): + DUMMY_KEY_1 = b"\x00\x00" + DUMMY_KEY_2 = b"\xff\xff" + HOST_NAME_1 = "host_name" + HOST_NAME_2 = "different host_name" + + cred_1 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) + cred_2 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) + self.assertEqual(cred_1, cred_2) + + cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1) + self.assertNotEqual(cred_1, cred_3) + + self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) + self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) + self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2)) + + credential_manager.invalidate_cred_auth_key() + cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) + self.assertNotEqual(cred_1, cred_4) + self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) + self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) + self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.crypto.py b/core/tests/test_trezor.wire.thp.crypto.py new file mode 100644 index 00000000000..d26785ce65e --- /dev/null +++ b/core/tests/test_trezor.wire.thp.crypto.py @@ -0,0 +1,156 @@ +from common import * # isort:skip +from trezorcrypto import aesgcm, curve25519 + +import storage + +if utils.USE_THP: + import thp_common + from trezor.wire.thp import crypto + from trezor.wire.thp.crypto import IV_1, IV_2, Handshake + + def get_dummy_device_secret(): + return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolCrypto(unittest.TestCase): + if utils.USE_THP: + handshake = Handshake() + key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07" + # 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag + vectors_enc = [ + ( + key_1, + 0, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", + b"e2c9dd152fbee5821ea7", + b"10625812de81b14a46b9f1e5100a6d0c", + ), + ( + key_1, + 1, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", + b"79811619ddb07c2b99f8", + b"71c6b872cdc499a7e9a3c7441f053214", + ), + ( + key_1, + 369, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + b"03bd030390f2dfe815a61c2b157a064f", + b"c1200f8a7ae9a6d32cef0fff878d55c2", + ), + ( + key_1, + 369, + b"\x55\x64\x73\x82\x91", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + b"03bd030390f2dfe815a61c2b157a064f", + b"693ac160cd93a20f7fc255f049d808d0", + ), + ] + # 0:chaining key, 1:input, 2:output_1, 3:output:2 + vectors_hkdf = [ + ( + crypto.PROTOCOL_NAME, + b"\x01\x02", + b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514", + b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba", + ), + ( + b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14", + b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa", + b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48", + b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76", + ), + ] + vectors_iv = [ + (0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + (1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"), + (7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"), + (1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"), + (4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"), + (0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"), + ] + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + utils.DISABLE_ENCRYPTION = False + + def test_encryption(self): + for v in self.vectors_enc: + buffer = bytearray(v[3]) + tag = crypto.enc(buffer, v[0], v[1], v[2]) + self.assertEqual(hexlify(buffer), v[4]) + self.assertEqual(hexlify(tag), v[5]) + self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2])) + self.assertEqual(buffer, v[3]) + + def test_hkdf(self): + for v in self.vectors_hkdf: + ck, k = crypto._hkdf(v[0], v[1]) + self.assertEqual(hexlify(ck), v[2]) + self.assertEqual(hexlify(k), v[3]) + + def test_iv_from_nonce(self): + for v in self.vectors_iv: + x = v[0] + y = x.to_bytes(8, "big") + iv = crypto._get_iv_from_nonce(v[0]) + self.assertEqual(iv, v[1]) + with self.assertRaises(AssertionError) as e: + iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1) + self.assertEqual(e.value.value, "Nonce overflow, terminate the channel") + + def test_incorrect_vectors(self): + pass + + def test_th1_crypto(self): + storage.device.get_device_secret = get_dummy_device_secret + handshake = self.handshake + + host_ephemeral_privkey = curve25519.generate_secret() + host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey) + handshake.handle_th1_crypto(b"", host_ephemeral_pubkey) + + def test_th2_crypto(self): + handshake = self.handshake + + host_static_privkey = curve25519.generate_secret() + host_static_pubkey = curve25519.publickey(host_static_privkey) + aes_ctx = aesgcm(handshake.k, IV_2) + aes_ctx.auth(handshake.h) + encrypted_host_static_pubkey = bytearray( + aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish() + ) + + # Code to encrypt Host's noise encrypted payload correctly: + protomsg = bytearray(b"\x10\x02\x10\x03") + temp_k = handshake.k + temp_h = handshake.h + + temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey) + _, temp_k = crypto._hkdf( + handshake.ck, + curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey), + ) + aes_ctx = aesgcm(temp_k, IV_1) + aes_ctx.encrypt_in_place(protomsg) + aes_ctx.auth(temp_h) + tag = aes_ctx.finish() + encrypted_payload = bytearray(protomsg + tag) + # end of encrypted payload generation + + handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload) + self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py new file mode 100644 index 00000000000..dac468dae37 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.py @@ -0,0 +1,370 @@ +from common import * # isort:skip +from mock_wire_interface import MockHID +from trezor import config, io, protobuf +from trezor.crypto.curve import curve25519 +from trezor.enums import MessageType +from trezor.wire.errors import UnexpectedMessage +from trezor.wire.protocol_common import Message + +if utils.USE_THP: + from typing import TYPE_CHECKING + + import thp_common + from storage import cache_thp + from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, + ) + from trezor.crypto import elligator2 + from trezor.enums import ThpPairingMethod + from trezor.messages import ( + ThpCodeEntryChallenge, + ThpCodeEntryCpaceHost, + ThpCodeEntryTag, + ThpCredentialRequest, + ThpEndRequest, + ThpStartPairingRequest, + ) + from trezor.wire import thp_main + from trezor.wire.thp import ChannelState, checksum, interface_manager + from trezor.wire.thp.crypto import Handshake + from trezor.wire.thp.pairing_context import PairingContext + + from apps.thp import pairing + + if TYPE_CHECKING: + from trezor.wire import WireInterface + + def get_dummy_key() -> bytes: + return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31" + + def send_channel_allocation_request( + interface: WireInterface, nonce: bytes | None = None + ) -> bytes: + if nonce is None or len(nonce) != 8: + nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77" + header = b"\x40\xff\xff\x00\x0c" + chksum = checksum.compute(header + nonce) + cid_req = header + nonce + chksum + gen = thp_main.thp_main_loop(interface) + gen.send(None) + gen.send(cid_req) + gen.send(None) + response_data = ( + b"\x0a\x04\x54\x32\x54\x31\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04" + ) + response_without_crc = ( + b"\x41\xff\xff\x00\x20" + + nonce + + cache_thp.cid_counter.to_bytes(2, "big") + + response_data + ) + chkcsum = checksum.compute(response_without_crc) + expected_response = response_without_crc + chkcsum + b"\x00" * 27 + return expected_response + + def get_channel_id_from_response(channel_allocation_response: bytes) -> int: + return int.from_bytes(channel_allocation_response[13:15], "big") + + def get_ack(channel_id: bytes) -> bytes: + if len(channel_id) != 2: + raise Exception("Channel id should by two bytes long") + return ( + b"\x20" + + channel_id + + b"\x00\x04" + + checksum.compute(b"\x20" + channel_id + b"\x00\x04") + + b"\x00" * 55 + ) + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocol(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + buffer = bytearray(64) + buffer2 = bytearray(256) + thp_main.set_read_buffer(buffer) + thp_main.set_write_buffer(buffer2) + interface_manager.decode_iface = thp_common.dummy_decode_iface + + def test_codec_message(self): + self.assertEqual(len(self.interface.data), 0) + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + + # There should be a failiure response to received init packet (starts with "?##") + test_codec_message = b"?## Some data" + gen.send(test_codec_message) + gen.send(None) + self.assertEqual(len(self.interface.data), 1) + + expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10" + self.assertEqual( + self.interface.data[-1][: len(expected_response)], expected_response + ) + + # There should be no response for continuation packet (starts with "?" only) + test_codec_message_2 = b"? Cont packet" + gen.send(test_codec_message_2) + with self.assertRaises(TypeError) as e: + gen.send(None) + self.assertEqual(e.value.value, "object with buffer protocol required") + self.assertEqual(len(self.interface.data), 1) + + def test_message_on_unallocated_channel(self): + gen = thp_main.thp_main_loop(self.interface) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + message_to_channel_789a = ( + b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + ) + gen.send(message_to_channel_789a) + gen.send(None) + unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + self.assertEqual( + utils.get_bytes_as_str(self.interface.data[-1]), + unallocated_chanel_error_on_channel_789a, + ) + + def test_channel_allocation(self): + test_counter = cache_thp.cid_counter + 1 + self.assertEqual(len(thp_main._CHANNELS), 0) + self.assertFalse(test_counter in thp_main._CHANNELS) + + expected_response = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-1], expected_response) + + self.assertTrue(test_counter in thp_main._CHANNELS) + self.assertEqual(len(thp_main._CHANNELS), 1) + + # test channel's default state is TH1: + cid = get_channel_id_from_response(self.interface.data[-1]) + self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1) + + def test_invalid_encrypted_tag(self): + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + # prepare 2 new channels + expected_response_1 = send_channel_allocation_request(self.interface) + expected_response_2 = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-2], expected_response_1) + self.assertEqual(self.interface.data[-1], expected_response_2) + + # test invalid encryption tag + config.init() + config.wipe() + cid_1 = get_channel_id_from_response(expected_response_1) + channel = thp_main._CHANNELS[cid_1] + channel.iface = self.interface + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + header = b"\x04" + channel.channel_id + b"\x00\x14" + + tag = b"\x00" * 16 + chksum = checksum.compute(header + tag) + message_with_invalid_tag = header + tag + chksum + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + cid_1_bytes = int.to_bytes(cid_1, 2, "big") + expected_ack_on_received_message = get_ack(cid_1_bytes) + + gen.send(message_with_invalid_tag) + gen.send(None) + + self.assertEqual( + self.interface.data[-1], + expected_ack_on_received_message, + ) + error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03" + chksum_err = checksum.compute(error_without_crc) + gen.send(None) + + decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54 + + self.assertEqual( + self.interface.data[-1], + decryption_failed_error, + ) + + def test_channel_errors(self): + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + # prepare 2 new channels + expected_response_1 = send_channel_allocation_request(self.interface) + expected_response_2 = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-2], expected_response_1) + self.assertEqual(self.interface.data[-1], expected_response_2) + + # test invalid encryption tag + config.init() + config.wipe() + cid_1 = get_channel_id_from_response(expected_response_1) + channel = thp_main._CHANNELS[cid_1] + channel.iface = self.interface + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + header = b"\x04" + channel.channel_id + b"\x00\x14" + + tag = b"\x00" * 16 + chksum = checksum.compute(header + tag) + message_with_invalid_tag = header + tag + chksum + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + cid_1_bytes = int.to_bytes(cid_1, 2, "big") + expected_ack_on_received_message = get_ack(cid_1_bytes) + + gen.send(message_with_invalid_tag) + gen.send(None) + + self.assertEqual( + self.interface.data[-1], + expected_ack_on_received_message, + ) + error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03" + chksum_err = checksum.compute(error_without_crc) + gen.send(None) + + decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54 + + self.assertEqual( + self.interface.data[-1], + decryption_failed_error, + ) + + # test invalid tag in handshake phase + cid_2 = get_channel_id_from_response(expected_response_1) + cid_2_bytes = cid_2.to_bytes(2, "big") + channel = thp_main._CHANNELS[cid_2] + channel.iface = self.interface + + channel.set_channel_state(ChannelState.TH2) + + message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9" + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + # gen.send(message_with_invalid_tag) + # gen.send(None) + # gen.send(None) + # for i in self.interface.data: + # print(utils.get_bytes_as_str(i)) + + def test_skip_pairing(self): + config.init() + config.wipe() + channel = thp_main._CHANNELS[4660] + channel.selected_pairing_methods = [ + ThpPairingMethod.NoMethod, + ThpPairingMethod.CodeEntry, + ThpPairingMethod.NFC_Unidirectional, + ThpPairingMethod.QrCode, + ] + pairing_ctx = PairingContext(channel) + request_message = ThpStartPairingRequest() + channel.set_channel_state(ChannelState.TP1) + gen = pairing.handle_pairing_request(pairing_ctx, request_message) + + with self.assertRaises(StopIteration): + gen.send(None) + self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT) + + # Teardown: set back initial channel state value + channel.set_channel_state(ChannelState.TH1) + + def test_pairing(self): + config.init() + config.wipe() + cid = get_channel_id_from_response( + send_channel_allocation_request(self.interface) + ) + channel = thp_main._CHANNELS[cid] + channel.selected_pairing_methods = [ + ThpPairingMethod.CodeEntry, + ThpPairingMethod.NFC_Unidirectional, + ThpPairingMethod.QrCode, + ] + pairing_ctx = PairingContext(channel) + request_message = ThpStartPairingRequest() + with self.assertRaises(UnexpectedMessage) as e: + pairing.handle_pairing_request(pairing_ctx, request_message) + print(e.value.message) + channel.set_channel_state(ChannelState.TP1) + gen = pairing.handle_pairing_request(pairing_ctx, request_message) + + channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0) + channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"") + + gen.send(None) + + async def _dummy(ctx: PairingContext, expected_types): + return await ctx.read([1018, 1024]) + + pairing.show_display_data = _dummy + + msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34") + buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry)) + protobuf.encode(buffer, msg_code_entry) + code_entry_challenge = Message(MessageType.ThpCodeEntryChallenge, buffer) + gen.send(code_entry_challenge) + + # tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3" + # tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0" + + pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe" + generator_host = elligator2.map_to_curve25519(pregenerator_host) + cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9" + cpace_host_public_key: bytes = curve25519.multiply( + cpace_host_private_key, generator_host + ) + msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key) + + # msg = ThpQrCodeTag(tag=tag_qrc) + # msg = ThpNfcUnidirectionalTag(tag=tag_nfc) + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + + protobuf.encode(buffer, msg) + user_message = Message(MessageType.ThpCodeEntryCpaceHost, buffer) + gen.send(user_message) + + tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81" + msg = ThpCodeEntryTag(tag=tag_ent) + + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + + protobuf.encode(buffer, msg) + user_message = Message(MessageType.ThpCodeEntryTag, buffer) + gen.send(user_message) + + host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77" + msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey) + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + protobuf.encode(buffer, msg) + credential_request = Message(MessageType.ThpCredentialRequest, buffer) + gen.send(credential_request) + + msg = ThpEndRequest() + + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + protobuf.encode(buffer, msg) + end_request = Message(1012, buffer) + with self.assertRaises(StopIteration) as e: + gen.send(end_request) + print("response message:", e.value.value.MESSAGE_NAME) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.writer.py b/core/tests/test_trezor.wire.thp.writer.py new file mode 100644 index 00000000000..84f6ac50ab9 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.writer.py @@ -0,0 +1,150 @@ +from common import * # isort:skip + +from typing import Any, Awaitable + +if utils.USE_THP: + import thp_common + from mock_wire_interface import MockHID + from trezor.wire.thp import writer + from trezor.wire.thp.thp_messages import ENCRYPTED_TRANSPORT, PacketHeader + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolWriter(unittest.TestCase): + short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + longer_payload_expected = [ + b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ] + eight_longer_payloads_expected = [ + b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e", + b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b", + b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8", + b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5", + b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122", + b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f", + b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c", + b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9", + b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516", + b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253", + b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90", + b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd", + b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a", + b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647", + b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384", + b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1", + b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe", + b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b", + b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778", + b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5", + b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2", + b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c", + b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9", + b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6", + b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223", + b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60", + b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d", + b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da", + b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000", + ] + empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + longer_payload_with_checksum_expected = [ + b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ] + + def await_until_result(self, task: Awaitable) -> Any: + with self.assertRaises(StopIteration): + while True: + task.send(None) + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def test_write_empty_packet(self): + self.await_until_result(writer.write_packet_to_wire(self.interface, b"")) + + print(self.interface.data[0]) + self.assertEqual(len(self.interface.data), 1) + self.assertEqual(self.interface.data[0], b"") + + def test_write_empty_payload(self): + header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4) + await_result(writer.write_payloads_to_wire(self.interface, header, (b"",))) + self.assertEqual(len(self.interface.data), 0) + + def test_write_short_payload(self): + header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 5) + data = b"\x07" + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) + self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected) + + def test_write_longer_payload(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256) + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) + + for i in range(len(self.longer_payload_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), self.longer_payload_expected[i] + ) + + def test_write_eight_longer_payloads(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 2048) + self.await_until_result( + writer.write_payloads_to_wire( + self.interface, header, (data, data, data, data, data, data, data, data) + ) + ) + for i in range(len(self.eight_longer_payloads_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i] + ) + + def test_write_empty_payload_with_checksum(self): + header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4) + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"") + ) + + self.assertEqual( + hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected + ) + + def test_write_longer_payload_with_checksum(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256) + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, data) + ) + + for i in range(len(self.longer_payload_with_checksum_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), + self.longer_payload_with_checksum_expected[i], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp_deprecated.py b/core/tests/test_trezor.wire.thp_deprecated.py new file mode 100644 index 00000000000..04a1e13c44c --- /dev/null +++ b/core/tests/test_trezor.wire.thp_deprecated.py @@ -0,0 +1,338 @@ +from common import * # isort:skip +import ustruct +from typing import TYPE_CHECKING + +from mock_wire_interface import MockHID +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import io +from trezor.utils import chunks +from trezor.wire.protocol_common import Message + +if utils.USE_THP: + import thp_common + import trezor.wire.thp + from trezor.wire import thp_main + from trezor.wire.thp import alternating_bit_protocol as ABP + from trezor.wire.thp import checksum + from trezor.wire.thp.checksum import CHECKSUM_LENGTH + from trezor.wire.thp.writer import PACKET_LENGTH + +if TYPE_CHECKING: + from trezorio import WireInterface + + +MESSAGE_TYPE = 0x4242 +MESSAGE_TYPE_BYTES = b"\x42\x42" +_MESSAGE_TYPE_LEN = 2 +PLAINTEXT_0 = 0x01 +PLAINTEXT_1 = 0x11 +COMMON_CID = 4660 +CONT = 0x80 + +HEADER_INIT_LENGTH = 5 +HEADER_CONT_LENGTH = 3 +if utils.USE_THP: + INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN + + +def make_header(ctrl_byte, cid, length): + return ustruct.pack(">BHH", ctrl_byte, cid, length) + + +def make_cont_header(): + return ustruct.pack(">BH", CONT, COMMON_CID) + + +def makeSimpleMessage(header, message_type, message_data): + return header + ustruct.pack(">H", message_type) + message_data + + +def makeCidRequest(header, message_data): + return header + message_data + + +def getPlaintext() -> bytes: + if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1: + return PLAINTEXT_1 + return PLAINTEXT_0 + + +async def deprecated_read_message( + iface: WireInterface, buffer: utils.BufferType +) -> Message: + return Message(-1, b"\x00") + + +async def deprecated_write_message( + iface: WireInterface, message: Message, is_retransmission: bool = False +) -> None: + pass + + +# This test suite is an adaptation of test_trezor.wire.codec_v1 +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestWireTrezorHostProtocolV1(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def _simple(self): + cid_req_header = make_header( + ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 + ) + cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data) + + message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18) + cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0" + message = makeSimpleMessage( + message_header, + MESSAGE_TYPE, + cid_request_dummy_data + cid_request_dummy_data_checksum, + ) + + buffer = bytearray(64) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(cid_req_message) + gen.send(None) + gen.send(message) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, cid_request_dummy_data) + + buffer_without_zeroes = buffer[: len(message) - 5] + message_without_header = message[5:] + # message should have been read into the buffer + self.assertEqual(buffer_without_zeroes, message_without_header) + + def _read_one_packet(self): + # zero length message - just a header + PLAINTEXT = getPlaintext() + header = make_header( + PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + ) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES) + message = header + MESSAGE_TYPE_BYTES + chksum + + buffer = bytearray(64) + gen = deprecated_read_message(self.interface, buffer) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(message) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, b"") + + # message should have been read into the buffer + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) + + def _read_many_packets(self): + message = bytes(range(256)) + header = make_header( + getPlaintext(), + COMMON_CID, + len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, + ) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message) + # message = MESSAGE_TYPE_BYTES + message + checksum + + # first packet is init header + 59 bytes of data + # other packets are cont header + 61 bytes of data + cont_header = make_cont_header() + packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ + cont_header + chunk + for chunk in chunks( + message[INIT_MESSAGE_DATA_LENGTH:] + chksum, + 64 - HEADER_CONT_LENGTH, + ) + ] + buffer = bytearray(262) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + for packet in packets: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + # last packet will stop + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # message should have been read into the buffer ) + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) + + def _read_large_message(self): + message = b"hello world" + header = make_header( + getPlaintext(), + COMMON_CID, + _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH, + ) + + packet = ( + header + + MESSAGE_TYPE_BYTES + + message + + checksum.compute(header + MESSAGE_TYPE_BYTES + message) + ) + + # make sure we fit into one packet, to make this easier + self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH) + + buffer = bytearray(1) + self.assertTrue(len(buffer) <= len(packet)) + + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(packet) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # read should have allocated its own buffer and not touch ours + self.assertEqual(buffer, b"\x00") + + def _roundtrip(self): + message_payload = bytes(range(256)) + message = Message( + MESSAGE_TYPE, message_payload, 1 + ) # TODO use different session id + gen = deprecated_write_message(self.interface, message) + # exhaust the iterator: + # (XXX we can only do this because the iterator is only accepting None and returns None) + for query in gen: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + buffer = bytearray(1024) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + for packet in self.interface.data: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + print(utils.get_bytes_as_str(packet)) + query = gen.send(packet) + + with self.assertRaises(StopIteration) as e: + gen.send(None) + + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message.data) + + def _write_one_packet(self): + message = Message(MESSAGE_TYPE, b"") + gen = deprecated_write_message(self.interface, message) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + with self.assertRaises(StopIteration): + gen.send(None) + + header = make_header( + getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + ) + expected_message = ( + header + + MESSAGE_TYPE_BYTES + + checksum.compute(header + MESSAGE_TYPE_BYTES) + + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH) + ) + self.assertTrue(self.interface.data == [expected_message]) + + def _write_multiple_packets(self): + message_payload = bytes(range(256)) + message = Message(MESSAGE_TYPE, message_payload) + gen = deprecated_write_message(self.interface, message) + + header = make_header( + PLAINTEXT_1, + COMMON_CID, + len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, + ) + cont_header = make_cont_header() + chksum = checksum.compute( + header + message.type.to_bytes(2, "big") + message.data + ) + packets = [ + header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH] + ] + [ + cont_header + chunk + for chunk in chunks( + message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, + thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH, + ) + ] + + for _ in packets: + # we receive as many queries as there are packets + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + # the first sent None only started the generator. the len(packets)-th None + # will finish writing and raise StopIteration + with self.assertRaises(StopIteration): + gen.send(None) + + # packets must be identical up to the last one + self.assertListEqual(packets[:-1], self.interface.data[:-1]) + # last packet must be identical up to message length. remaining bytes in + # the 64-byte packets are garbage -- in particular, it's the bytes of the + # previous packet + last_packet = packets[-1] + packets[-2][len(packets[-1]) :] + self.assertEqual(last_packet, self.interface.data[-1]) + + def _read_huge_packet(self): + PACKET_COUNT = 1180 + # message that takes up 1 180 USB packets + message_size = (PACKET_COUNT - 1) * ( + PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN + ) + INIT_MESSAGE_DATA_LENGTH + + # ensure that a message this big won't fit into memory + # Note: this control is changed, because THP has only 2 byte length field + self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN) + # self.assertRaises(MemoryError, bytearray, message_size) + header = make_header(PLAINTEXT_1, COMMON_CID, message_size) + packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) + buffer = bytearray(65536) + gen = deprecated_read_message(self.interface, buffer) + + query = gen.send(None) + + # THP returns "Message too large" error after reading the message size, + # it is different from codec_v1 as it does not allow big enough messages + # to raise MemoryError in this test + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(trezor.wire.thp.ThpError) as e: + query = gen.send(packet) + + self.assertEqual(e.value.args[0], "Message too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/thp_common.py b/core/tests/thp_common.py new file mode 100644 index 00000000000..1ac432514b8 --- /dev/null +++ b/core/tests/thp_common.py @@ -0,0 +1,43 @@ +from trezor import utils +from trezor.wire.thp import ChannelState + +if utils.USE_THP: + import unittest + from typing import TYPE_CHECKING, Any, Awaitable + + from mock_wire_interface import MockHID + from storage import cache_thp + from trezor.wire import context + from trezor.wire.thp import interface_manager + from trezor.wire.thp.channel import Channel + from trezor.wire.thp.interface_manager import _MOCK_INTERFACE_HID + from trezor.wire.thp.session_context import SessionContext + + if TYPE_CHECKING: + from trezor.wire import WireInterface + + def dummy_decode_iface(cached_iface: bytes): + return MockHID(0xDEADBEEF) + + def get_new_channel(channel_iface: WireInterface | None = None) -> Channel: + interface_manager.decode_iface = dummy_decode_iface + channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID) + channel = Channel(channel_cache) + channel.set_channel_state(ChannelState.TH1) + if channel_iface is not None: + channel.iface = channel_iface + return channel + + def prepare_context() -> None: + channel = get_new_channel() + session_cache = cache_thp.get_new_session(channel.channel_cache) + session_ctx = SessionContext(channel, session_cache) + context.CURRENT_CONTEXT = session_ctx + + +if __debug__: + # Disable log.debug + def suppres_debug_log() -> None: + from trezor import log + + log.debug = lambda name, msg, *args: None diff --git a/core/tools/codegen/get_trezor_keys.py b/core/tools/codegen/get_trezor_keys.py index 31c40fef1fe..b511abd807d 100755 --- a/core/tools/codegen/get_trezor_keys.py +++ b/core/tools/codegen/get_trezor_keys.py @@ -2,7 +2,7 @@ import binascii from trezorlib.client import TrezorClient -from trezorlib.transport_hid import HidTransport +from trezorlib.transport.hid import HidTransport devices = HidTransport.enumerate() if len(devices) > 0: diff --git a/docs/ci/jobs.md b/docs/ci/jobs.md index 7a57340f24d..2325549c2a9 100644 --- a/docs/ci/jobs.md +++ b/docs/ci/jobs.md @@ -106,44 +106,44 @@ Frozen version. That means you do not need any other files to run it, it is just a single binary file that you can execute directly. **Are you looking for a Trezor T emulator? This is most likely it.** -### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L317) +### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L318) -### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L332) +### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L333) -### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L346) +### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L347) -### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L369) +### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L370) -### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L392) +### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L393) -### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L408) +### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L409) -### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L430) +### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L431) -### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L455) +### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L456) Build of our cryptographic library, which is then incorporated into the other builds. -### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L485) +### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L486) -### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L501) +### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L502) -### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L518) +### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L519) -### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L537) +### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L538) -### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L558) +### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L559) Regular version (not only Bitcoin) of above. **Are you looking for a Trezor One emulator? This is most likely it.** -### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L573) +### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L574) -### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L591) +### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L592) -### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L617) +### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L618) Build of Legacy into UNIX emulator. Use keyboard arrows to emulate button presses. Bitcoin-only version. -### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L634) +### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L635) --- ## TEST stage - [test.yml](https://github.com/trezor/trezor-firmware/blob/master/ci/test.yml) diff --git a/legacy/firmware/fsm.c b/legacy/firmware/fsm.c index 07c4c24b1cd..439bbd91713 100644 --- a/legacy/firmware/fsm.c +++ b/legacy/firmware/fsm.c @@ -191,6 +191,9 @@ void fsm_sendFailure(FailureType code, const char *text) case FailureType_Failure_InvalidSession: text = _("Invalid session"); break; + case FailureType_Failure_ThpUnallocatedSession: + text = _("Unallocated session"); + break; case FailureType_Failure_FirmwareError: text = _("Firmware error"); break; diff --git a/legacy/firmware/protob/Makefile b/legacy/firmware/protob/Makefile index 1acdb5348b2..1f48060e24e 100644 --- a/legacy/firmware/protob/Makefile +++ b/legacy/firmware/protob/Makefile @@ -10,7 +10,7 @@ SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdPro EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \ UnlockBootloader AuthenticateDevice AuthenticityProof \ Solana StellarClaimClaimableBalanceOp \ - ChangeLanguage TranslationDataRequest TranslationDataAck \ + ChangeLanguage TranslationDataRequest TranslationDataAck Thp \ SetBrightness DebugLinkOptigaSetSecMax \ ifeq ($(BITCOIN_ONLY), 1) diff --git a/legacy/firmware/protob/messages-thp.proto b/legacy/firmware/protob/messages-thp.proto new file mode 120000 index 00000000000..4799efe83ae --- /dev/null +++ b/legacy/firmware/protob/messages-thp.proto @@ -0,0 +1 @@ +../../vendor/trezor-common/protob/messages-thp.proto \ No newline at end of file diff --git a/python/channel_data.json b/python/channel_data.json new file mode 100644 index 00000000000..0637a088a01 --- /dev/null +++ b/python/channel_data.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 4f6d56f8ed1..0899692eaa7 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -112,7 +112,7 @@ def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: start = time.monotonic() try: while True: - if transport._ping(): + if transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py index 39e26f569fc..2e4a530af5b 100644 --- a/python/src/trezorlib/authentication.py +++ b/python/src/trezorlib/authentication.py @@ -7,7 +7,7 @@ from importlib import metadata from . import device -from .client import TrezorClient +from .transport.session import Session try: cryptography_version = metadata.version("cryptography") @@ -361,7 +361,7 @@ def verify_authentication_response( def authenticate_device( - client: TrezorClient, + session: Session, challenge: bytes | None = None, *, whitelist: t.Collection[bytes] | None = None, @@ -371,7 +371,7 @@ def authenticate_device( if challenge is None: challenge = secrets.token_bytes(16) - resp = device.authenticate(client, challenge) + resp = device.authenticate(session, challenge) return verify_authentication_response( challenge, diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index d2e4b97912c..afe251a06c3 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,22 +18,22 @@ from . import messages from .protobuf import dict_to_proto -from .tools import expect, session +from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.BinanceAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -42,16 +42,15 @@ def get_address( @expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) def get_public_key( - client: "TrezorClient", address_n: "Address", show_display: bool = False + session: "Session", address_n: "Address", show_display: bool = False ) -> "MessageType": - return client.call( + return session.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) ) -@session def sign_tx( - client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False + session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False ) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] tx_msg = tx_json.copy() @@ -60,7 +59,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - response = client.call(envelope) + response = session.call(envelope) if not isinstance(response, messages.BinanceTxRequest): raise RuntimeError( @@ -77,7 +76,7 @@ def sign_tx( else: raise ValueError("can not determine msg type") - response = client.call(msg) + response = session.call(msg) if not isinstance(response, messages.BinanceSignedTx): raise RuntimeError( diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index a71ead2adc2..3ccb1a95959 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -13,7 +13,6 @@ # # You should have received a copy of the License along with this library. # If not, see . - import warnings from copy import copy from decimal import Decimal @@ -23,12 +22,12 @@ from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import expect, prepare_message_bytes, session +from .tools import expect, prepare_message_bytes if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session class ScriptSig(TypedDict): asm: str @@ -105,7 +104,7 @@ def make_bin_output(vout: "Vout") -> messages.TxOutputBinType: @expect(messages.PublicKey) def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,13 +115,13 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any): @expect(messages.Address) def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -153,13 +152,13 @@ def get_authenticated_address( chunkify: bool = False, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -172,15 +171,16 @@ def get_authenticated_address( ) +# TODO this is used by tests only @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> "MessageType": - return client.call( + return session.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, @@ -190,8 +190,9 @@ def get_ownership_id( ) +# TODO this is used by tests only def get_ownership_proof( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, @@ -202,11 +203,11 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call( + res = session.call( messages.GetOwnershipProof( address_n=n, coin_name=coin_name, @@ -226,7 +227,7 @@ def get_ownership_proof( @expect(messages.MessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -234,7 +235,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -247,7 +248,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -255,7 +256,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.VerifyMessage( address=address, signature=signature, @@ -269,9 +270,9 @@ def verify_message( return isinstance(resp, messages.Success) -@session +# @session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -319,17 +320,17 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") elif preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call(signtx) + res = session.call(signtx) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -388,7 +389,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg) + res = session.call(msg) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -418,7 +419,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg)) + res = session.call(messages.TxAck(tx=msg)) if not isinstance(res, messages.TxRequest): raise exceptions.TrezorException("Unexpected message") @@ -432,7 +433,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: @expect(messages.Success, field="message", ret_type=str) def authorize_coinjoin( - client: "TrezorClient", + session: "Session", coordinator: str, max_rounds: int, max_coordinator_fee_rate: int, @@ -441,7 +442,7 @@ def authorize_coinjoin( coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> "MessageType": - return client.call( + return session.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 49d2c6463f8..f39cfb42221 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -35,8 +35,8 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -825,7 +825,7 @@ def _get_collateral_inputs_items( @expect(messages.CardanoAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_parameters: messages.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], @@ -833,7 +833,7 @@ def get_address( derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, @@ -847,12 +847,12 @@ def get_address( @expect(messages.CardanoPublicKey) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, show_display: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, @@ -863,12 +863,12 @@ def get_public_key( @expect(messages.CardanoNativeScriptHash) def get_native_script_hash( - client: "TrezorClient", + session: "Session", native_script: messages.CardanoNativeScript, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, @@ -878,7 +878,7 @@ def get_native_script_hash( def sign_tx( - client: "TrezorClient", + session: "Session", signing_mode: messages.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], @@ -915,7 +915,7 @@ def sign_tx( signing_mode, ) - response = client.call( + response = session.call( messages.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), @@ -951,14 +951,14 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call(auxiliary_data) + auxiliary_data_supplement = session.call(auxiliary_data) if not isinstance( auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement ): @@ -971,7 +971,7 @@ def sign_tx( auxiliary_data_supplement.__dict__ ) - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR @@ -980,24 +980,24 @@ def sign_tx( _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR for reference_input in reference_inputs: - response = client.call(reference_input) + response = session.call(reference_input) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request) + response = session.call(witness_request) if not isinstance(response, messages.CardanoTxWitnessResponse): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["witnesses"].append( @@ -1009,12 +1009,12 @@ def sign_tx( } ) - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoTxBodyHash): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoSignTxFinished): raise UNEXPECTED_RESPONSE_ERROR diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 050e3788f78..c5af2b57082 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,33 +14,41 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import logging +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport +from .. import exceptions, transport, ui from ..client import TrezorClient -from ..ui import ClickUI, ScriptUI +from ..messages import Capability +from ..transport import Transport +from ..transport.new import channel_database + +LOG = logging.getLogger(__name__) -if TYPE_CHECKING: +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar from typing_extensions import Concatenate, ParamSpec - from ..transport import Transport - from ..ui import TrezorClientUI - P = ParamSpec("P") R = TypeVar("R") class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -48,7 +56,7 @@ def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -57,11 +65,48 @@ def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: return self.typemap[value] -class TrezorConnection: +def get_passphrase( + passphrase_on_host: bool, available_on_device: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + +class NewTrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -70,6 +115,29 @@ def __init__( self.passphrase_on_host = passphrase_on_host self.script = script + def get_session(self, derive_cardano: bool = False): + client = self.get_client() + + if self.session_id is not None: + pass # TODO Try resume - be careful of cardano derivation settings! + features = client.protocol.get_features() + + passphrase_enabled = True # TODO what to do here? + + if not passphrase_enabled: + return client.get_session(derive_cardano=derive_cardano) + + # TODO Passphrase empty by default - ??? + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + return session + def get_transport(self) -> "Transport": try: # look for transport without prefix search @@ -82,19 +150,33 @@ def get_transport(self) -> "Transport": # if this fails, we want the exception to bubble up to the caller return transport.get_transport(self.path, prefix_search=True) - def get_ui(self) -> "TrezorClientUI": - if self.script: - # It is alright to return just the class object instead of instance, - # as the ScriptUI class object itself is the implementation of TrezorClientUI - # (ScriptUI is just a set of staticmethods) - return ScriptUI - else: - return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self) -> TrezorClient: transport = self.get_transport() - ui = self.get_ui() - return TrezorClient(transport, ui=ui, session_id=self.session_id) + + stored_channels = channel_database.load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + try: + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + except Exception: + LOG.debug("Failed to resume a channel. Replacing by a new one.") + channel_database.remove_channel(path) + client = TrezorClient(transport) + else: + client = TrezorClient(transport) + + return client + + def get_management_session(self) -> Session: + client = self.get_client() + management_session = client.get_management_session() + return management_session @contextmanager def client_context(self): @@ -128,7 +210,131 @@ def client_context(self): # other exceptions may cause a traceback -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": +# class TrezorConnection: + +# def __init__( +# self, +# path: str, +# session_id: bytes | None, +# passphrase_on_host: bool, +# script: bool, +# ) -> None: +# self.path = path +# self.session_id = session_id +# self.passphrase_on_host = passphrase_on_host +# self.script = script + +# def get_transport(self) -> "Transport": +# try: +# # look for transport without prefix search +# return transport.get_transport(self.path, prefix_search=False) +# except Exception: +# # most likely not found. try again below. +# pass + +# # look for transport with prefix search +# # if this fails, we want the exception to bubble up to the caller +# return transport.get_transport(self.path, prefix_search=True) + +# def get_ui(self) -> "TrezorClientUI": +# if self.script: +# # It is alright to return just the class object instead of instance, +# # as the ScriptUI class object itself is the implementation of TrezorClientUI +# # (ScriptUI is just a set of staticmethods) +# return ScriptUI +# else: +# return ClickUI(passphrase_on_host=self.passphrase_on_host) + +# def get_client(self) -> TrezorClient: +# transport = self.get_transport() +# ui = self.get_ui() +# return TrezorClient(transport, ui=ui, session_id=self.session_id) + +# @contextmanager +# def client_context(self): +# """Get a client instance as a context manager. Handle errors in a manner +# appropriate for end-users. + +# Usage: +# >>> with obj.client_context() as client: +# >>> do_your_actions_here() +# """ +# try: +# client = self.get_client() +# except transport.DeviceIsBusy: +# click.echo("Device is in use by another process.") +# sys.exit(1) +# except Exception: +# click.echo("Failed to find a Trezor device.") +# if self.path is not None: +# click.echo(f"Using path: {self.path}") +# sys.exit(1) + +# try: +# yield client +# except exceptions.Cancelled: +# # handle cancel action +# click.echo("Action was cancelled.") +# sys.exit(1) +# except exceptions.TrezorException as e: +# # handle any Trezor-sent exceptions as user-readable +# raise click.ClickException(str(e)) from e +# # other exceptions may cause a traceback + +from ..transport.session import Session + + +def with_cardano_session( + func: "t.Callable[Concatenate[Session, P], R]", +) -> "t.Callable[P, R]": + return with_session(func=func, derive_cardano=True) + + +def with_session( + func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False +) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + session = obj.get_session(derive_cardano) + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed + + # the return type of @click.pass_obj is improperly specified and pyright doesn't + # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) + return function_with_session # type: ignore [is incompatible with return type] + + +def with_management_session( + func: "t.Callable[Concatenate[Session, P], R]", +) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_management_session( + obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + session = obj.get_management_session() + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed + + # the return type of @click.pass_obj is improperly specified and pyright doesn't + # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) + return function_with_management_session # type: ignore [is incompatible with return type] + + +def with_client( + func: "t.Callable[Concatenate[TrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -139,28 +345,66 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ @click.pass_obj @functools.wraps(func) def trezorctl_command_with_client( - obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) - + # session_was_resumed = obj.session_id == client.session_id + # if not session_was_resumed and obj.session_id is not None: + # # tried to resume but failed + # click.echo("Warning: failed to resume session.", err=True) + click.echo( + "Warning: resume session detection is not implemented yet!", err=True + ) try: return func(client, *args, **kwargs) finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass + channel_database.save_channel(client.protocol) + # if not session_was_resumed: + # try: + # client.end_session() + # except Exception: + # pass # the return type of @click.pass_obj is improperly specified and pyright doesn't # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) return trezorctl_command_with_client # type: ignore [is incompatible with return type] +# def with_client( +# func: "t.Callable[Concatenate[TrezorClient, P], R]", +# ) -> "t.Callable[P, R]": +# """Wrap a Click command in `with obj.client_context() as client`. + +# Sessions are handled transparently. The user is warned when session did not resume +# cleanly. The session is closed after the command completes - unless the session +# was resumed, in which case it should remain open. +# """ + +# @click.pass_obj +# @functools.wraps(func) +# def trezorctl_command_with_client( +# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" +# ) -> "R": +# with obj.client_context() as client: +# session_was_resumed = obj.session_id == client.session_id +# if not session_was_resumed and obj.session_id is not None: +# # tried to resume but failed +# click.echo("Warning: failed to resume session.", err=True) + +# try: +# return func(client, *args, **kwargs) +# finally: +# if not session_was_resumed: +# try: +# client.end_session() +# except Exception: +# pass + +# # the return type of @click.pass_obj is improperly specified and pyright doesn't +# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) +# return trezorctl_command_with_client + + class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. @@ -190,14 +434,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index a3139fb2711..d8097b3e900 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -20,11 +20,11 @@ import click from .. import binance, tools -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" @@ -39,23 +39,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display, chunkify) + return binance.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() + return binance.get_public_key(session, address_n, show_display).hex() @cli.command() @@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. """ address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index dde59a6bc63..66f7b6870a5 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -168,15 +169,15 @@ def cli() -> None: default=2, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, chunkify: bool, ) -> str: @@ -220,7 +221,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -237,9 +238,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -251,7 +252,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -277,7 +278,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -311,7 +312,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -348,9 +349,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -360,7 +361,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -375,8 +376,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -401,7 +402,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -432,9 +433,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -447,7 +448,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -468,9 +469,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -480,7 +481,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 23647ab8eaa..5b160ea6cf3 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,10 +20,10 @@ import click from .. import cardano, messages, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_cardano_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" @@ -62,9 +62,9 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_client +@with_cardano_session def sign_tx( - client: "TrezorClient", + session: "Session", file: TextIO, signing_mode: messages.CardanoTxSigningMode, protocol_magic: int, @@ -123,9 +123,8 @@ def sign_tx( for p in transaction["additional_witness_requests"] ] - client.init_device(derive_cardano=True) sign_tx_response = cardano.sign_tx( - client, + session, signing_mode, inputs, outputs, @@ -209,9 +208,9 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_cardano_session def get_address( - client: "TrezorClient", + session: "Session", address: str, address_type: messages.CardanoAddressType, staking_address: str, @@ -262,9 +261,8 @@ def get_address( script_staking_hash_bytes, ) - client.init_device(derive_cardano=True) return cardano.get_address( - client, + session, address_parameters, protocol_magic, network_id, @@ -283,18 +281,17 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_cardano_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, derivation_type: messages.CardanoDerivationType, show_display: bool, ) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) - client.init_device(derive_cardano=True) return cardano.get_public_key( - client, address_n, derivation_type=derivation_type, show_display=show_display + session, address_n, derivation_type=derivation_type, show_display=show_display ) @@ -312,9 +309,9 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_client +@with_cardano_session def get_native_script_hash( - client: "TrezorClient", + session: "Session", file: TextIO, display_format: messages.CardanoNativeScriptHashDisplayFormat, derivation_type: messages.CardanoDerivationType, @@ -323,7 +320,6 @@ def get_native_script_hash( native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) - client.init_device(derive_cardano=True) return cardano.get_native_script_hash( - client, native_script, display_format, derivation_type=derivation_type + session, native_script, display_format, derivation_type=derivation_type ) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a58b80d4b69..b8fd2cdcb19 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,10 +19,10 @@ import click from .. import misc, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PROMPT_TYPE = ChoiceType( @@ -42,10 +42,10 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_client -def get_entropy(client: "TrezorClient", size: int) -> str: +@with_session +def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" - return misc.get_entropy(client, size).hex() + return misc.get_entropy(session, size).hex() @cli.command() @@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_client +@with_session def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -75,7 +75,7 @@ def encrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.encrypt_keyvalue( - client, + session, address_n, key, value.encode(), @@ -91,9 +91,9 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_client +@with_session def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -112,7 +112,7 @@ def decrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.decrypt_keyvalue( - client, + session, address_n, key, bytes.fromhex(value), diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index 50613a04eee..a0db79a52fa 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,16 +18,15 @@ import click -from .. import mapping, messages, protobuf -from ..client import TrezorClient from ..debuglink import TrezorClientDebugLink from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen -from . import with_client +from ..transport.session import Session +from . import with_management_session if TYPE_CHECKING: - from . import TrezorConnection + from . import NewTrezorConnection @click.group(name="debug") @@ -35,58 +34,58 @@ def cli() -> None: """Miscellaneous debug features.""" -@cli.command() -@click.argument("message_name_or_type") -@click.argument("hex_data") -@click.pass_obj -def send_bytes( - obj: "TrezorConnection", message_name_or_type: str, hex_data: str -) -> None: - """Send raw bytes to Trezor. +# @cli.command() +# @click.argument("message_name_or_type") +# @click.argument("hex_data") +# @click.pass_obj +# def send_bytes( +# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str +# ) -> None: +# """Send raw bytes to Trezor. - Message type and message data must be specified separately, due to how message - chunking works on the transport level. Message length is calculated and sent - automatically, and it is currently impossible to explicitly specify invalid length. +# Message type and message data must be specified separately, due to how message +# chunking works on the transport level. Message length is calculated and sent +# automatically, and it is currently impossible to explicitly specify invalid length. - MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, - in which case the value of that enum is used. - """ - if message_name_or_type.isdigit(): - message_type = int(message_name_or_type) - else: - message_type = getattr(messages.MessageType, message_name_or_type) +# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, +# in which case the value of that enum is used. +# """ +# if message_name_or_type.isdigit(): +# message_type = int(message_name_or_type) +# else: +# message_type = getattr(messages.MessageType, message_name_or_type) - if not isinstance(message_type, int): - raise click.ClickException("Invalid message type.") +# if not isinstance(message_type, int): +# raise click.ClickException("Invalid message type.") - try: - message_data = bytes.fromhex(hex_data) - except Exception as e: - raise click.ClickException("Invalid hex data.") from e +# try: +# message_data = bytes.fromhex(hex_data) +# except Exception as e: +# raise click.ClickException("Invalid hex data.") from e - transport = obj.get_transport() - transport.begin_session() - transport.write(message_type, message_data) +# transport = obj.get_transport() +# transport.deprecated_begin_session() +# transport.write(message_type, message_data) - response_type, response_data = transport.read() - transport.end_session() +# response_type, response_data = transport.read() +# transport.deprecated_end_session() - click.echo(f"Response type: {response_type}") - click.echo(f"Response data: {response_data.hex()}") +# click.echo(f"Response type: {response_type}") +# click.echo(f"Response data: {response_data.hex()}") - try: - msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) - click.echo("Parsed message:") - click.echo(protobuf.format_message(msg)) - except Exception as e: - click.echo(f"Could not parse response: {e}") +# try: +# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) +# click.echo("Parsed message:") +# click.echo(protobuf.format_message(msg)) +# except Exception as e: +# click.echo(f"Could not parse response: {e}") @cli.command() @click.argument("directory", required=False) @click.option("-s", "--stop", is_flag=True, help="Stop the recording") @click.pass_obj -def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) -> None: +def record(obj: "NewTrezorConnection", directory: Union[str, None], stop: bool) -> None: """Record screen changes into a specified directory. Recording can be stopped with `-s / --stop` option. @@ -95,7 +94,7 @@ def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) -> def record_screen_from_connection( - obj: "TrezorConnection", directory: Union[str, None] + obj: "NewTrezorConnection", directory: Union[str, None] ) -> None: """Record screen helper to transform TrezorConnection into TrezorClientDebugLink.""" transport = obj.get_transport() @@ -106,17 +105,17 @@ def record_screen_from_connection( @cli.command() -@with_client -def prodtest_t1(client: "TrezorClient") -> str: +@with_management_session +def prodtest_t1(session: "Session") -> str: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - return debuglink_prodtest_t1(client) + return debuglink_prodtest_t1(session) @cli.command() -@with_client -def optiga_set_sec_max(client: "TrezorClient") -> str: +@with_management_session +def optiga_set_sec_max(session: "Session") -> str: """Set Optiga's security event counter to maximum.""" - return debuglink_optiga_set_sec_max(client) + return debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 6f2bd14883c..ceb3822a176 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -24,12 +24,12 @@ import requests from .. import debuglink, device, exceptions, messages, ui -from . import ChoiceType, with_client +from . import ChoiceType, with_management_session if t.TYPE_CHECKING: - from ..client import TrezorClient from ..protobuf import MessageType - from . import TrezorConnection + from ..transport.session import Session + from . import NewTrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { "scrambled": messages.RecoveryDeviceInputMethod.ScrambledWords, @@ -64,17 +64,18 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_client -def wipe(client: "TrezorClient", bootloader: bool) -> str: +@with_management_session +def wipe(session: "Session", bootloader: bool) -> str: """Reset device to factory defaults and remove all private data.""" + features = session.features if bootloader: - if not client.features.bootloader_mode: + if not features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) else: click.echo("Wiping user data and firmware!") else: - if client.features.bootloader_mode: + if features.bootloader_mode: click.echo( "Your device is in bootloader mode. This operation would also erase firmware." ) @@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: click.echo("Wiping user data!") try: - return device.wipe(client) + return device.wipe( + session + ) # TODO decide where the wipe should happen - management or regular session except exceptions.TrezorFailure as e: click.echo("Action failed: {} {}".format(*e.args)) sys.exit(3) @@ -102,9 +105,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: @click.option("-s", "--slip0014", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_client +@with_management_session def load( - client: "TrezorClient", + session: "Session", mnemonic: t.Sequence[str], pin: str, passphrase_protection: bool, @@ -128,7 +131,7 @@ def load( try: return debuglink.load_device( - client, + session, mnemonic=list(mnemonic), pin=pin, passphrase_protection=passphrase_protection, @@ -163,9 +166,9 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_client +@with_management_session def recover( - client: "TrezorClient", + session: "Session", words: str, expand: bool, pin_protection: bool, @@ -193,7 +196,7 @@ def recover( type = messages.RecoveryType.UnlockRepeatedBackup return device.recover( - client, + session, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -214,9 +217,9 @@ def recover( @click.option("-s", "--skip-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) -@with_client +@with_management_session def setup( - client: "TrezorClient", + session: "Session", strength: int | None, passphrase_protection: bool, pin_protection: bool, @@ -233,7 +236,7 @@ def setup( BT = messages.BackupType if backup_type is None: - if client.version >= (2, 7, 1): + if session.version >= (2, 7, 1): # SLIP39 extendable was introduced in 2.7.1 backup_type = BT.Slip39_Single_Extendable else: @@ -243,10 +246,10 @@ def setup( if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) - and messages.Capability.Shamir not in client.features.capabilities + and messages.Capability.Shamir not in session.features.capabilities ) or ( backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) - and messages.Capability.ShamirGroups not in client.features.capabilities + and messages.Capability.ShamirGroups not in session.features.capabilities ): click.echo( "WARNING: Your Trezor device does not indicate support for the requested\n" @@ -254,7 +257,7 @@ def setup( ) return device.reset( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -269,23 +272,21 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_client +@with_management_session def backup( - client: "TrezorClient", + session: "Session", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), ) -> str: """Perform device seed backup.""" - return device.backup(client, group_threshold, groups) + return device.backup(session, group_threshold, groups) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_client -def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType -) -> str: +@with_management_session +def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -299,36 +300,36 @@ def sd_protect( off - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - if client.features.model == "1": + if session.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - return device.sd_protect(client, operation) + return device.sd_protect(session, operation) @cli.command() @click.pass_obj -def reboot_to_bootloader(obj: "TrezorConnection") -> str: +def reboot_to_bootloader(obj: "NewTrezorConnection") -> str: """Reboot device into bootloader mode. Currently only supported on Trezor Model One. """ - # avoid using @with_client because it closes the session afterwards, + # avoid using @with_management_session because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - return device.reboot_to_bootloader(client) + return device.reboot_to_bootloader(client.get_management_session()) @cli.command() -@with_client -def tutorial(client: "TrezorClient") -> str: +@with_management_session +def tutorial(session: "Session") -> str: """Show on-device tutorial.""" - return device.show_device_tutorial(client) + return device.show_device_tutorial(session) @cli.command() -@with_client -def unlock_bootloader(client: "TrezorClient") -> str: +@with_management_session +def unlock_bootloader(session: "Session") -> str: """Unlocks bootloader. Irreversible.""" - return device.unlock_bootloader(client) + return device.unlock_bootloader(session) @cli.command() @@ -339,11 +340,11 @@ def unlock_bootloader(client: "TrezorClient") -> str: type=int, help="Dialog expiry in seconds.", ) -@with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str: +@with_management_session +def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str: """Show a "Do not disconnect" dialog.""" if enable is False: - return device.set_busy(client, None) + return device.set_busy(session, None) if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -353,7 +354,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - return device.set_busy(client, expiry * 1000) + return device.set_busy(session, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( @@ -373,9 +374,9 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_client +@with_management_session def authenticate( - client: "TrezorClient", + session: "Session", hex_challenge: str | None, root: t.BinaryIO | None, raw: bool | None, @@ -400,7 +401,7 @@ def authenticate( challenge = bytes.fromhex(hex_challenge) if raw: - msg = device.authenticate(client, challenge) + msg = device.authenticate(session, challenge) click.echo(f"Challenge: {hex_challenge}") click.echo(f"Signature of challenge: {msg.signature.hex()}") @@ -448,14 +449,14 @@ def format(self, record: logging.LogRecord) -> str: else: whitelist_json = requests.get( PUBKEY_WHITELIST_URL_TEMPLATE.format( - model=client.model.internal_name.lower() + model=session.model.internal_name.lower() ) ).json() whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] try: authentication.authenticate_device( - client, challenge, root_pubkey=root_bytes, whitelist=whitelist + session, challenge, root_pubkey=root_bytes, whitelist=whitelist ) except authentication.DeviceNotAuthentic: click.echo("Device is not authentic.") diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 84c248c4a4c..27d461d8b0b 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -20,11 +20,11 @@ import click from .. import eos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" @@ -37,11 +37,11 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) - res = eos.get_public_key(client, address_n, show_display) + res = eos.get_public_key(session, address_n, show_display) return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" @@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_transaction( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) address_n = tools.parse_path(address) return eos.sign_tx( - client, + session, address_n, tx_json["transaction"], tx_json["chain_id"], diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 6bbfc0d356d..d810d2bf2d1 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -26,14 +26,14 @@ from .. import _rlp, definitions, ethereum, tools from ..messages import EthereumDefinitions -from . import with_client +from . import with_session if TYPE_CHECKING: import web3 from eth_typing import ChecksumAddress # noqa: I900 from web3.types import Wei - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" @@ -268,24 +268,24 @@ def cli( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - return ethereum.get_address(client, address_n, show_display, network, chunkify) + return ethereum.get_address(session, address_n, show_display, network, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: +@with_session +def get_public_node(session: "Session", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) - result = ethereum.get_public_node(client, address_n, show_display=show_display) + result = ethereum.get_public_node(session, address_n, show_display=show_display) return { "node": { "depth": result.node.depth, @@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-C", "--chunkify", is_flag=True) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", chain_id: int, address: str, amount: int, @@ -400,7 +400,7 @@ def sign_tx( encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) address_n = tools.parse_path(address) from_address = ethereum.get_address( - client, address_n, encoded_network=encoded_network + session, address_n, encoded_network=encoded_network ) if token: @@ -446,7 +446,7 @@ def sign_tx( assert max_gas_fee is not None assert max_priority_fee is not None sig = ethereum.sign_tx_eip1559( - client, + session, n=address_n, nonce=nonce, gas_limit=gas_limit, @@ -465,7 +465,7 @@ def sign_tx( gas_price = _get_web3().eth.gas_price assert gas_price is not None sig = ethereum.sign_tx( - client, + session, n=address_n, tx_type=tx_type, nonce=nonce, @@ -526,14 +526,14 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", address: str, message: str, chunkify: bool + session: "Session", address: str, message: str, chunkify: bool ) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) + ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify) output = { "message": message, "address": ret.address, @@ -550,9 +550,9 @@ def sign_message( help="Be compatible with Metamask's signTypedData_v4 implementation", ) @click.argument("file", type=click.File("r")) -@with_client +@with_session def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO + session: "Session", address: str, metamask_v4_compat: bool, file: TextIO ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -565,7 +565,7 @@ def sign_typed_data( defs = EthereumDefinitions(encoded_network=network) data = json.loads(file.read()) ret = ethereum.sign_typed_data( - client, + session, address_n, data, metamask_v4_compat=metamask_v4_compat, @@ -583,9 +583,9 @@ def sign_typed_data( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: str, message: str, @@ -594,7 +594,7 @@ def verify_message( """Verify message signed with Ethereum address.""" signature_bytes = ethereum.decode_hex(signature) return ethereum.verify_message( - client, address, signature_bytes, message, chunkify=chunkify + session, address, signature_bytes, message, chunkify=chunkify ) @@ -602,9 +602,9 @@ def verify_message( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("domain_hash_hex") @click.argument("message_hash_hex") -@with_client +@with_session def sign_typed_data_hash( - client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str + session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str ) -> Dict[str, str]: """ Sign hash of typed data (EIP-712) with Ethereum address. @@ -618,7 +618,7 @@ def sign_typed_data_hash( message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) ret = ethereum.sign_typed_data_hash( - client, address_n, domain_hash, message_hash, network + session, address_n, domain_hash, message_hash, network ) output = { "domain_hash": domain_hash_hex, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 5983c572493..8d5e5628ba7 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,10 +19,10 @@ import click from .. import fido -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -40,10 +40,10 @@ def credentials() -> None: @credentials.command(name="list") -@with_client -def credentials_list(client: "TrezorClient") -> None: +@with_session +def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) for cred in creds: click.echo("") click.echo(f"WebAuthn credential at index {cred.index}:") @@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: +@with_session +def credentials_add(session: "Session", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - return fido.add_credential(client, bytes.fromhex(hex_credential_id)) + return fido.add_credential(session, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_client -def credentials_remove(client: "TrezorClient", index: int) -> str: +@with_session +def credentials_remove(session: "Session", index: int) -> str: """Remove the resident credential at the given index.""" - return fido.remove_credential(client, index) + return fido.remove_credential(session, index) # @@ -110,19 +110,19 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_client -def counter_set(client: "TrezorClient", counter: int) -> str: +@with_session +def counter_set(session: "Session", counter: int) -> str: """Set FIDO/U2F counter value.""" - return fido.set_counter(client, counter) + return fido.set_counter(session, counter) @counter.command(name="get-next") -@with_client -def counter_get_next(client: "TrezorClient") -> int: +@with_session +def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(client) + return fido.get_next_counter(session) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 4376a4f2839..73449f380bc 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,11 +37,12 @@ from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_client +from . import ChoiceType, with_management_session if TYPE_CHECKING: from ..client import TrezorClient - from . import TrezorConnection + from ..transport.session import Session + from . import NewTrezorConnection MODEL_CHOICE = ChoiceType( { @@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool: This is the case from bootloader version 1.8.0, and also holds for firmware version 1.8.0 because that installs the appropriate bootloader. """ - f = client.features - version = (f.major_version, f.minor_version, f.patch_version) - bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) + features = client.features + version = client.version + bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0) return bootloader_onev2 @@ -306,25 +307,26 @@ def find_best_firmware_version( If the specified version is not found, prints the closest available version (higher than the specified one, if existing). """ + features = client.features + model = client.model + if bitcoin_only is None: - bitcoin_only = _should_use_bitcoin_only(client.features) + bitcoin_only = _should_use_bitcoin_only(features) def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) - f = client.features - - releases = get_all_firmware_releases(client.model, bitcoin_only, beta) + releases = get_all_firmware_releases(model, bitcoin_only, beta) highest_version = releases[0]["version"] if version: want_version = [int(x) for x in version.split(".")] if len(want_version) != 3: click.echo("Please use the 'X.Y.Z' version format.") - if want_version[0] != f.major_version: + if want_version[0] != features.major_version: click.echo( - f"Warning: Trezor {client.model.name} firmware version should be " - f"{f.major_version}.X.Y (requested: {version})" + f"Warning: Trezor {model.name} firmware version should be " + f"{features.major_version}.X.Y (requested: {version})" ) else: want_version = highest_version @@ -359,8 +361,8 @@ def version_str(version: Iterable[int]) -> str: # to the newer one, in that case update to the minimal # compatible version first # Choosing the version key to compare based on (not) being in BL mode - client_version = [f.major_version, f.minor_version, f.patch_version] - if f.bootloader_mode: + client_version = client.version + if features.bootloader_mode: key_to_compare = "min_bootloader_version" else: key_to_compare = "min_firmware_version" @@ -447,11 +449,11 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: "TrezorClient", + session: "Session", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" - f = client.features + f = session.features try: if f.major_version == 1 and f.firmware_present is not False: # Trezor One does not send ButtonRequest @@ -461,7 +463,7 @@ def upload_firmware_into_device( with click.progressbar( label="Uploading", length=len(firmware_data), show_eta=False ) as bar: - firmware.update(client, firmware_data, bar.update) + firmware.update(session, firmware_data, bar.update) except exceptions.Cancelled: click.echo("Update aborted on device.") except exceptions.TrezorException as e: @@ -519,7 +521,7 @@ def cli() -> None: @click.pass_obj # fmt: on def verify( - obj: "TrezorConnection", + obj: "NewTrezorConnection", filename: BinaryIO, check_device: bool, fingerprint: Optional[str], @@ -564,7 +566,7 @@ def verify( @click.pass_obj # fmt: on def download( - obj: "TrezorConnection", + obj: "NewTrezorConnection", output: Optional[BinaryIO], model: Optional[TrezorModel], version: Optional[str], @@ -630,7 +632,7 @@ def download( # fmt: on @click.pass_obj def update( - obj: "TrezorConnection", + obj: "NewTrezorConnection", filename: Optional[BinaryIO], url: Optional[str], version: Optional[str], @@ -654,6 +656,7 @@ def update( against data.trezor.io information, if available. """ with obj.client_context() as client: + management_session = client.get_management_session() if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") sys.exit(1) @@ -709,7 +712,7 @@ def update( if _is_strict_update(client, firmware_data): header_size = _get_firmware_header_size(firmware_data) device.reboot_to_bootloader( - client, + management_session, boot_command=messages.BootCommand.INSTALL_UPGRADE, firmware_header=firmware_data[:header_size], language_data=language_data, @@ -719,7 +722,7 @@ def update( click.echo( "WARNING: Seamless installation not possible, language data will not be uploaded." ) - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(management_session) click.echo("Waiting for bootloader...") while True: @@ -735,13 +738,15 @@ def update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - upload_firmware_into_device(client=client, firmware_data=firmware_data) + upload_firmware_into_device( + session=client.get_management_session(), firmware_data=firmware_data + ) @cli.command() @click.argument("hex_challenge", required=False) -@with_client -def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: +@with_management_session +def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None - return firmware.get_hash(client, challenge).hex() + return firmware.get_hash(session, challenge).hex() diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index 355c562ae39..0441ebc09b4 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -19,10 +19,10 @@ import click from .. import messages, monero, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" @@ -42,9 +42,9 @@ def cli() -> None: default=messages.MoneroNetworkType.MAINNET, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, network_type: messages.MoneroNetworkType, @@ -52,7 +52,7 @@ def get_address( ) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - return monero.get_address(client, address_n, show_display, network_type, chunkify) + return monero.get_address(session, address_n, show_display, network_type, chunkify) @cli.command() @@ -63,13 +63,13 @@ def get_address( type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), default=messages.MoneroNetworkType.MAINNET, ) -@with_client +@with_session def get_watch_key( - client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType + session: "Session", address: str, network_type: messages.MoneroNetworkType ) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - res = monero.get_watch_key(client, address_n, network_type) + res = monero.get_watch_key(session, address_n, network_type) # TODO: could be made required in MoneroWatchKey assert res.address is not None assert res.watch_key is not None diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 746ad187236..eac16c2d8c2 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -21,10 +21,10 @@ import requests from .. import nem, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" @@ -39,9 +39,9 @@ def cli() -> None: @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, network: int, show_display: bool, @@ -49,7 +49,7 @@ def get_address( ) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) - return nem.get_address(client, address_n, network, show_display, chunkify) + return nem.get_address(session, address_n, network, show_display, chunkify) @cli.command() @@ -58,9 +58,9 @@ def get_address( @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, file: TextIO, broadcast: Optional[str], @@ -71,7 +71,7 @@ def sign_tx( Transaction file is expected in the NIS (RequestPrepareAnnounce) format. """ address_n = tools.parse_path(address) - transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify) payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index e4bcc0b3503..634a92028e6 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -20,10 +20,10 @@ import click from .. import ripple, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" @@ -37,13 +37,13 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ripple address""" address_n = tools.parse_path(address) - return ripple.get_address(client, address_n, show_display, chunkify) + return ripple.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -51,13 +51,13 @@ def get_address( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client -def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) - result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) + result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify) click.echo("Signature:") click.echo(result.signature.hex()) click.echo() diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 01e9ca68f1d..5535ae4e34c 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -24,10 +24,11 @@ import requests from .. import device, messages, toif -from . import AliasedGroup, ChoiceType, with_client +from ..transport.session import Session +from . import AliasedGroup, ChoiceType, with_management_session if TYPE_CHECKING: - from ..client import TrezorClient + pass try: from PIL import Image @@ -174,18 +175,18 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +@with_management_session +def pin(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - return device.change_pin(client, remove=_should_remove(enable, remove)) + return device.change_pin(session, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +@with_management_session +def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -193,32 +194,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - return device.change_wipe_code(client, remove=_should_remove(enable, remove)) + return device.change_wipe_code(session, remove=_should_remove(enable, remove)) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_client -def label(client: "TrezorClient", label: str) -> str: +@with_management_session +def label(session: "Session", label: str) -> str: """Set new device label.""" - return device.apply_settings(client, label=label) + return device.apply_settings(session, label=label) @cli.command() -@with_client -def brightness(client: "TrezorClient") -> str: +@with_management_session +def brightness(session: "Session") -> str: """Set display brightness.""" - return device.set_brightness(client) + return device.set_brightness(session) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> str: +@with_management_session +def haptic_feedback(session: "Session", enable: bool) -> str: """Enable or disable haptic feedback.""" - return device.apply_settings(client, haptic_feedback=enable) + return device.apply_settings(session, haptic_feedback=enable) @cli.command() @@ -227,9 +228,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_client +@with_management_session def language( - client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None + session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> str: """Set new language with translations.""" if remove != (path_or_url is None): @@ -254,29 +255,29 @@ def language( f"Failed to load translations from {path_or_url}" ) from None return device.change_language( - client, language_data=language_data, show_display=display + session, language_data=language_data, show_display=display ) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_client -def display_rotation(client: "TrezorClient", rotation: int) -> str: +@with_management_session +def display_rotation(session: "Session", rotation: int) -> str: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - return device.apply_settings(client, display_rotation=rotation) + return device.apply_settings(session, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> str: +@with_management_session +def auto_lock_delay(session: "Session", delay: str) -> str: """Set auto-lock delay (in seconds).""" - if not client.features.pin_protection: + if not session.features.pin_protection: raise click.ClickException("Set up a PIN first") value, unit = delay[:-1], delay[-1:] @@ -285,13 +286,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@with_client -def flags(client: "TrezorClient", flags: str) -> str: +@with_management_session +def flags(session: "Session", flags: str) -> str: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -299,7 +300,7 @@ def flags(client: "TrezorClient", flags: str) -> str: flags_int = int(flags, 16) else: flags_int = int(flags) - return device.apply_flags(client, flags=flags_int) + return device.apply_flags(session, flags=flags_int) @cli.command() @@ -308,8 +309,8 @@ def flags(client: "TrezorClient", flags: str) -> str: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: +@with_management_session +def homescreen(session: "Session", filename: str, quality: int) -> str: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -321,39 +322,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: if not path.exists() or not path.is_file(): raise click.ClickException("Cannot open file") - if client.features.model == "1": + if session.features.model == "1": img = image_to_t1(path) else: - if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: + if session.features.homescreen_format == messages.HomescreenFormat.Jpeg: width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 240 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 240 ) img = image_to_jpeg(path, width, height, quality) - elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: - width = client.features.homescreen_width - height = client.features.homescreen_height + elif session.features.homescreen_format == messages.HomescreenFormat.ToiG: + width = session.features.homescreen_width + height = session.features.homescreen_height if width is None or height is None: raise click.ClickException("Device did not report homescreen size.") img = image_to_toif(path, width, height, True) elif ( - client.features.homescreen_format == messages.HomescreenFormat.Toif - or client.features.homescreen_format is None + session.features.homescreen_format == messages.HomescreenFormat.Toif + or session.features.homescreen_format is None ): width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 144 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 144 ) img = image_to_toif(path, width, height, False) @@ -363,7 +364,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "Unknown image format requested by the device." ) - return device.apply_settings(client, homescreen=img) + return device.apply_settings(session, homescreen=img) @cli.command() @@ -371,9 +372,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_client +@with_management_session def safety_checks( - client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel + session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> str: """Set safety check level. @@ -386,18 +387,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - return device.apply_settings(client, safety_checks=level) + return device.apply_settings(session, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def experimental_features(client: "TrezorClient", enable: bool) -> str: +@with_management_session +def experimental_features(session: "Session", enable: bool) -> str: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - return device.apply_settings(client, experimental_features=enable) + return device.apply_settings(session, experimental_features=enable) # @@ -420,25 +421,25 @@ def passphrase_main() -> None: @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str: +@with_management_session +def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str: """Enable passphrase.""" - if client.features.passphrase_protection is not True: + if session.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None return device.apply_settings( - client, + session, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, ) @passphrase.command(name="off") -@with_client -def passphrase_off(client: "TrezorClient") -> str: +@with_management_session +def passphrase_off(session: "Session") -> str: """Disable passphrase.""" - return device.apply_settings(client, use_passphrase=False) + return device.apply_settings(session, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -451,10 +452,10 @@ def passphrase_off(client: "TrezorClient") -> str: @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str: +@with_management_session +def hide_passphrase_from_host(session: "Session", hide: bool) -> str: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - return device.apply_settings(client, hide_passphrase_from_host=hide) + return device.apply_settings(session, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 3fe80a51646..8152116b550 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -4,10 +4,10 @@ import click from .. import messages, solana, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h" @@ -21,40 +21,40 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, ) -> messages.SolanaPublicKey: """Get Solana public key.""" address_n = tools.parse_path(address) - return solana.get_public_key(client, address_n, show_display) + return solana.get_public_key(session, address_n, show_display) @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, chunkify: bool, ) -> messages.SolanaAddress: """Get Solana address.""" address_n = tools.parse_path(address) - return solana.get_address(client, address_n, show_display, chunkify) + return solana.get_address(session, address_n, show_display, chunkify) @cli.command() @click.argument("serialized_tx", type=str) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-a", "--additional-information-file", type=click.File("r")) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, serialized_tx: str, additional_information_file: Optional[TextIO], @@ -78,7 +78,7 @@ def sign_tx( ) return solana.sign_tx( - client, + session, address_n, bytes.fromhex(serialized_tx), additional_information, diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 77ce700ee5b..9acb6a57ed7 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -21,10 +21,10 @@ import click from .. import stellar, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from stellar_sdk import ( @@ -52,13 +52,13 @@ def cli() -> None: ) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) - return stellar.get_address(client, address_n, show_display, chunkify) + return stellar.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -77,9 +77,9 @@ def get_address( help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@with_client +@with_session def sign_transaction( - client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str + session: "Session", b64envelope: str, address: str, network_passphrase: str ) -> bytes: """Sign a base64-encoded transaction envelope. @@ -109,6 +109,6 @@ def sign_transaction( address_n = tools.parse_path(address) tx, operations = stellar.from_envelope(envelope) - resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) + resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase) return base64.b64encode(resp.signature) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 7dcd1ab9db1..e4f0c1a877d 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -20,10 +20,10 @@ import click from .. import messages, protobuf, tezos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" @@ -37,23 +37,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) - return tezos.get_address(client, address_n, show_display, chunkify) + return tezos.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) - return tezos.get_public_key(client, address_n, show_display) + return tezos.get_public_key(session, address_n, show_display) @cli.command() @@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) - return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) + return tezos.sign_tx(session, address_n, msg, chunkify=chunkify) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 625baab5ee1..bc93a448d59 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,13 +24,15 @@ import click -from .. import __version__, log, messages, protobuf, ui +from .. import __version__, log, messages, protobuf from ..client import TrezorClient from ..transport import DeviceIsBusy, enumerate_devices +from ..transport.new import channel_database +from ..transport.session import Session from ..transport.udp import UdpTransport from . import ( AliasedGroup, - TrezorConnection, + NewTrezorConnection, binance, btc, cardano, @@ -49,6 +51,7 @@ stellar, tezos, with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -213,7 +216,8 @@ def cli_main( except ValueError: raise click.ClickException(f"Not a valid session id: {session_id}") - ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) + # ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) + ctx.obj = NewTrezorConnection(path, bytes_session_id, passphrase_on_host, script) # Optionally record the screen into a specified directory. if record: @@ -254,7 +258,7 @@ def print_result(res: Any, is_json: bool, script: bool, **kwargs: Any) -> None: @cli.set_result_callback() @click.pass_obj -def stop_recording_action(obj: TrezorConnection, *args: Any, **kwargs: Any) -> None: +def stop_recording_action(obj: NewTrezorConnection, *args: Any, **kwargs: Any) -> None: """Stop recording screen changes when the recording was started by `cli_main`. (When user used the `-r / --record` option of `trezorctl` command.) @@ -286,16 +290,31 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: if no_resolve: return enumerate_devices() + stored_channels = channel_database.load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] for transport in enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + else: + client = TrezorClient(transport) + description = format_device_name(client.features) - client.end_session() + # json_string = channel_database.channel_to_str(client.protocol) + # print(json_string) + channel_database.save_channel(client.protocol) + # client.end_session() except DeviceIsBusy: description = "Device is in use by another process" except Exception: description = "Failed to read details" - click.echo(f"{transport} - {description}") + click.echo(f"{transport.get_path()} - {description}") return None @@ -313,15 +332,21 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.call( + messages.Ping(message=message, button_protection=button_protection) + ) @cli.command() @click.pass_obj -def get_session(obj: TrezorConnection) -> str: +def get_session( + obj: NewTrezorConnection, passphrase: str = "", derive_cardano: bool = False +) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -335,23 +360,35 @@ def get_session(obj: TrezorConnection) -> str: obj.session_id = None with obj.client_context() as client: + if client.features.model == "1" and client.version < (1, 9, 0): raise click.ClickException( "Upgrade your firmware to enable session support." ) client.ensure_unlocked() - if client.session_id is None: + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + if session.id is None: raise click.ClickException("Passphrase not enabled or firmware too old.") else: - return client.session_id.hex() + return session.id.hex() @cli.command() -@with_client -def clear_session(client: "TrezorClient") -> None: +@with_session +def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - return client.clear_session() + # TODO something like old: return client.clear_session() + print("NOT IMPLEMENTED") + raise NotImplementedError + + +@cli.command() +def new_clear_session() -> None: + """New Clear session (remove cached channels from trezorlib).""" + channel_database.clear_stored_channels() @cli.command() @@ -376,7 +413,7 @@ def usb_reset() -> None: @cli.command() @click.option("-t", "--timeout", type=float, default=10, help="Timeout in seconds") @click.pass_obj -def wait_for_emulator(obj: TrezorConnection, timeout: float) -> None: +def wait_for_emulator(obj: NewTrezorConnection, timeout: float) -> None: """Wait until Trezor Emulator comes up. Tries to connect to emulator and returns when it succeeds. diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 7fd6dae32ba..236a62e8226 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -13,25 +13,23 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import logging import os -import warnings -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import typing as t -from mnemonic import Mnemonic +from . import mapping, messages, models +from .mapping import ProtobufMapping +from .tools import parse_path +from .transport import Transport, get_transport +from .transport.new.channel_data import ChannelData +from .transport.new.protocol_and_channel import ProtocolAndChannel +from .transport.new.protocol_v1 import ProtocolV1 +from .transport.new.protocol_v2 import ProtocolV2 -from . import exceptions, mapping, messages, models -from .log import DUMP_BYTES -from .messages import Capability -from .tools import expect, parse_path, session - -if TYPE_CHECKING: - from .protobuf import MessageType - from .transport import Transport - from .ui import TrezorClientUI - -UI = TypeVar("UI", bound="TrezorClientUI") +if t.TYPE_CHECKING: + from .transport.session import Session LOG = logging.getLogger(__name__) @@ -48,445 +46,612 @@ """.strip() -def get_default_client( - path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any -) -> "TrezorClient": - """Get a client for a connected Trezor device. - - Returns a TrezorClient instance with minimum fuss. - - If path is specified, does a prefix-search for the specified device. Otherwise, uses - the value of TREZOR_PATH env variable, or finds first connected Trezor. - If no UI is supplied, instantiates the default CLI UI. - """ - from .transport import get_transport - from .ui import ClickUI - - if path is None: - path = os.getenv("TREZOR_PATH") - - transport = get_transport(path, prefix_search=True) - if ui is None: - ui = ClickUI() - - return TrezorClient(transport, ui, **kwargs) - +LOG = logging.getLogger(__name__) -class TrezorClient(Generic[UI]): - """Trezor client, a connection to a Trezor device. - This class allows you to manage connection state, send and receive protobuf - messages, handle user interactions, and perform some generic tasks - (send a cancel message, initialize or clear a session, ping the device). - """ +class TrezorClient: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None - model: models.TrezorModel - transport: "Transport" - session_id: Optional[bytes] - ui: UI - features: messages.Features + _management_session: Session | None = None + _features: messages.Features | None = None def __init__( self, - transport: "Transport", - ui: UI, - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, - model: Optional[models.TrezorModel] = None, - _init_device: bool = True, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: ProtocolAndChannel | None = None, ) -> None: - """Create a TrezorClient instance. - - You have to provide a `transport`, i.e., a raw connection to the device. You can - use `trezorlib.transport.get_transport` to find one. - - You have to provide an UI implementation for the three kinds of interaction: - - button request (notify the user that their interaction is needed) - - PIN request (on T1, ask the user to input numbers for a PIN matrix) - - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for - details. - - You can supply a `session_id` you might have saved in the previous session. If - you do, the user might not need to enter their passphrase again. - - You can provide Trezor model information. If not provided, it is detected from - the model name reported at initialization time. - - By default, the instance will open a connection to the Trezor device, send an - `Initialize` message, set up the `features` field from the response, and connect - to a session. By specifying `_init_device=False`, this step is skipped. Notably, - this means that `client.features` is unset. Use `client.init_device()` or - `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. - Only use this if you are _sure_ that you know what you are doing. This feature - might be removed at any time. - """ - LOG.info(f"creating client instance for device: {transport.get_path()}") - # Here, self.model could be set to None. Unless _init_device is False, it will - # get correctly reconfigured as part of the init_device flow. - self.model = model # type: ignore ["None" is incompatible with "TrezorModel"] - if self.model: - self.mapping = self.model.default_mapping - else: - self.mapping = mapping.DEFAULT_MAPPING self.transport = transport - self.ui = ui - self.session_counter = 0 - self.session_id = session_id - if _init_device: - self.init_device(session_id=session_id, derive_cardano=derive_cardano) - - def open(self) -> None: - if self.session_counter == 0: - self.transport.begin_session() - self.session_counter += 1 - - def close(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - # TODO call EndSession here? - self.transport.end_session() - - def cancel(self) -> None: - self._raw_write(messages.Cancel()) - - def call_raw(self, msg: "MessageType") -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self._raw_write(msg) - return self._raw_read() - - def _raw_write(self, msg: "MessageType") -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - msg_type, msg_bytes = self.mapping.encode(msg) - LOG.log( - DUMP_BYTES, - f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - self.transport.write(msg_type, msg_bytes) - - def _raw_read(self) -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - msg_type, msg_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - msg = self.mapping.decode(msg_type, msg_bytes) - LOG.debug( - f"received message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - return msg - - def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": - try: - pin = self.ui.get_pin(msg.type) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - self.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - - resp = self.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise exceptions.PinException(resp.code, resp.message) + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING else: - return resp - - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": - available_on_device = Capability.PassphraseEntry in self.features.capabilities - - def send_passphrase( - passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> "MessageType": - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = self.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - self.session_id = resp.state - resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = self.ui.get_passphrase(available_on_device=available_on_device) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - self.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - self.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - self._raw_write(messages.ButtonAck()) - self.ui.button_request(msg) - return self._raw_read() - - @session - def call(self, msg: "MessageType") -> "MessageType": - self.check_firmware_version() - resp = self.call_raw(msg) - while True: - if isinstance(resp, messages.PinMatrixRequest): - resp = self._callback_pin(resp) - elif isinstance(resp, messages.PassphraseRequest): - resp = self._callback_passphrase(resp) - elif isinstance(resp, messages.ButtonRequest): - resp = self._callback_button(resp) - elif isinstance(resp, messages.Failure): - if resp.code == messages.FailureType.ActionCancelled: - raise exceptions.Cancelled - raise exceptions.TrezorFailure(resp) + self.mapping = protobuf_mapping + if protocol is None: + try: + self.protocol = self._get_protocol() + except Exception as e: + print(e) + else: + self.protocol = protocol + self.protocol.mapping = self.mapping + + @classmethod + def resume( + cls, + transport: Transport, + channel_data: ChannelData, + protobuf_mapping: ProtobufMapping | None = None, + ) -> TrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + protocol_v1 = ProtocolV1(transport, protobuf_mapping) + if channel_data.protocol_version == 2: + try: + protocol_v1.write(messages.Ping(message="Sanity check - to resume")) + except Exception as e: + print(type(e)) + response = protocol_v1.read() + if ( + isinstance(response, messages.Failure) + and response.code == messages.FailureType.InvalidProtocol + ): + protocol = ProtocolV2(transport, protobuf_mapping, channel_data) + protocol.write(0, messages.Ping()) + response = protocol.read(0) + if not isinstance(response, messages.Success): + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + LOG.debug("Protocol V2 detected - can be resumed") else: - return resp - - def _refresh_features(self, features: messages.Features) -> None: - """Update internal fields based on passed-in Features message.""" - - if not self.model: - # Trezor Model One bootloader 1.8.0 or older does not send model name - model = models.by_internal_name(features.internal_model) - if model is None: - model = models.by_name(features.model or "1") - if model is None: - raise RuntimeError( - "Unsupported Trezor model" - f" (internal_model: {features.internal_model}, model: {features.model})" - ) - self.model = model - - if features.vendor not in self.model.vendors: - raise RuntimeError("Unsupported device") - - self.features = features - self.version = ( - self.features.major_version, - self.features.minor_version, - self.features.patch_version, - ) - self.check_firmware_version(warn_only=True) - if self.features.session_id is not None: - self.session_id = self.features.session_id - self.features.session_id = None - - @session - def refresh_features(self) -> messages.Features: - """Reload features from the device. - - Should be called after changing settings or performing operations that affect - device state. - """ - resp = self.call_raw(messages.GetFeatures()) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to GetFeatures") - self._refresh_features(resp) - return resp - - @session - def init_device( + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + else: + protocol = ProtocolV1(transport, protobuf_mapping, channel_data) + return TrezorClient(transport, protobuf_mapping, protocol) + + def get_session( self, - *, - session_id: Optional[bytes] = None, - new_session: bool = False, - derive_cardano: Optional[bool] = None, - ) -> Optional[bytes]: - """Initialize the device and return a session ID. - - You can optionally specify a session ID. If the session still exists on the - device, the same session ID will be returned and the session is resumed. - Otherwise a different session ID is returned. - - Specify `new_session=True` to open a fresh session. Since firmware version - 1.9.0/2.3.0, the previous session will remain cached on the device, and can be - resumed by calling `init_device` again with the appropriate session ID. - - If neither `new_session` nor `session_id` is specified, the current session ID - will be reused. If no session ID was cached, a new session ID will be allocated - and returned. - - # Version notes: - - Trezor One older than 1.9.0 does not have session management. Optional arguments - have no effect and the function returns None - - Trezor T older than 2.3.0 does not have session cache. Requesting a new session - will overwrite the old one. In addition, this function will always return None. - A valid session_id can be obtained from the `session_id` attribute, but only - after a passphrase-protected call is performed. You can use the following code: - - >>> client.init_device() - >>> client.ensure_unlocked() - >>> valid_session_id = client.session_id - """ - if new_session: - self.session_id = None - elif session_id is not None: - self.session_id = session_id - - resp = self.call_raw( - messages.Initialize( - session_id=self.session_id, - derive_cardano=derive_cardano, + passphrase: str | None = None, + derive_cardano: bool = False, + ) -> Session: + from .transport.session import SessionV1, SessionV2 + + if isinstance(self.protocol, ProtocolV1): + return SessionV1.new(self, passphrase, derive_cardano) + if isinstance(self.protocol, ProtocolV2): + return SessionV2.new(self, passphrase, derive_cardano) + raise NotImplementedError # TODO + + def get_management_session(self, new_session: bool = False) -> Session: + from .transport.session import SessionV1, SessionV2 + + if not new_session and self._management_session is not None: + return self._management_session + if isinstance(self.protocol, ProtocolV1): + self._management_session = SessionV1.new(self, "", False) + elif isinstance(self.protocol, ProtocolV2): + self._management_session = SessionV2(self, b"\x00") + assert self._management_session is not None + return self._management_session + + @property + def features(self) -> messages.Features: + if self._features is None: + self._features = self.protocol.get_features() + assert self._features is not None + return self._features + + @property + def model(self) -> models.TrezorModel: + f = self.features + model = models.by_name(f.model or "1") + + if model is None: + raise RuntimeError( + "Unsupported Trezor model" + f" (internal_model: {f.internal_model}, model: {f.model})" ) + return model + + @property + def version(self) -> tuple[int, int, int]: + f = self.features + ver = ( + f.major_version, + f.minor_version, + f.patch_version, ) - if isinstance(resp, messages.Failure): - # can happen if `derive_cardano` does not match the current session - raise exceptions.TrezorFailure(resp) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to Initialize") - - if self.session_id is not None and resp.session_id == self.session_id: - LOG.info("Successfully resumed session") - elif session_id is not None: - LOG.info("Failed to resume session") - - # TT < 2.3.0 compatibility: - # _refresh_features will clear out the session_id field. We want this function - # to return its value, so that callers can rely on it being either a valid - # session_id, or None if we can't do that. - # Older TT FW does not report session_id in Features and self.session_id might - # be invalid because TT will not allocate a session_id until a passphrase - # exchange happens. - reported_session_id = resp.session_id - self._refresh_features(resp) - return reported_session_id - - def is_outdated(self) -> bool: - if self.features.bootloader_mode: - return False - return self.version < self.model.minimum_version - - def check_firmware_version(self, warn_only: bool = False) -> None: - if self.is_outdated(): - if warn_only: - warnings.warn("Firmware is out of date", stacklevel=2) - else: - raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) + return ver - @expect(messages.Success, field="message", ret_type=str) - def ping( - self, - msg: str, - button_protection: bool = False, - ) -> "MessageType": - # We would like ping to work on any valid TrezorClient instance, but - # due to the protection modes, we need to go through self.call, and that will - # raise an exception if the firmware is too old. - # So we short-circuit the simplest variant of ping with call_raw. - if not button_protection: - # XXX this should be: `with self:` - try: - self.open() - resp = self.call_raw(messages.Ping(message=msg)) - if isinstance(resp, messages.ButtonRequest): - # device is PIN-locked. - # respond and hope for the best - resp = self._callback_button(resp) - return resp - finally: - self.close() - - return self.call( - messages.Ping(message=msg, button_protection=button_protection) - ) + def refresh_features(self) -> None: + self.protocol.update_features() + self._features = self.protocol.get_features() - def get_device_id(self) -> Optional[str]: - return self.features.device_id + def ensure_unlocked(self) -> None: + # TODO implement + raise NotImplementedError - @session - def lock(self, *, _refresh_features: bool = True) -> None: - """Lock the device. + def resume_session(self, session_id: bytes) -> Session: + raise NotImplementedError # TODO - If the device does not have a PIN configured, this will do nothing. - Otherwise, a lock screen will be shown and the device will prompt for PIN - before further actions. + def _get_protocol(self) -> ProtocolAndChannel: + self.transport.open() - This call does _not_ invalidate passphrase cache. If passphrase is in use, - the device will not prompt for it after unlocking. + protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) - To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate - passphrase cache, use `clear_session()`. - """ - # Private argument _refresh_features can be used internally to avoid - # refreshing in cases where we will refresh soon anyway. This is used - # in TrezorClient.clear_session() - self.call(messages.LockDevice()) - if _refresh_features: - self.refresh_features() + protocol.write(messages.Initialize()) - @session - def ensure_unlocked(self) -> None: - """Ensure the device is unlocked and a passphrase is cached. - - If the device is locked, this will prompt for PIN. If passphrase is enabled - and no passphrase is cached for the current session, the device will also - prompt for passphrase. - - After calling this method, further actions on the device will not prompt for - PIN or passphrase until the device is locked or the session becomes invalid. - """ - from .btc import get_address - - get_address(self, "Testnet", PASSPHRASE_TEST_PATH) - self.refresh_features() - - def end_session(self) -> None: - """Close the current session and clear cached passphrase. - - The session will become invalid until `init_device()` is called again. - If passphrase is enabled, further actions will prompt for it again. - - This is a no-op in bootloader mode, as it does not support session management. - """ - # since: 2.3.4, 1.9.4 - try: - if not self.features.bootloader_mode: - self.call(messages.EndSession()) - except exceptions.TrezorFailure: - # A failure most likely means that the FW version does not support - # the EndSession call. We ignore the failure and clear the local session_id. - # The client-side end result is identical. - pass - self.session_id = None - - @session - def clear_session(self) -> None: - """Lock the device and present a fresh session. - - The current session will be invalidated and a new one will be started. If the - device has PIN enabled, it will become locked. - - Equivalent to calling `lock()`, `end_session()` and `init_device()`. - """ - self.lock(_refresh_features=False) - self.end_session() - self.init_device(new_session=True) + response = protocol.read() + self.transport.close() + if isinstance(response, messages.Failure): + if response.code == messages.FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol = ProtocolV2(self.transport, self.mapping) + return protocol + + +def get_default_client( + path: t.Optional[str] = None, + **kwargs: t.Any, +) -> "TrezorClient": + """Get a client for a connected Trezor device. + + Returns a TrezorClient instance with minimum fuss. + + If path is specified, does a prefix-search for the specified device. Otherwise, uses + the value of TREZOR_PATH env variable, or finds first connected Trezor. + If no UI is supplied, instantiates the default CLI UI. + """ + + if path is None: + path = os.getenv("TREZOR_PATH") + + transport = get_transport(path, prefix_search=True) + + return TrezorClient(transport, **kwargs) + + +# class TrezorClient(t.Generic[UI]): +# """Trezor client, a connection to a Trezor device. + +# This class allows you to manage connection state, send and receive protobuf +# messages, handle user interactions, and perform some generic tasks +# (send a cancel message, initialize or clear a session, ping the device). +# """ + +# model: models.TrezorModel +# transport: "Transport" +# session_id: t.Optional[bytes] +# ui: UI +# features: messages.Features + +# def __init__( +# self, +# transport: "Transport", +# ui: UI, +# session_id: t.Optional[bytes] = None, +# derive_cardano: t.Optional[bool] = None, +# model: t.Optional[models.TrezorModel] = None, +# _init_device: bool = True, +# ) -> None: +# """Create a TrezorClient instance. + +# You have to provide a `transport`, i.e., a raw connection to the device. You can +# use `trezorlib.transport.get_transport` to find one. + +# You have to provide an UI implementation for the three kinds of interaction: +# - button request (notify the user that their interaction is needed) +# - PIN request (on T1, ask the user to input numbers for a PIN matrix) +# - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for +# details. + +# You can supply a `session_id` you might have saved in the previous session. If +# you do, the user might not need to enter their passphrase again. + +# You can provide Trezor model information. If not provided, it is detected from +# the model name reported at initialization time. + +# By default, the instance will open a connection to the Trezor device, send an +# `Initialize` message, set up the `features` field from the response, and connect +# to a session. By specifying `_init_device=False`, this step is skipped. Notably, +# this means that `client.features` is unset. Use `client.init_device()` or +# `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. +# Only use this if you are _sure_ that you know what you are doing. This feature +# might be removed at any time. +# """ +# LOG.info(f"creating client instance for device: {transport.get_path()}") +# # Here, self.model could be set to None. Unless _init_device is False, it will +# # get correctly reconfigured as part of the init_device flow. +# self.model = model # type: ignre ["None" is incompatible with "TrezorModel"] +# if self.model: +# self.mapping = self.model.default_mapping +# else: +# self.mapping = mapping.DEFAULT_MAPPING +# self.transport = transport +# self.ui = ui +# self.session_counter = 0 +# self.session_id = session_id +# if _init_device: +# self.init_device(session_id=session_id, derive_cardano=derive_cardano) +# self.resume_session() + +# def open(self) -> None: +# if self.session_counter == 0: +# session_id = self.transport.resume_session(b"") +# if self.session_id != session_id: +# print("Failed to resume session, allocated a new session") +# self.session_id = session_id +# self.transport.deprecated_begin_session() +# self.session_counter += 1 + +# def resume_session(self) -> None: +# new_id = self.transport.resume_session(self.session_id or b"") +# if self.session_id != new_id: +# print("Failed to resume session, allocated a new session") +# self.session_id = new_id + +# def close(self) -> None: +# self.session_counter = max(self.session_counter - 1, 0) +# if self.session_counter == 0: +# # TODO call EndSession here? +# self.transport.deprecated_end_session() + +# def cancel(self) -> None: +# self._raw_write(messages.Cancel()) + +# def call_raw(self, msg: "MessageType") -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 + +# self._raw_write(msg) +# x = self._raw_read() +# return x + +# def _raw_write(self, msg: "MessageType") -> None: +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# LOG.debug( +# f"sending message: {msg.__class__.__name__}", +# extra={"protobuf": msg}, +# ) +# msg_type, msg_bytes = self.mapping.encode(msg) +# LOG.log( +# DUMP_BYTES, +# f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", +# ) +# self.transport.write(msg_type, msg_bytes) + +# def _raw_read(self) -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# msg_type, msg_bytes = self.transport.read() +# print("type/data", msg_type, msg_bytes) +# LOG.log( +# DUMP_BYTES, +# f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", +# ) +# msg = self.mapping.decode(msg_type, msg_bytes) +# LOG.debug( +# f"received message: {msg.__class__.__name__}", +# extra={"protobuf": msg}, +# ) +# return msg + +# def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": +# try: +# pin = self.ui.get_pin(msg.type) +# except exceptions.Cancelled: +# self.call_raw(messages.Cancel()) +# raise + +# if any(d not in "123456789" for d in pin) or not ( +# 1 <= len(pin) <= MAX_PIN_LENGTH +# ): +# self.call_raw(messages.Cancel()) +# raise ValueError("Invalid PIN provided") + +# resp = self.call_raw(messages.PinMatrixAck(pin=pin)) +# if isinstance(resp, messages.Failure) and resp.code in ( +# messages.FailureType.PinInvalid, +# messages.FailureType.PinCancelled, +# messages.FailureType.PinExpected, +# ): +# raise exceptions.PinException(resp.code, resp.message) +# else: +# return resp + +# def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": +# available_on_device = Capability.PassphraseEntry in self.features.capabilities + +# def send_passphrase( +# passphrase: t.Optional[str] = None, on_device: t.Optional[bool] = None +# ) -> "MessageType": +# msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) +# resp = self.call_raw(msg) +# if isinstance(resp, messages.Deprecated_PassphraseStateRequest): +# self.session_id = resp.state +# resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) +# return resp + +# # short-circuit old style entry +# if msg._on_device is True: +# return send_passphrase(None, None) + +# try: +# passphrase = self.ui.get_passphrase(available_on_device=available_on_device) +# except exceptions.Cancelled: +# self.call_raw(messages.Cancel()) +# raise + +# if passphrase is PASSPHRASE_ON_DEVICE: +# if not available_on_device: +# self.call_raw(messages.Cancel()) +# raise RuntimeError("Device is not capable of entering passphrase") +# else: +# return send_passphrase(on_device=True) + +# # else process host-entered passphrase +# if not isinstance(passphrase, str): +# raise RuntimeError("Passphrase must be a str") +# passphrase = Mnemonic.normalize_string(passphrase) +# if len(passphrase) > MAX_PASSPHRASE_LENGTH: +# self.call_raw(messages.Cancel()) +# raise ValueError("Passphrase too long") + +# return send_passphrase(passphrase, on_device=False) + +# def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# # do this raw - send ButtonAck first, notify UI later +# self._raw_write(messages.ButtonAck()) +# self.ui.button_request(msg) +# return self._raw_read() + +# @session +# def call(self, msg: "MessageType") -> "MessageType": +# self.check_firmware_version() +# resp = self.call_raw(msg) +# while True: +# if isinstance(resp, messages.PinMatrixRequest): +# resp = self._callback_pin(resp) +# elif isinstance(resp, messages.PassphraseRequest): +# resp = self._callback_passphrase(resp) +# elif isinstance(resp, messages.ButtonRequest): +# resp = self._callback_button(resp) +# elif isinstance(resp, messages.Failure): +# print("self.call-failure") + +# if resp.code == messages.FailureType.ActionCancelled: +# raise exceptions.Cancelled +# raise exceptions.TrezorFailure(resp) +# else: +# print("self.call-end") +# return resp + +# def _refresh_features(self, features: messages.Features) -> None: +# """Update internal fields based on passed-in Features message.""" + +# if not self.model: +# # Trezor Model One bootloader 1.8.0 or older does not send model name +# model = models.by_internal_name(features.internal_model) +# if model is None: +# model = models.by_name(features.model or "1") +# if model is None: +# raise RuntimeError( +# "Unsupported Trezor model" +# f" (internal_model: {features.internal_model}, model: {features.model})" +# ) +# self.model = model + +# if features.vendor not in self.model.vendors: +# raise RuntimeError("Unsupported device") + +# self.features = features +# self.version = ( +# self.features.major_version, +# self.features.minor_version, +# self.features.patch_version, +# ) +# self.check_firmware_version(warn_only=True) +# if self.features.session_id is not None: +# self.session_id = self.features.session_id +# self.features.session_id = None + +# @session +# def refresh_features(self) -> messages.Features: +# """Reload features from the device. + +# Should be called after changing settings or performing operations that affect +# device state. +# """ +# resp = self.call_raw(messages.GetFeatures()) +# if not isinstance(resp, messages.Features): +# raise exceptions.TrezorException("Unexpected response to GetFeatures") +# self._refresh_features(resp) +# return resp + +# def init_device( +# self, +# *, +# session_id: t.Optional[bytes] = None, +# new_session: bool = False, +# derive_cardano: t.Optional[bool] = None, +# ) -> t.Optional[bytes]: +# """Initialize the device and return a session ID. + +# You can optionally specify a session ID. If the session still exists on the +# device, the same session ID will be returned and the session is resumed. +# Otherwise a different session ID is returned. + +# Specify `new_session=True` to open a fresh session. Since firmware version +# 1.9.0/2.3.0, the previous session will remain cached on the device, and can be +# resumed by calling `init_device` again with the appropriate session ID. + +# If neither `new_session` nor `session_id` is specified, the current session ID +# will be reused. If no session ID was cached, a new session ID will be allocated +# and returned. + +# # Version notes: + +# Trezor One older than 1.9.0 does not have session management. Optional arguments +# have no effect and the function returns None + +# Trezor T older than 2.3.0 does not have session cache. Requesting a new session +# will overwrite the old one. In addition, this function will always return None. +# A valid session_id can be obtained from the `session_id` attribute, but only +# after a passphrase-protected call is performed. You can use the following code: + +# >>> client.init_device() +# >>> client.ensure_unlocked() +# >>> valid_session_id = client.session_id +# """ +# if new_session: +# self.session_id = None +# elif session_id is not None: +# self.session_id = session_id + +# print("before init conn") + +# resp = self.transport.initialize_connection( +# mapping=self.mapping, +# session_id=session_id, +# derive_cardano=derive_cardano, +# ) +# print("here") +# if isinstance(resp, messages.Failure): +# # can happen if `derive_cardano` does not match the current session +# raise exceptions.TrezorFailure(resp) +# if not isinstance(resp, messages.Features): +# raise exceptions.TrezorException("Unexpected response to Initialize") + +# if self.session_id is not None and resp.session_id == self.session_id: +# LOG.info("Successfully resumed session") +# elif session_id is not None: +# LOG.info("Failed to resume session") + +# # TT < 2.3.0 compatibility: +# # _refresh_features will clear out the session_id field. We want this function +# # to return its value, so that callers can rely on it being either a valid +# # session_id, or None if we can't do that. +# # Older TT FW does not report session_id in Features and self.session_id might +# # be invalid because TT will not allocate a session_id until a passphrase +# # exchange happens. +# reported_session_id = resp.session_id +# self._refresh_features(resp) +# print("there:", reported_session_id) +# return reported_session_id + +# def is_outdated(self) -> bool: +# if self.features.bootloader_mode: +# return False +# return self.version < self.model.minimum_version + +# def check_firmware_version(self, warn_only: bool = False) -> None: +# if self.is_outdated(): +# if warn_only: +# warnings.warn("Firmware is out of date", stacklevel=2) +# else: +# raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) + +# @expect(messages.Success, field="message", ret_type=str) +# def ping( +# self, +# msg: str, +# button_protection: bool = False, +# ) -> "MessageType": +# # We would like ping to work on any valid TrezorClient instance, but +# # due to the protection modes, we need to go through self.call, and that will +# # raise an exception if the firmware is too old. +# # So we short-circuit the simplest variant of ping with call_raw. +# if not button_protection: +# # XXX this should be: `with self:` +# try: +# self.open() +# resp = self.call_raw(messages.Ping(message=msg)) +# if isinstance(resp, messages.ButtonRequest): +# # device is PIN-locked. +# # respond and hope for the best +# resp = self._callback_button(resp) +# return resp +# finally: +# self.close() + +# return self.call( +# messages.Ping(message=msg, button_protection=button_protection) +# ) + +# def get_device_id(self) -> t.Optional[str]: +# return self.features.device_id + +# @session +# def lock(self, *, _refresh_features: bool = True) -> None: +# """Lock the device. + +# If the device does not have a PIN configured, this will do nothing. +# Otherwise, a lock screen will be shown and the device will prompt for PIN +# before further actions. + +# This call does _not_ invalidate passphrase cache. If passphrase is in use, +# the device will not prompt for it after unlocking. + +# To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate +# passphrase cache, use `clear_session()`. +# """ +# # Private argument _refresh_features can be used internally to avoid +# # refreshing in cases where we will refresh soon anyway. This is used +# # in TrezorClient.clear_session() +# self.call(messages.LockDevice()) +# if _refresh_features: +# self.refresh_features() + +# @session +# def ensure_unlocked(self) -> None: +# """Ensure the device is unlocked and a passphrase is cached. + +# If the device is locked, this will prompt for PIN. If passphrase is enabled +# and no passphrase is cached for the current session, the device will also +# prompt for passphrase. + +# After calling this method, further actions on the device will not prompt for +# PIN or passphrase until the device is locked or the session becomes invalid. +# """ +# from .btc import get_address + +# get_address(self, "Testnet", PASSPHRASE_TEST_PATH) +# self.refresh_features() + +# def end_session(self) -> None: +# """Close the current session and clear cached passphrase. + +# The session will become invalid until `init_device()` is called again. +# If passphrase is enabled, further actions will prompt for it again. + +# This is a no-op in bootloader mode, as it does not support session management. +# """ +# # since: 2.3.4, 1.9.4 +# print("end session") +# try: +# if not self.features.bootloader_mode: +# self.transport.end_session(self.session_id or b"") +# # self.call(messages.EndSession()) +# except exceptions.TrezorFailure: +# # A failure most likely means that the FW version does not support +# # the EndSession call. We ignore the failure and clear the local session_id. +# # The client-side end result is identical. +# pass +# except ValueError as e: +# print(e) +# print(e.args) +# self.session_id = None + +# @session +# def clear_session(self) -> None: +# """Lock the device and present a fresh session. + +# The current session will be invalidated and a new one will be started. If the +# device has PIN enabled, it will become locked. + +# Equivalent to calling `lock()`, `end_session()` and `init_device()`. +# """ +# self.lock(_refresh_features=False) +# self.end_session() +# self.init_device(new_session=True) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index c5373d2b006..2b52408ad00 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,32 +14,20 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import json import logging import re import textwrap import time +import typing as t from contextlib import contextmanager from copy import deepcopy from datetime import datetime from enum import Enum, IntEnum, auto from itertools import zip_longest from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) from mnemonic import Mnemonic @@ -49,24 +37,27 @@ from .log import DUMP_BYTES from .messages import DebugWaitType from .tools import expect +from .transport.new.protocol_v1 import ProtocolV1 +from .transport.session import Session -if TYPE_CHECKING: +if t.TYPE_CHECKING: from typing_extensions import Protocol from .messages import PinMatrixRequestType from .transport import Transport - ExpectedMessage = Union[ - protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" + ExpectedMessage = t.Union[ + protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter" ] - AnyDict = Dict[str, Any] + AnyDict = t.Dict[str, t.Any] class InputFunc(Protocol): + def __call__( self, - hold_ms: Optional[int] = None, - wait: Optional[bool] = None, + hold_ms: int | None = None, + wait: bool | None = None, ) -> "LayoutContent": ... @@ -105,11 +96,13 @@ def __init__(self, json_str: str) -> None: except json.JSONDecodeError: self.dict = {} - def top_level_value(self, key: str) -> Any: + def top_level_value(self, key: str) -> t.Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> List["AnyDict"]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_objects_with_key_and_value( + self, key: str, value: t.Any + ) -> t.List["AnyDict"]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if data.get(key) == value: yield data @@ -122,8 +115,8 @@ def recursively_find(data: Any) -> Iterator[Any]: return list(recursively_find(self.dict)) def find_unique_object_with_key_and_value( - self, key: str, value: Any - ) -> Optional["AnyDict"]: + self, key: str, value: t.Any + ) -> "AnyDict | None": objects = self.find_objects_with_key_and_value(key, value) if not objects: return None @@ -131,9 +124,9 @@ def find_unique_object_with_key_and_value( return objects[0] def find_values_by_key( - self, key: str, only_type: Optional[type] = None - ) -> List[Any]: - def recursively_find(data: Any) -> Iterator[Any]: + self, key: str, only_type: type | None = None + ) -> t.List[t.Any]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if key in data: yield data[key] @@ -151,8 +144,8 @@ def recursively_find(data: Any) -> Iterator[Any]: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: Optional[type] = None - ) -> Any: + self, key: str, default: t.Any, only_type: type | None = None + ) -> t.Any: values = self.find_values_by_key(key, only_type=only_type) if not values: return default @@ -163,7 +156,7 @@ def find_unique_value_by_key( class LayoutContent(UnstructuredJSONReader): """Contains helper functions to extract specific parts of the layout.""" - def __init__(self, json_tokens: Sequence[str]) -> None: + def __init__(self, json_tokens: t.Sequence[str]) -> None: json_str = "".join(json_tokens) super().__init__(json_str) @@ -171,7 +164,7 @@ def main_component(self) -> str: """Getting the main component of the layout.""" return self.top_level_value("component") or "no main component" - def all_components(self) -> List[str]: + def all_components(self) -> t.List[str]: """Getting all components of the layout.""" return self.find_values_by_key("component", only_type=str) @@ -209,7 +202,7 @@ def _get_str_or_dict_text(self, key: str) -> str: def title(self) -> str: """Getting text that is displayed as a title and potentially subtitle.""" # There could be possibly subtitle as well - title_parts: List[str] = [] + title_parts: t.List[str] = [] title = self._get_str_or_dict_text("title") if title: @@ -244,7 +237,7 @@ def screen_content(self) -> str: # Look for paragraphs first (will match most of the time for TT) paragraphs = self.raw_content_paragraphs() if paragraphs: - main_text_blocks: List[str] = [] + main_text_blocks: t.List[str] = [] for par in paragraphs: par_content = "" for line_or_newline in par: @@ -294,13 +287,13 @@ def screen_content(self) -> str: # Default when not finding anything return self.main_component() - def raw_content_paragraphs(self) -> Optional[List[List[str]]]: + def raw_content_paragraphs(self) -> t.List[t.List[str]] | None: """Getting raw paragraphs as sent from Rust.""" return self.find_unique_value_by_key("paragraphs", default=None, only_type=list) - def tt_check_seed_button_contents(self) -> List[str]: + def tt_check_seed_button_contents(self) -> t.List[str]: """Getting list of button contents.""" - buttons: List[str] = [] + buttons: t.List[str] = [] button_objects = self.find_objects_with_key_and_value("component", "Button") for button in button_objects: if button.get("icon"): @@ -309,7 +302,7 @@ def tt_check_seed_button_contents(self) -> List[str]: buttons.append(button["text"]) return buttons - def button_contents(self) -> List[str]: + def button_contents(self) -> t.List[str]: """Getting list of button contents.""" buttons = self.find_unique_value_by_key("buttons", default={}, only_type=dict) @@ -331,13 +324,13 @@ def get_button_content(btn_key: str) -> str: button_keys = ("left_btn", "middle_btn", "right_btn") return [get_button_content(btn_key) for btn_key in button_keys] - def seed_words(self) -> List[str]: + def seed_words(self) -> t.List[str]: """Get all the seed words on the screen in order. Example content: "1. ladybug\n2. acid\n3. academic\n4. afraid" -> ["ladybug", "acid", "academic", "afraid"] """ - words: List[str] = [] + words: t.List[str] = [] for line in self.screen_content().split("\n"): # Dot after index is optional (present on TT, not on TR) match = re.match(r"^\s*\d+\.? (\w+)$", line) @@ -377,7 +370,7 @@ def get_middle_choice(self) -> str: """What is the choice being selected right now.""" return self.choice_items()[1] - def choice_items(self) -> Tuple[str, str, str]: + def choice_items(self) -> t.Tuple[str, str, str]: """Getting actions for all three possible buttons.""" choice_obj = self.find_unique_value_by_key( "choice_page", default={}, only_type=dict @@ -396,15 +389,15 @@ def footer(self) -> str: return footer.get("description", "") + " " + footer.get("instruction", "") -def multipage_content(layouts: List[LayoutContent]) -> str: +def multipage_content(layouts: t.List[LayoutContent]) -> str: """Get overall content from multiple-page layout.""" return "".join(layout.text_content() for layout in layouts) def _make_input_func( - button: Optional[messages.DebugButton] = None, - physical_button: Optional[messages.DebugPhysicalButton] = None, - swipe: Optional[messages.DebugSwipeDirection] = None, + button: messages.DebugButton | None = None, + physical_button: messages.DebugPhysicalButton | None = None, + swipe: messages.DebugSwipeDirection | None = None, ) -> "InputFunc": decision = messages.DebugLinkDecision( button=button, @@ -414,8 +407,8 @@ def _make_input_func( def input_func( self: "DebugLink", - hold_ms: Optional[int] = None, - wait: Optional[bool] = None, + hold_ms: int | None = None, + wait: bool | None = None, ) -> LayoutContent: __tracebackhide__ = True # for pytest # pylint: disable=W0612 decision.hold_ms = hold_ms @@ -425,24 +418,26 @@ def input_func( class DebugLink: + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.transport = transport self.allow_interactions = auto_interact self.mapping = mapping.DEFAULT_MAPPING + self.protocol = ProtocolV1(self.transport, self.mapping) # To be set by TrezorClientDebugLink (is not known during creation time) - self.model: Optional[models.TrezorModel] = None - self.version: Tuple[int, int, int] = (0, 0, 0) + self.model: models.TrezorModel | None = None + self.version: t.Tuple[int, int, int] = (0, 0, 0) # Where screenshots are being saved - self.screenshot_recording_dir: Optional[str] = None + self.screenshot_recording_dir: str | None = None # For T1 screenshotting functionality in DebugUI - self.t1_screenshot_directory: Optional[Path] = None + self.t1_screenshot_directory: Path | None = None self.t1_screenshot_counter = 0 # Optional file for saving text representation of the screen - self.screen_text_file: Optional[Path] = None + self.screen_text_file: Path | None = None self.last_screen_content = "" self.waiting_for_layout_change = False @@ -465,16 +460,22 @@ def layout_type(self) -> LayoutType: assert self.model is not None return LayoutType.from_model(self.model) - def set_screen_text_file(self, file_path: Optional[Path]) -> None: + def set_screen_text_file(self, file_path: Path | None) -> None: if file_path is not None: file_path.write_bytes(b"") self.screen_text_file = file_path def open(self) -> None: - self.transport.begin_session() + self.transport.open() + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_begin_session() def close(self) -> None: - self.transport.end_session() + pass + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_end_session() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -491,15 +492,10 @@ def _write(self, msg: protobuf.MessageType) -> None: DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - self.transport.write(msg_type, msg_bytes) + self.protocol.write(msg) def _read(self) -> protobuf.MessageType: - ret_type, ret_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}", - ) - msg = self.mapping.decode(ret_type, ret_bytes) + msg = self.protocol.read() # Collapse tokens to make log use less lines. msg_for_log = msg @@ -513,14 +509,21 @@ def _read(self) -> protobuf.MessageType: ) return msg - def _call(self, msg: protobuf.MessageType) -> Any: + def _call(self, msg: protobuf.MessageType) -> t.Any: self._write(msg) return self._read() def state( - self, wait_type: DebugWaitType = DebugWaitType.CURRENT_LAYOUT + self, + wait_type: DebugWaitType = DebugWaitType.CURRENT_LAYOUT, + thp_channel_id: bytes | None = None, ) -> messages.DebugLinkState: - result = self._call(messages.DebugLinkGetState(wait_layout=wait_type)) + result = self._call( + messages.DebugLinkGetState( + wait_layout=wait_type, + thp_channel_id=thp_channel_id, + ) + ) while not isinstance(result, (messages.Failure, messages.DebugLinkState)): result = self._read() if isinstance(result, messages.Failure): @@ -548,7 +551,7 @@ def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: return LayoutContent(obj.tokens) @contextmanager - def wait_for_layout_change(self) -> Iterator[LayoutContent]: + def wait_for_layout_change(self) -> t.Iterator[LayoutContent]: # set up a dummy layout content object to be yielded layout_content = LayoutContent( ["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("] @@ -596,7 +599,7 @@ def watch_layout(self, watch: bool) -> None: """ self._call(messages.DebugLinkWatchLayout(watch=watch)) - def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: + def encode_pin(self, pin: str, matrix: str | None = None) -> str: """Transform correct PIN according to the displayed matrix.""" if matrix is None: matrix = self.state().matrix @@ -606,7 +609,7 @@ def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]: + def read_recovery_word(self) -> t.Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -615,7 +618,7 @@ def read_reset_word(self) -> str: return state.reset_word def _decision( - self, decision: messages.DebugLinkDecision, wait: Optional[bool] = None + self, decision: messages.DebugLinkDecision, wait: bool | None = None ) -> LayoutContent: """Send a debuglink decision and returns the resulting layout. @@ -677,15 +680,15 @@ def _decision( ) """Press right button. See `_decision` for more details.""" - def input(self, word: str, wait: Optional[bool] = None) -> LayoutContent: + def input(self, word: str, wait: bool | None = None) -> LayoutContent: """Send text input to the device. See `_decision` for more details.""" return self._decision(messages.DebugLinkDecision(input=word), wait) def click( self, - click: Tuple[int, int], - hold_ms: Optional[int] = None, - wait: Optional[bool] = None, + click: t.Tuple[int, int], + hold_ms: int | None = None, + wait: bool | None = None, ) -> LayoutContent: """Send a click to the device. See `_decision` for more details.""" x, y = click @@ -736,9 +739,7 @@ def stop(self) -> None: def reseed(self, value: int) -> protobuf.MessageType: return self._call(messages.DebugLinkReseedRandom(value=value)) - def start_recording( - self, directory: str, refresh_index: Optional[int] = None - ) -> None: + def start_recording(self, directory: str, refresh_index: int | None = None) -> None: self.screenshot_recording_dir = directory # Different recording logic between core and legacy if self.model is not models.T1B1: @@ -793,7 +794,7 @@ def _save_screenshot_t1(self, data: bytes) -> None: assert len(data) == 128 * 64 // 8 - pixels: List[int] = [] + pixels: t.List[int] = [] for byteline in range(64 // 8): offset = byteline * 128 row = data[offset : offset + 128] @@ -826,7 +827,7 @@ def close(self) -> None: def _call( self, msg: protobuf.MessageType, nowait: bool = False - ) -> Optional[messages.DebugLinkState]: + ) -> messages.DebugLinkState | None: if not nowait: if isinstance(msg, messages.DebugLinkGetState): return messages.DebugLinkState() @@ -844,10 +845,10 @@ def __init__(self, debuglink: DebugLink) -> None: self.clear() def clear(self) -> None: - self.pins: Optional[Iterator[str]] = None + self.pins: t.Iterator[str] | None = None self.passphrase = "" - self.input_flow: Union[ - Generator[None, messages.ButtonRequest, None], object, None + self.input_flow: t.Union[ + t.Generator[None, messages.ButtonRequest, None], object, None ] = None def _default_input_flow(self, br: messages.ButtonRequest) -> None: @@ -869,6 +870,11 @@ def _default_input_flow(self, br: messages.ButtonRequest) -> None: else: self.debuglink.press_yes() + def debug_callback_button(self, session: Session, msg: t.Any) -> t.Any: + session._write(messages.ButtonAck()) + self.button_request(msg) + return session._read() + def button_request(self, br: messages.ButtonRequest) -> None: self.debuglink.snapshot_legacy() @@ -878,12 +884,12 @@ def button_request(self, br: messages.ButtonRequest) -> None: raise AssertionError("input flow ended prematurely") else: try: - assert isinstance(self.input_flow, Generator) + assert isinstance(self.input_flow, t.Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE - def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str: + def get_pin(self, code: "PinMatrixRequestType|None" = None) -> str: self.debuglink.snapshot_legacy() if self.pins is None: @@ -900,12 +906,15 @@ def get_passphrase(self, available_on_device: bool) -> str: class MessageFilter: - def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None: + + def __init__( + self, message_type: t.Type[protobuf.MessageType], **fields: t.Any + ) -> None: self.message_type = message_type - self.fields: Dict[str, Any] = {} + self.fields: t.Dict[str, t.Any] = {} self.update_fields(**fields) - def update_fields(self, **fields: Any) -> "MessageFilter": + def update_fields(self, **fields: t.Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -953,7 +962,7 @@ def match(self, message: protobuf.MessageType) -> bool: return True def to_string(self, maxwidth: int = 80) -> str: - fields: List[Tuple[str, str]] = [] + fields: t.List[t.Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -974,7 +983,7 @@ def to_string(self, maxwidth: int = 80) -> str: if len(oneline_str) < maxwidth: return f"{self.message_type.__name__}({oneline_str})" else: - item: List[str] = [] + item: t.List[str] = [] item.append(f"{self.message_type.__name__}(") for pair in pairs: item.append(f" {pair}") @@ -983,7 +992,8 @@ def to_string(self, maxwidth: int = 80) -> str: class MessageFilterGenerator: - def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: + + def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -991,6 +1001,227 @@ def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: message_filters = MessageFilterGenerator() +class SessionDebugWrapper(Session): + def __init__(self, session: Session) -> None: + self._session = session + self.reset_debug_features() + + @property + def client(self) -> TrezorClientDebugLink: + assert isinstance(self._session.client, TrezorClientDebugLink) + return self._session.client + + @property + def id(self) -> bytes: + return self._session.id + + def _write(self, msg: t.Any) -> None: + print("writing message:", type(msg)) + self._session._write(msg) + + def _read(self) -> t.Any: + resp = self._session._read() + print("reading message:", type(resp)) + if self.actual_responses is not None: + self.actual_responses.append(resp) + return resp + + def set_expected_responses( + self, + expected: t.List["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> session.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + + def lock(self, *, _refresh_features: bool = True) -> None: + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate + passphrase cache, use `clear_session()`. + """ + # TODO update the documentation above + # Private argument _refresh_features can be used internally to avoid + # refreshing in cases where we will refresh soon anyway. This is used + # in TrezorClient.clear_session() + self.call(messages.LockDevice()) + if _refresh_features: + self.refresh_features() + + def cancel(self) -> None: + self._write(messages.Cancel()) + + def set_filter( + self, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ) -> None: + """Configure a filter function for a specified message type. + + The `callback` must be a function that accepts a protobuf message, and returns + a (possibly modified) protobuf message of the same type. Whenever a message + is sent or received that matches `message_type`, `callback` is invoked on the + message and its result is substituted for the original. + + Useful for test scenarios with an active malicious actor on the wire. + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + self.filters[message_type] = callback + + def reset_debug_features(self) -> None: + """Prepare the debugging session for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.in_with_statement = False + self.expected_responses: t.List[MessageFilter] | None = None + self.actual_responses: t.List[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ] = {} + self.button_callback = self.client.button_callback + self.pin_callback = self.client.pin_callback + + def __enter__(self) -> "SessionDebugWrapper": + # For usage in with/expected_responses + if self.in_with_statement: + raise RuntimeError("Do not nest!") + self.in_with_statement = True + return self + + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + + # grab a copy of the inputflow generator to raise an exception through it + if isinstance(self.client.ui, DebugUI): + input_flow = self.client.ui.input_flow + else: + input_flow = None + + self.reset_debug_features() + + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) + if isinstance(input_flow, t.Generator): + # Ensure that the input flow is exhausted + try: + input_flow.throw( + AssertionError("input flow continues past end of test") + ) + except StopIteration: + pass + + elif isinstance(input_flow, t.Generator): + # Propagate the exception through the input flow, so that we see in + # traceback where it is stuck. + input_flow.throw(exc_type, value, traceback) + + @classmethod + def _verify_responses( + cls, + expected: t.List[MessageFilter] | None, + actual: t.List[protobuf.MessageType] | None, + ) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + if expected is None and actual is None: + return + + assert expected is not None + assert actual is not None + + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = cls._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) + + if act is None: + output = cls._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) + + if not exp.match(act): + output = cls._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) + + @staticmethod + def _expectation_lines( + expected: t.List[MessageFilter], current: int + ) -> t.List[str]: + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) + output: t.List[str] = [] + output.append("Expected responses:") + if start_at > 0: + output.append(f" (...{start_at} previous responses omitted)") + for i in range(start_at, stop_at): + exp = expected[i] + prefix = " " if i != current else ">>> " + output.append(textwrap.indent(exp.to_string(), prefix)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") + + output.append("") + return output + + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses # and other functionality for unit tests @@ -1016,16 +1247,18 @@ def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: raise # set transport explicitly so that sync_responses can work + super().__init__(transport) + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) self.reset_debug_features() self.sync_responses() - super().__init__(transport, ui=self.ui) - # So that we can choose right screenshotting logic (T1 vs TT) # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version + self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: @@ -1037,33 +1270,41 @@ def reset_debug_features(self) -> None: Clears all debugging state that might have been modified by a testcase. """ self.ui: DebugUI = DebugUI(self.debug) + # self.pin_callback = self.ui.debug_callback_button self.in_with_statement = False - self.expected_responses: Optional[List[MessageFilter]] = None - self.actual_responses: Optional[List[protobuf.MessageType]] = None - self.filters: Dict[ - Type[protobuf.MessageType], - Optional[Callable[[protobuf.MessageType], protobuf.MessageType]], + self.expected_responses: t.List[MessageFilter] | None = None + self.actual_responses: t.List[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} + self._management_session = self.get_management_session(new_session=True) + + @property + def button_callback(self): + return self.ui.debug_callback_button + def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" - if self.session_counter == 0: - self.open() + # if self.session_counter == 0: + # self.open() + # TODO check if is this needed def open(self) -> None: - super().open() - if self.session_counter == 1: - self.debug.open() + pass + # TODO is this needed? + # self.debug.open() def close(self) -> None: - if self.session_counter == 1: - self.debug.close() - super().close() + pass + # TODO is this needed? + # self.debug.close() def set_filter( self, - message_type: Type[protobuf.MessageType], - callback: Optional[Callable[[protobuf.MessageType], protobuf.MessageType]], + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ) -> None: """Configure a filter function for a specified message type. @@ -1088,7 +1329,8 @@ def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: return msg def set_input_flow( - self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] + self, + input_flow: t.Generator[None, messages.ButtonRequest | None, None], ) -> None: """Configure a sequence of input events for the current with-block. @@ -1122,6 +1364,7 @@ def set_input_flow( if not hasattr(input_flow, "send"): raise RuntimeError("input_flow should be a generator function") self.ui.input_flow = input_flow + assert input_flow is not None input_flow.send(None) # start the generator def watch_layout(self, watch: bool = True) -> None: @@ -1144,7 +1387,7 @@ def __enter__(self) -> "TrezorClientDebugLink": self.in_with_statement = True return self - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # copy expected/actual responses before clearing them @@ -1164,13 +1407,14 @@ def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, Generator): + elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) def set_expected_responses( - self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + self, + expected: t.List["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], ) -> None: """Set a sequence of expected responses to client calls. @@ -1209,7 +1453,7 @@ def set_expected_responses( ] self.actual_responses = [] - def use_pin_sequence(self, pins: Iterable[str]) -> None: + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ @@ -1217,6 +1461,7 @@ def use_pin_sequence(self, pins: Iterable[str]) -> None: def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" + self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1226,21 +1471,22 @@ def use_mnemonic(self, mnemonic: str) -> None: def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - resp = super()._raw_read() + resp = self.get_management_session()._read() resp = self._filter_message(resp) if self.actual_responses is not None: self.actual_responses.append(resp) return resp def _raw_write(self, msg: protobuf.MessageType) -> None: - return super()._raw_write(self._filter_message(msg)) + return self.get_management_session()._write(self._filter_message(msg)) @staticmethod - def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]: + def _expectation_lines( + expected: t.List[MessageFilter], current: int + ) -> t.List[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output: List[str] = [] + output: t.List[str] = [] output.append("Expected responses:") if start_at > 0: output.append(f" (...{start_at} previous responses omitted)") @@ -1258,8 +1504,8 @@ def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str] @classmethod def _verify_responses( cls, - expected: Optional[List[MessageFilter]], - actual: Optional[List[protobuf.MessageType]], + expected: t.List[MessageFilter] | None, + actual: t.List[protobuf.MessageType] | None, ) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 @@ -1304,23 +1550,25 @@ def sync_responses(self) -> None: # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() + # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) + self.transport.open() try: - self.transport.write(*cancel_msg) - + # self.protocol.write(messages.Cancel()) message = "SYNC" + secrets.token_hex(8) - ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) - self.transport.write(*ping_msg) + self.get_management_session()._write(messages.Ping(message=message)) resp = None while resp != messages.Success(message=message): - msg_id, msg_bytes = self.transport.read() try: - resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) + resp = self.get_management_session()._read() + + raise Exception + except Exception: pass + finally: - self.transport.end_session() + pass # TODO fix + # self.transport.end_session(self.session_id or b"") def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() @@ -1334,11 +1582,11 @@ def mnemonic_callback(self, _) -> str: @expect(messages.Success, field="message", ret_type=str) def load_device( - client: "TrezorClient", - mnemonic: Union[str, Iterable[str]], - pin: Optional[str], + session: "Session", + mnemonic: str | t.Iterable[str], + pin: str | None, passphrase_protection: bool, - label: Optional[str], + label: str | None, skip_checksum: bool = False, needs_backup: bool = False, no_backup: bool = False, @@ -1348,12 +1596,12 @@ def load_device( mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call device.wipe() and try again." ) - resp = client.call( + resp = session.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1364,7 +1612,7 @@ def load_device( no_backup=no_backup, ) ) - client.init_device() + session.refresh_features() return resp @@ -1373,11 +1621,11 @@ def load_device( @expect(messages.Success, field="message", ret_type=str) -def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: - if client.features.bootloader_mode is not True: +def prodtest_t1(session: "Session") -> protobuf.MessageType: + if session.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - return client.call( + return session.call( messages.ProdTestT1( payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" ) @@ -1386,8 +1634,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: def record_screen( debug_client: "TrezorClientDebugLink", - directory: Union[str, None], - report_func: Union[Callable[[str], None], None] = None, + directory: str | None, + report_func: t.Callable[[str], None] | None = None, ) -> None: """Record screen changes into a specified directory. @@ -1433,5 +1681,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: @expect(messages.Success, field="message", ret_type=str) -def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: - return client.call(messages.DebugLinkOptigaSetSecMax()) +def optiga_set_sec_max(session: "Session") -> protobuf.MessageType: + return session.call(messages.DebugLinkOptigaSetSecMax()) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index 799168d618c..b54cbfc0fa9 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -23,20 +23,19 @@ from . import messages from .exceptions import Cancelled, TrezorException -from .tools import Address, expect, session +from .tools import Address, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session RECOVERY_BACK = "\x08" # backspace character, sent literally @expect(messages.Success, field="message", ret_type=str) -@session def apply_settings( - client: "TrezorClient", + session: "Session", label: Optional[str] = None, language: Optional[str] = None, use_passphrase: Optional[bool] = None, @@ -67,13 +66,13 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings) - client.refresh_features() + out = session.call(settings) + session.refresh_features() return out def _send_language_data( - client: "TrezorClient", + session: "Session", request: "messages.TranslationDataRequest", language_data: bytes, ) -> "MessageType": @@ -83,76 +82,70 @@ def _send_language_data( data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] - response = client.call(messages.TranslationDataAck(data_chunk=chunk)) + response = session.call(messages.TranslationDataAck(data_chunk=chunk)) return response @expect(messages.Success, field="message", ret_type=str) -@session def change_language( - client: "TrezorClient", + session: "Session", language_data: bytes, show_display: bool | None = None, ) -> "MessageType": data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) - response = client.call(msg) + response = session.call(msg) if data_length > 0: assert isinstance(response, messages.TranslationDataRequest) - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) assert isinstance(response, messages.Success) - client.refresh_features() # changing the language in features + session.refresh_features() # changing the language in features return response @expect(messages.Success, field="message", ret_type=str) -@session -def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": - out = client.call(messages.ApplyFlags(flags=flags)) - client.refresh_features() +def apply_flags(session: "Session", flags: int) -> "MessageType": + out = session.call(messages.ApplyFlags(flags=flags)) + session.refresh_features() return out @expect(messages.Success, field="message", ret_type=str) -@session -def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangePin(remove=remove)) - client.refresh_features() +def change_pin(session: "Session", remove: bool = False) -> "MessageType": + ret = session.call(messages.ChangePin(remove=remove)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangeWipeCode(remove=remove)) - client.refresh_features() +def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType": + ret = session.call(messages.ChangeWipeCode(remove=remove)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType + session: "Session", operation: messages.SdProtectOperationType ) -> "MessageType": - ret = client.call(messages.SdProtect(operation=operation)) - client.refresh_features() + ret = session.call(messages.SdProtect(operation=operation)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session -def wipe(client: "TrezorClient") -> "MessageType": - ret = client.call(messages.WipeDevice()) - if not client.features.bootloader_mode: - client.init_device() +def wipe(session: "Session") -> "MessageType": + + ret = session.call(messages.WipeDevice()) + # if not session.features.bootloader_mode: + # session.refresh_features() return ret -@session def recover( - client: "TrezorClient", + session: "Session", word_count: int = 24, passphrase_protection: bool = False, pin_protection: bool = True, @@ -188,13 +181,13 @@ def recover( if type is None: type = messages.RecoveryType.NormalRecovery - if client.features.model == "1" and input_callback is None: + if session.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") if word_count not in (12, 18, 24): raise ValueError("Invalid word count. Use 12/18/24") - if client.features.initialized and type == messages.RecoveryType.NormalRecovery: + if session.features.initialized and type == messages.RecoveryType.NormalRecovery: raise RuntimeError( "Device already initialized. Call device.wipe() and try again." ) @@ -216,24 +209,23 @@ def recover( msg.label = label msg.u2f_counter = u2f_counter - res = client.call(msg) + res = session.call(msg) while isinstance(res, messages.WordRequest): try: assert input_callback is not None inp = input_callback(res.type) - res = client.call(messages.WordAck(word=inp)) + res = session.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(messages.Cancel()) + res = session.call(messages.Cancel()) - client.init_device() + session.refresh_features() return res @expect(messages.Success, field="message", ret_type=str) -@session def reset( - client: "TrezorClient", + session: "Session", display_random: bool = False, strength: Optional[int] = None, passphrase_protection: bool = False, @@ -257,13 +249,13 @@ def reset( DeprecationWarning, ) - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." ) if strength is None: - if client.features.model == "1": + if session.features.model == "1": strength = 256 else: strength = 128 @@ -280,25 +272,24 @@ def reset( backup_type=backup_type, ) - resp = client.call(msg) + resp = session.call(msg) if not isinstance(resp, messages.EntropyRequest): raise RuntimeError("Invalid response, expected EntropyRequest") external_entropy = os.urandom(32) # LOG.debug("Computer generated entropy: " + external_entropy.hex()) - ret = client.call(messages.EntropyAck(entropy=external_entropy)) - client.init_device() + ret = session.call(messages.EntropyAck(entropy=external_entropy)) + session.refresh_features() # TODO is necessary? return ret @expect(messages.Success, field="message", ret_type=str) -@session def backup( - client: "TrezorClient", + session: "Session", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), ) -> "MessageType": - ret = client.call( + ret = session.call( messages.BackupDevice( group_threshold=group_threshold, groups=[ @@ -307,37 +298,36 @@ def backup( ], ) ) - client.refresh_features() + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -def cancel_authorization(client: "TrezorClient") -> "MessageType": - return client.call(messages.CancelAuthorization()) +def cancel_authorization(session: "Session") -> "MessageType": + return session.call(messages.CancelAuthorization()) @expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) -def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": - resp = client.call(messages.UnlockPath(address_n=n)) +def unlock_path(session: "Session", n: "Address") -> "MessageType": + resp = session.call(messages.UnlockPath(address_n=n)) # Cancel the UnlockPath workflow now that we have the authentication code. try: - client.call(messages.Cancel()) + session.call(messages.Cancel()) except Cancelled: return resp else: raise TrezorException("Unexpected response in UnlockPath flow") -@session @expect(messages.Success, field="message", ret_type=str) def reboot_to_bootloader( - client: "TrezorClient", + session: "Session", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", ) -> "MessageType": - response = client.call( + response = session.call( messages.RebootToBootloader( boot_command=boot_command, firmware_header=firmware_header, @@ -345,42 +335,37 @@ def reboot_to_bootloader( ) ) if isinstance(response, messages.TranslationDataRequest): - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) return response -@session @expect(messages.Success, field="message", ret_type=str) -def show_device_tutorial(client: "TrezorClient") -> "MessageType": - return client.call(messages.ShowDeviceTutorial()) +def show_device_tutorial(session: "Session") -> "MessageType": + return session.call(messages.ShowDeviceTutorial()) -@session @expect(messages.Success, field="message", ret_type=str) -def unlock_bootloader(client: "TrezorClient") -> "MessageType": - return client.call(messages.UnlockBootloader()) +def unlock_bootloader(session: "Session") -> "MessageType": + return session.call(messages.UnlockBootloader()) @expect(messages.Success, field="message", ret_type=str) -@session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": +def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType": """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) - client.refresh_features() + ret = session.call(messages.SetBusy(expiry_ms=expiry_ms)) + session.refresh_features() return ret @expect(messages.AuthenticityProof) -def authenticate(client: "TrezorClient", challenge: bytes): - return client.call(messages.AuthenticateDevice(challenge=challenge)) +def authenticate(session: "Session", challenge: bytes): + return session.call(messages.AuthenticateDevice(challenge=challenge)) @expect(messages.Success, field="message", ret_type=str) -def set_brightness( - client: "TrezorClient", value: Optional[int] = None -) -> "MessageType": - return client.call(messages.SetBrightness(value=value)) +def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType": + return session.call(messages.SetBrightness(value=value)) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index 1ffaafb4ab7..fffe6f0adc2 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,12 +18,12 @@ from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, expect, session +from .tools import b58decode, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session def name_to_number(name: str) -> int: @@ -321,17 +321,16 @@ def parse_transaction_json( @expect(messages.EosPublicKey) def get_public_key( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> "MessageType": - response = client.call( + response = session.call( messages.EosGetPublicKey(address_n=n, show_display=show_display) ) return response -@session def sign_tx( - client: "TrezorClient", + session: "Session", address: "Address", transaction: dict, chain_id: str, @@ -347,11 +346,11 @@ def sign_tx( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) try: while isinstance(response, messages.EosTxActionRequest): - response = client.call(actions.pop(0)) + response = session.call(actions.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 1cf2eeeaed1..60eaa3366ba 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,12 +18,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import expect, prepare_message_bytes, session, unharden +from .tools import expect, prepare_message_bytes, unharden if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session def int_to_big_endian(value: int) -> bytes: @@ -163,13 +163,13 @@ def network_from_address_n( @expect(messages.EthereumAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, @@ -181,16 +181,15 @@ def get_address( @expect(messages.EthereumPublicKey) def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> "MessageType": - return client.call( + return session.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display) ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", n: "Address", nonce: int, gas_price: int, @@ -226,13 +225,13 @@ def sign_tx( data, chunk = data[1024:], data[:1024] msg.data_initial_chunk = chunk - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -247,9 +246,8 @@ def sign_tx( return response.signature_v, response.signature_r, response.signature_s -@session def sign_tx_eip1559( - client: "TrezorClient", + session: "Session", n: "Address", *, nonce: int, @@ -282,13 +280,13 @@ def sign_tx_eip1559( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -299,13 +297,13 @@ def sign_tx_eip1559( @expect(messages.EthereumMessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), @@ -317,7 +315,7 @@ def sign_message( @expect(messages.EthereumTypedDataSignature) def sign_typed_data( - client: "TrezorClient", + session: "Session", n: "Address", data: Dict[str, Any], *, @@ -333,7 +331,7 @@ def sign_typed_data( metamask_v4_compat=metamask_v4_compat, definitions=definitions, ) - response = client.call(request) + response = session.call(request) # Sending all the types while isinstance(response, messages.EthereumTypedDataStructRequest): @@ -349,7 +347,7 @@ def sign_typed_data( members.append(struct_member) request = messages.EthereumTypedDataStructAck(members=members) - response = client.call(request) + response = session.call(request) # Sending the whole message that should be signed while isinstance(response, messages.EthereumTypedDataValueRequest): @@ -362,7 +360,7 @@ def sign_typed_data( member_typename = data["primaryType"] member_data = data["message"] else: - client.cancel() + # TODO session.cancel() raise exceptions.TrezorException("Root index can only be 0 or 1") # It can be asking for a nested structure (the member path being [X, Y, Z, ...]) @@ -385,20 +383,20 @@ def sign_typed_data( encoded_data = encode_data(member_data, member_typename) request = messages.EthereumTypedDataValueAck(value=encoded_data) - response = client.call(request) + response = session.call(request) return response def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: bytes, message: AnyStr, chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.EthereumVerifyMessage( address=address, signature=signature, @@ -413,13 +411,13 @@ def verify_message( @expect(messages.EthereumTypedDataSignature) def sign_typed_data_hash( - client: "TrezorClient", + session: "Session", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index 4ed6f22951f..90064bb238c 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -20,8 +20,8 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect( @@ -29,27 +29,27 @@ field="credentials", ret_type=List[messages.WebAuthnCredential], ) -def list_credentials(client: "TrezorClient") -> "MessageType": - return client.call(messages.WebAuthnListResidentCredentials()) +def list_credentials(session: "Session") -> "MessageType": + return session.call(messages.WebAuthnListResidentCredentials()) @expect(messages.Success, field="message", ret_type=str) -def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": - return client.call( +def add_credential(session: "Session", credential_id: bytes) -> "MessageType": + return session.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id) ) @expect(messages.Success, field="message", ret_type=str) -def remove_credential(client: "TrezorClient", index: int) -> "MessageType": - return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) +def remove_credential(session: "Session", index: int) -> "MessageType": + return session.call(messages.WebAuthnRemoveResidentCredential(index=index)) @expect(messages.Success, field="message", ret_type=str) -def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": - return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) +def set_counter(session: "Session", u2f_counter: int) -> "MessageType": + return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) @expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) -def get_next_counter(client: "TrezorClient") -> "MessageType": - return client.call(messages.GetNextU2FCounter()) +def get_next_counter(session: "Session") -> "MessageType": + return session.call(messages.GetNextU2FCounter()) diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 5cc5d8830cb..a588b160e1e 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -20,7 +20,7 @@ from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import expect, session +from ..tools import expect from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware @@ -38,7 +38,7 @@ from .vendor import * # noqa: F401, F403 if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session T = t.TypeVar("T", bound="FirmwareType") @@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]: # ====== Client functions ====== # -@session def update( - client: "TrezorClient", + session: "Session", data: bytes, progress_update: t.Callable[[int], t.Any] = lambda _: None, ): - if client.features.bootloader_mode is False: + if session.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") - resp = client.call(messages.FirmwareErase(length=len(data))) + resp = session.call(messages.FirmwareErase(length=len(data))) # TREZORv1 method if isinstance(resp, messages.Success): - resp = client.call(messages.FirmwareUpload(payload=data)) + resp = session.call(messages.FirmwareUpload(payload=data)) progress_update(len(data)) if isinstance(resp, messages.Success): return @@ -97,7 +96,7 @@ def update( length = resp.length payload = data[resp.offset : resp.offset + length] digest = blake2s(payload).digest() - resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) + resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest)) progress_update(length) if isinstance(resp, messages.Success): @@ -107,5 +106,5 @@ def update( @expect(messages.FirmwareHash, field="hash", ret_type=bytes) -def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): - return client.call(messages.GetFirmwareHash(challenge=challenge)) +def get_hash(session: "Session", challenge: t.Optional[bytes]): + return session.call(messages.GetFirmwareHash(challenge=challenge)) diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index d50324d5868..b0bcd344a9c 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,6 +17,7 @@ from __future__ import annotations import io +import logging from types import ModuleType from typing import Dict, Optional, Tuple, Type, TypeVar @@ -25,6 +26,7 @@ from . import messages, protobuf T = TypeVar("T") +LOG = logging.getLogger(__name__) class ProtobufMapping: @@ -63,11 +65,21 @@ def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - + LOG.debug("encoding wire type %d", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) return wire_type, buf.getvalue() + def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: + """Serialize a Python protobuf class. + + Returns the byte representation of the protobuf message. + """ + + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return buf.getvalue() + def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: """Deserialize a protobuf message into a Python class.""" cls = self.type_to_class[msg_wire_type] diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index e178d7aacd4..2029bee9f38 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -270,6 +270,24 @@ class MessageType(IntEnum): SolanaAddress = 903 SolanaSignTx = 904 SolanaTxSignature = 905 + ThpCreateNewSession = 1000 + ThpNewSession = 1001 + ThpStartPairingRequest = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceHost = 1018 + ThpCodeEntryCpaceTrezor = 1019 + ThpCodeEntryTag = 1020 + ThpCodeEntrySecret = 1021 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcUnidirectionalTag = 1032 + ThpNfcUnidirectionalSecret = 1033 class FailureType(IntEnum): @@ -287,6 +305,8 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 FirmwareError = 99 @@ -632,6 +652,13 @@ class TezosBallotType(IntEnum): Pass = 2 +class ThpPairingMethod(IntEnum): + NoMethod = 1 + CodeEntry = 2 + QrCode = 3 + NFC_Unidirectional = 4 + + class BinanceGetAddress(protobuf.MessageType): MESSAGE_WIRE_TYPE = 700 FIELDS = { @@ -4032,6 +4059,7 @@ class DebugLinkGetState(protobuf.MessageType): 1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None), 2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None), 3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE), + 4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None), } def __init__( @@ -4040,10 +4068,12 @@ def __init__( wait_word_list: Optional["bool"] = None, wait_word_pos: Optional["bool"] = None, wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE, + thp_channel_id: Optional["bytes"] = None, ) -> None: self.wait_word_list = wait_word_list self.wait_word_pos = wait_word_pos self.wait_layout = wait_layout + self.thp_channel_id = thp_channel_id class DebugLinkState(protobuf.MessageType): @@ -4062,6 +4092,9 @@ class DebugLinkState(protobuf.MessageType): 11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None), 12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None), 13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None), + 14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None), + 15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None), + 16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None), } def __init__( @@ -4080,6 +4113,9 @@ def __init__( recovery_word_pos: Optional["int"] = None, reset_word_pos: Optional["int"] = None, mnemonic_type: Optional["BackupType"] = None, + thp_pairing_code_entry_code: Optional["int"] = None, + thp_pairing_code_qr_code: Optional["bytes"] = None, + thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None, ) -> None: self.tokens: Sequence["str"] = tokens if tokens is not None else [] self.layout = layout @@ -4094,6 +4130,9 @@ def __init__( self.recovery_word_pos = recovery_word_pos self.reset_word_pos = reset_word_pos self.mnemonic_type = mnemonic_type + self.thp_pairing_code_entry_code = thp_pairing_code_entry_code + self.thp_pairing_code_qr_code = thp_pairing_code_qr_code + self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional class DebugLinkStop(protobuf.MessageType): @@ -7756,6 +7795,280 @@ def __init__( self.amount = amount +class ThpDeviceProperties(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None), + 3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None), + 4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None), + 5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + internal_model: Optional["str"] = None, + model_variant: Optional["int"] = None, + bootloader_mode: Optional["bool"] = None, + protocol_version: Optional["int"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.internal_model = internal_model + self.model_variant = model_variant + self.bootloader_mode = bootloader_mode + self.protocol_version = protocol_version + + +class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + host_pairing_credential: Optional["bytes"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.host_pairing_credential = host_pairing_credential + + +class ThpCreateNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1000 + FIELDS = { + 1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None), + 3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + passphrase: Optional["str"] = None, + on_device: Optional["bool"] = None, + derive_cardano: Optional["bool"] = None, + ) -> None: + self.passphrase = passphrase + self.on_device = on_device + self.derive_cardano = derive_cardano + + +class ThpNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1001 + FIELDS = { + 1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + new_session_id: Optional["int"] = None, + ) -> None: + self.new_session_id = new_session_id + + +class ThpStartPairingRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1008 + FIELDS = { + 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_name: Optional["str"] = None, + ) -> None: + self.host_name = host_name + + +class ThpPairingPreparationsFinished(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1009 + + +class ThpCodeEntryCommitment(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1016 + FIELDS = { + 1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + commitment: Optional["bytes"] = None, + ) -> None: + self.commitment = commitment + + +class ThpCodeEntryChallenge(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1017 + FIELDS = { + 1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + challenge: Optional["bytes"] = None, + ) -> None: + self.challenge = challenge + + +class ThpCodeEntryCpaceHost(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1018 + FIELDS = { + 1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_host_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_host_public_key = cpace_host_public_key + + +class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1019 + FIELDS = { + 1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_trezor_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_trezor_public_key = cpace_trezor_public_key + + +class ThpCodeEntryTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1020 + FIELDS = { + 2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpCodeEntrySecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1021 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpQrCodeTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1024 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpQrCodeSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1025 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpNfcUnidirectionalTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1032 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpNfcUnidirectionalSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1033 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpCredentialRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1010 + FIELDS = { + 1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_static_pubkey: Optional["bytes"] = None, + ) -> None: + self.host_static_pubkey = host_static_pubkey + + +class ThpCredentialResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1011 + FIELDS = { + 1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + trezor_static_pubkey: Optional["bytes"] = None, + credential: Optional["bytes"] = None, + ) -> None: + self.trezor_static_pubkey = trezor_static_pubkey + self.credential = credential + + +class ThpEndRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1012 + + +class ThpEndResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1013 + + class ThpCredentialMetadata(protobuf.MessageType): MESSAGE_WIRE_TYPE = None FIELDS = { diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 4ed6f5aa81c..d951c52d7cd 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -20,25 +20,25 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.Entropy, field="entropy", ret_type=bytes) -def get_entropy(client: "TrezorClient", size: int) -> "MessageType": - return client.call(messages.GetEntropy(size=size)) +def get_entropy(session: "Session", size: int) -> "MessageType": + return session.call(messages.GetEntropy(size=size)) @expect(messages.SignedIdentity) def sign_identity( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, ) -> "MessageType": - return client.call( + return session.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, @@ -50,12 +50,12 @@ def sign_identity( @expect(messages.ECDHSessionKey) def get_ecdh_session_key( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, ) -> "MessageType": - return client.call( + return session.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, @@ -66,7 +66,7 @@ def get_ecdh_session_key( @expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -74,7 +74,7 @@ def encrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> "MessageType": - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -89,7 +89,7 @@ def encrypt_keyvalue( @expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -97,7 +97,7 @@ def decrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> "MessageType": - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -111,5 +111,5 @@ def decrypt_keyvalue( @expect(messages.Nonce, field="nonce", ret_type=bytes) -def get_nonce(client: "TrezorClient"): - return client.call(messages.GetNonce()) +def get_nonce(session: "Session"): + return session.call(messages.GetNonce()) diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index 5bce7574e82..5b071626b48 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -20,9 +20,9 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session # MAINNET = 0 @@ -33,13 +33,13 @@ @expect(messages.MoneroAddress, field="address", ret_type=bytes) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, @@ -51,10 +51,10 @@ def get_address( @expect(messages.MoneroWatchKey) def get_watch_key( - client: "TrezorClient", + session: "Session", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, ) -> "MessageType": - return client.call( + return session.call( messages.MoneroGetWatchKey(address_n=n, network_type=network_type) ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 3a67aec72c2..6aa087757a6 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -21,9 +21,9 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig @expect(messages.NEMAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify ) @@ -213,7 +213,7 @@ def get_address( @expect(messages.NEMSignedTx) def sign_tx( - client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False + session: "Session", n: "Address", transaction: dict, chunkify: bool = False ) -> "MessageType": try: msg = create_sign_tx(transaction, chunkify=chunkify) @@ -222,4 +222,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg) + return session.call(msg) diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 7a953b8fac5..f026236c071 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -21,9 +21,9 @@ from .tools import dict_from_camelcase, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") @@ -31,12 +31,12 @@ @expect(messages.RippleAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -45,14 +45,14 @@ def get_address( @expect(messages.RippleSignedTx) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, ) -> "MessageType": msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg) + return session.call(msg) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index be7f2e5fcb5..1a228b2f957 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -4,29 +4,29 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect(messages.SolanaPublicKey) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, ) -> "MessageType": - return client.call( + return session.call( messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) ) @expect(messages.SolanaAddress) def get_address( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, @@ -37,12 +37,12 @@ def get_address( @expect(messages.SolanaTxSignature) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], ) -> "MessageType": - return client.call( + return session.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index ebf81e4fd04..12a75ca5d8a 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -21,9 +21,9 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session StellarMessageType = Union[ messages.StellarAccountMergeOp, @@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: @expect(messages.StellarAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -338,7 +338,7 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", tx: messages.StellarSignTx, operations: List["StellarMessageType"], address_n: "Address", @@ -354,10 +354,10 @@ def sign_tx( # 3. Receive a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 5. The final message received will be StellarSignedTx which is returned from this method - resp = client.call(tx) + resp = session.call(tx) try: while isinstance(resp, messages.StellarTxOpRequest): - resp = client.call(operations.pop(0)) + resp = session.call(operations.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index cff06ed6c83..b74dc562599 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -20,19 +20,19 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.TezosAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -41,12 +41,12 @@ def get_address( @expect(messages.TezosPublicKey, field="public_key", ret_type=str) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -55,11 +55,11 @@ def get_public_key( @expect(messages.TezosSignedTx) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, ) -> "MessageType": sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg) + return session.call(sign_tx_msg) diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 4fd1558ec29..3e9bd1c5608 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -40,7 +40,7 @@ # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from . import client from .protobuf import MessageType @@ -284,23 +284,6 @@ def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]": return decorator -def session( - f: "Callable[Concatenate[TrezorClient, P], R]", -) -> "Callable[Concatenate[TrezorClient, P], R]": - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - client.open() - try: - return f(client, *args, **kwargs) - finally: - client.close() - - return wrapped_f - - # de-camelcasifier # https://stackoverflow.com/a/1176023/222189 diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b77..e8a9960fdf0 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,24 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -41,7 +35,7 @@ """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -53,73 +47,55 @@ class DeviceIsBusy(TransportException): class Transport: - """Raw connection to a Trezor device. + PATH_PREFIX: str - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. + @classmethod + def enumerate( + cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["T"]: + raise NotImplementedError - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages + @classmethod + def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ + if device.get_path() == path: + return device - PATH_PREFIX: str - ENABLED = False + if prefix_search and device.get_path().startswith(path): + return device - def __str__(self) -> str: - return self.get_path() + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: + def find_debug(self: "T") -> "T": raise NotImplementedError - def read(self) -> MessagePayload: + def open(self) -> None: raise NotImplementedError - def write(self, message_type: int, message_data: bytes) -> None: + def close(self) -> None: raise NotImplementedError - def find_debug(self: "T") -> "T": + def write_chunk(self, chunk: bytes) -> None: raise NotImplementedError - @classmethod - def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + def read_chunk(self) -> bytes: raise NotImplementedError - @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device - - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + CHUNK_SIZE: t.ClassVar[int] -def all_transports() -> Iterable[Type["Transport"]]: - from .bridge import BridgeTransport +def all_transports() -> t.Iterable[t.Type["Transport"]]: + # from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( - BridgeTransport, + transports: t.Tuple[t.Type["Transport"], ...] = ( + # BridgeTransport, HidTransport, UdpTransport, WebUsbTransport, @@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: t.Iterable["TrezorModel"] | None = None, +) -> t.Sequence["Transport"]: + devices: t.List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -145,9 +121,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f701..8d69e5b253f 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,24 +14,30 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +import typing as t import requests from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) +PROTOCOL_VERSION_1 = 1 +PROTOCOL_VERSION_2 = 2 + TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_VERSION_MODERN = (2, 0, 25) +TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) @@ -45,7 +51,7 @@ def __init__(self, path: str, status: int, message: str) -> None: super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge(path: str, data: str | None = None) -> requests.Response: url = TREZORD_HOST + "/" + path r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN + + +def supports_protocolV2() -> bool: + return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT + + +def detect_protocol_version(transport: "BridgeTransport") -> int: + from .. import mapping, messages + from ..messages import FailureType + + protocol_version = PROTOCOL_VERSION_1 + request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) + transport.deprecated_begin_session() + transport.deprecated_write(request_type, request_data) + + response_type, response_data = transport.deprecated_read() + response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) + transport.deprecated_begin_session() + if isinstance(response, messages.Failure): + if response.code == FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol_version = PROTOCOL_VERSION_2 + + return protocol_version + + +def _is_transport_valid(transport: "BridgeTransport") -> bool: + is_valid = ( + supports_protocolV2() + or detect_protocol_version(transport) == PROTOCOL_VERSION_1 + ) + if not is_valid: + LOG.warning("Detected unsupported Bridge transport!") + return is_valid + + +def filter_invalid_bridge_transports( + transports: t.Iterable["BridgeTransport"], +) -> t.Sequence["BridgeTransport"]: + """Filters out invalid bridge transports. Keeps only valid ones.""" + return [t for t in transports if _is_transport_valid(t)] class BridgeHandle: @@ -84,7 +134,7 @@ def read_buf(self) -> bytes: class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: @@ -112,13 +162,12 @@ class BridgeTransport(Transport): ENABLED: bool = True def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: Optional[str] = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -135,7 +184,7 @@ def find_debug(self) -> "BridgeTransport": raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call(self, action: str, data: str | None = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: @@ -144,17 +193,20 @@ def _call(self, action: str, data: Optional[str] = None) -> requests.Response: @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() - return [ - BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() - ] + return filter_invalid_bridge_transports( + [ + BridgeTransport(dev, legacy) + for dev in call_bridge("enumerate").json() + ] + ) except Exception: return [] - def begin_session(self) -> None: + def deprecated_begin_session(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -163,18 +215,32 @@ def begin_session(self) -> None: raise self.session = data.json()["session"] - def end_session(self) -> None: + def deprecated_end_session(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: + def deprecated_write(self, message_type: int, message_data: bytes) -> None: header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(header + message_data) - def read(self) -> MessagePayload: + def deprecated_read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) return msg_type, data[headerlen : headerlen + datalen] + + def open(self) -> None: + pass + # TODO self.handle.open() + + def close(self) -> None: + pass + # TODO self.handle.close() + + def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :) + self.handle.write_buf(chunk) + + def read_chunk(self) -> bytes: # TODO check if it works :) + return self.handle.read_buf() diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd70..995fd6960ca 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,15 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Transport, TransportException LOG = logging.getLogger(__name__) @@ -35,23 +36,61 @@ HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any + + +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -62,11 +101,11 @@ def open(self) -> None: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -77,7 +116,7 @@ def open(self) -> None: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -115,53 +154,6 @@ def probe_hid_version(self) -> int: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: List["HidTransport"] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> "HidTransport": - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/new/__init__.py b/python/src/trezorlib/transport/new/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/src/trezorlib/transport/new/alternating_bit_protocol.py b/python/src/trezorlib/transport/new/alternating_bit_protocol.py new file mode 100644 index 00000000000..62fb650fab0 --- /dev/null +++ b/python/src/trezorlib/transport/new/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +# from storage.cache_thp import ChannelCache +# from trezor import log +# from trezor.wire.thp import ThpError + + +# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: +# """ +# Checks if: +# - an ACK message is expected +# - the received ACK message acknowledges correct sequence number (bit) +# """ +# if not _is_ack_expected(cache): +# return False + +# if not _has_ack_correct_sync_bit(cache, ack_bit): +# return False + +# return True + + +# def _is_ack_expected(cache: ChannelCache) -> bool: +# is_expected: bool = not is_sending_allowed(cache) +# if __debug__ and not is_expected: +# log.debug(__name__, "Received unexpected ACK message") +# return is_expected + + +# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: +# is_correct: bool = get_send_seq_bit(cache) == sync_bit +# if __debug__ and not is_correct: +# log.debug(__name__, "Received ACK message with wrong ack bit") +# return is_correct + + +# def is_sending_allowed(cache: ChannelCache) -> bool: +# """ +# Checks whether sending a message in the provided channel is allowed. + +# Note: Sending a message in a channel before receipt of ACK message for the previously +# sent message (in the channel) is prohibited, as it can lead to desynchronization. +# """ +# return bool(cache.sync >> 7) + + +# def get_send_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the sequential number (bit) of the next message to be sent +# in the provided channel. +# """ +# return (cache.sync & 0x20) >> 5 + + +# def get_expected_receive_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the (expected) sequential number (bit) of the next message +# to be received in the provided channel. +# """ +# return (cache.sync & 0x40) >> 6 + + +# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: +# """ +# Set the flag whether sending a message in this channel is allowed or not. +# """ +# cache.sync &= 0x7F +# if sending_allowed: +# cache.sync |= 0x80 + + +# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# """ +# Set the expected sequential number (bit) of the next message to be received +# in the provided channel +# """ +# if __debug__: +# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected receive sync bit") + +# # set second bit to "seq_bit" value +# cache.sync &= 0xBF +# if seq_bit: +# cache.sync |= 0x40 + + +# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected send seq bit") +# if __debug__: +# log.debug(__name__, "setting sync send seq bit to %d", seq_bit) +# # set third bit to "seq_bit" value +# cache.sync &= 0xDF +# if seq_bit: +# cache.sync |= 0x20 + + +# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: +# """ +# Set the sequential bit of the "next message to be send" to the opposite value, +# i.e. 1 -> 0 and 0 -> 1 +# """ +# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/python/src/trezorlib/transport/new/channel_data.py b/python/src/trezorlib/transport/new/channel_data.py new file mode 100644 index 00000000000..3d70deecafd --- /dev/null +++ b/python/src/trezorlib/transport/new/channel_data.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from binascii import hexlify + + +class ChannelData: + def __init__( + self, + protocol_version: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + ) -> None: + self.protocol_version: int = protocol_version + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + + def to_dict(self): + return { + "protocol_version": self.protocol_version, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + } diff --git a/python/src/trezorlib/transport/new/channel_database.py b/python/src/trezorlib/transport/new/channel_database.py new file mode 100644 index 00000000000..2de48bba03a --- /dev/null +++ b/python/src/trezorlib/transport/new/channel_database.py @@ -0,0 +1,95 @@ +import json +import logging +import os +import typing as t + +from .channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +FILE_PATH = "channel_data.json" + + +def load_stored_channels() -> t.List[ChannelData]: + dicts = read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + +def channel_to_str(channel: ProtocolAndChannel) -> str: + return json.dumps(channel.get_channel_data().to_dict()) + + +def str_to_channel_data(channel_data: str) -> ChannelData: + return dict_to_channel_data(json.loads(channel_data)) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version=dict["protocol_version"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + ) + + +def ensure_file_exists() -> None: + LOG.debug("checking if file %s exists", FILE_PATH) + if not os.path.exists(FILE_PATH): + LOG.debug("File %s does not exist. Creating a new one.", FILE_PATH) + with open(FILE_PATH, "w") as f: + json.dump([], f) + + +def clear_stored_channels() -> None: + LOG.debug("Clearing contents of %s - to empty list.", FILE_PATH) + with open(FILE_PATH, "w") as f: + json.dump([], f) + + +def read_all_channels() -> t.List: + ensure_file_exists() + with open(FILE_PATH, "r") as f: + return json.load(f) + + +def save_all_channels(channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(FILE_PATH, "w") as f: + json.dump(channels, f, indent=4) + + +def save_channel(new_channel: ProtocolAndChannel): + LOG.debug("save channel") + channels = read_all_channels() + transport_path = new_channel.transport.get_path() + + # If channel is modified: replace the old by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + save_all_channels(channels) + return + + # Else: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + save_all_channels(channels) + + +def remove_channel(transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + save_all_channels(remaining_channels) diff --git a/python/src/trezorlib/transport/new/control_byte.py b/python/src/trezorlib/transport/new/control_byte.py new file mode 100644 index 00000000000..ce7f6066f98 --- /dev/null +++ b/python/src/trezorlib/transport/new/control_byte.py @@ -0,0 +1,59 @@ +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise Exception("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise Exception("Unexpected acknowledgement bit") + + +def get_seq_bit(ctrl_byte: int) -> int: + return (ctrl_byte & 0x10) >> 4 + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & ACK_MASK == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/python/src/trezorlib/transport/new/protocol_and_channel.py b/python/src/trezorlib/transport/new/protocol_and_channel.py new file mode 100644 index 00000000000..a2c847caff7 --- /dev/null +++ b/python/src/trezorlib/transport/new/protocol_and_channel.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport +from .channel_data import ChannelData + +LOG = logging.getLogger(__name__) + + +class ProtocolAndChannel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.transport = transport + self.mapping = mapping + self.channel_keys = channel_data + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def get_channel_data(self) -> ChannelData: + raise NotImplementedError + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/new/protocol_v1.py b/python/src/trezorlib/transport/new/protocol_v1.py new file mode 100644 index 00000000000..ead78ce4c5f --- /dev/null +++ b/python/src/trezorlib/transport/new/protocol_v1.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import LOG, ProtocolAndChannel + + +class ProtocolV1(ProtocolAndChannel): + HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] diff --git a/python/src/trezorlib/transport/new/protocol_v2.py b/python/src/trezorlib/transport/new/protocol_v2.py new file mode 100644 index 00000000000..f6960fef2fa --- /dev/null +++ b/python/src/trezorlib/transport/new/protocol_v2.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import typing as t +from binascii import hexlify +from enum import IntEnum + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ... import exceptions, messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp import checksum, curve25519, thp_io +from ..thp.checksum import CHECKSUM_LENGTH +from ..thp.packet_header import PacketHeader +from . import channel_database, control_byte +from .channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +MANAGEMENT_SESSION_ID: int = 0 + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") + + +class ProtocolV2(ProtocolAndChannel): + channel_id: int + + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + sync_bit_send: int + sync_bit_receive: int + + _has_valid_channel: bool = False + _features: messages.Features | None = None + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + super().__init__(transport, mapping, channel_data) + if channel_data is not None: + self.channel_id = channel_data.channel_id + self.key_request = bytes.fromhex(channel_data.key_request) + self.key_response = bytes.fromhex(channel_data.key_response) + self.nonce_request = channel_data.nonce_request + self.nonce_response = channel_data.nonce_response + self.sync_bit_receive = channel_data.sync_bit_receive + self.sync_bit_send = channel_data.sync_bit_send + self._has_valid_channel = True + + def get_channel(self) -> ProtocolV2: + if not self._has_valid_channel: + self._establish_new_channel() + return self + + def get_channel_data(self) -> ChannelData: + return ChannelData( + protocol_version=2, + transport_path=self.transport.get_path(), + channel_id=self.channel_id, + key_request=self.key_request, + key_response=self.key_response, + nonce_request=self.nonce_request, + nonce_response=self.nonce_response, + sync_bit_receive=self.sync_bit_receive, + sync_bit_send=self.sync_bit_send, + ) + + def read(self, session_id: int) -> t.Any: + sid, msg_type, msg_data = self.read_and_decrypt() + if sid != session_id: + raise Exception("Received messsage on different session.") + channel_database.save_channel(self) + return self.mapping.decode(msg_type, msg_data) + + def write(self, session_id: int, msg: t.Any) -> None: + msg_type, msg_data = self.mapping.encode(msg) + self._encrypt_and_write(session_id, msg_type, msg_data) + channel_database.save_channel(self) + + def get_features(self) -> messages.Features: + if not self._has_valid_channel: + self._establish_new_channel() + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + message = messages.GetFeatures() + message_type, message_data = self.mapping.encode(message) + self.session_id: int = 0 + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + _ = self._read_until_valid_crc_check() # TODO check ACK + _, msg_type, msg_data = self.read_and_decrypt() + features = self.mapping.decode(msg_type, msg_data) + if not isinstance(features, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = features + + def _establish_new_channel(self) -> None: + self.sync_bit_send = 0 + self.sync_bit_receive = 0 + # Send channel allocation request + channel_id_request_nonce = os.urandom(8) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + PacketHeader.get_channel_allocation_request_header(12), + channel_id_request_nonce, + ) + + # Read channel allocation response + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, channel_id_request_nonce + ): + print("TODO raise exception here, I guess") + + self.channel_id = int.from_bytes(payload[8:10], "big") + self.device_properties = payload[10:] + + # Send handshake init request + ha_init_req_header = PacketHeader(0, self.channel_id, 36) + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, ha_init_req_header, host_ephemeral_pubkey + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read handshake init response + header, payload = self._read_until_valid_crc_check() + self._send_ack_0() + + if not header.is_handshake_init_response(): + print("Received message is not a valid handshake init response message") + + trezor_ephemeral_pubkey = payload[:32] + encrypted_trezor_static_pubkey = payload[32:80] + noise_tag = payload[80:96] + + # TODO check noise tag + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) + + # Prepare and send handshake completion request + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + except Exception as e: + print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + # TODO: search for saved credentials (or possibly not, as we skip pairing phase) + + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = self.mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + pairing_methods=[ + messages.ThpPairingMethod.NoMethod, + ] + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload) + ha_completion_req_header = PacketHeader( + 0x12, + self.channel_id, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + print("Received message is not a valid handshake completion response") + self._send_ack_1() + + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + # Send StartPairingReqest message + message = messages.ThpStartPairingRequest() + message_type, message_data = self.mapping.encode(message) + + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read + _, msg_type, msg_data = self.read_and_decrypt() + maaa = self.mapping.decode(msg_type, msg_data) + + assert isinstance(maaa, messages.ThpEndResponse) + self._has_valid_channel = True + + def _send_ack_0(self): + LOG.debug("sending ack 0") + header = PacketHeader(0x20, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _send_ack_1(self): + LOG.debug("sending ack 1") + header = PacketHeader(0x28, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _encrypt_and_write( + self, + session_id: int, + message_type: int, + message_data: bytes, + ctrl_byte: int | None = None, + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + + if ctrl_byte is None: + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) + self.sync_bit_send = 1 - self.sync_bit_send + + sid = session_id.to_bytes(1, "big") + msg_type = message_type.to_bytes(2, "big") + data = sid + msg_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = PacketHeader( + ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, header, encrypted_message + ) + + def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if control_byte.is_ack(header.ctrl_byte): + return self.read_and_decrypt() + if not header.is_encrypted_transport(): + print("Trying to decrypt not encrypted message!") + print( + hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode() + ) + + if not control_byte.is_ack(header.ctrl_byte): + LOG.debug( + "--> Get sequence bit %d %s %s", + control_byte.get_seq_bit(header.ctrl_byte), + "from control byte", + hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(), + ) + if control_byte.get_seq_bit(header.ctrl_byte): + self._send_ack_1() + else: + self._send_ack_0() + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + session_id, + int.from_bytes(message_type, "big"), + message_data, + ) + + def _read_until_valid_crc_check( + self, + ) -> t.Tuple[PacketHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.transport) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + print(hexlify(header.to_bytes_init() + payload + chksum)) + LOG.debug("Received a message with invalid checksum") + header, payload, chksum = thp_io.read(self.transport) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: PacketHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + print("Received message is not a channel allocation response") + return False + if len(payload) < 10: + print("Invalid channel allocation response payload") + return False + if payload[:8] != original_nonce: + print("Invalid channel allocation response payload (nonce mismatch)") + return False + return True + + class ControlByteType(IntEnum): + CHANNEL_ALLOCATION_RES = 1 + HANDSHAKE_INIT_RES = 2 + HANDSHAKE_COMP_RES = 3 + ACK = 4 + ENCRYPTED_TRANSPORT = 5 diff --git a/python/src/trezorlib/transport/new/session.py b/python/src/trezorlib/transport/new/session.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index a5a0ee6be4d..00000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,165 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -import logging -import struct -from typing import Tuple - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle) -> None: - self.handle = handle - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self) -> MessagePayload: - return self.protocol.read() - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError("Unexpected magic characters") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError("Cannot parse header") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError("Unexpected magic characters") - return chunk[1:] diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py new file mode 100644 index 00000000000..9713ae708b7 --- /dev/null +++ b/python/src/trezorlib/transport/session.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import logging +import typing as t + +from .. import exceptions, messages, models +from .new.protocol_v1 import ProtocolV1 +from .new.protocol_v2 import ProtocolV2 + +if t.TYPE_CHECKING: + from ..client import TrezorClient + +LOG = logging.getLogger(__name__) + + +class Session: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + def __init__(self, client: TrezorClient, id: bytes) -> None: + self.client = client + self._id = id + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool + ) -> Session: + raise NotImplementedError + + def call(self, msg: t.Any) -> t.Any: + # TODO self.check_firmware_version() + resp = self.call_raw(msg) + + while True: + if isinstance(resp, messages.PinMatrixRequest): + if self.pin_callback is None: + raise Exception # TODO + resp = self.pin_callback(self, resp) + elif isinstance(resp, messages.PassphraseRequest): + raise NotImplementedError + # resp = self._callback_passphrase(resp) + elif isinstance(resp, messages.ButtonRequest): + if self.button_callback is None: + raise Exception # TODO + resp = self.button_callback(self, resp) + elif isinstance(resp, messages.Failure): + if resp.code == messages.FailureType.ActionCancelled: + raise exceptions.Cancelled + raise exceptions.TrezorFailure(resp) + else: + return resp + + def call_raw(self, msg: t.Any) -> t.Any: + self._write(msg) + return self._read() + + def _write(self, msg: t.Any) -> None: + raise NotImplementedError + + def _read(self) -> t.Any: + raise NotImplementedError + + def refresh_features(self) -> None: + self.client.refresh_features() + + def end(self) -> None: + raise NotImplementedError + + @property + def features(self) -> messages.Features: + return self.client.features + + @property + def model(self) -> models.TrezorModel: + return self.client.model + + @property + def version(self) -> t.Tuple[int, int, int]: + return self.client.version + + @property + def id(self) -> bytes: + return self._id + + +class SessionV1(Session): + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session_id = client.features.session_id + if session_id is None: + LOG.debug("warning, session id of protocol_v1 session is None") + session = SessionV1(client, id=b"") + else: + session = SessionV1(client, session_id) + session.button_callback = client.button_callback + session.pin_callback = client.pin_callback + return session + + def _write(self, msg: t.Any) -> None: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + self.client.protocol.write(msg) + + def _read(self) -> t.Any: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + return self.client.protocol.read() + + +def _callback_button(session: Session, msg: t.Any) -> t.Any: + print("Please confirm action on your Trezor device.") # TODO how to handle UI? + return session.call(messages.ButtonAck()) + + +class SessionV2(Session): + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool + ) -> SessionV2: + assert isinstance(client.protocol, ProtocolV2) + session = cls(client, b"\x00") + new_session: messages.ThpNewSession = session.call( + messages.ThpCreateNewSession( + passphrase=passphrase, derive_cardano=derive_cardano + ) + ) + assert new_session.new_session_id is not None + session_id = new_session.new_session_id + session.update_id_and_sid(session_id.to_bytes(1, "big")) + return session + + def __init__(self, client: TrezorClient, id: bytes) -> None: + super().__init__(client, id) + assert isinstance(client.protocol, ProtocolV2) + + self.pin_callback = client.pin_callback + self.button_callback = client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.channel: ProtocolV2 = client.protocol.get_channel() + self.update_id_and_sid(id) + + def _write(self, msg: t.Any) -> None: + LOG.debug("writing message %s", type(msg)) + self.channel.write(self.sid, msg) + + def _read(self) -> t.Any: + msg = self.channel.read(self.sid) + LOG.debug("reading message %s", type(msg)) + return msg + + def update_id_and_sid(self, id: bytes) -> None: + self._id = id + self.sid = int.from_bytes(id, "big") # TODO update to extract only sid diff --git a/python/src/trezorlib/transport/thp/checksum.py b/python/src/trezorlib/transport/thp/checksum.py new file mode 100644 index 00000000000..8e0f32f0132 --- /dev/null +++ b/python/src/trezorlib/transport/thp/checksum.py @@ -0,0 +1,19 @@ +import zlib + +CHECKSUM_LENGTH = 4 + + +def compute(data: bytes) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/python/src/trezorlib/transport/thp/curve25519.py b/python/src/trezorlib/transport/thp/curve25519.py new file mode 100644 index 00000000000..43127c49e57 --- /dev/null +++ b/python/src/trezorlib/transport/thp/curve25519.py @@ -0,0 +1,116 @@ +from typing import Tuple + +p = 2**255 - 19 +J = 486662 + +c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1) +c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8 +a24 = 121666 # (J + 2) // 4 + + +def decode_scalar(scalar: bytes) -> int: + # decodeScalar25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + if len(scalar) != 32: + raise ValueError("Invalid length of scalar") + + array = bytearray(scalar) + array[0] &= 248 + array[31] &= 127 + array[31] |= 64 + + return int.from_bytes(array, "little") + + +def decode_coordinate(coordinate: bytes) -> int: + # decodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + if len(coordinate) != 32: + raise ValueError("Invalid length of coordinate") + + array = bytearray(coordinate) + array[-1] &= 0x7F + return int.from_bytes(array, "little") % p + + +def encode_coordinate(coordinate: int) -> bytes: + # encodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + return coordinate.to_bytes(32, "little") + + +def get_private_key(secret: bytes) -> bytes: + return decode_scalar(secret).to_bytes(32, "little") + + +def get_public_key(private_key: bytes) -> bytes: + base_point = int.to_bytes(9, 32, "little") + return multiply(private_key, base_point) + + +def multiply(private_scalar: bytes, public_point: bytes): + # X25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + def ladder_operation( + x1: int, x2: int, z2: int, x3: int, z3: int + ) -> Tuple[int, int, int, int]: + # https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3 + # (x4, z4) = 2 * (x2, z2) + # (x5, z5) = (x2, z2) + (x3, z3) + # where (x1, 1) = (x3, z3) - (x2, z2) + + a = (x2 + z2) % p + aa = (a * a) % p + b = (x2 - z2) % p + bb = (b * b) % p + e = (aa - bb) % p + c = (x3 + z3) % p + d = (x3 - z3) % p + da = (d * a) % p + cb = (c * b) % p + t0 = (da + cb) % p + x5 = (t0 * t0) % p + t1 = (da - cb) % p + t2 = (t1 * t1) % p + z5 = (x1 * t2) % p + x4 = (aa * bb) % p + t3 = (a24 * e) % p + t4 = (bb + t3) % p + z4 = (e * t4) % p + + return x4, z4, x5, z5 + + def conditional_swap(first: int, second: int, condition: int): + # Returns (second, first) if condition is true and (first, second) otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask), (second & false_mask) | ( + first & true_mask + ) + + k = decode_scalar(private_scalar) + u = decode_coordinate(public_point) + + x_1 = u + x_2 = 1 + z_2 = 0 + x_3 = u + z_3 = 1 + swap = 0 + + for i in reversed(range(256)): + bit = (k >> i) & 1 + swap = bit ^ swap + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + swap = bit + x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3) + + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + + x = pow(z_2, p - 2, p) * x_2 % p + return encode_coordinate(x) diff --git a/python/src/trezorlib/transport/thp/packet_header.py b/python/src/trezorlib/transport/thp/packet_header.py new file mode 100644 index 00000000000..b282f9f46fc --- /dev/null +++ b/python/src/trezorlib/transport/thp/packet_header.py @@ -0,0 +1,82 @@ +import struct + +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +BROADCAST_CHANNEL_ID = 0xFFFF + + +class PacketHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.data_length = length + + def to_bytes_init(self) -> bytes: + return struct.pack( + self.format_str_init, self.ctrl_byte, self.cid, self.data_length + ) + + def to_bytes_cont(self) -> bytes: + return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.data_length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + def is_ack(self) -> bool: + return self.ctrl_byte & ACK_MASK == ACK_MESSAGE + + def is_channel_allocation_response(self): + return ( + self.cid == BROADCAST_CHANNEL_ID + and self.ctrl_byte == _CHANNEL_ALLOCATION_RES + ) + + def is_handshake_init_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES + + def is_handshake_comp_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES + + def is_encrypted_transport(self) -> bool: + return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_request_header(cls, length: int): + return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length) diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py new file mode 100644 index 00000000000..0ac33344011 --- /dev/null +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -0,0 +1,86 @@ +import struct +from typing import Tuple + +from .. import Transport +from ..thp import checksum +from .packet_header import PacketHeader + +INIT_HEADER_LENGTH = 5 +CONT_HEADER_LENGTH = 3 +MAX_PAYLOAD_LEN = 60000 +MESSAGE_TYPE_LENGTH = 2 + +CONTINUATION_PACKET = 0x80 + + +def write_payload_to_wire_and_add_checksum( + transport: Transport, header: PacketHeader, transport_payload: bytes +): + chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) + data = transport_payload + chksum + write_payload_to_wire(transport, header, data) + + +def write_payload_to_wire( + transport: Transport, header: PacketHeader, transport_payload: bytes +): + transport.open() + buffer = bytearray(transport_payload) + chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH] + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + + buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :] + while buffer: + chunk = ( + header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH] + ) + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] + + +def read(transport: Transport) -> Tuple[PacketHeader, bytes, bytes]: + buffer = bytearray() + # Read header with first part of message data + header, first_chunk = read_first(transport) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < header.data_length: + buffer.extend(read_next(transport, header.cid)) + # print("buffer read (data):", hexlify(buffer).decode()) + # print("buffer len (data):", datalen) + # TODO check checksum?? or do not strip ? + data_len = header.data_length - checksum.CHECKSUM_LENGTH + return ( + header, + buffer[:data_len], + buffer[data_len : data_len + checksum.CHECKSUM_LENGTH], + ) + + +def read_first(transport: Transport) -> Tuple[PacketHeader, bytes]: + chunk = transport.read_chunk() + try: + ctrl_byte, cid, data_length = struct.unpack( + PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] + ) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[INIT_HEADER_LENGTH:] + return PacketHeader(ctrl_byte, cid, data_length), data + + +def read_next(transport: Transport, cid: int) -> bytes: + chunk = transport.read_chunk() + ctrl_byte, read_cid = struct.unpack( + PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] + ) + if ctrl_byte != CONTINUATION_PACKET: + raise RuntimeError("Continuation packet with incorrect control byte") + if read_cid != cid: + raise RuntimeError("Continuation packet for different channel") + + return chunk[CONT_HEADER_LENGTH:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c63..2960df89945 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,14 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -31,14 +32,18 @@ LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 - def __init__(self, device: Optional[str] = None) -> None: + def __init__( + self, + device: str | None = None, + ) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -46,24 +51,17 @@ def __init__(self, device: Optional[str] = None) -> None: devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) - self.socket: Optional[socket.socket] = None - - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) + self.device: Tuple[str, int] = (host, port) - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + self.socket: socket.socket | None = None + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -77,7 +75,7 @@ def _try_path(cls, path: str) -> "UdpTransport": @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -99,20 +97,8 @@ def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": else: raise TransportException(f"No UDP device at {path}") - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +110,9 @@ def close(self) -> None: self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +120,8 @@ def write_chunk(self, chunk: bytes) -> None: self.socket.sendall(chunk) def read_chunk(self) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None while True: try: @@ -154,3 +133,33 @@ def read_chunk(self) -> bytes: if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a6..3ad47c6eb25 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException LOG = logging.getLogger(__name__) @@ -44,13 +45,69 @@ WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" def open(self) -> None: self.handle = self.device.open() @@ -75,6 +132,8 @@ def close(self) -> None: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -97,6 +156,8 @@ def write_chunk(self, chunk: bytes) -> None: return def read_chunk(self) -> bytes: + if self.handle is None: + self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -117,70 +178,6 @@ def read_chunk(self) -> bytes: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return chunk - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True) diff --git a/python/tools/encfs_aes_getpass.py b/python/tools/encfs_aes_getpass.py index 82773e50fa7..37a221154cc 100755 --- a/python/tools/encfs_aes_getpass.py +++ b/python/tools/encfs_aes_getpass.py @@ -35,7 +35,6 @@ from trezorlib.client import TrezorClient from trezorlib.tools import Address from trezorlib.transport import enumerate_devices -from trezorlib.ui import ClickUI version_tuple = tuple(map(int, trezorlib.__version__.split("."))) if not (0, 11) <= version_tuple < (0, 14): @@ -71,7 +70,7 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write("Available devices:\n") for d in devices: try: - client = TrezorClient(d, ui=ClickUI()) + client = TrezorClient(d) except IOError: sys.stderr.write("[-] \n") continue @@ -80,7 +79,7 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write(f"[{i}] {client.features.label}\n") else: sys.stderr.write(f"[{i}] \n") - client.close() + # TODO client.close() i += 1 sys.stderr.write("----------------------------\n") @@ -106,7 +105,8 @@ def main() -> None: devices = wait_for_devices() transport = choose_device(devices) - client = TrezorClient(transport, ui=ClickUI()) + client = TrezorClient(transport) + session = client.get_management_session() rootdir = os.environ["encfs_root"] # Read "man encfs" for more passw_file = os.path.join(rootdir, "password.dat") @@ -120,7 +120,7 @@ def main() -> None: sys.stderr.write("Computer asked Trezor for new strong password.\n") # 32 bytes, good for AES - trezor_entropy = trezorlib.misc.get_entropy(client, 32) + trezor_entropy = trezorlib.misc.get_entropy(session, 32) urandom_entropy = os.urandom(32) passw = hashlib.sha256(trezor_entropy + urandom_entropy).digest() @@ -129,7 +129,7 @@ def main() -> None: bip32_path = Address([10, 0]) passw_encrypted = trezorlib.misc.encrypt_keyvalue( - client, bip32_path, label, passw, False, True + session, bip32_path, label, passw, False, True ) data = { @@ -144,7 +144,7 @@ def main() -> None: data = json.load(open(passw_file, "r")) passw = trezorlib.misc.decrypt_keyvalue( - client, + session, data["bip32_path"], data["label"], bytes.fromhex(data["password_encrypted_hex"]), diff --git a/python/tools/helloworld.py b/python/tools/helloworld.py index 76b4502da2d..b8711dbb00a 100755 --- a/python/tools/helloworld.py +++ b/python/tools/helloworld.py @@ -24,13 +24,14 @@ def main() -> None: # Use first connected device client = get_default_client() + session = client.get_session(derive_cardano=True) # Print out Trezor's features and settings - print(client.features) + print(session.features) # Get the first address of first BIP44 account bip32_path = parse_path("44h/0h/0h/0/0") - address = btc.get_address(client, "Bitcoin", bip32_path, True) + address = btc.get_address(session, "Bitcoin", bip32_path, False) print("Bitcoin address:", address) diff --git a/python/tools/pwd_reader.py b/python/tools/pwd_reader.py index afd405e1642..1c012c7abf4 100755 --- a/python/tools/pwd_reader.py +++ b/python/tools/pwd_reader.py @@ -26,23 +26,24 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.tools import parse_path from trezorlib.transport import get_transport +from trezorlib.transport.session import Session # Return path by BIP-32 BIP32_PATH = parse_path("10016h/0") # Deriving master key -def getMasterKey(client: TrezorClient) -> str: +def getMasterKey(session: Session) -> str: bip32_path = BIP32_PATH ENC_KEY = "Activate TREZOR Password Manager?" ENC_VALUE = bytes.fromhex( "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee" ) - key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True) + key = misc.encrypt_keyvalue(session, bip32_path, ENC_KEY, ENC_VALUE, True, True) return key.hex() @@ -101,7 +102,7 @@ def decryptEntryValue(nonce: str, val: bytes) -> dict: # Decrypt give entry nonce -def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: +def getDecryptedNonce(session: Session, entry: dict) -> str: print() print("Waiting for Trezor input ...") print() @@ -117,7 +118,7 @@ def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: ENC_KEY = f"Unlock {item} for user {entry['username']}?" ENC_VALUE = entry["nonce"] decrypted_nonce = misc.decrypt_keyvalue( - client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True + session, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True ) return decrypted_nonce.hex() @@ -144,13 +145,14 @@ def main() -> None: print(e) return - client = TrezorClient(transport=transport, ui=ui.ClickUI()) + client = TrezorClient(transport=transport) + session = client.get_management_session() print() print("Confirm operation on Trezor") print() - masterKey = getMasterKey(client) + masterKey = getMasterKey(session) # print('master key:', masterKey) fileName = getFileEncKey(masterKey)[0] @@ -173,7 +175,7 @@ def main() -> None: entry_id = input("Select entry number to decrypt: ") entry_id = str(entry_id) - plain_nonce = getDecryptedNonce(client, entries[entry_id]) + plain_nonce = getDecryptedNonce(session, entries[entry_id]) pwdArr = entries[entry_id]["password"]["data"] pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr]) diff --git a/python/tools/pybridge.py b/python/tools/pybridge.py index 30d69bbc9b1..a48850a80c9 100644 --- a/python/tools/pybridge.py +++ b/python/tools/pybridge.py @@ -24,6 +24,9 @@ from gevent import monkey +import trezorlib.transport.new +import trezorlib.transport.new.transport + monkey.patch_all() import json @@ -103,11 +106,11 @@ def __init__(self, transport: trezorlib.transport.Transport) -> None: self.session: Session | None = None self.transport = transport - client = TrezorClient(transport, ui=SilentUI()) + client = TrezorClient(transport) # TODO add silent UI? self.model = ( trezorlib.models.by_name(client.features.model) or trezorlib.models.TREZOR_T ) - client.end_session() + # TODO client.end_session() def acquire(self, sid: str) -> str: if self.session_id() != sid: @@ -116,11 +119,11 @@ def acquire(self, sid: str) -> str: self.session.release() self.session = Session(self) - self.transport.begin_session() + # TODO self.transport.deprecated_begin_session() return self.session.id def release(self) -> None: - self.transport.end_session() + # TODO self.transport.deprecated_end_session() self.session = None def session_id(self) -> str | None: @@ -141,10 +144,14 @@ def to_json(self) -> dict: } def write(self, msg_id: int, data: bytes) -> None: - self.transport.write(msg_id, data) + raise NotImplementedError + # TODO + # self.transport.write(msg_id, data) def read(self) -> tuple[int, bytes]: - return self.transport.read() + raise NotImplementedError + # TODO + # return self.transport.read() @classmethod def find(cls, path: str) -> Transport | None: diff --git a/python/tools/rng_entropy_collector.py b/python/tools/rng_entropy_collector.py index 2b0a5b80d79..437561b1549 100755 --- a/python/tools/rng_entropy_collector.py +++ b/python/tools/rng_entropy_collector.py @@ -7,14 +7,15 @@ import io import sys -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.transport import get_transport def main() -> None: try: - client = TrezorClient(get_transport(), ui=ui.ClickUI()) + client = TrezorClient(get_transport()) + session = client.get_management_session() except Exception as e: print(e) return @@ -25,11 +26,9 @@ def main() -> None: with io.open(arg1, "wb") as f: for _ in range(0, arg2, step): - entropy = misc.get_entropy(client, step) + entropy = misc.get_entropy(session, step) f.write(entropy) - client.close() - if __name__ == "__main__": main() diff --git a/python/tools/trezor-otp.py b/python/tools/trezor-otp.py index bc0b66daa97..a88f745b412 100755 --- a/python/tools/trezor-otp.py +++ b/python/tools/trezor-otp.py @@ -27,26 +27,25 @@ from trezorlib.misc import decrypt_keyvalue, encrypt_keyvalue from trezorlib.tools import parse_path from trezorlib.transport import get_transport -from trezorlib.ui import ClickUI BIP32_PATH = parse_path("10016h/0") def encrypt(type: str, domain: str, secret: str) -> str: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + client = TrezorClient(transport) + session = client.get_management_session() dom = type.upper() + ": " + domain - enc = encrypt_keyvalue(client, BIP32_PATH, dom, secret.encode(), False, True) - client.close() + enc = encrypt_keyvalue(session, BIP32_PATH, dom, secret.encode(), False, True) return enc.hex() def decrypt(type: str, domain: str, secret: bytes) -> bytes: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + client = TrezorClient(transport) + session = client.get_management_session() dom = type.upper() + ": " + domain - dec = decrypt_keyvalue(client, BIP32_PATH, dom, secret, False, True) - client.close() + dec = decrypt_keyvalue(session, BIP32_PATH, dom, secret, False, True) return dec diff --git a/rust/trezor-client/src/messages/generated.rs b/rust/trezor-client/src/messages/generated.rs index 1493a938678..969f8565429 100644 --- a/rust/trezor-client/src/messages/generated.rs +++ b/rust/trezor-client/src/messages/generated.rs @@ -80,6 +80,24 @@ trezor_message_impl! { DebugLinkWatchLayout => MessageType_DebugLinkWatchLayout, DebugLinkResetDebugEvents => MessageType_DebugLinkResetDebugEvents, DebugLinkOptigaSetSecMax => MessageType_DebugLinkOptigaSetSecMax, + ThpCreateNewSession => MessageType_ThpCreateNewSession, + ThpNewSession => MessageType_ThpNewSession, + ThpStartPairingRequest => MessageType_ThpStartPairingRequest, + ThpPairingPreparationsFinished => MessageType_ThpPairingPreparationsFinished, + ThpCredentialRequest => MessageType_ThpCredentialRequest, + ThpCredentialResponse => MessageType_ThpCredentialResponse, + ThpEndRequest => MessageType_ThpEndRequest, + ThpEndResponse => MessageType_ThpEndResponse, + ThpCodeEntryCommitment => MessageType_ThpCodeEntryCommitment, + ThpCodeEntryChallenge => MessageType_ThpCodeEntryChallenge, + ThpCodeEntryCpaceHost => MessageType_ThpCodeEntryCpaceHost, + ThpCodeEntryCpaceTrezor => MessageType_ThpCodeEntryCpaceTrezor, + ThpCodeEntryTag => MessageType_ThpCodeEntryTag, + ThpCodeEntrySecret => MessageType_ThpCodeEntrySecret, + ThpQrCodeTag => MessageType_ThpQrCodeTag, + ThpQrCodeSecret => MessageType_ThpQrCodeSecret, + ThpNfcUnidirectionalTag => MessageType_ThpNfcUnidirectionalTag, + ThpNfcUnidirectionalSecret => MessageType_ThpNfcUnidirectionalSecret, } #[cfg(feature = "binance")] diff --git a/rust/trezor-client/src/protos/generated/messages.rs b/rust/trezor-client/src/protos/generated/messages.rs index 7cf263a1fed..3f719f429e7 100644 --- a/rust/trezor-client/src/protos/generated/messages.rs +++ b/rust/trezor-client/src/protos/generated/messages.rs @@ -510,6 +510,42 @@ pub enum MessageType { MessageType_SolanaSignTx = 904, // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_SolanaTxSignature) MessageType_SolanaTxSignature = 905, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCreateNewSession) + MessageType_ThpCreateNewSession = 1000, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpNewSession) + MessageType_ThpNewSession = 1001, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpStartPairingRequest) + MessageType_ThpStartPairingRequest = 1008, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpPairingPreparationsFinished) + MessageType_ThpPairingPreparationsFinished = 1009, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCredentialRequest) + MessageType_ThpCredentialRequest = 1010, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCredentialResponse) + MessageType_ThpCredentialResponse = 1011, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpEndRequest) + MessageType_ThpEndRequest = 1012, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpEndResponse) + MessageType_ThpEndResponse = 1013, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryCommitment) + MessageType_ThpCodeEntryCommitment = 1016, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryChallenge) + MessageType_ThpCodeEntryChallenge = 1017, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryCpaceHost) + MessageType_ThpCodeEntryCpaceHost = 1018, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryCpaceTrezor) + MessageType_ThpCodeEntryCpaceTrezor = 1019, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryTag) + MessageType_ThpCodeEntryTag = 1020, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntrySecret) + MessageType_ThpCodeEntrySecret = 1021, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpQrCodeTag) + MessageType_ThpQrCodeTag = 1024, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpQrCodeSecret) + MessageType_ThpQrCodeSecret = 1025, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpNfcUnidirectionalTag) + MessageType_ThpNfcUnidirectionalTag = 1032, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpNfcUnidirectionalSecret) + MessageType_ThpNfcUnidirectionalSecret = 1033, } impl ::protobuf::Enum for MessageType { @@ -762,6 +798,24 @@ impl ::protobuf::Enum for MessageType { 903 => ::std::option::Option::Some(MessageType::MessageType_SolanaAddress), 904 => ::std::option::Option::Some(MessageType::MessageType_SolanaSignTx), 905 => ::std::option::Option::Some(MessageType::MessageType_SolanaTxSignature), + 1000 => ::std::option::Option::Some(MessageType::MessageType_ThpCreateNewSession), + 1001 => ::std::option::Option::Some(MessageType::MessageType_ThpNewSession), + 1008 => ::std::option::Option::Some(MessageType::MessageType_ThpStartPairingRequest), + 1009 => ::std::option::Option::Some(MessageType::MessageType_ThpPairingPreparationsFinished), + 1010 => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialRequest), + 1011 => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialResponse), + 1012 => ::std::option::Option::Some(MessageType::MessageType_ThpEndRequest), + 1013 => ::std::option::Option::Some(MessageType::MessageType_ThpEndResponse), + 1016 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCommitment), + 1017 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryChallenge), + 1018 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceHost), + 1019 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceTrezor), + 1020 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryTag), + 1021 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntrySecret), + 1024 => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeTag), + 1025 => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeSecret), + 1032 => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalTag), + 1033 => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalSecret), _ => ::std::option::Option::None } } @@ -1009,6 +1063,24 @@ impl ::protobuf::Enum for MessageType { "MessageType_SolanaAddress" => ::std::option::Option::Some(MessageType::MessageType_SolanaAddress), "MessageType_SolanaSignTx" => ::std::option::Option::Some(MessageType::MessageType_SolanaSignTx), "MessageType_SolanaTxSignature" => ::std::option::Option::Some(MessageType::MessageType_SolanaTxSignature), + "MessageType_ThpCreateNewSession" => ::std::option::Option::Some(MessageType::MessageType_ThpCreateNewSession), + "MessageType_ThpNewSession" => ::std::option::Option::Some(MessageType::MessageType_ThpNewSession), + "MessageType_ThpStartPairingRequest" => ::std::option::Option::Some(MessageType::MessageType_ThpStartPairingRequest), + "MessageType_ThpPairingPreparationsFinished" => ::std::option::Option::Some(MessageType::MessageType_ThpPairingPreparationsFinished), + "MessageType_ThpCredentialRequest" => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialRequest), + "MessageType_ThpCredentialResponse" => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialResponse), + "MessageType_ThpEndRequest" => ::std::option::Option::Some(MessageType::MessageType_ThpEndRequest), + "MessageType_ThpEndResponse" => ::std::option::Option::Some(MessageType::MessageType_ThpEndResponse), + "MessageType_ThpCodeEntryCommitment" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCommitment), + "MessageType_ThpCodeEntryChallenge" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryChallenge), + "MessageType_ThpCodeEntryCpaceHost" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceHost), + "MessageType_ThpCodeEntryCpaceTrezor" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceTrezor), + "MessageType_ThpCodeEntryTag" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryTag), + "MessageType_ThpCodeEntrySecret" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntrySecret), + "MessageType_ThpQrCodeTag" => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeTag), + "MessageType_ThpQrCodeSecret" => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeSecret), + "MessageType_ThpNfcUnidirectionalTag" => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalTag), + "MessageType_ThpNfcUnidirectionalSecret" => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalSecret), _ => ::std::option::Option::None } } @@ -1255,6 +1327,24 @@ impl ::protobuf::Enum for MessageType { MessageType::MessageType_SolanaAddress, MessageType::MessageType_SolanaSignTx, MessageType::MessageType_SolanaTxSignature, + MessageType::MessageType_ThpCreateNewSession, + MessageType::MessageType_ThpNewSession, + MessageType::MessageType_ThpStartPairingRequest, + MessageType::MessageType_ThpPairingPreparationsFinished, + MessageType::MessageType_ThpCredentialRequest, + MessageType::MessageType_ThpCredentialResponse, + MessageType::MessageType_ThpEndRequest, + MessageType::MessageType_ThpEndResponse, + MessageType::MessageType_ThpCodeEntryCommitment, + MessageType::MessageType_ThpCodeEntryChallenge, + MessageType::MessageType_ThpCodeEntryCpaceHost, + MessageType::MessageType_ThpCodeEntryCpaceTrezor, + MessageType::MessageType_ThpCodeEntryTag, + MessageType::MessageType_ThpCodeEntrySecret, + MessageType::MessageType_ThpQrCodeTag, + MessageType::MessageType_ThpQrCodeSecret, + MessageType::MessageType_ThpNfcUnidirectionalTag, + MessageType::MessageType_ThpNfcUnidirectionalSecret, ]; } @@ -1507,6 +1597,24 @@ impl ::protobuf::EnumFull for MessageType { MessageType::MessageType_SolanaAddress => 238, MessageType::MessageType_SolanaSignTx => 239, MessageType::MessageType_SolanaTxSignature => 240, + MessageType::MessageType_ThpCreateNewSession => 241, + MessageType::MessageType_ThpNewSession => 242, + MessageType::MessageType_ThpStartPairingRequest => 243, + MessageType::MessageType_ThpPairingPreparationsFinished => 244, + MessageType::MessageType_ThpCredentialRequest => 245, + MessageType::MessageType_ThpCredentialResponse => 246, + MessageType::MessageType_ThpEndRequest => 247, + MessageType::MessageType_ThpEndResponse => 248, + MessageType::MessageType_ThpCodeEntryCommitment => 249, + MessageType::MessageType_ThpCodeEntryChallenge => 250, + MessageType::MessageType_ThpCodeEntryCpaceHost => 251, + MessageType::MessageType_ThpCodeEntryCpaceTrezor => 252, + MessageType::MessageType_ThpCodeEntryTag => 253, + MessageType::MessageType_ThpCodeEntrySecret => 254, + MessageType::MessageType_ThpQrCodeTag => 255, + MessageType::MessageType_ThpQrCodeSecret => 256, + MessageType::MessageType_ThpNfcUnidirectionalTag => 257, + MessageType::MessageType_ThpNfcUnidirectionalSecret => 258, }; Self::enum_descriptor().value_by_index(index) } @@ -1541,6 +1649,14 @@ pub mod exts { pub const wire_no_fsm: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50008, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + pub const channel_in: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50009, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + + pub const channel_out: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50010, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + + pub const pairing_in: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50011, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + + pub const pairing_out: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50012, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + pub const bitcoin_only: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(60000, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); pub const has_bitcoin_only_values: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(51001, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); @@ -1558,7 +1674,7 @@ pub mod exts { static file_descriptor_proto_data: &'static [u8] = b"\ \n\x0emessages.proto\x12\x12hw.trezor.messages\x1a\x20google/protobuf/de\ - scriptor.proto*\xe2S\n\x0bMessageType\x12(\n\x16MessageType_Initialize\ + scriptor.proto*\xc8Z\n\x0bMessageType\x12(\n\x16MessageType_Initialize\ \x10\0\x1a\x0c\x80\xa6\x1d\x01\xb0\xb5\x18\x01\x90\xb5\x18\x01\x12\x1e\n\ \x10MessageType_Ping\x10\x01\x1a\x08\x80\xa6\x1d\x01\x90\xb5\x18\x01\x12\ %\n\x13MessageType_Success\x10\x02\x1a\x0c\x80\xa6\x1d\x01\xa8\xb5\x18\ @@ -1831,31 +1947,60 @@ static file_descriptor_proto_data: &'static [u8] = b"\ \x07\x1a\x04\x90\xb5\x18\x01\x12$\n\x19MessageType_SolanaAddress\x10\x87\ \x07\x1a\x04\x98\xb5\x18\x01\x12#\n\x18MessageType_SolanaSignTx\x10\x88\ \x07\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_SolanaTxSignature\x10\ - \x89\x07\x1a\x04\x98\xb5\x18\x01\x1a\x04\xc8\xf3\x18\x01\"\x04\x08Z\x10\ - \\\"\x04\x08G\x10J\"\x04\x08r\x10z\"\x06\x08\xdb\x01\x10\xdb\x01\"\x06\ - \x08\xe0\x01\x10\xe0\x01\"\x06\x08\xac\x02\x10\xb0\x02\"\x06\x08\xb5\x02\ - \x10\xb8\x02:<\n\x07wire_in\x18\xd2\x86\x03\x20\x01(\x08\x12!.google.pro\ - tobuf.EnumValueOptionsR\x06wireIn:>\n\x08wire_out\x18\xd3\x86\x03\x20\ - \x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x07wireOut:G\n\rwire_de\ - bug_in\x18\xd4\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOption\ - sR\x0bwireDebugIn:I\n\x0ewire_debug_out\x18\xd5\x86\x03\x20\x01(\x08\x12\ - !.google.protobuf.EnumValueOptionsR\x0cwireDebugOut:@\n\twire_tiny\x18\ - \xd6\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x08wire\ - Tiny:L\n\x0fwire_bootloader\x18\xd7\x86\x03\x20\x01(\x08\x12!.google.pro\ - tobuf.EnumValueOptionsR\x0ewireBootloader:C\n\x0bwire_no_fsm\x18\xd8\x86\ - \x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\twireNoFsm:F\n\ - \x0cbitcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12!.google.protobuf.EnumV\ - alueOptionsR\x0bbitcoinOnly:U\n\x17has_bitcoin_only_values\x18\xb9\x8e\ - \x03\x20\x01(\x08\x12\x1c.google.protobuf.EnumOptionsR\x14hasBitcoinOnly\ - Values:T\n\x14experimental_message\x18\xa1\x96\x03\x20\x01(\x08\x12\x1f.\ - google.protobuf.MessageOptionsR\x13experimentalMessage:>\n\twire_type\ - \x18\xa2\x96\x03\x20\x01(\r\x12\x1f.google.protobuf.MessageOptionsR\x08w\ - ireType:F\n\rinternal_only\x18\xa3\x96\x03\x20\x01(\x08\x12\x1f.google.p\ - rotobuf.MessageOptionsR\x0cinternalOnly:N\n\x12experimental_field\x18\ - \x89\x9e\x03\x20\x01(\x08\x12\x1d.google.protobuf.FieldOptionsR\x11exper\ - imentalField:U\n\x17include_in_bitcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\ - \x12\x1c.google.protobuf.FileOptionsR\x14includeInBitcoinOnlyB8\n#com.sa\ - toshilabs.trezor.lib.protobufB\rTrezorMessage\x80\xa6\x1d\x01\ + \x89\x07\x1a\x04\x98\xb5\x18\x01\x12.\n\x1fMessageType_ThpCreateNewSessi\ + on\x10\xe8\x07\x1a\x08\x80\xa6\x1d\x01\xc8\xb5\x18\x01\x12(\n\x19Message\ + Type_ThpNewSession\x10\xe9\x07\x1a\x08\x80\xa6\x1d\x01\xd0\xb5\x18\x01\ + \x121\n\"MessageType_ThpStartPairingRequest\x10\xf0\x07\x1a\x08\x80\xa6\ + \x1d\x01\xd8\xb5\x18\x01\x129\n*MessageType_ThpPairingPreparationsFinish\ + ed\x10\xf1\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12/\n\x20Message\ + Type_ThpCredentialRequest\x10\xf2\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\ + \x18\x01\x120\n!MessageType_ThpCredentialResponse\x10\xf3\x07\x1a\x08\ + \x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12(\n\x19MessageType_ThpEndRequest\x10\ + \xf4\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12)\n\x1aMessageType_T\ + hpEndResponse\x10\xf5\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x121\n\ + \"MessageType_ThpCodeEntryCommitment\x10\xf8\x07\x1a\x08\x80\xa6\x1d\x01\ + \xe0\xb5\x18\x01\x120\n!MessageType_ThpCodeEntryChallenge\x10\xf9\x07\ + \x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x120\n!MessageType_ThpCodeEntry\ + CpaceHost\x10\xfa\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x122\n#Mes\ + sageType_ThpCodeEntryCpaceTrezor\x10\xfb\x07\x1a\x08\x80\xa6\x1d\x01\xe0\ + \xb5\x18\x01\x12*\n\x1bMessageType_ThpCodeEntryTag\x10\xfc\x07\x1a\x08\ + \x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12-\n\x1eMessageType_ThpCodeEntrySecre\ + t\x10\xfd\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12'\n\x18MessageT\ + ype_ThpQrCodeTag\x10\x80\x08\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12\ + *\n\x1bMessageType_ThpQrCodeSecret\x10\x81\x08\x1a\x08\x80\xa6\x1d\x01\ + \xe0\xb5\x18\x01\x122\n#MessageType_ThpNfcUnidirectionalTag\x10\x88\x08\ + \x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x125\n&MessageType_ThpNfcUnidir\ + ectionalSecret\x10\x89\x08\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x1a\ + \x04\xc8\xf3\x18\x01\"\x04\x08Z\x10\\\"\x04\x08G\x10J\"\x04\x08r\x10z\"\ + \x06\x08\xdb\x01\x10\xdb\x01\"\x06\x08\xe0\x01\x10\xe0\x01\"\x06\x08\xac\ + \x02\x10\xb0\x02\"\x06\x08\xb5\x02\x10\xb8\x02:<\n\x07wire_in\x18\xd2\ + \x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x06wireIn:>\ + \n\x08wire_out\x18\xd3\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumVal\ + ueOptionsR\x07wireOut:G\n\rwire_debug_in\x18\xd4\x86\x03\x20\x01(\x08\ + \x12!.google.protobuf.EnumValueOptionsR\x0bwireDebugIn:I\n\x0ewire_debug\ + _out\x18\xd5\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\ + \x0cwireDebugOut:@\n\twire_tiny\x18\xd6\x86\x03\x20\x01(\x08\x12!.google\ + .protobuf.EnumValueOptionsR\x08wireTiny:L\n\x0fwire_bootloader\x18\xd7\ + \x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x0ewireBoot\ + loader:C\n\x0bwire_no_fsm\x18\xd8\x86\x03\x20\x01(\x08\x12!.google.proto\ + buf.EnumValueOptionsR\twireNoFsm:B\n\nchannel_in\x18\xd9\x86\x03\x20\x01\ + (\x08\x12!.google.protobuf.EnumValueOptionsR\tchannelIn:D\n\x0bchannel_o\ + ut\x18\xda\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\n\ + channelOut:B\n\npairing_in\x18\xdb\x86\x03\x20\x01(\x08\x12!.google.prot\ + obuf.EnumValueOptionsR\tpairingIn:D\n\x0bpairing_out\x18\xdc\x86\x03\x20\ + \x01(\x08\x12!.google.protobuf.EnumValueOptionsR\npairingOut:F\n\x0cbitc\ + oin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOpti\ + onsR\x0bbitcoinOnly:U\n\x17has_bitcoin_only_values\x18\xb9\x8e\x03\x20\ + \x01(\x08\x12\x1c.google.protobuf.EnumOptionsR\x14hasBitcoinOnlyValues:T\ + \n\x14experimental_message\x18\xa1\x96\x03\x20\x01(\x08\x12\x1f.google.p\ + rotobuf.MessageOptionsR\x13experimentalMessage:>\n\twire_type\x18\xa2\ + \x96\x03\x20\x01(\r\x12\x1f.google.protobuf.MessageOptionsR\x08wireType:\ + F\n\rinternal_only\x18\xa3\x96\x03\x20\x01(\x08\x12\x1f.google.protobuf.\ + MessageOptionsR\x0cinternalOnly:N\n\x12experimental_field\x18\x89\x9e\ + \x03\x20\x01(\x08\x12\x1d.google.protobuf.FieldOptionsR\x11experimentalF\ + ield:U\n\x17include_in_bitcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12\x1c\ + .google.protobuf.FileOptionsR\x14includeInBitcoinOnlyB8\n#com.satoshilab\ + s.trezor.lib.protobufB\rTrezorMessage\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_common.rs b/rust/trezor-client/src/protos/generated/messages_common.rs index 46405da6692..1d0f3fde43e 100644 --- a/rust/trezor-client/src/protos/generated/messages_common.rs +++ b/rust/trezor-client/src/protos/generated/messages_common.rs @@ -414,6 +414,10 @@ pub mod failure { Failure_WipeCodeMismatch = 13, // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidSession) Failure_InvalidSession = 14, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_ThpUnallocatedSession) + Failure_ThpUnallocatedSession = 15, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidProtocol) + Failure_InvalidProtocol = 16, // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_FirmwareError) Failure_FirmwareError = 99, } @@ -441,6 +445,8 @@ pub mod failure { 12 => ::std::option::Option::Some(FailureType::Failure_PinMismatch), 13 => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch), 14 => ::std::option::Option::Some(FailureType::Failure_InvalidSession), + 15 => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession), + 16 => ::std::option::Option::Some(FailureType::Failure_InvalidProtocol), 99 => ::std::option::Option::Some(FailureType::Failure_FirmwareError), _ => ::std::option::Option::None } @@ -462,6 +468,8 @@ pub mod failure { "Failure_PinMismatch" => ::std::option::Option::Some(FailureType::Failure_PinMismatch), "Failure_WipeCodeMismatch" => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch), "Failure_InvalidSession" => ::std::option::Option::Some(FailureType::Failure_InvalidSession), + "Failure_ThpUnallocatedSession" => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession), + "Failure_InvalidProtocol" => ::std::option::Option::Some(FailureType::Failure_InvalidProtocol), "Failure_FirmwareError" => ::std::option::Option::Some(FailureType::Failure_FirmwareError), _ => ::std::option::Option::None } @@ -482,6 +490,8 @@ pub mod failure { FailureType::Failure_PinMismatch, FailureType::Failure_WipeCodeMismatch, FailureType::Failure_InvalidSession, + FailureType::Failure_ThpUnallocatedSession, + FailureType::Failure_InvalidProtocol, FailureType::Failure_FirmwareError, ]; } @@ -508,7 +518,9 @@ pub mod failure { FailureType::Failure_PinMismatch => 11, FailureType::Failure_WipeCodeMismatch => 12, FailureType::Failure_InvalidSession => 13, - FailureType::Failure_FirmwareError => 14, + FailureType::Failure_ThpUnallocatedSession => 14, + FailureType::Failure_InvalidProtocol => 15, + FailureType::Failure_FirmwareError => 16, }; Self::enum_descriptor().value_by_index(index) } @@ -2481,9 +2493,9 @@ impl ::protobuf::reflect::ProtobufValue for HDNodeType { static file_descriptor_proto_data: &'static [u8] = b"\ \n\x15messages-common.proto\x12\x19hw.trezor.messages.common\x1a\x0emess\ ages.proto\"%\n\x07Success\x12\x1a\n\x07message\x18\x01\x20\x01(\t:\0R\ - \x07message\"\x8f\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2.\ + \x07message\"\xcf\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2.\ .hw.trezor.messages.common.Failure.FailureTypeR\x04code\x12\x18\n\x07mes\ - sage\x18\x02\x20\x01(\tR\x07message\"\xa5\x03\n\x0bFailureType\x12\x1d\n\ + sage\x18\x02\x20\x01(\tR\x07message\"\xe5\x03\n\x0bFailureType\x12\x1d\n\ \x19Failure_UnexpectedMessage\x10\x01\x12\x1a\n\x16Failure_ButtonExpecte\ d\x10\x02\x12\x15\n\x11Failure_DataError\x10\x03\x12\x1b\n\x17Failure_Ac\ tionCancelled\x10\x04\x12\x17\n\x13Failure_PinExpected\x10\x05\x12\x18\n\ @@ -2492,44 +2504,45 @@ static file_descriptor_proto_data: &'static [u8] = b"\ essError\x10\t\x12\x1a\n\x16Failure_NotEnoughFunds\x10\n\x12\x1a\n\x16Fa\ ilure_NotInitialized\x10\x0b\x12\x17\n\x13Failure_PinMismatch\x10\x0c\ \x12\x1c\n\x18Failure_WipeCodeMismatch\x10\r\x12\x1a\n\x16Failure_Invali\ - dSession\x10\x0e\x12\x19\n\x15Failure_FirmwareError\x10c\"\xab\x06\n\rBu\ - ttonRequest\x12N\n\x04code\x18\x01\x20\x01(\x0e2:.hw.trezor.messages.com\ - mon.ButtonRequest.ButtonRequestTypeR\x04code\x12\x14\n\x05pages\x18\x02\ - \x20\x01(\rR\x05pages\x12\x12\n\x04name\x18\x04\x20\x01(\tR\x04name\"\ - \x99\x05\n\x11ButtonRequestType\x12\x17\n\x13ButtonRequest_Other\x10\x01\ - \x12\"\n\x1eButtonRequest_FeeOverThreshold\x10\x02\x12\x1f\n\x1bButtonRe\ - quest_ConfirmOutput\x10\x03\x12\x1d\n\x19ButtonRequest_ResetDevice\x10\ - \x04\x12\x1d\n\x19ButtonRequest_ConfirmWord\x10\x05\x12\x1c\n\x18ButtonR\ - equest_WipeDevice\x10\x06\x12\x1d\n\x19ButtonRequest_ProtectCall\x10\x07\ - \x12\x18\n\x14ButtonRequest_SignTx\x10\x08\x12\x1f\n\x1bButtonRequest_Fi\ - rmwareCheck\x10\t\x12\x19\n\x15ButtonRequest_Address\x10\n\x12\x1b\n\x17\ - ButtonRequest_PublicKey\x10\x0b\x12#\n\x1fButtonRequest_MnemonicWordCoun\ - t\x10\x0c\x12\x1f\n\x1bButtonRequest_MnemonicInput\x10\r\x120\n(_Depreca\ - ted_ButtonRequest_PassphraseType\x10\x0e\x1a\x02\x08\x01\x12'\n#ButtonRe\ - quest_UnknownDerivationPath\x10\x0f\x12\"\n\x1eButtonRequest_RecoveryHom\ - epage\x10\x10\x12\x19\n\x15ButtonRequest_Success\x10\x11\x12\x19\n\x15Bu\ - ttonRequest_Warning\x10\x12\x12!\n\x1dButtonRequest_PassphraseEntry\x10\ - \x13\x12\x1a\n\x16ButtonRequest_PinEntry\x10\x14J\x04\x08\x03\x10\x04\"\ - \x0b\n\tButtonAck\"\xbb\x02\n\x10PinMatrixRequest\x12T\n\x04type\x18\x01\ - \x20\x01(\x0e2@.hw.trezor.messages.common.PinMatrixRequest.PinMatrixRequ\ - estTypeR\x04type\"\xd0\x01\n\x14PinMatrixRequestType\x12\x20\n\x1cPinMat\ - rixRequestType_Current\x10\x01\x12!\n\x1dPinMatrixRequestType_NewFirst\ - \x10\x02\x12\"\n\x1ePinMatrixRequestType_NewSecond\x10\x03\x12&\n\"PinMa\ - trixRequestType_WipeCodeFirst\x10\x04\x12'\n#PinMatrixRequestType_WipeCo\ - deSecond\x10\x05\"\x20\n\x0cPinMatrixAck\x12\x10\n\x03pin\x18\x01\x20\ - \x02(\tR\x03pin\"5\n\x11PassphraseRequest\x12\x20\n\n_on_device\x18\x01\ - \x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"g\n\rPassphraseAck\x12\x1e\n\np\ - assphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x19\n\x06_state\x18\x02\ - \x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\n\ton_device\x18\x03\x20\ - \x01(\x08R\x08onDevice\"=\n!Deprecated_PassphraseStateRequest\x12\x14\n\ - \x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\x18\x01\"#\n\x1dDeprecated\ - _PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\nHDNodeType\x12\x14\n\x05de\ - pth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\x0bfingerprint\x18\x02\x20\ - \x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\x18\x03\x20\x02(\rR\x08chil\ - dNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\x0cR\tchainCode\x12\x1f\n\ - \x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivateKey\x12\x1d\n\npublic_key\ - \x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satoshilabs.trezor.lib.protobu\ - fB\x13TrezorMessageCommon\x80\xa6\x1d\x01\ + dSession\x10\x0e\x12!\n\x1dFailure_ThpUnallocatedSession\x10\x0f\x12\x1b\ + \n\x17Failure_InvalidProtocol\x10\x10\x12\x19\n\x15Failure_FirmwareError\ + \x10c\"\xab\x06\n\rButtonRequest\x12N\n\x04code\x18\x01\x20\x01(\x0e2:.h\ + w.trezor.messages.common.ButtonRequest.ButtonRequestTypeR\x04code\x12\ + \x14\n\x05pages\x18\x02\x20\x01(\rR\x05pages\x12\x12\n\x04name\x18\x04\ + \x20\x01(\tR\x04name\"\x99\x05\n\x11ButtonRequestType\x12\x17\n\x13Butto\ + nRequest_Other\x10\x01\x12\"\n\x1eButtonRequest_FeeOverThreshold\x10\x02\ + \x12\x1f\n\x1bButtonRequest_ConfirmOutput\x10\x03\x12\x1d\n\x19ButtonReq\ + uest_ResetDevice\x10\x04\x12\x1d\n\x19ButtonRequest_ConfirmWord\x10\x05\ + \x12\x1c\n\x18ButtonRequest_WipeDevice\x10\x06\x12\x1d\n\x19ButtonReques\ + t_ProtectCall\x10\x07\x12\x18\n\x14ButtonRequest_SignTx\x10\x08\x12\x1f\ + \n\x1bButtonRequest_FirmwareCheck\x10\t\x12\x19\n\x15ButtonRequest_Addre\ + ss\x10\n\x12\x1b\n\x17ButtonRequest_PublicKey\x10\x0b\x12#\n\x1fButtonRe\ + quest_MnemonicWordCount\x10\x0c\x12\x1f\n\x1bButtonRequest_MnemonicInput\ + \x10\r\x120\n(_Deprecated_ButtonRequest_PassphraseType\x10\x0e\x1a\x02\ + \x08\x01\x12'\n#ButtonRequest_UnknownDerivationPath\x10\x0f\x12\"\n\x1eB\ + uttonRequest_RecoveryHomepage\x10\x10\x12\x19\n\x15ButtonRequest_Success\ + \x10\x11\x12\x19\n\x15ButtonRequest_Warning\x10\x12\x12!\n\x1dButtonRequ\ + est_PassphraseEntry\x10\x13\x12\x1a\n\x16ButtonRequest_PinEntry\x10\x14J\ + \x04\x08\x03\x10\x04\"\x0b\n\tButtonAck\"\xbb\x02\n\x10PinMatrixRequest\ + \x12T\n\x04type\x18\x01\x20\x01(\x0e2@.hw.trezor.messages.common.PinMatr\ + ixRequest.PinMatrixRequestTypeR\x04type\"\xd0\x01\n\x14PinMatrixRequestT\ + ype\x12\x20\n\x1cPinMatrixRequestType_Current\x10\x01\x12!\n\x1dPinMatri\ + xRequestType_NewFirst\x10\x02\x12\"\n\x1ePinMatrixRequestType_NewSecond\ + \x10\x03\x12&\n\"PinMatrixRequestType_WipeCodeFirst\x10\x04\x12'\n#PinMa\ + trixRequestType_WipeCodeSecond\x10\x05\"\x20\n\x0cPinMatrixAck\x12\x10\n\ + \x03pin\x18\x01\x20\x02(\tR\x03pin\"5\n\x11PassphraseRequest\x12\x20\n\n\ + _on_device\x18\x01\x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"g\n\rPassphra\ + seAck\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x19\n\ + \x06_state\x18\x02\x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\n\ton_dev\ + ice\x18\x03\x20\x01(\x08R\x08onDevice\"=\n!Deprecated_PassphraseStateReq\ + uest\x12\x14\n\x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\x18\x01\"#\n\ + \x1dDeprecated_PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\nHDNodeType\ + \x12\x14\n\x05depth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\x0bfingerprin\ + t\x18\x02\x20\x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\x18\x03\x20\ + \x02(\rR\x08childNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\x0cR\tchainC\ + ode\x12\x1f\n\x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivateKey\x12\x1d\n\ + \npublic_key\x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satoshilabs.trezor\ + .lib.protobufB\x13TrezorMessageCommon\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_debug.rs b/rust/trezor-client/src/protos/generated/messages_debug.rs index 3ced865e262..4021a1d098c 100644 --- a/rust/trezor-client/src/protos/generated/messages_debug.rs +++ b/rust/trezor-client/src/protos/generated/messages_debug.rs @@ -1128,6 +1128,8 @@ pub struct DebugLinkGetState { pub wait_word_pos: ::std::option::Option, // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetState.wait_layout) pub wait_layout: ::std::option::Option<::protobuf::EnumOrUnknown>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetState.thp_channel_id) + pub thp_channel_id: ::std::option::Option<::std::vec::Vec>, // special fields // @@protoc_insertion_point(special_field:hw.trezor.messages.debug.DebugLinkGetState.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -1204,8 +1206,44 @@ impl DebugLinkGetState { self.wait_layout = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); } + // optional bytes thp_channel_id = 4; + + pub fn thp_channel_id(&self) -> &[u8] { + match self.thp_channel_id.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_thp_channel_id(&mut self) { + self.thp_channel_id = ::std::option::Option::None; + } + + pub fn has_thp_channel_id(&self) -> bool { + self.thp_channel_id.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_channel_id(&mut self, v: ::std::vec::Vec) { + self.thp_channel_id = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_thp_channel_id(&mut self) -> &mut ::std::vec::Vec { + if self.thp_channel_id.is_none() { + self.thp_channel_id = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.thp_channel_id.as_mut().unwrap() + } + + // Take field + pub fn take_thp_channel_id(&mut self) -> ::std::vec::Vec { + self.thp_channel_id.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(3); + let mut fields = ::std::vec::Vec::with_capacity(4); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( "wait_word_list", @@ -1222,6 +1260,11 @@ impl DebugLinkGetState { |m: &DebugLinkGetState| { &m.wait_layout }, |m: &mut DebugLinkGetState| { &mut m.wait_layout }, )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_channel_id", + |m: &DebugLinkGetState| { &m.thp_channel_id }, + |m: &mut DebugLinkGetState| { &mut m.thp_channel_id }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "DebugLinkGetState", fields, @@ -1249,6 +1292,9 @@ impl ::protobuf::Message for DebugLinkGetState { 24 => { self.wait_layout = ::std::option::Option::Some(is.read_enum_or_unknown()?); }, + 34 => { + self.thp_channel_id = ::std::option::Option::Some(is.read_bytes()?); + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -1270,6 +1316,9 @@ impl ::protobuf::Message for DebugLinkGetState { if let Some(v) = self.wait_layout { my_size += ::protobuf::rt::int32_size(3, v.value()); } + if let Some(v) = self.thp_channel_id.as_ref() { + my_size += ::protobuf::rt::bytes_size(4, &v); + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -1285,6 +1334,9 @@ impl ::protobuf::Message for DebugLinkGetState { if let Some(v) = self.wait_layout { os.write_enum(3, ::protobuf::EnumOrUnknown::value(&v))?; } + if let Some(v) = self.thp_channel_id.as_ref() { + os.write_bytes(4, v)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -1305,6 +1357,7 @@ impl ::protobuf::Message for DebugLinkGetState { self.wait_word_list = ::std::option::Option::None; self.wait_word_pos = ::std::option::Option::None; self.wait_layout = ::std::option::Option::None; + self.thp_channel_id = ::std::option::Option::None; self.special_fields.clear(); } @@ -1313,6 +1366,7 @@ impl ::protobuf::Message for DebugLinkGetState { wait_word_list: ::std::option::Option::None, wait_word_pos: ::std::option::Option::None, wait_layout: ::std::option::Option::None, + thp_channel_id: ::std::option::Option::None, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -1436,6 +1490,12 @@ pub struct DebugLinkState { pub mnemonic_type: ::std::option::Option<::protobuf::EnumOrUnknown>, // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.tokens) pub tokens: ::std::vec::Vec<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.thp_pairing_code_entry_code) + pub thp_pairing_code_entry_code: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.thp_pairing_code_qr_code) + pub thp_pairing_code_qr_code: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.thp_pairing_code_nfc_unidirectional) + pub thp_pairing_code_nfc_unidirectional: ::std::option::Option<::std::vec::Vec>, // special fields // @@protoc_insertion_point(special_field:hw.trezor.messages.debug.DebugLinkState.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -1783,8 +1843,99 @@ impl DebugLinkState { self.mnemonic_type = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); } + // optional uint32 thp_pairing_code_entry_code = 14; + + pub fn thp_pairing_code_entry_code(&self) -> u32 { + self.thp_pairing_code_entry_code.unwrap_or(0) + } + + pub fn clear_thp_pairing_code_entry_code(&mut self) { + self.thp_pairing_code_entry_code = ::std::option::Option::None; + } + + pub fn has_thp_pairing_code_entry_code(&self) -> bool { + self.thp_pairing_code_entry_code.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_pairing_code_entry_code(&mut self, v: u32) { + self.thp_pairing_code_entry_code = ::std::option::Option::Some(v); + } + + // optional bytes thp_pairing_code_qr_code = 15; + + pub fn thp_pairing_code_qr_code(&self) -> &[u8] { + match self.thp_pairing_code_qr_code.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_thp_pairing_code_qr_code(&mut self) { + self.thp_pairing_code_qr_code = ::std::option::Option::None; + } + + pub fn has_thp_pairing_code_qr_code(&self) -> bool { + self.thp_pairing_code_qr_code.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_pairing_code_qr_code(&mut self, v: ::std::vec::Vec) { + self.thp_pairing_code_qr_code = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_thp_pairing_code_qr_code(&mut self) -> &mut ::std::vec::Vec { + if self.thp_pairing_code_qr_code.is_none() { + self.thp_pairing_code_qr_code = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.thp_pairing_code_qr_code.as_mut().unwrap() + } + + // Take field + pub fn take_thp_pairing_code_qr_code(&mut self) -> ::std::vec::Vec { + self.thp_pairing_code_qr_code.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes thp_pairing_code_nfc_unidirectional = 16; + + pub fn thp_pairing_code_nfc_unidirectional(&self) -> &[u8] { + match self.thp_pairing_code_nfc_unidirectional.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_thp_pairing_code_nfc_unidirectional(&mut self) { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::None; + } + + pub fn has_thp_pairing_code_nfc_unidirectional(&self) -> bool { + self.thp_pairing_code_nfc_unidirectional.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_pairing_code_nfc_unidirectional(&mut self, v: ::std::vec::Vec) { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_thp_pairing_code_nfc_unidirectional(&mut self) -> &mut ::std::vec::Vec { + if self.thp_pairing_code_nfc_unidirectional.is_none() { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.thp_pairing_code_nfc_unidirectional.as_mut().unwrap() + } + + // Take field + pub fn take_thp_pairing_code_nfc_unidirectional(&mut self) -> ::std::vec::Vec { + self.thp_pairing_code_nfc_unidirectional.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(13); + let mut fields = ::std::vec::Vec::with_capacity(16); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( "layout", @@ -1851,6 +2002,21 @@ impl DebugLinkState { |m: &DebugLinkState| { &m.tokens }, |m: &mut DebugLinkState| { &mut m.tokens }, )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_pairing_code_entry_code", + |m: &DebugLinkState| { &m.thp_pairing_code_entry_code }, + |m: &mut DebugLinkState| { &mut m.thp_pairing_code_entry_code }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_pairing_code_qr_code", + |m: &DebugLinkState| { &m.thp_pairing_code_qr_code }, + |m: &mut DebugLinkState| { &mut m.thp_pairing_code_qr_code }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_pairing_code_nfc_unidirectional", + |m: &DebugLinkState| { &m.thp_pairing_code_nfc_unidirectional }, + |m: &mut DebugLinkState| { &mut m.thp_pairing_code_nfc_unidirectional }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "DebugLinkState", fields, @@ -1913,6 +2079,15 @@ impl ::protobuf::Message for DebugLinkState { 106 => { self.tokens.push(is.read_string()?); }, + 112 => { + self.thp_pairing_code_entry_code = ::std::option::Option::Some(is.read_uint32()?); + }, + 122 => { + self.thp_pairing_code_qr_code = ::std::option::Option::Some(is.read_bytes()?); + }, + 130 => { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::Some(is.read_bytes()?); + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -1965,6 +2140,15 @@ impl ::protobuf::Message for DebugLinkState { for value in &self.tokens { my_size += ::protobuf::rt::string_size(13, &value); }; + if let Some(v) = self.thp_pairing_code_entry_code { + my_size += ::protobuf::rt::uint32_size(14, v); + } + if let Some(v) = self.thp_pairing_code_qr_code.as_ref() { + my_size += ::protobuf::rt::bytes_size(15, &v); + } + if let Some(v) = self.thp_pairing_code_nfc_unidirectional.as_ref() { + my_size += ::protobuf::rt::bytes_size(16, &v); + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -2010,6 +2194,15 @@ impl ::protobuf::Message for DebugLinkState { for v in &self.tokens { os.write_string(13, &v)?; }; + if let Some(v) = self.thp_pairing_code_entry_code { + os.write_uint32(14, v)?; + } + if let Some(v) = self.thp_pairing_code_qr_code.as_ref() { + os.write_bytes(15, v)?; + } + if let Some(v) = self.thp_pairing_code_nfc_unidirectional.as_ref() { + os.write_bytes(16, v)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -2040,6 +2233,9 @@ impl ::protobuf::Message for DebugLinkState { self.reset_word_pos = ::std::option::Option::None; self.mnemonic_type = ::std::option::Option::None; self.tokens.clear(); + self.thp_pairing_code_entry_code = ::std::option::Option::None; + self.thp_pairing_code_qr_code = ::std::option::Option::None; + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::None; self.special_fields.clear(); } @@ -2058,6 +2254,9 @@ impl ::protobuf::Message for DebugLinkState { reset_word_pos: ::std::option::Option::None, mnemonic_type: ::std::option::Option::None, tokens: ::std::vec::Vec::new(), + thp_pairing_code_entry_code: ::std::option::Option::None, + thp_pairing_code_qr_code: ::std::option::Option::None, + thp_pairing_code_nfc_unidirectional: ::std::option::Option::None, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -3650,39 +3849,44 @@ static file_descriptor_proto_data: &'static [u8] = b"\ \x01\x20\x03(\tR\x06tokens:\x02\x18\x01\"-\n\x15DebugLinkReseedRandom\ \x12\x14\n\x05value\x18\x01\x20\x01(\rR\x05value\"j\n\x15DebugLinkRecord\ Screen\x12)\n\x10target_directory\x18\x01\x20\x01(\tR\x0ftargetDirectory\ - \x12&\n\rrefresh_index\x18\x02\x20\x01(\r:\x010R\x0crefreshIndex\"\x91\ + \x12&\n\rrefresh_index\x18\x02\x20\x01(\r:\x010R\x0crefreshIndex\"\xb7\ \x02\n\x11DebugLinkGetState\x12(\n\x0ewait_word_list\x18\x01\x20\x01(\ \x08R\x0cwaitWordListB\x02\x18\x01\x12&\n\rwait_word_pos\x18\x02\x20\x01\ (\x08R\x0bwaitWordPosB\x02\x18\x01\x12e\n\x0bwait_layout\x18\x03\x20\x01\ (\x0e29.hw.trezor.messages.debug.DebugLinkGetState.DebugWaitType:\tIMMED\ - IATER\nwaitLayout\"C\n\rDebugWaitType\x12\r\n\tIMMEDIATE\x10\0\x12\x0f\n\ - \x0bNEXT_LAYOUT\x10\x01\x12\x12\n\x0eCURRENT_LAYOUT\x10\x02\"\x97\x04\n\ - \x0eDebugLinkState\x12\x16\n\x06layout\x18\x01\x20\x01(\x0cR\x06layout\ - \x12\x10\n\x03pin\x18\x02\x20\x01(\tR\x03pin\x12\x16\n\x06matrix\x18\x03\ - \x20\x01(\tR\x06matrix\x12'\n\x0fmnemonic_secret\x18\x04\x20\x01(\x0cR\ - \x0emnemonicSecret\x129\n\x04node\x18\x05\x20\x01(\x0b2%.hw.trezor.messa\ - ges.common.HDNodeTypeR\x04node\x123\n\x15passphrase_protection\x18\x06\ - \x20\x01(\x08R\x14passphraseProtection\x12\x1d\n\nreset_word\x18\x07\x20\ - \x01(\tR\tresetWord\x12#\n\rreset_entropy\x18\x08\x20\x01(\x0cR\x0creset\ - Entropy\x12,\n\x12recovery_fake_word\x18\t\x20\x01(\tR\x10recoveryFakeWo\ - rd\x12*\n\x11recovery_word_pos\x18\n\x20\x01(\rR\x0frecoveryWordPos\x12$\ - \n\x0ereset_word_pos\x18\x0b\x20\x01(\rR\x0cresetWordPos\x12N\n\rmnemoni\ - c_type\x18\x0c\x20\x01(\x0e2).hw.trezor.messages.management.BackupTypeR\ - \x0cmnemonicType\x12\x16\n\x06tokens\x18\r\x20\x03(\tR\x06tokens\"\x0f\n\ - \rDebugLinkStop\"P\n\x0cDebugLinkLog\x12\x14\n\x05level\x18\x01\x20\x01(\ - \rR\x05level\x12\x16\n\x06bucket\x18\x02\x20\x01(\tR\x06bucket\x12\x12\n\ - \x04text\x18\x03\x20\x01(\tR\x04text\"G\n\x13DebugLinkMemoryRead\x12\x18\ - \n\x07address\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06length\x18\x02\ - \x20\x01(\rR\x06length\")\n\x0fDebugLinkMemory\x12\x16\n\x06memory\x18\ - \x01\x20\x01(\x0cR\x06memory\"^\n\x14DebugLinkMemoryWrite\x12\x18\n\x07a\ - ddress\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06memory\x18\x02\x20\ - \x01(\x0cR\x06memory\x12\x14\n\x05flash\x18\x03\x20\x01(\x08R\x05flash\"\ - -\n\x13DebugLinkFlashErase\x12\x16\n\x06sector\x18\x01\x20\x01(\rR\x06se\ - ctor\".\n\x14DebugLinkEraseSdCard\x12\x16\n\x06format\x18\x01\x20\x01(\ - \x08R\x06format\"0\n\x14DebugLinkWatchLayout\x12\x14\n\x05watch\x18\x01\ - \x20\x01(\x08R\x05watch:\x02\x18\x01\"\x1f\n\x19DebugLinkResetDebugEvent\ - s:\x02\x18\x01\"\x1a\n\x18DebugLinkOptigaSetSecMaxB=\n#com.satoshilabs.t\ - rezor.lib.protobufB\x12TrezorMessageDebug\x80\xa6\x1d\x01\ + IATER\nwaitLayout\x12$\n\x0ethp_channel_id\x18\x04\x20\x01(\x0cR\x0cthpC\ + hannelId\"C\n\rDebugWaitType\x12\r\n\tIMMEDIATE\x10\0\x12\x0f\n\x0bNEXT_\ + LAYOUT\x10\x01\x12\x12\n\x0eCURRENT_LAYOUT\x10\x02\"\xdb\x05\n\x0eDebugL\ + inkState\x12\x16\n\x06layout\x18\x01\x20\x01(\x0cR\x06layout\x12\x10\n\ + \x03pin\x18\x02\x20\x01(\tR\x03pin\x12\x16\n\x06matrix\x18\x03\x20\x01(\ + \tR\x06matrix\x12'\n\x0fmnemonic_secret\x18\x04\x20\x01(\x0cR\x0emnemoni\ + cSecret\x129\n\x04node\x18\x05\x20\x01(\x0b2%.hw.trezor.messages.common.\ + HDNodeTypeR\x04node\x123\n\x15passphrase_protection\x18\x06\x20\x01(\x08\ + R\x14passphraseProtection\x12\x1d\n\nreset_word\x18\x07\x20\x01(\tR\tres\ + etWord\x12#\n\rreset_entropy\x18\x08\x20\x01(\x0cR\x0cresetEntropy\x12,\ + \n\x12recovery_fake_word\x18\t\x20\x01(\tR\x10recoveryFakeWord\x12*\n\ + \x11recovery_word_pos\x18\n\x20\x01(\rR\x0frecoveryWordPos\x12$\n\x0eres\ + et_word_pos\x18\x0b\x20\x01(\rR\x0cresetWordPos\x12N\n\rmnemonic_type\ + \x18\x0c\x20\x01(\x0e2).hw.trezor.messages.management.BackupTypeR\x0cmne\ + monicType\x12\x16\n\x06tokens\x18\r\x20\x03(\tR\x06tokens\x12<\n\x1bthp_\ + pairing_code_entry_code\x18\x0e\x20\x01(\rR\x17thpPairingCodeEntryCode\ + \x126\n\x18thp_pairing_code_qr_code\x18\x0f\x20\x01(\x0cR\x14thpPairingC\ + odeQrCode\x12L\n#thp_pairing_code_nfc_unidirectional\x18\x10\x20\x01(\ + \x0cR\x1fthpPairingCodeNfcUnidirectional\"\x0f\n\rDebugLinkStop\"P\n\x0c\ + DebugLinkLog\x12\x14\n\x05level\x18\x01\x20\x01(\rR\x05level\x12\x16\n\ + \x06bucket\x18\x02\x20\x01(\tR\x06bucket\x12\x12\n\x04text\x18\x03\x20\ + \x01(\tR\x04text\"G\n\x13DebugLinkMemoryRead\x12\x18\n\x07address\x18\ + \x01\x20\x01(\rR\x07address\x12\x16\n\x06length\x18\x02\x20\x01(\rR\x06l\ + ength\")\n\x0fDebugLinkMemory\x12\x16\n\x06memory\x18\x01\x20\x01(\x0cR\ + \x06memory\"^\n\x14DebugLinkMemoryWrite\x12\x18\n\x07address\x18\x01\x20\ + \x01(\rR\x07address\x12\x16\n\x06memory\x18\x02\x20\x01(\x0cR\x06memory\ + \x12\x14\n\x05flash\x18\x03\x20\x01(\x08R\x05flash\"-\n\x13DebugLinkFlas\ + hErase\x12\x16\n\x06sector\x18\x01\x20\x01(\rR\x06sector\".\n\x14DebugLi\ + nkEraseSdCard\x12\x16\n\x06format\x18\x01\x20\x01(\x08R\x06format\"0\n\ + \x14DebugLinkWatchLayout\x12\x14\n\x05watch\x18\x01\x20\x01(\x08R\x05wat\ + ch:\x02\x18\x01\"\x1f\n\x19DebugLinkResetDebugEvents:\x02\x18\x01\"\x1a\ + \n\x18DebugLinkOptigaSetSecMaxB=\n#com.satoshilabs.trezor.lib.protobufB\ + \x12TrezorMessageDebug\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_thp.rs b/rust/trezor-client/src/protos/generated/messages_thp.rs index b42c245b3c8..6c5d1c92f38 100644 --- a/rust/trezor-client/src/protos/generated/messages_thp.rs +++ b/rust/trezor-client/src/protos/generated/messages_thp.rs @@ -25,6 +25,3265 @@ /// of protobuf runtime. const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_3_3_0; +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpDeviceProperties) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpDeviceProperties { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.internal_model) + pub internal_model: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.model_variant) + pub model_variant: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.bootloader_mode) + pub bootloader_mode: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.protocol_version) + pub protocol_version: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.pairing_methods) + pub pairing_methods: ::std::vec::Vec<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpDeviceProperties.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpDeviceProperties { + fn default() -> &'a ThpDeviceProperties { + ::default_instance() + } +} + +impl ThpDeviceProperties { + pub fn new() -> ThpDeviceProperties { + ::std::default::Default::default() + } + + // optional string internal_model = 1; + + pub fn internal_model(&self) -> &str { + match self.internal_model.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_internal_model(&mut self) { + self.internal_model = ::std::option::Option::None; + } + + pub fn has_internal_model(&self) -> bool { + self.internal_model.is_some() + } + + // Param is passed by value, moved + pub fn set_internal_model(&mut self, v: ::std::string::String) { + self.internal_model = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_internal_model(&mut self) -> &mut ::std::string::String { + if self.internal_model.is_none() { + self.internal_model = ::std::option::Option::Some(::std::string::String::new()); + } + self.internal_model.as_mut().unwrap() + } + + // Take field + pub fn take_internal_model(&mut self) -> ::std::string::String { + self.internal_model.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional uint32 model_variant = 2; + + pub fn model_variant(&self) -> u32 { + self.model_variant.unwrap_or(0) + } + + pub fn clear_model_variant(&mut self) { + self.model_variant = ::std::option::Option::None; + } + + pub fn has_model_variant(&self) -> bool { + self.model_variant.is_some() + } + + // Param is passed by value, moved + pub fn set_model_variant(&mut self, v: u32) { + self.model_variant = ::std::option::Option::Some(v); + } + + // optional bool bootloader_mode = 3; + + pub fn bootloader_mode(&self) -> bool { + self.bootloader_mode.unwrap_or(false) + } + + pub fn clear_bootloader_mode(&mut self) { + self.bootloader_mode = ::std::option::Option::None; + } + + pub fn has_bootloader_mode(&self) -> bool { + self.bootloader_mode.is_some() + } + + // Param is passed by value, moved + pub fn set_bootloader_mode(&mut self, v: bool) { + self.bootloader_mode = ::std::option::Option::Some(v); + } + + // optional uint32 protocol_version = 4; + + pub fn protocol_version(&self) -> u32 { + self.protocol_version.unwrap_or(0) + } + + pub fn clear_protocol_version(&mut self) { + self.protocol_version = ::std::option::Option::None; + } + + pub fn has_protocol_version(&self) -> bool { + self.protocol_version.is_some() + } + + // Param is passed by value, moved + pub fn set_protocol_version(&mut self, v: u32) { + self.protocol_version = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(5); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "internal_model", + |m: &ThpDeviceProperties| { &m.internal_model }, + |m: &mut ThpDeviceProperties| { &mut m.internal_model }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "model_variant", + |m: &ThpDeviceProperties| { &m.model_variant }, + |m: &mut ThpDeviceProperties| { &mut m.model_variant }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "bootloader_mode", + |m: &ThpDeviceProperties| { &m.bootloader_mode }, + |m: &mut ThpDeviceProperties| { &mut m.bootloader_mode }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "protocol_version", + |m: &ThpDeviceProperties| { &m.protocol_version }, + |m: &mut ThpDeviceProperties| { &mut m.protocol_version }, + )); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "pairing_methods", + |m: &ThpDeviceProperties| { &m.pairing_methods }, + |m: &mut ThpDeviceProperties| { &mut m.pairing_methods }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpDeviceProperties", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpDeviceProperties { + const NAME: &'static str = "ThpDeviceProperties"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.internal_model = ::std::option::Option::Some(is.read_string()?); + }, + 16 => { + self.model_variant = ::std::option::Option::Some(is.read_uint32()?); + }, + 24 => { + self.bootloader_mode = ::std::option::Option::Some(is.read_bool()?); + }, + 32 => { + self.protocol_version = ::std::option::Option::Some(is.read_uint32()?); + }, + 40 => { + self.pairing_methods.push(is.read_enum_or_unknown()?); + }, + 42 => { + ::protobuf::rt::read_repeated_packed_enum_or_unknown_into(is, &mut self.pairing_methods)? + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.internal_model.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.model_variant { + my_size += ::protobuf::rt::uint32_size(2, v); + } + if let Some(v) = self.bootloader_mode { + my_size += 1 + 1; + } + if let Some(v) = self.protocol_version { + my_size += ::protobuf::rt::uint32_size(4, v); + } + for value in &self.pairing_methods { + my_size += ::protobuf::rt::int32_size(5, value.value()); + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.internal_model.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.model_variant { + os.write_uint32(2, v)?; + } + if let Some(v) = self.bootloader_mode { + os.write_bool(3, v)?; + } + if let Some(v) = self.protocol_version { + os.write_uint32(4, v)?; + } + for v in &self.pairing_methods { + os.write_enum(5, ::protobuf::EnumOrUnknown::value(v))?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpDeviceProperties { + ThpDeviceProperties::new() + } + + fn clear(&mut self) { + self.internal_model = ::std::option::Option::None; + self.model_variant = ::std::option::Option::None; + self.bootloader_mode = ::std::option::Option::None; + self.protocol_version = ::std::option::Option::None; + self.pairing_methods.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpDeviceProperties { + static instance: ThpDeviceProperties = ThpDeviceProperties { + internal_model: ::std::option::Option::None, + model_variant: ::std::option::Option::None, + bootloader_mode: ::std::option::Option::None, + protocol_version: ::std::option::Option::None, + pairing_methods: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpDeviceProperties { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpDeviceProperties").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpDeviceProperties { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpDeviceProperties { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpHandshakeCompletionReqNoisePayload { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.host_pairing_credential) + pub host_pairing_credential: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.pairing_methods) + pub pairing_methods: ::std::vec::Vec<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpHandshakeCompletionReqNoisePayload { + fn default() -> &'a ThpHandshakeCompletionReqNoisePayload { + ::default_instance() + } +} + +impl ThpHandshakeCompletionReqNoisePayload { + pub fn new() -> ThpHandshakeCompletionReqNoisePayload { + ::std::default::Default::default() + } + + // optional bytes host_pairing_credential = 1; + + pub fn host_pairing_credential(&self) -> &[u8] { + match self.host_pairing_credential.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_host_pairing_credential(&mut self) { + self.host_pairing_credential = ::std::option::Option::None; + } + + pub fn has_host_pairing_credential(&self) -> bool { + self.host_pairing_credential.is_some() + } + + // Param is passed by value, moved + pub fn set_host_pairing_credential(&mut self, v: ::std::vec::Vec) { + self.host_pairing_credential = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_pairing_credential(&mut self) -> &mut ::std::vec::Vec { + if self.host_pairing_credential.is_none() { + self.host_pairing_credential = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.host_pairing_credential.as_mut().unwrap() + } + + // Take field + pub fn take_host_pairing_credential(&mut self) -> ::std::vec::Vec { + self.host_pairing_credential.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_pairing_credential", + |m: &ThpHandshakeCompletionReqNoisePayload| { &m.host_pairing_credential }, + |m: &mut ThpHandshakeCompletionReqNoisePayload| { &mut m.host_pairing_credential }, + )); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "pairing_methods", + |m: &ThpHandshakeCompletionReqNoisePayload| { &m.pairing_methods }, + |m: &mut ThpHandshakeCompletionReqNoisePayload| { &mut m.pairing_methods }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpHandshakeCompletionReqNoisePayload", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpHandshakeCompletionReqNoisePayload { + const NAME: &'static str = "ThpHandshakeCompletionReqNoisePayload"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_pairing_credential = ::std::option::Option::Some(is.read_bytes()?); + }, + 16 => { + self.pairing_methods.push(is.read_enum_or_unknown()?); + }, + 18 => { + ::protobuf::rt::read_repeated_packed_enum_or_unknown_into(is, &mut self.pairing_methods)? + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_pairing_credential.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + for value in &self.pairing_methods { + my_size += ::protobuf::rt::int32_size(2, value.value()); + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_pairing_credential.as_ref() { + os.write_bytes(1, v)?; + } + for v in &self.pairing_methods { + os.write_enum(2, ::protobuf::EnumOrUnknown::value(v))?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpHandshakeCompletionReqNoisePayload { + ThpHandshakeCompletionReqNoisePayload::new() + } + + fn clear(&mut self) { + self.host_pairing_credential = ::std::option::Option::None; + self.pairing_methods.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpHandshakeCompletionReqNoisePayload { + static instance: ThpHandshakeCompletionReqNoisePayload = ThpHandshakeCompletionReqNoisePayload { + host_pairing_credential: ::std::option::Option::None, + pairing_methods: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpHandshakeCompletionReqNoisePayload { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpHandshakeCompletionReqNoisePayload").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpHandshakeCompletionReqNoisePayload { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpHandshakeCompletionReqNoisePayload { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCreateNewSession) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCreateNewSession { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.passphrase) + pub passphrase: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.on_device) + pub on_device: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.derive_cardano) + pub derive_cardano: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCreateNewSession.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCreateNewSession { + fn default() -> &'a ThpCreateNewSession { + ::default_instance() + } +} + +impl ThpCreateNewSession { + pub fn new() -> ThpCreateNewSession { + ::std::default::Default::default() + } + + // optional string passphrase = 1; + + pub fn passphrase(&self) -> &str { + match self.passphrase.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_passphrase(&mut self) { + self.passphrase = ::std::option::Option::None; + } + + pub fn has_passphrase(&self) -> bool { + self.passphrase.is_some() + } + + // Param is passed by value, moved + pub fn set_passphrase(&mut self, v: ::std::string::String) { + self.passphrase = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_passphrase(&mut self) -> &mut ::std::string::String { + if self.passphrase.is_none() { + self.passphrase = ::std::option::Option::Some(::std::string::String::new()); + } + self.passphrase.as_mut().unwrap() + } + + // Take field + pub fn take_passphrase(&mut self) -> ::std::string::String { + self.passphrase.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional bool on_device = 2; + + pub fn on_device(&self) -> bool { + self.on_device.unwrap_or(false) + } + + pub fn clear_on_device(&mut self) { + self.on_device = ::std::option::Option::None; + } + + pub fn has_on_device(&self) -> bool { + self.on_device.is_some() + } + + // Param is passed by value, moved + pub fn set_on_device(&mut self, v: bool) { + self.on_device = ::std::option::Option::Some(v); + } + + // optional bool derive_cardano = 3; + + pub fn derive_cardano(&self) -> bool { + self.derive_cardano.unwrap_or(false) + } + + pub fn clear_derive_cardano(&mut self) { + self.derive_cardano = ::std::option::Option::None; + } + + pub fn has_derive_cardano(&self) -> bool { + self.derive_cardano.is_some() + } + + // Param is passed by value, moved + pub fn set_derive_cardano(&mut self, v: bool) { + self.derive_cardano = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(3); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "passphrase", + |m: &ThpCreateNewSession| { &m.passphrase }, + |m: &mut ThpCreateNewSession| { &mut m.passphrase }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "on_device", + |m: &ThpCreateNewSession| { &m.on_device }, + |m: &mut ThpCreateNewSession| { &mut m.on_device }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "derive_cardano", + |m: &ThpCreateNewSession| { &m.derive_cardano }, + |m: &mut ThpCreateNewSession| { &mut m.derive_cardano }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCreateNewSession", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCreateNewSession { + const NAME: &'static str = "ThpCreateNewSession"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.passphrase = ::std::option::Option::Some(is.read_string()?); + }, + 16 => { + self.on_device = ::std::option::Option::Some(is.read_bool()?); + }, + 24 => { + self.derive_cardano = ::std::option::Option::Some(is.read_bool()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.passphrase.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.on_device { + my_size += 1 + 1; + } + if let Some(v) = self.derive_cardano { + my_size += 1 + 1; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.passphrase.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.on_device { + os.write_bool(2, v)?; + } + if let Some(v) = self.derive_cardano { + os.write_bool(3, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCreateNewSession { + ThpCreateNewSession::new() + } + + fn clear(&mut self) { + self.passphrase = ::std::option::Option::None; + self.on_device = ::std::option::Option::None; + self.derive_cardano = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCreateNewSession { + static instance: ThpCreateNewSession = ThpCreateNewSession { + passphrase: ::std::option::Option::None, + on_device: ::std::option::Option::None, + derive_cardano: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCreateNewSession { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCreateNewSession").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCreateNewSession { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCreateNewSession { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNewSession) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNewSession { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNewSession.new_session_id) + pub new_session_id: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNewSession.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNewSession { + fn default() -> &'a ThpNewSession { + ::default_instance() + } +} + +impl ThpNewSession { + pub fn new() -> ThpNewSession { + ::std::default::Default::default() + } + + // optional uint32 new_session_id = 1; + + pub fn new_session_id(&self) -> u32 { + self.new_session_id.unwrap_or(0) + } + + pub fn clear_new_session_id(&mut self) { + self.new_session_id = ::std::option::Option::None; + } + + pub fn has_new_session_id(&self) -> bool { + self.new_session_id.is_some() + } + + // Param is passed by value, moved + pub fn set_new_session_id(&mut self, v: u32) { + self.new_session_id = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "new_session_id", + |m: &ThpNewSession| { &m.new_session_id }, + |m: &mut ThpNewSession| { &mut m.new_session_id }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNewSession", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNewSession { + const NAME: &'static str = "ThpNewSession"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 8 => { + self.new_session_id = ::std::option::Option::Some(is.read_uint32()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.new_session_id { + my_size += ::protobuf::rt::uint32_size(1, v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.new_session_id { + os.write_uint32(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNewSession { + ThpNewSession::new() + } + + fn clear(&mut self) { + self.new_session_id = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNewSession { + static instance: ThpNewSession = ThpNewSession { + new_session_id: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNewSession { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNewSession").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNewSession { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNewSession { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpStartPairingRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpStartPairingRequest { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpStartPairingRequest.host_name) + pub host_name: ::std::option::Option<::std::string::String>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpStartPairingRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpStartPairingRequest { + fn default() -> &'a ThpStartPairingRequest { + ::default_instance() + } +} + +impl ThpStartPairingRequest { + pub fn new() -> ThpStartPairingRequest { + ::std::default::Default::default() + } + + // optional string host_name = 1; + + pub fn host_name(&self) -> &str { + match self.host_name.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_host_name(&mut self) { + self.host_name = ::std::option::Option::None; + } + + pub fn has_host_name(&self) -> bool { + self.host_name.is_some() + } + + // Param is passed by value, moved + pub fn set_host_name(&mut self, v: ::std::string::String) { + self.host_name = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_name(&mut self) -> &mut ::std::string::String { + if self.host_name.is_none() { + self.host_name = ::std::option::Option::Some(::std::string::String::new()); + } + self.host_name.as_mut().unwrap() + } + + // Take field + pub fn take_host_name(&mut self) -> ::std::string::String { + self.host_name.take().unwrap_or_else(|| ::std::string::String::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_name", + |m: &ThpStartPairingRequest| { &m.host_name }, + |m: &mut ThpStartPairingRequest| { &mut m.host_name }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpStartPairingRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpStartPairingRequest { + const NAME: &'static str = "ThpStartPairingRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_name = ::std::option::Option::Some(is.read_string()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_name.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_name.as_ref() { + os.write_string(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpStartPairingRequest { + ThpStartPairingRequest::new() + } + + fn clear(&mut self) { + self.host_name = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpStartPairingRequest { + static instance: ThpStartPairingRequest = ThpStartPairingRequest { + host_name: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpStartPairingRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpStartPairingRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpStartPairingRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpStartPairingRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpPairingPreparationsFinished) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpPairingPreparationsFinished { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpPairingPreparationsFinished.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpPairingPreparationsFinished { + fn default() -> &'a ThpPairingPreparationsFinished { + ::default_instance() + } +} + +impl ThpPairingPreparationsFinished { + pub fn new() -> ThpPairingPreparationsFinished { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpPairingPreparationsFinished", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpPairingPreparationsFinished { + const NAME: &'static str = "ThpPairingPreparationsFinished"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpPairingPreparationsFinished { + ThpPairingPreparationsFinished::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpPairingPreparationsFinished { + static instance: ThpPairingPreparationsFinished = ThpPairingPreparationsFinished { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpPairingPreparationsFinished { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpPairingPreparationsFinished").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpPairingPreparationsFinished { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpPairingPreparationsFinished { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCommitment) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCommitment { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCommitment.commitment) + pub commitment: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCommitment.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCommitment { + fn default() -> &'a ThpCodeEntryCommitment { + ::default_instance() + } +} + +impl ThpCodeEntryCommitment { + pub fn new() -> ThpCodeEntryCommitment { + ::std::default::Default::default() + } + + // optional bytes commitment = 1; + + pub fn commitment(&self) -> &[u8] { + match self.commitment.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_commitment(&mut self) { + self.commitment = ::std::option::Option::None; + } + + pub fn has_commitment(&self) -> bool { + self.commitment.is_some() + } + + // Param is passed by value, moved + pub fn set_commitment(&mut self, v: ::std::vec::Vec) { + self.commitment = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_commitment(&mut self) -> &mut ::std::vec::Vec { + if self.commitment.is_none() { + self.commitment = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.commitment.as_mut().unwrap() + } + + // Take field + pub fn take_commitment(&mut self) -> ::std::vec::Vec { + self.commitment.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "commitment", + |m: &ThpCodeEntryCommitment| { &m.commitment }, + |m: &mut ThpCodeEntryCommitment| { &mut m.commitment }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCommitment", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCommitment { + const NAME: &'static str = "ThpCodeEntryCommitment"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.commitment = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.commitment.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.commitment.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCommitment { + ThpCodeEntryCommitment::new() + } + + fn clear(&mut self) { + self.commitment = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCommitment { + static instance: ThpCodeEntryCommitment = ThpCodeEntryCommitment { + commitment: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCommitment { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCommitment").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCommitment { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCommitment { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryChallenge) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryChallenge { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryChallenge.challenge) + pub challenge: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryChallenge.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryChallenge { + fn default() -> &'a ThpCodeEntryChallenge { + ::default_instance() + } +} + +impl ThpCodeEntryChallenge { + pub fn new() -> ThpCodeEntryChallenge { + ::std::default::Default::default() + } + + // optional bytes challenge = 1; + + pub fn challenge(&self) -> &[u8] { + match self.challenge.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_challenge(&mut self) { + self.challenge = ::std::option::Option::None; + } + + pub fn has_challenge(&self) -> bool { + self.challenge.is_some() + } + + // Param is passed by value, moved + pub fn set_challenge(&mut self, v: ::std::vec::Vec) { + self.challenge = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_challenge(&mut self) -> &mut ::std::vec::Vec { + if self.challenge.is_none() { + self.challenge = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.challenge.as_mut().unwrap() + } + + // Take field + pub fn take_challenge(&mut self) -> ::std::vec::Vec { + self.challenge.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "challenge", + |m: &ThpCodeEntryChallenge| { &m.challenge }, + |m: &mut ThpCodeEntryChallenge| { &mut m.challenge }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryChallenge", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryChallenge { + const NAME: &'static str = "ThpCodeEntryChallenge"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.challenge = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.challenge.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.challenge.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryChallenge { + ThpCodeEntryChallenge::new() + } + + fn clear(&mut self) { + self.challenge = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryChallenge { + static instance: ThpCodeEntryChallenge = ThpCodeEntryChallenge { + challenge: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryChallenge { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryChallenge").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryChallenge { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryChallenge { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCpaceHost) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCpaceHost { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceHost.cpace_host_public_key) + pub cpace_host_public_key: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCpaceHost.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCpaceHost { + fn default() -> &'a ThpCodeEntryCpaceHost { + ::default_instance() + } +} + +impl ThpCodeEntryCpaceHost { + pub fn new() -> ThpCodeEntryCpaceHost { + ::std::default::Default::default() + } + + // optional bytes cpace_host_public_key = 1; + + pub fn cpace_host_public_key(&self) -> &[u8] { + match self.cpace_host_public_key.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_cpace_host_public_key(&mut self) { + self.cpace_host_public_key = ::std::option::Option::None; + } + + pub fn has_cpace_host_public_key(&self) -> bool { + self.cpace_host_public_key.is_some() + } + + // Param is passed by value, moved + pub fn set_cpace_host_public_key(&mut self, v: ::std::vec::Vec) { + self.cpace_host_public_key = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_cpace_host_public_key(&mut self) -> &mut ::std::vec::Vec { + if self.cpace_host_public_key.is_none() { + self.cpace_host_public_key = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.cpace_host_public_key.as_mut().unwrap() + } + + // Take field + pub fn take_cpace_host_public_key(&mut self) -> ::std::vec::Vec { + self.cpace_host_public_key.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "cpace_host_public_key", + |m: &ThpCodeEntryCpaceHost| { &m.cpace_host_public_key }, + |m: &mut ThpCodeEntryCpaceHost| { &mut m.cpace_host_public_key }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCpaceHost", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCpaceHost { + const NAME: &'static str = "ThpCodeEntryCpaceHost"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.cpace_host_public_key = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.cpace_host_public_key.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.cpace_host_public_key.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCpaceHost { + ThpCodeEntryCpaceHost::new() + } + + fn clear(&mut self) { + self.cpace_host_public_key = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCpaceHost { + static instance: ThpCodeEntryCpaceHost = ThpCodeEntryCpaceHost { + cpace_host_public_key: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCpaceHost { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCpaceHost").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCpaceHost { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCpaceHost { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCpaceTrezor { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor.cpace_trezor_public_key) + pub cpace_trezor_public_key: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCpaceTrezor { + fn default() -> &'a ThpCodeEntryCpaceTrezor { + ::default_instance() + } +} + +impl ThpCodeEntryCpaceTrezor { + pub fn new() -> ThpCodeEntryCpaceTrezor { + ::std::default::Default::default() + } + + // optional bytes cpace_trezor_public_key = 1; + + pub fn cpace_trezor_public_key(&self) -> &[u8] { + match self.cpace_trezor_public_key.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_cpace_trezor_public_key(&mut self) { + self.cpace_trezor_public_key = ::std::option::Option::None; + } + + pub fn has_cpace_trezor_public_key(&self) -> bool { + self.cpace_trezor_public_key.is_some() + } + + // Param is passed by value, moved + pub fn set_cpace_trezor_public_key(&mut self, v: ::std::vec::Vec) { + self.cpace_trezor_public_key = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_cpace_trezor_public_key(&mut self) -> &mut ::std::vec::Vec { + if self.cpace_trezor_public_key.is_none() { + self.cpace_trezor_public_key = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.cpace_trezor_public_key.as_mut().unwrap() + } + + // Take field + pub fn take_cpace_trezor_public_key(&mut self) -> ::std::vec::Vec { + self.cpace_trezor_public_key.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "cpace_trezor_public_key", + |m: &ThpCodeEntryCpaceTrezor| { &m.cpace_trezor_public_key }, + |m: &mut ThpCodeEntryCpaceTrezor| { &mut m.cpace_trezor_public_key }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCpaceTrezor", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCpaceTrezor { + const NAME: &'static str = "ThpCodeEntryCpaceTrezor"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.cpace_trezor_public_key = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.cpace_trezor_public_key.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.cpace_trezor_public_key.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCpaceTrezor { + ThpCodeEntryCpaceTrezor::new() + } + + fn clear(&mut self) { + self.cpace_trezor_public_key = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCpaceTrezor { + static instance: ThpCodeEntryCpaceTrezor = ThpCodeEntryCpaceTrezor { + cpace_trezor_public_key: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCpaceTrezor { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCpaceTrezor").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCpaceTrezor { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCpaceTrezor { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryTag { + fn default() -> &'a ThpCodeEntryTag { + ::default_instance() + } +} + +impl ThpCodeEntryTag { + pub fn new() -> ThpCodeEntryTag { + ::std::default::Default::default() + } + + // optional bytes tag = 2; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpCodeEntryTag| { &m.tag }, + |m: &mut ThpCodeEntryTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryTag { + const NAME: &'static str = "ThpCodeEntryTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 18 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryTag { + ThpCodeEntryTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryTag { + static instance: ThpCodeEntryTag = ThpCodeEntryTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntrySecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntrySecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntrySecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntrySecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntrySecret { + fn default() -> &'a ThpCodeEntrySecret { + ::default_instance() + } +} + +impl ThpCodeEntrySecret { + pub fn new() -> ThpCodeEntrySecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpCodeEntrySecret| { &m.secret }, + |m: &mut ThpCodeEntrySecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntrySecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntrySecret { + const NAME: &'static str = "ThpCodeEntrySecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntrySecret { + ThpCodeEntrySecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntrySecret { + static instance: ThpCodeEntrySecret = ThpCodeEntrySecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntrySecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntrySecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntrySecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntrySecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpQrCodeTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpQrCodeTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpQrCodeTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpQrCodeTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpQrCodeTag { + fn default() -> &'a ThpQrCodeTag { + ::default_instance() + } +} + +impl ThpQrCodeTag { + pub fn new() -> ThpQrCodeTag { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpQrCodeTag| { &m.tag }, + |m: &mut ThpQrCodeTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpQrCodeTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpQrCodeTag { + const NAME: &'static str = "ThpQrCodeTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpQrCodeTag { + ThpQrCodeTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpQrCodeTag { + static instance: ThpQrCodeTag = ThpQrCodeTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpQrCodeTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpQrCodeTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpQrCodeTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpQrCodeTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpQrCodeSecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpQrCodeSecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpQrCodeSecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpQrCodeSecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpQrCodeSecret { + fn default() -> &'a ThpQrCodeSecret { + ::default_instance() + } +} + +impl ThpQrCodeSecret { + pub fn new() -> ThpQrCodeSecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpQrCodeSecret| { &m.secret }, + |m: &mut ThpQrCodeSecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpQrCodeSecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpQrCodeSecret { + const NAME: &'static str = "ThpQrCodeSecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpQrCodeSecret { + ThpQrCodeSecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpQrCodeSecret { + static instance: ThpQrCodeSecret = ThpQrCodeSecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpQrCodeSecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpQrCodeSecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpQrCodeSecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpQrCodeSecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNfcUnidirectionalTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNfcUnidirectionalTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNfcUnidirectionalTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNfcUnidirectionalTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNfcUnidirectionalTag { + fn default() -> &'a ThpNfcUnidirectionalTag { + ::default_instance() + } +} + +impl ThpNfcUnidirectionalTag { + pub fn new() -> ThpNfcUnidirectionalTag { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpNfcUnidirectionalTag| { &m.tag }, + |m: &mut ThpNfcUnidirectionalTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNfcUnidirectionalTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNfcUnidirectionalTag { + const NAME: &'static str = "ThpNfcUnidirectionalTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNfcUnidirectionalTag { + ThpNfcUnidirectionalTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNfcUnidirectionalTag { + static instance: ThpNfcUnidirectionalTag = ThpNfcUnidirectionalTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNfcUnidirectionalTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNfcUnidirectionalTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNfcUnidirectionalTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNfcUnidirectionalTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNfcUnidirectionalSecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNfcUnidirectionalSecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNfcUnidirectionalSecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNfcUnidirectionalSecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNfcUnidirectionalSecret { + fn default() -> &'a ThpNfcUnidirectionalSecret { + ::default_instance() + } +} + +impl ThpNfcUnidirectionalSecret { + pub fn new() -> ThpNfcUnidirectionalSecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpNfcUnidirectionalSecret| { &m.secret }, + |m: &mut ThpNfcUnidirectionalSecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNfcUnidirectionalSecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNfcUnidirectionalSecret { + const NAME: &'static str = "ThpNfcUnidirectionalSecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNfcUnidirectionalSecret { + ThpNfcUnidirectionalSecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNfcUnidirectionalSecret { + static instance: ThpNfcUnidirectionalSecret = ThpNfcUnidirectionalSecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNfcUnidirectionalSecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNfcUnidirectionalSecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNfcUnidirectionalSecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNfcUnidirectionalSecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCredentialRequest { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialRequest.host_static_pubkey) + pub host_static_pubkey: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCredentialRequest { + fn default() -> &'a ThpCredentialRequest { + ::default_instance() + } +} + +impl ThpCredentialRequest { + pub fn new() -> ThpCredentialRequest { + ::std::default::Default::default() + } + + // optional bytes host_static_pubkey = 1; + + pub fn host_static_pubkey(&self) -> &[u8] { + match self.host_static_pubkey.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_host_static_pubkey(&mut self) { + self.host_static_pubkey = ::std::option::Option::None; + } + + pub fn has_host_static_pubkey(&self) -> bool { + self.host_static_pubkey.is_some() + } + + // Param is passed by value, moved + pub fn set_host_static_pubkey(&mut self, v: ::std::vec::Vec) { + self.host_static_pubkey = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_static_pubkey(&mut self) -> &mut ::std::vec::Vec { + if self.host_static_pubkey.is_none() { + self.host_static_pubkey = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.host_static_pubkey.as_mut().unwrap() + } + + // Take field + pub fn take_host_static_pubkey(&mut self) -> ::std::vec::Vec { + self.host_static_pubkey.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_static_pubkey", + |m: &ThpCredentialRequest| { &m.host_static_pubkey }, + |m: &mut ThpCredentialRequest| { &mut m.host_static_pubkey }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCredentialRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCredentialRequest { + const NAME: &'static str = "ThpCredentialRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_static_pubkey = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_static_pubkey.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_static_pubkey.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCredentialRequest { + ThpCredentialRequest::new() + } + + fn clear(&mut self) { + self.host_static_pubkey = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCredentialRequest { + static instance: ThpCredentialRequest = ThpCredentialRequest { + host_static_pubkey: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCredentialRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCredentialRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCredentialRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCredentialRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCredentialResponse { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialResponse.trezor_static_pubkey) + pub trezor_static_pubkey: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialResponse.credential) + pub credential: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCredentialResponse { + fn default() -> &'a ThpCredentialResponse { + ::default_instance() + } +} + +impl ThpCredentialResponse { + pub fn new() -> ThpCredentialResponse { + ::std::default::Default::default() + } + + // optional bytes trezor_static_pubkey = 1; + + pub fn trezor_static_pubkey(&self) -> &[u8] { + match self.trezor_static_pubkey.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_trezor_static_pubkey(&mut self) { + self.trezor_static_pubkey = ::std::option::Option::None; + } + + pub fn has_trezor_static_pubkey(&self) -> bool { + self.trezor_static_pubkey.is_some() + } + + // Param is passed by value, moved + pub fn set_trezor_static_pubkey(&mut self, v: ::std::vec::Vec) { + self.trezor_static_pubkey = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_trezor_static_pubkey(&mut self) -> &mut ::std::vec::Vec { + if self.trezor_static_pubkey.is_none() { + self.trezor_static_pubkey = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.trezor_static_pubkey.as_mut().unwrap() + } + + // Take field + pub fn take_trezor_static_pubkey(&mut self) -> ::std::vec::Vec { + self.trezor_static_pubkey.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes credential = 2; + + pub fn credential(&self) -> &[u8] { + match self.credential.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_credential(&mut self) { + self.credential = ::std::option::Option::None; + } + + pub fn has_credential(&self) -> bool { + self.credential.is_some() + } + + // Param is passed by value, moved + pub fn set_credential(&mut self, v: ::std::vec::Vec) { + self.credential = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_credential(&mut self) -> &mut ::std::vec::Vec { + if self.credential.is_none() { + self.credential = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.credential.as_mut().unwrap() + } + + // Take field + pub fn take_credential(&mut self) -> ::std::vec::Vec { + self.credential.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "trezor_static_pubkey", + |m: &ThpCredentialResponse| { &m.trezor_static_pubkey }, + |m: &mut ThpCredentialResponse| { &mut m.trezor_static_pubkey }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "credential", + |m: &ThpCredentialResponse| { &m.credential }, + |m: &mut ThpCredentialResponse| { &mut m.credential }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCredentialResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCredentialResponse { + const NAME: &'static str = "ThpCredentialResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.trezor_static_pubkey = ::std::option::Option::Some(is.read_bytes()?); + }, + 18 => { + self.credential = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.trezor_static_pubkey.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.credential.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.trezor_static_pubkey.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.credential.as_ref() { + os.write_bytes(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCredentialResponse { + ThpCredentialResponse::new() + } + + fn clear(&mut self) { + self.trezor_static_pubkey = ::std::option::Option::None; + self.credential = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCredentialResponse { + static instance: ThpCredentialResponse = ThpCredentialResponse { + trezor_static_pubkey: ::std::option::Option::None, + credential: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCredentialResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCredentialResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCredentialResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCredentialResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpEndRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpEndRequest { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpEndRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpEndRequest { + fn default() -> &'a ThpEndRequest { + ::default_instance() + } +} + +impl ThpEndRequest { + pub fn new() -> ThpEndRequest { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpEndRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpEndRequest { + const NAME: &'static str = "ThpEndRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpEndRequest { + ThpEndRequest::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpEndRequest { + static instance: ThpEndRequest = ThpEndRequest { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpEndRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpEndRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpEndRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpEndRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpEndResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpEndResponse { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpEndResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpEndResponse { + fn default() -> &'a ThpEndResponse { + ::default_instance() + } +} + +impl ThpEndResponse { + pub fn new() -> ThpEndResponse { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpEndResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpEndResponse { + const NAME: &'static str = "ThpEndResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpEndResponse { + ThpEndResponse::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpEndResponse { + static instance: ThpEndResponse = ThpEndResponse { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpEndResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpEndResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpEndResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpEndResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + // @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialMetadata) #[derive(PartialEq,Clone,Default,Debug)] pub struct ThpCredentialMetadata { @@ -537,17 +3796,128 @@ impl ::protobuf::reflect::ProtobufValue for ThpAuthenticatedCredentialData { type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; } +#[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] +// @@protoc_insertion_point(enum:hw.trezor.messages.thp.ThpPairingMethod) +pub enum ThpPairingMethod { + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.NoMethod) + NoMethod = 1, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.CodeEntry) + CodeEntry = 2, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.QrCode) + QrCode = 3, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.NFC_Unidirectional) + NFC_Unidirectional = 4, +} + +impl ::protobuf::Enum for ThpPairingMethod { + const NAME: &'static str = "ThpPairingMethod"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1 => ::std::option::Option::Some(ThpPairingMethod::NoMethod), + 2 => ::std::option::Option::Some(ThpPairingMethod::CodeEntry), + 3 => ::std::option::Option::Some(ThpPairingMethod::QrCode), + 4 => ::std::option::Option::Some(ThpPairingMethod::NFC_Unidirectional), + _ => ::std::option::Option::None + } + } + + fn from_str(str: &str) -> ::std::option::Option { + match str { + "NoMethod" => ::std::option::Option::Some(ThpPairingMethod::NoMethod), + "CodeEntry" => ::std::option::Option::Some(ThpPairingMethod::CodeEntry), + "QrCode" => ::std::option::Option::Some(ThpPairingMethod::QrCode), + "NFC_Unidirectional" => ::std::option::Option::Some(ThpPairingMethod::NFC_Unidirectional), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [ThpPairingMethod] = &[ + ThpPairingMethod::NoMethod, + ThpPairingMethod::CodeEntry, + ThpPairingMethod::QrCode, + ThpPairingMethod::NFC_Unidirectional, + ]; +} + +impl ::protobuf::EnumFull for ThpPairingMethod { + fn enum_descriptor() -> ::protobuf::reflect::EnumDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::EnumDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().enum_by_package_relative_name("ThpPairingMethod").unwrap()).clone() + } + + fn descriptor(&self) -> ::protobuf::reflect::EnumValueDescriptor { + let index = match self { + ThpPairingMethod::NoMethod => 0, + ThpPairingMethod::CodeEntry => 1, + ThpPairingMethod::QrCode => 2, + ThpPairingMethod::NFC_Unidirectional => 3, + }; + Self::enum_descriptor().value_by_index(index) + } +} + +// Note, `Default` is implemented although default value is not 0 +impl ::std::default::Default for ThpPairingMethod { + fn default() -> Self { + ThpPairingMethod::NoMethod + } +} + +impl ThpPairingMethod { + fn generated_enum_descriptor_data() -> ::protobuf::reflect::GeneratedEnumDescriptorData { + ::protobuf::reflect::GeneratedEnumDescriptorData::new::("ThpPairingMethod") + } +} + static file_descriptor_proto_data: &'static [u8] = b"\ \n\x12messages-thp.proto\x12\x16hw.trezor.messages.thp\x1a\x0emessages.p\ - roto\":\n\x15ThpCredentialMetadata\x12\x1b\n\thost_name\x18\x01\x20\x01(\ - \tR\x08hostName:\x04\x98\xb2\x19\x01\"\x82\x01\n\x14ThpPairingCredential\ - \x12R\n\rcred_metadata\x18\x01\x20\x01(\x0b2-.hw.trezor.messages.thp.Thp\ - CredentialMetadataR\x0ccredMetadata\x12\x10\n\x03mac\x18\x02\x20\x01(\ - \x0cR\x03mac:\x04\x98\xb2\x19\x01\"\xa8\x01\n\x1eThpAuthenticatedCredent\ - ialData\x12,\n\x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStatic\ - Pubkey\x12R\n\rcred_metadata\x18\x02\x20\x01(\x0b2-.hw.trezor.messages.t\ - hp.ThpCredentialMetadataR\x0ccredMetadata:\x04\x98\xb2\x19\x01B;\n#com.s\ - atoshilabs.trezor.lib.protobufB\x10TrezorMessageThp\x80\xa6\x1d\x01\ + roto\"\x88\x02\n\x13ThpDeviceProperties\x12%\n\x0einternal_model\x18\x01\ + \x20\x01(\tR\rinternalModel\x12#\n\rmodel_variant\x18\x02\x20\x01(\rR\ + \x0cmodelVariant\x12'\n\x0fbootloader_mode\x18\x03\x20\x01(\x08R\x0eboot\ + loaderMode\x12)\n\x10protocol_version\x18\x04\x20\x01(\rR\x0fprotocolVer\ + sion\x12Q\n\x0fpairing_methods\x18\x05\x20\x03(\x0e2(.hw.trezor.messages\ + .thp.ThpPairingMethodR\x0epairingMethods\"\xb2\x01\n%ThpHandshakeComplet\ + ionReqNoisePayload\x126\n\x17host_pairing_credential\x18\x01\x20\x01(\ + \x0cR\x15hostPairingCredential\x12Q\n\x0fpairing_methods\x18\x02\x20\x03\ + (\x0e2(.hw.trezor.messages.thp.ThpPairingMethodR\x0epairingMethods\"y\n\ + \x13ThpCreateNewSession\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassp\ + hrase\x12\x1b\n\ton_device\x18\x02\x20\x01(\x08R\x08onDevice\x12%\n\x0ed\ + erive_cardano\x18\x03\x20\x01(\x08R\rderiveCardano\"5\n\rThpNewSession\ + \x12$\n\x0enew_session_id\x18\x01\x20\x01(\rR\x0cnewSessionId\"5\n\x16Th\ + pStartPairingRequest\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\x08hostNam\ + e\"\x20\n\x1eThpPairingPreparationsFinished\"8\n\x16ThpCodeEntryCommitme\ + nt\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitment\"5\n\x15ThpCo\ + deEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\x0cR\tchallenge\"\ + J\n\x15ThpCodeEntryCpaceHost\x121\n\x15cpace_host_public_key\x18\x01\x20\ + \x01(\x0cR\x12cpaceHostPublicKey\"P\n\x17ThpCodeEntryCpaceTrezor\x125\n\ + \x17cpace_trezor_public_key\x18\x01\x20\x01(\x0cR\x14cpaceTrezorPublicKe\ + y\"#\n\x0fThpCodeEntryTag\x12\x10\n\x03tag\x18\x02\x20\x01(\x0cR\x03tag\ + \",\n\x12ThpCodeEntrySecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\ + \x06secret\"\x20\n\x0cThpQrCodeTag\x12\x10\n\x03tag\x18\x01\x20\x01(\x0c\ + R\x03tag\")\n\x0fThpQrCodeSecret\x12\x16\n\x06secret\x18\x01\x20\x01(\ + \x0cR\x06secret\"+\n\x17ThpNfcUnidirectionalTag\x12\x10\n\x03tag\x18\x01\ + \x20\x01(\x0cR\x03tag\"4\n\x1aThpNfcUnidirectionalSecret\x12\x16\n\x06se\ + cret\x18\x01\x20\x01(\x0cR\x06secret\"D\n\x14ThpCredentialRequest\x12,\n\ + \x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStaticPubkey\"i\n\ + \x15ThpCredentialResponse\x120\n\x14trezor_static_pubkey\x18\x01\x20\x01\ + (\x0cR\x12trezorStaticPubkey\x12\x1e\n\ncredential\x18\x02\x20\x01(\x0cR\ + \ncredential\"\x0f\n\rThpEndRequest\"\x10\n\x0eThpEndResponse\":\n\x15Th\ + pCredentialMetadata\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\x08hostName\ + :\x04\x98\xb2\x19\x01\"\x82\x01\n\x14ThpPairingCredential\x12R\n\rcred_m\ + etadata\x18\x01\x20\x01(\x0b2-.hw.trezor.messages.thp.ThpCredentialMetad\ + ataR\x0ccredMetadata\x12\x10\n\x03mac\x18\x02\x20\x01(\x0cR\x03mac:\x04\ + \x98\xb2\x19\x01\"\xa8\x01\n\x1eThpAuthenticatedCredentialData\x12,\n\ + \x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStaticPubkey\x12R\n\ + \rcred_metadata\x18\x02\x20\x01(\x0b2-.hw.trezor.messages.thp.ThpCredent\ + ialMetadataR\x0ccredMetadata:\x04\x98\xb2\x19\x01*S\n\x10ThpPairingMetho\ + d\x12\x0c\n\x08NoMethod\x10\x01\x12\r\n\tCodeEntry\x10\x02\x12\n\n\x06Qr\ + Code\x10\x03\x12\x16\n\x12NFC_Unidirectional\x10\x04B;\n#com.satoshilabs\ + .trezor.lib.protobufB\x10TrezorMessageThp\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -566,11 +3936,32 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { let generated_file_descriptor = generated_file_descriptor_lazy.get(|| { let mut deps = ::std::vec::Vec::with_capacity(1); deps.push(super::messages::file_descriptor().clone()); - let mut messages = ::std::vec::Vec::with_capacity(3); + let mut messages = ::std::vec::Vec::with_capacity(23); + messages.push(ThpDeviceProperties::generated_message_descriptor_data()); + messages.push(ThpHandshakeCompletionReqNoisePayload::generated_message_descriptor_data()); + messages.push(ThpCreateNewSession::generated_message_descriptor_data()); + messages.push(ThpNewSession::generated_message_descriptor_data()); + messages.push(ThpStartPairingRequest::generated_message_descriptor_data()); + messages.push(ThpPairingPreparationsFinished::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCommitment::generated_message_descriptor_data()); + messages.push(ThpCodeEntryChallenge::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCpaceHost::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCpaceTrezor::generated_message_descriptor_data()); + messages.push(ThpCodeEntryTag::generated_message_descriptor_data()); + messages.push(ThpCodeEntrySecret::generated_message_descriptor_data()); + messages.push(ThpQrCodeTag::generated_message_descriptor_data()); + messages.push(ThpQrCodeSecret::generated_message_descriptor_data()); + messages.push(ThpNfcUnidirectionalTag::generated_message_descriptor_data()); + messages.push(ThpNfcUnidirectionalSecret::generated_message_descriptor_data()); + messages.push(ThpCredentialRequest::generated_message_descriptor_data()); + messages.push(ThpCredentialResponse::generated_message_descriptor_data()); + messages.push(ThpEndRequest::generated_message_descriptor_data()); + messages.push(ThpEndResponse::generated_message_descriptor_data()); messages.push(ThpCredentialMetadata::generated_message_descriptor_data()); messages.push(ThpPairingCredential::generated_message_descriptor_data()); messages.push(ThpAuthenticatedCredentialData::generated_message_descriptor_data()); - let mut enums = ::std::vec::Vec::with_capacity(0); + let mut enums = ::std::vec::Vec::with_capacity(1); + enums.push(ThpPairingMethod::generated_enum_descriptor_data()); ::protobuf::reflect::GeneratedFileDescriptor::new_generated( file_descriptor_proto(), deps, diff --git a/tests/click_tests/record_layout.py b/tests/click_tests/record_layout.py index 71a590e7975..d93391912c0 100644 --- a/tests/click_tests/record_layout.py +++ b/tests/click_tests/record_layout.py @@ -63,7 +63,7 @@ CALLS_DONE = [] DEBUGLINK = None -get_client_orig = cli.TrezorConnection.get_client +get_client_orig = cli.NewTrezorConnection.get_client def get_client(conn): @@ -75,7 +75,7 @@ def get_client(conn): return client -cli.TrezorConnection.get_client = get_client +cli.NewTrezorConnection.get_client = get_client def scan_layouts(dest): diff --git a/tests/common.py b/tests/common.py index b2a20bb39dd..41fd00f4d54 100644 --- a/tests/common.py +++ b/tests/common.py @@ -34,8 +34,8 @@ from _pytest.mark.structures import MarkDecorator from trezorlib.debuglink import DebugLink - from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ButtonRequest + from trezorlib.transport.session import Session PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] @@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None: assert got >= expected -def get_test_address(client: "Client") -> str: +def get_test_address(session: "Session") -> str: """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase protected call, or to identify the root secret (seed+passphrase)""" - return btc.get_address(client, "Testnet", TEST_ADDRESS_N) + return btc.get_address(session, "Testnet", TEST_ADDRESS_N) def compact_size(n: int) -> bytes: @@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None: debug.swipe_up() -def is_core(client: "Client") -> bool: - return client.model is not models.T1B1 +def is_core(session: "Session") -> bool: + return session.model is not models.T1B1 diff --git a/tests/conftest.py b/tests/conftest.py index 00dacb52926..eeb6a65be95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ import typing as t from enum import IntEnum from pathlib import Path +from time import sleep import pytest import xdist @@ -27,10 +28,12 @@ from _pytest.reports import TestReport from trezorlib import debuglink, log, models +from trezorlib.debuglink import SessionDebugWrapper from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.device import apply_settings from trezorlib.device import wipe as wipe_device from trezorlib.transport import enumerate_devices, get_transport +from trezorlib.transport.new.protocol_v1 import ProtocolV1 # register rewrites before importing from local package # so that we see details of failed asserts from this module @@ -285,12 +288,13 @@ def client( _raw_client.reset_debug_features() _raw_client.open() - try: - _raw_client.sync_responses() - _raw_client.init_device() - except Exception: - request.session.shouldstop = "Failed to communicate with Trezor" - pytest.fail("Failed to communicate with Trezor") + if isinstance(_raw_client.protocol, ProtocolV1): + try: + _raw_client.sync_responses() + # TODO _raw_client.init_device() + except Exception: + request.session.shouldstop = "Failed to communicate with Trezor" + pytest.fail("Failed to communicate with Trezor") # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -302,8 +306,20 @@ def client( if sd_marker: should_format = sd_marker.kwargs.get("formatted", True) _raw_client.debug.erase_sd_card(format=should_format) + session = _raw_client.get_management_session() - wipe_device(_raw_client) + wipe_device(session) + sleep(1) # Makes tests more stable (wait for wipe to finish) + from trezorlib.transport.new import channel_database + + channel_database.clear_stored_channels() + _raw_client.protocol = None + _raw_client.__init__( + transport=_raw_client.transport, + auto_interact=_raw_client.debug.allow_interactions, + ) + if not _raw_client.features.bootloader_mode: + _raw_client.refresh_features() # Load language again, as it got erased in wipe if _raw_client.model is not models.T1B1: @@ -328,10 +344,10 @@ def client( use_passphrase = setup_params["passphrase"] is True or isinstance( setup_params["passphrase"], str ) - if not setup_params["uninitialized"]: + session = _raw_client.get_management_session(new_session=True) debuglink.load_device( - _raw_client, + session, mnemonic=setup_params["mnemonic"], # type: ignore pin=setup_params["pin"], # type: ignore passphrase_protection=use_passphrase, @@ -341,12 +357,12 @@ def client( ) if request.node.get_closest_marker("experimental"): - apply_settings(_raw_client, experimental_features=True) + apply_settings(session, experimental_features=True) if use_passphrase and isinstance(setup_params["passphrase"], str): _raw_client.use_passphrase(setup_params["passphrase"]) - _raw_client.clear_session() + # TODO _raw_client.clear_session() with ui_tests.screen_recording(_raw_client, request): yield _raw_client @@ -354,6 +370,35 @@ def client( _raw_client.close() +@pytest.fixture(scope="function") +def session( + request: pytest.FixtureRequest, client: Client +) -> t.Generator[SessionDebugWrapper, None, None]: + derive_cardano = bool(request.node.get_closest_marker("cardano")) + passphrase = client.passphrase or "" + session = client.get_session(derive_cardano=derive_cardano, passphrase=passphrase) + try: + yield SessionDebugWrapper(session) + finally: + pass + # TODO + # session.end() + + +@pytest.fixture(scope="function") +def uninitialized_session( + request: pytest.FixtureRequest, + client: Client, +) -> t.Generator[SessionDebugWrapper, None, None]: + session = client.get_management_session() + try: + yield SessionDebugWrapper(session) + finally: + pass + # TODO + # session.end() + + def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool: """Return True if the current process is the main test runner. diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index cdb6e722713..6b5a0247676 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -17,7 +17,7 @@ import pytest from trezorlib.binance import get_address -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowAddressQRCode @@ -38,23 +38,23 @@ @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) -def test_binance_get_address(client: Client, path: str, expected_address: str): +def test_binance_get_address(session: Session, path: str, expected_address: str): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - address = get_address(client, parse_path(path), show_display=True) + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) def test_binance_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index ea04fdbd88f..f65baa5dd83 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowXpubQRCode @@ -31,11 +31,11 @@ @pytest.mark.setup_client( mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) -def test_binance_get_public_key(client: Client): - with client: +def test_binance_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - sig = binance.get_public_key(client, BINANCE_PATH, show_display=True) + sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() == "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e" diff --git a/tests/device_tests/binance/test_sign_tx.py b/tests/device_tests/binance/test_sign_tx.py index ceb06924650..1665e005a46 100644 --- a/tests/device_tests/binance/test_sign_tx.py +++ b/tests/device_tests/binance/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path BINANCE_TEST_VECTORS = [ @@ -110,10 +110,10 @@ @pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS) @pytest.mark.parametrize("chunkify", (True, False)) def test_binance_sign_message( - client: Client, chunkify: bool, message: dict, expected_response: dict + session: Session, chunkify: bool, message: dict, expected_response: dict ): response = binance.sign_tx( - client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify + session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify ) assert response.public_key.hex() == expected_response["public_key"] diff --git a/tests/device_tests/bitcoin/payment_req.py b/tests/device_tests/bitcoin/payment_req.py index 73d98859ba1..f928a5fa8e8 100644 --- a/tests/device_tests/bitcoin/payment_req.py +++ b/tests/device_tests/bitcoin/payment_req.py @@ -4,6 +4,7 @@ from ecdsa import SECP256k1, SigningKey from trezorlib import btc, messages +from trezorlib.transport.session import Session from ...common import compact_size @@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data): def make_payment_request( - client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None + session: Session, + recipient_name, + outputs, + change_addresses=None, + memos=None, + nonce=None, ): h_pr = sha256(b"SL\x00\x24") @@ -52,7 +58,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, memo.text.encode()) elif isinstance(memo, RefundMemo): address_resp = btc.get_authenticated_address( - client, "Testnet", memo.address_n + session, "Testnet", memo.address_n ) msg_memo = messages.RefundMemo( address=address_resp.address, mac=address_resp.mac @@ -63,7 +69,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, address_resp.address.encode()) elif isinstance(memo, CoinPurchaseMemo): address_resp = btc.get_authenticated_address( - client, memo.coin_name, memo.address_n + session, memo.coin_name, memo.address_n ) msg_memo = messages.CoinPurchaseMemo( coin_type=memo.slip44, diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index b149ff53d16..e1abeb943ed 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -59,15 +59,15 @@ @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.setup_client(pin=PIN) -def test_sign_tx(client: Client, chunkify: bool): +def test_sign_tx(session: Session, chunkify: bool): # NOTE: FAKE input tx commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=2, max_coordinator_fee_rate=500_000, # 0.5 % @@ -77,14 +77,14 @@ def test_sign_tx(client: Client, chunkify: bool): script_type=messages.InputScriptType.SPENDTAPROOT, ) - client.call(messages.LockDevice()) + session.call(messages.LockDevice()) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -93,12 +93,12 @@ def test_sign_tx(client: Client, chunkify: bool): preauthorized=True, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/5"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -206,8 +206,8 @@ def test_sign_tx(client: Client, chunkify: bool): no_fee_indices=[], ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.PreauthorizedRequest(), request_input(0), @@ -222,7 +222,7 @@ def test_sign_tx(client: Client, chunkify: bool): ] ) signatures, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -243,7 +243,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a second time. btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -256,7 +256,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a third time, number of rounds should be exceeded. with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -267,7 +267,7 @@ def test_sign_tx(client: Client, chunkify: bool): ) -def test_sign_tx_large(client: Client): +def test_sign_tx_large(session: Session): # NOTE: FAKE input tx commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") @@ -278,17 +278,16 @@ def test_sign_tx_large(client: Client): output_denom = 10_000 # sats max_expected_delay = 60 # seconds - with client: - btc.authorize_coinjoin( - client, - coordinator="www.example.com", - max_rounds=2, - max_coordinator_fee_rate=500_000, # 0.5 % - max_fee_per_kvbyte=3500, - n=parse_path("m/10025h/1h/0h/1h"), - coin_name="Testnet", - script_type=messages.InputScriptType.SPENDTAPROOT, - ) + btc.authorize_coinjoin( + session, + coordinator="www.example.com", + max_rounds=2, + max_coordinator_fee_rate=500_000, # 0.5 % + max_fee_per_kvbyte=3500, + n=parse_path("m/10025h/1h/0h/1h"), + coin_name="Testnet", + script_type=messages.InputScriptType.SPENDTAPROOT, + ) # INPUTS. @@ -399,22 +398,21 @@ def test_sign_tx_large(client: Client): ) start = time.time() - with client: - btc.sign_tx( - client, - "Testnet", - inputs, - outputs, - prev_txes=TX_CACHE_TESTNET, - coinjoin_request=coinjoin_req, - preauthorized=True, - serialize=False, - ) + btc.sign_tx( + session, + "Testnet", + inputs, + outputs, + prev_txes=TX_CACHE_TESTNET, + coinjoin_request=coinjoin_req, + preauthorized=True, + serialize=False, + ) delay = time.time() - start assert delay <= max_expected_delay -def test_sign_tx_spend(client: Client): +def test_sign_tx_spend(session: Session): # NOTE: FAKE input tx inputs = [ @@ -446,15 +444,15 @@ def test_sign_tx_spend(client: Client): # Ensure that Trezor refuses to spend from CoinJoin without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest(), @@ -462,7 +460,7 @@ def test_sign_tx_spend(client: Client): request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -472,7 +470,7 @@ def test_sign_tx_spend(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -487,7 +485,7 @@ def test_sign_tx_spend(client: Client): ) -def test_sign_tx_migration(client: Client): +def test_sign_tx_migration(session: Session): inputs = [ messages.TxInputType( address_n=parse_path("m/84h/1h/3h/0/12"), @@ -520,15 +518,15 @@ def test_sign_tx_migration(client: Client): # Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest(), @@ -536,7 +534,7 @@ def test_sign_tx_migration(client: Client): request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_2cc3c1), @@ -558,7 +556,7 @@ def test_sign_tx_migration(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -573,11 +571,11 @@ def test_sign_tx_migration(client: Client): ) -def test_wrong_coordinator(client: Client): +def test_wrong_coordinator(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -589,7 +587,7 @@ def test_wrong_coordinator(client: Client): with pytest.raises(TrezorFailure, match="Unauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -599,9 +597,9 @@ def test_wrong_coordinator(client: Client): ) -def test_wrong_account_type(client: Client): +def test_wrong_account_type(session: Session): params = { - "client": client, + "session": session, "coordinator": "www.example.com", "max_rounds": 10, "max_coordinator_fee_rate": 500_000, # 0.5 % @@ -625,11 +623,11 @@ def test_wrong_account_type(client: Client): ) -def test_cancel_authorization(client: Client): +def test_cancel_authorization(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -639,11 +637,11 @@ def test_cancel_authorization(client: Client): script_type=messages.InputScriptType.SPENDTAPROOT, ) - device.cancel_authorization(client) + device.cancel_authorization(session) with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -653,35 +651,39 @@ def test_cancel_authorization(client: Client): ) -def test_get_public_key(client: Client): +def test_get_public_key(session: Session): + raise Exception( + "Test fails on: unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH)" + ) # TODO + ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h") EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU" # Ensure that user cannot access SLIP-25 path without UnlockPath. with pytest.raises(TrezorFailure, match="Forbidden key path"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) # Get unlock path MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH) # Ensure that UnlockPath fails with invalid MAC. invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:] with pytest.raises(TrezorFailure, match="Invalid MAC"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -690,15 +692,15 @@ def test_get_public_key(client: Client): ) # Ensure that user does not need to confirm access when path unlock is requested with MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.UnlockedPathRequest, messages.PublicKey, ] ) resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -708,11 +710,15 @@ def test_get_public_key(client: Client): assert resp.xpub == EXPECTED_XPUB -def test_get_address(client: Client): +def test_get_address(session: Session): + raise Exception( + "Test fails on: unlock_path_mac = device.unlock_path(session, SLIP25_PATH)" + ) # TODO + # Ensure that the SLIP-0025 external chain is inaccessible without user confirmation. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -720,20 +726,20 @@ def test_get_address(client: Client): ) # Unlock CoinJoin path. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, SLIP25_PATH) # Ensure that the SLIP-0025 external chain is accessible after user confirmation. for chunkify in (True, False): resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -745,7 +751,7 @@ def test_get_address(client: Client): assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7" resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -758,7 +764,7 @@ def test_get_address(client: Client): # Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -769,7 +775,7 @@ def test_get_address(client: Client): with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -781,7 +787,7 @@ def test_get_address(client: Client): # Ensure that another SLIP-0025 account is inaccessible with the same MAC. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/1h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -791,10 +797,11 @@ def test_get_address(client: Client): ) -def test_multisession_authorization(client: Client): +def test_multisession_authorization(session: Session): + raise Exception("Test is not functional with the new session handling") # TODO # Authorize CoinJoin with www.example1.com in session 1. btc.authorize_coinjoin( - client, + session, coordinator="www.example1.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -805,12 +812,12 @@ def test_multisession_authorization(client: Client): ) # Open a second session. - session_id1 = client.session_id - client.init_device(new_session=True) + # session_id1 = session.session_id + # TODO client.init_device(new_session=True) # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( - client, + session, coordinator="www.example2.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -823,7 +830,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example1.com should fail in session 2. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -834,7 +841,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -849,12 +856,12 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - session_id2 = client.session_id - client.init_device(session_id=session_id1) + # session_id2 = session.session_id + # TODO client.init_device(session_id=session_id1) # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -871,7 +878,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should fail in session 1. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -881,12 +888,12 @@ def test_multisession_authorization(client: Client): ) # Cancel the authorization in session 1. - device.cancel_authorization(client) + device.cancel_authorization(session) # Requesting a preauthorized ownership proof should fail now. with pytest.raises(TrezorFailure, match="No preauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -896,11 +903,11 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - client.init_device(session_id=session_id2) + # TODO client.init_device(session_id=session_id2) # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, diff --git a/tests/device_tests/bitcoin/test_bcash.py b/tests/device_tests/bitcoin/test_bcash.py index 76538828632..d1f0129741c 100644 --- a/tests/device_tests/bitcoin/test_bcash.py +++ b/tests/device_tests/bitcoin/test_bcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -53,7 +53,7 @@ pytestmark = pytest.mark.altcoin -def test_send_bch_change(client: Client): +def test_send_bch_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/0/0"), # bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv @@ -72,14 +72,14 @@ def test_send_bch_change(client: Client): amount=73_452, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_bc37c2), @@ -92,9 +92,9 @@ def test_send_bch_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) - + # raise Exception(hexlify(serialized_tx)) assert_tx_matches( serialized_tx, hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c", @@ -102,7 +102,7 @@ def test_send_bch_change(client: Client): ) -def test_send_bch_nochange(client: Client): +def test_send_bch_nochange(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client): ) -def test_send_bch_oldaddr(client: Client): +def test_send_bch_oldaddr(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -252,15 +252,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_bd32ff), @@ -271,16 +271,16 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_bch_multisig_wrongchange(client: Client): +def test_send_bch_multisig_wrongchange(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -327,13 +327,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=23_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_062fbd), @@ -346,7 +346,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1], prev_txes=TX_API + session, "Bcash", [inp1], [out1], prev_txes=TX_API ) assert ( signatures1[0].hex() @@ -359,12 +359,12 @@ def getmultisig(chain, nr, signatures): @pytest.mark.multisig -def test_send_bch_multisig_change(client: Client): +def test_send_bch_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -395,13 +395,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=24_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -415,7 +415,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -434,13 +434,13 @@ def getmultisig(chain, nr, signatures): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -454,7 +454,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -468,7 +468,7 @@ def getmultisig(chain, nr, signatures): @pytest.mark.models("core") -def test_send_bch_external_presigned(client: Client): +def test_send_bch_external_presigned(session: Session): inp1 = messages.TxInputType( # address_n=parse_path("44'/145'/0'/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_bgold.py b/tests/device_tests/bitcoin/test_bgold.py index 71c1a6c3ad4..831ea216cbd 100644 --- a/tests/device_tests/bitcoin/test_bgold.py +++ b/tests/device_tests/bitcoin/test_bgold.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path, tx_hash @@ -51,7 +51,7 @@ # All data taken from T1 -def test_send_bitcoin_gold_change(client: Client): +def test_send_bitcoin_gold_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client): amount=1_252_382_934 - 1_896_050 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client): ) -def test_send_bitcoin_gold_nochange(client: Client): +def test_send_bitcoin_gold_nochange(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client): amount=1_252_382_934 + 38_448_607 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -193,15 +193,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -213,16 +213,16 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_btg_multisig_change(client: Client): +def test_send_btg_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" + session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -254,13 +254,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=1_252_382_934 - 24_000 - 1_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -275,7 +275,7 @@ def getmultisig(chain, nr, signatures): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -293,13 +293,13 @@ def getmultisig(chain, nr, signatures): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -314,7 +314,7 @@ def getmultisig(chain, nr, signatures): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -327,7 +327,7 @@ def getmultisig(chain, nr, signatures): ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -347,16 +347,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_db7239), @@ -371,7 +371,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -380,7 +380,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_witness_change(client: Client): +def test_send_p2sh_witness_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" + session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client): request_finished(), ] ) - signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API) + signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API) # store signature inp1.multisig.signatures[0] = signatures[0] # sign with third key inp1.address_n[2] = H_(3) - client.set_expected_responses( + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1], prev_txes=TX_API + session, "Bgold", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client): ) -def test_send_mixed_inputs(client: Client): +def test_send_mixed_inputs(session: Session): # NOTE: fake input tx used # First is non-segwit, second is segwit. @@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client): @pytest.mark.models("core") -def test_send_btg_external_presigned(client: Client): +def test_send_btg_external_presigned(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client): amount=1_252_382_934 + 58_456 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_dash.py b/tests/device_tests/bitcoin/test_dash.py index 4dde98bfbfd..06b335c1487 100644 --- a/tests/device_tests/bitcoin/test_dash.py +++ b/tests/device_tests/bitcoin/test_dash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] -def test_send_dash(client: Client): +def test_send_dash(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -57,13 +57,13 @@ def test_send_dash(client: Client): amount=999_999_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -77,7 +77,9 @@ def test_send_dash(client: Client): request_finished(), ] ) - _, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API) + _, serialized_tx = btc.sign_tx( + session, "Dash", [inp1], [out1], prev_txes=TX_API + ) assert ( serialized_tx.hex() @@ -85,7 +87,7 @@ def test_send_dash(client: Client): ) -def test_send_dash_dip2_input(client: Client): +def test_send_dash_dip2_input(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client): amount=95_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Dash", [inp1], [out1, out2], prev_txes=TX_API + session, "Dash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_decred.py b/tests/device_tests/bitcoin/test_decred.py index 78bb1b0c3af..204d0559280 100644 --- a/tests/device_tests/bitcoin/test_decred.py +++ b/tests/device_tests/bitcoin/test_decred.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -57,7 +57,7 @@ ] -def test_send_decred(client: Client): +def test_send_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -76,13 +76,13 @@ def test_send_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -95,7 +95,7 @@ def test_send_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -105,7 +105,7 @@ def test_send_decred(client: Client): @pytest.mark.models("core") -def test_purchase_ticket_decred(client: Client): +def test_purchase_ticket_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), @@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1], [out1, out2, out3], @@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client): @pytest.mark.models("core") -def test_spend_from_stake_generation_and_revocation_decred(client: Client): +def test_spend_from_stake_generation_and_revocation_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_8b6890), @@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ) -def test_send_decred_change(client: Client): +def test_send_decred_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -278,15 +278,15 @@ def test_send_decred_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_input(2), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -311,7 +311,7 @@ def test_send_decred_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2, inp3], [out1, out2], @@ -325,12 +325,12 @@ def test_send_decred_change(client: Client): @pytest.mark.multisig -def test_decred_multisig_change(client: Client): +def test_decred_multisig_change(session: Session): # NOTE: fake input tx used paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)] nodes = [ - btc.get_public_node(client, address_n, coin_name="Decred Testnet").node + btc.get_public_node(session, address_n, coin_name="Decred Testnet").node for address_n in paths ] @@ -384,15 +384,15 @@ def test_multisig(index): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_9ac7d2), @@ -410,7 +410,7 @@ def test_multisig(index): ] ) signature, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 6efdd99ed82..7a077b20527 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -18,7 +18,7 @@ from trezorlib import btc, messages, models from trezorlib.cli import btc as btc_cli -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_ from ...input_flows import InputFlowShowXpubQRCode @@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type): @pytest.mark.parametrize( "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) -def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors): - with client: +def test_descriptors( + session: Session, coin, account, purpose, script_type, descriptors +): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( - client, + session, _address_n(purpose, coin, account, script_type), show_display=True, coin_name=coin, @@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) def test_descriptors_trezorlib( - client: Client, coin, account, purpose, script_type, descriptors + session: Session, coin, account, purpose, script_type, descriptors ): - with client: + with session.client as client: if client.model != models.T1B1: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) res = btc_cli._get_descriptor( - client, coin, account, purpose, script_type, show_display=True + session, coin, account, purpose, script_type, show_display=True ) assert res == descriptors diff --git a/tests/device_tests/bitcoin/test_firo.py b/tests/device_tests/bitcoin/test_firo.py index 52db787957d..2ceeb2c2d77 100644 --- a/tests/device_tests/bitcoin/test_firo.py +++ b/tests/device_tests/bitcoin/test_firo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -30,7 +30,7 @@ @pytest.mark.altcoin -def test_spend_lelantus(client: Client): +def test_spend_lelantus(session: Session): inp1 = messages.TxInputType( # THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk address_n=parse_path("m/44h/1h/0h/0/4"), @@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API + session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API ) assert_tx_matches( serialized_tx, diff --git a/tests/device_tests/bitcoin/test_fujicoin.py b/tests/device_tests/bitcoin/test_fujicoin.py index f28747c7173..45886e8603b 100644 --- a/tests/device_tests/bitcoin/test_fujicoin.py +++ b/tests/device_tests/bitcoin/test_fujicoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path TXHASH_33043a = bytes.fromhex( @@ -27,7 +27,7 @@ pytestmark = pytest.mark.altcoin -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7 address_n=parse_path("m/86h/75h/0h/0/1"), @@ -42,7 +42,7 @@ def test_send_p2tr(client: Client): amount=99_996_670_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1]) + _, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1]) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86 assert ( diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index f668a22bf54..846c4212778 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path @@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs): ) -def test_btc(client: Client): +def test_btc(session: Session): assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) == "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) == "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" ) @pytest.mark.altcoin -def test_ltc(client: Client): +def test_ltc(session: Session): assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0")) == "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1")) == "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0")) == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" ) -def test_tbtc(client: Client): +def test_tbtc(session: Session): assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1")) == "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" ) @pytest.mark.altcoin -def test_bch(client: Client): +def test_bch(session: Session): assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0")) == "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1")) == "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0")) == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" ) @pytest.mark.altcoin -def test_grs(client: Client): +def test_grs(session: Session): assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) == "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) == "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" ) @pytest.mark.altcoin -def test_tgrs(client: Client): +def test_tgrs(session: Session): assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) == "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6" ) @pytest.mark.altcoin -def test_elements(client: Client): +def test_elements(session: Session): assert ( - btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0")) == "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr" ) @pytest.mark.models("core") -def test_address_mac(client: Client): +def test_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/1/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/1/0") ) assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert ( @@ -150,7 +150,7 @@ def test_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Testnet", parse_path("m/44h/1h/0h/1/0") + session, "Testnet", parse_path("m/44h/1h/0h/1/0") ) assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" assert ( @@ -160,16 +160,16 @@ def test_address_mac(client: Client): # Script type mismatch. resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False ) assert resp.mac is None @pytest.mark.models("core") @pytest.mark.altcoin -def test_altcoin_address_mac(client: Client): +def test_altcoin_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Litecoin", parse_path("m/44h/2h/0h/1/0") + session, "Litecoin", parse_path("m/44h/2h/0h/1/0") ) assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" assert ( @@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Bcash", parse_path("m/44h/145h/0h/1/0") + session, "Bcash", parse_path("m/44h/145h/0h/1/0") ) assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" assert ( @@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") + session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") ) assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" assert ( @@ -197,20 +197,20 @@ def test_altcoin_address_mac(client: Client): @pytest.mark.multisig -def test_multisig(client: Client): +def test_multisig(session: Session): xpubs = [] for n in range(1, 4): - node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h")) + node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h")) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/0/0"), show_display=(nr == 1), @@ -220,7 +220,7 @@ def test_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/1/0"), show_display=(nr == 1), @@ -232,12 +232,12 @@ def test_multisig(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -246,7 +246,7 @@ def test_multisig_missing(client: Client, show_display): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/44h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/44h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( @@ -260,12 +260,12 @@ def test_multisig_missing(client: Client, show_display): ) for multisig in (multisig1, multisig2): - with client, pytest.raises(TrezorFailure): - if is_core(client): + with session.client as client, pytest.raises(TrezorFailure): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=show_display, @@ -275,22 +275,22 @@ def test_multisig_missing(client: Client, show_display): @pytest.mark.altcoin @pytest.mark.multisig -def test_bch_multisig(client: Client): +def test_bch_multisig(session: Session): xpubs = [] for n in range(1, 4): node = btc.get_public_node( - client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" + session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" ) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/0/0"), show_display=(nr == 1), @@ -300,7 +300,7 @@ def test_bch_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/1/0"), show_display=(nr == 1), @@ -310,43 +310,43 @@ def test_bch_multisig(client: Client): ) -def test_public_ckd(client: Client): - node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node - node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node +def test_public_ckd(session: Session): + node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node + node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node node_sub2 = bip32.public_ckd(node, [1, 0]) assert node_sub1.chain_code == node_sub2.chain_code assert node_sub1.public_key == node_sub2.public_key - address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) address2 = bip32.get_address(node_sub2, 0) assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert address1 == address2 -def test_invalid_path(client: Client): +def test_invalid_path(session: Session): with pytest.raises(TrezorFailure, match="Forbidden key path"): # slip44 id mismatch btc.get_address( - client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True ) -def test_unknown_path(client: Client): +def test_unknown_path(session: Session): UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") - with client: - client.set_expected_responses([messages.Failure]) + with session: + session.set_expected_responses([messages.Failure]) with pytest.raises(TrezorFailure, match="Forbidden key path"): # account number is too high - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) # disable safety checks - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ messages.ButtonRequest( code=messages.ButtonRequestType.UnknownDerivationPath @@ -355,21 +355,21 @@ def test_unknown_path(client: Client): messages.Address, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) # try again with a warning - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) - with client: + with session: # no warning is displayed when the call is silent - client.set_expected_responses([messages.Address]) - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False) + session.set_expected_responses([messages.Address]) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) @pytest.mark.altcoin -def test_crw(client: Client): +def test_crw(session: Session): assert ( - btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0")) + btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0")) == "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48" ) diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index 0958facda29..8514082828c 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -25,10 +25,10 @@ from ...input_flows import InputFlowConfirmAllWarnings -def test_show_segwit(client: Client): +def test_show_segwit(session: Session): assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -39,7 +39,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/0/0"), False, @@ -50,7 +50,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -61,7 +61,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -73,14 +73,14 @@ def test_show_segwit(client: Client): @pytest.mark.altcoin -def test_show_segwit_altcoin(client: Client): - with client: - if is_core(client): +def test_show_segwit_altcoin(session: Session): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/0/0"), True, @@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Elements", parse_path("m/49h/1h/0h/0/0"), True, @@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/49h/1h/{i}h/0/7"), False, @@ -168,12 +168,12 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -182,7 +182,7 @@ def test_multisig_missing(client: Client, show_display): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/49h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/49h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -197,7 +197,7 @@ def test_multisig_missing(client: Client, show_display): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/49h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py index f1416300533..7c1fdc77428 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -141,7 +141,7 @@ @pytest.mark.parametrize("show_display", (True, False)) @pytest.mark.parametrize("coin, path, script_type, address", VECTORS) def test_show_segwit( - client: Client, + session: Session, show_display: bool, coin: str, path: str, @@ -150,7 +150,7 @@ def test_show_segwit( ): assert ( btc.get_address( - client, + session, coin, parse_path(path), show_display, @@ -166,10 +166,10 @@ def test_show_segwit( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) @pytest.mark.parametrize("path, address", BIP86_VECTORS) -def test_bip86(client: Client, path: str, address: str): +def test_bip86(session: Session, path: str, address: str): assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(path), False, @@ -181,10 +181,10 @@ def test_bip86(client: Client, path: str, address: str): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -197,7 +197,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/1"), False, @@ -208,7 +208,7 @@ def test_show_multisig_3(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/0"), False, @@ -221,12 +221,12 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display: bool): +def test_multisig_missing(session: Session, show_display: bool): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -235,7 +235,7 @@ def test_multisig_missing(client: Client, show_display: bool): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -250,7 +250,7 @@ def test_multisig_missing(client: Client, show_display: bool): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index ef143ff3624..24bac454d2c 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import is_core @@ -55,20 +55,20 @@ @pytest.mark.models("legacy") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_t1( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): def input_flow_t1(): yield - client.debug.press_no() + session.debug.press_no() yield - client.debug.press_yes() + session.debug.press_yes() - with client: + with session: # This is the only place where even T1 is using input flow - client.set_input_flow(input_flow_t1) + session.set_input_flow(input_flow_t1) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -82,18 +82,18 @@ def input_flow_t1(): @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_tt( - client: Client, + session: Session, chunkify: bool, path: str, script_type: messages.InputScriptType, address: str, ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -107,13 +107,13 @@ def test_show_tt( @pytest.mark.models("core") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_cancel( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowShowAddressQRCodeCancel(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -121,10 +121,10 @@ def test_show_cancel( ) -def test_show_unrecognized_path(client: Client): +def test_show_unrecognized_path(session: Session): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", tools.parse_path("m/24684621h/516582h/5156h/21/856"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -133,9 +133,9 @@ def test_show_unrecognized_path(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): node = btc.get_public_node( - client, tools.parse_path("m/45h/0/0"), coin_name="Bitcoin" + session, tools.parse_path("m/45h/0/0"), coin_name="Bitcoin" ).node multisig = messages.MultisigRedeemScriptType( pubkeys=[ @@ -148,13 +148,13 @@ def test_show_multisig_3(client: Client): ) for i in [1, 2, 3]: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/0/0/{i}"), show_display=True, @@ -241,16 +241,17 @@ def test_show_multisig_3(client: Client): "script_type, bip48_type, address, xpubs, ignore_xpub_magic", VECTORS_MULTISIG ) def test_show_multisig_xpubs( - client: Client, + session: Session, script_type: messages.InputScriptType, bip48_type: int, address: str, xpubs: list[str], ignore_xpub_magic: bool, ): + raise Exception("Does not work") nodes = [ btc.get_public_node( - client, + session, tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h"), coin_name="Bitcoin", ) @@ -264,13 +265,13 @@ def test_show_multisig_xpubs( ) for i in range(3): - with client: + with session.client as client: IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) client.set_input_flow(IF.get()) client.debug.synchronize_at("Homescreen") client.watch_layout() btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h/0/0"), show_display=True, @@ -281,9 +282,9 @@ def test_show_multisig_xpubs( @pytest.mark.multisig -def test_show_multisig_15(client: Client): +def test_show_multisig_15(session: Session): node = btc.get_public_node( - client, tools.parse_path("m/45h/0/0"), coin_name="Bitcoin" + session, tools.parse_path("m/45h/0/0"), coin_name="Bitcoin" ).node pubs = [messages.HDNodePathType(node=node, address_n=[x]) for x in range(15)] @@ -293,13 +294,13 @@ def test_show_multisig_15(client: Client): ) for i in range(15): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/0/0/{i}"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getownershipproof.py b/tests/device_tests/bitcoin/test_getownershipproof.py index b21fe944b0e..51309eb625d 100644 --- a/tests/device_tests/bitcoin/test_getownershipproof.py +++ b/tests/device_tests/bitcoin/test_getownershipproof.py @@ -17,14 +17,14 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path -def test_p2wpkh_ownership_id(client: Client): +def test_p2wpkh_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -35,9 +35,9 @@ def test_p2wpkh_ownership_id(client: Client): ) -def test_p2tr_ownership_id(client: Client): +def test_p2tr_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -48,12 +48,12 @@ def test_p2tr_ownership_id(client: Client): ) -def test_attack_ownership_id(client: Client): +def test_attack_ownership_id(session: Session): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -62,7 +62,7 @@ def test_attack_ownership_id(client: Client): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -77,7 +77,7 @@ def test_attack_ownership_id(client: Client): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), multisig=multisig, @@ -85,9 +85,9 @@ def test_attack_ownership_id(client: Client): ) -def test_p2wpkh_ownership_proof(client: Client): +def test_p2wpkh_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -98,9 +98,9 @@ def test_p2wpkh_ownership_proof(client: Client): ) -def test_p2tr_ownership_proof(client: Client): +def test_p2tr_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -111,10 +111,10 @@ def test_p2tr_ownership_proof(client: Client): ) -def test_fake_ownership_id(client: Client): +def test_fake_ownership_id(session: Session): with pytest.raises(TrezorFailure, match="Invalid ownership identifier"): btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -124,9 +124,9 @@ def test_fake_ownership_id(client: Client): ) -def test_confirm_ownership_proof(client: Client): +def test_confirm_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -139,9 +139,9 @@ def test_confirm_ownership_proof(client: Client): ) -def test_confirm_ownership_proof_with_data(client: Client): +def test_confirm_ownership_proof_with_data(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index e8b90cbb487..81dadf8a60a 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -110,35 +110,35 @@ @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node(client: Client, coin_name, xpub_magic, path, xpub): - res = btc.get_public_node(client, path, coin_name=coin_name) +def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): + res = btc.get_public_node(session, path, coin_name=coin_name) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show(client: Client, coin_name, xpub_magic, path, xpub): - with client: +def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") @pytest.mark.parametrize("coin_name, path", VECTORS_INVALID) -def test_invalid_path(client: Client, coin_name, path): +def test_invalid_path(session: Session, coin_name, path): with pytest.raises(TrezorFailure, match="Forbidden key path"): - btc.get_public_node(client, path, coin_name=coin_name) + btc.get_public_node(session, path, coin_name=coin_name) -def test_slip25_path(client: Client): +def test_slip25_path(session: Session): # Ensure that CoinJoin XPUBs are inaccessible without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_public_node( - client, + session, parse_path("m/10025h/0h/0h/1h"), script_type=messages.InputScriptType.SPENDTAPROOT, ) @@ -169,14 +169,14 @@ def test_slip25_path(client: Client): @pytest.mark.parametrize("script_type, xpub, xpub_ignored_magic", VECTORS_SCRIPT_TYPES) -def test_script_type(client: Client, script_type, xpub, xpub_ignored_magic): +def test_script_type(session: Session, script_type, xpub, xpub_ignored_magic): path = parse_path("m/44h/0h/0") res = btc.get_public_node( - client, path, coin_name="Bitcoin", script_type=script_type + session, path, coin_name="Bitcoin", script_type=script_type ) assert res.xpub == xpub res = btc.get_public_node( - client, + session, path, coin_name="Bitcoin", script_type=script_type, diff --git a/tests/device_tests/bitcoin/test_getpublickey_curve.py b/tests/device_tests/bitcoin/test_getpublickey_curve.py index 8b8cba68871..393afca61c8 100644 --- a/tests/device_tests/bitcoin/test_getpublickey_curve.py +++ b/tests/device_tests/bitcoin/test_getpublickey_curve.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -54,21 +54,21 @@ @pytest.mark.parametrize("curve, path, pubkey", VECTORS) -def test_publickey_curve(client: Client, curve, path, pubkey): - resp = btc.get_public_node(client, path, ecdsa_curve_name=curve) +def test_publickey_curve(session: Session, curve, path, pubkey): + resp = btc.get_public_node(session, path, ecdsa_curve_name=curve) assert resp.node.public_key.hex() == pubkey -def test_ed25519_public(client: Client): +def test_ed25519_public(session: Session): with pytest.raises(TrezorFailure): - btc.get_public_node(client, PATH_PUBLIC, ecdsa_curve_name="ed25519") + btc.get_public_node(session, PATH_PUBLIC, ecdsa_curve_name="ed25519") @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") -def test_coin_and_curve(client: Client): +def test_coin_and_curve(session: Session): with pytest.raises( TrezorFailure, match="Cannot use coin_name or script_type with ecdsa_curve_name" ): btc.get_public_node( - client, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" + session, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" ) diff --git a/tests/device_tests/bitcoin/test_grs.py b/tests/device_tests/bitcoin/test_grs.py index d25ffd20f00..ff2b5c4cdfc 100644 --- a/tests/device_tests/bitcoin/test_grs.py +++ b/tests/device_tests/bitcoin/test_grs.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ pytestmark = pytest.mark.altcoin -def test_legacy(client: Client): +def test_legacy(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -56,7 +56,7 @@ def test_legacy(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -64,7 +64,7 @@ def test_legacy(client: Client): ) -def test_legacy_change(client: Client): +def test_legacy_change(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -78,7 +78,7 @@ def test_legacy_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -86,7 +86,7 @@ def test_legacy_change(client: Client): ) -def test_send_segwit_p2sh(client: Client): +def test_send_segwit_p2sh(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -107,7 +107,7 @@ def test_send_segwit_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -120,7 +120,7 @@ def test_send_segwit_p2sh(client: Client): ) -def test_send_segwit_p2sh_change(client: Client): +def test_send_segwit_p2sh_change(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -141,7 +141,7 @@ def test_send_segwit_p2sh_change(client: Client): amount=123_456_789 - 11_000 - 12_300_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -154,7 +154,7 @@ def test_send_segwit_p2sh_change(client: Client): ) -def test_send_segwit_native(client: Client): +def test_send_segwit_native(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -174,7 +174,7 @@ def test_send_segwit_native(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -187,7 +187,7 @@ def test_send_segwit_native(client: Client): ) -def test_send_segwit_native_change(client: Client): +def test_send_segwit_native_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -207,7 +207,7 @@ def test_send_segwit_native_change(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -220,7 +220,7 @@ def test_send_segwit_native_change(client: Client): ) -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # tgrs1paxhjl357yzctuf3fe58fcdx6nul026hhh6kyldpfsf3tckj9a3wsvuqrgn address_n=parse_path("m/86h/1h/1h/0/0"), @@ -236,7 +236,7 @@ def test_send_p2tr(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://blockbook-test.groestlcoin.org/tx/c66a79075044aaab3dba17daffb23f48addee87d7c87c7bc88e2997ce38a74ee diff --git a/tests/device_tests/bitcoin/test_komodo.py b/tests/device_tests/bitcoin/test_komodo.py index f883afc7bcd..111acefc6f3 100644 --- a/tests/device_tests/bitcoin/test_komodo.py +++ b/tests/device_tests/bitcoin/test_komodo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.komodo] -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: 2807c5b126ec8e2b078cab0f12e4c8b4ce1d7724905f8ebef8dca26b0c8e0f1d:0 # input 1: 10.9998 KMD @@ -61,13 +61,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -82,7 +82,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1], @@ -100,7 +100,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_one_one_rewards_claim(client: Client): +def test_one_one_rewards_claim(session: Session): # prevout: 7b28bd91119e9776f0d4ebd80e570165818a829bbf4477cd1afe5149dbcd34b1:0 # input 1: 10.9997 KMD @@ -125,16 +125,16 @@ def test_one_one_rewards_claim(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -150,7 +150,7 @@ def test_one_one_rewards_claim(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 927720bfb2a..30a544d9446 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -54,12 +54,12 @@ @pytest.mark.multisig @pytest.mark.parametrize("chunkify", (True, False)) -def test_2_of_3(client: Client, chunkify: bool): +def test_2_of_3(session: Session, chunkify: bool): # input tx: 6b07c1321b52d9c85743f9695e13eb431b41708cdf4e1585258d51208e5b93fc nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -88,7 +88,7 @@ def test_2_of_3(client: Client, chunkify: bool): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_6b07c1), @@ -100,12 +100,12 @@ def test_2_of_3(client: Client, chunkify: bool): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) # Now we have first signature signatures1, _ = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1], @@ -142,10 +142,10 @@ def test_2_of_3(client: Client, chunkify: bool): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET ) assert ( @@ -161,11 +161,11 @@ def test_2_of_3(client: Client, chunkify: bool): @pytest.mark.multisig -def test_15_of_15(client: Client): +def test_15_of_15(session: Session): # input tx: 0d5b5648d47b5650edea1af3d47bbe5624213abb577cf1b1c96f98321f75cdbc node = btc.get_public_node( - client, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" + session, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" ).node pubs = [messages.HDNodePathType(node=node, address_n=[0, x]) for x in range(15)] @@ -191,9 +191,9 @@ def test_15_of_15(client: Client): multisig=multisig, ) - with client: + with session: sig, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) signatures[x] = sig[0] @@ -205,9 +205,9 @@ def test_15_of_15(client: Client): @pytest.mark.multisig @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_missing_pubkey(client: Client): +def test_missing_pubkey(session: Session): node = btc.get_public_node( - client, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" ).node multisig = messages.MultisigRedeemScriptType( @@ -237,16 +237,16 @@ def test_missing_pubkey(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) - if client.model is models.T1B1: + if session.model is models.T1B1: assert exc.value.message.endswith("Failed to derive scriptPubKey") else: assert exc.value.message.endswith("Pubkey not found in multisig script") @pytest.mark.multisig -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): """ In Phases 1 and 2 the attacker replaces a non-multisig input `input_real` with a multisig input `input_fake`, which allows the @@ -269,7 +269,7 @@ def test_attack_change_input(client: Client): multisig_fake = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -304,12 +304,12 @@ def test_attack_change_input(client: Client): ) # Transaction can be signed without the attack processor - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], @@ -326,11 +326,11 @@ def attack_processor(msg): attack_count -= 1 return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index e94eb60fd31..d2387c8b950 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ... import bip32 @@ -140,7 +140,7 @@ def _responses( - client: Client, + session: Session, INP1: messages.TxInputType, INP2: messages.TxInputType, change: int = 0, @@ -154,7 +154,7 @@ def _responses( if change != 1: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) elif foreign: resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) @@ -163,7 +163,7 @@ def _responses( if change != 2: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) elif foreign: resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) @@ -196,7 +196,7 @@ def _responses( # both outputs are external -def test_external_external(client: Client): +def test_external_external(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -209,10 +209,10 @@ def test_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -227,7 +227,7 @@ def test_external_external(client: Client): # first external, second internal -def test_external_internal(client: Client): +def test_external_internal(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -240,15 +240,15 @@ def test_external_internal(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change=2, foreign=True) + with session, session.client as client: + session.set_expected_responses( + _responses(session, INP1, INP2, change=2, foreign=True) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -263,7 +263,7 @@ def test_external_internal(client: Client): # first internal, second external -def test_internal_external(client: Client): +def test_internal_external(session: Session): out1 = messages.TxOutputType( address_n=parse_path("m/45h/0/1/0"), amount=40_000_000, @@ -276,15 +276,15 @@ def test_internal_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change=1, foreign=True) + with session, session.client as client: + session.set_expected_responses( + _responses(session, INP1, INP2, change=1, foreign=True) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -299,7 +299,7 @@ def test_internal_external(client: Client): # both outputs are external -def test_multisig_external_external(client: Client): +def test_multisig_external_external(session: Session): out1 = messages.TxOutputType( address="3B23k4kFBRtu49zvpG3Z9xuFzfpHvxBcwt", amount=40_000_000, @@ -312,10 +312,10 @@ def test_multisig_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -330,7 +330,7 @@ def test_multisig_external_external(client: Client): # inputs match, change matches (first is change) -def test_multisig_change_match_first(client: Client): +def test_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT2, NODE_EXT1, NODE_INT], address_n=[1, 0], @@ -351,10 +351,10 @@ def test_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2, change=1)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2, change=1)) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -369,7 +369,7 @@ def test_multisig_change_match_first(client: Client): # inputs match, change matches (second is change) -def test_multisig_change_match_second(client: Client): +def test_multisig_change_match_second(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 1], @@ -390,10 +390,10 @@ def test_multisig_change_match_second(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2, change=2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2, change=2)) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -408,7 +408,7 @@ def test_multisig_change_match_second(client: Client): # inputs match, change mismatches (second tries to be change but isn't) -def test_multisig_mismatch_change(client: Client): +def test_multisig_mismatch_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], address_n=[1, 0], @@ -429,10 +429,10 @@ def test_multisig_mismatch_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -447,7 +447,7 @@ def test_multisig_mismatch_change(client: Client): # inputs mismatch, change matches with first input -def test_multisig_mismatch_inputs(client: Client): +def test_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT2, NODE_EXT1, NODE_INT], address_n=[1, 0], @@ -468,10 +468,10 @@ def test_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP3)) + with session: + session.set_expected_responses(_responses(session, INP1, INP3)) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP3], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index 96457da386c..156e61c3962 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -96,11 +96,11 @@ # accepted in case we make this more restrictive in the future. @pytest.mark.parametrize("path, script_types", VECTORS) def test_getpublicnode( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: res = btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin", script_type=script_type + session, parse_path(path), coin_name="Bitcoin", script_type=script_type ) assert res.xpub @@ -109,18 +109,18 @@ def test_getpublicnode( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_types", VECTORS) def test_getaddress( - client: Client, + session: Session, chunkify: bool, path: str, script_types: list[messages.InputScriptType], ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) res = btc.get_address( - client, + session, "Bitcoin", parse_path(path), show_display=True, @@ -133,16 +133,16 @@ def test_getaddress( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signmessage( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path(path), script_type=script_type, @@ -154,12 +154,14 @@ def test_signmessage( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signtx( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): address_n = parse_path(path) for script_type in script_types: - address = btc.get_address(client, "Bitcoin", address_n, script_type=script_type) + address = btc.get_address( + session, "Bitcoin", address_n, script_type=script_type + ) prevhash, prevtx = forge_prevtx([(address, 390_000)]) inp1 = messages.TxInputType( address_n=address_n, @@ -175,12 +177,12 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert serialized_tx.hex() @@ -189,12 +191,12 @@ def test_signtx( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) def test_getaddress_multisig( - client: Client, paths: list[str], address_index: list[int] + session: Session, paths: list[str], address_index: list[int] ): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -202,12 +204,12 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) address = btc.get_address( - client, + session, "Bitcoin", parse_path(paths[0]) + address_index, show_display=True, @@ -220,11 +222,11 @@ def test_getaddress_multisig( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) -def test_signtx_multisig(client: Client, paths: list[str], address_index: list[int]): +def test_signtx_multisig(session: Session, paths: list[str], address_index: list[int]): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -237,7 +239,7 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i address_n = parse_path(paths[0]) + address_index address = btc.get_address( - client, + session, "Bitcoin", address_n, multisig=multisig, @@ -261,12 +263,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig, _ = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert sig[0] diff --git a/tests/device_tests/bitcoin/test_op_return.py b/tests/device_tests/bitcoin/test_op_return.py index b5063891993..0aa8acb0802 100644 --- a/tests/device_tests/bitcoin/test_op_return.py +++ b/tests/device_tests/bitcoin/test_op_return.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,7 +43,7 @@ ) -def test_opreturn(client: Client): +def test_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/1h/0/21"), # myGMXcCxmuDooMdzZFPMmvHviijzqYKhza amount=89_581, @@ -63,13 +63,13 @@ def test_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.SignTx), @@ -86,7 +86,7 @@ def test_opreturn(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -96,7 +96,7 @@ def test_opreturn(client: Client): ) -def test_nonzero_opreturn(client: Client): +def test_nonzero_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/10h/0/5"), amount=390_000, @@ -110,18 +110,18 @@ def test_nonzero_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="OP_RETURN output with non-zero amount" ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) -def test_opreturn_address(client: Client): +def test_opreturn_address(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/2"), amount=390_000, @@ -136,11 +136,11 @@ def test_opreturn_address(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="Output's address_n provided but not expected." ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_peercoin.py b/tests/device_tests/bitcoin/test_peercoin.py index b1b62e49e55..b3de714e26e 100644 --- a/tests/device_tests/bitcoin/test_peercoin.py +++ b/tests/device_tests/bitcoin/test_peercoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -32,7 +32,7 @@ @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_included(client: Client): +def test_timestamp_included(session: Session): # tx: 41b29ad615d8eea40a4654a052d18bb10cd08f203c351f4d241f88b031357d3d # input 0: 0.1 PPC @@ -50,7 +50,7 @@ def test_timestamp_included(client: Client): ) _, timestamp_tx = btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -66,7 +66,7 @@ def test_timestamp_included(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing(client: Client): +def test_timestamp_missing(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -81,7 +81,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -92,7 +92,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -104,7 +104,7 @@ def test_timestamp_missing(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing_prevtx(client: Client): +def test_timestamp_missing_prevtx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -122,7 +122,7 @@ def test_timestamp_missing_prevtx(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -134,7 +134,7 @@ def test_timestamp_missing_prevtx(client: Client): prevtx.timestamp = None with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index 211bed200e9..c2c78ec00af 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -20,7 +20,7 @@ from trezorlib import btc, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import message_filters from trezorlib.exceptions import Cancelled from trezorlib.tools import parse_path @@ -286,7 +286,7 @@ def case( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -296,7 +296,7 @@ def test_signmessage( signature: str, ): sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -314,7 +314,7 @@ def test_signmessage( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage_info( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -323,11 +323,11 @@ def test_signmessage_info( message: str, signature: str, ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignMessageInfo(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -354,12 +354,12 @@ def test_signmessage_info( @pytest.mark.models("core") @pytest.mark.parametrize("message", MESSAGE_LENGTHS) -def test_signmessage_pagination(client: Client, message: str): - with client: +def test_signmessage_pagination(session: Session, message: str): + with session.client as client: IF = InputFlowSignMessagePagination(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, @@ -367,19 +367,19 @@ def test_signmessage_pagination(client: Client, message: str): # We cannot differentiate between a newline and space in the message read from Trezor. # TODO: do the check also for T2B1 - if client.layout_type in (LayoutType.TT, LayoutType.Mercury): + if session.client.layout_type in (LayoutType.TT, LayoutType.Mercury): message_read = IF.message_read.replace(" ", "").replace("...", "") signed_message = message.replace("\n", "").replace(" ", "") assert signed_message in message_read @pytest.mark.models("t2t1", reason="Tailored to TT fonts and screen size") -def test_signmessage_pagination_trailing_newline(client: Client): +def test_signmessage_pagination_trailing_newline(session: Session): message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" # The trailing newline must not cause a new paginated screen to appear. # The UI must be a single dialog without pagination. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # expect address confirmation message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), @@ -389,18 +389,18 @@ def test_signmessage_pagination_trailing_newline(client: Client): ] ) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, ) -def test_signmessage_path_warning(client: Client): +def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ # expect a path warning message_filters.ButtonRequest( @@ -411,11 +411,11 @@ def test_signmessage_path_warning(client: Client): messages.MessageSignature, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/86h/0h/0h/0/0"), message=message, diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index 96fc4edc691..135992224e6 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.tools import H_, parse_path @@ -111,7 +111,7 @@ CORNER_BUTTON = (215, 25) -def test_one_one_fee(client: Client): +def test_one_one_fee(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -127,13 +127,13 @@ def test_one_one_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_0dac36), @@ -148,7 +148,7 @@ def test_one_one_fee(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -158,7 +158,7 @@ def test_one_one_fee(client: Client): ) -def test_testnet_one_two_fee(client: Client): +def test_testnet_one_two_fee(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd inp1 = messages.TxInputType( @@ -180,13 +180,13 @@ def test_testnet_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -203,7 +203,7 @@ def test_testnet_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -213,7 +213,7 @@ def test_testnet_one_two_fee(client: Client): ) -def test_testnet_fee_high_warning(client: Client): +def test_testnet_fee_high_warning(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -230,13 +230,13 @@ def test_testnet_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -250,7 +250,7 @@ def test_testnet_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -260,7 +260,7 @@ def test_testnet_fee_high_warning(client: Client): ) -def test_one_two_fee(client: Client): +def test_one_two_fee(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -282,14 +282,14 @@ def test_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -305,7 +305,7 @@ def test_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -316,7 +316,7 @@ def test_one_two_fee(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_one_three_fee(client: Client, chunkify: bool): +def test_one_three_fee(session: Session, chunkify: bool): # input tx: bb5169091f09e833e155b291b662019df56870effe388c626221c5ea84274bc4 inp1 = messages.TxInputType( @@ -344,16 +344,16 @@ def test_one_three_fee(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -371,7 +371,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2, out3], @@ -386,7 +386,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ) -def test_two_two(client: Client): +def test_two_two(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -415,15 +415,15 @@ def test_two_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -449,7 +449,7 @@ def test_two_two(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -464,7 +464,7 @@ def test_two_two(client: Client): @pytest.mark.slow -def test_lots_of_inputs(client: Client): +def test_lots_of_inputs(session: Session): # Tests if device implements serialization of len(inputs) correctly # input tx: 3019487f064329247daad245aed7a75349d09c14b1d24f170947690e030f5b20 @@ -485,7 +485,7 @@ def test_lots_of_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET + session, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -495,7 +495,7 @@ def test_lots_of_inputs(client: Client): @pytest.mark.slow -def test_lots_of_outputs(client: Client): +def test_lots_of_outputs(session: Session): # Tests if device implements serialization of len(outputs) correctly # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e @@ -518,7 +518,7 @@ def test_lots_of_outputs(client: Client): outputs.append(out) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -528,7 +528,7 @@ def test_lots_of_outputs(client: Client): @pytest.mark.slow -def test_lots_of_change(client: Client): +def test_lots_of_change(session: Session): # Tests if device implements prompting for multiple change addresses correctly # input tx: 892d06cb3394b8e6006eec9a2aa90692b718a29be6844b6c6a9e89ec3aa6aac4 @@ -559,13 +559,13 @@ def test_lots_of_change(client: Client): request_change_outputs = [request_output(i + 1) for i in range(cnt)] - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), ] + request_change_outputs + [ @@ -585,7 +585,7 @@ def test_lots_of_change(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -594,7 +594,7 @@ def test_lots_of_change(client: Client): ) -def test_fee_high_warning(client: Client): +def test_fee_high_warning(session: Session): # input tx: 1f326f65768d55ef146efbb345bd87abe84ac7185726d0457a026fc347a26ef3 inp1 = messages.TxInputType( @@ -610,13 +610,13 @@ def test_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -631,7 +631,7 @@ def test_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -642,7 +642,7 @@ def test_fee_high_warning(client: Client): @pytest.mark.models("core") -def test_fee_high_hardfail(client: Client): +def test_fee_high_hardfail(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -660,18 +660,18 @@ def test_fee_high_hardfail(client: Client): ) with pytest.raises(TrezorFailure, match="fee is unexpectedly large"): - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) # set SafetyCheckLevel to PromptTemporarily and try again device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: + with session.client as client: IF = InputFlowSignTxHighFee(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert IF.finished @@ -682,7 +682,7 @@ def test_fee_high_hardfail(client: Client): ) -def test_not_enough_funds(client: Client): +def test_not_enough_funds(session: Session): # input tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 inp1 = messages.TxInputType( @@ -698,21 +698,21 @@ def test_not_enough_funds(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.Failure(code=messages.FailureType.NotEnoughFunds), ] ) with pytest.raises(TrezorFailure, match="NotEnoughFunds"): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) -def test_p2sh(client: Client): +def test_p2sh(session: Session): # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e inp1 = messages.TxInputType( @@ -728,13 +728,13 @@ def test_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_58d56a), @@ -748,7 +748,7 @@ def test_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -758,7 +758,7 @@ def test_p2sh(client: Client): ) -def test_testnet_big_amount(client: Client): +def test_testnet_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 # input tx: 074b0070939db4c2635c1bef0c8e68412ccc8d3c8782137547c7a2bbde073fc0 @@ -775,7 +775,7 @@ def test_testnet_big_amount(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -785,7 +785,7 @@ def test_testnet_big_amount(client: Client): ) -def test_attack_change_outputs(client: Client): +def test_attack_change_outputs(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -815,15 +815,15 @@ def test_attack_change_outputs(client: Client): ) # Test if the transaction can be signed normally - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -849,7 +849,7 @@ def test_attack_change_outputs(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -871,14 +871,14 @@ def attack_processor(msg): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -886,7 +886,7 @@ def attack_processor(msg): ) -def test_attack_modify_change_address(client: Client): +def test_attack_modify_change_address(session: Session): # Ensure that if the change output is modified after the user confirms the # transaction, then signing fails. @@ -926,16 +926,18 @@ def attack_processor(msg): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # input tx: d2dcdaf547ea7f57a713c607f15e883ddc4a98167ee2c43ed953c53cb5153e24 inp1 = messages.TxInputType( @@ -960,7 +962,7 @@ def test_attack_change_input_address(client: Client): # Test if the transaction can be signed normally _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -982,14 +984,14 @@ def attack_processor(msg): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1004,7 +1006,7 @@ def attack_processor(msg): # Now run the attack, must trigger the exception with pytest.raises(TrezorFailure) as exc: btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1015,7 +1017,7 @@ def attack_processor(msg): assert exc.value.message.endswith("Transaction has changed during signing") -def test_spend_coinbase(client: Client): +def test_spend_coinbase(session: Session): # NOTE: the input transaction is not real # We did not have any coinbase transaction at connected with `all all` seed, # so it was artificially created for the test purpose @@ -1033,13 +1035,13 @@ def test_spend_coinbase(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_005f6f), @@ -1052,7 +1054,7 @@ def test_spend_coinbase(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -1062,7 +1064,7 @@ def test_spend_coinbase(client: Client): ) -def test_two_changes(client: Client): +def test_two_changes(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1091,13 +1093,13 @@ def test_two_changes(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), request_output(2), messages.ButtonRequest(code=B.SignTx), @@ -1118,7 +1120,7 @@ def test_two_changes(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change1, out_change2], @@ -1126,7 +1128,7 @@ def test_two_changes(client: Client): ) -def test_change_on_main_chain_allowed(client: Client): +def test_change_on_main_chain_allowed(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1150,13 +1152,13 @@ def test_change_on_main_chain_allowed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1174,7 +1176,7 @@ def test_change_on_main_chain_allowed(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change], @@ -1182,7 +1184,7 @@ def test_change_on_main_chain_allowed(client: Client): ) -def test_not_enough_vouts(client: Client): +def test_not_enough_vouts(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a prev_tx = TX_CACHE_MAINNET[TXHASH_ac4ca0] @@ -1222,7 +1224,7 @@ def test_not_enough_vouts(client: Client): TrezorFailure, match="Not enough outputs in previous transaction." ): btc.sign_tx( - client, + session, "Bitcoin", [inp0, inp1, inp2], [out1], @@ -1240,7 +1242,7 @@ def test_not_enough_vouts(client: Client): ("branch_id", 13), ), ) -def test_prevtx_forbidden_fields(client: Client, field, value): +def test_prevtx_forbidden_fields(session: Session, field, value): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1258,7 +1260,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} + session, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} ) @@ -1266,7 +1268,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): "field, value", (("expiry", 9), ("timestamp", 42), ("version_group_id", 69), ("branch_id", 13)), ) -def test_signtx_forbidden_fields(client: Client, field: str, value: int): +def test_signtx_forbidden_fields(session: Session, field: str, value: int): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1283,7 +1285,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs + session, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs ) @@ -1291,7 +1293,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): "script_type", (messages.InputScriptType.SPENDADDRESS, messages.InputScriptType.EXTERNAL), ) -def test_incorrect_input_script_type(client: Client, script_type): +def test_incorrect_input_script_type(session: Session, script_type): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( "030e669acac1f280d1ddf441cd2ba5e97417bf2689e4bbec86df4f831bf9f7ffd0" @@ -1300,7 +1302,7 @@ def test_incorrect_input_script_type(client: Client, script_type): multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1335,7 +1337,9 @@ def test_incorrect_input_script_type(client: Client, script_type): with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( @@ -1346,7 +1350,7 @@ def test_incorrect_input_script_type(client: Client, script_type): ), ) def test_incorrect_output_script_type( - client: Client, script_type: messages.OutputScriptType + session: Session, script_type: messages.OutputScriptType ): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( @@ -1356,7 +1360,7 @@ def test_incorrect_output_script_type( multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1390,14 +1394,16 @@ def test_incorrect_output_script_type( with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( "lock_time, sequence", ((499_999_999, 0xFFFFFFFE), (500_000_000, 0xFFFFFFFE), (1, 0xFFFFFFFF)), ) -def test_lock_time(client: Client, lock_time: int, sequence: int): +def test_lock_time(session: Session, lock_time: int, sequence: int): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1414,13 +1420,13 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1436,7 +1442,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): ) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1446,7 +1452,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_lock_time_blockheight(client: Client): +def test_lock_time_blockheight(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1463,12 +1469,12 @@ def test_lock_time_blockheight(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowLockTimeBlockHeight(client, "499999999") client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1481,7 +1487,7 @@ def test_lock_time_blockheight(client: Client): @pytest.mark.parametrize( "lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00") ) -def test_lock_time_datetime(client: Client, lock_time_str: str): +def test_lock_time_datetime(session: Session, lock_time_str: str): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1502,12 +1508,12 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with client: + with session.client as client: IF = InputFlowLockTimeDatetime(client, lock_time_str) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1517,7 +1523,7 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information(client: Client): +def test_information(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1534,12 +1540,12 @@ def test_information(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformation(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1548,7 +1554,7 @@ def test_information(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_mixed(client: Client): +def test_information_mixed(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/0"), # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q amount=31_000_000, @@ -1569,12 +1575,12 @@ def test_information_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationMixed(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -1583,7 +1589,7 @@ def test_information_mixed(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_cancel(client: Client): +def test_information_cancel(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1600,12 +1606,12 @@ def test_information_cancel(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignTxInformationCancel(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1618,7 +1624,7 @@ def test_information_cancel(client: Client): skip="mercury", reason="Cannot test layouts on T1, not implemented in mercury UI", ) -def test_information_replacement(client: Client): +def test_information_replacement(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -1650,12 +1656,12 @@ def test_information_replacement(client: Client): orig_index=0, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationReplacement(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_amount_unit.py b/tests/device_tests/bitcoin/test_signtx_amount_unit.py index d3dfa3d00ec..50cc19151b6 100644 --- a/tests/device_tests/bitcoin/test_signtx_amount_unit.py +++ b/tests/device_tests/bitcoin/test_signtx_amount_unit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_testnet(client: Client, amount_unit): +def test_signtx_testnet(session: Session, amount_unit): inp1 = messages.TxInputType( # tb1qajr3a3y5uz27lkxrmn7ck8lp22dgytvagr5nqy address_n=parse_path("m/84h/1h/0h/0/87"), @@ -61,9 +61,9 @@ def test_signtx_testnet(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -79,7 +79,7 @@ def test_signtx_testnet(client: Client, amount_unit): @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_btc(client: Client, amount_unit): +def test_signtx_btc(session: Session, amount_unit): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -95,9 +95,9 @@ def test_signtx_btc(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_external.py b/tests/device_tests/bitcoin/test_signtx_external.py index fd8e0cff3e9..4d44e3ec763 100644 --- a/tests/device_tests/bitcoin/test_signtx_external.py +++ b/tests/device_tests/bitcoin/test_signtx_external.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path @@ -82,7 +82,7 @@ @pytest.mark.models("core") -def test_p2pkh_presigned(client: Client): +def test_p2pkh_presigned(session: Session): inp1 = messages.TxInputType( # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q address_n=parse_path("m/44h/1h/0h/0/0"), @@ -142,9 +142,9 @@ def test_p2pkh_presigned(client: Client): ) # Test with first input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1ext, inp2], [out1, out2], @@ -155,9 +155,9 @@ def test_p2pkh_presigned(client: Client): assert serialized_tx.hex() == expected_tx # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -170,7 +170,7 @@ def test_p2pkh_presigned(client: Client): inp2ext.script_sig[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -179,7 +179,7 @@ def test_p2pkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_presigned(client: Client): +def test_p2wpkh_in_p2sh_presigned(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX amount=123_456_789, @@ -216,20 +216,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -252,7 +252,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -267,20 +267,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): # Test corrupted script hash in scriptsig. inp1.script_sig[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -293,7 +293,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid public key hash"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -302,7 +302,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_presigned(client: Client): +def test_p2wpkh_presigned(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -339,9 +339,9 @@ def test_p2wpkh_presigned(client: Client): ) # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -358,7 +358,7 @@ def test_p2wpkh_presigned(client: Client): inp2.witness[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -367,7 +367,7 @@ def test_p2wpkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wsh_external_presigned(client: Client): +def test_p2wsh_external_presigned(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=10_000, @@ -399,14 +399,14 @@ def test_p2wsh_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -429,7 +429,7 @@ def test_p2wsh_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -444,14 +444,14 @@ def test_p2wsh_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -470,12 +470,12 @@ def test_p2wsh_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) @pytest.mark.models("core") -def test_p2tr_external_presigned(client: Client): +def test_p2tr_external_presigned(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -509,14 +509,14 @@ def test_p2tr_external_presigned(client: Client): amount=4_600, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -530,7 +530,7 @@ def test_p2tr_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -541,14 +541,14 @@ def test_p2tr_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -558,7 +558,7 @@ def test_p2tr_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -567,18 +567,18 @@ def test_p2tr_external_presigned(client: Client): @pytest.mark.models("core") -def test_p2pkh_with_proof(client: Client): +def test_p2pkh_with_proof(session: Session): # TODO pass @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_with_proof(client: Client): +def test_p2wpkh_in_p2sh_with_proof(session: Session): # TODO pass -def test_p2wpkh_with_proof(client: Client): +def test_p2wpkh_with_proof(session: Session): inp1 = messages.TxInputType( # seed "alcohol woman abuse must during monitor noble actual mixed trade anger aisle" # 84'/1'/0'/0/0 @@ -610,18 +610,18 @@ def test_p2wpkh_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e5b7e2), @@ -643,7 +643,7 @@ def test_p2wpkh_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -660,7 +660,7 @@ def test_p2wpkh_with_proof(client: Client): inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -671,7 +671,7 @@ def test_p2wpkh_with_proof(client: Client): @pytest.mark.setup_client( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) -def test_p2tr_with_proof(client: Client): +def test_p2tr_with_proof(session: Session): # Resulting TXID 48ec6dc7bb772ff18cbce0135fedda7c0e85212c7b2f85a5d0cc7a917d77c48a inp1 = messages.TxInputType( @@ -703,15 +703,15 @@ def test_p2tr_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -722,7 +722,7 @@ def test_p2tr_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -736,10 +736,12 @@ def test_p2tr_with_proof(client: Client): # Test corrupted ownership proof. inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + ) -def test_p2wpkh_with_false_proof(client: Client): +def test_p2wpkh_with_false_proof(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -768,8 +770,8 @@ def test_p2wpkh_with_false_proof(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), @@ -779,7 +781,7 @@ def test_p2wpkh_with_false_proof(client: Client): with pytest.raises(TrezorFailure, match="Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -787,7 +789,7 @@ def test_p2wpkh_with_false_proof(client: Client): ) -def test_p2tr_external_unverified(client: Client): +def test_p2tr_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -823,13 +825,13 @@ def test_p2tr_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. @@ -840,7 +842,7 @@ def test_p2tr_external_unverified(client: Client): ) -def test_p2wpkh_external_unverified(client: Client): +def test_p2wpkh_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -875,13 +877,13 @@ def test_p2wpkh_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 5ef4ba0389c..27f0599de9b 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -36,7 +36,7 @@ # Litecoin does not have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should fail. @pytest.mark.altcoin -def test_invalid_path_fail(client: Client): +def test_invalid_path_fail(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -52,7 +52,7 @@ def test_invalid_path_fail(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) assert exc.value.code == messages.FailureType.DataError assert exc.value.message.endswith("Forbidden key path") @@ -61,7 +61,7 @@ def test_invalid_path_fail(client: Client): # Litecoin does not have strong replay protection using SIGHASH_FORKID, but # spending from Bitcoin path should pass with safety checks set to prompt. @pytest.mark.altcoin -def test_invalid_path_prompt(client: Client): +def test_invalid_path_prompt(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -77,21 +77,21 @@ def test_invalid_path_prompt(client: Client): ) device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) # Bcash does have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should work. @pytest.mark.altcoin -def test_invalid_path_pass_forkid(client: Client): +def test_invalid_path_pass_forkid(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -106,32 +106,32 @@ def test_invalid_path_pass_forkid(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) -def test_attack_path_segwit(client: Client): +def test_attack_path_segwit(session: Session): # Scenario: The attacker falsely claims that the transaction uses Testnet paths to # avoid the path warning dialog, but in step6_sign_segwit_inputs() uses Bitcoin paths # to get a valid signature. device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) # Generate keys address_a = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/0h/0/0"), script_type=messages.InputScriptType.SPENDWITNESS, ) address_b = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -178,15 +178,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} ) -def test_invalid_path_fail_asap(client: Client): +def test_invalid_path_fail_asap(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/0"), amount=1_000_000, @@ -202,14 +202,14 @@ def test_invalid_path_fail_asap(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), messages.Failure(code=messages.FailureType.DataError), ] ) try: - btc.sign_tx(client, "Testnet", [inp1], [out1]) + btc.sign_tx(session, "Testnet", [inp1], [out1]) except TrezorFailure: pass diff --git a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py index de0f3807689..d3ab1cf37b0 100644 --- a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py +++ b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py @@ -15,7 +15,7 @@ # If not, see . from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -34,7 +34,7 @@ ) -def test_non_segwit_segwit_inputs(client: Client): +def test_non_segwit_segwit_inputs(session: Session): # First is non-segwit, second is segwit. inp1 = messages.TxInputType( @@ -58,9 +58,9 @@ def test_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -71,7 +71,7 @@ def test_non_segwit_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_inputs(client: Client): +def test_segwit_non_segwit_inputs(session: Session): # First is segwit, second is non-segwit. inp1 = messages.TxInputType( @@ -94,9 +94,9 @@ def test_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -107,7 +107,7 @@ def test_segwit_non_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_segwit_inputs(client: Client): +def test_segwit_non_segwit_segwit_inputs(session: Session): # First is segwit, second is non-segwit and third is segwit again. inp1 = messages.TxInputType( @@ -138,9 +138,9 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 @@ -151,7 +151,7 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): ) -def test_non_segwit_segwit_non_segwit_inputs(client: Client): +def test_non_segwit_segwit_non_segwit_inputs(session: Session): # First is non-segwit, second is segwit and third is non-segwit again. inp1 = messages.TxInputType( @@ -180,9 +180,9 @@ def test_non_segwit_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index e02cb2b6c6b..32c90d05e0c 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -18,8 +18,8 @@ import pytest -from trezorlib import btc, messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import btc, messages, misc, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -138,7 +138,7 @@ def case(id, *args, altcoin: bool = False, models: str | None = None): case("out12", (PaymentRequestParams([1, 2], [], get_nonce=True),)), ), ) -def test_payment_request(client: Client, payment_request_params): +def test_payment_request(session: Session, payment_request_params): for txo in outputs: txo.payment_req_index = None @@ -148,10 +148,10 @@ def test_payment_request(client: Client, payment_request_params): for txo_index in params.txo_indices: outputs[txo_index].payment_req_index = i request_outputs.append(outputs[txo_index]) - nonce = misc.get_nonce(client) if params.get_nonce else None + nonce = misc.get_nonce(session) if params.get_nonce else None payment_reqs.append( make_payment_request( - client, + session, recipient_name="trezor.io", outputs=request_outputs, change_addresses=["tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9"], @@ -161,7 +161,7 @@ def test_payment_request(client: Client, payment_request_params): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -174,7 +174,7 @@ def test_payment_request(client: Client, payment_request_params): # Ensure that the nonce has been invalidated. with pytest.raises(TrezorFailure, match="Invalid nonce in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -184,15 +184,18 @@ def test_payment_request(client: Client, payment_request_params): @pytest.mark.models(skip="safe3") -def test_payment_request_details(client: Client): +def test_payment_request_details(session: Session): + if session.model is models.T2B1: + pytest.skip("Details not implemented on T2B1") + # Test that payment request details are shown when requested. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None - nonce = misc.get_nonce(client) + nonce = misc.get_nonce(session) payment_reqs = [ make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[TextMemo("Invoice #87654321.")], @@ -200,12 +203,12 @@ def test_payment_request_details(client: Client): ) ] - with client: + with session.client as client: IF = InputFlowPaymentRequestDetails(client, outputs) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -216,16 +219,16 @@ def test_payment_request_details(client: Client): assert serialized_tx.hex() == SERIALIZED_TX -def test_payment_req_wrong_amount(client: Client): +def test_payment_req_wrong_amount(session: Session): # Test wrong total amount in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Decrease the total amount of the payment request. @@ -233,7 +236,7 @@ def test_payment_req_wrong_amount(client: Client): with pytest.raises(TrezorFailure, match="Invalid amount in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -242,18 +245,18 @@ def test_payment_req_wrong_amount(client: Client): ) -def test_payment_req_wrong_mac_refund(client: Client): +def test_payment_req_wrong_mac_refund(session: Session): # Test wrong MAC in payment request memo. memo = RefundMemo(parse_path("m/44h/1h/0h/1/0")) outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -263,7 +266,7 @@ def test_payment_req_wrong_mac_refund(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -274,7 +277,7 @@ def test_payment_req_wrong_mac_refund(client: Client): @pytest.mark.altcoin @pytest.mark.models("t2t1", reason="Dash not supported on Safe family") -def test_payment_req_wrong_mac_purchase(client: Client): +def test_payment_req_wrong_mac_purchase(session: Session): # Test wrong MAC in payment request memo. memo = CoinPurchaseMemo( amount="22.34904 DASH", @@ -286,11 +289,11 @@ def test_payment_req_wrong_mac_purchase(client: Client): outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -300,7 +303,7 @@ def test_payment_req_wrong_mac_purchase(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -309,16 +312,16 @@ def test_payment_req_wrong_mac_purchase(client: Client): ) -def test_payment_req_wrong_output(client: Client): +def test_payment_req_wrong_output(session: Session): # Test wrong output in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Use a different address in the second output. @@ -335,7 +338,7 @@ def test_payment_req_wrong_output(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, fake_outputs, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index 307823a9f3f..a2f96c04ed1 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -5,7 +5,7 @@ import pytest from trezorlib import btc, messages, models, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import is_core @@ -78,7 +78,7 @@ def _check_error_message(value: bytes, model: models.TrezorModel, message: str): @with_bad_prevhashes -def test_invalid_prev_hash(client: Client, prev_hash): +def test_invalid_prev_hash(session: Session, prev_hash): inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), amount=123_456_789, @@ -93,12 +93,12 @@ def test_invalid_prev_hash(client: Client, prev_hash): ) with pytest.raises(TrezorFailure) as e: - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes={}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes={}) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_attack(client: Client, prev_hash): +def test_invalid_prev_hash_attack(session: Session, prev_hash): # prepare input with a valid prev-hash inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), @@ -130,20 +130,20 @@ def attack_filter(msg): msg.tx.inputs[0].prev_hash = prev_hash return msg - with client, pytest.raises(TrezorFailure) as e: - client.set_filter(messages.TxAck, attack_filter) - if is_core(client): + with session, session.client as client, pytest.raises(TrezorFailure) as e: + session.set_filter(messages.TxAck, attack_filter) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed assert counter == 0 - _check_error_message(prev_hash, client.model, e.value.message) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): +def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): prev_tx = copy(PREV_TX) # smoke check: replace prev_hash with all zeros, reserialize and hash, try to sign @@ -161,16 +161,16 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): amount=99_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) # attack: replace prev_hash with an invalid value prev_tx.inputs[0].prev_hash = prev_hash tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with client, pytest.raises(TrezorFailure) as e: - if client.model is not models.T1B1: + with session, session.client as client, pytest.raises(TrezorFailure) as e: + if session.model is not models.T1B1: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_replacement.py b/tests/device_tests/bitcoin/test_signtx_replacement.py index 97fe7e2d873..fd5db6a5027 100644 --- a/tests/device_tests/bitcoin/test_signtx_replacement.py +++ b/tests/device_tests/bitcoin/test_signtx_replacement.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -90,7 +90,7 @@ ) -def test_p2pkh_fee_bump(client: Client): +def test_p2pkh_fee_bump(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/4"), amount=174_998, @@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_50f6f1), @@ -132,7 +132,7 @@ def test_p2pkh_fee_bump(client: Client): request_meta(TXHASH_beafc7), request_input(0, TXHASH_beafc7), request_output(0, TXHASH_beafc7), - (is_core(client), request_orig_input(0, TXHASH_50f6f1)), + (is_core(session), request_orig_input(0, TXHASH_50f6f1)), request_orig_input(0, TXHASH_50f6f1), request_orig_output(0, TXHASH_50f6f1), request_orig_output(1, TXHASH_50f6f1), @@ -145,7 +145,7 @@ def test_p2pkh_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -159,7 +159,7 @@ def test_p2pkh_fee_bump(client: Client): ) -def test_p2wpkh_op_return_fee_bump(client: Client): +def test_p2wpkh_op_return_fee_bump(session: Session): # Original input. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/1h/0/14"), @@ -190,9 +190,9 @@ def test_p2wpkh_op_return_fee_bump(client: Client): orig_index=1, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -207,7 +207,7 @@ def test_p2wpkh_op_return_fee_bump(client: Client): # txid 48bc29fc42a64b43d043b0b7b99b21aa39654234754608f791c60bcbd91a8e92 -def test_p2tr_fee_bump(client: Client): +def test_p2tr_fee_bump(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -243,8 +243,8 @@ def test_p2tr_fee_bump(client: Client): orig_index=1, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_8e4af7), @@ -269,7 +269,7 @@ def test_p2tr_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -281,7 +281,7 @@ def test_p2tr_fee_bump(client: Client): ) -def test_p2wpkh_finalize(client: Client): +def test_p2wpkh_finalize(session: Session): # Original input with disabled RBF opt-in, i.e. we finalize the transaction. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/2"), @@ -312,8 +312,8 @@ def test_p2wpkh_finalize(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_70f987), @@ -339,7 +339,7 @@ def test_p2wpkh_finalize(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -401,7 +401,7 @@ def test_p2wpkh_finalize(client: Client): ), ) def test_p2wpkh_payjoin( - client, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx + session, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx ): # Original input. inp1 = messages.TxInputType( @@ -444,8 +444,8 @@ def test_p2wpkh_payjoin( orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_65b768), @@ -478,7 +478,7 @@ def test_p2wpkh_payjoin( ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -489,7 +489,7 @@ def test_p2wpkh_payjoin( assert serialized_tx.hex() == expected_tx -def test_p2wpkh_in_p2sh_remove_change(client: Client): +def test_p2wpkh_in_p2sh_remove_change(session: Session): # Test fee bump with change-output removal. Originally fee was 3780, now 98060. inp1 = messages.TxInputType( @@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -553,7 +553,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -567,7 +567,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ) -def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): +def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -634,7 +634,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -649,7 +649,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): @pytest.mark.models("core") -def test_tx_meld(client: Client): +def test_tx_meld(session: Session): # Meld two original transactions into one, joining the change-outputs into a different one. inp1 = messages.TxInputType( @@ -720,8 +720,8 @@ def test_tx_meld(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -785,7 +785,7 @@ def test_tx_meld(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3], @@ -799,7 +799,7 @@ def test_tx_meld(client: Client): ) -def test_attack_steal_change(client: Client): +def test_attack_steal_change(session: Session): # Attempt to steal amount equivalent to the change in the original transaction by # hiding the fact that an output in the original transaction is a change-output. @@ -860,7 +860,7 @@ def test_attack_steal_change(client: Client): TrezorFailure, match="Original output is missing change-output parameters" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -870,7 +870,7 @@ def test_attack_steal_change(client: Client): @pytest.mark.models("core") -def test_attack_false_internal(client: Client): +def test_attack_false_internal(session: Session): # Falsely claim that an external input is internal in the original transaction. # If this were possible, it would allow an attacker to make it look like the # user was spending more in the original than they actually were, making it @@ -914,7 +914,7 @@ def test_attack_false_internal(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -922,7 +922,7 @@ def test_attack_false_internal(client: Client): ) -def test_attack_fake_int_input_amount(client: Client): +def test_attack_fake_int_input_amount(session: Session): # Give a fake input amount for an original internal input while giving the correct # amount for the replacement input. If an attacker could increase the amount of an # internal input in the original transaction, then they could bump the fee of the @@ -968,7 +968,7 @@ def test_attack_fake_int_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -977,7 +977,7 @@ def test_attack_fake_int_input_amount(client: Client): @pytest.mark.models("core") -def test_attack_fake_ext_input_amount(client: Client): +def test_attack_fake_ext_input_amount(session: Session): # Give a fake input amount for an original external input while giving the correct # amount for the replacement input. If an attacker could decrease the amount of an # external input in the original transaction, then they could steal the fee from @@ -1044,7 +1044,7 @@ def test_attack_fake_ext_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -1052,7 +1052,7 @@ def test_attack_fake_ext_input_amount(client: Client): ) -def test_p2wpkh_invalid_signature(client: Client): +def test_p2wpkh_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. # Original input with disabled RBF opt-in, i.e. we finalize the transaction. @@ -1096,7 +1096,7 @@ def test_p2wpkh_invalid_signature(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1105,7 +1105,7 @@ def test_p2wpkh_invalid_signature(client: Client): ) -def test_p2tr_invalid_signature(client: Client): +def test_p2tr_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. inp1 = messages.TxInputType( @@ -1151,4 +1151,4 @@ def test_p2tr_invalid_signature(client: Client): prev_txes = {TXHASH_8e4af7: prev_tx_invalid} with pytest.raises(TrezorFailure, match="Invalid signature"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit.py b/tests/device_tests/bitcoin/test_signtx_segwit.py index 763626caef0..ef8c988ff39 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -47,7 +47,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2sh(client: Client, chunkify: bool): +def test_send_p2sh(session: Session, chunkify: bool): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -66,16 +66,16 @@ def test_send_p2sh(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -90,7 +90,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -105,7 +105,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -124,13 +124,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -146,7 +146,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -156,11 +156,11 @@ def test_send_p2sh_change(client: Client): ) -def test_testnet_segwit_big_amount(client: Client): +def test_testnet_segwit_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 address_n = parse_path("m/49h/1h/0h/0/0") address = btc.get_address( - client, + session, "Testnet", address_n, script_type=messages.InputScriptType.SPENDP2SHWITNESS, @@ -179,13 +179,13 @@ def test_testnet_segwit_big_amount(client: Client): amount=2**32 + 1, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(prev_hash), @@ -198,7 +198,7 @@ def test_testnet_segwit_big_amount(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} ) # Transaction does not exist on the blockchain, not using assert_tx_matches() assert ( @@ -208,12 +208,12 @@ def test_testnet_segwit_big_amount(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input: 338e2d02e0eaf8848e38925904e51546cf22e58db5b1860c4a0e72b69c56afe5 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -241,7 +241,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_338e2d), @@ -254,10 +254,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -265,10 +265,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -278,7 +278,7 @@ def test_send_multisig_1(client: Client): ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # Simulates an attack where the user is coerced into unknowingly # transferring funds from one account to another one of their accounts, # potentially resulting in privacy issues. @@ -303,17 +303,17 @@ def test_attack_change_input_address(client: Client): ) # Test if the transaction can be signed normally. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), # The user is required to confirm transfer to another account. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -328,7 +328,7 @@ def test_attack_change_input_address(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -349,15 +349,15 @@ def attack_processor(msg): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) -def test_attack_mixed_inputs(client: Client): +def test_attack_mixed_inputs(session: Session): TRUE_AMOUNT = 123_456_789 FAKE_AMOUNT = 120_000_000 @@ -389,11 +389,11 @@ def test_attack_mixed_inputs(client: Client): request_output(0), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), ), messages.ButtonRequest(code=messages.ButtonRequestType.FeeOverThreshold), @@ -417,16 +417,16 @@ def test_attack_mixed_inputs(client: Client): request_finished(), ] - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 asks for first input for witness again expected_responses.insert(-2, request_input(0)) - with client: + with session: # Sign unmodified transaction. # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT - client.set_expected_responses(expected_responses) + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -436,7 +436,7 @@ def test_attack_mixed_inputs(client: Client): # In Phase 1 make the user confirm a lower value of the segwit input. inp2.amount = FAKE_AMOUNT - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 fails as soon as it encounters the fake amount. expected_responses = ( expected_responses[:4] + expected_responses[5:15] + [messages.Failure()] @@ -446,10 +446,10 @@ def test_attack_mixed_inputs(client: Client): expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] ) - with pytest.raises(TrezorFailure) as e, client: - client.set_expected_responses(expected_responses) + with pytest.raises(TrezorFailure) as e, session: + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 0c779c777ef..920b0bf48b7 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ...bip32 import deserialize @@ -61,7 +61,7 @@ ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -82,16 +82,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -106,7 +106,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -116,7 +116,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -137,13 +137,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -159,7 +159,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -169,7 +169,7 @@ def test_send_p2sh_change(client: Client): ) -def test_send_native(client: Client): +def test_send_native(session: Session): # input tx: b36780ceb86807ca6e7535a6fd418b1b788cb9b227d2c8a26a0de295e523219e inp1 = messages.TxInputType( @@ -190,16 +190,16 @@ def test_send_native(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b36780), @@ -214,7 +214,7 @@ def test_send_native(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -224,7 +224,7 @@ def test_send_native(client: Client): ) -def test_send_to_taproot(client: Client): +def test_send_to_taproot(session: Session): # input tx: ec16dc5a539c5d60001a7471c37dbb0b5294c289c77df8bd07870b30d73e2231 inp1 = messages.TxInputType( @@ -244,9 +244,9 @@ def test_send_to_taproot(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=10_000 - 7_000 - 200, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -256,7 +256,7 @@ def test_send_to_taproot(client: Client): ) -def test_send_native_change(client: Client): +def test_send_native_change(session: Session): # input tx: fcb3f5436224900afdba50e9e763d98b920dfed056e552040d99ea9bc03a9d83 inp1 = messages.TxInputType( @@ -277,13 +277,13 @@ def test_send_native_change(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -300,7 +300,7 @@ def test_send_native_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -310,7 +310,7 @@ def test_send_native_change(client: Client): ) -def test_send_both(client: Client): +def test_send_both(session: Session): # input 1 tx: 65047a2b107d6301d72d4a1e49e7aea9cf06903fdc4ae74a4a9bba9bc1a414d2 # input 2 tx: d159fd2fcb5854a7c8b275d598765a446f1e2ff510bf077545a404a0c9db65f7 @@ -344,21 +344,21 @@ def test_send_both(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_65047a), @@ -382,7 +382,7 @@ def test_send_both(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -397,12 +397,12 @@ def test_send_both(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -433,7 +433,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -449,10 +449,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -460,10 +460,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -474,12 +474,12 @@ def test_send_multisig_1(client: Client): @pytest.mark.multisig -def test_send_multisig_2(client: Client): +def test_send_multisig_2(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -510,7 +510,7 @@ def test_send_multisig_2(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -526,10 +526,10 @@ def test_send_multisig_2(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -537,10 +537,10 @@ def test_send_multisig_2(client: Client): # sign with first key inp1.address_n[2] = H_(1) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -551,12 +551,12 @@ def test_send_multisig_2(client: Client): @pytest.mark.multisig -def test_send_multisig_3_change(client: Client): +def test_send_multisig_3_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -595,7 +595,7 @@ def test_send_multisig_3_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -611,13 +611,13 @@ def test_send_multisig_3_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -626,13 +626,13 @@ def test_send_multisig_3_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -643,12 +643,12 @@ def test_send_multisig_3_change(client: Client): @pytest.mark.multisig -def test_send_multisig_4_change(client: Client): +def test_send_multisig_4_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -687,7 +687,7 @@ def test_send_multisig_4_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -703,13 +703,13 @@ def test_send_multisig_4_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -718,13 +718,13 @@ def test_send_multisig_4_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -734,7 +734,7 @@ def test_send_multisig_4_change(client: Client): ) -def test_multisig_mismatch_inputs_single(client: Client): +def test_multisig_mismatch_inputs_single(session: Session): # Ensure that if there is a non-multisig input, then a multisig output # will not be identified as a change output. @@ -788,18 +788,18 @@ def test_multisig_mismatch_inputs_single(client: Client): amount=100_000 + 100_000 - 50_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), # Ensure that the multisig output is not identified as a change output. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_1c022d), @@ -824,7 +824,7 @@ def test_multisig_mismatch_inputs_single(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_signtx_taproot.py b/tests/device_tests/bitcoin/test_signtx_taproot.py index f548154ae70..74c475bcc67 100644 --- a/tests/device_tests/bitcoin/test_signtx_taproot.py +++ b/tests/device_tests/bitcoin/test_signtx_taproot.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -64,7 +64,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2tr(client: Client, chunkify: bool): +def test_send_p2tr(session: Session, chunkify: bool): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -79,13 +79,13 @@ def test_send_p2tr(client: Client, chunkify: bool): amount=4_450, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -94,7 +94,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify + session, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify ) assert_tx_matches( @@ -104,7 +104,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ) -def test_send_two_with_change(client: Client): +def test_send_two_with_change(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -133,14 +133,14 @@ def test_send_two_with_change(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, amount=6_800 + 13_000 - 200 - 15_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -153,7 +153,7 @@ def test_send_two_with_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API ) assert_tx_matches( @@ -163,7 +163,7 @@ def test_send_two_with_change(client: Client): ) -def test_send_mixed(client: Client): +def test_send_mixed(session: Session): inp1 = messages.TxInputType( # 2MutHjgAXkqo3jxX2DZWorLAckAnwTxSM9V address_n=parse_path("m/49h/1h/1h/0/0"), @@ -222,8 +222,8 @@ def test_send_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # process inputs request_input(0), @@ -233,19 +233,19 @@ def test_send_mixed(client: Client): # approve outputs request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(3), messages.ButtonRequest(code=B.ConfirmOutput), request_output(4), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), # verify inputs request_input(0), @@ -293,12 +293,12 @@ def test_send_mixed(client: Client): request_input(0), request_input(1), request_input(2), - (client.model is models.T1B1, request_input(3)), + (session.model is models.T1B1, request_input(3)), request_finished(), ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3, out4, out5], @@ -312,13 +312,13 @@ def test_send_mixed(client: Client): ) -def test_attack_script_type(client: Client): +def test_attack_script_type(session: Session): # Scenario: The attacker falsely claims that the transaction is Taproot-only to # avoid prev tx streaming and gives a lower amount for one of the inputs. The # correct input types and amounts are revelaled only in step6_sign_segwit_inputs() # to get a valid signature. This results in a transaction which pays a fee much # larger than what the user confirmed. - + raise Exception("THIS TEST FAILS, TO DO FIX") inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/1/0"), amount=7_289_000, @@ -354,16 +354,16 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -374,7 +374,7 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) assert exc.value.code == messages.FailureType.ProcessError assert exc.value.message.endswith("Transaction has changed during signing") @@ -392,7 +392,7 @@ def attack_processor(msg): "tb1pllllllllllllllllllllllllllllllllllllllllllllallllscqgl4zhn", ), ) -def test_send_invalid_address(client: Client, address: str): +def test_send_invalid_address(session: Session, address: str): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -407,12 +407,12 @@ def test_send_invalid_address(client: Client, address: str): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure): - client.set_expected_responses( + with session, pytest.raises(TrezorFailure): + session.set_expected_responses( [ request_input(0), request_output(0), messages.Failure, ] ) - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index 86389d8a515..88907e318dc 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -19,12 +19,12 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -35,9 +35,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "mirio8q3gtv7fhdnmb3TpZ4EuafdzSs7zL", bytes.fromhex( @@ -49,9 +49,9 @@ def test_message_testnet(client: Client): @pytest.mark.altcoin -def test_message_grs(client: Client): +def test_message_grs(session: Session): ret = btc.verify_message( - client, + session, "Groestlcoin", "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM", base64.b64decode( @@ -62,9 +62,9 @@ def test_message_grs(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -76,7 +76,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -88,7 +88,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -100,7 +100,7 @@ def test_message_verify(client: Client): # compressed pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -112,7 +112,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -124,7 +124,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -136,7 +136,7 @@ def test_message_verify(client: Client): # trezor pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -148,7 +148,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -160,7 +160,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -172,9 +172,9 @@ def test_message_verify(client: Client): @pytest.mark.altcoin -def test_message_verify_bcash(client: Client): +def test_message_verify_bcash(session: Session): res = btc.verify_message( - client, + session, "Bcash", "bitcoincash:qqj22md58nm09vpwsw82fyletkxkq36zxyxh322pru", bytes.fromhex( @@ -185,9 +185,9 @@ def test_message_verify_bcash(client: Client): assert res is True -def test_verify_bitcoind(client: Client): +def test_verify_bitcoind(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1KzXE97kV7DrpxCViCN3HbGbiKhzzPM7TQ", bytes.fromhex( @@ -199,12 +199,12 @@ def test_verify_bitcoind(client: Client): assert res is True -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -214,7 +214,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit.py b/tests/device_tests/bitcoin/test_verifymessage_segwit.py index 84f04442646..9c3169e0c78 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "2N4VkePSzKH2sv5YBikLHGvzUYvfPxV6zS9", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "3L6TyTisPBmrDAj6RoKmDzNnj4eQi54gD2", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py index 5bea51f7dc1..3a4ed68e5da 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "tb1qyjjkmdpu7metqt5r36jf872a34syws336p3n3p", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "bc1qannfxke2tfd4l7vhepehpvt05y83v3qsf6nfkk", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_zcash.py b/tests/device_tests/bitcoin/test_zcash.py index dc959199a35..adb99589150 100644 --- a/tests/device_tests/bitcoin/test_zcash.py +++ b/tests/device_tests/bitcoin/test_zcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -57,7 +57,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_v3_not_supported(client: Client): +def test_v3_not_supported(session: Session): # prevout: aaf51e4606c264e47e5c42c958fe4cf1539c5172684721e38e69f4ef634d75dc:1 # input 1: 3.0 TAZ @@ -75,9 +75,9 @@ def test_v3_not_supported(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure, match="DataError"): + with session, pytest.raises(TrezorFailure, match="DataError"): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -88,7 +88,7 @@ def test_v3_not_supported(client: Client): ) -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: e3820602226974b1dd87b7113cc8aea8c63e5ae29293991e7bfa80c126930368:0 # input 1: 3.0 TAZ @@ -106,13 +106,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -128,7 +128,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -145,7 +145,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -161,7 +161,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -170,7 +170,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_old_versions(client: Client): +def test_spend_old_versions(session: Session): # NOTE: fake input tx used input_v1 = messages.TxInputType( @@ -210,9 +210,9 @@ def test_spend_old_versions(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", inputs, [output], @@ -229,7 +229,7 @@ def test_spend_old_versions(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -259,14 +259,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -289,7 +289,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d7c02e6b6dc..95f42ecec13 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -22,7 +22,7 @@ get_public_key, parse_optional_bytes, ) -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import CardanoAddressType, CardanoDerivationType from trezorlib.tools import parse_path @@ -48,15 +48,15 @@ "cardano/get_base_address.derivations.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_cardano_get_address(client: Client, chunkify: bool, parameters, result): - client.init_device(new_session=True, derive_cardano=True) +def test_cardano_get_address(session: Session, chunkify: bool, parameters, result): + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] address = get_address( - client, + session, address_parameters=create_address_parameters( address_type=getattr( CardanoAddressType, parameters["address_type"].upper() @@ -94,17 +94,17 @@ def test_cardano_get_address(client: Client, chunkify: bool, parameters, result) "cardano/get_public_key.slip39.json", "cardano/get_public_key.derivations.json", ) -def test_cardano_get_public_key(client: Client, parameters, result): - with client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(client.ui.passphrase)) +def test_cardano_get_public_key(session: Session, parameters, result): + with session, session.client as client: + IF = InputFlowShowXpubQRCode(client, passphrase=False) client.set_input_flow(IF.get()) - client.init_device(new_session=True, derive_cardano=True) + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] key = get_public_key( - client, parse_path(parameters["path"]), derivation_type, show_display=True + session, parse_path(parameters["path"]), derivation_type, show_display=True ) assert key.node.public_key.hex() == result["public_key"] diff --git a/tests/device_tests/cardano/test_derivations.py b/tests/device_tests/cardano/test_derivations.py index 656c31a8bde..a7e01f01c47 100644 --- a/tests/device_tests/cardano/test_derivations.py +++ b/tests/device_tests/cardano/test_derivations.py @@ -17,7 +17,7 @@ import pytest from trezorlib.cardano import get_public_key -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import CardanoDerivationType as D from trezorlib.tools import parse_path @@ -33,28 +33,26 @@ ADDRESS_N = parse_path("m/1852h/1815h/0h") -def test_bad_session(client: Client): - client.init_device(new_session=True) +def test_bad_session(session: Session): with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) + get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) - client.init_device(new_session=True, derive_cardano=False) - with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) +def test_ledger_available_without_cardano(session: Session): + # session.init_device(new_session=True, derive_cardano=False) + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) -def test_ledger_available_always(client: Client): - client.init_device(new_session=True, derive_cardano=False) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) - client.init_device(new_session=True, derive_cardano=True) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) +@pytest.mark.cardano +def test_ledger_available_with_cardano(session: Session): + # session.init_device(new_session=True, derive_cardano=True) + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.parametrize("derivation_type", D) # try ALL derivation types -def test_derivation_irrelevant_on_slip39(client: Client, derivation_type): - client.init_device(new_session=True, derive_cardano=False) - pubkey = get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - test_pubkey = get_public_key(client, ADDRESS_N, derivation_type=derivation_type) +def test_derivation_irrelevant_on_slip39(session: Session, derivation_type): + # session.init_device(new_session=True, derive_cardano=False) + pubkey = get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) + test_pubkey = get_public_key(session, ADDRESS_N, derivation_type=derivation_type) assert pubkey == test_pubkey diff --git a/tests/device_tests/cardano/test_get_native_script_hash.py b/tests/device_tests/cardano/test_get_native_script_hash.py index 63ee56d16fb..2859d69a41a 100644 --- a/tests/device_tests/cardano/test_get_native_script_hash.py +++ b/tests/device_tests/cardano/test_get_native_script_hash.py @@ -18,7 +18,7 @@ from trezorlib import messages from trezorlib.cardano import get_native_script_hash, parse_native_script -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import parametrize_using_common_fixtures @@ -32,11 +32,9 @@ @parametrize_using_common_fixtures( "cardano/get_native_script_hash.json", ) -def test_cardano_get_native_script_hash(client: Client, parameters, result): - client.init_device(new_session=True, derive_cardano=True) - +def test_cardano_get_native_script_hash(session: Session, parameters, result): native_script_hash = get_native_script_hash( - client, + session, native_script=parse_native_script(parameters["native_script"]), display_format=messages.CardanoNativeScriptHashDisplayFormat.__members__[ parameters["display_format"] diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index ec78719c7f4..6a7eb8afba9 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -18,6 +18,7 @@ from trezorlib import cardano, device, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -54,9 +55,9 @@ def show_details_input_flow(client: Client): "cardano/sign_tx.plutus.json", "cardano/sign_tx.slip39.json", ) -def test_cardano_sign_tx(client: Client, parameters, result): +def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( - client, + session, parameters, input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), ) @@ -65,8 +66,8 @@ def test_cardano_sign_tx(client: Client, parameters, result): @pytest.mark.models(skip="mercury", reason="Not yet implemented in new UI") @parametrize_using_common_fixtures("cardano/sign_tx.show_details.json") -def test_cardano_sign_tx_show_details(client: Client, parameters, result): - response = call_sign_tx(client, parameters, show_details_input_flow, chunkify=True) +def test_cardano_sign_tx_show_details(session: Session, parameters, result): + response = call_sign_tx(session, parameters, show_details_input_flow, chunkify=True) assert response == _transform_expected_result(result) @@ -76,13 +77,13 @@ def test_cardano_sign_tx_show_details(client: Client, parameters, result): "cardano/sign_tx.multisig.failed.json", "cardano/sign_tx.plutus.failed.json", ) -def test_cardano_sign_tx_failed(client: Client, parameters, result): +def test_cardano_sign_tx_failed(session: Session, parameters, result): with pytest.raises(TrezorFailure, match=result["error_message"]): - call_sign_tx(client, parameters, None) + call_sign_tx(session, parameters, None) -def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = False): - client.init_device(new_session=True, derive_cardano=True) +def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = False): + # session.init_device(new_session=True, derive_cardano=True) signing_mode = messages.CardanoTxSigningMode.__members__[parameters["signing_mode"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]] @@ -113,18 +114,18 @@ def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = F if parameters.get("security_checks") == "prompt": device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) else: - device.apply_settings(client, safety_checks=messages.SafetyCheckLevel.Strict) + device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with client: + with session.client as client: if input_flow is not None: client.watch_layout() client.set_input_flow(input_flow(client)) return cardano.sign_tx( - client=client, + session=session, signing_mode=signing_mode, inputs=inputs, outputs=outputs, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index 1b518e95f2e..d99c54cb2b6 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.eos import get_public_key from trezorlib.tools import parse_path @@ -28,12 +28,12 @@ @pytest.mark.eos @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_eos_get_public_key(client: Client): - with client: +def test_eos_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) public_key = get_public_key( - client, parse_path("m/44h/194h/0h/0/0"), show_display=True + session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) assert ( public_key.wif_public_key @@ -43,7 +43,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02015fabe197c955036bab25f4e7c16558f9f672f9f625314ab1ec8f64f7b1198e" ) - public_key = get_public_key(client, parse_path("m/44h/194h/0h/0/1")) + public_key = get_public_key(session, parse_path("m/44h/194h/0h/0/1")) assert ( public_key.wif_public_key == "EOS5d1VP15RKxT4dSakWu2TFuEgnmaGC2ckfSvQwND7pZC1tXkfLP" @@ -52,7 +52,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02608bc2c431521dee0b9d5f2fe34053e15fc3b20d2895e0abda857b9ed8e77a78" ) - public_key = get_public_key(client, parse_path("m/44h/194h/1h/0/0")) + public_key = get_public_key(session, parse_path("m/44h/194h/1h/0/0")) assert ( public_key.wif_public_key == "EOS7UuNeTf13nfcG85rDB7AHGugZi4C4wJ4ft12QRotqNfxdV2NvP" diff --git a/tests/device_tests/eos/test_signtx.py b/tests/device_tests/eos/test_signtx.py index 57fd051bb4a..54ebece6a9a 100644 --- a/tests/device_tests/eos/test_signtx.py +++ b/tests/device_tests/eos/test_signtx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import eos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import EosSignedTx from trezorlib.tools import parse_path @@ -35,7 +35,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_eos_signtx_transfer_token(client: Client, chunkify: bool): +def test_eos_signtx_transfer_token(session: Session, chunkify: bool): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -60,8 +60,8 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -69,7 +69,7 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): ) -def test_eos_signtx_buyram(client: Client): +def test_eos_signtx_buyram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -93,8 +93,8 @@ def test_eos_signtx_buyram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -102,7 +102,7 @@ def test_eos_signtx_buyram(client: Client): ) -def test_eos_signtx_buyrambytes(client: Client): +def test_eos_signtx_buyrambytes(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -126,8 +126,8 @@ def test_eos_signtx_buyrambytes(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -135,7 +135,7 @@ def test_eos_signtx_buyrambytes(client: Client): ) -def test_eos_signtx_sellram(client: Client): +def test_eos_signtx_sellram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -155,8 +155,8 @@ def test_eos_signtx_sellram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -164,7 +164,7 @@ def test_eos_signtx_sellram(client: Client): ) -def test_eos_signtx_delegate(client: Client): +def test_eos_signtx_delegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -190,8 +190,8 @@ def test_eos_signtx_delegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -199,7 +199,7 @@ def test_eos_signtx_delegate(client: Client): ) -def test_eos_signtx_undelegate(client: Client): +def test_eos_signtx_undelegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -224,8 +224,8 @@ def test_eos_signtx_undelegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -233,7 +233,7 @@ def test_eos_signtx_undelegate(client: Client): ) -def test_eos_signtx_refund(client: Client): +def test_eos_signtx_refund(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -253,8 +253,8 @@ def test_eos_signtx_refund(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -262,7 +262,7 @@ def test_eos_signtx_refund(client: Client): ) -def test_eos_signtx_linkauth(client: Client): +def test_eos_signtx_linkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -287,8 +287,8 @@ def test_eos_signtx_linkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -296,7 +296,7 @@ def test_eos_signtx_linkauth(client: Client): ) -def test_eos_signtx_unlinkauth(client: Client): +def test_eos_signtx_unlinkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -320,8 +320,8 @@ def test_eos_signtx_unlinkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -329,7 +329,7 @@ def test_eos_signtx_unlinkauth(client: Client): ) -def test_eos_signtx_updateauth(client: Client): +def test_eos_signtx_updateauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -376,8 +376,8 @@ def test_eos_signtx_updateauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -385,7 +385,7 @@ def test_eos_signtx_updateauth(client: Client): ) -def test_eos_signtx_deleteauth(client: Client): +def test_eos_signtx_deleteauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -405,8 +405,8 @@ def test_eos_signtx_deleteauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -414,7 +414,7 @@ def test_eos_signtx_deleteauth(client: Client): ) -def test_eos_signtx_vote(client: Client): +def test_eos_signtx_vote(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -468,8 +468,8 @@ def test_eos_signtx_vote(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -477,7 +477,7 @@ def test_eos_signtx_vote(client: Client): ) -def test_eos_signtx_vote_proxy(client: Client): +def test_eos_signtx_vote_proxy(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -497,8 +497,8 @@ def test_eos_signtx_vote_proxy(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -506,7 +506,7 @@ def test_eos_signtx_vote_proxy(client: Client): ) -def test_eos_signtx_unknown(client: Client): +def test_eos_signtx_unknown(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -526,8 +526,8 @@ def test_eos_signtx_unknown(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -535,7 +535,7 @@ def test_eos_signtx_unknown(client: Client): ) -def test_eos_signtx_newaccount(client: Client): +def test_eos_signtx_newaccount(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -602,8 +602,8 @@ def test_eos_signtx_newaccount(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -611,7 +611,7 @@ def test_eos_signtx_newaccount(client: Client): ) -def test_eos_signtx_setcontract(client: Client): +def test_eos_signtx_setcontract(session: Session): transaction = { "expiration": "2018-06-19T13:29:53", "ref_block_num": 30587, @@ -638,8 +638,8 @@ def test_eos_signtx_setcontract(client: Client): "context_free_data": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 74d30bb3a99..edb7833c58d 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -5,7 +5,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -39,60 +39,60 @@ } -def test_builtin(client: Client) -> None: +def test_builtin(session: Session) -> None: # Ethereum (SLIP-44 60, chain_id 1) will sign without any definitions provided - ethereum.sign_tx(client, **DEFAULT_TX_PARAMS) + ethereum.sign_tx(session, **DEFAULT_TX_PARAMS) -def test_chain_id_allowed(client: Client) -> None: +def test_chain_id_allowed(session: Session) -> None: # Any chain id is allowed as long as the SLIP44 stays the same params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=222222) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_disallowed(client: Client) -> None: +def test_slip44_disallowed(session: Session) -> None: # SLIP44 is not allowed without a valid network definition params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0")) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_external(client: Client) -> None: +def test_slip44_external(session: Session) -> None: # to use a non-default SLIP44, a valid network definition must be provided network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_slip44_external_disallowed(client: Client) -> None: +def test_slip44_external_disallowed(session: Session) -> None: # network definition does not allow a different SLIP44 network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/55555h/0h/0/0"), chain_id=66666) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_chain_id_mismatch(client: Client) -> None: +def test_chain_id_mismatch(session: Session) -> None: # network definition for a different chain id will be rejected network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=55555) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_definition_does_not_override_builtin(client: Client) -> None: +def test_definition_does_not_override_builtin(session: Session) -> None: # The builtin definition for Ethereum (SLIP44 60, chain_id 1) will be used # even if a valid definition with a different SLIP44 is provided network = common.encode_network(chain_id=1, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=1) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO: test that the builtin definition will not show different symbol @@ -101,77 +101,77 @@ def test_definition_does_not_override_builtin(client: Client) -> None: # all tokens are currently accepted, we would need to check the screenshots -def test_builtin_token(client: Client) -> None: +def test_builtin_token(session: Session) -> None: # The builtin definition for USDT (ERC20) will be used even if not provided params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) # TODO check that USDT symbol is shown # TODO: test_builtin_token_not_overriden (builtin definition is used even if a custom one is provided) -def test_external_token(client: Client) -> None: +def test_external_token(session: Session) -> None: # A valid token definition must be provided to use a non-builtin token token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=1, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) - ethereum.sign_tx(client, **params, definitions=common.make_defs(None, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(None, token)) # TODO check that FakeTok symbol is shown -def test_external_chain_without_token(client: Client) -> None: +def test_external_chain_without_token(session: Session) -> None: # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO check that UNKN token is used, FAKE network -def test_external_chain_token_ok(client: Client) -> None: +def test_external_chain_token_ok(session: Session) -> None: # when providing an external chain and matching token, everything works network = common.encode_network(chain_id=66666, slip44=60) token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=66666, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, token)) # TODO check that FakeTok is used, FAKE network -def test_external_chain_token_mismatch(client: Client) -> None: +def test_external_chain_token_mismatch(session: Session) -> None: # when providing external defs, we explicitly allow, but not use, tokens # from other chains network = common.encode_network(chain_id=66666, slip44=60) token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=55555, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, token)) # TODO check that UNKN is used for token, FAKE for network -def _call_getaddress(client: Client, slip44: int, network: bytes | None) -> None: +def _call_getaddress(session: Session, slip44: int, network: bytes | None) -> None: ethereum.get_address( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), show_display=False, encoded_network=network, ) -def _call_signmessage(client: Client, slip44: int, network: bytes | None) -> None: +def _call_signmessage(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_message( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), b"hello", encoded_network=network, ) -def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> None: +def _call_sign_typed_data(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_typed_data( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), TYPED_DATA, metamask_v4_compat=True, @@ -180,10 +180,10 @@ def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> def _call_sign_typed_data_hash( - client: Client, slip44: int, network: bytes | None + session: Session, slip44: int, network: bytes | None ) -> None: ethereum.sign_typed_data_hash( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), b"\x00" * 32, b"\xff" * 32, @@ -191,7 +191,7 @@ def _call_sign_typed_data_hash( ) -MethodType = Callable[[Client, int, "bytes | None"], None] +MethodType = Callable[[Session, int, "bytes | None"], None] METHODS = ( @@ -203,29 +203,29 @@ def _call_sign_typed_data_hash( @pytest.mark.parametrize("method", METHODS) -def test_method_builtin(client: Client, method: MethodType) -> None: +def test_method_builtin(session: Session, method: MethodType) -> None: # calling a method with a builtin slip44 will work - method(client, 60, None) + method(session, 60, None) @pytest.mark.parametrize("method", METHODS) -def test_method_def_missing(client: Client, method: MethodType) -> None: +def test_method_def_missing(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has no definition will fail with pytest.raises(TrezorFailure, match="Forbidden key path"): - method(client, 66666, None) + method(session, 66666, None) @pytest.mark.parametrize("method", METHODS) -def test_method_external(client: Client, method: MethodType) -> None: +def test_method_external(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition will work network = common.encode_network(slip44=66666) - method(client, 66666, network) + method(session, 66666, network) @pytest.mark.parametrize("method", METHODS) -def test_method_external_mismatch(client: Client, method: MethodType) -> None: +def test_method_external_mismatch(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition that does not match # the slip44 will fail network = common.encode_network(slip44=77777) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - method(client, 66666, network) + method(session, 66666, network) diff --git a/tests/device_tests/ethereum/test_definitions_bad.py b/tests/device_tests/ethereum/test_definitions_bad.py index 3f21195643f..ae917105ae9 100644 --- a/tests/device_tests/ethereum/test_definitions_bad.py +++ b/tests/device_tests/ethereum/test_definitions_bad.py @@ -5,7 +5,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import EthereumDefinitionType from trezorlib.tools import parse_path @@ -16,99 +16,99 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] -def fails(client: Client, network: bytes, match: str) -> None: +def fails(session: Session, network: bytes, match: str) -> None: with pytest.raises(TrezorFailure, match=match): ethereum.get_address( - client, + session, parse_path("m/44h/666666h/0h"), show_display=False, encoded_network=network, ) -def test_short_message(client: Client) -> None: - fails(client, b"\x00", "Invalid Ethereum definition") +def test_short_message(session: Session) -> None: + fails(session, b"\x00", "Invalid Ethereum definition") -def test_mangled_signature(client: Client) -> None: +def test_mangled_signature(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_signature = signature[:-1] + b"\xff" - fails(client, payload + proof + bad_signature, "Invalid definition signature") + fails(session, payload + proof + bad_signature, "Invalid definition signature") -def test_not_enough_signatures(client: Client) -> None: +def test_not_enough_signatures(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [], threshold=1) - fails(client, payload + proof + signature, "Invalid definition signature") + fails(session, payload + proof + signature, "Invalid definition signature") -def test_missing_signature(client: Client) -> None: +def test_missing_signature(session: Session) -> None: payload = make_payload() proof, _ = sign_payload(payload, []) - fails(client, payload + proof, "Invalid Ethereum definition") + fails(session, payload + proof, "Invalid Ethereum definition") -def test_mangled_payload(client: Client) -> None: +def test_mangled_payload(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_payload = payload[:-1] + b"\xff" - fails(client, bad_payload + proof + signature, "Invalid definition signature") + fails(session, bad_payload + proof + signature, "Invalid definition signature") -def test_proof_length_mismatch(client: Client) -> None: +def test_proof_length_mismatch(session: Session) -> None: payload = make_payload() _, signature = sign_payload(payload, []) bad_proof = b"\x01" - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_proof(client: Client) -> None: +def test_bad_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [sha256(b"x").digest()]) bad_proof = proof[:-1] + b"\xff" - fails(client, payload + bad_proof + signature, "Invalid definition signature") + fails(session, payload + bad_proof + signature, "Invalid definition signature") -def test_trimmed_proof(client: Client) -> None: +def test_trimmed_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_proof = proof[:-1] - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_prefix(client: Client) -> None: +def test_bad_prefix(session: Session) -> None: payload = make_payload() payload = b"trzd2" + payload[5:] proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_bad_type(client: Client) -> None: +def test_bad_type(session: Session) -> None: # assuming we expect a network definition payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=make_token()) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition type mismatch") + fails(session, payload + proof + signature, "Definition type mismatch") -def test_outdated(client: Client) -> None: +def test_outdated(session: Session) -> None: payload = make_payload(timestamp=0) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition is outdated") + fails(session, payload + proof + signature, "Definition is outdated") -def test_malformed_protobuf(client: Client) -> None: +def test_malformed_protobuf(session: Session) -> None: payload = make_payload(message=b"\x00") proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_protobuf_mismatch(client: Client) -> None: +def test_protobuf_mismatch(session: Session) -> None: payload = make_payload( data_type=EthereumDefinitionType.NETWORK, message=make_token() ) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") payload = make_payload( data_type=EthereumDefinitionType.TOKEN, message=make_network() @@ -119,13 +119,13 @@ def test_protobuf_mismatch(client: Client) -> None: params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) ethereum.sign_tx( - client, + session, **params, definitions=make_defs(None, payload + proof + signature), ) -def test_trailing_garbage(client: Client) -> None: +def test_trailing_garbage(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature + b"\x00", "Invalid Ethereum definition") + fails(session, payload + proof + signature + b"\x00", "Invalid Ethereum definition") diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index 3add0ad92fb..b57fcd6afd3 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -27,21 +27,21 @@ @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress(client: Client, parameters, result): +def test_getaddress(session: Session, parameters, result): address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True) == result["address"] + ethereum.get_address(session, address_n, show_display=True) == result["address"] ) @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress_chunkify_details(client: Client, parameters, result): - with client: +def test_getaddress_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True, chunkify=True) + ethereum.get_address(session, address_n, show_display=True, chunkify=True) == result["address"] ) diff --git a/tests/device_tests/ethereum/test_getpublickey.py b/tests/device_tests/ethereum/test_getpublickey.py index 103b261f579..586abf736d7 100644 --- a/tests/device_tests/ethereum/test_getpublickey.py +++ b/tests/device_tests/ethereum/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -27,9 +27,9 @@ @parametrize_using_common_fixtures("ethereum/getpublickey.json") -def test_ethereum_getpublickey(client: Client, parameters, result): +def test_ethereum_getpublickey(session: Session, parameters, result): path = parse_path(parameters["path"]) - res = ethereum.get_public_node(client, path) + res = ethereum.get_public_node(session, path) assert res.node.depth == len(path) assert res.node.fingerprint == result["fingerprint"] assert res.node.child_num == result["child_num"] @@ -38,14 +38,14 @@ def test_ethereum_getpublickey(client: Client, parameters, result): assert res.xpub == result["xpub"] -def test_slip25_disallowed(client: Client): +def test_slip25_disallowed(session: Session): path = parse_path("m/10025'/60'/0'/0/0") with pytest.raises(TrezorFailure): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) @pytest.mark.models("legacy") -def test_legacy_restrictions(client: Client): +def test_legacy_restrictions(session: Session): path = parse_path("m/46'") with pytest.raises(TrezorFailure, match="Invalid path for EthereumGetPublicKey"): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index 14dda4bdbe0..43b872af4b2 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum, exceptions -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,11 +28,11 @@ @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data( - client, + session, address_n, parameters["data"], metamask_v4_compat=parameters["metamask_v4_compat"], @@ -43,11 +43,11 @@ def test_ethereum_sign_typed_data(client: Client, parameters, result): @pytest.mark.models("legacy") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data_hash( - client, + session, address_n, ethereum.decode_hex(parameters["domain_separator_hash"]), # message hash is empty for domain-only hashes @@ -96,13 +96,13 @@ def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): @pytest.mark.models("core", skip="mercury", reason="Not yet implemented in new UI") -def test_ethereum_sign_typed_data_show_more_button(client: Client): - with client: +def test_ethereum_sign_typed_data_show_more_button(session: Session): + with session.client as client: client.watch_layout() IF = InputFlowEIP712ShowMore(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, @@ -110,13 +110,13 @@ def test_ethereum_sign_typed_data_show_more_button(client: Client): @pytest.mark.models("core") -def test_ethereum_sign_typed_data_cancel(client: Client): - with client, pytest.raises(exceptions.Cancelled): +def test_ethereum_sign_typed_data_cancel(session: Session): + with session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() IF = InputFlowEIP712Cancel(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index 8cf2680ad82..7e50bd205ad 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -26,18 +26,18 @@ @parametrize_using_common_fixtures("ethereum/signmessage.json") -def test_signmessage(client: Client, parameters, result): +def test_signmessage(session: Session, parameters, result): res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] @parametrize_using_common_fixtures("ethereum/verifymessage.json") -def test_verify(client: Client, parameters, result): +def test_verify(session: Session, parameters, result): res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], @@ -45,7 +45,7 @@ def test_verify(client: Client, parameters, result): assert res is True -def test_verify_invalid(client: Client): +def test_verify_invalid(session: Session): # First vector from the verifymessage JSON fixture msg = "This is an example of a signed message." address = "0xEa53AF85525B1779eE99ece1a5560C0b78537C3b" @@ -54,7 +54,7 @@ def test_verify_invalid(client: Client): ) res = ethereum.verify_message( - client, + session, address, sig, msg, @@ -63,7 +63,7 @@ def test_verify_invalid(client: Client): # Changing the signature, expecting failure res = ethereum.verify_message( - client, + session, address, sig[:-1] + b"\x00", msg, @@ -72,7 +72,7 @@ def test_verify_invalid(client: Client): # Changing the message, expecting failure res = ethereum.verify_message( - client, + session, address, sig, msg + "abc", @@ -81,7 +81,7 @@ def test_verify_invalid(client: Client): # Changing the address, expecting failure res = ethereum.verify_message( - client, + session, address[:-1] + "a", sig, msg, diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index b12e7146cf9..cf2c6ec39d9 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -17,6 +17,7 @@ import pytest from trezorlib import ethereum, exceptions, messages, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters from trezorlib.exceptions import TrezorFailure @@ -55,23 +56,23 @@ def make_defs(parameters: dict) -> messages.EthereumDefinitions: "ethereum/sign_tx_eip155.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx(client: Client, chunkify: bool, parameters: dict, result: dict): - _do_test_signtx(client, parameters, result, chunkify=chunkify) +def test_signtx(session: Session, chunkify: bool, parameters: dict, result: dict): + _do_test_signtx(session, parameters, result, chunkify=chunkify) def _do_test_signtx( - client: Client, + session: Session, parameters: dict, result: dict, input_flow=None, chunkify: bool = False, ): - with client: + with session.client as client: if input_flow: client.watch_layout() client.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -114,10 +115,10 @@ def _do_test_signtx( @pytest.mark.models("core", reason="T1 does not support input flows") -def test_signtx_fee_info(client: Client): - input_flow = InputFlowEthereumSignTxShowFeeInfo(client).get() +def test_signtx_fee_info(session: Session): + input_flow = InputFlowEthereumSignTxShowFeeInfo(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -129,10 +130,10 @@ def test_signtx_fee_info(client: Client): skip="mercury", reason="T1 does not support input flows; Mercury can't send Cancel on Summary", ) -def test_signtx_go_back_from_summary(client: Client): - input_flow = InputFlowEthereumSignTxGoBackFromSummary(client).get() +def test_signtx_go_back_from_summary(session: Session): + input_flow = InputFlowEthereumSignTxGoBackFromSummary(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -141,10 +142,12 @@ def test_signtx_go_back_from_summary(client: Client): @parametrize_using_common_fixtures("ethereum/sign_tx_eip1559.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result: dict): - with client: +def test_signtx_eip1559( + session: Session, chunkify: bool, parameters: dict, result: dict +): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_limit=int(parameters["gas_limit"], 16), @@ -163,14 +166,14 @@ def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result assert sig_v == result["sig_v"] -def test_sanity_checks(client: Client): +def test_sanity_checks(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -183,7 +186,7 @@ def test_sanity_checks(client: Client): # gas overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -196,7 +199,7 @@ def test_sanity_checks(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -207,13 +210,13 @@ def test_sanity_checks(client: Client): ) -def test_data_streaming(client: Client): +def test_data_streaming(session: Session): """Only verifying the expected responses, the signatures are checked in vectorized function above. """ - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), (is_t1, messages.ButtonRequest(code=messages.ButtonRequestType.SignTx)), @@ -251,7 +254,7 @@ def test_data_streaming(client: Client): ) ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0, gas_price=20_000, @@ -263,11 +266,11 @@ def test_data_streaming(client: Client): ) -def test_signtx_eip1559_access_list(client: Client): - with client: +def test_signtx_eip1559_access_list(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -302,11 +305,11 @@ def test_signtx_eip1559_access_list(client: Client): ) -def test_signtx_eip1559_access_list_larger(client: Client): - with client: +def test_signtx_eip1559_access_list_larger(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -355,14 +358,14 @@ def test_signtx_eip1559_access_list_larger(client: Client): ) -def test_sanity_checks_eip1559(client: Client): +def test_sanity_checks_eip1559(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -376,7 +379,7 @@ def test_sanity_checks_eip1559(client: Client): # max fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -390,7 +393,7 @@ def test_sanity_checks_eip1559(client: Client): # priority fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -404,7 +407,7 @@ def test_sanity_checks_eip1559(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -435,10 +438,12 @@ def input_flow_data_go_back(client: Client, cancel: bool = False): "flow", (input_flow_data_skip, input_flow_data_scroll_down, input_flow_data_go_back) ) @pytest.mark.models("core", skip="mercury", reason="Not yet implemented in new UI") -def test_signtx_data_pagination(client: Client, flow): +def test_signtx_data_pagination(session: Session, flow): + raise Exception("TEST DOES NOT WORK - TODO FIX") + def _sign_tx_call(): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0x0, gas_price=0x14, @@ -450,12 +455,12 @@ def _sign_tx_call(): data=bytes.fromhex(HEXDATA), ) - with client: + with session.client as client: client.watch_layout() client.set_input_flow(flow(client)) _sign_tx_call() - with client, pytest.raises(exceptions.Cancelled): + with session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() client.set_input_flow(flow(client, cancel=True)) _sign_tx_call() @@ -464,20 +469,22 @@ def _sign_tx_call(): @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_staking(client: Client, chunkify: bool, parameters: dict, result: dict): - input_flow = InputFlowEthereumSignTxStaking(client).get() +def test_signtx_staking( + session: Session, chunkify: bool, parameters: dict, result: dict +): + input_flow = InputFlowEthereumSignTxStaking(session.client).get() _do_test_signtx( - client, parameters, result, input_flow=input_flow, chunkify=chunkify + session, parameters, result, input_flow=input_flow, chunkify=chunkify ) @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_data_error.json") -def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dict): +def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: dict): # result not needed with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -494,10 +501,10 @@ def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dic @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") -def test_signtx_staking_eip1559(client: Client, parameters: dict, result: dict): - with client: +def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), max_gas_fee=int(parameters["max_gas_fee"], 16), diff --git a/tests/device_tests/misc/test_msg_cipherkeyvalue.py b/tests/device_tests/misc/test_msg_cipherkeyvalue.py index 7a9fe664206..4efec7ab060 100644 --- a/tests/device_tests/misc/test_msg_cipherkeyvalue.py +++ b/tests/device_tests/misc/test_msg_cipherkeyvalue.py @@ -17,15 +17,15 @@ import pytest from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_encrypt(client: Client): +def test_encrypt(session: Session): res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -35,7 +35,7 @@ def test_encrypt(client: Client): assert res.hex() == "676faf8f13272af601776bc31bc14e8f" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -45,7 +45,7 @@ def test_encrypt(client: Client): assert res.hex() == "5aa0fbcb9d7fa669880745479d80c622" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -55,7 +55,7 @@ def test_encrypt(client: Client): assert res.hex() == "958d4f63269b61044aaedc900c8d6208" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -66,7 +66,7 @@ def test_encrypt(client: Client): # different key res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test2", b"testing message!", @@ -77,7 +77,7 @@ def test_encrypt(client: Client): # different message res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message! it is different", @@ -90,7 +90,7 @@ def test_encrypt(client: Client): # different path res = misc.encrypt_keyvalue( - client, + session, [0, 1, 3], "test", b"testing message!", @@ -101,9 +101,9 @@ def test_encrypt(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_decrypt(client: Client): +def test_decrypt(session: Session): res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), @@ -113,7 +113,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), @@ -123,7 +123,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), @@ -133,7 +133,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), @@ -144,7 +144,7 @@ def test_decrypt(client: Client): # different key res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test2", bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), @@ -155,7 +155,7 @@ def test_decrypt(client: Client): # different message res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex( @@ -168,7 +168,7 @@ def test_decrypt(client: Client): # different path res = misc.decrypt_keyvalue( - client, + session, [0, 1, 3], "test", bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), @@ -178,11 +178,11 @@ def test_decrypt(client: Client): assert res == b"testing message!" -def test_encrypt_badlen(client: Client): +def test_encrypt_badlen(session: Session): with pytest.raises(Exception): - misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.encrypt_keyvalue(session, [0, 1, 2], "test", b"testing") -def test_decrypt_badlen(client: Client): +def test_decrypt_badlen(session: Session): with pytest.raises(Exception): - misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.decrypt_keyvalue(session, [0, 1, 2], "test", b"testing") diff --git a/tests/device_tests/misc/test_msg_getecdhsessionkey.py b/tests/device_tests/misc/test_msg_getecdhsessionkey.py index 8c38f612b1d..d7c532dc5a0 100644 --- a/tests/device_tests/misc/test_msg_getecdhsessionkey.py +++ b/tests/device_tests/misc/test_msg_getecdhsessionkey.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_ecdh(client: Client): +def test_ecdh(session: Session): identity = messages.IdentityType( proto="gpg", user="", @@ -37,7 +37,7 @@ def test_ecdh(client: Client): "0407f2c6e5becf3213c1d07df0cfbe8e39f70a8c643df7575e5c56859ec52c45ca950499c019719dae0fda04248d851e52cf9d66eeb211d89a77be40de22b6c89d" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="secp256k1", @@ -55,7 +55,7 @@ def test_ecdh(client: Client): "04811a6c2bd2a547d0dd84747297fec47719e7c3f9b0024f027c2b237be99aac39a9230acbd163d0cb1524a0f5ea4bfed6058cec6f18368f72a12aa0c4d083ff64" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="nist256p1", @@ -73,7 +73,7 @@ def test_ecdh(client: Client): "40a8cf4b6a64c4314e80f15a8ea55812bd735fbb365936a48b2d78807b575fa17a" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="curve25519", diff --git a/tests/device_tests/misc/test_msg_getentropy.py b/tests/device_tests/misc/test_msg_getentropy.py index 593fb1a76c1..d5d19425f9b 100644 --- a/tests/device_tests/misc/test_msg_getentropy.py +++ b/tests/device_tests/misc/test_msg_getentropy.py @@ -20,7 +20,7 @@ from trezorlib import messages as m from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session ENTROPY_LENGTHS_POW2 = [2**l for l in range(10)] ENTROPY_LENGTHS_POW2_1 = [2**l + 1 for l in range(10)] @@ -40,11 +40,11 @@ def entropy(data): @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) -def test_entropy(client: Client, entropy_length): - with client: - client.set_expected_responses( +def test_entropy(session: Session, entropy_length): + with session: + session.set_expected_responses( [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] ) - ent = misc.get_entropy(client, entropy_length) + ent = misc.get_entropy(session, entropy_length) assert len(ent) == entropy_length print(f"{entropy_length} bytes: entropy = {entropy(ent)}") diff --git a/tests/device_tests/misc/test_msg_signidentity.py b/tests/device_tests/misc/test_msg_signidentity.py index bc9e7f5bd4e..6715387d387 100644 --- a/tests/device_tests/misc/test_msg_signidentity.py +++ b/tests/device_tests/misc/test_msg_signidentity.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_sign(client: Client): +def test_sign(session: Session): hidden = bytes.fromhex( "cd8552569d6e4509266ef137584d1e62c7579b5b8ed69bbafa4b864c6521e7c2" ) @@ -40,7 +40,7 @@ def test_sign(client: Client): path="/login", index=0, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "17F17smBTX9VTZA9Mj8LM5QGYNZnmziCjL" assert ( sig.public_key.hex() @@ -62,7 +62,7 @@ def test_sign(client: Client): path="/pub", index=3, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "1KAr6r5qF2kADL8bAaRQBjGKYEGxn9WrbS" assert ( sig.public_key.hex() @@ -80,7 +80,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="nist256p1" + session, identity, hidden, visual, ecdsa_curve_name="nist256p1" ) assert sig.address is None assert ( @@ -99,7 +99,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -116,7 +116,7 @@ def test_sign(client: Client): proto="gpg", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -133,7 +133,7 @@ def test_sign(client: Client): proto="signify", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index dfd0ce5ab09..1a6d3ffc01c 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -47,19 +47,19 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_monero_getaddress(client: Client, path: str, expected_address: bytes): - address = monero.get_address(client, parse_path(path), show_display=True) +def test_monero_getaddress(session: Session, path: str, expected_address: bytes): + address = monero.get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_monero_getaddress_chunkify_details( - client: Client, path: str, expected_address: bytes + session: Session, path: str, expected_address: bytes ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = monero.get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/monero/test_getwatchkey.py b/tests/device_tests/monero/test_getwatchkey.py index eee83d0445f..30e3d7b1140 100644 --- a/tests/device_tests/monero/test_getwatchkey.py +++ b/tests/device_tests/monero/test_getwatchkey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -27,8 +27,8 @@ @pytest.mark.monero @pytest.mark.models("core") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_monero_getwatchkey(client: Client): - res = monero.get_watch_key(client, parse_path("m/44h/128h/0h")) +def test_monero_getwatchkey(session: Session): + res = monero.get_watch_key(session, parse_path("m/44h/128h/0h")) assert ( res.address == b"4Ahp23WfMrMFK3wYL2hLWQFGt87ZTeRkufS6JoQZu6MEFDokAQeGWmu9MA3GFq1yVLSJQbKJqVAn9F9DLYGpRzRAEXqAXKM" @@ -37,7 +37,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "8722520a581e2a50cc1adab4a1692401effd37b0d63b9d9b60fd7f34ea2b950e" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/1h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/1h")) assert ( res.address == b"44iAazhoAkv5a5RqLNVyh82a1n3ceNggmN4Ho7bUBJ14WkEVR8uFTe9f7v5rNnJ2kEbVXxfXiRzsD5Jtc6NvBi4D6WNHPie" @@ -46,7 +46,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "1f70b7d9e86c11b7a5bee883b75c43d6be189c8f812726ea1ecd94b06bb7db04" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/2h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/2h")) assert ( res.address == b"47ejhmbZ4wHUhXaqA4b7PN667oPMkokf4ZkNdWrMSPy9TNaLVr7vLqVUQHh2MnmaAEiyrvLsX8xUf99q3j1iAeMV8YvSFcH" diff --git a/tests/device_tests/nem/test_getaddress.py b/tests/device_tests/nem/test_getaddress.py index b2b20c529ec..920dd974904 100644 --- a/tests/device_tests/nem/test_getaddress.py +++ b/tests/device_tests/nem/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -28,10 +28,10 @@ @pytest.mark.models("t1b1", "t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_getaddress(client: Client, chunkify: bool): +def test_nem_getaddress(session: Session, chunkify: bool): assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x68, show_display=True, @@ -41,7 +41,7 @@ def test_nem_getaddress(client: Client, chunkify: bool): ) assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x98, show_display=True, diff --git a/tests/device_tests/nem/test_signtx_mosaics.py b/tests/device_tests/nem/test_signtx_mosaics.py index 51cfd556a77..3e6b835f953 100644 --- a/tests/device_tests/nem/test_signtx_mosaics.py +++ b/tests/device_tests/nem/test_signtx_mosaics.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -32,9 +32,9 @@ ] -def test_nem_signtx_mosaic_supply_change(client: Client): +def test_nem_signtx_mosaic_supply_change(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_mosaic_supply_change(client: Client): ) -def test_nem_signtx_mosaic_creation(client: Client): +def test_nem_signtx_mosaic_creation(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -93,9 +93,9 @@ def test_nem_signtx_mosaic_creation(client: Client): ) -def test_nem_signtx_mosaic_creation_properties(client: Client): +def test_nem_signtx_mosaic_creation_properties(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -130,9 +130,9 @@ def test_nem_signtx_mosaic_creation_properties(client: Client): ) -def test_nem_signtx_mosaic_creation_levy(client: Client): +def test_nem_signtx_mosaic_creation_levy(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_multisig.py b/tests/device_tests/nem/test_signtx_multisig.py index d153547c424..ef641e52f39 100644 --- a/tests/device_tests/nem/test_signtx_multisig.py +++ b/tests/device_tests/nem/test_signtx_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,9 +31,9 @@ # assertion data from T1 -def test_nem_signtx_aggregate_modification(client: Client): +def test_nem_signtx_aggregate_modification(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_aggregate_modification(client: Client): ) -def test_nem_signtx_multisig(client: Client): +def test_nem_signtx_multisig(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 1, @@ -98,7 +98,7 @@ def test_nem_signtx_multisig(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -132,9 +132,9 @@ def test_nem_signtx_multisig(client: Client): ) -def test_nem_signtx_multisig_signer(client: Client): +def test_nem_signtx_multisig_signer(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 333, @@ -169,7 +169,7 @@ def test_nem_signtx_multisig_signer(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 900000, diff --git a/tests/device_tests/nem/test_signtx_others.py b/tests/device_tests/nem/test_signtx_others.py index f775c60cdf6..9760d8c5235 100644 --- a/tests/device_tests/nem/test_signtx_others.py +++ b/tests/device_tests/nem/test_signtx_others.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,10 +31,10 @@ # assertion data from T1 -def test_nem_signtx_importance_transfer(client: Client): - with client: +def test_nem_signtx_importance_transfer(session: Session): + with session: tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 12349215, @@ -60,9 +60,9 @@ def test_nem_signtx_importance_transfer(client: Client): ) -def test_nem_signtx_provision_namespace(client: Client): +def test_nem_signtx_provision_namespace(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_transfers.py b/tests/device_tests/nem/test_signtx_transfers.py index 0388b30ffb4..2df62b55936 100644 --- a/tests/device_tests/nem/test_signtx_transfers.py +++ b/tests/device_tests/nem/test_signtx_transfers.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12, is_core @@ -32,16 +32,16 @@ # assertion data from T1 @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_signtx_simple(client: Client, chunkify: bool): - with client: - client.set_expected_responses( +def test_nem_signtx_simple(session: Session, chunkify: bool): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Unencrypted message messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -53,7 +53,7 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -82,16 +82,16 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_encrypted_payload(client: Client): - with client: - client.set_expected_responses( +def test_nem_signtx_encrypted_payload(session: Session): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Ask for encryption messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -103,7 +103,7 @@ def test_nem_signtx_encrypted_payload(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -134,9 +134,9 @@ def test_nem_signtx_encrypted_payload(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_xem_as_mosaic(client: Client): +def test_nem_signtx_xem_as_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -168,9 +168,9 @@ def test_nem_signtx_xem_as_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_unknown_mosaic(client: Client): +def test_nem_signtx_unknown_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -202,9 +202,9 @@ def test_nem_signtx_unknown_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic(client: Client): +def test_nem_signtx_known_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -236,9 +236,9 @@ def test_nem_signtx_known_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic_with_levy(client: Client): +def test_nem_signtx_known_mosaic_with_levy(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -270,9 +270,9 @@ def test_nem_signtx_known_mosaic_with_levy(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_multiple_mosaics(client: Client): +def test_nem_signtx_multiple_mosaics(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 416fef78eac..084c65059d5 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -19,7 +19,7 @@ import pytest from trezorlib import device, exceptions, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import ( @@ -28,9 +28,9 @@ ) -def do_recover_legacy(client: Client, mnemonic: list[str]): +def do_recover_legacy(session: Session, mnemonic: list[str]): def input_callback(_): - word, pos = client.debug.read_recovery_word() + word, pos = session.client.debug.read_recovery_word() if pos != 0 and pos is not None: word = mnemonic[pos - 1] mnemonic[pos - 1] = None @@ -39,7 +39,7 @@ def input_callback(_): return word ret = device.recover( - client, + session, type=messages.RecoveryType.DryRun, word_count=len(mnemonic), input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, @@ -50,58 +50,58 @@ def input_callback(_): return ret -def do_recover_core(client: Client, mnemonic: list[str], mismatch: bool = False): - with client: +def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): + with session.client as client: client.watch_layout() IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) client.set_input_flow(IF.get()) - return device.recover(client, type=messages.RecoveryType.DryRun) + return device.recover(session, type=messages.RecoveryType.DryRun) -def do_recover(client: Client, mnemonic: list[str], mismatch: bool = False): - if client.model is models.T1B1: - return do_recover_legacy(client, mnemonic) +def do_recover(session: Session, mnemonic: list[str], mismatch: bool = False): + if session.model is models.T1B1: + return do_recover_legacy(session, mnemonic) else: - return do_recover_core(client, mnemonic, mismatch) + return do_recover_core(session, mnemonic, mismatch) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_dry_run(client: Client): - ret = do_recover(client, MNEMONIC12.split(" ")) +def test_dry_run(session: Session): + ret = do_recover(session, MNEMONIC12.split(" ")) assert isinstance(ret, messages.Success) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_seed_mismatch(client: Client): +def test_seed_mismatch(session: Session): with pytest.raises( exceptions.TrezorFailure, match="does not match the one in the device" ): - do_recover(client, ["all"] * 12, mismatch=True) + do_recover(session, ["all"] * 12, mismatch=True) @pytest.mark.models("legacy") -def test_invalid_seed_t1(client: Client): +def test_invalid_seed_t1(session: Session): with pytest.raises(exceptions.TrezorFailure, match="Invalid seed"): - do_recover(client, ["stick"] * 12) + do_recover(session, ["stick"] * 12) @pytest.mark.models("core") -def test_invalid_seed_core(client: Client): - with client: +def test_invalid_seed_core(session: Session): + with session.client as client: client.watch_layout() IF = InputFlowBip39RecoveryDryRunInvalid(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( - client, + session, type=messages.RecoveryType.DryRun, ) @pytest.mark.setup_client(uninitialized=True) -def test_uninitialized(client: Client): +def test_uninitialized(session: Session): with pytest.raises(exceptions.TrezorFailure, match="not initialized"): - do_recover(client, ["all"] * 12) + do_recover(session, ["all"] * 12) DRY_RUN_ALLOWED_FIELDS = ( @@ -140,7 +140,7 @@ def _make_bad_params(): @pytest.mark.parametrize("field_name, field_value", _make_bad_params()) -def test_bad_parameters(client: Client, field_name: str, field_value: Any): +def test_bad_parameters(session: Session, field_name: str, field_value: Any): msg = messages.RecoveryDevice( type=messages.RecoveryType.DryRun, word_count=12, @@ -152,4 +152,4 @@ def test_bad_parameters(client: Client, field_name: str, field_value: Any): exceptions.TrezorFailure, match="Forbidden field set in dry-run", ): - client.call(msg) + session.call(msg) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py index 4f2eab6147b..7ddc634b8d4 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -29,9 +29,10 @@ @pytest.mark.setup_client(uninitialized=True) -def test_pin_passphrase(client: Client): +def test_pin_passphrase(session: Session): + debug = session.client.debug mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -43,30 +44,30 @@ def test_pin_passphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -76,23 +77,26 @@ def test_pin_passphrase(client: Client): assert fakes == 12 assert mnemonic == [None] * 12 + raise Exception("TEST IS USING INIT MESSAGE - TODO CHANGE") # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session.init_device() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_nopin_nopassphrase(client: Client): +def test_nopin_nopassphrase(session: Session): mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -104,19 +108,20 @@ def test_nopin_nopassphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug = session.client.debug + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -126,21 +131,26 @@ def test_nopin_nopassphrase(client: Client): assert fakes == 12 assert mnemonic == [None] * 12 + raise Exception("TEST IS USING INIT MESSAGE - TODO CHANGE") + # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + # session.init_device() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_word_fail(client: Client): - ret = client.call_raw( +def test_word_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -152,23 +162,24 @@ def test_word_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.WordRequest) for _ in range(int(12 * 2)): - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word="kwyjibo")) + ret = session.call_raw(messages.WordAck(word="kwyjibo")) assert isinstance(ret, messages.Failure) break else: - client.call_raw(messages.WordAck(word=word)) + session.call_raw(messages.WordAck(word=word)) @pytest.mark.setup_client(uninitialized=True) -def test_pin_fail(client: Client): - ret = client.call_raw( +def test_pin_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -180,36 +191,36 @@ def test_pin_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN4) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN4) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time, but different one - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Failure should be raised assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): device.recover( - client, + session, word_count=12, pin_protection=False, passphrase_protection=False, label="label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index 6046e85ca78..fe1c2e52733 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import InputFlowBip39Recovery @@ -26,47 +26,47 @@ @pytest.mark.setup_client(uninitialized=True) -def test_tt_pin_passphrase(client: Client): - with client: +def test_tt_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" @pytest.mark.setup_client(uninitialized=True) -def test_tt_nopin_nopassphrase(client: Client): - with client: +def test_tt_nopin_nopassphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): - device.recover(client) + device.recover(session) with pytest.raises(exceptions.TrezorFailure, match="Already initialized"): - client.call(messages.RecoveryDevice()) + session.call(messages.RecoveryDevice()) diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index fa181117357..80139dd2c41 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_33 from ...input_flows import ( @@ -46,13 +46,13 @@ # To allow reusing functionality for multiple tests def _test_secret( - client: Client, shares: list[str], secret: str, click_info: bool = False + session: Session, shares: list[str], secret: str, click_info: bool = False ): - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", @@ -60,86 +60,86 @@ def _test_secret( # Workflow succesfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Advanced - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Advanced + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_secret(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret) +def test_secret(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret) @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models(skip="safe3", reason="safe3 does not have info button") -def test_secret_click_info_button(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret, click_info=True) +def test_secret_click_info_button(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret, click_info=True) @pytest.mark.setup_client(uninitialized=True) -def test_extra_share_entered(client: Client): +def test_extra_share_entered(session: Session): _test_secret( - client, + session, shares=EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, secret=VECTORS[0][1], ) @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryNoAbort( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): # we choose the second share from the fixture because # the 1st is 1of1 and group threshold condition is reached first first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ") # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( client, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_group_threshold_reached(client: Client): +def test_group_threshold_reached(session: Session): # first share in the fixture is 1of1 so we choose that first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ") # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( client, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 73e18a8686c..136be18bb6c 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import MNEMONIC_SLIP39_ADVANCED_20 @@ -39,14 +39,14 @@ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryDryRun( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -60,9 +60,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( @@ -70,7 +70,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 3f7ed75e730..c5db44e91b1 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -70,32 +70,32 @@ @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("shares, secret, backup_type", VECTORS) def test_secret( - client: Client, shares: list[str], secret: str, backup_type: messages.BackupType + session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with client: + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is backup_type + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is backup_type # Check mnemonic - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.setup_client(uninitialized=True) -def test_recover_with_pin_passphrase(client: Client): - with client: +def test_recover_with_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery( client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="label", @@ -103,99 +103,99 @@ def test_recover_with_pin_passphrase(client: Client): # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Slip39_Basic @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_abort_between_shares(client: Client): - with client: +def test_abort_between_shares(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_first_share(client: Client): - with client: +def test_invalid_mnemonic_first_share(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_second_share(client: Client): - with client: +def test_invalid_mnemonic_second_share(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("nth_word", range(3)) -def test_wrong_nth_word(client: Client, nth_word: int): +def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: + with session.client as client: IF = InputFlowSlip39BasicRecoveryWrongNthWord(client, share, nth_word) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: + with session.client as client: IF = InputFlowSlip39BasicRecoverySameShare(client, share) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_1of1(client: Client): - with client: +def test_1of1(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", @@ -203,7 +203,7 @@ def test_1of1(client: Client): # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Basic diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index 4c9ddf8036f..3fcd7b51bd2 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...input_flows import InputFlowSlip39BasicRecoveryDryRun @@ -37,12 +37,12 @@ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -56,9 +56,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( @@ -66,7 +66,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index 148087d4f4b..c2b25490d5a 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -19,7 +19,7 @@ from shamir_mnemonic import shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupAvailability, BackupType from ...common import WITH_MOCK_URANDOM @@ -31,32 +31,32 @@ ) -def backup_flow_bip39(client: Client) -> bytes: - with client: +def backup_flow_bip39(session: Session) -> bytes: + with session.client as client: IF = InputFlowBip39Backup(client) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) assert IF.mnemonic is not None return IF.mnemonic.encode() -def backup_flow_slip39_basic(client: Client): - with client: +def backup_flow_slip39_basic(session: Session): + with session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) ems = shamir.recover_ems(groups) return ems.ciphertext -def backup_flow_slip39_advanced(client: Client): - with client: +def backup_flow_slip39_advanced(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] groups = shamir.decode_mnemonics(mnemonics) @@ -74,32 +74,34 @@ def backup_flow_slip39_advanced(client: Client): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_msg(client: Client, backup_type, backup_flow): - with WITH_MOCK_URANDOM, client: +def test_skip_backup_msg(session: Session, backup_type, backup_flow): + with WITH_MOCK_URANDOM, session: device.reset( - client, + session, skip_backup=True, passphrase_protection=False, pin_protection=False, backup_type=backup_type, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + raise Exception("NOT VALID TEST, INIT IS REMOVED, NEEDS TO BE REMADE") + + session.init_device() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret @@ -107,32 +109,34 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow): - with WITH_MOCK_URANDOM, client: +def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowResetSkipBackup(client) client.set_input_flow(IF.get()) device.reset( - client, + session, pin_protection=False, passphrase_protection=False, backup_type=backup_type, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type + + secret = backup_flow(session) - secret = backup_flow(client) + raise Exception("NOT VALID TEST, init_device IS REMOVED, NEEDS TO BE REMADE") - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session.init_device() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py index 803818b3752..c37839b77f7 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py @@ -18,7 +18,7 @@ from mnemonic import Mnemonic from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -28,8 +28,9 @@ @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup(client: Client): - ret = client.call_raw( +def test_reset_device_skip_backup(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -40,17 +41,17 @@ def test_reset_device_skip_backup(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False @@ -61,14 +62,14 @@ def test_reset_device_skip_backup(client: Client): expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -78,9 +79,9 @@ def test_reset_device_skip_backup(client: Client): mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.Success) @@ -90,13 +91,14 @@ def test_reset_device_skip_backup(client: Client): assert mnemonic == expected_mnemonic # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup_break(client: Client): - ret = client.call_raw( +def test_reset_device_skip_backup_break(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -107,26 +109,26 @@ def test_reset_device_skip_backup_break(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False assert ret.no_backup is False # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) # send Initialize -> break workflow - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -134,11 +136,11 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) # read Features again - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -146,6 +148,6 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False -def test_initialized_device_backup_fail(client: Client): - ret = client.call_raw(messages.BackupDevice()) +def test_initialized_device_backup_fail(session: Session): + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py index 689b81b0d61..05e5f6e073b 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py @@ -18,7 +18,7 @@ from mnemonic import Mnemonic from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -26,9 +26,10 @@ pytestmark = pytest.mark.models("legacy") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): + debug = session.client.debug # No PIN, no passphrase - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=False, @@ -38,13 +39,13 @@ def reset_device(client: Client, strength: int): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -53,9 +54,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(session.debug.read_reset_word()) + session.debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -65,9 +66,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(session.debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -77,32 +78,35 @@ def reset_device(client: Client, strength: int): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False assert resp.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_128(client: Client): - reset_device(client, 128) +def test_reset_device_128(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) +def test_reset_device_192(session: Session): + reset_device(session, 192) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_256_pin(client: Client): +def test_reset_device_256_pin(session: Session): + debug = session.client.debug strength = 256 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -113,24 +117,24 @@ def test_reset_device_256_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -139,9 +143,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -151,9 +155,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -163,23 +167,26 @@ def test_reset_device_256_pin(client: Client): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True assert resp.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -190,27 +197,27 @@ def test_failed_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("1234") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("1234") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("6789") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("6789") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.reset( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index 9715d118b3b..932d622700a 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -19,7 +19,7 @@ from trezorlib import device, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import EXTERNAL_ENTROPY, MNEMONIC12, WITH_MOCK_URANDOM, generate_entropy @@ -32,14 +32,15 @@ pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): - with WITH_MOCK_URANDOM, client: +def reset_device(session: Session, strength: int): + debug = session.client.debug + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -47,7 +48,7 @@ def reset_device(client: Client, strength: int): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -55,7 +56,7 @@ def reset_device(client: Client, strength: int): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False @@ -64,30 +65,31 @@ def reset_device(client: Client, strength: int): # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device(client: Client): - reset_device(client, 128) # 12 words +def test_reset_device(session: Session): + reset_device(session, 128) # 12 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) # 18 words +def test_reset_device_192(session: Session): + reset_device(session, 192) # 18 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_pin(client: Client): +def test_reset_device_pin(session: Session): + debug = session.client.debug strength = 256 # 24 words - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetPIN(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.reset( - client, + session, strength=strength, passphrase_protection=True, pin_protection=True, @@ -95,7 +97,7 @@ def test_reset_device_pin(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -103,7 +105,7 @@ def test_reset_device_pin(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True @@ -111,16 +113,17 @@ def test_reset_device_pin(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_reset_failed_check(client: Client): +def test_reset_failed_check(session: Session): + debug = session.client.debug strength = 256 # 24 words - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetFailedCheck(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -128,7 +131,7 @@ def test_reset_failed_check(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -136,7 +139,7 @@ def test_reset_failed_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False @@ -145,46 +148,47 @@ def test_reset_failed_check(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice(strength=strength, pin_protection=True, label="test") ) # Confirm Reset assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for first time assert isinstance(ret, messages.ButtonRequest) - client.debug.input("654") - ret = client.call_raw(messages.ButtonAck()) + debug.input("654") + ret = session.call_raw(messages.ButtonAck()) # Re-enter PIN for TR - if client.layout_type is LayoutType.TR: + if session.client.layout_type is LayoutType.TR: assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for second time assert isinstance(ret, messages.ButtonRequest) - client.debug.input("456") - ret = client.call_raw(messages.ButtonAck()) + debug.input("456") + ret = session.call_raw(messages.ButtonAck()) # PIN mismatch assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.ButtonRequest) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.reset( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index ea69e07970f..0f0f3d9a65e 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -27,24 +27,24 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) -def test_reset_recovery(client: Client): - mnemonic = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) +def test_reset_recovery(session: Session): + mnemonic = reset(session) + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) - device.wipe(client) - recover(client, mnemonic) - address_after = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + device.wipe(session) + recover(session, mnemonic) + address_after = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) assert address_before == address_after -def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str: - with WITH_MOCK_URANDOM, client: +def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -53,26 +53,26 @@ def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False assert IF.mnemonic is not None return IF.mnemonic -def recover(client: Client, mnemonic: str): +def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with client: + with session.client as client: IF = InputFlowBip39Recovery(client, words) client.set_input_flow(IF.get()) client.watch_layout() - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index e37b4f50994..783927394e6 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -30,9 +30,9 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) -def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) +def test_reset_recovery(session: Session): + mnemonics = reset(session) + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) # we're generating 3of5 groups 3of5 shares each test_combinations = [ mnemonics[0:3] # shares 1-3 from groups 1-3 @@ -49,22 +49,22 @@ def test_reset_recovery(client: Client): + mnemonics[22:25], ] for combination in test_combinations: - device.wipe(client) - recover(client, combination) + device.wipe(session) + recover(session, combination) address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with WITH_MOCK_URANDOM, client: +def reset(session: Session, strength: int = 128) -> list[str]: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -73,25 +73,25 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, False) client.set_input_flow(IF.get()) - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 3980b1a1495..2ba67572363 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -33,28 +33,28 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) @WITH_MOCK_URANDOM -def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) +def test_reset_recovery(session: Session): + mnemonics = reset(session) + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) for share_subset in itertools.combinations(mnemonics, 3): - device.wipe(client) + device.wipe(session) selected_mnemonics = share_subset - recover(client, selected_mnemonics) + recover(session, selected_mnemonics) address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -63,25 +63,25 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 6aa9d2bf3d0..e161ac5a1b7 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -18,7 +18,7 @@ from shamir_mnemonic import shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import BackupAvailability, BackupType @@ -30,17 +30,17 @@ # TODO: test with different options @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_advanced(client: Client): +def test_reset_device_slip39_advanced(session: Session): strength = 128 member_threshold = 3 - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -49,22 +49,22 @@ def test_reset_device_slip39_advanced(client: Client): ) # generate secret locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) def validate_mnemonics( diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index 8eb5d7830fa..0ebaf431193 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -20,7 +20,7 @@ from shamir_mnemonic import MnemonicError, shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import BackupAvailability, BackupType @@ -30,16 +30,16 @@ pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): member_threshold = 3 - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -48,32 +48,32 @@ def reset_device(client: Client, strength: int): ) # generate secret locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.debug.state().reset_entropy secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic(client: Client): - reset_device(client, 128) +def test_reset_device_slip39_basic(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic_256(client: Client): - reset_device(client, 256) +def test_reset_device_slip39_basic_256(session: Session): + reset_device(session, 256) def validate_mnemonics(mnemonics, threshold, expected_ems): diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 0d35b6c5b93..2a066926cd8 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.ripple import get_address from trezorlib.tools import parse_path @@ -43,28 +43,28 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_ripple_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_ripple_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_ripple_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address @pytest.mark.setup_client(mnemonic=CUSTOM_MNEMONIC) -def test_ripple_get_address_other(client: Client): +def test_ripple_get_address_other(session: Session): # data from https://github.com/you21979/node-ripple-bip32/blob/master/test/test.js - address = get_address(client, parse_path("m/44h/144h/0h/0/0")) + address = get_address(session, parse_path("m/44h/144h/0h/0/0")) assert address == "r4ocGE47gm4G4LkA9mriVHQqzpMLBTgnTY" - address = get_address(client, parse_path("m/44h/144h/0h/0/1")) + address = get_address(session, parse_path("m/44h/144h/0h/0/1")) assert address == "rUt9ULSrUvfCmke8HTFU1szbmFpWzVbBXW" diff --git a/tests/device_tests/ripple/test_sign_tx.py b/tests/device_tests/ripple/test_sign_tx.py index a03a29d4bec..82911c8abe8 100644 --- a/tests/device_tests/ripple/test_sign_tx.py +++ b/tests/device_tests/ripple/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ripple -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -29,7 +29,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_ripple_sign_simple_tx(client: Client, chunkify: bool): +def test_ripple_sign_simple_tx(session: Session, chunkify: bool): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -43,7 +43,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -66,7 +66,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -92,7 +92,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -104,7 +104,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): ) -def test_ripple_sign_invalid_fee(client: Client): +def test_ripple_sign_invalid_fee(session: Session): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -121,4 +121,4 @@ def test_ripple_sign_invalid_fee(client: Client): TrezorFailure, match="ProcessError: Fee must be in the range of 10 to 10,000 drops", ): - ripple.sign_tx(client, parse_path("m/44h/144h/0h/0/2"), msg) + ripple.sign_tx(session, parse_path("m/44h/144h/0h/0/2"), msg) diff --git a/tests/device_tests/solana/test_address.py b/tests/device_tests/solana/test_address.py index dca1126c056..ce17d7c2a3b 100644 --- a/tests/device_tests/solana/test_address.py +++ b/tests/device_tests/solana/test_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_address from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ @parametrize_using_common_fixtures( "solana/get_address.json", ) -def test_solana_get_address(client: Client, parameters, result): +def test_solana_get_address(session: Session, parameters, result): actual_result = get_address( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.address == result["expected_address"] diff --git a/tests/device_tests/solana/test_public_key.py b/tests/device_tests/solana/test_public_key.py index 864852b116a..abe24dfc8f1 100644 --- a/tests/device_tests/solana/test_public_key.py +++ b/tests/device_tests/solana/test_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_public_key from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ @parametrize_using_common_fixtures( "solana/get_public_key.json", ) -def test_solana_get_public_key(client: Client, parameters, result): +def test_solana_get_public_key(session: Session, parameters, result): actual_result = get_public_key( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.public_key.hex() == result["expected_public_key"] diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 241a3d3b34f..d5685e1ed75 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import sign_tx from trezorlib.tools import parse_path @@ -42,13 +42,11 @@ "solana/sign_tx.unknown_instructions.json", "solana/sign_tx.predefined_transactions.json", ) -def test_solana_sign_tx(client: Client, parameters, result): - client.init_device(new_session=True) - +def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) actual_result = sign_tx( - client, + session, address_n=parse_path(parameters["address"]), serialized_tx=serialized_tx, additional_info=( diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 8e214ab1135..1d5c59e1f8e 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -55,7 +55,7 @@ import pytest from trezorlib import messages, protobuf, stellar -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -87,10 +87,10 @@ def make_op(operation_data): @parametrize_using_common_fixtures("stellar/sign_tx.json") -def test_sign_tx(client: Client, parameters, result): +def test_sign_tx(session: Session, parameters, result): tx, operations = parameters_to_proto(parameters) response = stellar.sign_tx( - client, tx, operations, tx.address_n, tx.network_passphrase + session, tx, operations, tx.address_n, tx.network_passphrase ) assert response.public_key.hex() == result["public_key"] assert b64encode(response.signature).decode() == result["signature"] @@ -113,20 +113,20 @@ def test_xdr(parameters, result): @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address(client: Client, parameters, result): +def test_get_address(session: Session, parameters, result): address_n = parse_path(parameters["path"]) - address = stellar.get_address(client, address_n, show_display=True) + address = stellar.get_address(session, address_n, show_display=True) assert address == result["address"] @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address_chunkify_details(client: Client, parameters, result): - with client: +def test_get_address_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( - client, address_n, show_display=True, chunkify=True + session, address_n, show_display=True, chunkify=True ) assert address == result["address"] diff --git a/tests/device_tests/test_authenticate_device.py b/tests/device_tests/test_authenticate_device.py index f2ffb5d7157..5e697b4f070 100644 --- a/tests/device_tests/test_authenticate_device.py +++ b/tests/device_tests/test_authenticate_device.py @@ -5,7 +5,7 @@ from cryptography.x509 import extensions as ext from trezorlib import device, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import compact_size @@ -35,16 +35,16 @@ ), ), ) -def test_authenticate_device(client: Client, challenge: bytes) -> None: +def test_authenticate_device(session: Session, challenge: bytes) -> None: # NOTE Applications must generate a random challenge for each request. # Issue an AuthenticateDevice challenge to Trezor. - proof = device.authenticate(client, challenge) + proof = device.authenticate(session, challenge) certs = [x509.load_der_x509_certificate(cert) for cert in proof.certificates] # Verify the last certificate in the certificate chain against trust anchor. root_public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256R1(), ROOT_PUBLIC_KEY[client.model] + ec.SECP256R1(), ROOT_PUBLIC_KEY[session.model] ) root_public_key.verify( certs[-1].signature, @@ -78,11 +78,11 @@ def test_authenticate_device(client: Client, challenge: bytes) -> None: # Verify that the common name matches the Trezor model. common_name = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0] - if client.model == models.T3B1: + if session.model == models.T3B1: # XXX TODO replace as soon as we have T3B1 staging internal_model = "T2B1" else: - internal_model = client.model.internal_name + internal_model = session.model.internal_name assert common_name.value.startswith(internal_model) # Verify the signature of the challenge. diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index dc0f69a1df9..a0412b77025 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -19,7 +19,7 @@ import pytest from trezorlib import device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ..common import TEST_ADDRESS_N, get_test_address @@ -29,42 +29,42 @@ pytestmark = pytest.mark.setup_client(pin=PIN4) -def pin_request(client: Client): +def pin_request(session: Session): return ( messages.PinMatrixRequest - if client.model is models.T1B1 + if session.model is models.T1B1 else messages.ButtonRequest ) -def set_autolock_delay(client: Client, delay): - with client: +def set_autolock_delay(session: Session, delay): + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.ButtonRequest, messages.Success, messages.Features, ] ) - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) -def test_apply_auto_lock_delay(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_apply_auto_lock_delay(session: Session): + set_autolock_delay(session, 10 * 1000) time.sleep(0.1) # sleep less than auto-lock delay - with client: + with session: # No PIN protection is required. - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([pin_request(session), messages.Address]) + get_test_address(session) @pytest.mark.parametrize( @@ -78,44 +78,44 @@ def test_apply_auto_lock_delay(client: Client): 536870, # 149 hours, maximum ], ) -def test_apply_auto_lock_delay_valid(client: Client, seconds): - set_autolock_delay(client, seconds * 1000) - assert client.features.auto_lock_delay_ms == seconds * 1000 +def test_apply_auto_lock_delay_valid(session: Session, seconds): + set_autolock_delay(session, seconds * 1000) + assert session.features.auto_lock_delay_ms == seconds * 1000 -def test_autolock_default_value(client: Client): - assert client.features.auto_lock_delay_ms is None - with client: +def test_autolock_default_value(session: Session): + assert session.features.auto_lock_delay_ms is None + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, label="pls unlock") - client.refresh_features() - assert client.features.auto_lock_delay_ms == 60 * 10 * 1000 + device.apply_settings(session, label="pls unlock") + session.refresh_features() + assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @pytest.mark.parametrize( "seconds", [0, 1, 9, 536871, 2**22], ) -def test_apply_auto_lock_delay_out_of_range(client: Client, seconds): - with client: +def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.Failure(code=messages.FailureType.ProcessError), ] ) delay = seconds * 1000 with pytest.raises(TrezorFailure): - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) @pytest.mark.models("core") -def test_autolock_cancels_ui(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_cancels_ui(session: Session): + set_autolock_delay(session, 10 * 1000) - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -126,44 +126,45 @@ def test_autolock_cancels_ui(client: Client): assert isinstance(resp, messages.ButtonRequest) # send an ack, do not read response - client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) # sleep more than auto-lock delay time.sleep(10.5) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, messages.Failure) assert resp.code == messages.FailureType.ActionCancelled -def test_autolock_ignores_initialize(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_initialize(session: Session): + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() while time.monotonic() - start < 11: # init_device should always work even if locked - client.init_device() + raise Exception("THIS SHOULD BE REMADE") + session.init_device() time.sleep(0.1) # after 11 seconds we are definitely locked - assert client.features.unlocked is False + assert session.features.unlocked is False -def test_autolock_ignores_getaddress(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_getaddress(session: Session): + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() # let's continue for 8 seconds to give a little leeway to the slow CI while time.monotonic() - start < 8: - get_test_address(client) + get_test_address(session) time.sleep(0.1) # sleep 3 more seconds to wait for autolock time.sleep(3) # after 11 seconds we are definitely locked - client.refresh_features() - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index c2d1202eb52..a985000bace 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -15,32 +15,35 @@ # If not, see . from trezorlib import device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_features(client: Client): - f0 = client.features - # client erases session_id from its features - f0.session_id = client.session_id - f1 = client.call(messages.Initialize(session_id=f0.session_id)) +def test_features(session: Session): + raise Exception("TEST NEEDS TO BE REMADE") + f0 = session.features + # session erases session_id from its features + f0.session_id = session.session_id + f1 = session.call(messages.Initialize(session_id=f0.session_id)) assert f0 == f1 -def test_capabilities(client: Client): - assert (messages.Capability.Translations in client.features.capabilities) == ( - client.model is not models.T1B1 +def test_capabilities(session: Session): + assert (messages.Capability.Translations in session.features.capabilities) == ( + session.model is not models.T1B1 ) -def test_ping(client: Client): - ping = client.call(messages.Ping(message="ahoj!")) +def test_ping(session: Session): + ping = session.call(messages.Ping(message="ahoj!")) assert ping == messages.Success(message="ahoj!") -def test_device_id_same(client: Client): - id1 = client.get_device_id() - client.init_device() - id2 = client.get_device_id() +def test_device_id_same(session: Session): + raise Exception("TEST NEEDS TO BE REMADE") + + id1 = session.client.get_device_id() + # session.init_device() + id2 = session.client.get_device_id() # ID must be at least 12 characters assert len(id1) >= 12 @@ -49,10 +52,11 @@ def test_device_id_same(client: Client): assert id1 == id2 -def test_device_id_different(client: Client): - id1 = client.get_device_id() - device.wipe(client) - id2 = client.get_device_id() +def test_device_id_different(session: Session): + raise Exception("TEST NEEDS TO BE REMADE") + id1 = session.get_device_id() + device.wipe(session) + id2 = session.get_device_id() # Device ID must be fresh after every reset assert id1 != id2 diff --git a/tests/device_tests/test_bip32_speed.py b/tests/device_tests/test_bip32_speed.py index 1d184c7e4a8..84d8cf9ae59 100644 --- a/tests/device_tests/test_bip32_speed.py +++ b/tests/device_tests/test_bip32_speed.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import H_ @@ -29,47 +29,47 @@ ] -def test_public_ckd(client: Client): +def test_public_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() - btc.get_address(client, "Bitcoin", range(depth)) + btc.get_address(session, "Bitcoin", range(depth)) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_private_ckd(client: Client): +def test_private_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() address_n = [H_(-i) for i in range(-depth, 0)] - btc.get_address(client, "Bitcoin", address_n) + btc.get_address(session, "Bitcoin", address_n) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_cache(client: Client): +def test_cache(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) + btc.get_address(session, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) nocache_time = time.time() - start start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) + btc.get_address(session, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) cache_time = time.time() - start print("NOCACHE TIME", nocache_time) diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 706745a1981..ce2286487de 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -20,62 +20,65 @@ from trezorlib import btc, device from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path PIN = "1234" -def _assert_busy(client: Client, should_be_busy: bool, screen: str = "Homescreen"): - assert client.features.busy is should_be_busy - if client.layout_type is not LayoutType.T1: +def _assert_busy(session: Session, should_be_busy: bool, screen: str = "Homescreen"): + assert session.features.busy is should_be_busy + if session.client.layout_type is not LayoutType.T1: if should_be_busy: - assert "CoinJoinProgress" in client.debug.read_layout().all_components() + assert ( + "CoinJoinProgress" + in session.client.debug.read_layout().all_components() + ) else: - assert client.debug.read_layout().main_component() == screen + assert session.client.debug.read_layout().main_component() == screen @pytest.mark.setup_client(pin=PIN) -def test_busy_state(client: Client): - _assert_busy(client, False, "Lockscreen") - assert client.features.unlocked is False +def test_busy_state(session: Session): + _assert_busy(session, False, "Lockscreen") + assert session.features.unlocked is False # Show busy dialog for 1 minute. - device.set_busy(client, expiry_ms=60 * 1000) - _assert_busy(client, True) - assert client.features.unlocked is False + device.set_busy(session, expiry_ms=60 * 1000) + _assert_busy(session, True) + assert session.features.unlocked is False - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) - client.refresh_features() - _assert_busy(client, True) - assert client.features.unlocked is True + session.refresh_features() + _assert_busy(session, True) + assert session.features.unlocked is True # Hide the busy dialog. - device.set_busy(client, None) + device.set_busy(session, None) - _assert_busy(client, False) - assert client.features.unlocked is True + _assert_busy(session, False) + assert session.features.unlocked is True @pytest.mark.models("core") -def test_busy_expiry_core(client: Client): +def test_busy_expiry_core(session: Session): WAIT_TIME_MS = 1500 TOLERANCE = 1000 - _assert_busy(client, False) + _assert_busy(session, False) # Start a timer start = time.monotonic() # Show the busy dialog. - device.set_busy(client, expiry_ms=WAIT_TIME_MS) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=WAIT_TIME_MS) + _assert_busy(session, True) # Wait until the layout changes - client.debug.wait_layout() + session.client.debug.wait_layout() end = time.monotonic() # Check that the busy dialog was shown for at least WAIT_TIME_MS. @@ -84,26 +87,26 @@ def test_busy_expiry_core(client: Client): # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) @pytest.mark.flaky(max_runs=5) @pytest.mark.models("legacy") -def test_busy_expiry_legacy(client: Client): - _assert_busy(client, False) +def test_busy_expiry_legacy(session: Session): + _assert_busy(session, False) # Show the busy dialog. - device.set_busy(client, expiry_ms=1500) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=1500) + _assert_busy(session, True) # Hasn't expired yet. time.sleep(0.1) - _assert_busy(client, True) + _assert_busy(session, True) # Wait for it to expire. Add some tolerance to account for CI/hardware slowness. time.sleep(4.0) # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index 108a7c10349..1f5aabb5537 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -18,7 +18,7 @@ import trezorlib.messages as m from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled from ..common import TEST_ADDRESS_N @@ -36,15 +36,17 @@ ), ], ) -def test_cancel_message_via_cancel(client: Client, message): +def test_cancel_message_via_cancel(session: Session, message): + raise Exception("TREZOR FW CANNOT HANDLE CANCEL FOR NOW, TODO FIX") + def input_flow(): yield - client.cancel() + session.cancel() - with client, pytest.raises(Cancelled): - client.set_expected_responses([m.ButtonRequest(), m.Failure()]) + with session, session.client as client, pytest.raises(Cancelled): + session.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_input_flow(input_flow) - client.call(message) + session.call(message) @pytest.mark.parametrize( @@ -59,51 +61,56 @@ def input_flow(): ), ], ) -def test_cancel_message_via_initialize(client: Client, message): - resp = client.call_raw(message) +def test_cancel_message_via_initialize(session: Session, message): + raise Exception("TREZOR FW CANNOT HANDLE CANCEL FOR NOW, TODO FIX") + + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client._raw_write(m.Initialize()) + session._write(m.ButtonAck()) + session._write(m.Initialize()) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.Features) @pytest.mark.models("core") -def test_cancel_on_paginated(client: Client): +def test_cancel_on_paginated(session: Session): """Check that device is responsive on paginated screen. See #1708.""" # In #1708, the device would ignore USB (or UDP) events while waiting for the user # to page through the screen. This means that this testcase, instead of failing, # would get stuck waiting for the _raw_read result. # I'm not spending the effort to modify the testcase to cause a _failure_ if that # happens again. Just be advised that this should not get stuck. + + raise Exception("TREZOR FW CANNOT HANDLE CANCEL FOR NOW, TODO FIX") + message = m.SignMessage( message=b"hello" * 64, address_n=TEST_ADDRESS_N, coin_name="Testnet", ) - resp = client.call_raw(message) + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client.debug.press_yes() + session._write(m.ButtonAck()) + session.client.debug.press_yes() - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.ButtonRequest) # In T2B1, confirm message is no longer paginated by default, # user needs to click info button - if client.layout_type is LayoutType.TR: - client._raw_write(m.ButtonAck()) - client.debug.press_right() - resp = client._raw_read() + if session.client.layout_type is LayoutType.TR: + session._write(m.ButtonAck()) + session.client.debug.press_right() + resp = session._read() assert isinstance(resp, m.ButtonRequest) assert resp.pages is not None - client._raw_write(m.ButtonAck()) + session._write(m.ButtonAck()) - client._raw_write(m.Cancel()) - resp = client._raw_read() + session._write(m.Cancel()) + resp = session._read() assert isinstance(resp, m.Failure) assert resp.code == m.FailureType.ActionCancelled diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 747613db127..9ba6b42eaf9 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device, messages, misc +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path from trezorlib.transport import udp @@ -40,27 +41,32 @@ def test_mnemonic(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12, pin="1234", passphrase="") -def test_pin(client: Client): - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) +def test_pin(session: Session): + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PinMatrixRequest) - state = client.debug.state() - assert state.pin == "1234" - assert state.matrix != "" + with session.client as client: + state = client.debug.state() + assert state.pin == "1234" + assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") - resp = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) - assert isinstance(resp, messages.PassphraseRequest) + pin_encoded = client.debug.encode_pin("1234") + resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + assert isinstance(resp, messages.PassphraseRequest) - resp = client.call_raw(messages.PassphraseAck(passphrase="")) - assert isinstance(resp, messages.Address) + resp = session.call_raw(messages.PassphraseAck(passphrase="")) + assert isinstance(resp, messages.Address) @pytest.mark.models("core") -def test_softlock_instability(client: Client): +def test_softlock_instability(session: Session): + raise Exception("THIS NEEDS TO BE FIXED") + def load_device(): debuglink.load_device( - client, + session, mnemonic=MNEMONIC12, pin="1234", passphrase_protection=False, @@ -68,27 +74,27 @@ def load_device(): ) # start from a clean slate: - resp = client.debug.reseed(0) + resp = session.client.debug.reseed(0) if isinstance(resp, messages.Failure) and not isinstance( - client.transport, udp.UdpTransport + session.client.transport, udp.UdpTransport ): pytest.xfail("reseed only supported on emulator") - device.wipe(client) - entropy_after_wipe = misc.get_entropy(client, 16) + device.wipe(session) + entropy_after_wipe = misc.get_entropy(session, 16) # configure and wipe the device load_device() - client.debug.reseed(0) - device.wipe(client) - assert misc.get_entropy(client, 16) == entropy_after_wipe + session.client.debug.reseed(0) + device.wipe(session) + assert misc.get_entropy(session, 16) == entropy_after_wipe load_device() # the device has PIN -> lock it - client.call(messages.LockDevice()) - client.debug.reseed(0) + session.call(messages.LockDevice()) + session.client.debug.reseed(0) # wipe_device should succeed with no need to unlock - device.wipe(client) + device.wipe(session) # the device is now trying to run the lockscreen, which attempts to unlock. # If the device actually called config.unlock(), it would use additional randomness. # That is undesirable. Assert that the returned entropy is still the same. - assert misc.get_entropy(client, 16) == entropy_after_wipe + assert misc.get_entropy(session, 16) == entropy_after_wipe diff --git a/tests/device_tests/test_firmware_hash.py b/tests/device_tests/test_firmware_hash.py index 50eb063c2b3..217be1c45d9 100644 --- a/tests/device_tests/test_firmware_hash.py +++ b/tests/device_tests/test_firmware_hash.py @@ -3,7 +3,7 @@ import pytest from trezorlib import firmware, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session # size of FIRMWARE_AREA, see core/embed/models/model_*_layout.c FIRMWARE_LENGTHS = { @@ -15,35 +15,35 @@ } -def test_firmware_hash_emu(client: Client) -> None: - if client.features.fw_vendor != "EMULATOR": +def test_firmware_hash_emu(session: Session) -> None: + if session.features.fw_vendor != "EMULATOR": pytest.skip("Only for emulator") - data = b"\xff" * FIRMWARE_LENGTHS[client.model] + data = b"\xff" * FIRMWARE_LENGTHS[session.model] expected_hash = blake2s(data).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash == expected_hash challenge = b"Hello Trezor" expected_hash = blake2s(data, key=challenge).digest() - hash = firmware.get_hash(client, challenge) + hash = firmware.get_hash(session, challenge) assert hash == expected_hash -def test_firmware_hash_hw(client: Client) -> None: - if client.features.fw_vendor == "EMULATOR": +def test_firmware_hash_hw(session: Session) -> None: + if session.features.fw_vendor == "EMULATOR": pytest.skip("Only for hardware") # TODO get firmware image from outside the environment, check for actual result challenge = b"Hello Trezor" - empty_data = b"\xff" * FIRMWARE_LENGTHS[client.model] + empty_data = b"\xff" * FIRMWARE_LENGTHS[session.model] empty_hash = blake2s(empty_data).digest() empty_hash_challenge = blake2s(empty_data, key=challenge).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash != empty_hash - hash2 = firmware.get_hash(client, challenge) + hash2 = firmware.get_hash(session, challenge) assert hash != hash2 assert hash2 != empty_hash_challenge diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index d313608ee20..a80ab92ecf1 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -23,7 +23,7 @@ from trezorlib import debuglink, device, exceptions, messages, models from trezorlib._internal import translations -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import message_filters from ..translations import ( @@ -57,228 +57,236 @@ def get_ping_title(lang: str) -> str: @pytest.fixture -def client(client: Client) -> Iterator[Client]: - lang_before = client.features.language or "" +def session(session: Session) -> Iterator[Session]: + lang_before = session.features.language or "" try: - set_language(client, "en") - yield client + set_language(session, "en") + yield session finally: - set_language(client, lang_before[:2]) + set_language(session, lang_before[:2]) -def _check_ping_screen_texts(client: Client, title: str, right_button: str) -> None: - def ping_input_flow(client: Client, title: str, right_button: str): +def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> None: + def ping_input_flow(session: Session, title: str, right_button: str): yield - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert layout.title().upper() == title.upper() assert layout.button_contents()[-1].upper() == right_button.upper() - client.debug.press_yes() + session.client.debug.press_yes() # TT does not have a right button text (but a green OK tick) - if client.model in (models.T2T1, models.T3T1): + if session.model in (models.T2T1, models.T3T1): right_button = "-" - with client: + with session.client as client: client.watch_layout(True) - client.set_input_flow(ping_input_flow(client, title, right_button)) - ping = client.call(messages.Ping(message="ahoj!", button_protection=True)) + client.set_input_flow(ping_input_flow(session, title, right_button)) + ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") -def test_error_too_long(client: Client): - assert client.features.language == "en-US" +def test_error_too_long(session: Session): + assert session.features.language == "en-US" # Translations too long # Sending more than allowed by the flash capacity - max_length = MAX_DATA_LENGTH[client.model] - with pytest.raises(exceptions.TrezorFailure, match="Translations too long"), client: + max_length = MAX_DATA_LENGTH[session.model] + with pytest.raises( + exceptions.TrezorFailure, match="Translations too long" + ), session: bad_data = (max_length + 1) * b"a" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_length(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_length(session: Session): + assert session.features.language == "en-US" # Invalid data length # Sending more data than advertised in the header - with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), client: - good_data = build_and_sign_blob("cs", client) + with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data + b"abcd" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_header_magic(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_header_magic(session: Session): + assert session.features.language == "en-US" # Invalid header magic # Does not match the expected magic with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = 4 * b"a" + good_data[4:] - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_hash(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_hash(session: Session): + assert session.features.language == "en-US" # Invalid data hash # Changing the data after their hash has been calculated with pytest.raises( exceptions.TrezorFailure, match="Translation data verification failed" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data[:-8] + 8 * b"a" device.change_language( - client, + session, language_data=bad_data, ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_version_mismatch(client: Client): - assert client.features.language == "en-US" +def test_error_version_mismatch(session: Session): + assert session.features.language == "en-US" # Translations version mismatch # Change the version to one not matching the current device with pytest.raises( exceptions.TrezorFailure, match="Translations version mismatch" - ), client: - blob = prepare_blob("cs", client.model, (3, 5, 4, 0)) + ), session: + blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) device.change_language( - client, + session, language_data=sign_blob(blob), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_signature(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_signature(session: Session): + assert session.features.language == "en-US" # Invalid signature # Changing the data in the signature section with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - blob = prepare_blob("cs", client.model, client.version) + ), session: + blob = prepare_blob("cs", session.model, session.version) blob.proof = translations.Proof( merkle_proof=[], sigmask=0b011, signature=b"a" * 64, ) device.change_language( - client, + session, language_data=blob.build(), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) @pytest.mark.parametrize("lang", LANGUAGES) -def test_full_language_change(client: Client, lang: str): - assert client.features.language == "en-US" - assert client.features.language_version_matches is True +def test_full_language_change(session: Session, lang: str): + raise Exception("Investigate why it fails") + assert session.features.language == "en-US" + assert session.features.language_version_matches is True # Setting selected language - set_language(client, lang) - assert client.features.language[:2] == lang - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + set_language(session, lang) + assert session.features.language[:2] == lang + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) # Setting the default language via empty data - set_language(client, "en") - assert client.features.language == "en-US" - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + set_language(session, "en") + assert session.features.language == "en-US" + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_language_is_removed_after_wipe(client: Client): - assert client.features.language == "en-US" +def test_language_is_removed_after_wipe(session: Session): + raise Exception("Test is not ressurected after") + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Setting cs language - set_language(client, "cs") - assert client.features.language == "cs-CZ" + set_language(session, "cs") + assert session.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Wipe device - device.wipe(client) - assert client.features.language == "en-US" + device.wipe(session) + assert session.features.language == "en-US" # Load it again debuglink.load_device( - client, + session, mnemonic=" ".join(["all"] * 12), pin=None, passphrase_protection=False, label="test", ) - assert client.features.language == "en-US" + assert session.features.language == "en-US" + + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) +def test_translations_renders_on_screen(session: Session): + raise Exception("Investigate why it fails") -def test_translations_renders_on_screen(client: Client): czech_data = get_lang_json("cs") # Setting some values of words__confirm key and checking that in ping screen title - assert client.features.language == "en-US" + assert session.features.language == "en-US" # Normal english - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Normal czech - set_language(client, "cs") - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + set_language(session, "cs") + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Modified czech - changed value czech_data_copy = deepcopy(czech_data) new_czech_confirm = "ABCD" czech_data_copy["translations"]["words__confirm"] = new_czech_confirm device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, new_czech_confirm, get_ping_button("cs")) + _check_ping_screen_texts(session, new_czech_confirm, get_ping_button("cs")) # Modified czech - key deleted completely, english is shown czech_data_copy = deepcopy(czech_data) del czech_data_copy["translations"]["words__confirm"] device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("cs")) -def test_reject_update(client: Client): - assert client.features.language == "en-US" +def test_reject_update(session: Session): + raise Exception("Investigate why it fails") + + assert session.features.language == "en-US" lang = "cs" - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) def input_flow_reject(): yield - client.debug.press_no() + session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), client: + with pytest.raises(exceptions.Cancelled), session, session.client as client: client.set_input_flow(input_flow_reject) - device.change_language(client, language_data) + device.change_language(session, language_data) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def _maybe_confirm_set_language( - client: Client, lang: str, show_display: bool | None, is_displayed: bool + session: Session, lang: str, show_display: bool | None, is_displayed: bool ) -> None: - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) CHUNK_SIZE = 1024 @@ -302,21 +310,21 @@ def chunks(data, size): else: expected_responses = expected_responses_silent - with client: - client.set_expected_responses(expected_responses) - device.change_language(client, language_data, show_display=show_display) - assert client.features.language is not None - assert client.features.language[:2] == lang + with session: + session.set_expected_responses(expected_responses) + device.change_language(session, language_data, show_display=show_display) + assert session.features.language is not None + assert session.features.language[:2] == lang # explicitly handle the cases when expected_responses are correct for # change_language but incorrect for selected is_displayed mode (otherwise the # user would get an unhelpful generic expected_responses mismatch) - if is_displayed and client.actual_responses == expected_responses_silent: + if is_displayed and session.actual_responses == expected_responses_silent: raise AssertionError("Change should have been visible but was silent") - if not is_displayed and client.actual_responses == expected_responses_confirm: + if not is_displayed and session.actual_responses == expected_responses_confirm: raise AssertionError("Change should have been silent but was visible") # if the expected_responses do not match either, the generic error message will - # be raised by the client context manager + # be raised by the session context manager @pytest.mark.parametrize( @@ -328,61 +336,63 @@ def chunks(data, size): ], ) @pytest.mark.setup_client(uninitialized=True) -def test_silent_first_install(client: Client, show_display: bool, is_displayed: bool): - assert not client.features.initialized - _maybe_confirm_set_language(client, "cs", show_display, is_displayed) +def test_silent_first_install(session: Session, show_display: bool, is_displayed: bool): + assert not session.features.initialized + _maybe_confirm_set_language(session, "cs", show_display, is_displayed) @pytest.mark.parametrize("show_display", (True, None)) -def test_switch_from_english(client: Client, show_display: bool | None): - assert client.features.initialized - assert client.features.language == "en-US" - _maybe_confirm_set_language(client, "cs", show_display, True) +def test_switch_from_english(session: Session, show_display: bool | None): + assert session.features.initialized + assert session.features.language == "en-US" + _maybe_confirm_set_language(session, "cs", show_display, True) -def test_switch_from_english_not_silent(client: Client): - assert client.features.initialized - assert client.features.language == "en-US" +def test_switch_from_english_not_silent(session: Session): + assert session.features.initialized + assert session.features.language == "en-US" with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) @pytest.mark.setup_client(uninitialized=True) -def test_switch_language(client: Client): - assert not client.features.initialized - assert client.features.language == "en-US" +def test_switch_language(session: Session): + assert not session.features.initialized + assert session.features.language == "en-US" # switch to Czech silently - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) # switch to French silently with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "fr", False, False) + _maybe_confirm_set_language(session, "fr", False, False) # switch to French with display, explicitly - _maybe_confirm_set_language(client, "fr", True, True) + _maybe_confirm_set_language(session, "fr", True, True) # switch back to Czech with display, implicitly - _maybe_confirm_set_language(client, "cs", None, True) + _maybe_confirm_set_language(session, "cs", None, True) -def test_header_trailing_data(client: Client): +def test_header_trailing_data(session: Session): """Adding trailing data to _header_ section specifically must be accepted by firmware, as long as the blob is otherwise valid and signed. (this ensures forwards compatibility if we extend the header) """ - assert client.features.language == "en-US" + raise Exception("Investigate why it fails") + + assert session.features.language == "en-US" lang = "cs" - blob = prepare_blob(lang, client.model, client.version) + blob = prepare_blob(lang, session.model, session.version) blob.header_bytes += b"trailing dataa" assert len(blob.header_bytes) % 2 == 0, "Trailing data must keep the 2-alignment" language_data = sign_blob(blob) - device.change_language(client, language_data) - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + device.change_language(session, language_data) + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 5ff88b017f4..04adf56f957 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device, exceptions, messages, misc, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..input_flows import InputFlowConfirmAllWarnings @@ -30,7 +30,7 @@ EXPECTED_RESPONSES_NOPIN = [ messages.ButtonRequest(), messages.Success, - messages.Features, + # messages.Features, ] EXPECTED_RESPONSES_PIN_T1 = [messages.PinMatrixRequest()] + EXPECTED_RESPONSES_NOPIN EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPIN @@ -38,7 +38,7 @@ EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES = [ messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] PIN4 = "1234" @@ -50,173 +50,180 @@ TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:. from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_ping(client: Client): - with client: - client.set_expected_responses([messages.Success]) - res = client.ping("random data") - assert res == "random data" +def test_ping(session: Session): + with session: + session.set_expected_responses([messages.Success]) + res = session.call(messages.Ping(message="random data")) + assert res.message == "random data" - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.Success, ] ) - res = client.ping("random data", button_protection=True) - assert res == "random data" + res = session.call( + messages.Ping(message="random data 2", button_protection=True) + ) + assert res.message == "random data 2" diff --git a/tests/device_tests/test_msg_sd_protect.py b/tests/device_tests/test_msg_sd_protect.py index fb305613825..2e1f3b5bf53 100644 --- a/tests/device_tests/test_msg_sd_protect.py +++ b/tests/device_tests/test_msg_sd_protect.py @@ -17,7 +17,7 @@ import pytest from trezorlib import debuglink, device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op @@ -26,64 +26,69 @@ pytestmark = [pytest.mark.models("core", skip="safe3"), pytest.mark.sd_card] -def test_enable_disable(client: Client): - assert client.features.sd_protection is False +def test_enable_disable(session: Session): + assert session.features.sd_protection is False # Disabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.DISABLE) + device.sd_protect(session, Op.DISABLE) # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Enabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False -def test_refresh(client: Client): - assert client.features.sd_protection is False +def test_refresh(session: Session): + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is True + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False # Refreshing SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is False + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is False -def test_wipe(client: Client): +def test_wipe(session: Session): # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Wipe device (this wipes internal storage) - device.wipe(client) - assert client.features.sd_protection is False + raise Exception("TEST FAILS AFTER WIPE DEVICE") + device.wipe(session) + assert session.features.sd_protection is False # Restore device to working status debuglink.load_device( - client, mnemonic=MNEMONIC12, pin=None, passphrase_protection=False, label="test" + session, + mnemonic=MNEMONIC12, + pin=None, + passphrase_protection=False, + label="test", ) - assert client.features.sd_protection is False + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) + device.sd_protect(session, Op.REFRESH) diff --git a/tests/device_tests/test_msg_show_device_tutorial.py b/tests/device_tests/test_msg_show_device_tutorial.py index 52904c50c50..f6a083879f3 100644 --- a/tests/device_tests/test_msg_show_device_tutorial.py +++ b/tests/device_tests/test_msg_show_device_tutorial.py @@ -17,11 +17,11 @@ import pytest from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("safe") -def test_tutorial(client: Client): - device.show_device_tutorial(client) - assert client.features.initialized is False +def test_tutorial(session: Session): + device.show_device_tutorial(session) + assert session.features.initialized is False diff --git a/tests/device_tests/test_passphrase_slip39_advanced.py b/tests/device_tests/test_passphrase_slip39_advanced.py index 64ef1f5e577..697ec442703 100644 --- a/tests/device_tests/test_passphrase_slip39_advanced.py +++ b/tests/device_tests/test_passphrase_slip39_advanced.py @@ -34,13 +34,12 @@ def test_128bit_passphrase(client: Client): xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mkKDUMRR1CcK8eLAzCZAjKnNbCquPoWPxN" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare @@ -53,11 +52,10 @@ def test_256bit_passphrase(client: Client): xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mxVtGxUJ898WLzPMmy6PT1FDHD1GUCWGm7" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare diff --git a/tests/device_tests/test_passphrase_slip39_basic.py b/tests/device_tests/test_passphrase_slip39_basic.py index de0e7a734b2..120a6f556ec 100644 --- a/tests/device_tests/test_passphrase_slip39_basic.py +++ b/tests/device_tests/test_passphrase_slip39_basic.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -28,14 +28,14 @@ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6, passphrase="TREZOR") -def test_3of6_passphrase(client: Client): +def test_3of6_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mi4HXfRJAqCDyEdet5veunBvXLTKSxpuim" @@ -46,25 +46,25 @@ def test_3of6_passphrase(client: Client): ), passphrase="TREZOR", ) -def test_2of5_passphrase(client: Client): +def test_2of5_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mjXH4pN7TtbHp3tWLqVKktKuaQeByHMoBZ" @pytest.mark.setup_client( mnemonic=MNEMONIC_SLIP39_BASIC_EXT_20_2of3, passphrase="TREZOR" ) -def test_2of3_ext_passphrase(client: Client): +def test_2of3_ext_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: xprv9s21ZrQH143K4FS1qQdXYAFVAHiSAnjj21YAKGh2CqUPJ2yQhMmYGT4e5a2tyGLiVsRgTEvajXkxhg92zJ8zmWZas9LguQWz7WZShfJg6RS """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "moELJhDbGK41k6J2ePYh2U8uc5qskC663C" diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index ee58790c046..963825192e4 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -19,7 +19,7 @@ import pytest from trezorlib import messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import PinException from ..common import check_pin_backoff_time, get_test_address @@ -32,18 +32,19 @@ @pytest.mark.setup_client(pin=None) -def test_no_protection(client: Client): - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_no_protection(session: Session): + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) -def test_correct_pin(client: Client): - with client: +def test_correct_pin(session: Session): + with session, session.client as client: client.use_pin_sequence([PIN4]) + session.lock() # TODO is the lock here necessary/correctly? # Expected responses differ between T1 and TT - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ (is_t1, messages.PinMatrixRequest), ( @@ -53,45 +54,46 @@ def test_correct_pin(client: Client): messages.Address, ] ) - # client.set_expected_responses([messages.ButtonRequest, messages.Address]) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_incorrect_pin_t1(client: Client): +def test_incorrect_pin_t1(session: Session): with pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + session.client.use_pin_sequence([BAD_PIN]) + get_test_address(session) @pytest.mark.models("core") -def test_incorrect_pin_t2(client: Client): - with client: +def test_incorrect_pin_t2(session: Session): + session.lock() # TODO is the lock here necessary/correctly? + with session, session.client as client: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt client.use_pin_sequence([BAD_PIN, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.Address, ] ) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_exponential_backoff_t1(client: Client): +def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with client, pytest.raises(PinException): + with session, session.client as client, pytest.raises(PinException): client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") -def test_exponential_backoff_t2(client: Client): - with client: +def test_exponential_backoff_t2(session: Session): + with session.client as client: IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) client.set_input_flow(IF.get()) - get_test_address(client) + session.lock() + get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 22ffb13b7f9..74e507c1201 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -18,6 +18,7 @@ from trezorlib import btc, device, messages, misc, models from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,9 +44,9 @@ pytestmark = pytest.mark.setup_client(pin=PIN4, passphrase=True) -def _pin_request(client: Client): +def _pin_request(session: Session): """Get appropriate PIN request for each model""" - if client.model is models.T1B1: + if session.model is models.T1B1: return messages.PinMatrixRequest else: return messages.ButtonRequest(code=B.PinEntry) @@ -58,171 +59,173 @@ def _assert_protection( with client: client.use_pin_sequence([PIN4]) client.ensure_unlocked() + client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase - client.clear_session() + # TODO session.clear_session() -def test_initialize(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses([messages.Features]) - client.init_device() +def test_initialize(session: Session): + _assert_protection(session.client) + with session: + session.set_expected_responses([messages.Features]) + raise Exception("INITIALIZE IS DISABLED") + # TODO session.init_device() @pytest.mark.models("core") @pytest.mark.setup_client(pin=PIN4) @pytest.mark.parametrize("passphrase", (True, False)) -def test_passphrase_reporting(client: Client, passphrase): +def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, use_passphrase=passphrase) + device.apply_settings(session, use_passphrase=passphrase) - client.lock() + session.lock() # on a locked device, passphrase_protection should be None - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + assert session.features.unlocked is False + assert session.features.passphrase_protection is None # on an unlocked device, protection should be reported accurately - _assert_protection(client, pin=True, passphrase=passphrase) + _assert_protection(session.client, pin=True, passphrase=passphrase) # after re-locking, the setting should be hidden again - client.lock() - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + session.lock() + assert session.features.unlocked is False + assert session.features.passphrase_protection is None -def test_apply_settings(client: Client): - _assert_protection(client) - with client: +def test_apply_settings(session: Session): + _assert_protection(session.client) + with session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.Success, messages.Features, ] - ) # TrezorClient reinitializes device - device.apply_settings(client, label="nazdar") + ) # TrezorSession reinitializes device + device.apply_settings(session, label="nazdar") @pytest.mark.models("legacy") -def test_change_pin_t1(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t1(session: Session): + _assert_protection(session.client) + with session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - _pin_request(client), + _pin_request(session), + _pin_request(session), + _pin_request(session), messages.Success, messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.models("core") -def test_change_pin_t2(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t2(session: Session): + _assert_protection(session.client) + with session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - (client.layout_type is LayoutType.TR, messages.ButtonRequest), - _pin_request(client), + _pin_request(session), + _pin_request(session), + (session.client.layout_type is LayoutType.TR, messages.ButtonRequest), + _pin_request(session), messages.ButtonRequest, messages.Success, messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.setup_client(pin=None, passphrase=False) -def test_ping(client: Client): - _assert_protection(client, pin=False, passphrase=False) - with client: - client.set_expected_responses([messages.ButtonRequest, messages.Success]) - client.ping("msg", True) +def test_ping(session: Session): + _assert_protection(session.client, pin=False, passphrase=False) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + session.call(messages.Ping("msg", True)) -def test_get_entropy(client: Client): - _assert_protection(client) - with client: +def test_get_entropy(session: Session): + _assert_protection(session.client) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest(code=B.ProtectCall), messages.Entropy, ] ) - misc.get_entropy(client, 10) + misc.get_entropy(session, 10) -def test_get_public_key(client: Client): - _assert_protection(client) - with client: +def test_get_public_key(session: Session): + _assert_protection(session.client) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.PassphraseRequest, messages.PublicKey, ] ) - btc.get_public_node(client, []) + btc.get_public_node(session, []) -def test_get_address(client: Client): - _assert_protection(client) - with client: +def test_get_address(session: Session): + _assert_protection(session.client) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.PassphraseRequest, messages.Address, ] ) - get_test_address(client) + get_test_address(session) -def test_wipe_device(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( +def test_wipe_device(session: Session): + _assert_protection(session.client) + with session: + session.set_expected_responses( [messages.ButtonRequest, messages.Success, messages.Features] ) - device.wipe(client) + device.wipe(session) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_reset_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - with WITH_MOCK_URANDOM, client: - client.set_expected_responses( +def test_reset_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + with WITH_MOCK_URANDOM, session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.EntropyRequest] + [messages.ButtonRequest] * 24 + [messages.Success, messages.Features] ) device.reset( - client, + session, strength=128, passphrase_protection=True, pin_protection=False, @@ -232,7 +235,7 @@ def test_reset_device(client: Client): with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.reset` has its own check - client.call( + session.call( messages.ResetDevice( strength=128, passphrase_protection=True, @@ -244,30 +247,30 @@ def test_reset_device(client: Client): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_recovery_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - client.use_mnemonic(MNEMONIC12) - with client: - client.set_expected_responses( +def test_recovery_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + session.client.use_mnemonic(MNEMONIC12) + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.WordRequest] * 24 + [messages.Success, messages.Features] ) device.recover( - client, + session, 12, False, False, "label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.recover` has its own check - client.call( + session.call( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -277,13 +280,13 @@ def test_recovery_device(client: Client): ) -def test_sign_message(client: Client): - _assert_protection(client) - with client: +def test_sign_message(session: Session): + _assert_protection(session.client) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.PassphraseRequest, messages.ButtonRequest, messages.ButtonRequest, @@ -291,15 +294,15 @@ def test_sign_message(client: Client): ] ) btc.sign_message( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" ) @pytest.mark.models("legacy") -def test_verify_message_t1(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( +def test_verify_message_t1(session: Session): + _assert_protection(session.client) + with session: + session.set_expected_responses( [ messages.ButtonRequest, messages.ButtonRequest, @@ -308,7 +311,7 @@ def test_verify_message_t1(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -319,13 +322,13 @@ def test_verify_message_t1(client: Client): @pytest.mark.models("core") -def test_verify_message_t2(client: Client): - _assert_protection(client) - with client: +def test_verify_message_t2(session: Session): + _assert_protection(session.client) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.ButtonRequest, messages.ButtonRequest, @@ -333,7 +336,7 @@ def test_verify_message_t2(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -343,7 +346,7 @@ def test_verify_message_t2(client: Client): ) -def test_signtx(client: Client): +def test_signtx(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -359,17 +362,17 @@ def test_signtx(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _assert_protection(client) - with client: + _assert_protection(session.client) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.PassphraseRequest, request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -382,7 +385,7 @@ def test_signtx(client: Client): request_finished(), ] ) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) # def test_firmware_erase(): @@ -393,29 +396,29 @@ def test_signtx(client: Client): @pytest.mark.setup_client(pin=PIN4, passphrase=False) -def test_unlocked(client: Client): - assert client.features.unlocked is False +def test_unlocked(session: Session): + assert session.features.unlocked is False - _assert_protection(client, passphrase=False) - with client: + _assert_protection(session.client, passphrase=False) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([_pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([_pin_request(session), messages.Address]) + get_test_address(session) - client.init_device() - assert client.features.unlocked is True - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + # TODO session.init_device() + assert session.features.unlocked is True + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) @pytest.mark.setup_client(pin=None, passphrase=True) -def test_passphrase_cached(client: Client): - _assert_protection(client, pin=False) - with client: - client.set_expected_responses([messages.PassphraseRequest, messages.Address]) - get_test_address(client) - - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_passphrase_cached(session: Session): + _assert_protection(session.client, pin=False) + with session: + session.set_expected_responses([messages.PassphraseRequest, messages.Address]) + get_test_address(session) + + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 36afeae9059..01cc6d49f7c 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -17,8 +17,8 @@ import pytest -from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import device, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from .. import translations as TR @@ -35,194 +35,198 @@ @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @WITH_MOCK_URANDOM -def test_repeated_backup(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) @WITH_MOCK_URANDOM -def test_repeated_backup_upgrade_single(client: Client): +def test_repeated_backup_upgrade_single(session: Session): assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing - assert client.features.backup_type == messages.BackupType.Slip39_Single_Extendable + assert session.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # backup type was upgraded: - assert client.features.backup_type == messages.BackupType.Slip39_Basic_Extendable + assert session.features.backup_type == messages.BackupType.Slip39_Basic_Extendable # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @WITH_MOCK_URANDOM -def test_repeated_backup_cancel(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_cancel(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() TR.assert_in(layout.text_content(), "recovery__unlock_repeated_backup") # send a Cancel message with pytest.raises(Cancelled): - client.call(messages.Cancel()) + session.call(messages.Cancel()) - client.refresh_features() + session.refresh_features() # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @WITH_MOCK_URANDOM -def test_repeated_backup_send_disallowed_message(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_send_disallowed_message(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() TR.assert_in(layout.text_content(), "recovery__unlock_repeated_backup") # send a GetAddress message - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -233,10 +237,13 @@ def test_repeated_backup_send_disallowed_message(client: Client): assert isinstance(resp, messages.Failure) assert "not allowed" in resp.message - assert client.features.backup_availability == messages.BackupAvailability.Available - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.backup_availability == messages.BackupAvailability.Available + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we are still on the confirmation screen! TR.assert_in( - client.debug.read_layout().text_content(), "recovery__unlock_repeated_backup" + session.client.debug.read_layout().text_content(), + "recovery__unlock_repeated_backup", ) + with pytest.raises(exceptions.Cancelled): + session.call(messages.Cancel()) diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index e0c13af944c..c210d11f359 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op @@ -27,101 +27,105 @@ @pytest.mark.sd_card(formatted=False) -def test_sd_format(client: Client): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True +def test_sd_format(session: Session): + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True @pytest.mark.sd_card(formatted=False) -def test_sd_no_format(client: Client): +def test_sd_no_format(session: Session): + debug = session.client.debug + def input_flow(): yield # enable SD protection? - client.debug.press_yes() + debug.press_yes() yield # format SD card - client.debug.press_no() + debug.press_no() - with pytest.raises(TrezorFailure) as e, client: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.set_input_flow(input_flow) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @pytest.mark.sd_card @pytest.mark.setup_client(pin="1234") -def test_sd_protect_unlock(client: Client): - layout = client.debug.read_layout +def test_sd_protect_unlock(session: Session): + raise Exception("FAILS, NOT SURE WHY") + debug = session.client.debug + layout = debug.read_layout def input_flow_enable_sd_protect(): + # debug.press_yes() yield # Enter PIN to unlock device assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input("1234") yield # do you really want to enable SD protection TR.assert_in(layout().text_content(), "sd_card__enable") - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input("1234") yield # you have successfully enabled SD protection TR.assert_in(layout().text_content(), "sd_card__enabled") - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_enable_sd_protect) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): yield # do you really want to change PIN? TR.assert_equals(layout().title(), "pin__title_settings") - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input("1234") yield # enter new PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input("1234") yield # enter new PIN again assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input("1234") yield # Pin change successful TR.assert_in(layout().text_content(), "pin__changed") - client.debug.press_yes() + debug.press_yes() - with client: + with session.client as client: client.watch_layout() client.set_input_flow(input_flow_change_pin) - device.change_pin(client) + device.change_pin(session) - client.debug.erase_sd_card(format=False) + debug.erase_sd_card(format=False) def input_flow_change_pin_format(): yield # do you really want to change PIN? TR.assert_equals(layout().title(), "pin__title_settings") - client.debug.press_yes() - + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input("1234") yield # SD card problem TR.assert_in_multiple( layout().text_content(), ["sd_card__unplug_and_insert_correct", "sd_card__insert_correct_card"], ) - client.debug.press_no() # close + debug.press_no() # close - with client, pytest.raises(TrezorFailure) as e: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.watch_layout() client.set_input_flow(input_flow_change_pin_format) - device.change_pin(client) + device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 3e6b5423938..9f35118370d 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_address from trezorlib.tools import parse_path @@ -35,19 +35,19 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_tezos_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_tezos_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_tezos_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/tezos/test_getpublickey.py b/tests/device_tests/tezos/test_getpublickey.py index 9f5bfcd0f74..8b1e72609d7 100644 --- a/tests/device_tests/tezos/test_getpublickey.py +++ b/tests/device_tests/tezos/test_getpublickey.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_public_key from trezorlib.tools import parse_path @@ -24,11 +24,11 @@ @pytest.mark.altcoin @pytest.mark.tezos @pytest.mark.models("core") -def test_tezos_get_public_key(client: Client): +def test_tezos_get_public_key(session: Session): path = parse_path("m/44h/1729h/0h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkttLhEbVfMC3DhyVVFzdwh8ncRnEWiLD1x8TAuPU7vSJak7RtBX" path = parse_path("m/44h/1729h/1h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkuTPqWjcApwyD3VdJhviKM5C13zGk8c4m87crgFarQboF3Mp56f" diff --git a/tests/device_tests/tezos/test_sign_tx.py b/tests/device_tests/tezos/test_sign_tx.py index 06e17304db6..f70a4934d9c 100644 --- a/tests/device_tests/tezos/test_sign_tx.py +++ b/tests/device_tests/tezos/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, tezos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import dict_to_proto from trezorlib.tools import parse_path @@ -32,10 +32,10 @@ ] -def test_tezos_sign_tx_proposal(client: Client): - with client: +def test_tezos_sign_tx_proposal(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -63,10 +63,10 @@ def test_tezos_sign_tx_proposal(client: Client): assert resp.operation_hash == "opLqntFUu984M7LnGsFvfGW6kWe9QjAz4AfPDqQvwJ1wPM4Si4c" -def test_tezos_sign_tx_multiple_proposals(client: Client): - with client: +def test_tezos_sign_tx_multiple_proposals(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -95,9 +95,9 @@ def test_tezos_sign_tx_multiple_proposals(client: Client): assert resp.operation_hash == "onobSyNgiitGXxSVFJN6949MhUomkkxvH4ZJ2owgWwNeDdntF9Y" -def test_tezos_sing_tx_ballot_yay(client: Client): +def test_tezos_sing_tx_ballot_yay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -119,9 +119,9 @@ def test_tezos_sing_tx_ballot_yay(client: Client): ) -def test_tezos_sing_tx_ballot_nay(client: Client): +def test_tezos_sing_tx_ballot_nay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -142,9 +142,9 @@ def test_tezos_sing_tx_ballot_nay(client: Client): ) -def test_tezos_sing_tx_ballot_pass(client: Client): +def test_tezos_sing_tx_ballot_pass(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -167,9 +167,9 @@ def test_tezos_sing_tx_ballot_pass(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): +def test_tezos_sign_tx_tranasaction(session: Session, chunkify: bool): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -202,9 +202,9 @@ def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): assert resp.operation_hash == "oon8PNUsPETGKzfESv1Epv4535rviGS7RdCfAEKcPvzojrcuufb" -def test_tezos_sign_tx_delegation(client: Client): +def test_tezos_sign_tx_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_15, dict_to_proto( messages.TezosSignTx, @@ -232,9 +232,9 @@ def test_tezos_sign_tx_delegation(client: Client): assert resp.operation_hash == "op79C1tR7wkUgYNid2zC1WNXmGorS38mTXZwtAjmCQm2kG7XG59" -def test_tezos_sign_tx_origination(client: Client): +def test_tezos_sign_tx_origination(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -263,9 +263,9 @@ def test_tezos_sign_tx_origination(client: Client): assert resp.operation_hash == "onmq9FFZzvG2zghNdr1bgv9jzdbzNycXjSSNmCVhXCGSnV3WA9g" -def test_tezos_sign_tx_reveal(client: Client): +def test_tezos_sign_tx_reveal(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH, dict_to_proto( messages.TezosSignTx, @@ -305,9 +305,9 @@ def test_tezos_sign_tx_reveal(client: Client): assert resp.operation_hash == "oo9JFiWTnTSvUZfajMNwQe1VyFN2pqwiJzZPkpSAGfGD57Z6mZJ" -def test_tezos_smart_contract_delegation(client: Client): +def test_tezos_smart_contract_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -342,9 +342,9 @@ def test_tezos_smart_contract_delegation(client: Client): assert resp.operation_hash == "oo75gfQGGPEPChXZzcPPAGtYqCpsg2BS5q9gmhrU3NQP7CEffpU" -def test_tezos_kt_remove_delegation(client: Client): +def test_tezos_kt_remove_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -377,9 +377,9 @@ def test_tezos_kt_remove_delegation(client: Client): assert resp.operation_hash == "ootMi1tXbfoVgFyzJa8iXyR4mnHd5TxLm9hmxVzMVRkbyVjKaHt" -def test_tezos_smart_contract_transfer(client: Client): +def test_tezos_smart_contract_transfer(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -420,9 +420,9 @@ def test_tezos_smart_contract_transfer(client: Client): assert resp.operation_hash == "ooRGGtCmoQDgB36XvQqmM7govc3yb77YDUoa7p2QS7on27wGRns" -def test_tezos_smart_contract_transfer_to_contract(client: Client): +def test_tezos_smart_contract_transfer_to_contract(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 3fd7ca7fd95..9fe849fc3ee 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -17,7 +17,7 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import MNEMONIC12 @@ -30,23 +30,23 @@ @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_add_remove(client: Client): - with client: +def test_add_remove(session: Session): + with session, session.client as client: IF = InputFlowFidoConfirm(client) client.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, 0) + fido.remove_credential(session, 0) # List should be empty. - assert fido.list_credentials(client) == [] + assert fido.list_credentials(session) == [] # Add valid credential #1. - fido.add_credential(client, CRED1) + fido.add_credential(session, CRED1) # Check that the credential was added and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name == "Example" @@ -59,10 +59,10 @@ def test_add_remove(client: Client): assert creds[0].hmac_secret is True # Add valid credential #2, which has same rpId and userId as credential #1. - fido.add_credential(client, CRED2) + fido.add_credential(session, CRED2) # Check that the credential #2 replaced credential #1 and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name is None @@ -79,29 +79,29 @@ def test_add_remove(client: Client): fido.add_credential(client, CRED1[:-2]) # Check that the invalid credential was not added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 # Add valid credential, which has same userId as #2, but different rpId. - fido.add_credential(client, CRED3) + fido.add_credential(session, CRED3) # Check that the credential was added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 2 # Fill up the credential storage to maximum capacity. for cred in CREDS[: RK_CAPACITY - 2]: - fido.add_credential(client, cred) + fido.add_credential(session, cred) # Adding one more valid credential to full storage should fail. with pytest.raises(TrezorFailure): - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) # Removing the index, which is one past the end, should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, RK_CAPACITY) + fido.remove_credential(session, RK_CAPACITY) # Remove index 2. - fido.remove_credential(client, 2) + fido.remove_credential(session, 2) # Adding another valid credential should succeed now. - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) diff --git a/tests/device_tests/webauthn/test_u2f_counter.py b/tests/device_tests/webauthn/test_u2f_counter.py index d99467f2b9d..c140ba54578 100644 --- a/tests/device_tests/webauthn/test_u2f_counter.py +++ b/tests/device_tests/webauthn/test_u2f_counter.py @@ -17,15 +17,15 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.altcoin -def test_u2f_counter(client: Client): - assert fido.get_next_counter(client) == 0 - assert fido.get_next_counter(client) == 1 - fido.set_counter(client, 111111) - assert fido.get_next_counter(client) == 111112 - assert fido.get_next_counter(client) == 111113 - fido.set_counter(client, 0) - assert fido.get_next_counter(client) == 1 +def test_u2f_counter(session: Session): + assert fido.get_next_counter(session) == 0 + assert fido.get_next_counter(session) == 1 + fido.set_counter(session, 111111) + assert fido.get_next_counter(session) == 111112 + assert fido.get_next_counter(session) == 111113 + fido.set_counter(session, 0) + assert fido.get_next_counter(session) == 1 diff --git a/tests/device_tests/zcash/test_sign_tx.py b/tests/device_tests/zcash/test_sign_tx.py index d689c8af969..4d7df800903 100644 --- a/tests/device_tests/zcash/test_sign_tx.py +++ b/tests/device_tests/zcash/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -53,7 +53,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -69,7 +69,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -77,7 +77,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_v4_input(client: Client): +def test_spend_v4_input(session: Session): # 4b6cecb81c825180786ebe07b65bcc76078afc5be0f1c64e08d764005012380d is a v4 tx inp1 = messages.TxInputType( @@ -95,13 +95,13 @@ def test_spend_v4_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -110,7 +110,7 @@ def test_spend_v4_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -126,7 +126,7 @@ def test_spend_v4_input(client: Client): ) -def test_send_to_multisig(client: Client): +def test_send_to_multisig(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/8"), @@ -143,13 +143,13 @@ def test_send_to_multisig(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -158,7 +158,7 @@ def test_send_to_multisig(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -174,7 +174,7 @@ def test_send_to_multisig(client: Client): ) -def test_spend_v5_input(client: Client): +def test_spend_v5_input(session: Session): inp1 = messages.TxInputType( # tmBMyeJebzkP5naji8XUKqLyL1NDwNkgJFt address_n=parse_path("m/44h/1h/0h/0/9"), @@ -190,13 +190,13 @@ def test_spend_v5_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -205,7 +205,7 @@ def test_spend_v5_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -221,7 +221,7 @@ def test_spend_v5_input(client: Client): ) -def test_one_two(client: Client): +def test_one_two(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -243,13 +243,13 @@ def test_one_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -260,7 +260,7 @@ def test_one_two(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -277,7 +277,7 @@ def test_one_two(client: Client): @pytest.mark.models("core") -def test_unified_address(client: Client): +def test_unified_address(session: Session): # identical to the test_one_two # but receiver address is unified with an orchard address inp1 = messages.TxInputType( @@ -301,13 +301,13 @@ def test_unified_address(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -318,7 +318,7 @@ def test_unified_address(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -335,7 +335,7 @@ def test_unified_address(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -365,14 +365,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(1), request_input(0), @@ -383,7 +383,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], @@ -399,7 +399,7 @@ def test_external_presigned(client: Client): ) -def test_refuse_replacement_tx(client: Client): +def test_refuse_replacement_tx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/4"), amount=174998, @@ -437,7 +437,7 @@ def test_refuse_replacement_tx(client: Client): TrezorFailure, match="Replacement transactions are not supported." ): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -447,12 +447,12 @@ def test_refuse_replacement_tx(client: Client): ) -def test_spend_multisig(client: Client): +def test_spend_multisig(session: Session): # Cloned from tests/device_tests/bitcoin/test_multisig.py::test_2_of_3 nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" ).node for index in range(1, 4) ] @@ -482,17 +482,17 @@ def test_spend_multisig(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures1, _ = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -529,10 +529,10 @@ def test_spend_multisig(client: Client): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp3], [out1], diff --git a/tests/input_flows.py b/tests/input_flows.py index bb43b8ef189..da93f2f9f01 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -16,6 +16,7 @@ from trezorlib import messages from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import multipage_content @@ -155,13 +156,15 @@ def input_two_different_pins() -> BRGeneratorType: class InputFlowCodeChangeFail(InputFlowBase): + def __init__( - self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str + self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str ): - super().__init__(client) + super().__init__(session.client) self.current_pin = current_pin self.new_pin_1 = new_pin_1 self.new_pin_2 = new_pin_2 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield # do you want to change pin? @@ -176,7 +179,7 @@ def input_flow_common(self) -> BRGeneratorType: # failed retry yield # enter current pin again - self.client.cancel() + self.session.cancel() class InputFlowWrongPIN(InputFlowBase): diff --git a/tests/translations.py b/tests/translations.py index b411162ae6c..52f857856e3 100644 --- a/tests/translations.py +++ b/tests/translations.py @@ -5,7 +5,7 @@ from trezorlib import cosi, device, models from trezorlib._internal import translations -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from . import common @@ -53,19 +53,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes: def build_and_sign_blob( lang_or_def: translations.JsonDef | Path | str, - client: Client, + session: Session, ) -> bytes: - blob = prepare_blob(lang_or_def, client.model, client.version) + blob = prepare_blob(lang_or_def, session.model, session.version) return sign_blob(blob) -def set_language(client: Client, lang: str): +def set_language(session: Session, lang: str): if lang.startswith("en"): language_data = b"" else: - language_data = build_and_sign_blob(lang, client) - with client: - device.change_language(client, language_data) # type: ignore + language_data = build_and_sign_blob(lang, session) + with session: + device.change_language(session, language_data) # type: ignore def get_lang_json(lang: str) -> translations.JsonDef: