diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index 2b9ea58130d..f6f530a35de 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -49,7 +49,7 @@ jobs: cat $GITHUB_OUTPUT core_firmware: - name: Build firmware + name: Build firmware (${{ matrix.model }}, ${{ matrix.coins }}, ${{ matrix.type }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest strategy: fail-fast: false @@ -57,10 +57,27 @@ jobs: model: [T2T1, T3B1, T3T1, T3W1] coins: [universal, btconly] type: ${{ fromJSON(github.event_name == 'schedule' && '["normal", "debuglink", "production"]' || '["normal", "debuglink"]') }} + protocol: [v1] include: - model: D001 coins: universal type: normal + - model: T2T1 + coins: universal + type: debuglink + protocol: v2 + - model: T2T1 + coins: btconly + type: debuglink + protocol: v2 + - model: T3T1 + coins: universal + type: debuglink + protocol: v2 + - model: T3T1 + coins: btconly + type: debuglink + protocol: v2 exclude: - model: T3W1 type: production @@ -70,6 +87,7 @@ jobs: PYOPT: ${{ matrix.type == 'debuglink' && '0' || '1' }} PRODUCTION: ${{ matrix.type == 'production' && '1' || '0' }} BOOTLOADER_DEVEL: ${{ matrix.model == 'T3W1' && '1' || '0' }} + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -90,7 +108,7 @@ jobs: if: matrix.coins == 'btconly' && matrix.type != 'debuglink' - uses: actions/upload-artifact@v4 with: - name: core-firmware-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }} + name: core-firmware-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }}-protocol_${{ matrix.protocol }} path: | core/build/boardloader/*.bin core/build/bootloader/*.bin @@ -101,7 +119,7 @@ jobs: retention-days: 7 core_emu: - name: Build emu + name: Build emu (${{ matrix.model }}, ${{ matrix.coins }}, ${{ matrix.type }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: param strategy: @@ -112,15 +130,38 @@ jobs: # type: [normal, debuglink] type: [debuglink] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1] exclude: - type: normal asan: asan + include: + - model: T2T1 + coins: universal + type: debuglink + asan: noasan + protocol: v2 + - model: T2T1 + coins: btconly + type: debuglink + asan: noasan + protocol: v2 + - model: T3T1 + coins: universal + type: debuglink + asan: noasan + protocol: v2 + - model: T3T1 + coins: btconly + type: debuglink + asan: noasan + protocol: v2 env: TREZOR_MODEL: ${{ matrix.model }} BITCOIN_ONLY: ${{ matrix.coins == 'universal' && '0' || '1' }} PYOPT: ${{ matrix.type == 'debuglink' && '0' || '1' }} ADDRESS_SANITIZER: ${{ matrix.asan == 'asan' && '1' || '0' }} LSAN_OPTIONS: "suppressions=../../asan_suppressions.txt" + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -132,7 +173,7 @@ jobs: - run: cp core/build/unix/trezor-emu-core core/build/unix/trezor-emu-core-${{ matrix.model }}-${{ matrix.coins }} - uses: actions/upload-artifact@v4 with: - name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }}-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }}-${{ matrix.asan }}-protocol_${{ matrix.protocol }} path: | core/build/unix/trezor-emu-core* core/build/bootloader_emu/bootloader.elf @@ -177,7 +218,7 @@ jobs: retention-days: 2 core_unit_python_test: - name: Python unit tests + name: Python unit tests (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: param strategy: @@ -185,10 +226,12 @@ jobs: matrix: model: [T2T1, T3B1, T3T1, T3W1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1, v2] env: TREZOR_MODEL: ${{ matrix.model }} ADDRESS_SANITIZER: ${{ matrix.asan == 'asan' && '1' || '0' }} LSAN_OPTIONS: "suppressions=../../asan_suppressions.txt" + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -198,7 +241,7 @@ jobs: - run: nix-shell --run "poetry run make -C core test" core_unit_rust_test: - name: Rust unit tests + name: Rust unit tests (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -208,12 +251,14 @@ jobs: matrix: model: [T2T1, T3B1, T3T1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1, v2] env: TREZOR_MODEL: ${{ matrix.model }} ADDRESS_SANITIZER: ${{ matrix.asan == 'asan' && '1' || '0' }} RUSTC_BOOTSTRAP: ${{ matrix.asan == 'asan' && '1' || '0' }} RUSTFLAGS: ${{ matrix.asan == 'asan' && '-Z sanitizer=address' || '' }} LSAN_OPTIONS: "suppressions=../../asan_suppressions.txt" + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -237,7 +282,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-noasan + name: core-emu-${{ matrix.model }}-universal-debuglink-noasan-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -248,7 +293,7 @@ jobs: # See artifacts for a comprehensive report of UI. # See [docs/tests/ui-tests](../tests/ui-tests.md) for more info. core_device_test: - name: Device tests + name: Device tests (${{ matrix.model }}, ${{ matrix.coins }}, ${{ matrix.asan }}, ${{ matrix.lang }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -260,6 +305,13 @@ jobs: coins: [universal, btconly] asan: ${{ fromJSON(needs.param.outputs.asan) }} lang: ${{ fromJSON(needs.param.outputs.test_lang) }} + protocol: [v1] + include: + - model: T2T1 + coins: universal + asan: noasan + lang: en + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} TREZOR_MODEL: ${{ matrix.model }} @@ -268,13 +320,14 @@ jobs: PYTEST_TIMEOUT: ${{ matrix.asan == 'asan' && 600 || 400 }} ACTIONS_DO_UI_TEST: ${{ matrix.coins == 'universal' && matrix.asan == 'noasan' }} TEST_LANG: ${{ matrix.lang }} + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-debuglink-${{ matrix.asan }}-protocol_${{ matrix.protocol }} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -283,7 +336,7 @@ jobs: if: failure() - uses: actions/upload-artifact@v4 with: - name: core-test-device-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.lang }}-${{ matrix.asan }} + name: core-test-device-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.lang }}-${{ matrix.asan }}-protocol_${{ matrix.protocol }} path: tests/trezor.log retention-days: 7 if: always() @@ -299,7 +352,7 @@ jobs: # Click tests - UI. # See [docs/tests/click-tests](../tests/click-tests.md) for more info. core_click_test: - name: Click tests + name: Click tests (${{ matrix.model }}, ${{ matrix.asan }}, ${{ matrix.lang }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -311,6 +364,12 @@ jobs: model: [T2T1, T3B1, T3T1] asan: ${{ fromJSON(needs.param.outputs.asan) }} lang: ${{ fromJSON(needs.param.outputs.test_lang) }} + protocol: [v1] + include: + - model: T2T1 + asan: noasan + lang: en + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} # MULTICORE: 4 # more could interfere with other jobs @@ -322,7 +381,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -332,7 +391,7 @@ jobs: if: ${{ matrix.asan == 'asan' }} - uses: actions/upload-artifact@v4 with: - name: core-test-click-${{ matrix.model }}-${{ matrix.lang }}-${{ matrix.asan }} + name: core-test-click-${{ matrix.model }}-${{ matrix.lang }}-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: tests/trezor.log retention-days: 7 if: always() @@ -349,7 +408,7 @@ jobs: # Upgrade tests. # See [docs/tests/upgrade-tests](../tests/upgrade-tests.md) for more info. core_upgrade_test: - name: Upgrade tests + name: Upgrade tests (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -361,6 +420,7 @@ jobs: # FIXME: T3T1 https://github.com/trezor/trezor-firmware/issues/3595 model: [T2T1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1, v2] env: TREZOR_UPGRADE_TEST: core PYTEST_TIMEOUT: 400 @@ -370,7 +430,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -380,7 +440,7 @@ jobs: # Persistence tests - UI. core_persistence_test: - name: Persistence tests + name: Persistence tests (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -391,6 +451,11 @@ jobs: matrix: model: [T2T1, T3B1, T3T1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1] + include: + - model: T2T1 + asan: noasan + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} PYTEST_TIMEOUT: 400 @@ -400,7 +465,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -433,7 +498,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-noasan + name: core-emu-${{ matrix.model }}-universal-debuglink-noasan-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment # XXX poetry maybe not needed @@ -491,7 +556,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-firmware-${{ matrix.model }}-universal-normal # FIXME: s/normal/debuglink/ + name: core-firmware-${{ matrix.model }}-universal-normal-protocol_v1 # FIXME: s/normal/debuglink/ path: core/build - uses: ./.github/actions/environment - run: nix-shell --run "poetry run core/tools/size/checker.py core/build/firmware/firmware.elf" @@ -515,7 +580,7 @@ jobs: fetch-depth: 0 - uses: actions/download-artifact@v4 with: - name: core-firmware-${{ matrix.model }}-universal-normal + name: core-firmware-${{ matrix.model }}-universal-normal-protocol_v1 path: core/build - uses: ./.github/actions/environment - run: nix-shell --run "poetry run core/tools/size/compare_master.py core/build/firmware/firmware.elf -r firmware_elf_size_report.txt" @@ -527,7 +592,7 @@ jobs: # Monero tests. core_monero_test: - name: Monero test + name: Monero test (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -537,6 +602,11 @@ jobs: matrix: model: [T2T1, T3B1, T3T1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1] + include: + - model: T2T1 + asan: noasan + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} PYTEST_TIMEOUT: 400 @@ -546,7 +616,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: cachix/install-nix-action@v23 @@ -557,7 +627,7 @@ jobs: - run: nix-shell --arg fullDeps true --run "unset _PYTHON_SYSCONFIGDATA_NAME && poetry run make -C core test_emu_monero" - uses: actions/upload-artifact@v4 with: - name: core-test-monero-${{ matrix.model }}-${{ matrix.asan }} + name: core-test-monero-${{ matrix.model }}-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: | tests/trezor.log core/tests/trezor_monero_tests.log @@ -568,7 +638,7 @@ jobs: # Tests for U2F and HID. core_u2f_test: - name: U2F test + name: U2F test (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -578,6 +648,11 @@ jobs: matrix: model: [T2T1, T3B1, T3T1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1] + include: + - model: T2T1 + asan: noasan + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} PYTEST_TIMEOUT: 400 @@ -587,7 +662,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -595,7 +670,7 @@ jobs: - run: nix-shell --run "poetry run make -C core test_emu_u2f" - uses: actions/upload-artifact@v4 with: - name: core-test-u2f-${{ matrix.model }}-${{ matrix.asan }} + name: core-test-u2f-${{ matrix.model }}-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: tests/trezor.log retention-days: 7 if: always() @@ -603,7 +678,7 @@ jobs: # FIDO2 device tests. core_fido2_test: - name: FIDO2 test + name: FIDO2 test (${{ matrix.model }}, ${{ matrix.asan }}${{ matrix.protocol=='v2' && ', THP' || ''}}) runs-on: ubuntu-latest needs: - param @@ -613,6 +688,11 @@ jobs: matrix: model: [T2T1, T3T1] # XXX T3B1 https://github.com/trezor/trezor-firmware/issues/2724 asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1] + include: + - model: T2T1 + asan: noasan + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} PYTEST_TIMEOUT: 400 @@ -622,14 +702,14 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment - run: nix-shell --run "poetry run make -C core test_emu_fido2" - uses: actions/upload-artifact@v4 with: - name: core-test-fido2-${{ matrix.model }}-${{ matrix.asan }} + name: core-test-fido2-${{ matrix.model }}-${{ matrix.asan }}-protocol_${{matrix.protocol}} path: | tests/trezor.log retention-days: 7 @@ -728,7 +808,7 @@ jobs: steps: - uses: actions/download-artifact@v4 with: - pattern: core-emu*debuglink-noasan + pattern: core-emu*debuglink-noasan-protocol_v* merge-multiple: true - name: Configure aws credentials uses: aws-actions/configure-aws-credentials@v4 @@ -751,7 +831,7 @@ jobs: steps: - uses: actions/download-artifact@v4 with: - pattern: core-emu*debuglink-noasan + pattern: core-emu*debuglink-noasan-protocol_v* merge-multiple: true - name: Configure aws credentials uses: aws-actions/configure-aws-credentials@v4 diff --git a/common/protob/messages-common.proto b/common/protob/messages-common.proto index 3e8cb9537ce..4dd16add5b1 100644 --- a/common/protob/messages-common.proto +++ b/common/protob/messages-common.proto @@ -39,6 +39,10 @@ message Failure { Failure_PinMismatch = 12; Failure_WipeCodeMismatch = 13; Failure_InvalidSession = 14; + Failure_ThpUnallocatedSession = 15; + Failure_InvalidProtocol = 16; + Failure_BufferError = 17; + Failure_DeviceIsBusy = 18; Failure_FirmwareError = 99; } } diff --git a/common/protob/messages-debug.proto b/common/protob/messages-debug.proto index 3727b6243f4..f91d47b7e44 100644 --- a/common/protob/messages-debug.proto +++ b/common/protob/messages-debug.proto @@ -85,7 +85,7 @@ message DebugLinkRecordScreen { } /** - * Request: Computer asks for device state + * Request: Host asks for device state * @start * @next DebugLinkState */ @@ -134,6 +134,29 @@ message DebugLinkState { repeated string tokens = 13; // current layout represented as a list of string tokens } +/** + * Request: Host asks for device pairing info + * @start + * @next DebugLinkPairingInfo + */ + message DebugLinkGetPairingInfo { + optional bytes channel_id = 1; // ID of the THP channel to get pairing info from + optional bytes handshake_hash = 2; // handshake hash of the THP channel + optional bytes nfc_secret_host = 3; // host's NFC secret (In case of NFC pairing) +} + + /** + * Response: Device pairing info + * @end + */ + message DebugLinkPairingInfo { + optional bytes channel_id = 1; // ID of the THP channel the pairing info is from + optional bytes handshake_hash = 2; // handshake hash of the THP channel + optional uint32 code_entry_code = 3; // CodeEntry pairing code + optional bytes code_qr_code = 4; // QrCode pairing code + optional bytes nfc_secret_trezor = 5; // NFC secret used in NFC pairing +} + /** * Request: Ask device to restart * @start diff --git a/common/protob/messages-thp.proto b/common/protob/messages-thp.proto index c05d9f64d71..d798f882fdd 100644 --- a/common/protob/messages-thp.proto +++ b/common/protob/messages-thp.proto @@ -9,6 +9,224 @@ import "options.proto"; option (include_in_bitcoin_only) = true; +/** + * Mapping between Trezor wire identifier (uint) and a Thp protobuf message + */ +enum ThpMessageType { + reserved 0 to 999; // Values reserved by other messages, see messages.proto + + ThpMessageType_ThpCreateNewSession = 1000 [(bitcoin_only)=true]; + ThpMessageType_ThpPairingRequest = 1006 [(bitcoin_only) = true]; + ThpMessageType_ThpPairingRequestApproved = 1007 [(bitcoin_only) = true]; + ThpMessageType_ThpSelectMethod = 1008 [(bitcoin_only) = true]; + ThpMessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true]; + ThpMessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true]; + ThpMessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true]; + ThpMessageType_ThpEndRequest = 1012 [(bitcoin_only) = true]; + ThpMessageType_ThpEndResponse = 1013 [(bitcoin_only) = true]; + ThpMessageType_ThpCodeEntryCommitment = 1016 [(bitcoin_only)=true]; + ThpMessageType_ThpCodeEntryChallenge = 1017 [(bitcoin_only)=true]; + ThpMessageType_ThpCodeEntryCpaceTrezor = 1018 [(bitcoin_only)=true]; + ThpMessageType_ThpCodeEntryCpaceHostTag = 1019 [(bitcoin_only)=true]; + ThpMessageType_ThpCodeEntrySecret = 1020 [(bitcoin_only)=true]; + ThpMessageType_ThpQrCodeTag = 1024 [(bitcoin_only)=true]; + ThpMessageType_ThpQrCodeSecret = 1025 [(bitcoin_only)=true]; + ThpMessageType_ThpNfcTagHost = 1032 [(bitcoin_only)=true]; + ThpMessageType_ThpNfcTagTrezor = 1033 [(bitcoin_only)=true]; + + reserved 1100 to 2147483647; // Values reserved by other messages, see messages.proto +} + + +/** + * Numeric identifiers of pairing methods. + * @embed + */ +enum ThpPairingMethod { + SkipPairing = 1; // Trust without MITM protection. + CodeEntry = 2; // User types code diplayed on Trezor into the host application. + QrCode = 3; // User scans code displayed on Trezor into host application. + NFC = 4; // Trezor and host application exchange authentication secrets via NFC. +} + +/** + * @embed + */ +message ThpDeviceProperties { + optional string internal_model = 1; // Internal model name e.g. "T2B1". + optional uint32 model_variant = 2; // Encodes the device properties such as color. + optional uint32 protocol_version_major = 3; // The major version of the communication protocol used by the firmware. + optional uint32 protocol_version_minor = 4; // The minor version of the communication protocol used by the firmware. + repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor. +} + +/** + * @embed + */ +message ThpHandshakeCompletionReqNoisePayload { + optional bytes host_pairing_credential = 1; // Host's pairing credential +} + +/** + * Request: Ask device for a new session with given passphrase. + * @start + * @next Success + */ +message ThpCreateNewSession{ + optional string passphrase = 1; + optional bool on_device = 2; // User wants to enter passphrase on the device + optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only +} + + +/** + * Request: Start pairing process. + * @start + * @next ThpPairingRequestApproved + */ +message ThpPairingRequest{ + optional string host_name = 1; // Human-readable host name +} + +/** + * Response: Host is allowed to start pairing process. + * @start + * @next ThpSelectMethod + */ +message ThpPairingRequestApproved{ +} + +/** + * Request: Start pairing using the method selected. + * @start + * @next ThpPairingPreparationsFinished + * @next ThpCodeEntryCommitment + */ +message ThpSelectMethod { + optional ThpPairingMethod selected_pairing_method = 1; +} + +/** + * Response: Pairing is ready for user input / OOB communication. + * @next ThpCodeEntryCpace + * @next ThpQrCodeTag + * @next ThpNfcTagHost + */ +message ThpPairingPreparationsFinished{ +} + +/** + * Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment. + * @next ThpCodeEntryChallenge + */ +message ThpCodeEntryCommitment { + optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret +} + +/** + * Response: Host responds to Trezor's Code Entry commitment with a challenge. + * @next ThpCodeEntryCpaceTrezor + */ +message ThpCodeEntryChallenge { + optional bytes challenge = 1; // Host's random 32-byte challenge +} + + + +/** + * Response: Trezor continues with the CPACE protocol. + * @next ThpCodeEntryCpaceHostTag + */ +message ThpCodeEntryCpaceTrezor { + optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key +} + +/** + * Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor. + * @next ThpCodeEntrySecret + */ + message ThpCodeEntryCpaceHostTag { + optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key + optional bytes tag = 2; // SHA-256 of shared secret + +} + +/** + * Response: Trezor finishes the CPACE protocol. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpCodeEntrySecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: User selected QR Code pairing option. Host sends a QR Tag. + * @next ThpQrCodeSecret + */ +message ThpQrCodeTag { + optional bytes tag = 1; // SHA-256 of shared secret +} + +/** + * Response: Trezor sends the QR secret. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpQrCodeSecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag. + * @next ThpNfcTagTrezor + */ +message ThpNfcTagHost { + optional bytes tag = 1; // Host's tag +} + +/** + * Response: Trezor sends the Unidirectioal NFC secret. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpNfcTagTrezor { + optional bytes tag = 1; // Trezor's tag +} + +/** + * Request: Host requests issuance of a new pairing credential. + * @start + * @next ThpCredentialResponse + */ +message ThpCredentialRequest { + optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake. + optional bool autoconnect = 2; // Whether host wants to autoconnect without user confirmation +} + +/** + * Response: Trezor issues a new pairing credential. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpCredentialResponse { + optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake. + optional bytes credential = 2; // The pairing credential issued by the Trezor to the host. +} + +/** + * Request: Host requests transition to the encrypted traffic phase. + * @start + * @next ThpEndResponse + */ +message ThpEndRequest {} + +/** + * Response: Trezor approves transition to the encrypted traffic phase + * @end + */ +message ThpEndResponse {} + /** * Only for internal use. * @embed @@ -16,6 +234,7 @@ option (include_in_bitcoin_only) = true; message ThpCredentialMetadata { option (internal_only) = true; optional string host_name = 1; // Human-readable host name + optional bool autoconnect = 2; // Whether host is allowed to autoconnect without user confirmation } /** diff --git a/common/protob/messages.proto b/common/protob/messages.proto index ae7fdfda599..f807d8c477b 100644 --- a/common/protob/messages.proto +++ b/common/protob/messages.proto @@ -134,6 +134,8 @@ enum MessageType { MessageType_DebugLinkWatchLayout = 9006 [(bitcoin_only) = true, (wire_debug_in) = true]; MessageType_DebugLinkResetDebugEvents = 9007 [(bitcoin_only) = true, (wire_debug_in) = true]; MessageType_DebugLinkOptigaSetSecMax = 9008 [(bitcoin_only) = true, (wire_debug_in) = true]; + MessageType_DebugLinkGetPairingInfo = 9009 [(bitcoin_only) = true, (wire_debug_in) = true]; + MessageType_DebugLinkPairingInfo = 9010 [(bitcoin_only) = true, (wire_debug_out) = true]; // Ethereum MessageType_EthereumGetPublicKey = 450 [(wire_in) = true]; diff --git a/common/protob/pb2py b/common/protob/pb2py index d6cbbde171e..ea12f8d9c1d 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -62,6 +62,7 @@ INT_TYPES = ( ) MESSAGE_TYPE_ENUM = "MessageType" +THP_MESSAGE_TYPE_ENUM = "ThpMessageType" LengthDelimited = c.Struct( "len" / c.VarInt, @@ -239,6 +240,9 @@ class ProtoMessage: @classmethod def from_message(cls, descriptor: "Descriptor", message): message_type = find_by_name(descriptor.message_type_enum.value, message.name) + thp_message_type = None + if not isinstance(descriptor.thp_message_type_enum,tuple): + thp_message_type = find_by_name(descriptor.thp_message_type_enum.value, message.name) # use extensions set on the message_type entry (if any) extensions = descriptor.get_extensions(message_type) # override with extensions set on the message itself @@ -248,6 +252,8 @@ class ProtoMessage: wire_type = extensions["wire_type"] elif message_type is not None: wire_type = message_type.number + elif thp_message_type is not None: + wire_type = thp_message_type.number else: wire_type = None @@ -351,10 +357,13 @@ class Descriptor: ] logging.debug(f"found {len(self.files)} bitcoin-only files") - # find message_type enum + # find message_type and thp_message_type enum top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files) self.message_type_enum = find_by_name(top_level_enums, MESSAGE_TYPE_ENUM, ()) + top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files) + self.thp_message_type_enum = find_by_name(top_level_enums, THP_MESSAGE_TYPE_ENUM, ()) self.convert_enum_value_names(self.message_type_enum) + self.convert_enum_value_names(self.thp_message_type_enum) # find messages and enums self.messages = [] @@ -423,6 +432,8 @@ class Descriptor: self._nested_types_from_message(nested.orig) def convert_enum_value_names(self, enum): + if isinstance(enum,tuple): + return for value in enum.value: value.name = strip_enum_prefix(enum.name, value.name) @@ -558,6 +569,8 @@ class RustBlobRenderer: enums = [] cursor = 0 for enum in sorted(self.descriptor.enums, key=lambda e: e.name): + if enum.name == "MessageType": + continue self.enum_map[enum.name] = cursor enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value)) enums.append(enum_blob) diff --git a/core/Makefile b/core/Makefile index 3bca1bb31d0..6f987a48b55 100644 --- a/core/Makefile +++ b/core/Makefile @@ -296,6 +296,10 @@ build_unix: templates ## build unix port build_unix_frozen: templates build_cross ## build unix port with frozen modules $(SCONS) $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) TREZOR_EMULATOR_FROZEN=1 +build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0") + $(SCONS) $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) TREZOR_EMULATOR_FROZEN=1 \ + PYOPT=0 + build_unix_debug: templates ## build unix port $(SCONS) --max-drift=1 $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \ TREZOR_EMULATOR_ASAN=1 TREZOR_EMULATOR_DEBUGGABLE=1 diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 57aeb9b6773..e5248431658 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -578,14 +578,23 @@ if FROZEN: ] if not EVERYTHING else [] )) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py')) + if not THP or PYOPT == '0': + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', - exclude=[ - SOURCE_PY_DIR + 'storage/sd_salt.py', - ] if not SDCARD else [] - )) + + exclude_list = [] + if 'sd_card' not in FEATURES_AVAILABLE: + exclude_list.append(SOURCE_PY_DIR + 'storage/sd_salt.py') + if THP and PYOPT == '1': + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_codec.py') + if not THP: + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_thp.py') + + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=exclude_list)) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/messages/__init__.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/*.py', diff --git a/core/SConscript.unix b/core/SConscript.unix index 6b0fd17ff88..13369d2c660 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -637,14 +637,23 @@ if FROZEN: ] if not EVERYTHING else [] )) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py')) + if not THP or PYOPT == '0': + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', - exclude=[ - SOURCE_PY_DIR + 'storage/sd_salt.py', - ] if 'sd_card' not in FEATURES_AVAILABLE else [] - )) + + exclude_list = [] + if 'sd_card' not in FEATURES_AVAILABLE: + exclude_list.append(SOURCE_PY_DIR + 'storage/sd_salt.py') + if THP and PYOPT == '1': + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_codec.py') + if not THP: + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_thp.py') + + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=exclude_list)) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/messages/__init__.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/*.py', diff --git a/core/emu.py b/core/emu.py index 0cf88a6ca91..3dbd8426560 100755 --- a/core/emu.py +++ b/core/emu.py @@ -282,9 +282,10 @@ def cli( label = "Emulator" assert emulator.client is not None - trezorlib.device.wipe(emulator.client) + trezorlib.device.wipe(emulator.client.get_seedless_session()) + trezorlib.debuglink.load_device( - emulator.client, + emulator.client.get_seedless_session(), mnemonics, pin=None, passphrase_protection=False, diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 70651ea3a89..2a23e58aeec 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -51,6 +51,8 @@ import storage.cache_codec storage.cache_common import storage.cache_common +storage.cache_thp +import storage.cache_thp storage.common import storage.common storage.debug @@ -419,10 +421,50 @@ import apps.workflow_handlers if utils.USE_THP: + trezor.enums.ThpMessageType + import trezor.enums.ThpMessageType + trezor.enums.ThpPairingMethod + import trezor.enums.ThpPairingMethod + trezor.wire.thp + import trezor.wire.thp + trezor.wire.thp.alternating_bit_protocol + import trezor.wire.thp.alternating_bit_protocol + trezor.wire.thp.channel + import trezor.wire.thp.channel + trezor.wire.thp.channel_manager + import trezor.wire.thp.channel_manager + trezor.wire.thp.checksum + import trezor.wire.thp.checksum + trezor.wire.thp.control_byte + import trezor.wire.thp.control_byte + trezor.wire.thp.cpace + import trezor.wire.thp.cpace + trezor.wire.thp.crypto + import trezor.wire.thp.crypto + trezor.wire.thp.interface_manager + import trezor.wire.thp.interface_manager + trezor.wire.thp.memory_manager + import trezor.wire.thp.memory_manager + trezor.wire.thp.pairing_context + import trezor.wire.thp.pairing_context + trezor.wire.thp.received_message_handler + import trezor.wire.thp.received_message_handler + trezor.wire.thp.session_context + import trezor.wire.thp.session_context + trezor.wire.thp.session_manager + import trezor.wire.thp.session_manager + trezor.wire.thp.thp_main + import trezor.wire.thp.thp_main + trezor.wire.thp.transmission_loop + import trezor.wire.thp.transmission_loop + trezor.wire.thp.writer + import trezor.wire.thp.writer apps.thp import apps.thp apps.thp.credential_manager import apps.thp.credential_manager + apps.thp.pairing + import apps.thp.pairing if not utils.BITCOIN_ONLY: trezor.enums.BinanceOrderSide diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 5552fc86ba9..6f388d6e49e 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -3,7 +3,7 @@ import storage.device as storage_device from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED from trezor import TR, config, utils, wire, workflow -from trezor.enums import HomescreenFormat, MessageType +from trezor.enums import HomescreenFormat, MessageType, ThpMessageType from trezor.messages import Success, UnlockPath from trezor.ui.layouts import confirm_action from trezor.wire import context @@ -27,6 +27,9 @@ ) from trezor.wire import Handler, Msg + if utils.USE_THP: + from trezor.messages import Failure, ThpCreateNewSession + _SCREENSAVER_IS_ON = False @@ -204,33 +207,103 @@ def get_features() -> Features: return f -async def handle_Initialize(msg: Initialize) -> Features: - import storage.cache_codec as cache_codec +if utils.USE_THP: - session_id = cache_codec.start_session(msg.session_id) + async def handle_ThpCreateNewSession( + message: ThpCreateNewSession, + ) -> Success | Failure: + """ + Creates a new `ThpSession` based on the provided parameters and returns a + `Success` message on success. - if not utils.BITCOIN_ONLY: - from storage.cache_common import APP_COMMON_DERIVE_CARDANO + Returns an appropriate `Failure` message if session creation fails. + """ + from trezor import log, loop + from trezor.enums import FailureType + from trezor.messages import Failure + from trezor.wire import NotInitialized + from trezor.wire.context import get_context + from trezor.wire.errors import ActionCancelled, DataError + from trezor.wire.thp.session_context import GenericSessionContext + from trezor.wire.thp.session_manager import get_new_session_context - derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) - have_seed = context.cache_is_set(APP_COMMON_SEED) - if ( - have_seed - and msg.derive_cardano is not None - and msg.derive_cardano != bool(derive_cardano) - ): - # seed is already derived, and host wants to change derive_cardano setting - # => create a new session - cache_codec.end_current_session() - session_id = cache_codec.start_session() - have_seed = False + from apps.common.seed import derive_and_store_roots + + ctx = get_context() + + # Assert that context `ctx` is `GenericSessionContext` + assert isinstance(ctx, GenericSessionContext) + + channel = ctx.channel + session_id = ctx.session_id - if not have_seed: - context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) + # Do not use `ctx` beyond this point, as it is techically + # allowed to change in between await statements + + if not 0 <= session_id <= 255: + return Failure( + code=FailureType.DataError, + message="Invalid session_id for session creation.", + ) + + new_session = get_new_session_context( + channel_ctx=channel, session_id=session_id + ) + try: + await unlock_device() + await derive_and_store_roots(new_session, message) + except DataError as e: + return Failure(code=FailureType.DataError, message=e.message) + except ActionCancelled as e: + return Failure(code=FailureType.ActionCancelled, message=e.message) + except NotInitialized as e: + return Failure(code=FailureType.NotInitialized, message=e.message) + # TODO handle other errors (`Exception` when "Cardano icarus secret is already set!") + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "New session with sid %d and passphrase %s created.", + session_id, + message.passphrase if message.passphrase is not None else "", + ) - features = get_features() - features.session_id = session_id - return features + channel.sessions[new_session.session_id] = new_session + loop.schedule(new_session.handle()) + + return Success(message="New session created.") + +else: + + async def handle_Initialize(msg: Initialize) -> Features: + import storage.cache_codec as cache_codec + + session_id = cache_codec.start_session(msg.session_id) + + if not utils.BITCOIN_ONLY: + from storage.cache_common import APP_COMMON_DERIVE_CARDANO + + derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) + if ( + have_seed + and msg.derive_cardano is not None + and msg.derive_cardano != bool(derive_cardano) + ): + # seed is already derived, and host wants to change derive_cardano setting + # => create a new session + cache_codec.end_current_session() + session_id = cache_codec.start_session() + have_seed = False + + if not have_seed: + context.cache_set_bool( + APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano) + ) + + features = get_features() + features.session_id = session_id + return features async def handle_GetFeatures(msg: GetFeatures) -> Features: @@ -464,8 +537,12 @@ def boot() -> None: MT = MessageType # local_cache_global # Register workflow handlers + if utils.USE_THP: + TMT = ThpMessageType + workflow_handlers.register(TMT.ThpCreateNewSession, handle_ThpCreateNewSession) + else: + workflow_handlers.register(MT.Initialize, handle_Initialize) for msg_type, handler in [ - (MT.Initialize, handle_Initialize), (MT.GetFeatures, handle_GetFeatures), (MT.Cancel, handle_Cancel), (MT.LockDevice, handle_LockDevice), diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 35f6b3f60ce..781b29b59f5 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -6,9 +6,8 @@ APP_CARDANO_ICARUS_TREZOR_SECRET, APP_COMMON_DERIVE_CARDANO, ) -from trezor import wire +from trezor import utils, wire from trezor.crypto import cardano -from trezor.wire import context from apps.common import mnemonic from apps.common.seed import get_seed @@ -21,6 +20,7 @@ from trezor import messages from trezor.crypto import bip32 from trezor.enums import CardanoDerivationType + from trezor.wire.protocol_common import Context from apps.common.keychain import Handler, MsgOut from apps.common.paths import Bip32Path @@ -116,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool: return path[: len(MINTING_ROOT)] == MINTING_ROOT -def derive_and_store_secrets(passphrase: str) -> None: +def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: assert device.is_initialized() - assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -138,14 +138,13 @@ def derive_and_store_secrets(passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret) - context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret) + ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: from trezor.enums import CardanoDerivationType - - from apps.common.seed import derive_and_store_roots + from trezor.wire import context if not device.is_initialized(): raise wire.NotInitialized("Device is not initialized") @@ -164,10 +163,13 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai # _get_secret secret = context.cache_get(cache_entry) - if secret is None: - await derive_and_store_roots() - secret = context.cache_get(cache_entry) - assert secret is not None + if not utils.USE_THP: + if secret is None: + from apps.common.seed import derive_and_store_roots_legacy + + await derive_and_store_roots_legacy() + secret = context.cache_get(cache_entry) + assert secret is not None root = cardano.from_secret(secret) return Keychain(root) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py index fc56f42f9b7..8037aba6987 100644 --- a/core/src/apps/common/backup.py +++ b/core/src/apps/common/backup.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED -from trezor import wire +from trezor import utils, wire from trezor.enums import MessageType from trezor.wire import context from trezor.wire.message_handler import filters, remove_filter @@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None: remove_filter(_repeated_backup_filter) -_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - MessageType.BackupDevice, - MessageType.WipeDevice, - MessageType.Cancel, -) +if utils.USE_THP: + _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( + MessageType.GetFeatures, + MessageType.EndSession, + MessageType.BackupDevice, + MessageType.WipeDevice, + MessageType.Cancel, + ) +else: + _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + MessageType.BackupDevice, + MessageType.WipeDevice, + MessageType.Cancel, + ) def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: diff --git a/core/src/apps/common/keychain.py b/core/src/apps/common/keychain.py index 16913d1529e..7959789b251 100644 --- a/core/src/apps/common/keychain.py +++ b/core/src/apps/common/keychain.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from trezor import utils from trezor.crypto import bip32 from trezor.wire import DataError @@ -172,6 +173,9 @@ async def get_keychain( ) -> Keychain: from .seed import get_seed + if not utils.USE_THP: + pass + # try to ask for passphrase here seed = await get_seed() keychain = Keychain(seed, curve, schemas, slip21_namespaces) return keychain diff --git a/core/src/apps/common/passphrase.py b/core/src/apps/common/passphrase.py index ef8bb5b1850..d150dd47369 100644 --- a/core/src/apps/common/passphrase.py +++ b/core/src/apps/common/passphrase.py @@ -1,84 +1,122 @@ from micropython import const +from typing import TYPE_CHECKING import storage.device as storage_device +from trezor import utils from trezor.wire import DataError _MAX_PASSPHRASE_LEN = const(50) +if TYPE_CHECKING: + from trezor.messages import ThpCreateNewSession + def is_enabled() -> bool: return storage_device.is_passphrase_enabled() -async def get() -> str: - from trezor import workflow - +async def get_passphrase(msg: ThpCreateNewSession) -> str: if not is_enabled(): return "" + + if msg.on_device or storage_device.get_passphrase_always_on_device(): + passphrase = await _get_on_device() else: - workflow.close_others() # request exclusive UI access - if storage_device.get_passphrase_always_on_device(): - from trezor.ui.layouts import request_passphrase_on_device + passphrase = msg.passphrase or "" + if passphrase: + await _handle_displaying_passphrase_from_host(passphrase) - passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) - else: - passphrase = await _request_on_host() - if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: - raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes") + if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: + raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes") - return passphrase + return passphrase -async def _request_on_host() -> str: - from trezor import TR - from trezor.messages import PassphraseAck, PassphraseRequest - from trezor.ui.layouts import request_passphrase_on_host - from trezor.wire.context import call +async def _get_on_device() -> str: + from trezor import workflow + from trezor.ui.layouts import request_passphrase_on_device - request_passphrase_on_host() + workflow.close_others() # request exclusive UI access + passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) - request = PassphraseRequest() - ack = await call(request, PassphraseAck) - passphrase = ack.passphrase # local_cache_attribute + return passphrase - if ack.on_device: - from trezor.ui.layouts import request_passphrase_on_device - if passphrase is not None: - raise DataError("Passphrase provided when it should not be") - return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) +async def _handle_displaying_passphrase_from_host(passphrase: str) -> None: + from trezor import TR + from trezor.ui.layouts import confirm_action, confirm_blob + + # We want to hide the passphrase, or show it, according to settings. + if storage_device.get_hide_passphrase_from_host(): + await confirm_action( + "passphrase_host1_hidden", + TR.passphrase__wallet, + description=TR.passphrase__from_host_not_shown, + prompt_screen=True, + prompt_title=TR.passphrase__access_wallet, + ) + else: + await confirm_action( + "passphrase_host1", + TR.passphrase__wallet, + description=TR.passphrase__next_screen_will_show_passphrase, + verb=TR.buttons__continue, + ) - if passphrase is None: - raise DataError( - "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase." + await confirm_blob( + "passphrase_host2", + TR.passphrase__title_confirm, + passphrase, ) - # non-empty passphrase - if passphrase: - from trezor.ui.layouts import confirm_action, confirm_blob - - # We want to hide the passphrase, or show it, according to settings. - if storage_device.get_hide_passphrase_from_host(): - await confirm_action( - "passphrase_host1_hidden", - TR.passphrase__wallet, - description=TR.passphrase__from_host_not_shown, - prompt_screen=True, - prompt_title=TR.passphrase__access_wallet, - ) + +if not utils.USE_THP: + + async def get() -> str: + from trezor import workflow + + if not is_enabled(): + return "" else: - await confirm_action( - "passphrase_host1", - TR.passphrase__wallet, - description=TR.passphrase__next_screen_will_show_passphrase, - verb=TR.buttons__continue, - ) + workflow.close_others() # request exclusive UI access + if storage_device.get_passphrase_always_on_device(): + from trezor.ui.layouts import request_passphrase_on_device + + passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) + else: + passphrase = await _request_on_host() + if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: + raise DataError( + f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes" + ) + + return passphrase + + async def _request_on_host() -> str: + from trezor.messages import PassphraseAck, PassphraseRequest + from trezor.ui.layouts import request_passphrase_on_host + from trezor.wire.context import call + + request_passphrase_on_host() - await confirm_blob( - "passphrase_host2", - TR.passphrase__title_confirm, - passphrase, - info=False, + request = PassphraseRequest() + ack = await call(request, PassphraseAck) + passphrase = ack.passphrase # local_cache_attribute + + if ack.on_device: + from trezor.ui.layouts import request_passphrase_on_device + + if passphrase is not None: + raise DataError("Passphrase provided when it should not be") + return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) + + if passphrase is None: + raise DataError( + "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase." ) - return passphrase + # non-empty passphrase + if passphrase: + await _handle_displaying_passphrase_from_host(passphrase) + + return passphrase diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index b09004ae698..4bb15184f80 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -5,14 +5,18 @@ from trezor import utils from trezor.crypto import hmac from trezor.wire import context +from trezor.wire.context import get_context +from trezor.wire.errors import DataError from apps.common import cache from . import mnemonic -from .passphrase import get as get_passphrase +from .passphrase import get_passphrase as get_passphrase if TYPE_CHECKING: from trezor.crypto import bip32 + from trezor.messages import ThpCreateNewSession + from trezor.wire.protocol_common import Context from .paths import Bip32Path, Slip21Path @@ -22,6 +26,9 @@ APP_COMMON_DERIVE_CARDANO, ) +if not utils.USE_THP: + from .passphrase import get as get_passphrase_legacy + class Slip21Node: """ @@ -54,51 +61,111 @@ def clone(self) -> "Slip21Node": return Slip21Node(data=self.data) -if not utils.BITCOIN_ONLY: - # === Cardano variant === - # We want to derive both the normal seed and the Cardano seed together, AND - # expose a method for Cardano to do the same +if utils.USE_THP: + + async def get_seed() -> bytes: # type: ignore [Function declaration "get_seed" is obscured by a declaration of the same name] + common_seed = context.cache_get(APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + if utils.BITCOIN_ONLY: + # === Bitcoin_only variant === + # We want to derive the normal seed ONLY - async def derive_and_store_roots() -> None: - from trezor import wire + async def derive_and_store_roots( + ctx: Context, msg: ThpCreateNewSession + ) -> None: - if not storage_device.is_initialized(): - raise wire.NotInitialized("Device is not initialized") + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") - need_seed = not context.cache_is_set(APP_COMMON_SEED) - need_cardano_secret = context.cache_get_bool( - APP_COMMON_DERIVE_CARDANO - ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET) + if ctx.cache.is_set(APP_COMMON_SEED): + raise Exception("Seed is already set!") - if not need_seed and not need_cardano_secret: - return + from trezor import wire - passphrase = await get_passphrase() + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") - if need_seed: + passphrase = await get_passphrase(msg) common_seed = mnemonic.get_seed(passphrase) - context.cache_set(APP_COMMON_SEED, common_seed) + ctx.cache.set(APP_COMMON_SEED, common_seed) - if need_cardano_secret: - from apps.cardano.seed import derive_and_store_secrets + else: + # === Cardano variant === + # We want to derive both the normal seed and the Cardano seed together + async def derive_and_store_roots( + ctx: Context, msg: ThpCreateNewSession + ) -> None: - derive_and_store_secrets(passphrase) + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") - @cache.stored_async(APP_COMMON_SEED) - async def get_seed() -> bytes: - await derive_and_store_roots() - common_seed = context.cache_get(APP_COMMON_SEED) - assert common_seed is not None - return common_seed + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET): + raise Exception("Cardano icarus secret is already set!") + + passphrase = await get_passphrase(msg) + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if msg.derive_cardano: + from apps.cardano.seed import derive_and_store_secrets + + ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True) + derive_and_store_secrets(ctx, passphrase) else: - # === Bitcoin-only variant === - # We use the simple version of `get_seed` that never needs to derive anything else. + if utils.BITCOIN_ONLY: + # === Bitcoin-only variant === + # We use the simple version of `get_seed` that never needs to derive anything else. + + @cache.stored_async(APP_COMMON_SEED) + async def get_seed() -> bytes: + passphrase = await get_passphrase_legacy() + return mnemonic.get_seed(passphrase=passphrase) + + else: + # === Cardano variant === + # We want to derive both the normal seed and the Cardano seed together, AND + # expose a method for Cardano to do the same + + @cache.stored_async(APP_COMMON_SEED) + async def get_seed() -> bytes: + await derive_and_store_roots_legacy() + common_seed = context.cache_get(APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + async def derive_and_store_roots_legacy() -> None: + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + ctx = get_context() + need_seed = not ctx.cache.is_set(APP_COMMON_SEED) + need_cardano_secret = ctx.cache.get_bool( + APP_COMMON_DERIVE_CARDANO + ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET) + + if not need_seed and not need_cardano_secret: + return + + passphrase = await get_passphrase_legacy() + + if need_seed: + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if need_cardano_secret: + from apps.cardano.seed import derive_and_store_secrets - @cache.stored_async(APP_COMMON_SEED) - async def get_seed() -> bytes: - passphrase = await get_passphrase() - return mnemonic.get_seed(passphrase) + derive_and_store_secrets(ctx, passphrase) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 21ffc9f1bd2..1b5475ffa59 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -1,3 +1,5 @@ +from trezor.wire import message_handler + if not __debug__: from trezor.utils import halt @@ -22,20 +24,23 @@ from trezor.messages import ( DebugLinkDecision, DebugLinkEraseSdCard, + DebugLinkGetPairingInfo, DebugLinkGetState, DebugLinkOptigaSetSecMax, + DebugLinkPairingInfo, DebugLinkRecordScreen, DebugLinkReseedRandom, DebugLinkState, ) from trezor.ui import Layout - from trezor.wire import WireInterface, context + from trezor.wire import WireInterface + from trezor.wire.protocol_common import Context Handler = Callable[[Any], Awaitable[Any]] layout_change_box = loop.mailbox() - DEBUG_CONTEXT: context.Context | None = None + DEBUG_CONTEXT: Context | None = None REFRESH_INDEX = 0 @@ -70,9 +75,7 @@ def wait_until_layout_is_running(timeout: int | None = _DEADLOCK_SLEEP_MS) -> Aw "layout deadlock detected (did you send a ButtonAck?)" ) - async def return_layout_change( - ctx: wire.protocol_common.Context, detect_deadlock: bool = False - ) -> None: + async def return_layout_change(ctx: Context, detect_deadlock: bool = False) -> None: # set up the wait storage.layout_watcher = True @@ -265,6 +268,42 @@ def callback(*args: str) -> None: tokens=tokens, ) + async def dispatch_DebugLinkGetPairingInfo( + msg: DebugLinkGetPairingInfo, + ) -> DebugLinkPairingInfo | None: + if not utils.USE_THP: + raise RuntimeError("Trezor does not support THP") + if msg.channel_id is None: + raise RuntimeError("Invalid DebugLinkGetPairingInfo message") + + from trezor.wire.thp.channel import Channel + from trezor.wire.thp.pairing_context import PairingContext + from trezor.wire.thp.thp_main import _CHANNELS + + channel_id = int.from_bytes(msg.channel_id, "big") + channel: Channel | None = None + ctx: PairingContext | None = None + try: + channel = _CHANNELS[channel_id] + ctx = channel.connection_context + except KeyError: + pass + + if ctx is None or not isinstance(ctx, PairingContext): + raise RuntimeError("Trezor is not in pairing mode") + + ctx.nfc_secret_host = msg.nfc_secret_host + ctx.handshake_hash_host = msg.handshake_hash + from trezor.messages import DebugLinkPairingInfo + + return DebugLinkPairingInfo( + channel_id=ctx.channel_id, + handshake_hash=ctx.channel_ctx.get_handshake_hash(), + code_entry_code=ctx.code_code_entry, + code_qr_code=ctx.code_qr_code, + nfc_secret_trezor=ctx.nfc_secret, + ) + async def dispatch_DebugLinkGetState( msg: DebugLinkGetState, ) -> DebugLinkState | None: @@ -395,7 +434,6 @@ async def handle_session(iface: WireInterface) -> None: ctx.iface.iface_num(), msg_type, ) - if msg.type not in WORKFLOW_HANDLERS: await ctx.write(wire.message_handler.unexpected_message()) continue @@ -408,7 +446,7 @@ async def handle_session(iface: WireInterface) -> None: await ctx.write(Success()) continue - req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type) + req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) try: res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) except Exception as exc: @@ -427,6 +465,7 @@ async def handle_session(iface: WireInterface) -> None: WORKFLOW_HANDLERS: dict[int, Handler] = { MessageType.DebugLinkDecision: dispatch_DebugLinkDecision, MessageType.DebugLinkGetState: dispatch_DebugLinkGetState, + MessageType.DebugLinkGetPairingInfo: dispatch_DebugLinkGetPairingInfo, MessageType.DebugLinkReseedRandom: dispatch_DebugLinkReseedRandom, MessageType.DebugLinkRecordScreen: dispatch_DebugLinkRecordScreen, MessageType.DebugLinkEraseSdCard: dispatch_DebugLinkEraseSdCard, diff --git a/core/src/apps/management/reboot_to_bootloader.py b/core/src/apps/management/reboot_to_bootloader.py index 85596c0268d..2213d2c17a5 100644 --- a/core/src/apps/management/reboot_to_bootloader.py +++ b/core/src/apps/management/reboot_to_bootloader.py @@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn: boot_args = None ctx = get_context() - await ctx.write(Success(message="Rebooting")) + await ctx.write_force(Success(message="Rebooting")) # make sure the outgoing USB buffer is flushed await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) # reboot to the bootloader, pass the firmware header hash if any diff --git a/core/src/apps/management/recovery_device/__init__.py b/core/src/apps/management/recovery_device/__init__.py index 10ca5f63777..f722a63c7fe 100644 --- a/core/src/apps/management/recovery_device/__init__.py +++ b/core/src/apps/management/recovery_device/__init__.py @@ -24,6 +24,7 @@ async def recovery_device(msg: RecoveryDevice) -> Success: from trezor import TR, config, wire, workflow from trezor.enums import BackupType, ButtonRequestType from trezor.ui.layouts import confirm_action, confirm_reset_device + from trezor.wire.context import try_get_ctx_ids from apps.common import mnemonic from apps.common.request_pin import ( @@ -69,8 +70,8 @@ async def recovery_device(msg: RecoveryDevice) -> Success: if recovery_type == RecoveryType.NormalRecovery: await confirm_reset_device(recovery=True) - # wipe storage to make sure the device is in a clear state - storage.reset() + # wipe storage to make sure the device is in a clear state (except protocol cache) + storage.reset(excluded=try_get_ctx_ids()) # set up pin if requested if msg.pin_protection: diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 7ad56a47422..face532e762 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -3,8 +3,9 @@ import storage.device as storage_device import storage.recovery as storage_recovery import storage.recovery_shares as storage_recovery_shares -from trezor import TR, wire +from trezor import TR, utils, wire from trezor.messages import Success +from trezor.wire import message_handler from apps.common import backup_types @@ -38,18 +39,27 @@ async def recovery_process() -> Success: recovery_type = storage_recovery.get_type() - wire.message_handler.AVOID_RESTARTING_FOR = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - ) + if utils.USE_THP: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.GetFeatures, + MessageType.EndSession, + ) + else: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + ) try: return await _continue_recovery_process() except recover.RecoveryAborted: storage_recovery.end_progress() backup.deactivate_repeated_backup() if recovery_type == RecoveryType.NormalRecovery: - storage.wipe() + from trezor.wire.context import try_get_ctx_ids + + storage.wipe(clear_cache=False) + storage.wipe_cache(excluded=try_get_ctx_ids()) raise wire.ActionCancelled @@ -59,11 +69,17 @@ async def _continue_repeated_backup() -> None: from apps.common import backup from apps.management.backup_device import perform_backup - wire.message_handler.AVOID_RESTARTING_FOR = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - ) + if utils.USE_THP: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.GetFeatures, + MessageType.EndSession, + ) + else: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + ) try: await perform_backup(is_repeated_backup=True) diff --git a/core/src/apps/management/reset_device/__init__.py b/core/src/apps/management/reset_device/__init__.py index 4b3d8bf2efc..4840d31a21f 100644 --- a/core/src/apps/management/reset_device/__init__.py +++ b/core/src/apps/management/reset_device/__init__.py @@ -38,7 +38,7 @@ async def reset_device(msg: ResetDevice) -> Success: prompt_backup, show_wallet_created_success, ) - from trezor.wire.context import call + from trezor.wire.context import call, try_get_ctx_ids from apps.common.request_pin import request_pin_confirm @@ -60,8 +60,8 @@ async def reset_device(msg: ResetDevice) -> Success: # Rendering empty loader so users do not feel a freezing screen render_empty_loader(config.StorageMessage.PROCESSING_MSG) - # wipe storage to make sure the device is in a clear state - storage.reset() + # wipe storage to make sure the device is in a clear state (except protocol cache) + storage.reset(excluded=try_get_ctx_ids()) # Check backup type, perform type-specific handling if backup_types.is_slip39_backup_type(backup_type): @@ -139,7 +139,7 @@ async def reset_device(msg: ResetDevice) -> Success: if perform_backup: await layout.show_backup_success() - return Success(message="Initialized") + return Success(message="Initialized") # TODO: Why "Initialized?" async def _entropy_check(secret: bytes) -> bool: diff --git a/core/src/apps/management/wipe_device.py b/core/src/apps/management/wipe_device.py index b6e60057a6c..e6f787b2c79 100644 --- a/core/src/apps/management/wipe_device.py +++ b/core/src/apps/management/wipe_device.py @@ -1,12 +1,19 @@ from typing import TYPE_CHECKING +from trezor.wire.context import get_context, try_get_ctx_ids + if TYPE_CHECKING: - from trezor.messages import Success, WipeDevice + from typing import NoReturn + + from trezor.messages import WipeDevice +if __debug__: + from trezor import log -async def wipe_device(msg: WipeDevice) -> Success: + +async def wipe_device(msg: WipeDevice) -> NoReturn: import storage - from trezor import TR, config, translations + from trezor import TR, config, loop, translations from trezor.enums import ButtonRequestType from trezor.messages import Success from trezor.pin import render_empty_loader @@ -26,16 +33,31 @@ async def wipe_device(msg: WipeDevice) -> Success: br_code=ButtonRequestType.WipeDevice, ) + if __debug__: + log.debug(__name__, "Device wipe - start") + # start an empty progress screen so that the screen is not blank while waiting render_empty_loader(config.StorageMessage.PROCESSING_MSG) # wipe storage - storage.wipe() + storage.wipe(clear_cache=False) + + # clear cache - exclude current context + storage.wipe_cache(excluded=try_get_ctx_ids()) + # erase translations translations.deinit() translations.erase() + try: + await get_context().write_force(Success(message="Device wiped")) + except Exception: + if __debug__: + log.debug(__name__, "Failed to send Success message after wipe.") + pass + storage.wipe_cache() # reload settings reload_settings_from_storage() - - return Success(message="Device wiped") + loop.clear() + if __debug__: + log.debug(__name__, "Device wipe - finished") diff --git a/core/src/apps/thp/credential_manager.py b/core/src/apps/thp/credential_manager.py index adf2ba62409..9170a06f9ec 100644 --- a/core/src/apps/thp/credential_manager.py +++ b/core/src/apps/thp/credential_manager.py @@ -63,17 +63,26 @@ def issue_credential( return credential_raw -def validate_credential( +def decode_credential( encoded_pairing_credential_message: bytes, +) -> ThpPairingCredential: + """ + Decode a protobuf encoded pairing credential. + """ + expected_type = protobuf.type_for_name("ThpPairingCredential") + credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type) + assert ThpPairingCredential.is_type_of(credential) + return credential + + +def validate_credential( + credential: ThpPairingCredential, host_static_pubkey: bytes, ) -> bool: """ Validate a pairing credential binded to the provided host static public key. """ cred_auth_key = derive_cred_auth_key() - expected_type = protobuf.type_for_name("ThpPairingCredential") - credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type) - assert ThpPairingCredential.is_type_of(credential) proto_msg = ThpAuthenticatedCredentialData( host_static_pubkey=host_static_pubkey, cred_metadata=credential.cred_metadata, @@ -83,6 +92,27 @@ def validate_credential( return mac == credential.mac +def decode_and_validate_credential( + encoded_pairing_credential_message: bytes, + host_static_pubkey: bytes, +) -> bool: + """ + Decode a protobuf encoded pairing credential and validate it + binded to the provided host static public key. + """ + credential = decode_credential(encoded_pairing_credential_message) + return validate_credential(credential, host_static_pubkey) + + +def is_credential_autoconnect(credential: ThpPairingCredential) -> bool: + assert ThpPairingCredential.is_type_of(credential) + if credential.cred_metadata is None: + return False + if credential.cred_metadata.autoconnect is None: + return False + return credential.cred_metadata.autoconnect + + def _encode_message_into_new_buffer(msg: protobuf.MessageType) -> bytes: msg_len = protobuf.encoded_length(msg) new_buffer = bytearray(msg_len) diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py new file mode 100644 index 00000000000..0031c934208 --- /dev/null +++ b/core/src/apps/thp/pairing.py @@ -0,0 +1,469 @@ +from typing import TYPE_CHECKING +from ubinascii import hexlify + +from trezor import protobuf +from trezor.crypto import random +from trezor.crypto.hashlib import sha256 +from trezor.enums import ThpMessageType, ThpPairingMethod +from trezor.messages import ( + Cancel, + ThpCodeEntryChallenge, + ThpCodeEntryCommitment, + ThpCodeEntryCpaceHostTag, + ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, + ThpCredentialMetadata, + ThpCredentialRequest, + ThpCredentialResponse, + ThpEndRequest, + ThpEndResponse, + ThpNfcTagHost, + ThpNfcTagTrezor, + ThpPairingPreparationsFinished, + ThpPairingRequest, + ThpQrCodeSecret, + ThpQrCodeTag, + ThpSelectMethod, +) +from trezor.wire import message_handler +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.errors import SilentError, UnexpectedMessage +from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods +from trezor.wire.thp.pairing_context import PairingContext + +from .credential_manager import is_credential_autoconnect, issue_credential + +if __debug__: + from trezor import log + +if TYPE_CHECKING: + from typing import Any, Callable, Concatenate, ParamSpec, Tuple + + P = ParamSpec("P") + FuncWithContext = Callable[Concatenate[PairingContext, P], Any] + +# +# Helpers - decorators + + +def check_state_and_log( + *allowed_states: ChannelState, +) -> Callable[[FuncWithContext], FuncWithContext]: + def decorator(f: FuncWithContext) -> FuncWithContext: + def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object: + _check_state(context, *allowed_states) + if __debug__: + try: + log.debug(__name__, "started %s", f.__name__) + except AttributeError: + log.debug( + __name__, + "started a function that cannot be named, because it raises AttributeError, eg. closure", + ) + return f(context, *args, **kwargs) + + return inner + + return decorator + + +def check_method_is_allowed_and_selected( + pairing_method: ThpPairingMethod, +) -> Callable[[FuncWithContext], FuncWithContext]: + def decorator(f: FuncWithContext) -> FuncWithContext: + def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object: + _check_method_is_allowed(context, pairing_method) + _check_method_is_selected(context, pairing_method) + return f(context, *args, **kwargs) + + return inner + + return decorator + + +# +# Pairing handlers + + +@check_state_and_log(ChannelState.TP0) +async def handle_pairing_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + + if not ThpPairingRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + + ctx.host_name = message.host_name or "" + + await ctx.show_pairing_dialogue() + assert ThpSelectMethod.MESSAGE_WIRE_TYPE is not None + select_method_msg = await ctx.read( + [ + ThpSelectMethod.MESSAGE_WIRE_TYPE, + ] + ) + + assert ThpSelectMethod.is_type_of(select_method_msg) + assert select_method_msg.selected_pairing_method is not None + + ctx.set_selected_method(select_method_msg.selected_pairing_method) + + if ctx.selected_method == ThpPairingMethod.SkipPairing: + return await _end_pairing(ctx) + + while True: + await _prepare_pairing(ctx) + + ctx.channel_ctx.set_channel_state(ChannelState.TP3) + try: + # Should raise UnexpectedMessageException + await ctx.show_pairing_method_screen() + except UnexpectedMessageException as e: + raw_response = e.msg + name = message_handler.get_msg_name(raw_response.type) + if name is None: + req_type = protobuf.type_for_wire(raw_response.type) + else: + req_type = protobuf.type_for_name(name) + response = message_handler.wrap_protobuf_load(raw_response.data, req_type) + + if Cancel.is_type_of(response): + ctx.channel_ctx.clear() + raise SilentError("Action was cancelled by the Host") + + if ThpSelectMethod.is_type_of(response): + assert response.selected_pairing_method is not None + ctx.set_selected_method(response.selected_pairing_method) + ctx.channel_ctx.set_channel_state(ChannelState.TP1) + else: + break + + response: protobuf.MessageType = await _handle_different_pairing_methods( + ctx, response + ) + return await handle_credential_phase( + ctx, + message=response, + show_connection_dialog=False, + ) + + +@check_state_and_log(ChannelState.TC1) +async def handle_credential_phase( + ctx: PairingContext, + message: protobuf.MessageType, + show_connection_dialog: bool = True, +) -> ThpEndResponse: + autoconnect: bool = False + credential = ctx.channel_ctx.credential + + if credential is not None: + autoconnect = is_credential_autoconnect(credential) + if credential.cred_metadata is not None: + ctx.host_name = credential.cred_metadata.host_name + if ctx.host_name is None: + raise Exception("Credential does not have a hostname") + + if show_connection_dialog and not autoconnect: + await ctx.show_connection_dialogue() + + while ThpCredentialRequest.is_type_of(message): + message = await _handle_credential_request(ctx, message) + + return await _handle_end_request(ctx, message) + + +async def _prepare_pairing(ctx: PairingContext) -> None: + ctx.channel_ctx.set_channel_state(ChannelState.TP1) + + if ctx.selected_method == ThpPairingMethod.CodeEntry: + await _handle_code_entry_is_selected(ctx) + elif ctx.selected_method == ThpPairingMethod.NFC: + await _handle_nfc_is_selected(ctx) + elif ctx.selected_method == ThpPairingMethod.QrCode: + await _handle_qr_code_is_selected(ctx) + else: + raise Exception() # TODO unknown pairing method + + +@check_state_and_log(ChannelState.TP1) +async def _handle_code_entry_is_selected(ctx: PairingContext) -> None: + if ctx.code_entry_secret is None: + await _handle_code_entry_is_selected_first_time(ctx) + else: + await ctx.write_force(ThpPairingPreparationsFinished()) + + +async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None: + from trezor.wire.thp.cpace import Cpace + + ctx.code_entry_secret = random.bytes(16) + commitment = sha256(ctx.code_entry_secret).digest() + + challenge_message = await ctx.call( + ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge + ) + ctx.channel_ctx.set_channel_state(ChannelState.TP2) + + if not ThpCodeEntryChallenge.is_type_of(challenge_message): + raise UnexpectedMessage("Unexpected message") + + if challenge_message.challenge is None: + raise Exception("Invalid message") + sha_ctx = sha256(ThpPairingMethod.CodeEntry.to_bytes(1, "big")) + sha_ctx.update(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.code_entry_secret) + sha_ctx.update(challenge_message.challenge) + code_code_entry_hash = sha_ctx.digest() + ctx.code_code_entry = int.from_bytes(code_code_entry_hash, "big") % 1000000 + ctx.cpace = Cpace( + ctx.channel_ctx.get_handshake_hash(), + ) + assert ctx.code_code_entry is not None + ctx.cpace.generate_keys_and_secret(ctx.code_code_entry.to_bytes(6, "big")) + await ctx.write_force( + ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key) + ) + + +@check_state_and_log(ChannelState.TP1) +async def _handle_nfc_is_selected(ctx: PairingContext) -> None: + ctx.nfc_secret = random.bytes(16) + await ctx.write_force(ThpPairingPreparationsFinished()) + + +@check_state_and_log(ChannelState.TP1) +async def _handle_qr_code_is_selected(ctx: PairingContext) -> None: + ctx.qr_code_secret = random.bytes(16) + + sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big")) + sha_ctx.update(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.qr_code_secret) + + ctx.code_qr_code = sha_ctx.digest()[:16] + await ctx.write_force(ThpPairingPreparationsFinished()) + + +@check_state_and_log(ChannelState.TP3) +async def _handle_different_pairing_methods( + ctx: PairingContext, response: protobuf.MessageType +) -> protobuf.MessageType: + if ThpCodeEntryCpaceHostTag.is_type_of(response): + return await _handle_code_entry_cpace(ctx, response) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcTagHost.is_type_of(response): + return await _handle_nfc_tag(ctx, response) + raise UnexpectedMessage("Unexpected message" + str(response)) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed_and_selected(ThpPairingMethod.CodeEntry) +async def _handle_code_entry_cpace( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + + if TYPE_CHECKING: + assert ThpCodeEntryCpaceHostTag.is_type_of(message) + if message.cpace_host_public_key is None: + raise ThpError( + "Message ThpCodeEntryCpaceHostTag is missing cpace_host_public_key" + ) + if message.tag is None: + raise ThpError("Message ThpCodeEntryCpaceHostTag is missing tag") + + ctx.cpace.compute_shared_secret(message.cpace_host_public_key) + expected_tag = sha256(ctx.cpace.shared_secret).digest() + if expected_tag != message.tag: + print( + "expected code entry tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + print( + "expected code entry shared secret:", + hexlify(ctx.cpace.shared_secret).decode(), + ) # TODO remove after testing + raise ThpError("Unexpected Code Entry Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpCodeEntrySecret(secret=ctx.code_entry_secret), + ) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed_and_selected(ThpPairingMethod.QrCode) +async def _handle_qr_code_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + if TYPE_CHECKING: + assert isinstance(message, ThpQrCodeTag) + assert ctx.code_qr_code is not None + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.code_qr_code) + expected_tag = sha_ctx.digest() + if expected_tag != message.tag: + print( + "expected qr code tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + print( + "expected handshake hash:", + hexlify(ctx.channel_ctx.get_handshake_hash()).decode(), + ) # TODO remove after testing + print( + "expected code qr code:", + hexlify(ctx.code_qr_code).decode(), + ) # TODO remove after testing + print( + "expected secret:", hexlify(ctx.qr_code_secret or b"").decode() + ) # TODO remove after testing + raise ThpError("Unexpected QR Code Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpQrCodeSecret(secret=ctx.qr_code_secret), + ) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed_and_selected(ThpPairingMethod.NFC) +async def _handle_nfc_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + if TYPE_CHECKING: + assert isinstance(message, ThpNfcTagHost) + + assert ctx.nfc_secret is not None + assert ctx.handshake_hash_host is not None + assert ctx.nfc_secret_host is not None + assert len(ctx.nfc_secret_host) == 16 + + sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) + sha_ctx.update(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.nfc_secret) + expected_tag = sha_ctx.digest() + if expected_tag != message.tag: + print( + "expected nfc tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + raise ThpError("Unexpected NFC Unidirectional Tag") + + if ctx.handshake_hash_host[:16] != ctx.channel_ctx.get_handshake_hash()[:16]: + raise ThpError("Handshake hash mismatch") + + sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) + sha_ctx.update(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.nfc_secret_host) + trezor_tag = sha_ctx.digest() + return await _handle_secret_reveal( + ctx, + msg=ThpNfcTagTrezor(tag=trezor_tag), + ) + + +@check_state_and_log(ChannelState.TP3, ChannelState.TP4) +async def _handle_secret_reveal( + ctx: PairingContext, + msg: protobuf.MessageType, +) -> protobuf.MessageType: + ctx.channel_ctx.set_channel_state(ChannelState.TC1) + return await ctx.call_any( + msg, + ThpMessageType.ThpCredentialRequest, + ThpMessageType.ThpEndRequest, + ) + + +@check_state_and_log(ChannelState.TC1) +async def _handle_credential_request( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + + if not ThpCredentialRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + if message.host_static_pubkey is None: + raise Exception("Invalid message") # TODO change failure type + + autoconnect: bool = False + if message.autoconnect is not None: + autoconnect = message.autoconnect + + trezor_static_pubkey = crypto.get_trezor_static_pubkey() + credential_metadata = ThpCredentialMetadata( + host_name=ctx.host_name, + autoconnect=autoconnect, + ) + credential = issue_credential(message.host_static_pubkey, credential_metadata) + + return await ctx.call_any( + ThpCredentialResponse( + trezor_static_pubkey=trezor_static_pubkey, credential=credential + ), + ThpMessageType.ThpCredentialRequest, + ThpMessageType.ThpEndRequest, + ) + + +@check_state_and_log(ChannelState.TC1) +async def _handle_end_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + if not ThpEndRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + return await _end_pairing(ctx) + + +async def _end_pairing(ctx: PairingContext) -> ThpEndResponse: + ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + return ThpEndResponse() + + +# +# Helpers - checkers + + +def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None: + if ctx.channel_ctx.get_channel_state() not in allowed_states: + raise UnexpectedMessage("Unexpected message") + + +def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None: + if method not in get_enabled_pairing_methods(ctx.iface): + raise ThpError("Unexpected pairing method") + + +def _check_method_is_selected(ctx: PairingContext, method: ThpPairingMethod) -> None: + if method is not ctx.selected_method: + raise ThpError("Not selected pairing method") + + +# +# Helpers - getters + + +def _get_accepted_messages(ctx: PairingContext) -> Tuple[int, ...]: + r = _get_possible_pairing_methods(ctx) + mtype = Cancel.MESSAGE_WIRE_TYPE + r += (mtype,) if mtype is not None else () + mtype = ThpSelectMethod.MESSAGE_WIRE_TYPE + r += (mtype,) if mtype is not None else () + + return r + + +def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]: + r = tuple( + [ + _get_message_type_for_method(ctx.selected_method), + ] + ) + return r + + +def _get_message_type_for_method(method: int) -> int: + if method is ThpPairingMethod.CodeEntry: + return ThpMessageType.ThpCodeEntryCpaceHostTag + if method is ThpPairingMethod.NFC: + return ThpMessageType.ThpNfcTagHost + if method is ThpPairingMethod.QrCode: + return ThpMessageType.ThpQrCodeTag + raise ValueError("Unexpected pairing method - no message type available") diff --git a/core/src/storage/__init__.py b/core/src/storage/__init__.py index 3a012874f3d..12532490901 100644 --- a/core/src/storage/__init__.py +++ b/core/src/storage/__init__.py @@ -1,12 +1,29 @@ # make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module +from typing import TYPE_CHECKING + from storage import cache, common, device +if TYPE_CHECKING: + from typing import Tuple + + pass -def wipe() -> None: + +def wipe(clear_cache: bool = True) -> None: + """ + Wipes the storage. + If the device should communicate after wipe, use `clear_cache=False` and clear cache manually later using + `wipe_cache()`. + """ from trezor import config config.wipe() - cache.clear_all() + if clear_cache: + cache.clear_all() + + +def wipe_cache(excluded: Tuple[bytes, bytes] | None = None) -> None: + cache.clear_all(excluded) def init_unlocked() -> None: @@ -21,12 +38,13 @@ def init_unlocked() -> None: common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True) -def reset() -> None: +def reset(excluded: Tuple[bytes, bytes] | None) -> None: """ Wipes storage but keeps the device id unchanged. """ device_id = device.get_device_id() - wipe() + wipe(clear_cache=False) + wipe_cache(excluded) common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 72d8a1e4188..6db224a782d 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,26 +1,47 @@ import builtins import gc +from typing import TYPE_CHECKING -from storage import cache_codec from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache +from trezor import utils + +if TYPE_CHECKING: + from typing import Tuple + + pass # Cache initialization _SESSIONLESS_CACHE = SessionlessCache() -_PROTOCOL_CACHE = cache_codec + + +if utils.USE_THP: + from storage import cache_thp + + _PROTOCOL_CACHE = cache_thp +else: + from storage import cache_codec + + _PROTOCOL_CACHE = cache_codec + _PROTOCOL_CACHE.initialize() _SESSIONLESS_CACHE.clear() gc.collect() -def clear_all() -> None: +def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None: """ Clears all data from both the protocol cache and the sessionless cache. """ global autolock_last_touch autolock_last_touch = None _SESSIONLESS_CACHE.clear() - _PROTOCOL_CACHE.clear_all() + + if utils.USE_THP and excluded is not None: + # If we want to keep THP connection alive, we do not clear communication keys + cache_thp.clear_all_except_one_session_keys(excluded) + else: + _PROTOCOL_CACHE.clear_all() def get_int_all_sessions(key: int) -> builtins.set[int]: diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py index 90cead81db5..40eee905ccd 100644 --- a/core/src/storage/cache_common.py +++ b/core/src/storage/cache_common.py @@ -14,6 +14,14 @@ APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) APP_MONERO_LIVE_REFRESH = const(7) +# Cache keys for THP channel +if utils.USE_THP: + CHANNEL_HANDSHAKE_HASH = const(0) + CHANNEL_KEY_RECEIVE = const(1) + CHANNEL_KEY_SEND = const(2) + CHANNEL_NONCE_RECEIVE = const(3) + CHANNEL_NONCE_SEND = const(4) + # Keys that are valid across sessions SESSIONLESS_FLAG = const(128) APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py new file mode 100644 index 00000000000..53ab1107550 --- /dev/null +++ b/core/src/storage/cache_thp.py @@ -0,0 +1,356 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache + +if TYPE_CHECKING: + from typing import Tuple + + pass + + +# THP specific constants +_MAX_CHANNELS_COUNT = const(10) +_MAX_SESSIONS_COUNT = const(20) + + +_CHANNEL_STATE_LENGTH = const(1) +_WIRE_INTERFACE_LENGTH = const(1) +_SESSION_STATE_LENGTH = const(1) +_CHANNEL_ID_LENGTH = const(2) +SESSION_ID_LENGTH = const(1) +BROADCAST_CHANNEL_ID = const(0xFFFF) +KEY_LENGTH = const(32) +TAG_LENGTH = const(16) +_UNALLOCATED_STATE = const(0) +_ALLOCATED_STATE = const(1) +_SEEDLESS_STATE = const(2) + + +class ThpDataCache(DataCache): + def __init__(self) -> None: + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + self.last_usage = 0 + super().__init__() + + def clear(self) -> None: + self.channel_id[:] = b"" + self.last_usage = 0 + super().clear() + + +class ChannelCache(ThpDataCache): + def __init__(self) -> None: + self.host_ephemeral_pubkey = bytearray(KEY_LENGTH) + self.state = bytearray(_CHANNEL_STATE_LENGTH) + self.iface = bytearray(1) # TODO add decoding + self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) + self.session_id_counter = 0x00 + self.fields = ( + 32, # CHANNEL_HANDSHAKE_HASH + 32, # CHANNEL_KEY_RECEIVE + 32, # CHANNEL_KEY_SEND + 8, # CHANNEL_NONCE_RECEIVE + 8, # CHANNEL_NONCE_SEND + ) + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray( + int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big") + ) # Set state to UNALLOCATED + self.host_ephemeral_pubkey[:] = bytearray(KEY_LENGTH) + self.state[:] = bytearray(_CHANNEL_STATE_LENGTH) + self.iface[:] = bytearray(1) + super().clear() + + +class SessionThpCache(ThpDataCache): + def __init__(self) -> None: + from trezor import utils + + self.session_id = bytearray(SESSION_ID_LENGTH) + self.state = bytearray(_SESSION_STATE_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 0, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 0, # APP_MONERO_LIVE_REFRESH + ) + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray( + int.to_bytes(0, _SESSION_STATE_LENGTH, "big") + ) # Set state to UNALLOCATED + self.session_id[:] = b"" + super().clear() + + +_CHANNELS: list[ChannelCache] = [] +_SESSIONS: list[SessionThpCache] = [] +cid_counter: int = 0 + +# Last-used counter +_usage_counter = 0 + + +def initialize() -> None: + global _CHANNELS + global _SESSIONS + global cid_counter + + for _ in range(_MAX_CHANNELS_COUNT): + _CHANNELS.append(ChannelCache()) + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionThpCache()) + + for channel in _CHANNELS: + channel.clear() + for session in _SESSIONS: + session.clear() + + from trezorcrypto import random + + cid_counter = random.uniform(0xFFFE) + + +def get_new_channel(iface: bytes) -> ChannelCache: + if len(iface) != _WIRE_INTERFACE_LENGTH: + raise Exception("Invalid WireInterface (encoded) length") + + new_cid = get_next_channel_id() + index = _get_next_channel_index() + + # clear sessions from replaced channel + if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE: + old_cid = _CHANNELS[index].channel_id + clear_sessions_with_channel_id(old_cid) + + _CHANNELS[index] = ChannelCache() + _CHANNELS[index].channel_id[:] = new_cid + _CHANNELS[index].last_usage = _get_usage_counter_and_increment() + _CHANNELS[index].state[:] = bytearray( + _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big") + ) + _CHANNELS[index].iface[:] = bytearray(iface) + return _CHANNELS[index] + + +def update_channel_last_used(channel_id: bytes) -> None: + for channel in _CHANNELS: + if channel.channel_id == channel_id: + channel.last_usage = _get_usage_counter_and_increment() + return + + +def update_session_last_used(channel_id: bytes, session_id: bytes) -> None: + for session in _SESSIONS: + if session.channel_id == channel_id and session.session_id == session_id: + session.last_usage = _get_usage_counter_and_increment() + update_channel_last_used(channel_id) + return + + +def get_all_allocated_channels() -> list[ChannelCache]: + _list: list[ChannelCache] = [] + for channel in _CHANNELS: + if _get_channel_state(channel) != _UNALLOCATED_STATE: + _list.append(channel) + return _list + + +def get_allocated_session( + channel_id: bytes, session_id: bytes +) -> SessionThpCache | None: + index = get_allocated_session_index(channel_id, session_id) + if index is None: + return index + return _SESSIONS[index] + + +def get_allocated_session_index(channel_id: bytes, session_id: bytes) -> int | None: + """ + Finds and returns index of the first allocated session matching the given `channel_id` + and `session_id`, or `None` if no match is found. + + Raises `Exception` if either channel_id or session_id has an invalid length. + """ + if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH: + raise Exception("At least one of arguments has invalid length") + + for i in range(_MAX_SESSIONS_COUNT): + if _get_session_state(_SESSIONS[i]) == _UNALLOCATED_STATE: + continue + if _SESSIONS[i].channel_id != channel_id: + continue + if _SESSIONS[i].session_id != session_id: + continue + return i + return None + + +def is_seedless_session(session_cache: SessionThpCache) -> bool: + return _get_session_state(session_cache) == _SEEDLESS_STATE + + +def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None: + if len(key) != KEY_LENGTH: + raise Exception("Invalid key length") + channel.host_ephemeral_pubkey = key + + +def create_or_replace_session( + channel: ChannelCache, session_id: bytes +) -> SessionThpCache: + index = get_allocated_session_index(channel.channel_id, session_id) + if index is None: + index = _get_next_session_index() + + _SESSIONS[index] = SessionThpCache() + _SESSIONS[index].channel_id[:] = channel.channel_id + _SESSIONS[index].session_id[:] = session_id + _SESSIONS[index].last_usage = _get_usage_counter_and_increment() + channel.last_usage = ( + _get_usage_counter_and_increment() + ) # increment also use of the channel so it does not get replaced + + _SESSIONS[index].state[:] = bytearray( + _ALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big") + ) + return _SESSIONS[index] + + +def _get_usage_counter_and_increment() -> int: + global _usage_counter + _usage_counter += 1 + return _usage_counter + + +def _get_next_channel_index() -> int: + idx = _get_unallocated_channel_index() + if idx is not None: + return idx + return _get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT) + + +def _get_next_session_index() -> int: + idx = _get_unallocated_session_index() + if idx is not None: + return idx + return _get_least_recently_used_item(_SESSIONS, max_count=_MAX_SESSIONS_COUNT) + + +def _get_unallocated_channel_index() -> int | None: + for i in range(_MAX_CHANNELS_COUNT): + if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_unallocated_session_index() -> int | None: + for i in range(_MAX_SESSIONS_COUNT): + if (_SESSIONS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_channel_state(channel: ChannelCache) -> int: + return int.from_bytes(channel.state, "big") + + +def _get_session_state(session: SessionThpCache) -> int: + return int.from_bytes(session.state, "big") + + +def get_next_channel_id() -> bytes: + global cid_counter + while True: + cid_counter += 1 + if cid_counter >= BROADCAST_CHANNEL_ID: + cid_counter = 1 + if _is_cid_unique(): + break + return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") + + +def _is_cid_unique() -> bool: + global cid_counter + cid_counter_bytes = cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") + for channel in _CHANNELS: + if channel.channel_id == cid_counter_bytes: + return False + return True + + +def _get_least_recently_used_item( + list: list[ChannelCache] | list[SessionThpCache], max_count: int +) -> int: + global _usage_counter + lru_counter = _usage_counter + 1 + lru_item_index = 0 + for i in range(max_count): + if list[i].last_usage < lru_counter: + lru_counter = list[i].last_usage + lru_item_index = i + return lru_item_index + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_sessions_with_channel_id(channel_id: bytes) -> None: + for session in _SESSIONS: + if session.channel_id == channel_id: + session.clear() + + +def clear_session(session: SessionThpCache) -> None: + for s in _SESSIONS: + if s.channel_id == session.channel_id and s.session_id == session.session_id: + session.clear() + + +def clear_all() -> None: + for session in _SESSIONS: + session.clear() + for channel in _CHANNELS: + channel.clear() + + +def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None: + cid, sid = excluded + + for channel in _CHANNELS: + if channel.channel_id != cid: + channel.clear() + + for session in _SESSIONS: + if session.channel_id != cid and session.session_id != sid: + session.clear() + else: + s_last_usage = session.last_usage + session.clear() + session.last_usage = s_last_usage + session.state = bytearray(_SEEDLESS_STATE.to_bytes(1, "big")) + session.session_id[:] = bytearray(sid) + session.channel_id[:] = bytearray(cid) diff --git a/core/src/trezor/enums/FailureType.py b/core/src/trezor/enums/FailureType.py index fbb2001e54c..e95dcb803fc 100644 --- a/core/src/trezor/enums/FailureType.py +++ b/core/src/trezor/enums/FailureType.py @@ -16,4 +16,8 @@ PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 +ThpUnallocatedSession = 15 +InvalidProtocol = 16 +BufferError = 17 +DeviceIsBusy = 18 FirmwareError = 99 diff --git a/core/src/trezor/enums/MessageType.py b/core/src/trezor/enums/MessageType.py index ed569795b0b..1a955fbcb7f 100644 --- a/core/src/trezor/enums/MessageType.py +++ b/core/src/trezor/enums/MessageType.py @@ -97,6 +97,8 @@ DebugLinkWatchLayout = 9006 DebugLinkResetDebugEvents = 9007 DebugLinkOptigaSetSecMax = 9008 +DebugLinkGetPairingInfo = 9009 +DebugLinkPairingInfo = 9010 BenchmarkListNames = 9100 BenchmarkNames = 9101 BenchmarkRun = 9102 diff --git a/core/src/trezor/enums/ThpMessageType.py b/core/src/trezor/enums/ThpMessageType.py new file mode 100644 index 00000000000..3ca4a27c358 --- /dev/null +++ b/core/src/trezor/enums/ThpMessageType.py @@ -0,0 +1,22 @@ +# Automatically generated by pb2py +# fmt: off +# isort:skip_file + +ThpCreateNewSession = 1000 +ThpPairingRequest = 1006 +ThpPairingRequestApproved = 1007 +ThpSelectMethod = 1008 +ThpPairingPreparationsFinished = 1009 +ThpCredentialRequest = 1010 +ThpCredentialResponse = 1011 +ThpEndRequest = 1012 +ThpEndResponse = 1013 +ThpCodeEntryCommitment = 1016 +ThpCodeEntryChallenge = 1017 +ThpCodeEntryCpaceTrezor = 1018 +ThpCodeEntryCpaceHostTag = 1019 +ThpCodeEntrySecret = 1020 +ThpQrCodeTag = 1024 +ThpQrCodeSecret = 1025 +ThpNfcTagHost = 1032 +ThpNfcTagTrezor = 1033 diff --git a/core/src/trezor/enums/ThpPairingMethod.py b/core/src/trezor/enums/ThpPairingMethod.py new file mode 100644 index 00000000000..0af2487182a --- /dev/null +++ b/core/src/trezor/enums/ThpPairingMethod.py @@ -0,0 +1,8 @@ +# Automatically generated by pb2py +# fmt: off +# isort:skip_file + +SkipPairing = 1 +CodeEntry = 2 +QrCode = 3 +NFC = 4 diff --git a/core/src/trezor/enums/__init__.py b/core/src/trezor/enums/__init__.py index d16c3c4a660..c39574773f6 100644 --- a/core/src/trezor/enums/__init__.py +++ b/core/src/trezor/enums/__init__.py @@ -39,6 +39,10 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 + BufferError = 17 + DeviceIsBusy = 18 FirmwareError = 99 class ButtonRequestType(IntEnum): @@ -347,6 +351,32 @@ class TezosBallotType(IntEnum): Nay = 1 Pass = 2 + class ThpMessageType(IntEnum): + ThpCreateNewSession = 1000 + ThpPairingRequest = 1006 + ThpPairingRequestApproved = 1007 + ThpSelectMethod = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceTrezor = 1018 + ThpCodeEntryCpaceHostTag = 1019 + ThpCodeEntrySecret = 1020 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcTagHost = 1032 + ThpNfcTagTrezor = 1033 + + class ThpPairingMethod(IntEnum): + SkipPairing = 1 + CodeEntry = 2 + QrCode = 3 + NFC = 4 + class MessageType(IntEnum): Initialize = 0 Ping = 1 @@ -447,6 +477,8 @@ class MessageType(IntEnum): DebugLinkWatchLayout = 9006 DebugLinkResetDebugEvents = 9007 DebugLinkOptigaSetSecMax = 9008 + DebugLinkGetPairingInfo = 9009 + DebugLinkPairingInfo = 9010 EthereumGetPublicKey = 450 EthereumPublicKey = 451 EthereumGetAddress = 56 diff --git a/core/src/trezor/messages.py b/core/src/trezor/messages.py index 1dbfcbc4078..9d0f58203fd 100644 --- a/core/src/trezor/messages.py +++ b/core/src/trezor/messages.py @@ -68,6 +68,8 @@ def __getattr__(name: str) -> Any: from trezor.enums import StellarSignerType # noqa: F401 from trezor.enums import TezosBallotType # noqa: F401 from trezor.enums import TezosContractType # noqa: F401 + from trezor.enums import ThpMessageType # noqa: F401 + from trezor.enums import ThpPairingMethod # noqa: F401 from trezor.enums import WordRequestType # noqa: F401 class BenchmarkListNames(protobuf.MessageType): @@ -2950,6 +2952,46 @@ def __init__( def is_type_of(cls, msg: Any) -> TypeGuard["DebugLinkState"]: return isinstance(msg, cls) + class DebugLinkGetPairingInfo(protobuf.MessageType): + channel_id: "bytes | None" + handshake_hash: "bytes | None" + nfc_secret_host: "bytes | None" + + def __init__( + self, + *, + channel_id: "bytes | None" = None, + handshake_hash: "bytes | None" = None, + nfc_secret_host: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["DebugLinkGetPairingInfo"]: + return isinstance(msg, cls) + + class DebugLinkPairingInfo(protobuf.MessageType): + channel_id: "bytes | None" + handshake_hash: "bytes | None" + code_entry_code: "int | None" + code_qr_code: "bytes | None" + nfc_secret_trezor: "bytes | None" + + def __init__( + self, + *, + channel_id: "bytes | None" = None, + handshake_hash: "bytes | None" = None, + code_entry_code: "int | None" = None, + code_qr_code: "bytes | None" = None, + nfc_secret_trezor: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["DebugLinkPairingInfo"]: + return isinstance(msg, cls) + class DebugLinkStop(protobuf.MessageType): @classmethod @@ -6164,13 +6206,281 @@ def __init__( def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]: return isinstance(msg, cls) + class ThpDeviceProperties(protobuf.MessageType): + internal_model: "str | None" + model_variant: "int | None" + protocol_version_major: "int | None" + protocol_version_minor: "int | None" + pairing_methods: "list[ThpPairingMethod]" + + def __init__( + self, + *, + pairing_methods: "list[ThpPairingMethod] | None" = None, + internal_model: "str | None" = None, + model_variant: "int | None" = None, + protocol_version_major: "int | None" = None, + protocol_version_minor: "int | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]: + return isinstance(msg, cls) + + class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + host_pairing_credential: "bytes | None" + + def __init__( + self, + *, + host_pairing_credential: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]: + return isinstance(msg, cls) + + class ThpCreateNewSession(protobuf.MessageType): + passphrase: "str | None" + on_device: "bool | None" + derive_cardano: "bool | None" + + def __init__( + self, + *, + passphrase: "str | None" = None, + on_device: "bool | None" = None, + derive_cardano: "bool | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]: + return isinstance(msg, cls) + + class ThpPairingRequest(protobuf.MessageType): + host_name: "str | None" + + def __init__( + self, + *, + host_name: "str | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingRequest"]: + return isinstance(msg, cls) + + class ThpPairingRequestApproved(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingRequestApproved"]: + return isinstance(msg, cls) + + class ThpSelectMethod(protobuf.MessageType): + selected_pairing_method: "ThpPairingMethod | None" + + def __init__( + self, + *, + selected_pairing_method: "ThpPairingMethod | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpSelectMethod"]: + return isinstance(msg, cls) + + class ThpPairingPreparationsFinished(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]: + return isinstance(msg, cls) + + class ThpCodeEntryCommitment(protobuf.MessageType): + commitment: "bytes | None" + + def __init__( + self, + *, + commitment: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]: + return isinstance(msg, cls) + + class ThpCodeEntryChallenge(protobuf.MessageType): + challenge: "bytes | None" + + def __init__( + self, + *, + challenge: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]: + return isinstance(msg, cls) + + class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + cpace_trezor_public_key: "bytes | None" + + def __init__( + self, + *, + cpace_trezor_public_key: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]: + return isinstance(msg, cls) + + class ThpCodeEntryCpaceHostTag(protobuf.MessageType): + cpace_host_public_key: "bytes | None" + tag: "bytes | None" + + def __init__( + self, + *, + cpace_host_public_key: "bytes | None" = None, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHostTag"]: + return isinstance(msg, cls) + + class ThpCodeEntrySecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]: + return isinstance(msg, cls) + + class ThpQrCodeTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]: + return isinstance(msg, cls) + + class ThpQrCodeSecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]: + return isinstance(msg, cls) + + class ThpNfcTagHost(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcTagHost"]: + return isinstance(msg, cls) + + class ThpNfcTagTrezor(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcTagTrezor"]: + return isinstance(msg, cls) + + class ThpCredentialRequest(protobuf.MessageType): + host_static_pubkey: "bytes | None" + autoconnect: "bool | None" + + def __init__( + self, + *, + host_static_pubkey: "bytes | None" = None, + autoconnect: "bool | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]: + return isinstance(msg, cls) + + class ThpCredentialResponse(protobuf.MessageType): + trezor_static_pubkey: "bytes | None" + credential: "bytes | None" + + def __init__( + self, + *, + trezor_static_pubkey: "bytes | None" = None, + credential: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]: + return isinstance(msg, cls) + + class ThpEndRequest(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]: + return isinstance(msg, cls) + + class ThpEndResponse(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]: + return isinstance(msg, cls) + class ThpCredentialMetadata(protobuf.MessageType): host_name: "str | None" + autoconnect: "bool | None" def __init__( self, *, host_name: "str | None" = None, + autoconnect: "bool | None" = None, ) -> None: pass diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 11ef808ff96..f3012e8da21 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -35,6 +35,10 @@ ) from typing import TYPE_CHECKING +DISABLE_ENCRYPTION: bool = False + +ALLOW_DEBUG_MESSAGES: bool = True + if __debug__: if EMULATOR: import uos diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 2662a5610aa..287ab3377b1 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -29,7 +29,12 @@ from trezor import log, loop, protobuf, utils from . import message_handler, protocol_common -from .codec.codec_context import CodecContext + +if utils.USE_THP: + from .thp import thp_main +else: + from .codec.codec_context import CodecContext + from .context import UnexpectedMessageException from .message_handler import failure @@ -37,10 +42,6 @@ # other packages. from .errors import * # isort:skip # noqa: F401,F403 -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - if TYPE_CHECKING: from trezorio import WireInterface from typing import Any, Callable, Coroutine, TypeVar @@ -57,57 +58,91 @@ def setup(iface: WireInterface) -> None: loop.schedule(handle_session(iface)) -async def handle_session(iface: WireInterface) -> None: - ctx = CodecContext(iface, WIRE_BUFFER) - next_msg: protocol_common.Message | None = None +if utils.USE_THP: + # memory_manager is imported to create READ/WRITE buffers + # in more stable area of memory + from .thp import memory_manager # noqa: F401 - # Take a mark of modules that are imported at this point, so we can - # roll back and un-import any others. - modules = utils.unimport_begin() - while True: - try: - if next_msg is None: - # If the previous run did not keep an unprocessed message for us, - # wait for a new one coming from the wire. - try: - msg = await ctx.read_from_wire() - except protocol_common.WireError as exc: - if __debug__: - log.exception(__name__, exc) - await ctx.write(failure(exc)) - continue + async def handle_session(iface: WireInterface) -> None: - else: - # Process the message from previous run. - msg = next_msg - next_msg = None + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() - do_not_restart = False + while True: try: - do_not_restart = await message_handler.handle_single_message(ctx, msg) - except UnexpectedMessageException as unexpected: - # The workflow was interrupted by an unexpected message. We need to - # process it as if it was a new message... - next_msg = unexpected.msg - # ...and we must not restart because that would lose the message. - do_not_restart = True - continue + await thp_main.thp_main_loop(iface) except Exception as exc: - # Log and ignore. The session handler can only exit explicitly in the - # following finally block. + # Log and try again. if __debug__: log.exception(__name__, exc) finally: # Unload modules imported by the workflow. Should not raise. + if __debug__: + log.debug(__name__, "utils.unimport_end(modules) and loop.clear()") utils.unimport_end(modules) + loop.clear() + return # pylint: disable=lost-exception + +else: + _PROTOBUF_BUFFER_SIZE = const(8192) + WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - if not do_not_restart: - # Let the session be restarted from `main`. - loop.clear() - return # pylint: disable=lost-exception + async def handle_session(iface: WireInterface) -> None: + ctx = CodecContext(iface, WIRE_BUFFER) + next_msg: protocol_common.Message | None = None - except Exception as exc: - # Log and try again. The session handler can only exit explicitly via - # loop.clear() above. - if __debug__: - log.exception(__name__, exc) + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() + while True: + try: + if next_msg is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one coming from the wire. + try: + msg = await ctx.read_from_wire() + except protocol_common.WireError as exc: + if __debug__: + log.exception(__name__, exc) + await ctx.write(failure(exc)) + continue + + else: + # Process the message from previous run. + msg = next_msg + next_msg = None + + do_not_restart = False + try: + do_not_restart = await message_handler.handle_single_message( + ctx, msg + ) + except UnexpectedMessageException as unexpected: + # The workflow was interrupted by an unexpected message. We need to + # process it as if it was a new message... + next_msg = unexpected.msg + # ...and we must not restart because that would lose the message. + do_not_restart = True + continue + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + utils.unimport_end(modules) + + if not do_not_restart: + # Let the session be restarted from `main`. + if __debug__: + log.debug(__name__, "loop.clear()") + loop.clear() + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. + if __debug__: + log.exception(__name__, exc) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 56df34fbc58..00bfeb77d4f 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -17,7 +17,7 @@ from storage import cache from storage.cache_common import SESSIONLESS_FLAG -from trezor import loop, protobuf +from trezor import loop, protobuf, utils from .protocol_common import Context, Message @@ -138,6 +138,17 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: send_exc = None +def try_get_ctx_ids() -> tuple[bytes, bytes] | None: + ids = None + if utils.USE_THP: + from trezor.wire.thp.session_context import GenericSessionContext + + ctx = get_context() + if isinstance(ctx, GenericSessionContext): + ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big")) + return ids + + # ACCESS TO CACHE if TYPE_CHECKING: diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index 376820b5834..8f572fcf0c2 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -8,6 +8,17 @@ def __init__(self, code: FailureType, message: str) -> None: self.message = message +class SilentError(Exception): + def __init__(self, message: str) -> None: + super().__init__() + self.message = message + + +class WireBufferError(Error): + def __init__(self, message: str = "Buffer error") -> None: + super().__init__(FailureType.BufferError, message) + + class UnexpectedMessage(Error): def __init__(self, message: str) -> None: super().__init__(FailureType.UnexpectedMessage, message) diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index 21c901dc90e..ace23d34ac2 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -25,7 +25,12 @@ def wrap_protobuf_load( expected_type: type[LoadedMessageType], ) -> LoadedMessageType: try: - if __debug__ and utils.EMULATOR and utils.USE_THP: + if ( + __debug__ + and utils.EMULATOR + and utils.USE_THP + and utils.ALLOW_DEBUG_MESSAGES + ): log.debug( __name__, "Buffer to be parsed to a LoadedMessage: %s", @@ -38,7 +43,7 @@ def wrap_protobuf_load( ) return msg except Exception as e: - if __debug__: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.exception(__name__, e) if e.args: raise DataError("Failed to decode message: " + " ".join(e.args)) @@ -46,6 +51,25 @@ def wrap_protobuf_load( raise DataError("Failed to decode message") +if utils.USE_THP: + from trezor.enums import ThpMessageType + + def get_msg_name(msg_type: int) -> str | None: + for name in dir(ThpMessageType): + if not name.startswith("__"): # Skip built-in attributes + value = getattr(ThpMessageType, name) + if isinstance(value, int): + if value == msg_type: + return name + return None + + def get_msg_type(msg_name: str) -> int | None: + value = getattr(ThpMessageType, msg_name) + if isinstance(value, int): + return value + return None + + async def handle_single_message(ctx: Context, msg: Message) -> bool: """Handle a message that was loaded from a WireInterface by the caller. @@ -60,17 +84,27 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: the type of message is supposed to be optimized and not disrupt the running state, this function will return `True`. """ - if __debug__: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: try: msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME except Exception: msg_type = f"{msg.type} - unknown message type" - log.debug( - __name__, - "%d receive: <%s>", - ctx.iface.iface_num(), - msg_type, - ) + if utils.USE_THP: + cid = int.from_bytes(ctx.channel_id, "big") + log.debug( + __name__, + "%d:%d receive: <%s>", + ctx.iface.iface_num(), + cid, + msg_type, + ) + else: + log.debug( + __name__, + "%d receive: <%s>", + ctx.iface.iface_num(), + msg_type, + ) res_msg: protobuf.MessageType | None = None @@ -132,7 +166,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: # - the message was not valid protobuf # - workflow raised some kind of an exception while running # - something canceled the workflow from the outside - if __debug__: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if isinstance(exc, ActionCancelled): log.debug(__name__, "cancelled: %s", exc.message) elif isinstance(exc, loop.TaskClosed): diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index ed4105517b1..0e54afe8c3c 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from trezorio import WireInterface - from typing import Container, TypeVar, overload + from typing import Awaitable, Container, TypeVar, overload from storage.cache_common import DataCache @@ -72,6 +72,9 @@ async def write(self, msg: protobuf.MessageType) -> None: """Write a message to the wire.""" ... + def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.write(msg) + async def call( self, msg: protobuf.MessageType, diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py new file mode 100644 index 00000000000..bbc5a62c6c2 --- /dev/null +++ b/core/src/trezor/wire/thp/__init__.py @@ -0,0 +1,189 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import protobuf, utils +from trezor.enums import ThpPairingMethod +from trezor.messages import ThpDeviceProperties + +from ..protocol_common import WireError + +if TYPE_CHECKING: + from enum import IntEnum + + from trezor.wire import WireInterface + from typing_extensions import Self +else: + IntEnum = object + +CODEC_V1 = const(0x3F) + +HANDSHAKE_INIT_REQ = const(0x00) +HANDSHAKE_INIT_RES = const(0x01) +HANDSHAKE_COMP_REQ = const(0x02) +HANDSHAKE_COMP_RES = const(0x03) +ENCRYPTED = const(0x04) + +ACK_MESSAGE = const(0x20) +CHANNEL_ALLOCATION_REQ = const(0x40) +_CHANNEL_ALLOCATION_RES = const(0x41) +_ERROR = const(0x42) +CONTINUATION_PACKET = const(0x80) + + +class ThpError(WireError): + pass + + +class ThpDecryptionError(ThpError): + pass + + +class ThpInvalidDataError(ThpError): + pass + + +class ThpDeviceLockedError(ThpError): + pass + + +class ThpUnallocatedSessionError(ThpError): + + def __init__(self, session_id: int) -> None: + self.session_id = session_id + + +class ThpErrorType(IntEnum): + TRANSPORT_BUSY = 1 + UNALLOCATED_CHANNEL = 2 + DECRYPTION_FAILED = 3 + INVALID_DATA = 4 + DEVICE_LOCKED = 5 + + +class ChannelState(IntEnum): + UNALLOCATED = 0 + TH1 = 1 + TH2 = 2 + TP0 = 3 + TP1 = 4 + TP2 = 5 + TP3 = 6 + TP4 = 7 + TC1 = 8 + ENCRYPTED_TRANSPORT = 9 + INVALIDATED = 10 + + +class SessionState(IntEnum): + UNALLOCATED = 0 + ALLOCATED = 1 + SEEDLESS = 2 + + +class PacketHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.length = length + + def to_bytes(self) -> bytes: + return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + """ + Packs header information in the form of **intial** packet + into the provided buffer. + """ + ustruct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + """ + Packs header information in the form of **continuation** packet header + into the provided buffer. + """ + ustruct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + @classmethod + def get_error_header(cls, cid: int, length: int) -> Self: + """ + Returns header for protocol-level error messages. + """ + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_response_header(cls, length: int) -> Self: + """ + Returns header for allocation response handshake message. + """ + return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length) + + +_DEFAULT_ENABLED_PAIRING_METHODS = [ + ThpPairingMethod.CodeEntry, + ThpPairingMethod.QrCode, + ThpPairingMethod.NFC, +] + + +def get_enabled_pairing_methods( + iface: WireInterface | None = None, +) -> list[ThpPairingMethod]: + """ + Returns pairing methods that are currently allowed by the device + with respect to the wire interface the host communicates on. + """ + methods = _DEFAULT_ENABLED_PAIRING_METHODS.copy() + if __debug__: + methods.append(ThpPairingMethod.SkipPairing) + return methods + + +def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties: + # TODO define model variants + return ThpDeviceProperties( + pairing_methods=get_enabled_pairing_methods(iface), + internal_model=utils.INTERNAL_MODEL, + model_variant=0, + protocol_version_major=2, + protocol_version_minor=0, + ) + + +def get_encoded_device_properties(iface: WireInterface) -> bytes: + props = _get_device_properties(iface) + length = protobuf.encoded_length(props) + encoded_properties = bytearray(length) + protobuf.encode(encoded_properties, props) + return encoded_properties + + +def get_channel_allocation_response( + nonce: bytes, new_cid: bytes, iface: WireInterface +) -> bytes: + props_msg = get_encoded_device_properties(iface) + return nonce + new_cid + props_msg + + +if __debug__: + + def state_to_str(state: int) -> str: + name = { + v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") + }.get(state) + if name is not None: + return name + return "UNKNOWN_STATE" diff --git a/core/src/trezor/wire/thp/alternating_bit_protocol.py b/core/src/trezor/wire/thp/alternating_bit_protocol.py new file mode 100644 index 00000000000..d8ba60c5b23 --- /dev/null +++ b/core/src/trezor/wire/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +from storage.cache_thp import ChannelCache +from trezor import log, utils +from trezor.wire.thp import ThpError + + +def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: + """ + Checks if: + - an ACK message is expected + - the received ACK message acknowledges correct sequence number (bit) + """ + if not _is_ack_expected(cache): + return False + + if not _has_ack_correct_sync_bit(cache, ack_bit): + return False + + return True + + +def _is_ack_expected(cache: ChannelCache) -> bool: + is_expected: bool = not is_sending_allowed(cache) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected: + log.debug(__name__, "Received unexpected ACK message") + return is_expected + + +def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: + is_correct: bool = get_send_seq_bit(cache) == sync_bit + if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct: + log.debug(__name__, "Received ACK message with wrong ack bit") + return is_correct + + +def is_sending_allowed(cache: ChannelCache) -> bool: + """ + Checks whether sending a message in the provided channel is allowed. + + Note: Sending a message in a channel before receipt of ACK message for the previously + sent message (in the channel) is prohibited, as it can lead to desynchronization. + """ + return bool(cache.sync >> 7) + + +def get_send_seq_bit(cache: ChannelCache) -> int: + """ + Returns the sequential number (bit) of the next message to be sent + in the provided channel. + """ + return (cache.sync & 0x20) >> 5 + + +def get_expected_receive_seq_bit(cache: ChannelCache) -> int: + """ + Returns the (expected) sequential number (bit) of the next message + to be received in the provided channel. + """ + return (cache.sync & 0x40) >> 6 + + +def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: + """ + Set the flag whether sending a message in this channel is allowed or not. + """ + cache.sync &= 0x7F + if sending_allowed: + cache.sync |= 0x80 + + +def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: + """ + Set the expected sequential number (bit) of the next message to be received + in the provided channel + """ + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) + if seq_bit not in (0, 1): + raise ThpError("Unexpected receive sync bit") + + # set second bit to "seq_bit" value + cache.sync &= 0xBF + if seq_bit: + cache.sync |= 0x40 + + +def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: + if seq_bit not in (0, 1): + raise ThpError("Unexpected send seq bit") + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "setting sync send seq bit to %d", seq_bit) + # set third bit to "seq_bit" value + cache.sync &= 0xDF + if seq_bit: + cache.sync |= 0x20 + + +def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: + """ + Set the sequential bit of the "next message to be send" to the opposite value, + i.e. 1 -> 0 and 0 -> 1 + """ + _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py new file mode 100644 index 00000000000..39baa27e491 --- /dev/null +++ b/core/src/trezor/wire/thp/channel.py @@ -0,0 +1,522 @@ +import ustruct +from typing import TYPE_CHECKING + +from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, +) +from storage.cache_thp import ( + SESSION_ID_LENGTH, + TAG_LENGTH, + ChannelCache, + clear_sessions_with_channel_id, +) +from trezor import log, loop, protobuf, utils, workflow +from trezor.wire.errors import WireBufferError + +from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError +from . import alternating_bit_protocol as ABP +from . import ( + checksum, + control_byte, + crypto, + interface_manager, + memory_manager, + received_message_handler, +) +from .checksum import CHECKSUM_LENGTH +from .transmission_loop import TransmissionLoop +from .writer import ( + CONT_HEADER_LENGTH, + INIT_HEADER_LENGTH, + MESSAGE_TYPE_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if __debug__: + from trezor.utils import get_bytes_as_str + + from . import state_to_str + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable + + from trezor.messages import ThpPairingCredential + + from .pairing_context import PairingContext + from .session_context import GenericSessionContext + + +class Channel: + """ + THP protocol encrypted communication channel. + """ + + def __init__(self, channel_cache: ChannelCache) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "channel initialization") + + # Channel properties + self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) + self.channel_cache: ChannelCache = channel_cache + self.channel_id: bytes = channel_cache.channel_id + + # Shared variables + self.buffer: utils.BufferType = bytearray(self.iface.TX_PACKET_LEN) + self.fallback_decrypt: bool = False + self.bytes_read: int = 0 + self.expected_payload_length: int = 0 + self.is_cont_packet_expected: bool = False + self.sessions: dict[int, GenericSessionContext] = {} + + # Objects for writing a message to a wire + self.transmission_loop: TransmissionLoop | None = None + self.write_task_spawn: loop.spawn | None = None + + # Temporary objects + self.handshake: crypto.Handshake | None = None + self.credential: ThpPairingCredential | None = None + self.connection_context: PairingContext | None = None + self.busy_decoder: crypto.BusyDecoder | None = None + self.temp_crc: int | None = None + self.temp_crc_compare: bytearray | None = None + self.temp_tag: bytearray | None = None + + def clear(self) -> None: + clear_sessions_with_channel_id(self.channel_id) + self.channel_cache.clear() + + # ACCESS TO CHANNEL_DATA + + def get_channel_id_int(self) -> int: + return int.from_bytes(self.channel_id, "big") + + def get_channel_state(self) -> int: + state = int.from_bytes(self.channel_cache.state, "big") + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("get_channel_state: ", state_to_str(state)) + return state + + def get_handshake_hash(self) -> bytes: + h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH) + assert h is not None + return h + + def set_channel_state(self, state: ChannelState) -> None: + self.channel_cache.state = bytearray(state.to_bytes(1, "big")) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("set_channel_state: ", state_to_str(state)) + + # READ and DECRYPT + + def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("receive packet") + + self._handle_received_packet(packet) + + try: + buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) + except WireBufferError: + pass # TODO ?? + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + try: + self._log("self.buffer: ", get_bytes_as_str(buffer)) + except Exception: + pass # TODO handle nicer - happens in fallback_decrypt + + if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: + self._finish_message() + if self.fallback_decrypt: + # TODO Check CRC and if valid, check tag, if valid update nonces + self._finish_fallback() + # TODO self.write() failure device is busy - use channel buffer to send this failure message!! + return None + return received_message_handler.handle_received_message(self, buffer) + elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: + self.is_cont_packet_expected = True + else: + raise ThpError( + "Read more bytes than is the expected length of the message!" + ) + return None + + def _handle_received_packet(self, packet: utils.BufferType) -> None: + ctrl_byte = packet[0] + if control_byte.is_continuation(ctrl_byte): + self._handle_cont_packet(packet) + return + self._handle_init_packet(packet) + + def _handle_init_packet(self, packet: utils.BufferType) -> None: + self.fallback_decrypt = False + self.bytes_read = 0 + self.expected_payload_length = 0 + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("handle_init_packet") + + _, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) + self.expected_payload_length = payload_length + + # If the channel does not "own" the buffer lock, decrypt first packet + # TODO do it only when needed! + # TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation. + # On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented + # in "memory_manager.select_buffer", because the session id is unknown (encrypted). + + # if control_byte.is_encrypted_transport(ctrl_byte): + # packet_payload = self._decrypt_single_packet_payload(packet_payload) + + cid = self.get_channel_id_int() + length = payload_length + INIT_HEADER_LENGTH + try: + buffer = memory_manager.get_new_read_buffer(cid, length) + except WireBufferError: + # TODO handle not encrypted/(short??), eg. ACK + + self.fallback_decrypt = True + + self._prepare_fallback() + + to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length) + buf = memoryview(self.buffer)[:to_read_len] + utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH) + + # CRC CHECK + self._handle_fallback_crc(buf) + + # TAG CHECK + self._handle_fallback_decryption(buf) + + self.bytes_read += to_read_len + return + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("handle_init_packet - payload len: ", str(payload_length)) + self._log("handle_init_packet - buffer len: ", str(len(buffer))) + + self._buffer_packet_data(buffer, packet, 0) + + def _handle_fallback_crc(self, buf: memoryview) -> None: + assert self.temp_crc is not None + assert self.temp_crc_compare is not None + if self.expected_payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH: + # The CRC checksum is not in this packet, compute crc over whole buffer + self.temp_crc = checksum.compute_int(buf, self.temp_crc) + elif self.expected_payload_length >= len(buf) + self.bytes_read: + # At least a part of the CRC checksum is in this packet, compute CRC over + # the first (max(0, crc_copy_len)) bytes and add the rest of the bytes + # (max 4) as the checksum from message into temp_crc_compare + crc_copy_len = ( + self.expected_payload_length - self.bytes_read - CHECKSUM_LENGTH + ) + self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc) + + crc_checksum = buf[ + self.expected_payload_length + - CHECKSUM_LENGTH + - len(buf) + - self.bytes_read : + ] + offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:]) + utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0) + else: + raise Exception( + f"Buffer (+bytes_read) ({len(buf)}+{self.bytes_read})should not be bigger than payload{self.expected_payload_length}" + ) + + def _handle_fallback_decryption(self, buf: memoryview) -> None: + assert self.busy_decoder is not None + assert self.temp_tag is not None + if ( + self.expected_payload_length + > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH + ): + # The noise tag is not in this packet, decrypt the whole buffer + self.busy_decoder.decrypt_part(buf) + elif self.expected_payload_length >= len(buf) + self.bytes_read: + # At least a part of the noise tag is in this packet, decrypt + # the first (max(0, dec_len)) bytes and add the rest of the bytes + # as the noise_tag from message into temp_tag + dec_len = ( + self.expected_payload_length + - self.bytes_read + - TAG_LENGTH + - CHECKSUM_LENGTH + ) + self.busy_decoder.decrypt_part(buf[:dec_len]) + + noise_tag = buf[ + self.expected_payload_length + - CHECKSUM_LENGTH + - TAG_LENGTH + - len(buf) + - self.bytes_read : + ] + offset = ( + TAG_LENGTH + CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH - TAG_LENGTH :]) + ) + utils.memcpy(self.temp_tag, offset, noise_tag, 0) + else: + raise Exception("Buffer (+bytes_read) should not be bigger than payload") + + def _handle_cont_packet(self, packet: utils.BufferType) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("handle_cont_packet") + + if not self.is_cont_packet_expected: + raise ThpError("Continuation packet is not expected, ignoring") + + if self.fallback_decrypt: + to_read_len = min( + len(packet) - CONT_HEADER_LENGTH, + self.expected_payload_length - self.bytes_read, + ) + buf = memoryview(self.buffer)[:to_read_len] + utils.memcpy(buf, 0, packet, CONT_HEADER_LENGTH) + + # CRC CHECK + self._handle_fallback_crc(buf) + + # TAG CHECK + self._handle_fallback_decryption(buf) + + self.bytes_read += to_read_len + return + try: + buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) + except WireBufferError: + self.set_channel_state(ChannelState.INVALIDATED) + pass # TODO handle device busy, channel kaput + self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH) + + def _buffer_packet_data( + self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int + ) -> None: + self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) + + def _finish_message(self) -> None: + self.bytes_read = 0 + self.expected_payload_length = 0 + self.is_cont_packet_expected = False + + def _finish_fallback(self) -> None: + self.fallback_decrypt = False + self.busy_decoder = None + + def _decrypt_single_packet_payload( + self, payload: utils.BufferType + ) -> utils.BufferType: + # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) + return payload + + def _prepare_fallback(self) -> None: + # prepare busy decoder + key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) + nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) + + assert key_receive is not None + assert nonce_receive is not None + + self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive) + + # prepare temp channel values + self.temp_crc = 0 + self.temp_crc_compare = bytearray(4) + self.temp_tag = bytearray(16) + # self.bytes_read = INIT_HEADER_LENGTH + + def decrypt_buffer( + self, message_length: int, offset: int = INIT_HEADER_LENGTH + ) -> None: + buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) + # if buffer is WireBufferError: + # pass # TODO handle deviceBUSY + noise_buffer = memoryview(buffer)[ + offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH + ] + tag = buffer[ + message_length + - CHECKSUM_LENGTH + - TAG_LENGTH : message_length + - CHECKSUM_LENGTH + ] + + if utils.DISABLE_ENCRYPTION: + is_tag_valid = tag == crypto.DUMMY_TAG + else: + key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) + nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) + + assert key_receive is not None + assert nonce_receive is not None + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("Buffer before decryption: ", get_bytes_as_str(noise_buffer)) + + is_tag_valid = crypto.dec(noise_buffer, tag, key_receive, nonce_receive) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("Buffer after decryption: ", get_bytes_as_str(noise_buffer)) + + self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("Is decrypted tag valid? ", str(is_tag_valid)) + self._log("Received tag: ", get_bytes_as_str(tag)) + self._log("New nonce_receive: ", str((nonce_receive + 1))) + + if not is_tag_valid: + raise ThpDecryptionError() + + # WRITE and ENCRYPT + + async def write( + self, + msg: protobuf.MessageType, + session_id: int = 0, + force: bool = False, + ) -> None: + if __debug__ and utils.EMULATOR: + self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) + + cid = self.get_channel_id_int() + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH + try: + buffer = memory_manager.get_new_write_buffer(cid, length) + noise_payload_len = memory_manager.encode_into_buffer( + buffer, msg, session_id + ) + except WireBufferError: + from trezor.messages import Failure, FailureType + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("Failed to get write buffer, killing channel.") + + noise_payload_len = memory_manager.encode_into_buffer( + self.buffer, + Failure( + code=FailureType.FirmwareError, + message="Failed to obtain write buffer.", + ), + session_id, + ) + self.set_channel_state(ChannelState.INVALIDATED) + task = self._write_and_encrypt(noise_payload_len, force) + if task is not None: + await task + + def write_error(self, err_type: int) -> Awaitable[None]: + msg_data = err_type.to_bytes(1, "big") + length = len(msg_data) + CHECKSUM_LENGTH + header = PacketHeader.get_error_header(self.get_channel_id_int(), length) + return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) + + def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: + self._prepare_write() + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop(ctrl_byte, payload) + ) + + def _write_and_encrypt( + self, noise_payload_len: int, force: bool = False + ) -> Awaitable[None] | None: + buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) + # if buffer is WireBufferError: + # pass # TODO handle deviceBUSY + + self._encrypt(buffer, noise_payload_len) + payload_length = noise_payload_len + TAG_LENGTH + + if self.write_task_spawn is not None: + self.write_task_spawn.close() # UPS TODO might break something + print("\nCLOSED\n") + self._prepare_write() + if force: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("Writing FORCE message (without async or retransmission).") + + return self._write_encrypted_payload_loop( + ENCRYPTED, memoryview(buffer[:payload_length]) + ) + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop( + ENCRYPTED, memoryview(buffer[:payload_length]) + ) + ) + return None + + def _prepare_write(self) -> None: + # TODO add condition that disallows to write when can_send_message is false + ABP.set_sending_allowed(self.channel_cache, False) + + async def _write_encrypted_payload_loop( + self, ctrl_byte: int, payload: bytes + ) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("write_encrypted_payload_loop") + + payload_len = len(payload) + CHECKSUM_LENGTH + sync_bit = ABP.get_send_seq_bit(self.channel_cache) + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) + header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len) + self.transmission_loop = TransmissionLoop(self, header, payload) + await self.transmission_loop.start() + + ABP.set_send_seq_bit_to_opposite(self.channel_cache) + + # Let the main loop be restarted and clear loop, if there is no other + # workflow and the state is ENCRYPTED_TRANSPORT + if self._can_clear_loop(): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("clearing loop from channel") + + loop.clear() + + def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("encrypt") + + assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH + + noise_buffer = memoryview(buffer)[0:noise_payload_len] + + if utils.DISABLE_ENCRYPTION: + tag = crypto.DUMMY_TAG + else: + key_send = self.channel_cache.get(CHANNEL_KEY_SEND) + nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND) + + assert key_send is not None + assert nonce_send is not None + + tag = crypto.enc(noise_buffer, key_send, nonce_send) + + self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("New nonce_send: ", str((nonce_send + 1))) + + buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag + + def _can_clear_loop(self) -> bool: + return ( + not workflow.tasks + ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT + + if __debug__: + + def _log(self, text_1: str, text_2: str = "") -> None: + log.debug( + __name__, + "(cid: %s) %s%s", + get_bytes_as_str(self.channel_id), + text_1, + text_2, + ) diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py new file mode 100644 index 00000000000..75de9485f95 --- /dev/null +++ b/core/src/trezor/wire/thp/channel_manager.py @@ -0,0 +1,30 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp + +from . import ChannelState, interface_manager +from .channel import Channel + +if TYPE_CHECKING: + from trezorio import WireInterface + + +def create_new_channel(iface: WireInterface) -> Channel: + """ + Creates a new channel for the interface `iface`. + """ + channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface)) + channel = Channel(channel_cache) + channel.set_channel_state(ChannelState.TH1) + return channel + + +def load_cached_channels() -> dict[int, Channel]: + """ + Returns all allocated channels from cache. + """ + channels: dict[int, Channel] = {} + cached_channels = cache_thp.get_all_allocated_channels() + for channel in cached_channels: + channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel) + return channels diff --git a/core/src/trezor/wire/thp/checksum.py b/core/src/trezor/wire/thp/checksum.py new file mode 100644 index 00000000000..44aab466309 --- /dev/null +++ b/core/src/trezor/wire/thp/checksum.py @@ -0,0 +1,33 @@ +from micropython import const + +from trezor import utils +from trezor.crypto import crc + +CHECKSUM_LENGTH = const(4) + + +def compute(data: bytes | utils.BufferType, crc_chain: int = 0) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. Allows for for chaining + computations over multiple data segments using `crc_chain` (optional). + """ + return crc.crc32(data, crc_chain).to_bytes(CHECKSUM_LENGTH, "big") + + +def compute_int(data: bytes | utils.BufferType, crc_chain: int = 0) -> int: + """ + Returns a CRC-32 checksum of the provided `data`. Allows for for chaining + computations over multiple data segments using `crc_chain` (optional). + + Returns checksum in the form of `int`. + """ + return crc.crc32(data, crc_chain) + + +def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/core/src/trezor/wire/thp/control_byte.py b/core/src/trezor/wire/thp/control_byte.py new file mode 100644 index 00000000000..5d4d69b0400 --- /dev/null +++ b/core/src/trezor/wire/thp/control_byte.py @@ -0,0 +1,50 @@ +from micropython import const + +from . import ( + ACK_MESSAGE, + CONTINUATION_PACKET, + ENCRYPTED, + HANDSHAKE_COMP_REQ, + HANDSHAKE_INIT_REQ, + ThpError, +) + +_CONTINUATION_PACKET_MASK = const(0x80) +_ACK_MASK = const(0xF7) +_DATA_MASK = const(0xE7) + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise ThpError("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise ThpError("Unexpected acknowledgement bit") + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & _ACK_MASK == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & _CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & _DATA_MASK == ENCRYPTED + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & _DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & _DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/core/src/trezor/wire/thp/cpace.py b/core/src/trezor/wire/thp/cpace.py new file mode 100644 index 00000000000..fad0f705d3b --- /dev/null +++ b/core/src/trezor/wire/thp/cpace.py @@ -0,0 +1,37 @@ +from trezor.crypto import elligator2, random +from trezor.crypto.curve import curve25519 +from trezor.crypto.hashlib import sha512 + +_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06" +_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20" + + +class Cpace: + """ + CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/ + """ + + def __init__(self, handshake_hash: bytes) -> None: + self.handshake_hash: bytes = handshake_hash + self.shared_secret: bytes + self.trezor_private_key: bytes + self.trezor_public_key: bytes + + def generate_keys_and_secret(self, code_code_entry: bytes) -> None: + """ + Generate ephemeral key pair and a shared secret using Elligator2 with X25519. + """ + sha_ctx = sha512(_PREFIX) + sha_ctx.update(code_code_entry) + sha_ctx.update(_PADDING) + sha_ctx.update(self.handshake_hash) + sha_ctx.update(b"\x00") + pregenerator = sha_ctx.digest()[:32] + generator = elligator2.map_to_curve25519(pregenerator) + self.trezor_private_key = random.bytes(32) + self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator) + + def compute_shared_secret(self, host_public_key: bytes) -> None: + self.shared_secret = curve25519.multiply( + self.trezor_private_key, host_public_key + ) diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py new file mode 100644 index 00000000000..4ba7fc71c9e --- /dev/null +++ b/core/src/trezor/wire/thp/crypto.py @@ -0,0 +1,220 @@ +from micropython import const +from trezorcrypto import aesgcm, bip32, curve25519, hmac + +from storage import device +from trezor import log, utils +from trezor.crypto.hashlib import sha256 +from trezor.wire.thp import ThpDecryptionError + +# The HARDENED flag is taken from apps.common.paths +# It is not imported to save on resources +HARDENED = const(0x8000_0000) +PUBKEY_LENGTH = const(32) +if utils.DISABLE_ENCRYPTION: + DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" + +if __debug__: + from trezor.utils import get_bytes_as_str + + +def enc( + buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes = b"" +) -> bytes: + """ + Encrypts the provided `buffer` with AES-GCM (in place). + Returns a 16-byte long encryption tag. + """ + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "enc (key: %s, nonce: %d)", get_bytes_as_str(key), nonce) + iv = _get_iv_from_nonce(nonce) + aes_ctx = aesgcm(key, iv) + aes_ctx.auth(auth_data) + aes_ctx.encrypt_in_place(buffer) + return aes_ctx.finish() + + +def dec( + buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes = b"" +) -> bool: + """ + Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as + the tag computed in decryption, otherwise it returns `False`. + """ + iv = _get_iv_from_nonce(nonce) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "dec (key: %s, nonce: %d)", get_bytes_as_str(key), nonce) + aes_ctx = aesgcm(key, iv) + aes_ctx.auth(auth_data) + aes_ctx.decrypt_in_place(buffer) + computed_tag = aes_ctx.finish() + return computed_tag == tag + + +class BusyDecoder: + + def __init__(self, key: bytes, nonce: int, auth_data: bytes = b"") -> None: + iv = _get_iv_from_nonce(nonce) + self.aes_ctx = aesgcm(key, iv) + self.aes_ctx.auth(auth_data) + + def decrypt_part(self, part: utils.BufferType) -> None: + self.aes_ctx.decrypt_in_place(part) + + def finish_and_check_tag(self, tag: bytes) -> bool: + computed_tag = self.aes_ctx.finish() + return computed_tag == tag + + +PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" +IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + + +class Handshake: + """ + `Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel. + The following values should be saved for future use before disposing of this object: + - `h` - handshake hash, can be used to bind other values to the channel + - `key_receive` - key for decrypting incoming communication + - `key_send` - key for encrypting outgoing communication + """ + + def __init__(self) -> None: + self.trezor_ephemeral_privkey: bytes + self.ck: bytes + self.k: bytes + self.h: bytes + self.key_receive: bytes + self.key_send: bytes + + def handle_th1_crypto( + self, + device_properties: bytes, + host_ephemeral_pubkey: bytes, + ) -> tuple[bytes, bytes, bytes]: + + trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair() + self.trezor_ephemeral_privkey = curve25519.generate_secret() + trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey) + self.h = _hash_of_two(PROTOCOL_NAME, device_properties) + self.h = _hash_of_two(self.h, host_ephemeral_pubkey) + self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey) + point = curve25519.multiply( + self.trezor_ephemeral_privkey, host_ephemeral_pubkey + ) + self.ck, self.k = _hkdf(PROTOCOL_NAME, point) + mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey) + trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey) + aes_ctx = aesgcm(self.k, IV_1) + encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey) + if __debug__: + log.debug( + __name__, "th1 - enc (key: %s, nonce: %d)", get_bytes_as_str(self.k), 0 + ) + aes_ctx.auth(self.h) + tag_to_encrypted_key = aes_ctx.finish() + encrypted_trezor_static_pubkey = ( + encrypted_trezor_static_pubkey + tag_to_encrypted_key + ) + self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey) + point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey) + self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point)) + aes_ctx = aesgcm(self.k, IV_1) + aes_ctx.auth(self.h) + tag = aes_ctx.finish() + self.h = _hash_of_two(self.h, tag) + return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag) + + def handle_th2_crypto( + self, + encrypted_host_static_pubkey: utils.BufferType, + encrypted_payload: utils.BufferType, + ) -> None: + + aes_ctx = aesgcm(self.k, IV_2) + + # The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted. + # However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for + # authentication of the gcm tag. + aes_ctx.auth(self.h) # Authenticate with the previous value of `h` + self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value + aes_ctx.decrypt_in_place( + memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] + ) + if __debug__: + log.debug( + __name__, "th2 - dec (key: %s, nonce: %d)", get_bytes_as_str(self.k), 1 + ) + host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] + tag = aes_ctx.finish() + if tag != encrypted_host_static_pubkey[-16:]: + raise ThpDecryptionError() + + self.ck, self.k = _hkdf( + self.ck, + curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey), + ) + aes_ctx = aesgcm(self.k, IV_1) + aes_ctx.auth(self.h) + aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16]) + if __debug__: + log.debug( + __name__, "th2 - dec (key: %s, nonce: %d)", get_bytes_as_str(self.k), 0 + ) + tag = aes_ctx.finish() + if tag != encrypted_payload[-16:]: + raise ThpDecryptionError() + + self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16]) + self.key_receive, self.key_send = _hkdf(self.ck, b"") + if __debug__: + log.debug( + __name__, + "(key_receive: %s, key_send: %s)", + get_bytes_as_str(self.key_receive), + get_bytes_as_str(self.key_send), + ) + + def get_handshake_completion_response(self, trezor_state: bytes) -> bytes: + aes_ctx = aesgcm(self.key_send, IV_1) + encrypted_trezor_state = aes_ctx.encrypt(trezor_state) + tag = aes_ctx.finish() + return encrypted_trezor_state + tag + + +def _derive_static_key_pair() -> tuple[bytes, bytes]: + node_int = HARDENED | int.from_bytes(b"\x00THP", "big") + node = bip32.from_seed(device.get_device_secret(), "curve25519") + node.derive(node_int) + + trezor_static_privkey = node.private_key() + trezor_static_pubkey = node.public_key()[1:33] + # Note: the first byte (\x01) of the public key is removed, as it + # only indicates the type of the elliptic curve used + + return trezor_static_privkey, trezor_static_pubkey + + +def get_trezor_static_pubkey() -> bytes: + _, pubkey = _derive_static_key_pair() + return pubkey + + +def _hkdf(chaining_key: bytes, input: bytes) -> tuple[bytes, bytes]: + temp_key = hmac(hmac.SHA256, chaining_key, input).digest() + output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest() + ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes: + ctx = sha256(part_1) + ctx.update(part_2) + return ctx.digest() + + +def _get_iv_from_nonce(nonce: int) -> bytes: + utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") diff --git a/core/src/trezor/wire/thp/interface_manager.py b/core/src/trezor/wire/thp/interface_manager.py new file mode 100644 index 00000000000..a1fecfe7d64 --- /dev/null +++ b/core/src/trezor/wire/thp/interface_manager.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +import usb + +_WIRE_INTERFACE_USB = b"\x01" +# TODO _WIRE_INTERFACE_BLE = b"\x02" + +if TYPE_CHECKING: + from trezorio import WireInterface + + +def decode_iface(cached_iface: bytes) -> WireInterface: + """Decode the cached wire interface.""" + if cached_iface == _WIRE_INTERFACE_USB: + iface = usb.iface_wire + if iface is None: + raise RuntimeError("There is no valid USB WireInterface") + return iface + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") + + +def encode_iface(iface: WireInterface) -> bytes: + """Encode wire interface into bytes.""" + if iface is usb.iface_wire: + return _WIRE_INTERFACE_USB + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py new file mode 100644 index 00000000000..681642d8fc6 --- /dev/null +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -0,0 +1,180 @@ +import utime +from micropython import const + +from storage.cache_thp import SESSION_ID_LENGTH +from trezor import protobuf, utils +from trezor.wire.errors import WireBufferError +from trezor.wire.message_handler import get_msg_type + +from . import ThpError +from .writer import MAX_PAYLOAD_LEN, MESSAGE_TYPE_LENGTH + +_PROTOBUF_BUFFER_SIZE = 8192 +READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) +WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) +LOCK_TIMEOUT = 200 # miliseconds + + +lock_owner_cid: int | None = None +lock_time: int = 0 + +READ_BUFFER_SLICE: memoryview | None = None +WRITE_BUFFER_SLICE: memoryview | None = None + +# Buffer types +_READ: int = const(0) +_WRITE: int = const(1) + + +# +# Access to buffer slices + + +def get_new_read_buffer(channel_id: int, length: int) -> memoryview: + return _get_new_buffer(_READ, channel_id, length) + + +def get_new_write_buffer(channel_id: int, length: int) -> memoryview: + return _get_new_buffer(_WRITE, channel_id, length) + + +def get_existing_read_buffer(channel_id: int) -> memoryview: + return _get_existing_buffer(_READ, channel_id) + + +def get_existing_write_buffer(channel_id: int) -> memoryview: + return _get_existing_buffer(_WRITE, channel_id) + + +def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview: + if is_locked(): + if not is_owner(channel_id): + raise WireBufferError + update_lock_time() + else: + update_lock(channel_id) + + if buffer_type == _READ: + global READ_BUFFER + buffer = READ_BUFFER + elif buffer_type == _WRITE: + global WRITE_BUFFER + buffer = WRITE_BUFFER + else: + raise Exception("Invalid buffer_type") + + if length > MAX_PAYLOAD_LEN or length > len(buffer): + raise ThpError("Message is too large") # TODO reword + + if buffer_type == _READ: + global READ_BUFFER_SLICE + READ_BUFFER_SLICE = memoryview(READ_BUFFER)[:length] + return READ_BUFFER_SLICE + + if buffer_type == _WRITE: + global WRITE_BUFFER_SLICE + WRITE_BUFFER_SLICE = memoryview(WRITE_BUFFER)[:length] + return WRITE_BUFFER_SLICE + + raise Exception("Invalid buffer_type") + + +def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview: + if not is_owner(channel_id): + raise WireBufferError + update_lock_time() + + if buffer_type == _READ: + global READ_BUFFER_SLICE + if READ_BUFFER_SLICE is None: + raise WireBufferError + return READ_BUFFER_SLICE + + if buffer_type == _WRITE: + global WRITE_BUFFER_SLICE + if WRITE_BUFFER_SLICE is None: + raise WireBufferError + return WRITE_BUFFER_SLICE + + raise Exception("Invalid buffer_type") + + +# +# Buffer locking + + +def is_locked() -> bool: + global lock_owner_cid + global lock_time + + time_diff = utime.ticks_diff(utime.ticks_ms(), lock_time) + return lock_owner_cid is not None and time_diff < LOCK_TIMEOUT + + +def is_owner(channel_id: int) -> bool: + global lock_owner_cid + return lock_owner_cid is not None and lock_owner_cid == channel_id + + +def update_lock(channel_id: int) -> None: + set_owner(channel_id) + update_lock_time() + + +def set_owner(channel_id: int) -> None: + global lock_owner_cid + lock_owner_cid = channel_id + + +def update_lock_time() -> None: + global lock_time + lock_time = utime.ticks_ms() + + +# +# Helper for encoding messages into buffer + + +def encode_into_buffer( + buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int +) -> int: + """Encode protobuf message `msg` into the `buffer`, including session id + an messages's wire type. Will fail if provided message has no wire type.""" + + # cannot write message without wire type + msg_type = msg.MESSAGE_WIRE_TYPE + if msg_type is None: + msg_type = get_msg_type(msg.MESSAGE_NAME) + if msg_type is None: + raise Exception("Message has no wire type.") + + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + + _encode_session_into_buffer(memoryview(buffer), session_id) + _encode_message_type_into_buffer(memoryview(buffer), msg_type, SESSION_ID_LENGTH) + _encode_message_into_buffer( + memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + ) + + return payload_size + + +def _encode_session_into_buffer( + buffer: memoryview, session_id: int, buffer_offset: int = 0 +) -> None: + session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") + utils.memcpy(buffer, buffer_offset, session_id_bytes, 0) + + +def _encode_message_type_into_buffer( + buffer: memoryview, message_type: int, offset: int = 0 +) -> None: + msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big") + utils.memcpy(buffer, offset, msg_type_bytes, 0) + + +def _encode_message_into_buffer( + buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 +) -> None: + protobuf.encode(memoryview(buffer[buffer_offset:]), message) diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py new file mode 100644 index 00000000000..081a745890d --- /dev/null +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -0,0 +1,329 @@ +from typing import TYPE_CHECKING +from ubinascii import hexlify + +import trezorui_api +from trezor import loop, protobuf, workflow +from trezor.enums import ButtonRequestType +from trezor.wire import context, message_handler, protocol_common +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.errors import ActionCancelled, SilentError +from trezor.wire.protocol_common import Context, Message +from trezor.wire.thp import ChannelState, get_enabled_pairing_methods + +if TYPE_CHECKING: + from typing import Awaitable, Container + + from trezor.enums import ThpPairingMethod + from trezorui_api import UiResult + + from .channel import Channel + from .cpace import Cpace + + pass + +if __debug__: + from trezor import log + + +class PairingContext(Context): + + def __init__(self, channel_ctx: Channel) -> None: + super().__init__(channel_ctx.iface, channel_ctx.channel_id) + self.channel_ctx: Channel = channel_ctx + self.incoming_message = loop.mailbox() + self.nfc_secret: bytes | None = None + self.qr_code_secret: bytes | None = None + self.code_entry_secret: bytes | None = None + + self.selected_method: ThpPairingMethod + + self.code_code_entry: int | None = None + self.code_qr_code: bytes | None = None + self.code_nfc: bytes | None = None + # The 2 following attributes are important for NFC pairing + self.nfc_secret_host: bytes | None = None + self.handshake_hash_host: bytes | None = None + + self.cpace: Cpace + self.host_name: str | None + + async def handle(self) -> None: + next_message: Message | None = None + + while True: + try: + if next_message is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one. + try: + message: Message = await self.incoming_message + except protocol_common.WireError as e: + if __debug__: + log.exception(__name__, e) + await self.write(message_handler.failure(e)) + continue + else: + # Process the message from previous run. + message = next_message + next_message = None + + try: + next_message = await handle_message(self, message) + except Exception as exc: + # Log and ignore. The context handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + if next_message is None: + + # Shut down the loop if there is no next message waiting. + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The context handler can only exit explicitly via + # finally block above + if __debug__: + log.exception(__name__, exc) + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) + + message: Message = await self.incoming_message + + if message.type not in expected_types: + raise UnexpectedMessageException(message) + + if expected_type is None: + name = message_handler.get_msg_name(message.type) + if name is None: + expected_type = protobuf.type_for_wire(message.type) + else: + expected_type = protobuf.type_for_name(name) + + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel_ctx.write(msg) + + def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.channel_ctx.write(msg, force=True) + + async def call( + self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] + ) -> protobuf.MessageType: + expected_wire_type = expected_type.MESSAGE_WIRE_TYPE + if expected_wire_type is None: + expected_wire_type = message_handler.get_msg_type( + expected_type.MESSAGE_NAME + ) + + assert expected_wire_type is not None + + await self.write(msg) + del msg + + return await self.read((expected_wire_type,), expected_type) + + async def call_any( + self, msg: protobuf.MessageType, *expected_types: int + ) -> protobuf.MessageType: + await self.write(msg) + del msg + return await self.read(expected_types) + + def set_selected_method(self, selected_method: ThpPairingMethod) -> None: + if selected_method not in get_enabled_pairing_methods(self.iface): + raise Exception("Not allowed to set this method") + self.selected_method = selected_method + + async def show_pairing_dialogue(self) -> None: + from trezor.messages import ThpPairingRequestApproved + from trezor.ui.layouts.common import interact + + result = await interact( + trezorui_api.confirm_action( + title="Pairing dialogue", + action="Do you want to start pairing?", + description="Choose wisely!", + ), + br_name="pairing_request", + br_code=ButtonRequestType.Other, + ) + if result == trezorui_api.CONFIRMED: + await self.write(ThpPairingRequestApproved()) + + async def show_connection_dialogue(self) -> None: + from trezor.ui.layouts.common import interact + + await interact( + trezorui_api.confirm_action( + title="Connection dialogue", + action="Do you want previously connected device to connect?", + description="Choose wisely! (or not)", + ), + br_name="connection_request", + br_code=ButtonRequestType.Other, + ) + + async def show_pairing_method_screen( + self, selected_method: ThpPairingMethod | None = None + ) -> UiResult: + from trezor.enums import ThpPairingMethod + + if selected_method is None: + selected_method = self.selected_method + if selected_method is ThpPairingMethod.CodeEntry: + return await self._show_code_entry_screen() + elif selected_method is ThpPairingMethod.NFC: + return await self._show_nfc_screen() + elif selected_method is ThpPairingMethod.QrCode: + return await self._show_qr_code_screen() + else: + raise Exception("Unknown pairing method") + + async def _show_code_entry_screen(self) -> UiResult: + from trezor.ui.layouts.common import interact + + return await interact( + trezorui_api.show_simple( + title="Copy the following", + text=self._get_code_code_entry_str(), + button="Cancel", + ), + br_name=None, + ) + + async def _show_nfc_screen(self) -> UiResult: + from trezor.ui.layouts.common import interact + + return await interact( + trezorui_api.show_simple( + title="NFC Pairing", + text="Move your device close to Trezor", + button="Cancel", + ), + br_name=None, + ) + + async def _show_qr_code_screen(self) -> UiResult: + from trezor.ui.layouts.common import interact + + return await interact( + trezorui_api.show_address_details( # noqa + qr_title="Scan QR code to pair", + address=self._get_code_qr_code_str(), + case_sensitive=True, + details_title="", + account="", + path="", + xpubs=[], + ), + br_name=None, + br_code=ButtonRequestType.Other, + ) + + def _get_code_code_entry_str(self) -> str: + if self.code_code_entry is not None: + code_str = f"{self.code_code_entry:06}" + if __debug__: + log.debug(__name__, "code_code_entry: %s", code_str) + + return code_str[:3] + " " + code_str[3:] + raise Exception("Code entry string is not available") + + def _get_code_qr_code_str(self) -> str: + if self.code_qr_code is not None: + code_str = (hexlify(self.code_qr_code)).decode("utf-8") + if __debug__: + log.debug(__name__, "code_qr_code_hexlified: %s", code_str) + return code_str + raise Exception("QR code string is not available") + + +async def handle_message( + pairing_ctx: PairingContext, + msg: protocol_common.Message, +) -> protocol_common.Message | None: + + res_msg: protobuf.MessageType | None = None + + from apps.thp.pairing import handle_credential_phase, handle_pairing_request + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + name = message_handler.get_msg_name(msg.type) + if name is None: + req_type = protobuf.type_for_wire(msg.type) + else: + req_type = protobuf.type_for_name(name) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + if pairing_ctx.channel_ctx.get_channel_state() == ChannelState.TC1: + task = handle_credential_phase(pairing_ctx, req_msg) + else: + task = handle_pairing_request(pairing_ctx, req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + res_msg = await workflow.spawn(context.with_context(pairing_ctx, task)) + + except UnexpectedMessageException as exc: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + # TODO: + # We might handle only the few common cases here, like + # Initialize and Cancel. + return exc.msg + except SilentError as exc: + if __debug__: + log.error(__name__, "SilentError: %s", exc.message) + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = message_handler.failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await pairing_ctx.write(res_msg) + return None diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py new file mode 100644 index 00000000000..f1766153873 --- /dev/null +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -0,0 +1,462 @@ +import ustruct +from typing import TYPE_CHECKING + +from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, +) +from storage.cache_thp import ( + KEY_LENGTH, + SESSION_ID_LENGTH, + TAG_LENGTH, + update_channel_last_used, + update_session_last_used, +) +from trezor import config, log, loop, protobuf, utils +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor.wire.thp import memory_manager + +from .. import message_handler +from ..errors import DataError +from ..protocol_common import Message +from . import ( + ACK_MESSAGE, + HANDSHAKE_COMP_RES, + HANDSHAKE_INIT_RES, + ChannelState, + PacketHeader, + SessionState, + ThpDecryptionError, + ThpDeviceLockedError, + ThpError, + ThpErrorType, + ThpInvalidDataError, + ThpUnallocatedSessionError, +) +from . import alternating_bit_protocol as ABP +from . import checksum, control_byte, get_encoded_device_properties, session_manager +from .checksum import CHECKSUM_LENGTH +from .crypto import PUBKEY_LENGTH, Handshake +from .session_context import SeedlessSessionContext +from .writer import ( + INIT_HEADER_LENGTH, + MESSAGE_TYPE_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if TYPE_CHECKING: + from typing import Awaitable + + from trezor.messages import ThpHandshakeCompletionReqNoisePayload + + from .channel import Channel + +if __debug__: + from trezor.utils import get_bytes_as_str + + +_TREZOR_STATE_UNPAIRED = b"\x00" +_TREZOR_STATE_PAIRED = b"\x01" +_TREZOR_STATE_PAIRED_AUTOCONNECT = b"\x02" + + +async def handle_received_message( + ctx: Channel, message_buffer: utils.BufferType +) -> None: + """Handle a message received from the channel.""" + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_received_message") + if utils.ALLOW_DEBUG_MESSAGES: # TODO remove after performance tests are done + try: + import micropython + + print("micropython.mem_info() from received_message_handler.py") + micropython.mem_info() + print("Allocation count:", micropython.alloc_count()) # type: ignore ["alloc_count" is not a known attribute of module "micropython"] + except AttributeError: + print( + "To show allocation count, create the build with TREZOR_MEMPERF=1" + ) + ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer) + message_length = payload_length + INIT_HEADER_LENGTH + + _check_checksum(message_length, message_buffer) + + # Synchronization process + seq_bit = (ctrl_byte & 0x10) >> 4 + ack_bit = (ctrl_byte & 0x08) >> 3 + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "handle_completed_message - seq bit of message: %d, ack bit of message: %d", + seq_bit, + ack_bit, + ) + # 0: Update "last-time used" + update_channel_last_used(ctx.channel_id) + + # 1: Handle ACKs + if control_byte.is_ack(ctrl_byte): + await _handle_ack(ctx, ack_bit) + return + + if _should_have_ctrl_byte_encrypted_transport( + ctx + ) and not control_byte.is_encrypted_transport(ctrl_byte): + raise ThpError("Message is not encrypted. Ignoring") + + # 2: Handle message with unexpected sequential bit + if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Received message with an unexpected sequential bit") + await _send_ack(ctx, ack_bit=seq_bit) + raise ThpError("Received message with an unexpected sequential bit") + + # 3: Send ACK in response + await _send_ack(ctx, ack_bit=seq_bit) + + ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit) + + try: + await _handle_message_to_app_or_channel( + ctx, payload_length, message_length, ctrl_byte + ) + except ThpUnallocatedSessionError as e: + error_message = Failure(code=FailureType.ThpUnallocatedSession) + await ctx.write(error_message, e.session_id) + except ThpDecryptionError: + await ctx.write_error(ThpErrorType.DECRYPTION_FAILED) + ctx.clear() + except ThpInvalidDataError: + await ctx.write_error(ThpErrorType.INVALID_DATA) + ctx.clear() + except ThpDeviceLockedError: + await ctx.write_error(ThpErrorType.DEVICE_LOCKED) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_received_message - end") + + +def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]: + ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) + header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "Writing ACK message to a channel with id: %d, ack_bit: %d", + ctx.get_channel_id_int(), + ack_bit, + ) + return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"") + + +def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "check_checksum") + if not checksum.is_valid( + checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length], + data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH], + ): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Invalid checksum, ignoring message.") + raise ThpError("Invalid checksum, ignoring message.") + + +async def _handle_ack(ctx: Channel, ack_bit: int) -> None: + if not ABP.is_ack_valid(ctx.channel_cache, ack_bit): + return + # ACK is expected and it has correct sync bit + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Received ACK message with correct ack bit") + if ctx.transmission_loop is not None: + ctx.transmission_loop.stop_immediately() + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Stopped transmission loop") + + ABP.set_sending_allowed(ctx.channel_cache, True) + + if ctx.write_task_spawn is not None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') + await ctx.write_task_spawn + # Note that no the write_task_spawn could result in loop.clear(), + # which will result in termination of this function - any code after + # this await might not be executed + + +def _handle_message_to_app_or_channel( + ctx: Channel, + payload_length: int, + message_length: int, + ctrl_byte: int, +) -> Awaitable[None]: + state = ctx.get_channel_state() + + if state is ChannelState.ENCRYPTED_TRANSPORT: + return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) + + if state is ChannelState.TH1: + return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte) + + if state is ChannelState.TH2: + return _handle_state_TH2(ctx, message_length, ctrl_byte) + + if _is_channel_state_pairing(state): + return _handle_pairing(ctx, message_length) + + raise ThpError("Unimplemented channel state") + + +async def _handle_state_TH1( + ctx: Channel, + payload_length: int, + message_length: int, + ctrl_byte: int, +) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_state_TH1") + if not control_byte.is_handshake_init_req(ctrl_byte): + raise ThpError("Message received is not a handshake init request!") + if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: + raise ThpError("Message received is not a valid handshake init request!") + + if not config.is_unlocked(): + raise ThpDeviceLockedError + + ctx.handshake = Handshake() + + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO buffer is gone :/ + + host_ephemeral_pubkey = bytearray( + buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH] + ) + trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = ( + ctx.handshake.handle_th1_crypto( + get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey + ) + ) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "trezor ephemeral pubkey: %s", + get_bytes_as_str(trezor_ephemeral_pubkey), + ) + log.debug( + __name__, + "encrypted trezor masked static pubkey: %s", + get_bytes_as_str(encrypted_trezor_static_pubkey), + ) + log.debug(__name__, "tag: %s", get_bytes_as_str(tag)) + + payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag + + # send handshake init response message + ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload) + ctx.set_channel_state(ChannelState.TH2) + return + + +async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None: + from apps.thp.credential_manager import decode_credential, validate_credential + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_state_TH2") + if not control_byte.is_handshake_comp_req(ctrl_byte): + raise ThpError("Message received is not a handshake completion request!") + if ctx.handshake is None: + raise Exception("Handshake object is not prepared. Retry handshake.") + + if not config.is_unlocked(): + raise ThpDeviceLockedError + + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle + host_encrypted_static_pubkey = buffer[ + INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH + ] + handshake_completion_request_noise_payload = buffer[ + INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH + ] + + ctx.handshake.handle_th2_crypto( + host_encrypted_static_pubkey, handshake_completion_request_noise_payload + ) + + ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive) + ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send) + ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h) + ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1) + + noise_payload = _decode_message( + buffer[ + INIT_HEADER_LENGTH + + KEY_LENGTH + + TAG_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + 0, + "ThpHandshakeCompletionReqNoisePayload", + ) + if TYPE_CHECKING: + assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "host static pubkey: %s, noise payload: %s", + utils.get_bytes_as_str(host_encrypted_static_pubkey), + utils.get_bytes_as_str(handshake_completion_request_noise_payload), + ) + + # key is decoded in handshake._handle_th2_crypto + host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH] + + paired: bool = False + trezor_state = _TREZOR_STATE_UNPAIRED + + if noise_payload.host_pairing_credential is not None: + try: # TODO change try-except for something better + credential = decode_credential(noise_payload.host_pairing_credential) + paired = validate_credential( + credential, + host_static_pubkey, + ) + if paired: + trezor_state = _TREZOR_STATE_PAIRED + ctx.credential = credential + else: + ctx.credential = None + except DataError as e: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.exception(__name__, e) + pass + + # send hanshake completion response + ctx.write_handshake_message( + HANDSHAKE_COMP_RES, + ctx.handshake.get_handshake_completion_response(trezor_state), + ) + + ctx.handshake = None + + if paired: + ctx.set_channel_state(ChannelState.TC1) + else: + ctx.set_channel_state(ChannelState.TP0) + + +async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") + + ctx.decrypt_buffer(message_length) + + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle + session_id, message_type = ustruct.unpack( + ">BH", memoryview(buffer)[INIT_HEADER_LENGTH:] + ) + if session_id not in ctx.sessions: + + s = session_manager.get_session_from_cache(ctx, session_id) + + if s is None: + s = SeedlessSessionContext(ctx, session_id) + + ctx.sessions[session_id] = s + loop.schedule(s.handle()) + + elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED: + raise ThpUnallocatedSessionError(session_id) + + s = ctx.sessions[session_id] + update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big")) + + s.incoming_message.put( + Message( + message_type, + buffer[ + INIT_HEADER_LENGTH + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +async def _handle_pairing(ctx: Channel, message_length: int) -> None: + from .pairing_context import PairingContext + + if ctx.connection_context is None: + ctx.connection_context = PairingContext(ctx) + loop.schedule(ctx.connection_context.handle()) + + ctx.decrypt_buffer(message_length) + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle + message_type = ustruct.unpack( + ">H", buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] + )[0] + + ctx.connection_context.incoming_message.put( + Message( + message_type, + buffer[ + INIT_HEADER_LENGTH + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool: + if ctx.get_channel_state() in [ + ChannelState.UNALLOCATED, + ChannelState.TH1, + ChannelState.TH2, + ]: + return False + return True + + +def _decode_message( + buffer: bytes, msg_type: int, message_name: str | None = None +) -> protobuf.MessageType: + if __debug__: + log.debug(__name__, "decode message") + if message_name is not None: + expected_type = protobuf.type_for_name(message_name) + else: + expected_type = protobuf.type_for_wire(msg_type) + return message_handler.wrap_protobuf_load(buffer, expected_type) + + +def _is_channel_state_pairing(state: int) -> bool: + if state in ( + ChannelState.TP0, + ChannelState.TP1, + ChannelState.TP2, + ChannelState.TP3, + ChannelState.TP4, + ChannelState.TC1, + ): + return True + return False diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py new file mode 100644 index 00000000000..0601c1f7b7b --- /dev/null +++ b/core/src/trezor/wire/thp/session_context.py @@ -0,0 +1,169 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp +from storage.cache_common import InvalidSessionError +from storage.cache_thp import SessionThpCache +from trezor import log, loop, protobuf, utils +from trezor.wire import message_handler, protocol_common +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.message_handler import failure + +from ..protocol_common import Context, Message +from . import SessionState + +if TYPE_CHECKING: + from typing import Awaitable, Container + + from storage.cache_common import DataCache + + from .channel import Channel + + pass + +_EXIT_LOOP = True +_REPEAT_LOOP = False + +if __debug__: + from trezor.utils import get_bytes_as_str + + +class GenericSessionContext(Context): + + def __init__(self, channel: Channel, session_id: int) -> None: + super().__init__(channel.iface, channel.channel_id) + self.channel: Channel = channel + self.session_id: int = session_id + self.incoming_message = loop.mailbox() + + async def handle(self) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "handle - start (channel_id (bytes): %s, session_id: %d)", + get_bytes_as_str(self.channel_id), + self.session_id, + ) + + next_message: Message | None = None + + while True: + message = next_message + next_message = None + try: + if await self._handle_message(message): + loop.schedule(self.handle()) + return + except UnexpectedMessageException as unexpected: + # The workflow was interrupted by an unexpected message. We need to + # process it as if it was a new message... + next_message = unexpected.msg + continue + except Exception as exc: + # Log and try again. + if __debug__: + log.exception(__name__, exc) + + async def _handle_message( + self, + next_message: Message | None, + ) -> bool: + + try: + if next_message is not None: + # Process the message from previous run. + message = next_message + next_message = None + else: + # Wait for a new message from wire + message = await self.incoming_message + + except protocol_common.WireError as e: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.exception(__name__, e) + await self.write(failure(e)) + return _REPEAT_LOOP + + await message_handler.handle_single_message(self, message) + return _EXIT_LOOP + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) + message: Message = await self.incoming_message + if message.type not in expected_types: + if __debug__: + log.debug( + __name__, + "EXPECTED TYPES: %s\nRECEIVED TYPE: %s", + str(expected_types), + str(message.type), + ) + raise UnexpectedMessageException(message) + + if expected_type is None: + expected_type = protobuf.type_for_wire(message.type) + + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel.write(msg, self.session_id) + + def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.channel.write(msg, self.session_id, force=True) + + def get_session_state(self) -> SessionState: ... + + +class SeedlessSessionContext(GenericSessionContext): + + def __init__(self, channel_ctx: Channel, session_id: int) -> None: + super().__init__(channel_ctx, session_id) + + def get_session_state(self) -> SessionState: + return SessionState.SEEDLESS + + @property + def cache(self) -> DataCache: + raise InvalidSessionError + + +class SessionContext(GenericSessionContext): + + def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None: + if channel_ctx.channel_id != session_cache.channel_id: + raise Exception( + "The session has different channel id than the provided channel context!" + ) + session_id = int.from_bytes(session_cache.session_id, "big") + super().__init__(channel_ctx, session_id) + self.session_cache = session_cache + + # ACCESS TO SESSION DATA + + def get_session_state(self) -> SessionState: + state = int.from_bytes(self.session_cache.state, "big") + return SessionState(state) + + def set_session_state(self, state: SessionState) -> None: + self.session_cache.state = bytearray(state.to_bytes(1, "big")) + + def release(self) -> None: + if self.session_cache is not None: + cache_thp.clear_session(self.session_cache) + + # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + return self.session_cache diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py new file mode 100644 index 00000000000..a50a1a332a1 --- /dev/null +++ b/core/src/trezor/wire/thp/session_manager.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp + +from .session_context import ( + GenericSessionContext, + SeedlessSessionContext, + SessionContext, +) + +if TYPE_CHECKING: + from .channel import Channel + + +def get_new_session_context( + channel_ctx: Channel, + session_id: int, +) -> SessionContext: + session_cache = cache_thp.create_or_replace_session( + channel=channel_ctx.channel_cache, + session_id=session_id.to_bytes(1, "big"), + ) + return SessionContext(channel_ctx, session_cache) + + +def get_new_seedless_session_ctx( + channel_ctx: Channel, session_id: int +) -> SeedlessSessionContext: + """ + Creates new `SeedlessSessionContext` that is not backed by a cache entry. + + Seed cannot be derived with this type of session. + """ + return SeedlessSessionContext(channel_ctx, session_id) + + +def get_session_from_cache( + channel_ctx: Channel, session_id: int +) -> GenericSessionContext | None: + """ + Returns a `SessionContext` (or `SeedlessSessionContext`) reconstructed from a cache or `None` if backing cache is not found. + """ + session_id_bytes = session_id.to_bytes(1, "big") + session_cache = cache_thp.get_allocated_session( + channel_ctx.channel_id, session_id_bytes + ) + if session_cache is None: + return None + elif cache_thp.is_seedless_session(session_cache): + return SeedlessSessionContext(channel_ctx, session_id) + return SessionContext(channel_ctx, session_cache) diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py new file mode 100644 index 00000000000..5482cf396c0 --- /dev/null +++ b/core/src/trezor/wire/thp/thp_main.py @@ -0,0 +1,165 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import io, log, loop, utils + +from . import ( + CHANNEL_ALLOCATION_REQ, + CODEC_V1, + ChannelState, + PacketHeader, + ThpError, + ThpErrorType, + channel_manager, + checksum, + control_byte, + get_channel_allocation_response, + writer, +) +from .channel import Channel +from .checksum import CHECKSUM_LENGTH +from .writer import ( + INIT_HEADER_LENGTH, + MAX_PAYLOAD_LEN, + write_payload_to_wire_and_add_checksum, +) + +if TYPE_CHECKING: + from trezorio import WireInterface + +_CID_REQ_PAYLOAD_LENGTH = const(12) +_CHANNELS: dict[int, Channel] = {} + + +async def thp_main_loop(iface: WireInterface) -> None: + global _CHANNELS + _CHANNELS = channel_manager.load_cached_channels() + + read = loop.wait(iface.iface_num() | io.POLL_READ) + packet = bytearray(iface.RX_PACKET_LEN) + while True: + try: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "thp_main_loop") + packet_len = await read + assert packet_len == len(packet) + iface.read(packet, 0) + + if _get_ctrl_byte(packet) == CODEC_V1: + await _handle_codec_v1(iface, packet) + continue + + cid = ustruct.unpack(">BH", packet)[1] + + if cid == BROADCAST_CHANNEL_ID: + await _handle_broadcast(iface, packet) + continue + + if cid in _CHANNELS: + await _handle_allocated(iface, cid, packet) + else: + await _handle_unallocated(iface, cid, packet) + + except ThpError as e: + if __debug__: + log.exception(__name__, e) + + +async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None: + # If the received packet is not an initial codec_v1 packet, do not send error message + if not packet[1:3] == b"##": + return + if __debug__: + log.debug(__name__, "Received codec_v1 message, returning error") + error_message = _get_codec_v1_error_message() + await writer.write_packet_to_wire(iface, error_message) + + +async def _handle_broadcast(iface: WireInterface, packet: utils.BufferType) -> None: + if _get_ctrl_byte(packet) != CHANNEL_ALLOCATION_REQ: + raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") + if __debug__: + log.debug(__name__, "Received valid message on the broadcast channel") + + length, nonce = ustruct.unpack(">H8s", packet[3:]) + payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH) + if not checksum.is_valid( + payload[-4:], + packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH], + ): + raise ThpError("Checksum is not valid") + + new_channel: Channel = channel_manager.create_new_channel(iface) + cid = int.from_bytes(new_channel.channel_id, "big") + _CHANNELS[cid] = new_channel + + response_data = get_channel_allocation_response( + nonce, new_channel.channel_id, iface + ) + response_header = PacketHeader.get_channel_allocation_response_header( + len(response_data) + CHECKSUM_LENGTH, + ) + if __debug__: + log.debug(__name__, "New channel allocated with id %d", cid) + + await write_payload_to_wire_and_add_checksum(iface, response_header, response_data) + + +async def _handle_allocated( + iface: WireInterface, cid: int, packet: utils.BufferType +) -> None: + channel = _CHANNELS[cid] + if channel is None: + await _handle_unallocated(iface, cid, packet) + raise ThpError("Invalid state of a channel") + if channel.iface is not iface: + # TODO send error message to wire + raise ThpError("Channel has different WireInterface") + + if channel.get_channel_state() != ChannelState.UNALLOCATED: + x = channel.receive_packet(packet) + if x is not None: + await x + + +async def _handle_unallocated(iface: WireInterface, cid: int, packet: bytes) -> None: + if control_byte.is_continuation(_get_ctrl_byte(packet)): + return + data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big") + header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) + await write_payload_to_wire_and_add_checksum(iface, header, data) + + +def _get_buffer_for_payload( + payload_length: int, + existing_buffer: utils.BufferType, + max_length: int = MAX_PAYLOAD_LEN, +) -> utils.BufferType: + if payload_length > max_length: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + try: + new_buffer = bytearray(payload_length) + except MemoryError: + raise ThpError("Message too large") + return new_buffer + return _reuse_existing_buffer(payload_length, existing_buffer) + + +def _reuse_existing_buffer( + payload_length: int, existing_buffer: utils.BufferType +) -> utils.BufferType: + return memoryview(existing_buffer)[:payload_length] + + +def _get_ctrl_byte(packet: bytes) -> int: + return packet[0] + + +def _get_codec_v1_error_message() -> bytes: + # Codec_v1 magic constant "?##" + Failure message type + msg_size + # + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B + ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + return ERROR_MSG diff --git a/core/src/trezor/wire/thp/transmission_loop.py b/core/src/trezor/wire/thp/transmission_loop.py new file mode 100644 index 00000000000..cd3e3ba2f8d --- /dev/null +++ b/core/src/trezor/wire/thp/transmission_loop.py @@ -0,0 +1,54 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import loop + +from .writer import write_payload_to_wire_and_add_checksum + +if TYPE_CHECKING: + from . import PacketHeader + from .channel import Channel + +MAX_RETRANSMISSION_COUNT = const(50) +MIN_RETRANSMISSION_COUNT = const(2) + + +class TransmissionLoop: + + def __init__( + self, channel: Channel, header: PacketHeader, transport_payload: bytes + ) -> None: + self.channel: Channel = channel + self.header: PacketHeader = header + self.transport_payload: bytes = transport_payload + self.wait_task: loop.spawn | None = None + self.min_retransmisson_count_achieved: bool = False + + async def start( + self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT + ) -> None: + self.min_retransmisson_count_achieved = False + for i in range(max_retransmission_count): + if i >= MIN_RETRANSMISSION_COUNT: + self.min_retransmisson_count_achieved = True + await write_payload_to_wire_and_add_checksum( + self.channel.iface, self.header, self.transport_payload + ) + self.wait_task = loop.spawn(self._wait(i)) + try: + await self.wait_task + except loop.TaskClosed: + self.wait_task = None + break + + def stop_immediately(self) -> None: + if self.wait_task is not None: + self.wait_task.close() + self.wait_task = None + + async def _wait(self, counter: int = 0) -> None: + timeout_ms = round(10200 - 1010000 / (counter + 100)) + await loop.sleep(timeout_ms) + + def __del__(self) -> None: + self.stop_immediately() diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py new file mode 100644 index 00000000000..03aedf36906 --- /dev/null +++ b/core/src/trezor/wire/thp/writer.py @@ -0,0 +1,91 @@ +from micropython import const +from trezorcrypto import crc +from typing import TYPE_CHECKING + +from trezor import io, log, loop, utils + +from . import PacketHeader + +INIT_HEADER_LENGTH = const(5) +CONT_HEADER_LENGTH = const(3) +CHECKSUM_LENGTH = const(4) +MAX_PAYLOAD_LEN = const(60000) +MESSAGE_TYPE_LENGTH = const(2) + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable, Sequence + + +def write_payload_to_wire_and_add_checksum( + iface: WireInterface, header: PacketHeader, transport_payload: bytes +) -> Awaitable[None]: + header_checksum: int = crc.crc32(header.to_bytes()) + checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes( + CHECKSUM_LENGTH, "big" + ) + data = (transport_payload, checksum) + return write_payloads_to_wire(iface, header, data) + + +async def write_payloads_to_wire( + iface: WireInterface, header: PacketHeader, data: Sequence[bytes] +) -> None: + n_of_data = len(data) + total_length = sum(len(item) for item in data) + + current_data_idx = 0 + current_data_offset = 0 + + packet = bytearray(iface.TX_PACKET_LEN) + header.pack_to_init_buffer(packet) + packet_offset: int = INIT_HEADER_LENGTH + packet_number = 0 + nwritten = 0 + while nwritten < total_length: + if packet_number == 1: + header.pack_to_cont_buffer(packet) + if packet_number >= 1 and nwritten >= total_length - iface.TX_PACKET_LEN: + packet[:] = bytearray(iface.TX_PACKET_LEN) + header.pack_to_cont_buffer(packet) + while True: + n = utils.memcpy( + packet, packet_offset, data[current_data_idx], current_data_offset + ) + packet_offset += n + current_data_offset += n + nwritten += n + + if packet_offset < iface.TX_PACKET_LEN: + current_data_idx += 1 + current_data_offset = 0 + if current_data_idx >= n_of_data: + break + elif packet_offset == iface.TX_PACKET_LEN: + break + else: + raise Exception("Should not happen!!!") + packet_number += 1 + packet_offset = CONT_HEADER_LENGTH + + # write packet to wire (in-lined) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + written_by_iface: int = 0 + while written_by_iface < len(packet): + await loop.wait(iface.iface_num() | io.POLL_WRITE) + written_by_iface = iface.write(packet) + + +async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None: + while True: + await loop.wait(iface.iface_num() | io.POLL_WRITE) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + n_written = iface.write(packet) + if n_written == len(packet): + return diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 67b88f8e684..a7a4791b1d4 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -3,7 +3,7 @@ import storage.cache as storage_cache from trezor import log, loop -from trezor.enums import MessageType +from trezor.enums import MessageType, ThpMessageType if TYPE_CHECKING: from typing import Callable @@ -17,18 +17,30 @@ from trezor import utils - -ALLOW_WHILE_LOCKED = ( - MessageType.Initialize, - MessageType.EndSession, - MessageType.GetFeatures, - MessageType.Cancel, - MessageType.LockDevice, - MessageType.DoPreauthorized, - MessageType.WipeDevice, - MessageType.SetBusy, - MessageType.Ping, -) +if utils.USE_THP: + ALLOW_WHILE_LOCKED = ( + ThpMessageType.ThpCreateNewSession, + MessageType.EndSession, + MessageType.GetFeatures, + MessageType.Cancel, + MessageType.LockDevice, + MessageType.DoPreauthorized, + MessageType.WipeDevice, + MessageType.SetBusy, + MessageType.Ping, + ) +else: + ALLOW_WHILE_LOCKED = ( + MessageType.Initialize, + MessageType.EndSession, + MessageType.GetFeatures, + MessageType.Cancel, + MessageType.LockDevice, + MessageType.DoPreauthorized, + MessageType.WipeDevice, + MessageType.SetBusy, + MessageType.Ping, + ) # Set of workflow tasks. Multiple workflows can be running at the same time. diff --git a/core/tests/mock_wire_interface.py b/core/tests/mock_wire_interface.py new file mode 100644 index 00000000000..13cd0333757 --- /dev/null +++ b/core/tests/mock_wire_interface.py @@ -0,0 +1,50 @@ +from trezor.loop import wait + + +class MockHID: + + TX_PACKET_LEN = 64 + RX_PACKET_LEN = 64 + + def __init__(self, num): + self.num = num + self.data = [] + self.packet = None + + def pad_packet(self, data): + if len(data) > self.RX_PACKET_LEN: + raise Exception("Too long packet") + padding_length = self.RX_PACKET_LEN - len(data) + return data + b"\x00" * padding_length + + def iface_num(self): + return self.num + + def write(self, msg): + self.data.append(bytearray(msg)) + return len(msg) + + def mock_read(self, packet, gen): + self.packet = self.pad_packet(packet) + return gen.send(self.RX_PACKET_LEN) + + def read(self, buffer, offset=0): + if self.packet is None: + raise Exception("No packet to read") + + if offset > len(buffer): + raise Exception("Offset out of bounds") + + buffer_space = len(buffer) - offset + + if len(self.packet) > buffer_space: + raise Exception("Buffer too small") + else: + end = offset + len(self.packet) + buffer[offset:end] = self.packet + read = len(self.packet) + self.packet = None + return read + + def wait_object(self, mode): + return wait(mode | self.num) diff --git a/core/tests/myTests.sh b/core/tests/myTests.sh new file mode 100755 index 00000000000..1c29c1fd01b --- /dev/null +++ b/core/tests/myTests.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +declare -a results +declare -i passed=0 failed=0 exit_code=0 +declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m' +MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}" +print_summary() { + echo + echo 'Summary:' + echo '-------------------' + printf '%b\n' "${results[@]}" + if [ $exit_code == 0 ]; then + echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!" + else + echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!" + fi +} + +trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT + +cd $(dirname $0) + +[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*) + +declare -i num_of_tests=${#tests[@]} + +for test_case in ${tests[@]}; do + echo ${MICROPYTHON} + echo ${test_case} + echo + if $MICROPYTHON $test_case; then + results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case") + ((passed++)) + else + results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case") + ((failed++)) + exit_code=1 + fi +done + +print_summary +exit $exit_code diff --git a/core/tests/test_apps.bitcoin.approver.py b/core/tests/test_apps.bitcoin.approver.py index 8086fd8e2d5..408f96d7eef 100644 --- a/core/tests/test_apps.bitcoin.approver.py +++ b/core/tests/test_apps.bitcoin.approver.py @@ -1,4 +1,5 @@ -from common import H_, await_result, unittest # isort:skip +# flake8: noqa: F403,F405 +from common import * # isort:skip import storage.cache_codec from trezor import wire @@ -20,11 +21,25 @@ from apps.bitcoin.sign_tx.tx_info import TxInfo from apps.common import coins +if utils.USE_THP: + import thp_common +else: + import storage.cache_codec + from trezor.wire.codec.codec_context import CodecContext + class TestApprover(unittest.TestCase): + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + else: - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) def tearDownClass(self): context.CURRENT_CONTEXT = None @@ -54,7 +69,8 @@ def setUp(self): coin_name=self.coin.coin_name, script_type=InputScriptType.SPENDTAPROOT, ) - storage.cache_codec.start_session() + if not utils.USE_THP: + storage.cache_codec.start_session() def make_coinjoin_request(self, inputs): return CoinJoinRequest( diff --git a/core/tests/test_apps.bitcoin.authorization.py b/core/tests/test_apps.bitcoin.authorization.py index 03d32651c70..aedcadba92a 100644 --- a/core/tests/test_apps.bitcoin.authorization.py +++ b/core/tests/test_apps.bitcoin.authorization.py @@ -1,23 +1,38 @@ -from common import H_, unittest # isort:skip +# flake8: noqa: F403,F405 +from common import * # isort:skip import storage.cache_codec from trezor.enums import InputScriptType from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx from trezor.wire import context -from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import CoinJoinAuthorization from apps.common import coins _ROUND_ID_LEN = 32 +if utils.USE_THP: + import thp_common +else: + import storage.cache_codec + from trezor.wire.codec.codec_context import CodecContext + class TestAuthorization(unittest.TestCase): coin = coins.by_name("Bitcoin") - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) def tearDownClass(self): context.CURRENT_CONTEXT = None @@ -34,7 +49,8 @@ def setUp(self): ) self.authorization = CoinJoinAuthorization(self.msg_auth) - storage.cache_codec.start_session() + if not utils.USE_THP: + storage.cache_codec.start_session() def test_ownership_proof_account_depth_mismatch(self): # Account depth mismatch. diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index 232d2bf01dc..25239dad8cc 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -1,7 +1,7 @@ # flake8: noqa: F403,F405 from common import * # isort:skip -from storage import cache_codec, cache_common +from storage import cache_common from trezor import wire from trezor.crypto import bip39 from trezor.wire import context @@ -9,20 +9,38 @@ from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec + class TestBitcoinKeychain(unittest.TestCase): - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) - def test_bitcoin(self): coin = _get_coin_by_name("Bitcoin") keychain = await_result(_get_keychain_for_coin(coin)) @@ -98,18 +116,30 @@ def test_unknown(self): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestAltcoinKeychains(unittest.TestCase): + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) - def test_bcash(self): coin = _get_coin_by_name("Bcash") keychain = await_result(_get_keychain_for_coin(coin)) diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index f54f64d74ff..8d0839f3743 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -2,7 +2,7 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache, cache_codec, cache_common +from storage import cache, cache_common from trezor import wire from trezor.crypto import bip39 from trezor.enums import SafetyCheckLevel @@ -13,18 +13,32 @@ from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.paths import PATTERN_SEP5, PathSchema +if utils.USE_THP: + import thp_common +if not utils.USE_THP: + from storage import cache_codec + class TestKeychain(unittest.TestCase): - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - def tearDown(self): cache.clear_all() diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 3215aba2674..6355da641c1 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -3,7 +3,7 @@ import unittest -from storage import cache_codec, cache_common +from storage import cache_common from trezor import wire from trezor.crypto import bip39 from trezor.wire import context @@ -12,6 +12,12 @@ from apps.common.keychain import get_keychain from apps.common.paths import HARDENED +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec + + if not utils.BITCOIN_ONLY: from ethereum_common import encode_network, make_network from trezor.messages import ( @@ -74,17 +80,30 @@ def _check_keychain(self, keychain, slip44_id): addr, ) - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) - def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n) network = make_network(slip44=slip44) diff --git a/core/tests/test_apps.thp.credential_manager.py b/core/tests/test_apps.thp.credential_manager.py index 267707d3744..e25cc6a0027 100644 --- a/core/tests/test_apps.thp.credential_manager.py +++ b/core/tests/test_apps.thp.credential_manager.py @@ -48,16 +48,28 @@ def test_credentials(self): cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1) self.assertNotEqual(cred_1, cred_3) - self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) - self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) - self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2)) + self.assertTrue( + credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_1) + ) + self.assertTrue( + credential_manager.decode_and_validate_credential(cred_3, DUMMY_KEY_1) + ) + self.assertFalse( + credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_2) + ) credential_manager.invalidate_cred_auth_key() cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) self.assertNotEqual(cred_1, cred_4) - self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) - self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) - self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1)) + self.assertFalse( + credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_1) + ) + self.assertFalse( + credential_manager.decode_and_validate_credential(cred_3, DUMMY_KEY_1) + ) + self.assertTrue( + credential_manager.decode_and_validate_credential(cred_4, DUMMY_KEY_1) + ) def test_protobuf_encoding(self): """ diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index cc93015e05a..07a3904c295 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,241 +1,560 @@ # flake8: noqa: F403,F405 from common import * # isort:skip -from mock_storage import mock_storage -from storage import cache, cache_codec, cache_common -from trezor.messages import EndSession, Initialize -from trezor.wire import context -from trezor.wire.codec.codec_context import CodecContext - -from apps.base import handle_EndSession, handle_Initialize -from apps.common.cache import stored, stored_async KEY = 0 +if utils.USE_THP: + import thp_common + from mock_wire_interface import MockHID + from storage import cache, cache_thp + from trezor.wire.thp import ChannelState + from trezor.wire.thp.session_context import SessionContext + + _PROTOCOL_CACHE = cache_thp + +else: + from mock_storage import mock_storage + from storage import cache, cache_codec + from trezor.messages import EndSession, Initialize + + from apps.base import handle_EndSession + + _PROTOCOL_CACHE = cache_codec + + def is_session_started() -> bool: + return cache_codec.get_active_session() is not None -# Function moved from cache.py, as it was not used there -def is_session_started() -> bool: - return cache_codec._active_session_idx is not None + def get_active_session(): + return cache_codec.get_active_session() class TestStorageCache(unittest.TestCase): - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) - - def tearDownClass(self): - context.CURRENT_CONTEXT = None - - def setUp(self): - cache.clear_all() - - def test_start_session(self): - session_id_a = cache_codec.start_session() - self.assertIsNotNone(session_id_a) - session_id_b = cache_codec.start_session() - self.assertNotEqual(session_id_a, session_id_b) - - cache.clear_all() - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_set(KEY, "something") - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_get(KEY) - - def test_end_session(self): - session_id = cache_codec.start_session() - self.assertTrue(is_session_started()) - context.cache_set(KEY, b"A") - cache_codec.end_current_session() - self.assertFalse(is_session_started()) - self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) - - # ending an ended session should be a no-op - cache_codec.end_current_session() - self.assertFalse(is_session_started()) - - session_id_a = cache_codec.start_session(session_id) - # original session no longer exists - self.assertNotEqual(session_id_a, session_id) - # original session data no longer exists - self.assertIsNone(context.cache_get(KEY)) - - # create a new session - session_id_b = cache_codec.start_session() - # switch back to original session - session_id = cache_codec.start_session(session_id_a) - self.assertEqual(session_id, session_id_a) - # end original session - cache_codec.end_current_session() - # switch back to B - session_id = cache_codec.start_session(session_id_b) - self.assertEqual(session_id, session_id_b) - - def test_session_queue(self): - session_id = cache_codec.start_session() - self.assertEqual(cache_codec.start_session(session_id), session_id) - context.cache_set(KEY, b"A") - for _ in range(cache_codec._MAX_SESSIONS_COUNT): + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + cache.clear_all() + + def test_new_channel_and_session(self): + channel = thp_common.get_new_channel(self.interface) + + # Assert that channel is created without any sessions + self.assertEqual(len(channel.sessions), 0) + + cid_1 = channel.channel_id + session_cache_1 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x01" + ) + session_1 = SessionContext(channel, session_cache_1) + self.assertEqual(session_1.channel_id, cid_1) + + session_cache_2 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x02" + ) + session_2 = SessionContext(channel, session_cache_2) + self.assertEqual(session_2.channel_id, cid_1) + self.assertEqual(session_1.channel_id, session_2.channel_id) + self.assertNotEqual(session_1.session_id, session_2.session_id) + + channel_2 = thp_common.get_new_channel(self.interface) + cid_2 = channel_2.channel_id + self.assertNotEqual(cid_1, cid_2) + + session_cache_3 = cache_thp.create_or_replace_session( + channel_2.channel_cache, b"\x01" + ) + session_3 = SessionContext(channel_2, session_cache_3) + self.assertEqual(session_3.channel_id, cid_2) + + # Sessions 1 and 3 should have different channel_id, but the same session_id + self.assertNotEqual(session_1.channel_id, session_3.channel_id) + self.assertEqual(session_1.session_id, session_3.session_id) + + self.assertEqual(cache_thp._SESSIONS[0], session_cache_1) + self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2) + self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id) + + # Check that session data IS in cache for created sessions ONLY + for i in range(3): + self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"") + self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"") + self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0) + for i in range(3, cache_thp._MAX_SESSIONS_COUNT): + self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"") + self.assertEqual(cache_thp._SESSIONS[i].session_id, b"") + self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0) + + # Check that session data IS NOT in cache after cache.clear_all() + cache.clear_all() + for session in cache_thp._SESSIONS: + self.assertEqual(session.channel_id, b"") + self.assertEqual(session.session_id, b"") + self.assertEqual(session.last_usage, 0) + self.assertEqual(session.state, b"\x00") + + def test_channel_capacity_in_cache(self): + self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3) + channels = [] + for i in range(cache_thp._MAX_CHANNELS_COUNT): + channels.append(thp_common.get_new_channel(self.interface)) + channel_ids = [channel.channel_cache.channel_id for channel in channels] + + # Assert that each channel_id is unique and that cache and list of channels + # have the same "channels" on the same indexes + for i in range(len(channel_ids)): + self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i]) + for j in range(i + 1, len(channel_ids)): + self.assertNotEqual(channel_ids[i], channel_ids[j]) + + # Create a new channel that is over the capacity + new_channel = thp_common.get_new_channel(self.interface) + for c in channels: + self.assertNotEqual(c.channel_id, new_channel.channel_id) + + # Test that the oldest (least used) channel was replaced (_CHANNELS[0]) + self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0]) + self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id) + + # Update the "last used" value of the second channel in cache (_CHANNELS[1]) and + # assert that it is not replaced when creating a new channel + cache_thp.update_channel_last_used(channel_ids[1]) + new_new_channel = thp_common.get_new_channel(self.interface) + self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1]) + + # Assert that it was in fact the _CHANNEL[2] that was replaced + self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2]) + self.assertEqual( + cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id + ) + + def test_session_capacity_in_cache(self): + self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4) + channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache + channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache + + sesions_A = [] + cid = [] + sid = [] + for i in range(3): + sesions_A.append( + cache_thp.create_or_replace_session( + channel_cache_A, (i + 1).to_bytes(1, "big") + ) + ) + cid.append(sesions_A[i].channel_id) + sid.append(sesions_A[i].session_id) + + sessions_B = [] + for i in range(cache_thp._MAX_SESSIONS_COUNT - 3): + sessions_B.append( + cache_thp.create_or_replace_session( + channel_cache_B, (i + 10).to_bytes(1, "big") + ) + ) + + for i in range(3): + self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i]) + self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id) + self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id) + for i in range(3, cache_thp._MAX_SESSIONS_COUNT): + self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i]) + + # Assert that new session replaces the oldest (least used) one (_SESSOIONS[0]) + new_session = cache_thp.create_or_replace_session(channel_cache_B, b"\xab") + self.assertEqual(new_session, cache_thp._SESSIONS[0]) + self.assertNotEqual(new_session.channel_id, cid[0]) + self.assertNotEqual(new_session.session_id, sid[0]) + + # Assert that updating "last used" for session on channel A increases also + # the "last usage" of channel A. + self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage) + cache_thp.update_session_last_used( + channel_cache_A.channel_id, sesions_A[1].session_id + ) + self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage) + + new_new_session = cache_thp.create_or_replace_session( + channel_cache_B, b"\xaa" + ) + + # Assert that creating a new session on channel B shifts the "last usage" again + # and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced + self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage) + self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1]) + self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2]) + self.assertEqual(new_new_session, cache_thp._SESSIONS[2]) + + def test_clear(self): + channel_A = thp_common.get_new_channel(self.interface) + channel_B = thp_common.get_new_channel(self.interface) + cid_A = channel_A.channel_id + cid_B = channel_B.channel_id + sessions = [] + + for i in range(3): + sessions.append( + cache_thp.create_or_replace_session( + channel_A.channel_cache, (i + 1).to_bytes(1, "big") + ) + ) + sessions.append( + cache_thp.create_or_replace_session( + channel_B.channel_cache, (i + 10).to_bytes(1, "big") + ) + ) + + self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A) + self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0) + + self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B) + self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0) + + # Assert that clearing of channel A works + self.assertNotEqual(channel_A.channel_cache.channel_id, b"") + self.assertNotEqual(channel_A.channel_cache.last_usage, 0) + self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1) + + channel_A.clear() + + self.assertEqual(channel_A.channel_cache.channel_id, b"") + self.assertEqual(channel_A.channel_cache.last_usage, 0) + self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED) + + # Assert that clearing channel A also cleared all its sessions + for i in range(3): + self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0) + self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"") + + self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0) + self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B) + + cache.clear_all() + for session in cache_thp._SESSIONS: + self.assertEqual(session.last_usage, 0) + self.assertEqual(session.channel_id, b"") + for channel in cache_thp._CHANNELS: + self.assertEqual(channel.channel_id, b"") + self.assertEqual(channel.last_usage, 0) + self.assertEqual( + cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED + ) + + def test_get_set(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x01" + ) + session_1.set(KEY, b"hello") + self.assertEqual(session_1.get(KEY), b"hello") + + session_2 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x02" + ) + session_2.set(KEY, b"world") + self.assertEqual(session_2.get(KEY), b"world") + + self.assertEqual(session_1.get(KEY), b"hello") + + cache.clear_all() + self.assertIsNone(session_1.get(KEY)) + self.assertIsNone(session_2.get(KEY)) + + def test_get_set_int(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x01" + ) + session_1.set_int(KEY, 1234) + + self.assertEqual(session_1.get_int(KEY), 1234) + + session_2 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x02" + ) + session_2.set_int(KEY, 5678) + self.assertEqual(session_2.get_int(KEY), 5678) + + self.assertEqual(session_1.get_int(KEY), 1234) + + cache.clear_all() + self.assertIsNone(session_1.get_int(KEY)) + self.assertIsNone(session_2.get_int(KEY)) + + def test_get_set_bool(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x01" + ) + with self.assertRaises(AssertionError): + session_1.set_bool(KEY, True) + + # Change length of first session field to 0 so that the length check passes + session_1.fields = (0,) + session_1.fields[1:] + + # with self.assertRaises(AssertionError) as e: + session_1.set_bool(KEY, True) + self.assertEqual(session_1.get_bool(KEY), True) + + session_2 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x02" + ) + session_2.fields = session_2.fields = (0,) + session_2.fields[1:] + session_2.set_bool(KEY, False) + self.assertEqual(session_2.get_bool(KEY), False) + + self.assertEqual(session_1.get_bool(KEY), True) + + cache.clear_all() + + # Default value is False + self.assertFalse(session_1.get_bool(KEY)) + self.assertFalse(session_2.get_bool(KEY)) + + def test_delete(self): + channel = thp_common.get_new_channel(self.interface) + session_1 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x01" + ) + + self.assertIsNone(session_1.get(KEY)) + session_1.set(KEY, b"hello") + self.assertEqual(session_1.get(KEY), b"hello") + session_1.delete(KEY) + self.assertIsNone(session_1.get(KEY)) + + session_1.set(KEY, b"hello") + session_2 = cache_thp.create_or_replace_session( + channel.channel_cache, b"\x02" + ) + + self.assertIsNone(session_2.get(KEY)) + session_2.set(KEY, b"hello") + self.assertEqual(session_2.get(KEY), b"hello") + session_2.delete(KEY) + self.assertIsNone(session_2.get(KEY)) + + self.assertEqual(session_1.get(KEY), b"hello") + + else: + + def setUpClass(self): + from trezor.wire import context + from trezor.wire.codec.codec_context import CodecContext + + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + from trezor.wire import context + + context.CURRENT_CONTEXT = None + + def setUp(self): + cache.clear_all() + + def test_start_session(self): + session_id_a = cache_codec.start_session() + self.assertIsNotNone(session_id_a) + session_id_b = cache_codec.start_session() + self.assertNotEqual(session_id_a, session_id_b) + + cache.clear_all() + self.assertIsNone(get_active_session()) + for session in cache_codec._SESSIONS: + self.assertEqual(session.session_id, b"") + self.assertEqual(session.last_usage, 0) + + def test_end_session(self): + session_id = cache_codec.start_session() + self.assertTrue(is_session_started()) + get_active_session().set(KEY, b"A") + cache_codec.end_current_session() + self.assertFalse(is_session_started()) + self.assertIsNone(get_active_session()) + + # ending an ended session should be a no-op + cache_codec.end_current_session() + self.assertFalse(is_session_started()) + + session_id_a = cache_codec.start_session(session_id) + # original session no longer exists + self.assertNotEqual(session_id_a, session_id) + # original session data no longer exists + self.assertIsNone(get_active_session().get(KEY)) + + # create a new session + session_id_b = cache_codec.start_session() + # switch back to original session + session_id = cache_codec.start_session(session_id_a) + self.assertEqual(session_id, session_id_a) + # end original session + cache_codec.end_current_session() + # switch back to B + session_id = cache_codec.start_session(session_id_b) + self.assertEqual(session_id, session_id_b) + + def test_session_queue(self): + session_id = cache_codec.start_session() + self.assertEqual(cache_codec.start_session(session_id), session_id) + get_active_session().set(KEY, b"A") + for _ in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT): + cache_codec.start_session() + self.assertNotEqual(cache_codec.start_session(session_id), session_id) + self.assertIsNone(get_active_session().get(KEY)) + + def test_get_set(self): + session_id1 = cache_codec.start_session() + cache_codec.get_active_session().set(KEY, b"hello") + self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello") + + session_id2 = cache_codec.start_session() + cache_codec.get_active_session().set(KEY, b"world") + self.assertEqual(cache_codec.get_active_session().get(KEY), b"world") + + cache_codec.start_session(session_id2) + self.assertEqual(cache_codec.get_active_session().get(KEY), b"world") + cache_codec.start_session(session_id1) + self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello") + + cache.clear_all() + self.assertIsNone(cache_codec.get_active_session()) + + def test_get_set_int(self): + session_id1 = cache_codec.start_session() + get_active_session().set_int(KEY, 1234) + self.assertEqual(get_active_session().get_int(KEY), 1234) + + session_id2 = cache_codec.start_session() + get_active_session().set_int(KEY, 5678) + self.assertEqual(get_active_session().get_int(KEY), 5678) + + cache_codec.start_session(session_id2) + self.assertEqual(get_active_session().get_int(KEY), 5678) + cache_codec.start_session(session_id1) + self.assertEqual(get_active_session().get_int(KEY), 1234) + + cache.clear_all() + self.assertIsNone(get_active_session()) + + def test_delete(self): + session_id1 = cache_codec.start_session() + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + get_active_session().delete(KEY) + self.assertIsNone(get_active_session().get(KEY)) + + get_active_session().set(KEY, b"hello") + cache_codec.start_session() + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + get_active_session().delete(KEY) + self.assertIsNone(get_active_session().get(KEY)) + + cache_codec.start_session(session_id1) + self.assertEqual(get_active_session().get(KEY), b"hello") + + def test_decorators(self): + run_count = 0 + cache_codec.start_session() + from apps.common.cache import stored + + @stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"foo" + + # cache is empty + self.assertIsNone(get_active_session().get(KEY)) + self.assertEqual(run_count, 0) + self.assertEqual(func(), b"foo") + # function was run + self.assertEqual(run_count, 1) + self.assertEqual(get_active_session().get(KEY), b"foo") + # function does not run again but returns cached value + self.assertEqual(func(), b"foo") + self.assertEqual(run_count, 1) + + def test_empty_value(self): + cache_codec.start_session() + + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"") + self.assertEqual(get_active_session().get(KEY), b"") + + get_active_session().delete(KEY) + run_count = 0 + + from apps.common.cache import stored + + @stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"" + + self.assertEqual(func(), b"") + # function gets called once + self.assertEqual(run_count, 1) + self.assertEqual(func(), b"") + # function is not called for a second time + self.assertEqual(run_count, 1) + + @mock_storage + def test_Initialize(self): + from apps.base import handle_Initialize + + def call_Initialize(**kwargs): + msg = Initialize(**kwargs) + return await_result(handle_Initialize(msg)) + + # calling Initialize without an ID allocates a new one + session_id = cache_codec.start_session() + features = call_Initialize() + self.assertNotEqual(session_id, features.session_id) + + # calling Initialize with the current ID does not allocate a new one + features = call_Initialize(session_id=session_id) + self.assertEqual(session_id, features.session_id) + + # store "hello" + get_active_session().set(KEY, b"hello") + # check that it is cleared + features = call_Initialize() + session_id = features.session_id + self.assertIsNone(get_active_session().get(KEY)) + # store "hello" again + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + + # supplying a different session ID starts a new session + call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH) + self.assertIsNone(get_active_session().get(KEY)) + + # but resuming a session loads the previous one + call_Initialize(session_id=session_id) + self.assertEqual(get_active_session().get(KEY), b"hello") + + def test_EndSession(self): + + self.assertIsNone(get_active_session()) cache_codec.start_session() - self.assertNotEqual(cache_codec.start_session(session_id), session_id) - self.assertIsNone(context.cache_get(KEY)) - - def test_get_set(self): - session_id1 = cache_codec.start_session() - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - - session_id2 = cache_codec.start_session() - context.cache_set(KEY, b"world") - self.assertEqual(context.cache_get(KEY), b"world") - - cache_codec.start_session(session_id2) - self.assertEqual(context.cache_get(KEY), b"world") - cache_codec.start_session(session_id1) - self.assertEqual(context.cache_get(KEY), b"hello") - - cache.clear_all() - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_get(KEY) - - def test_get_set_int(self): - session_id1 = cache_codec.start_session() - context.cache_set_int(KEY, 1234) - self.assertEqual(context.cache_get_int(KEY), 1234) - - session_id2 = cache_codec.start_session() - context.cache_set_int(KEY, 5678) - self.assertEqual(context.cache_get_int(KEY), 5678) - - cache_codec.start_session(session_id2) - self.assertEqual(context.cache_get_int(KEY), 5678) - cache_codec.start_session(session_id1) - self.assertEqual(context.cache_get_int(KEY), 1234) - - cache.clear_all() - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_get_int(KEY) - - def test_delete(self): - session_id1 = cache_codec.start_session() - self.assertIsNone(context.cache_get(KEY)) - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - context.cache_delete(KEY) - self.assertIsNone(context.cache_get(KEY)) - - context.cache_set(KEY, b"hello") - cache_codec.start_session() - self.assertIsNone(context.cache_get(KEY)) - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - context.cache_delete(KEY) - self.assertIsNone(context.cache_get(KEY)) - - cache_codec.start_session(session_id1) - self.assertEqual(context.cache_get(KEY), b"hello") - - def test_decorators(self): - run_count = 0 - cache_codec.start_session() - - @stored(KEY) - def func(): - nonlocal run_count - run_count += 1 - return b"foo" - - # cache is empty - self.assertIsNone(context.cache_get(KEY)) - self.assertEqual(run_count, 0) - self.assertEqual(func(), b"foo") - # function was run - self.assertEqual(run_count, 1) - self.assertEqual(context.cache_get(KEY), b"foo") - # function does not run again but returns cached value - self.assertEqual(func(), b"foo") - self.assertEqual(run_count, 1) - - @stored_async(KEY) - async def async_func(): - nonlocal run_count - run_count += 1 - return b"bar" - - # cache is still full - self.assertEqual(await_result(async_func()), b"foo") - self.assertEqual(run_count, 1) - - cache_codec.start_session() - self.assertEqual(await_result(async_func()), b"bar") - self.assertEqual(run_count, 2) - # awaitable is also run only once - self.assertEqual(await_result(async_func()), b"bar") - self.assertEqual(run_count, 2) - - def test_empty_value(self): - cache_codec.start_session() - - self.assertIsNone(context.cache_get(KEY)) - context.cache_set(KEY, b"") - self.assertEqual(context.cache_get(KEY), b"") - - context.cache_delete(KEY) - run_count = 0 - - @stored(KEY) - def func(): - nonlocal run_count - run_count += 1 - return b"" - - self.assertEqual(func(), b"") - # function gets called once - self.assertEqual(run_count, 1) - self.assertEqual(func(), b"") - # function is not called for a second time - self.assertEqual(run_count, 1) - - @mock_storage - def test_Initialize(self): - def call_Initialize(**kwargs): - msg = Initialize(**kwargs) - return await_result(handle_Initialize(msg)) - - # calling Initialize without an ID allocates a new one - session_id = cache_codec.start_session() - features = call_Initialize() - self.assertNotEqual(session_id, features.session_id) - - # calling Initialize with the current ID does not allocate a new one - features = call_Initialize(session_id=session_id) - self.assertEqual(session_id, features.session_id) - - # store "hello" - context.cache_set(KEY, b"hello") - # check that it is cleared - features = call_Initialize() - session_id = features.session_id - self.assertIsNone(context.cache_get(KEY)) - # store "hello" again - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - - # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH) - self.assertIsNone(context.cache_get(KEY)) - - # but resuming a session loads the previous one - call_Initialize(session_id=session_id) - self.assertEqual(context.cache_get(KEY), b"hello") - - def test_EndSession(self): - self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) - cache_codec.start_session() - self.assertTrue(is_session_started()) - self.assertIsNone(context.cache_get(KEY)) - await_result(handle_EndSession(EndSession())) - self.assertFalse(is_session_started()) - self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) + self.assertTrue(is_session_started()) + self.assertIsNone(get_active_session().get(KEY)) + await_result(handle_EndSession(EndSession())) + self.assertFalse(is_session_started()) + self.assertIsNone(cache_codec.get_active_session()) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index 852f5f5b8b3..9e6d4785909 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -3,61 +3,11 @@ import ustruct +from mock_wire_interface import MockHID from trezor import io -from trezor.loop import wait from trezor.utils import chunks from trezor.wire.codec import codec_v1 - -class MockHID: - - TX_PACKET_LEN = 64 - RX_PACKET_LEN = 64 - - def __init__(self, num): - self.num = num - self.data = [] - self.packet = None - - def pad_packet(self, data): - if len(data) > self.RX_PACKET_LEN: - raise Exception("Too long packet") - padding_length = self.RX_PACKET_LEN - len(data) - return data + b"\x00" * padding_length - - def iface_num(self): - return self.num - - def write(self, msg): - self.data.append(bytearray(msg)) - return len(msg) - - def mock_read(self, packet, gen): - self.packet = self.pad_packet(packet) - return gen.send(self.RX_PACKET_LEN) - - def read(self, buffer, offset=0): - if self.packet is None: - raise Exception("No packet to read") - - if offset > len(buffer): - raise Exception("Offset out of bounds") - - buffer_space = len(buffer) - offset - - if len(self.packet) > buffer_space: - raise Exception("Buffer too small") - else: - end = offset + len(self.packet) - buffer[offset:end] = self.packet - read = len(self.packet) - self.packet = None - return read - - def wait_object(self, mode): - return wait(mode | self.num) - - MESSAGE_TYPE = 0x4242 HEADER_PAYLOAD_LENGTH = MockHID.RX_PACKET_LEN - 3 - ustruct.calcsize(">HL") diff --git a/core/tests/test_trezor.wire.thp.checksum.py b/core/tests/test_trezor.wire.thp.checksum.py new file mode 100644 index 00000000000..f5d59d6805b --- /dev/null +++ b/core/tests/test_trezor.wire.thp.checksum.py @@ -0,0 +1,95 @@ +# flake8: noqa: F403,F405 +from common import * # isort:skip + +if utils.USE_THP: + from trezor.wire.thp import checksum + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolChecksum(unittest.TestCase): + vectors_correct = [ + ( + b"", + b"\x00\x00\x00\x00", + ), + ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b"\x19\x0A\x55\xAD", + ), + ( + bytes("a", "ascii"), + b"\xE8\xB7\xBE\x43", + ), + ( + bytes("abc", "ascii"), + b"\x35\x24\x41\xC2", + ), + ( + bytes("123456789", "ascii"), + b"\xCB\xF4\x39\x26", + ), + ( + bytes( + "12345678901234567890123456789012345678901234567890123456789012345678901234567890", + "ascii", + ), + b"\x7C\xA9\x4A\x72", + ), + ( + b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61", + b"\x9B\xD3\x66\xAE", + ), + ( + b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4", + b"\x6B\xA4\xEC\x92", + ), + ] + vectors_incorrect = [ + ( + b"", + b"\x00\x00\x00\x00\x00", + ), + ( + b"", + b"", + ), + ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b"\x19\x0A\x55\xAE", + ), + ( + bytes("A", "ascii"), + b"\xE8\xB7\xBE\x43", + ), + ( + bytes("abc ", "ascii"), + b"\x35\x24\x41\xC2", + ), + ( + bytes("1234567890", "ascii"), + b"\xCB\xF4\x39\x26", + ), + ( + bytes( + "1234567890123456789012345678901234567890123456789012345678901234567890123456789", + "ascii", + ), + b"\x7C\xA9\x4A\x72", + ), + ] + + def test_computation(self): + for data, chksum in self.vectors_correct: + self.assertEqual(checksum.compute(data), chksum) + + def test_validation_correct(self): + for data, chksum in self.vectors_correct: + self.assertTrue(checksum.is_valid(chksum, data)) + + def test_validation_incorrect(self): + for data, chksum in self.vectors_incorrect: + self.assertFalse(checksum.is_valid(chksum, data)) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.crypto.py b/core/tests/test_trezor.wire.thp.crypto.py new file mode 100644 index 00000000000..1eee57f8cb9 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.crypto.py @@ -0,0 +1,157 @@ +# flake8: noqa: F403,F405 +from common import * # isort:skip +from trezorcrypto import aesgcm, curve25519 + +import storage + +if utils.USE_THP: + import thp_common + from trezor.wire.thp import crypto + from trezor.wire.thp.crypto import IV_1, IV_2, Handshake + + def get_dummy_device_secret(): + return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolCrypto(unittest.TestCase): + if utils.USE_THP: + handshake = Handshake() + key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07" + # 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag + vectors_enc = [ + ( + key_1, + 0, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", + b"e2c9dd152fbee5821ea7", + b"10625812de81b14a46b9f1e5100a6d0c", + ), + ( + key_1, + 1, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", + b"79811619ddb07c2b99f8", + b"71c6b872cdc499a7e9a3c7441f053214", + ), + ( + key_1, + 369, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + b"03bd030390f2dfe815a61c2b157a064f", + b"c1200f8a7ae9a6d32cef0fff878d55c2", + ), + ( + key_1, + 369, + b"\x55\x64\x73\x82\x91", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + b"03bd030390f2dfe815a61c2b157a064f", + b"693ac160cd93a20f7fc255f049d808d0", + ), + ] + # 0:chaining key, 1:input, 2:output_1, 3:output:2 + vectors_hkdf = [ + ( + crypto.PROTOCOL_NAME, + b"\x01\x02", + b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514", + b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba", + ), + ( + b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14", + b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa", + b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48", + b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76", + ), + ] + vectors_iv = [ + (0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + (1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"), + (7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"), + (1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"), + (4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"), + (0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"), + ] + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + utils.DISABLE_ENCRYPTION = False + + def test_encryption(self): + for v in self.vectors_enc: + buffer = bytearray(v[3]) + tag = crypto.enc(buffer, v[0], v[1], v[2]) + self.assertEqual(hexlify(buffer), v[4]) + self.assertEqual(hexlify(tag), v[5]) + self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2])) + self.assertEqual(buffer, v[3]) + + def test_hkdf(self): + for v in self.vectors_hkdf: + ck, k = crypto._hkdf(v[0], v[1]) + self.assertEqual(hexlify(ck), v[2]) + self.assertEqual(hexlify(k), v[3]) + + def test_iv_from_nonce(self): + for v in self.vectors_iv: + # x = v[0] + # y = x.to_bytes(8, "big") + iv = crypto._get_iv_from_nonce(v[0]) + self.assertEqual(iv, v[1]) + with self.assertRaises(AssertionError) as e: + iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1) + self.assertEqual(e.value.value, "Nonce overflow, terminate the channel") + + def test_incorrect_vectors(self): + pass + + def test_th1_crypto(self): + storage.device.get_device_secret = get_dummy_device_secret + handshake = self.handshake + + host_ephemeral_privkey = curve25519.generate_secret() + host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey) + handshake.handle_th1_crypto(b"", host_ephemeral_pubkey) + + def test_th2_crypto(self): + handshake = self.handshake + + host_static_privkey = curve25519.generate_secret() + host_static_pubkey = curve25519.publickey(host_static_privkey) + aes_ctx = aesgcm(handshake.k, IV_2) + aes_ctx.auth(handshake.h) + encrypted_host_static_pubkey = bytearray( + aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish() + ) + + # Code to encrypt Host's noise encrypted payload correctly: + protomsg = bytearray(b"\x10\x02\x10\x03") + temp_k = handshake.k + temp_h = handshake.h + + temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey) + _, temp_k = crypto._hkdf( + handshake.ck, + curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey), + ) + aes_ctx = aesgcm(temp_k, IV_1) + aes_ctx.encrypt_in_place(protomsg) + aes_ctx.auth(temp_h) + tag = aes_ctx.finish() + encrypted_payload = bytearray(protomsg + tag) + # end of encrypted payload generation + + handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload) + self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py new file mode 100644 index 00000000000..7ba208a421c --- /dev/null +++ b/core/tests/test_trezor.wire.thp.py @@ -0,0 +1,381 @@ +# flake8: noqa: F403,F405 +from common import * # isort:skip +from mock_wire_interface import MockHID +from trezor import config, io, protobuf +from trezor.crypto.curve import curve25519 +from trezor.enums import ThpMessageType +from trezor.wire.errors import UnexpectedMessage +from trezor.wire.protocol_common import Message + +if utils.USE_THP: + from typing import TYPE_CHECKING + + import thp_common + from storage import cache_thp + from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, + ) + from trezor.crypto import elligator2 + from trezor.enums import ThpPairingMethod + from trezor.messages import ( + ThpCodeEntryChallenge, + ThpCodeEntryCpaceHostTag, + ThpCredentialRequest, + ThpEndRequest, + ThpPairingRequest, + ) + from trezor.wire.thp import ( + ChannelState, + checksum, + interface_manager, + memory_manager, + thp_main, + ) + from trezor.wire.thp.crypto import Handshake + from trezor.wire.thp.pairing_context import PairingContext + + from apps.thp import pairing + + if TYPE_CHECKING: + from trezor.wire import WireInterface + + def get_dummy_key() -> bytes: + return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31" + + def dummy_encode_iface(iface: WireInterface): + return thp_common._MOCK_INTERFACE_HID + + def send_channel_allocation_request( + interface: MockHID, nonce: bytes | None = None + ) -> bytes: + if nonce is None or len(nonce) != 8: + nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77" + header = b"\x40\xff\xff\x00\x0c" + chksum = checksum.compute(header + nonce) + cid_req = header + nonce + chksum + gen = thp_main.thp_main_loop(interface) + expected_channel_index = cache_thp._get_next_channel_index() + gen.send(None) + interface.mock_read(cid_req, gen) + gen.send(None) + model = bytes(utils.INTERNAL_MODEL, "big") + response_data = ( + b"\x0a\x04" + model + "\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04" + ) + response_without_crc = ( + b"\x41\xff\xff\x00\x20" + + nonce + + cache_thp._CHANNELS[expected_channel_index].channel_id + + response_data + ) + chkcsum = checksum.compute(response_without_crc) + expected_response = response_without_crc + chkcsum + b"\x00" * 27 + return expected_response + + def get_channel_id_from_response(channel_allocation_response: bytes) -> int: + return int.from_bytes(channel_allocation_response[13:15], "big") + + def get_ack(channel_id: bytes) -> bytes: + if len(channel_id) != 2: + raise Exception("Channel id should by two bytes long") + return ( + b"\x20" + + channel_id + + b"\x00\x04" + + checksum.compute(b"\x20" + channel_id + b"\x00\x04") + + b"\x00" * 55 + ) + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocol(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + interface_manager.encode_iface = dummy_encode_iface + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + memory_manager.READ_BUFFER = bytearray(64) + memory_manager.WRITE_BUFFER = bytearray(256) + interface_manager.decode_iface = thp_common.dummy_decode_iface + + def test_codec_message(self): + self.assertEqual(len(self.interface.data), 0) + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + + # There should be a failiure response to received init packet (starts with "?##") + test_codec_message = b"?## Some data" + self.interface.mock_read(test_codec_message, gen) + gen.send(None) + self.assertEqual(len(self.interface.data), 1) + + expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10" + self.assertEqual( + self.interface.data[-1][: len(expected_response)], expected_response + ) + + # There should be no response for continuation packet (starts with "?" only) + test_codec_message_2 = b"? Cont packet" + self.interface.mock_read(test_codec_message_2, gen) + + # Check that sending None fails on AssertionError + with self.assertRaises(AssertionError): + gen.send(None) + self.assertEqual(len(self.interface.data), 1) + + def test_message_on_unallocated_channel(self): + gen = thp_main.thp_main_loop(self.interface) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + message_to_channel_789a = ( + b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + ) + self.interface.mock_read(message_to_channel_789a, gen) + gen.send(None) + unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + self.assertEqual( + utils.get_bytes_as_str(self.interface.data[-1]), + unallocated_chanel_error_on_channel_789a, + ) + + def tbd_channel_allocation(self): + self.assertEqual(len(thp_main._CHANNELS), 0) + for c in cache_thp._CHANNELS: + self.assertEqual(int.from_bytes(c.state, "big"), ChannelState.UNALLOCATED) + + expected_channel_index = cache_thp._get_next_channel_index() + expected_response = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-1], expected_response) + + cid = cache_thp._CHANNELS[expected_channel_index].channel_id + self.assertTrue(int.from_bytes(cid, "big") in thp_main._CHANNELS) + self.assertEqual(len(thp_main._CHANNELS), 1) + # test channel's default state is TH1: + cid = get_channel_id_from_response(self.interface.data[-1]) + self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1) + + def tbd_invalid_encrypted_tag(self): + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + # prepare 2 new channels + expected_response_1 = send_channel_allocation_request(self.interface) + expected_response_2 = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-2], expected_response_1) + self.assertEqual(self.interface.data[-1], expected_response_2) + + # test invalid encryption tag + config.init() + config.wipe() + cid_1 = get_channel_id_from_response(expected_response_1) + channel = thp_main._CHANNELS[cid_1] + channel.iface = self.interface + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + header = b"\x04" + channel.channel_id + b"\x00\x14" + + tag = b"\x00" * 16 + chksum = checksum.compute(header + tag) + message_with_invalid_tag = header + tag + chksum + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + cid_1_bytes = int.to_bytes(cid_1, 2, "big") + expected_ack_on_received_message = get_ack(cid_1_bytes) + + self.interface.mock_read(message_with_invalid_tag, gen) + gen.send(None) + + self.assertEqual( + self.interface.data[-1], + expected_ack_on_received_message, + ) + error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03" + chksum_err = checksum.compute(error_without_crc) + gen.send(None) + + decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54 + + self.assertEqual( + self.interface.data[-1], + decryption_failed_error, + ) + + def tbd_test_channel_errors(self): + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + # prepare 2 new channels + expected_response_1 = send_channel_allocation_request(self.interface) + expected_response_2 = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-2], expected_response_1) + self.assertEqual(self.interface.data[-1], expected_response_2) + + # test invalid encryption tag + config.init() + config.wipe() + cid_1 = get_channel_id_from_response(expected_response_1) + channel = thp_main._CHANNELS[cid_1] + channel.iface = self.interface + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + header = b"\x04" + channel.channel_id + b"\x00\x14" + + tag = b"\x00" * 16 + chksum = checksum.compute(header + tag) + message_with_invalid_tag = header + tag + chksum + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + cid_1_bytes = int.to_bytes(cid_1, 2, "big") + # expected_ack_on_received_message = get_ack(cid_1_bytes) + + self.interface.mock_read(message_with_invalid_tag, gen) + # gen.send(None) + + # self.assertEqual( + # self.interface.data[-1], + # expected_ack_on_received_message, + # ) + error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03" + chksum_err = checksum.compute(error_without_crc) + # gen.send(None) + + decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54 + + self.assertEqual( + self.interface.data[-1], + decryption_failed_error, + ) + + # test invalid tag in handshake phase + cid_2 = get_channel_id_from_response(expected_response_1) + # cid_2_bytes = cid_2.to_bytes(2, "big") + channel = thp_main._CHANNELS[cid_2] + channel.iface = self.interface + + channel.set_channel_state(ChannelState.TH2) + + message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9" + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + # gen.send(message_with_invalid_tag) + # gen.send(None) + # gen.send(None) + # for i in self.interface.data: + # print(utils.get_bytes_as_str(i)) + + def tbd_skip_pairing(self): + config.init() + config.wipe() + channel = next(iter(thp_main._CHANNELS.values())) + channel.selected_pairing_methods = [ + ThpPairingMethod.SkipPairing, + ThpPairingMethod.CodeEntry, + ThpPairingMethod.NFC_Unidirectional, + ThpPairingMethod.QrCode, + ] + pairing_ctx = PairingContext(channel) + request_message = ThpPairingRequest() + channel.set_channel_state(ChannelState.TP1) + gen = pairing.handle_pairing_request(pairing_ctx, request_message) + + with self.assertRaises(StopIteration): + gen.send(None) + self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT) + + # Teardown: set back initial channel state value + channel.set_channel_state(ChannelState.TH1) + + def TODO_test_pairing(self): + config.init() + config.wipe() + cid = get_channel_id_from_response( + send_channel_allocation_request(self.interface) + ) + channel = thp_main._CHANNELS[cid] + channel.selected_pairing_methods = [ + ThpPairingMethod.CodeEntry, + ThpPairingMethod.NFC, + ThpPairingMethod.QrCode, + ] + pairing_ctx = PairingContext(channel) + request_message = ThpPairingRequest() + with self.assertRaises(UnexpectedMessage) as e: + pairing.handle_pairing_request(pairing_ctx, request_message) + print(e.value.message) + channel.set_channel_state(ChannelState.TP1) + gen = pairing.handle_pairing_request(pairing_ctx, request_message) + + channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0) + channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"") + + gen.send(None) + + async def _dummy(ctx: PairingContext, expected_types): + return await ctx.read([1018, 1024]) + + # pairing.show_display_data = _dummy + + msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34") + buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry)) + protobuf.encode(buffer, msg_code_entry) + code_entry_challenge = Message(ThpMessageType.ThpCodeEntryChallenge, buffer) + self.interface.mock_read(code_entry_challenge, gen) + + # tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3" + # tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0" + + pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe" + generator_host = elligator2.map_to_curve25519(pregenerator_host) + cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9" + cpace_host_public_key: bytes = curve25519.multiply( + cpace_host_private_key, generator_host + ) + msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key) + + # msg = ThpQrCodeTag(tag=tag_qrc) + # msg = ThpNfcTagHost(tag=tag_nfc) + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + + protobuf.encode(buffer, msg) + user_message = Message(ThpMessageType.ThpCodeEntryCpaceHost, buffer) + self.interface.mock_read(user_message, gen) + + tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81" + msg = ThpCodeEntryTag(tag=tag_ent) + + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + + protobuf.encode(buffer, msg) + user_message = Message(ThpMessageType.ThpCodeEntryTag, buffer) + self.interface.mock_read(user_message, gen) + + host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77" + msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey) + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + protobuf.encode(buffer, msg) + credential_request = Message(ThpMessageType.ThpCredentialRequest, buffer) + self.interface.mock_read(credential_request, gen) + + msg = ThpEndRequest() + + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + protobuf.encode(buffer, msg) + end_request = Message(1012, buffer) + with self.assertRaises(StopIteration) as e: + self.interface.mock_read(end_request, gen) + print("response message:", e.value.value.MESSAGE_NAME) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.writer.py b/core/tests/test_trezor.wire.thp.writer.py new file mode 100644 index 00000000000..0f6110761ac --- /dev/null +++ b/core/tests/test_trezor.wire.thp.writer.py @@ -0,0 +1,150 @@ +# flake8: noqa: F403,F405 +from common import * # isort:skip + +from typing import Any, Awaitable + +if utils.USE_THP: + import thp_common + from mock_wire_interface import MockHID + from trezor.wire.thp import ENCRYPTED, PacketHeader, writer + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolWriter(unittest.TestCase): + short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + longer_payload_expected = [ + b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ] + eight_longer_payloads_expected = [ + b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e", + b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b", + b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8", + b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5", + b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122", + b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f", + b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c", + b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9", + b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516", + b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253", + b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90", + b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd", + b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a", + b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647", + b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384", + b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1", + b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe", + b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b", + b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778", + b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5", + b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2", + b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c", + b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9", + b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6", + b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223", + b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60", + b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d", + b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da", + b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000", + ] + empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + longer_payload_with_checksum_expected = [ + b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ] + + def await_until_result(self, task: Awaitable) -> Any: + with self.assertRaises(StopIteration): + while True: + task.send(None) + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def test_write_empty_packet(self): + self.await_until_result(writer.write_packet_to_wire(self.interface, b"")) + + print(self.interface.data[0]) + self.assertEqual(len(self.interface.data), 1) + self.assertEqual(self.interface.data[0], b"") + + def test_write_empty_payload(self): + header = PacketHeader(ENCRYPTED, 4660, 4) + await_result(writer.write_payloads_to_wire(self.interface, header, (b"",))) + self.assertEqual(len(self.interface.data), 0) + + def test_write_short_payload(self): + header = PacketHeader(ENCRYPTED, 4660, 5) + data = b"\x07" + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) + self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected) + + def test_write_longer_payload(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED, 4660, 256) + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) + + for i in range(len(self.longer_payload_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), self.longer_payload_expected[i] + ) + + def test_write_eight_longer_payloads(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED, 4660, 2048) + self.await_until_result( + writer.write_payloads_to_wire( + self.interface, header, (data, data, data, data, data, data, data, data) + ) + ) + for i in range(len(self.eight_longer_payloads_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i] + ) + + def test_write_empty_payload_with_checksum(self): + header = PacketHeader(ENCRYPTED, 4660, 4) + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"") + ) + + self.assertEqual( + hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected + ) + + def test_write_longer_payload_with_checksum(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED, 4660, 256) + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, data) + ) + + for i in range(len(self.longer_payload_with_checksum_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), + self.longer_payload_with_checksum_expected[i], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp_deprecated.py b/core/tests/test_trezor.wire.thp_deprecated.py new file mode 100644 index 00000000000..fce54f0f09b --- /dev/null +++ b/core/tests/test_trezor.wire.thp_deprecated.py @@ -0,0 +1,337 @@ +# flake8: noqa: F403,F405 +from common import * # isort:skip +import ustruct +from typing import TYPE_CHECKING + +from mock_wire_interface import MockHID +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import io +from trezor.utils import chunks +from trezor.wire.protocol_common import Message + +if utils.USE_THP: + import thp_common + import trezor.wire.thp + from trezor.wire.thp import alternating_bit_protocol as ABP + from trezor.wire.thp import checksum, thp_main + from trezor.wire.thp.checksum import CHECKSUM_LENGTH + +if TYPE_CHECKING: + from trezorio import WireInterface + + +MESSAGE_TYPE = 0x4242 +MESSAGE_TYPE_BYTES = b"\x42\x42" +_MESSAGE_TYPE_LEN = 2 +PLAINTEXT_0 = 0x01 +PLAINTEXT_1 = 0x11 +COMMON_CID = 4660 +CONT = 0x80 + +HEADER_INIT_LENGTH = 5 +HEADER_CONT_LENGTH = 3 +if utils.USE_THP: + INIT_MESSAGE_DATA_LENGTH = HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN # + PACKET_LENGTH + + +def make_header(ctrl_byte, cid, length): + return ustruct.pack(">BHH", ctrl_byte, cid, length) + + +def make_cont_header(): + return ustruct.pack(">BH", CONT, COMMON_CID) + + +def makeSimpleMessage(header, message_type, message_data): + return header + ustruct.pack(">H", message_type) + message_data + + +def makeCidRequest(header, message_data): + return header + message_data + + +def getPlaintext() -> bytes: + if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1: + return PLAINTEXT_1 + return PLAINTEXT_0 + + +async def deprecated_read_message( + iface: WireInterface, buffer: utils.BufferType +) -> Message: + return Message(-1, b"\x00") + + +async def deprecated_write_message( + iface: WireInterface, message: Message, is_retransmission: bool = False +) -> None: + pass + + +# This test suite is an adaptation of test_trezor.wire.codec_v1 +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestWireTrezorHostProtocolV1(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def _simple(self): + cid_req_header = make_header( + ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 + ) + cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data) + + message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18) + cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0" + message = makeSimpleMessage( + message_header, + MESSAGE_TYPE, + cid_request_dummy_data + cid_request_dummy_data_checksum, + ) + + buffer = bytearray(64) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(cid_req_message) + gen.send(None) + gen.send(message) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, cid_request_dummy_data) + + buffer_without_zeroes = buffer[: len(message) - 5] + message_without_header = message[5:] + # message should have been read into the buffer + self.assertEqual(buffer_without_zeroes, message_without_header) + + def _read_one_packet(self): + # zero length message - just a header + PLAINTEXT = getPlaintext() + header = make_header( + PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + ) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES) + message = header + MESSAGE_TYPE_BYTES + chksum + + buffer = bytearray(64) + gen = deprecated_read_message(self.interface, buffer) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(message) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, b"") + + # message should have been read into the buffer + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) + + def _read_many_packets(self): + message = bytes(range(256)) + header = make_header( + getPlaintext(), + COMMON_CID, + len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, + ) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message) + # message = MESSAGE_TYPE_BYTES + message + checksum + + # first packet is init header + 59 bytes of data + # other packets are cont header + 61 bytes of data + cont_header = make_cont_header() + packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ + cont_header + chunk + for chunk in chunks( + message[INIT_MESSAGE_DATA_LENGTH:] + chksum, + 64 - HEADER_CONT_LENGTH, + ) + ] + buffer = bytearray(262) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + for packet in packets: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + # last packet will stop + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # message should have been read into the buffer ) + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) + + def _read_large_message(self): + message = b"hello world" + header = make_header( + getPlaintext(), + COMMON_CID, + _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH, + ) + + packet = ( + header + + MESSAGE_TYPE_BYTES + + message + + checksum.compute(header + MESSAGE_TYPE_BYTES + message) + ) + + # make sure we fit into one packet, to make this easier + # self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH) + + buffer = bytearray(1) + self.assertTrue(len(buffer) <= len(packet)) + + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(packet) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # read should have allocated its own buffer and not touch ours + self.assertEqual(buffer, b"\x00") + + def _roundtrip(self): + message_payload = bytes(range(256)) + message = Message( + MESSAGE_TYPE, message_payload, 1 + ) # TODO use different session id + gen = deprecated_write_message(self.interface, message) + # exhaust the iterator: + # (XXX we can only do this because the iterator is only accepting None and returns None) + for query in gen: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + buffer = bytearray(1024) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + for packet in self.interface.data: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + print(utils.get_bytes_as_str(packet)) + query = gen.send(packet) + + with self.assertRaises(StopIteration) as e: + gen.send(None) + + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message.data) + + def _write_one_packet(self): + message = Message(MESSAGE_TYPE, b"") + gen = deprecated_write_message(self.interface, message) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + with self.assertRaises(StopIteration): + gen.send(None) + + header = make_header( + getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + ) + expected_message = ( + header + + MESSAGE_TYPE_BYTES + + checksum.compute(header + MESSAGE_TYPE_BYTES) + + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH) + ) + self.assertTrue(self.interface.data == [expected_message]) + + def _write_multiple_packets(self): + message_payload = bytes(range(256)) + message = Message(MESSAGE_TYPE, message_payload) + gen = deprecated_write_message(self.interface, message) + + header = make_header( + PLAINTEXT_1, + COMMON_CID, + len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, + ) + cont_header = make_cont_header() + chksum = checksum.compute( + header + message.type.to_bytes(2, "big") + message.data + ) + packets = [ + header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH] + ] + [ + cont_header + chunk + for chunk in chunks( + message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, + HEADER_CONT_LENGTH, # + PACKET_LENGTH + ) + ] + + for _ in packets: + # we receive as many queries as there are packets + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + # the first sent None only started the generator. the len(packets)-th None + # will finish writing and raise StopIteration + with self.assertRaises(StopIteration): + gen.send(None) + + # packets must be identical up to the last one + self.assertListEqual(packets[:-1], self.interface.data[:-1]) + # last packet must be identical up to message length. remaining bytes in + # the 64-byte packets are garbage -- in particular, it's the bytes of the + # previous packet + last_packet = packets[-1] + packets[-2][len(packets[-1]) :] + self.assertEqual(last_packet, self.interface.data[-1]) + + def _read_huge_packet(self): + PACKET_COUNT = 1180 + # message that takes up 1 180 USB packets + message_size = (PACKET_COUNT - 1) * ( + HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN # + PACKET_LENGTH + ) + INIT_MESSAGE_DATA_LENGTH + + # ensure that a message this big won't fit into memory + # Note: this control is changed, because THP has only 2 byte length field + self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN) + # self.assertRaises(MemoryError, bytearray, message_size) + header = make_header(PLAINTEXT_1, COMMON_CID, message_size) + packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) + buffer = bytearray(65536) + gen = deprecated_read_message(self.interface, buffer) + + query = gen.send(None) + + # THP returns "Message too large" error after reading the message size, + # it is different from codec_v1 as it does not allow big enough messages + # to raise MemoryError in this test + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(trezor.wire.thp.ThpError) as e: + query = gen.send(packet) + + self.assertEqual(e.value.args[0], "Message too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/thp_common.py b/core/tests/thp_common.py new file mode 100644 index 00000000000..ba6f0acdc30 --- /dev/null +++ b/core/tests/thp_common.py @@ -0,0 +1,47 @@ +# flake8: noqa: F403,F405 +from trezor import utils +from trezor.wire.thp import ChannelState + +if utils.USE_THP: + import unittest + from typing import TYPE_CHECKING + + from mock_wire_interface import MockHID + from storage import cache_thp + from trezor.wire import context + from trezor.wire.thp import interface_manager + from trezor.wire.thp.channel import Channel + from trezor.wire.thp.session_context import SessionContext + + _MOCK_INTERFACE_HID = b"\x00" + + if TYPE_CHECKING: + from trezor.wire import WireInterface + + def dummy_decode_iface(cached_iface: bytes): + return MockHID(0xDEADBEEF) + + def get_new_channel(channel_iface: WireInterface | None = None) -> Channel: + interface_manager.decode_iface = dummy_decode_iface + channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID) + channel = Channel(channel_cache) + channel.set_channel_state(ChannelState.TH1) + if channel_iface is not None: + channel.iface = channel_iface + return channel + + def prepare_context() -> None: + channel = get_new_channel() + session_cache = cache_thp.create_or_replace_session( + channel.channel_cache, session_id=b"\x01" + ) + session_ctx = SessionContext(channel, session_cache) + context.CURRENT_CONTEXT = session_ctx + + +if __debug__: + # Disable log.debug + def suppres_debug_log() -> None: + from trezor import log + + log.debug = lambda name, msg, *args: None diff --git a/core/tools/codegen/get_trezor_keys.py b/core/tools/codegen/get_trezor_keys.py index 31c40fef1fe..b511abd807d 100755 --- a/core/tools/codegen/get_trezor_keys.py +++ b/core/tools/codegen/get_trezor_keys.py @@ -2,7 +2,7 @@ import binascii from trezorlib.client import TrezorClient -from trezorlib.transport_hid import HidTransport +from trezorlib.transport.hid import HidTransport devices = HidTransport.enumerate() if len(devices) > 0: diff --git a/legacy/firmware/fsm.c b/legacy/firmware/fsm.c index eefed373dba..25a67d1b8b3 100644 --- a/legacy/firmware/fsm.c +++ b/legacy/firmware/fsm.c @@ -191,6 +191,18 @@ void fsm_sendFailure(FailureType code, const char *text) case FailureType_Failure_InvalidSession: text = _("Invalid session"); break; + case FailureType_Failure_ThpUnallocatedSession: + text = _("Unallocated session"); + break; + case FailureType_Failure_InvalidProtocol: + text = _("Invalid protocol"); + break; + case FailureType_Failure_BufferError: + text = _("Buffer error"); + break; + case FailureType_Failure_DeviceIsBusy: + text = _("Device is busy"); + break; case FailureType_Failure_FirmwareError: text = _("Firmware error"); break; diff --git a/legacy/firmware/protob/Makefile b/legacy/firmware/protob/Makefile index 48dbd5ddc7d..2ea4e10ead8 100644 --- a/legacy/firmware/protob/Makefile +++ b/legacy/firmware/protob/Makefile @@ -3,14 +3,14 @@ Q := @ endif SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdProtect Tezos WebAuthn \ - DebugLinkRecordScreen DebugLinkEraseSdCard DebugLinkWatchLayout \ - DebugLinkLayout DebugLinkResetDebugEvents GetNonce \ + DebugLinkRecordScreen DebugLinkEraseSdCard DebugLinkWatchLayout DebugLinkLayout \ + DebugLinkResetDebugEvents DebugLinkGetPairingInfo DebugLinkPairingInfo GetNonce \ TxAckInput TxAckOutput TxAckPrev TxAckPaymentRequest \ EthereumSignTypedData EthereumTypedDataStructRequest EthereumTypedDataStructAck \ EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \ UnlockBootloader AuthenticateDevice AuthenticityProof \ Solana StellarClaimClaimableBalanceOp \ - ChangeLanguage TranslationDataRequest TranslationDataAck \ + ChangeLanguage TranslationDataRequest TranslationDataAck Thp\ SetBrightness DebugLinkOptigaSetSecMax EntropyCheckReady EntropyCheckContinue \ BenchmarkListNames BenchmarkRun BenchmarkNames BenchmarkResult diff --git a/legacy/firmware/protob/messages-thp.proto b/legacy/firmware/protob/messages-thp.proto new file mode 120000 index 00000000000..4799efe83ae --- /dev/null +++ b/legacy/firmware/protob/messages-thp.proto @@ -0,0 +1 @@ +../../vendor/trezor-common/protob/messages-thp.proto \ No newline at end of file diff --git a/python/requirements.txt b/python/requirements.txt index 161faad77ea..f4a121a1568 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -8,3 +8,5 @@ libusb1>=1.6.4 construct>=2.9,!=2.10.55 typing_extensions>=4.7.1 construct-classes>=0.1.2 +cryptography >=43.0.3 +platformdirs >=2 diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 4f6d56f8ed1..2eb0e332bb0 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -93,6 +93,8 @@ def client(self) -> TrezorClientDebugLink: """ if self._client is None: raise RuntimeError + if self._client.is_invalidated: + self._client = self._client.get_new_client() return self._client def make_args(self) -> List[str]: @@ -112,7 +114,7 @@ def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: start = time.monotonic() try: while True: - if transport._ping(): + if transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py index 39e26f569fc..2e4a530af5b 100644 --- a/python/src/trezorlib/authentication.py +++ b/python/src/trezorlib/authentication.py @@ -7,7 +7,7 @@ from importlib import metadata from . import device -from .client import TrezorClient +from .transport.session import Session try: cryptography_version = metadata.version("cryptography") @@ -361,7 +361,7 @@ def verify_authentication_response( def authenticate_device( - client: TrezorClient, + session: Session, challenge: bytes | None = None, *, whitelist: t.Collection[bytes] | None = None, @@ -371,7 +371,7 @@ def authenticate_device( if challenge is None: challenge = secrets.token_bytes(16) - resp = device.authenticate(client, challenge) + resp = device.authenticate(session, challenge) return verify_authentication_response( challenge, diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py index 6587e2a3abe..64218b7aad8 100644 --- a/python/src/trezorlib/benchmark.py +++ b/python/src/trezorlib/benchmark.py @@ -19,16 +19,16 @@ from . import messages if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session def list_names( - client: "TrezorClient", + session: "Session", ) -> messages.BenchmarkNames: - return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) + return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) -def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult: - return client.call( +def run(session: "Session", name: str) -> messages.BenchmarkResult: + return session.call( messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult ) diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index 938092a2dfc..6b35db0446e 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,20 +18,19 @@ from . import messages from .protobuf import dict_to_proto -from .tools import session if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -40,17 +39,16 @@ def get_address( def get_public_key( - client: "TrezorClient", address_n: "Address", show_display: bool = False + session: "Session", address_n: "Address", show_display: bool = False ) -> bytes: - return client.call( + return session.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display), expect=messages.BinancePublicKey, ).public_key -@session def sign_tx( - client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False + session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False ) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] tx_msg = tx_json.copy() @@ -59,7 +57,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - client.call(envelope, expect=messages.BinanceTxRequest) + session.call(envelope, expect=messages.BinanceTxRequest) if "refid" in msg: msg = dict_to_proto(messages.BinanceCancelMsg, msg) @@ -70,4 +68,4 @@ def sign_tx( else: raise ValueError("can not determine msg type") - return client.call(msg, expect=messages.BinanceSignedTx) + return session.call(msg, expect=messages.BinanceSignedTx) diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index 078f486d9e6..e3980055fcd 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -25,11 +25,11 @@ from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import _return_success, prepare_message_bytes, session +from .tools import _return_success, prepare_message_bytes if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session class ScriptSig(TypedDict): asm: str @@ -105,7 +105,7 @@ def make_bin_output(vout: "Vout") -> messages.TxOutputBinType: def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,12 +116,12 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> messages.PublicKey: if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str: def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -151,12 +151,12 @@ def get_authenticated_address( chunkify: bool = False, ) -> messages.Address: if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -171,13 +171,13 @@ def get_authenticated_address( def get_ownership_id( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> bytes: - return client.call( + return session.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, @@ -188,8 +188,9 @@ def get_ownership_id( ).ownership_id +# TODO this is used by tests only def get_ownership_proof( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, @@ -200,9 +201,9 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) + session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call( + res = session.call( messages.GetOwnershipProof( address_n=n, coin_name=coin_name, @@ -219,7 +220,7 @@ def get_ownership_proof( def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -227,7 +228,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> messages.MessageSignature: - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -241,7 +242,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -249,7 +250,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - client.call( + session.call( messages.VerifyMessage( address=address, signature=signature, @@ -264,9 +265,9 @@ def verify_message( return False -@session +# @session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -314,14 +315,14 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) elif preauthorized: - client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) + session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call(signtx, expect=messages.TxRequest) + res = session.call(signtx, expect=messages.TxRequest) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -380,7 +381,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg, expect=messages.TxRequest) + res = session.call(msg, expect=messages.TxRequest) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -410,7 +411,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest) + res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest) for i, sig in zip(inputs, signatures): if i.script_type != messages.InputScriptType.EXTERNAL and sig is None: @@ -420,7 +421,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: def authorize_coinjoin( - client: "TrezorClient", + session: "Session", coordinator: str, max_rounds: int, max_coordinator_fee_rate: int, @@ -429,7 +430,7 @@ def authorize_coinjoin( coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> str | None: - resp = client.call( + resp = session.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 4cbc635f1ff..a945cc9b106 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -35,7 +35,7 @@ from . import tools if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -818,7 +818,7 @@ def _get_collateral_inputs_items( def get_address( - client: "TrezorClient", + session: "Session", address_parameters: m.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], @@ -826,7 +826,7 @@ def get_address( derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, chunkify: bool = False, ) -> str: - return client.call( + return session.call( m.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, @@ -840,12 +840,12 @@ def get_address( def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, show_display: bool = False, ) -> m.CardanoPublicKey: - return client.call( + return session.call( m.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, @@ -856,12 +856,12 @@ def get_public_key( def get_native_script_hash( - client: "TrezorClient", + session: "Session", native_script: m.CardanoNativeScript, display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, ) -> m.CardanoNativeScriptHash: - return client.call( + return session.call( m.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, @@ -872,7 +872,7 @@ def get_native_script_hash( def sign_tx( - client: "TrezorClient", + session: "Session", signing_mode: m.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], @@ -907,7 +907,7 @@ def sign_tx( signing_mode, ) - response = client.call( + response = session.call( m.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), @@ -942,12 +942,12 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call( + auxiliary_data_supplement = session.call( auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement ) if ( @@ -958,25 +958,25 @@ def sign_tx( auxiliary_data_supplement.__dict__ ) - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) for tx_item in chain( _get_mint_items(mint), _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) for reference_input in reference_inputs: - response = client.call(reference_input, expect=m.CardanoTxItemAck) + response = session.call(reference_input, expect=m.CardanoTxItemAck) sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request, expect=m.CardanoTxWitnessResponse) + response = session.call(witness_request, expect=m.CardanoTxWitnessResponse) sign_tx_response["witnesses"].append( { "type": response.type, @@ -986,9 +986,9 @@ def sign_tx( } ) - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) return sign_tx_response diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 6db335a7adc..192eac614c8 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,33 +14,42 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import logging +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport -from ..client import TrezorClient -from ..ui import ClickUI, ScriptUI +from .. import exceptions, transport, ui +from ..client import ProtocolVersion, TrezorClient +from ..messages import Capability +from ..transport import Transport +from ..transport.session import Session, SessionV1, SessionV2 +from ..transport.thp.channel_database import get_channel_db + +LOG = logging.getLogger(__name__) -if TYPE_CHECKING: +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ - from typing import TypeVar from typing_extensions import Concatenate, ParamSpec - from ..transport import Transport - from ..ui import TrezorClientUI - P = ParamSpec("P") - R = TypeVar("R") + R = t.TypeVar("R") + FuncWithSession = t.Callable[Concatenate[Session, P], R] class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -48,7 +57,7 @@ def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -57,11 +66,69 @@ def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: return self.typemap[value] +def get_passphrase( + passphrase_on_host: bool, available_on_device: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + +def get_client(transport: Transport) -> TrezorClient: + stored_channels = get_channel_db().load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + try: + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + except Exception: + LOG.debug("Failed to resume a channel. Replacing by a new one.") + get_channel_db().remove_channel(path) + client = TrezorClient(transport) + else: + client = TrezorClient(transport) + return client + + class TrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -70,6 +137,54 @@ def __init__( self.passphrase_on_host = passphrase_on_host self.script = script + def get_session( + self, + derive_cardano: bool = False, + empty_passphrase: bool = False, + must_resume: bool = False, + ) -> Session: + client = self.get_client() + if must_resume and self.session_id is None: + click.echo("Failed to resume session - no session id provided") + raise RuntimeError("Failed to resume session - no session id provided") + + # Try resume session from id + if self.session_id is not None: + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + session = SessionV1.resume_from_id( + client=client, session_id=self.session_id + ) + elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: + session = SessionV2(client, self.session_id) + # TODO fix resumption on THP + else: + raise Exception("Unsupported client protocol", client.protocol_version) + if must_resume: + if session.id != self.session_id or session.id is None: + click.echo("Failed to resume session") + RuntimeError("Failed to resume session - no session id provided") + return session + + features = client.protocol.get_features() + + passphrase_enabled = True # TODO what to do here? + + if not passphrase_enabled: + return client.get_session(derive_cardano=derive_cardano) + + if empty_passphrase: + passphrase = "" + else: + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + return session + def get_transport(self) -> "Transport": try: # look for transport without prefix search @@ -82,19 +197,13 @@ def get_transport(self) -> "Transport": # if this fails, we want the exception to bubble up to the caller return transport.get_transport(self.path, prefix_search=True) - def get_ui(self) -> "TrezorClientUI": - if self.script: - # It is alright to return just the class object instead of instance, - # as the ScriptUI class object itself is the implementation of TrezorClientUI - # (ScriptUI is just a set of staticmethods) - return ScriptUI - else: - return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self) -> TrezorClient: - transport = self.get_transport() - ui = self.get_ui() - return TrezorClient(transport, ui=ui, session_id=self.session_id) + return get_client(self.get_transport()) + + def get_seedless_session(self) -> Session: + client = self.get_client() + seedless_session = client.get_seedless_session() + return seedless_session @contextmanager def client_context(self): @@ -127,8 +236,106 @@ def client_context(self): raise click.ClickException(str(e)) from e # other exceptions may cause a traceback + @contextmanager + def session_context( + self, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, + must_resume: bool = False, + ): + """Get a session instance as a context manager. Handle errors in a manner + appropriate for end-users. -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": + Usage: + >>> with obj.session_context() as session: + >>> do_your_actions_here() + """ + try: + if management: + session = self.get_seedless_session() + else: + session = self.get_session( + derive_cardano=derive_cardano, + empty_passphrase=empty_passphrase, + must_resume=must_resume, + ) + except exceptions.DeviceLockedException: + click.echo( + "Device is locked, enter a pin on the device.", + err=True, + ) + except transport.DeviceIsBusy: + click.echo("Device is in use by another process.") + sys.exit(1) + except Exception: + click.echo("Failed to find a Trezor device.") + if self.path is not None: + click.echo(f"Using path: {self.path}") + sys.exit(1) + + try: + yield session + except exceptions.Cancelled: + # handle cancel action + click.echo("Action was cancelled.") + sys.exit(1) + except exceptions.TrezorException as e: + # handle any Trezor-sent exceptions as user-readable + raise click.ClickException(str(e)) from e + # other exceptions may cause a traceback + + +def with_session( + func: "t.Callable[Concatenate[Session, P], R]|None" = None, + *, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, + must_resume: bool = False, +) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: + """Provides a Click command with parameter `session=obj.get_session(...)` + based on the parameters provided. + + If default parameters are ok, this decorator can be used without parentheses. + + TODO: handle resumption of sessions and their (potential) closure. + """ + + def decorator( + func: FuncWithSession, + ) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + with obj.session_context( + empty_passphrase=empty_passphrase, + derive_cardano=derive_cardano, + management=management, + must_resume=must_resume, + ) as session: + try: + return func(session, *args, **kwargs) + + finally: + pass + # TODO try end session if not resumed + + return function_with_session + + # If the decorator @get_session is used without parentheses + if func and callable(func): + return decorator(func) # type: ignore [Function return type] + + return decorator + + +def with_client( + func: "t.Callable[Concatenate[TrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -142,23 +349,62 @@ def trezorctl_command_with_client( obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) - + # session_was_resumed = obj.session_id == client.session_id + # if not session_was_resumed and obj.session_id is not None: + # # tried to resume but failed + # click.echo("Warning: failed to resume session.", err=True) + click.echo( + "Warning: resume session detection is not implemented yet!", err=True + ) try: return func(client, *args, **kwargs) finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) + # if not session_was_resumed: + # try: + # client.end_session() + # except Exception: + # pass return trezorctl_command_with_client +# def with_client( +# func: "t.Callable[Concatenate[TrezorClient, P], R]", +# ) -> "t.Callable[P, R]": +# """Wrap a Click command in `with obj.client_context() as client`. + +# Sessions are handled transparently. The user is warned when session did not resume +# cleanly. The session is closed after the command completes - unless the session +# was resumed, in which case it should remain open. +# """ + +# @click.pass_obj +# @functools.wraps(func) +# def trezorctl_command_with_client( +# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" +# ) -> "R": +# with obj.client_context() as client: +# session_was_resumed = obj.session_id == client.session_id +# if not session_was_resumed and obj.session_id is not None: +# # tried to resume but failed +# click.echo("Warning: failed to resume session.", err=True) + +# try: +# return func(client, *args, **kwargs) +# finally: +# if not session_was_resumed: +# try: +# client.end_session() +# except Exception: +# pass + +# # the return type of @click.pass_obj is improperly specified and pyright doesn't +# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) +# return trezorctl_command_with_client + + class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. @@ -188,14 +434,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index e445089815c..7908223881f 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,17 +20,15 @@ import click from .. import benchmark -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session -def list_names_patern( - client: "TrezorClient", pattern: Optional[str] = None -) -> List[str]: - names = list(benchmark.list_names(client).names) +def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]: + names = list(benchmark.list_names(session).names) if pattern is None: return names return [name for name in names if fnmatch(name, pattern)] @@ -43,10 +41,10 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: +@with_session(empty_passphrase=True) +def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: @@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def run(client: "TrezorClient", pattern: Optional[str]) -> None: +@with_session(empty_passphrase=True) +def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: for name in names: - result = benchmark.run(client, name) + result = benchmark.run(session, name) click.echo(f"{name}: {result.value} {result.unit}") diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index a3139fb2711..d8097b3e900 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -20,11 +20,11 @@ import click from .. import binance, tools -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" @@ -39,23 +39,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display, chunkify) + return binance.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() + return binance.get_public_key(session, address_n, show_display).hex() @cli.command() @@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. """ address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index d6a9867215c..77bbe83f811 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -174,15 +175,15 @@ def cli() -> None: help="Sort pubkeys lexicographically using BIP-67", ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, multisig_sort_pubkeys: bool, chunkify: bool, @@ -235,7 +236,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -252,9 +253,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -266,7 +267,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -326,7 +327,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -363,9 +364,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -375,7 +376,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -390,8 +391,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -462,7 +463,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -483,9 +484,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -495,7 +496,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 26d4eab5b99..1e6935d6d9a 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,10 +20,10 @@ import click from .. import cardano, messages, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" @@ -62,9 +62,9 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_client +@with_session(derive_cardano=True) def sign_tx( - client: "TrezorClient", + session: "Session", file: TextIO, signing_mode: messages.CardanoTxSigningMode, protocol_magic: int, @@ -123,9 +123,8 @@ def sign_tx( for p in transaction["additional_witness_requests"] ] - client.init_device(derive_cardano=True) sign_tx_response = cardano.sign_tx( - client, + session, signing_mode, inputs, outputs, @@ -209,9 +208,9 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_address( - client: "TrezorClient", + session: "Session", address: str, address_type: messages.CardanoAddressType, staking_address: str, @@ -262,9 +261,8 @@ def get_address( script_staking_hash_bytes, ) - client.init_device(derive_cardano=True) return cardano.get_address( - client, + session, address_parameters, protocol_magic, network_id, @@ -283,18 +281,17 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_public_key( - client: "TrezorClient", + session: "Session", address: str, derivation_type: messages.CardanoDerivationType, show_display: bool, ) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) - client.init_device(derive_cardano=True) return cardano.get_public_key( - client, address_n, derivation_type=derivation_type, show_display=show_display + session, address_n, derivation_type=derivation_type, show_display=show_display ) @@ -312,9 +309,9 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_client +@with_session(derive_cardano=True) def get_native_script_hash( - client: "TrezorClient", + session: "Session", file: TextIO, display_format: messages.CardanoNativeScriptHashDisplayFormat, derivation_type: messages.CardanoDerivationType, @@ -323,7 +320,6 @@ def get_native_script_hash( native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) - client.init_device(derive_cardano=True) return cardano.get_native_script_hash( - client, native_script, display_format, derivation_type=derivation_type + session, native_script, display_format, derivation_type=derivation_type ) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a58b80d4b69..469bc719a48 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,10 +19,10 @@ import click from .. import misc, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PROMPT_TYPE = ChoiceType( @@ -42,10 +42,10 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_client -def get_entropy(client: "TrezorClient", size: int) -> str: +@with_session(empty_passphrase=True) +def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" - return misc.get_entropy(client, size).hex() + return misc.get_entropy(session, size).hex() @cli.command() @@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -75,7 +75,7 @@ def encrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.encrypt_keyvalue( - client, + session, address_n, key, value.encode(), @@ -91,9 +91,9 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -112,7 +112,7 @@ def decrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.decrypt_keyvalue( - client, + session, address_n, key, bytes.fromhex(value), diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index d9d936c7ab9..fc93174c778 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,13 +18,12 @@ import click -from .. import mapping, messages, protobuf -from ..client import TrezorClient from ..debuglink import TrezorClientDebugLink from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from . import TrezorConnection @@ -35,51 +34,51 @@ def cli() -> None: """Miscellaneous debug features.""" -@cli.command() -@click.argument("message_name_or_type") -@click.argument("hex_data") -@click.pass_obj -def send_bytes( - obj: "TrezorConnection", message_name_or_type: str, hex_data: str -) -> None: - """Send raw bytes to Trezor. +# @cli.command() +# @click.argument("message_name_or_type") +# @click.argument("hex_data") +# @click.pass_obj +# def send_bytes( +# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str +# ) -> None: +# """Send raw bytes to Trezor. - Message type and message data must be specified separately, due to how message - chunking works on the transport level. Message length is calculated and sent - automatically, and it is currently impossible to explicitly specify invalid length. +# Message type and message data must be specified separately, due to how message +# chunking works on the transport level. Message length is calculated and sent +# automatically, and it is currently impossible to explicitly specify invalid length. - MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, - in which case the value of that enum is used. - """ - if message_name_or_type.isdigit(): - message_type = int(message_name_or_type) - else: - message_type = getattr(messages.MessageType, message_name_or_type) +# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, +# in which case the value of that enum is used. +# """ +# if message_name_or_type.isdigit(): +# message_type = int(message_name_or_type) +# else: +# message_type = getattr(messages.MessageType, message_name_or_type) - if not isinstance(message_type, int): - raise click.ClickException("Invalid message type.") +# if not isinstance(message_type, int): +# raise click.ClickException("Invalid message type.") - try: - message_data = bytes.fromhex(hex_data) - except Exception as e: - raise click.ClickException("Invalid hex data.") from e +# try: +# message_data = bytes.fromhex(hex_data) +# except Exception as e: +# raise click.ClickException("Invalid hex data.") from e - transport = obj.get_transport() - transport.begin_session() - transport.write(message_type, message_data) +# transport = obj.get_transport() +# transport.deprecated_begin_session() +# transport.write(message_type, message_data) - response_type, response_data = transport.read() - transport.end_session() +# response_type, response_data = transport.read() +# transport.deprecated_end_session() - click.echo(f"Response type: {response_type}") - click.echo(f"Response data: {response_data.hex()}") +# click.echo(f"Response type: {response_type}") +# click.echo(f"Response data: {response_data.hex()}") - try: - msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) - click.echo("Parsed message:") - click.echo(protobuf.format_message(msg)) - except Exception as e: - click.echo(f"Could not parse response: {e}") +# try: +# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) +# click.echo("Parsed message:") +# click.echo(protobuf.format_message(msg)) +# except Exception as e: +# click.echo(f"Could not parse response: {e}") @cli.command() @@ -106,17 +105,17 @@ def record_screen_from_connection( @cli.command() -@with_client -def prodtest_t1(client: "TrezorClient") -> None: +@with_session(management=True) +def prodtest_t1(session: "Session") -> None: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - debuglink_prodtest_t1(client) + debuglink_prodtest_t1(session) @cli.command() -@with_client -def optiga_set_sec_max(client: "TrezorClient") -> None: +@with_session(management=True) +def optiga_set_sec_max(session: "Session") -> None: """Set Optiga's security event counter to maximum.""" - debuglink_optiga_set_sec_max(client) + debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 0803b85a695..ebd80fd75ed 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -25,10 +25,10 @@ from .. import debuglink, device, exceptions, messages, ui from ..tools import format_path -from . import ChoiceType, with_client +from . import ChoiceType, with_session if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { @@ -64,17 +64,18 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_client -def wipe(client: "TrezorClient", bootloader: bool) -> None: +@with_session(management=True) +def wipe(session: "Session", bootloader: bool) -> None: """Reset device to factory defaults and remove all private data.""" + features = session.features if bootloader: - if not client.features.bootloader_mode: + if not features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) else: click.echo("Wiping user data and firmware!") else: - if client.features.bootloader_mode: + if features.bootloader_mode: click.echo( "Your device is in bootloader mode. This operation would also erase firmware." ) @@ -86,7 +87,13 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None: else: click.echo("Wiping user data!") - device.wipe(client) + try: + device.wipe( + session + ) # TODO decide where the wipe should happen - management or regular session + except exceptions.TrezorFailure as e: + click.echo("Action failed: {} {}".format(*e.args)) + sys.exit(3) @cli.command() @@ -99,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None: @click.option("-a", "--academic", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_client +@with_session(management=True) def load( - client: "TrezorClient", + session: "Session", mnemonic: t.Sequence[str], pin: str, passphrase_protection: bool, @@ -132,7 +139,7 @@ def load( try: debuglink.load_device( - client, + session, mnemonic=list(mnemonic), pin=pin, passphrase_protection=passphrase_protection, @@ -167,9 +174,9 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_client +@with_session(management=True) def recover( - client: "TrezorClient", + session: "Session", words: str, expand: bool, pin_protection: bool, @@ -197,7 +204,7 @@ def recover( type = messages.RecoveryType.UnlockRepeatedBackup device.recover( - client, + session, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -219,9 +226,9 @@ def recover( @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) @click.option("-e", "--entropy-check-count", type=click.IntRange(0)) -@with_client +@with_session(management=True) def setup( - client: "TrezorClient", + session: "Session", strength: int | None, passphrase_protection: bool, pin_protection: bool, @@ -241,10 +248,10 @@ def setup( if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) - and messages.Capability.Shamir not in client.features.capabilities + and messages.Capability.Shamir not in session.features.capabilities ) or ( backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) - and messages.Capability.ShamirGroups not in client.features.capabilities + and messages.Capability.ShamirGroups not in session.features.capabilities ): click.echo( "WARNING: Your Trezor device does not indicate support for the requested\n" @@ -252,7 +259,7 @@ def setup( ) path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -273,22 +280,21 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_client +@with_session(management=True) def backup( - client: "TrezorClient", + session: "Session", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), ) -> None: """Perform device seed backup.""" - device.backup(client, group_threshold, groups) + + device.backup(session, group_threshold, groups) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_client -def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType -) -> None: +@with_session(management=True) +def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -302,9 +308,9 @@ def sd_protect( off - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - if client.features.model == "1": + if session.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - device.sd_protect(client, operation) + device.sd_protect(session, operation) @cli.command() @@ -314,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> None: Currently only supported on Trezor Model One. """ - # avoid using @with_client because it closes the session afterwards, + # avoid using @with_session because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(client.get_seedless_session()) @cli.command() -@with_client -def tutorial(client: "TrezorClient") -> None: +@with_session(management=True) +def tutorial(session: "Session") -> None: """Show on-device tutorial.""" - device.show_device_tutorial(client) + device.show_device_tutorial(session) @cli.command() -@with_client -def unlock_bootloader(client: "TrezorClient") -> None: +@with_session(management=True) +def unlock_bootloader(session: "Session") -> None: """Unlocks bootloader. Irreversible.""" - device.unlock_bootloader(client) + device.unlock_bootloader(session) @cli.command() @@ -342,12 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> None: type=int, help="Dialog expiry in seconds.", ) -@with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None: +@with_session(management=True) +def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None: """Show a "Do not disconnect" dialog.""" if enable is False: - device.set_busy(client, None) - return + device.set_busy(session, None) if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -357,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - device.set_busy(client, expiry * 1000) + device.set_busy(session, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( @@ -377,9 +382,9 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_client +@with_session(management=True) def authenticate( - client: "TrezorClient", + session: "Session", hex_challenge: str | None, root: t.BinaryIO | None, raw: bool | None, @@ -404,7 +409,7 @@ def authenticate( challenge = bytes.fromhex(hex_challenge) if raw: - msg = device.authenticate(client, challenge) + msg = device.authenticate(session, challenge) click.echo(f"Challenge: {hex_challenge}") click.echo(f"Signature of challenge: {msg.signature.hex()}") @@ -452,14 +457,14 @@ def format(self, record: logging.LogRecord) -> str: else: whitelist_json = requests.get( PUBKEY_WHITELIST_URL_TEMPLATE.format( - model=client.model.internal_name.lower() + model=session.model.internal_name.lower() ) ).json() whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] try: authentication.authenticate_device( - client, challenge, root_pubkey=root_bytes, whitelist=whitelist + session, challenge, root_pubkey=root_bytes, whitelist=whitelist ) except authentication.DeviceNotAuthentic: click.echo("Device is not authentic.") diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 84c248c4a4c..27d461d8b0b 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -20,11 +20,11 @@ import click from .. import eos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" @@ -37,11 +37,11 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) - res = eos.get_public_key(client, address_n, show_display) + res = eos.get_public_key(session, address_n, show_display) return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" @@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_transaction( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) address_n = tools.parse_path(address) return eos.sign_tx( - client, + session, address_n, tx_json["transaction"], tx_json["chain_id"], diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 6bbfc0d356d..d810d2bf2d1 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -26,14 +26,14 @@ from .. import _rlp, definitions, ethereum, tools from ..messages import EthereumDefinitions -from . import with_client +from . import with_session if TYPE_CHECKING: import web3 from eth_typing import ChecksumAddress # noqa: I900 from web3.types import Wei - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" @@ -268,24 +268,24 @@ def cli( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - return ethereum.get_address(client, address_n, show_display, network, chunkify) + return ethereum.get_address(session, address_n, show_display, network, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: +@with_session +def get_public_node(session: "Session", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) - result = ethereum.get_public_node(client, address_n, show_display=show_display) + result = ethereum.get_public_node(session, address_n, show_display=show_display) return { "node": { "depth": result.node.depth, @@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-C", "--chunkify", is_flag=True) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", chain_id: int, address: str, amount: int, @@ -400,7 +400,7 @@ def sign_tx( encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) address_n = tools.parse_path(address) from_address = ethereum.get_address( - client, address_n, encoded_network=encoded_network + session, address_n, encoded_network=encoded_network ) if token: @@ -446,7 +446,7 @@ def sign_tx( assert max_gas_fee is not None assert max_priority_fee is not None sig = ethereum.sign_tx_eip1559( - client, + session, n=address_n, nonce=nonce, gas_limit=gas_limit, @@ -465,7 +465,7 @@ def sign_tx( gas_price = _get_web3().eth.gas_price assert gas_price is not None sig = ethereum.sign_tx( - client, + session, n=address_n, tx_type=tx_type, nonce=nonce, @@ -526,14 +526,14 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", address: str, message: str, chunkify: bool + session: "Session", address: str, message: str, chunkify: bool ) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) + ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify) output = { "message": message, "address": ret.address, @@ -550,9 +550,9 @@ def sign_message( help="Be compatible with Metamask's signTypedData_v4 implementation", ) @click.argument("file", type=click.File("r")) -@with_client +@with_session def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO + session: "Session", address: str, metamask_v4_compat: bool, file: TextIO ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -565,7 +565,7 @@ def sign_typed_data( defs = EthereumDefinitions(encoded_network=network) data = json.loads(file.read()) ret = ethereum.sign_typed_data( - client, + session, address_n, data, metamask_v4_compat=metamask_v4_compat, @@ -583,9 +583,9 @@ def sign_typed_data( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: str, message: str, @@ -594,7 +594,7 @@ def verify_message( """Verify message signed with Ethereum address.""" signature_bytes = ethereum.decode_hex(signature) return ethereum.verify_message( - client, address, signature_bytes, message, chunkify=chunkify + session, address, signature_bytes, message, chunkify=chunkify ) @@ -602,9 +602,9 @@ def verify_message( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("domain_hash_hex") @click.argument("message_hash_hex") -@with_client +@with_session def sign_typed_data_hash( - client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str + session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str ) -> Dict[str, str]: """ Sign hash of typed data (EIP-712) with Ethereum address. @@ -618,7 +618,7 @@ def sign_typed_data_hash( message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) ret = ethereum.sign_typed_data_hash( - client, address_n, domain_hash, message_hash, network + session, address_n, domain_hash, message_hash, network ) output = { "domain_hash": domain_hash_hex, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index b51bb74e123..70133732417 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,10 +19,10 @@ import click from .. import fido -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -40,10 +40,10 @@ def credentials() -> None: @credentials.command(name="list") -@with_client -def credentials_list(client: "TrezorClient") -> None: +@with_session(empty_passphrase=True) +def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) for cred in creds: click.echo("") click.echo(f"WebAuthn credential at index {cred.index}:") @@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None: +@with_session(empty_passphrase=True) +def credentials_add(session: "Session", hex_credential_id: str) -> None: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - fido.add_credential(client, bytes.fromhex(hex_credential_id)) + fido.add_credential(session, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_client -def credentials_remove(client: "TrezorClient", index: int) -> None: +@with_session(empty_passphrase=True) +def credentials_remove(session: "Session", index: int) -> None: """Remove the resident credential at the given index.""" - fido.remove_credential(client, index) + fido.remove_credential(session, index) # @@ -110,19 +110,19 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_client -def counter_set(client: "TrezorClient", counter: int) -> None: +@with_session(empty_passphrase=True) +def counter_set(session: "Session", counter: int) -> None: """Set FIDO/U2F counter value.""" - fido.set_counter(client, counter) + fido.set_counter(session, counter) @counter.command(name="get-next") -@with_client -def counter_get_next(client: "TrezorClient") -> int: +@with_session(empty_passphrase=True) +def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(client) + return fido.get_next_counter(session) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 4376a4f2839..262c9cc330b 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,10 +37,11 @@ from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection MODEL_CHOICE = ChoiceType( @@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool: This is the case from bootloader version 1.8.0, and also holds for firmware version 1.8.0 because that installs the appropriate bootloader. """ - f = client.features - version = (f.major_version, f.minor_version, f.patch_version) - bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) + features = client.features + version = client.version + bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0) return bootloader_onev2 @@ -306,25 +307,26 @@ def find_best_firmware_version( If the specified version is not found, prints the closest available version (higher than the specified one, if existing). """ + features = client.features + model = client.model + if bitcoin_only is None: - bitcoin_only = _should_use_bitcoin_only(client.features) + bitcoin_only = _should_use_bitcoin_only(features) def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) - f = client.features - - releases = get_all_firmware_releases(client.model, bitcoin_only, beta) + releases = get_all_firmware_releases(model, bitcoin_only, beta) highest_version = releases[0]["version"] if version: want_version = [int(x) for x in version.split(".")] if len(want_version) != 3: click.echo("Please use the 'X.Y.Z' version format.") - if want_version[0] != f.major_version: + if want_version[0] != features.major_version: click.echo( - f"Warning: Trezor {client.model.name} firmware version should be " - f"{f.major_version}.X.Y (requested: {version})" + f"Warning: Trezor {model.name} firmware version should be " + f"{features.major_version}.X.Y (requested: {version})" ) else: want_version = highest_version @@ -359,8 +361,8 @@ def version_str(version: Iterable[int]) -> str: # to the newer one, in that case update to the minimal # compatible version first # Choosing the version key to compare based on (not) being in BL mode - client_version = [f.major_version, f.minor_version, f.patch_version] - if f.bootloader_mode: + client_version = client.version + if features.bootloader_mode: key_to_compare = "min_bootloader_version" else: key_to_compare = "min_firmware_version" @@ -447,11 +449,11 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: "TrezorClient", + session: "Session", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" - f = client.features + f = session.features try: if f.major_version == 1 and f.firmware_present is not False: # Trezor One does not send ButtonRequest @@ -461,7 +463,7 @@ def upload_firmware_into_device( with click.progressbar( label="Uploading", length=len(firmware_data), show_eta=False ) as bar: - firmware.update(client, firmware_data, bar.update) + firmware.update(session, firmware_data, bar.update) except exceptions.Cancelled: click.echo("Update aborted on device.") except exceptions.TrezorException as e: @@ -654,6 +656,7 @@ def update( against data.trezor.io information, if available. """ with obj.client_context() as client: + seedless_session = client.get_seedless_session() if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") sys.exit(1) @@ -709,7 +712,7 @@ def update( if _is_strict_update(client, firmware_data): header_size = _get_firmware_header_size(firmware_data) device.reboot_to_bootloader( - client, + seedless_session, boot_command=messages.BootCommand.INSTALL_UPGRADE, firmware_header=firmware_data[:header_size], language_data=language_data, @@ -719,7 +722,7 @@ def update( click.echo( "WARNING: Seamless installation not possible, language data will not be uploaded." ) - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(seedless_session) click.echo("Waiting for bootloader...") while True: @@ -735,13 +738,15 @@ def update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - upload_firmware_into_device(client=client, firmware_data=firmware_data) + upload_firmware_into_device( + session=client.get_seedless_session(), firmware_data=firmware_data + ) @cli.command() @click.argument("hex_challenge", required=False) -@with_client -def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: +@with_session(management=True) +def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None - return firmware.get_hash(client, challenge).hex() + return firmware.get_hash(session, challenge).hex() diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index 355c562ae39..0441ebc09b4 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -19,10 +19,10 @@ import click from .. import messages, monero, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" @@ -42,9 +42,9 @@ def cli() -> None: default=messages.MoneroNetworkType.MAINNET, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, network_type: messages.MoneroNetworkType, @@ -52,7 +52,7 @@ def get_address( ) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - return monero.get_address(client, address_n, show_display, network_type, chunkify) + return monero.get_address(session, address_n, show_display, network_type, chunkify) @cli.command() @@ -63,13 +63,13 @@ def get_address( type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), default=messages.MoneroNetworkType.MAINNET, ) -@with_client +@with_session def get_watch_key( - client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType + session: "Session", address: str, network_type: messages.MoneroNetworkType ) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - res = monero.get_watch_key(client, address_n, network_type) + res = monero.get_watch_key(session, address_n, network_type) # TODO: could be made required in MoneroWatchKey assert res.address is not None assert res.watch_key is not None diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 746ad187236..eac16c2d8c2 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -21,10 +21,10 @@ import requests from .. import nem, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" @@ -39,9 +39,9 @@ def cli() -> None: @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, network: int, show_display: bool, @@ -49,7 +49,7 @@ def get_address( ) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) - return nem.get_address(client, address_n, network, show_display, chunkify) + return nem.get_address(session, address_n, network, show_display, chunkify) @cli.command() @@ -58,9 +58,9 @@ def get_address( @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, file: TextIO, broadcast: Optional[str], @@ -71,7 +71,7 @@ def sign_tx( Transaction file is expected in the NIS (RequestPrepareAnnounce) format. """ address_n = tools.parse_path(address) - transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify) payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index e4bcc0b3503..634a92028e6 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -20,10 +20,10 @@ import click from .. import ripple, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" @@ -37,13 +37,13 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ripple address""" address_n = tools.parse_path(address) - return ripple.get_address(client, address_n, show_display, chunkify) + return ripple.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -51,13 +51,13 @@ def get_address( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client -def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) - result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) + result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify) click.echo("Signature:") click.echo(result.signature.hex()) click.echo() diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index 00e4178c440..f62c043c0a5 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -24,10 +24,11 @@ import requests from .. import device, messages, toif -from . import AliasedGroup, ChoiceType, with_client +from ..transport.session import Session +from . import AliasedGroup, ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + pass try: from PIL import Image @@ -190,18 +191,18 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: +@with_session(management=True) +def pin(session: "Session", enable: Optional[bool], remove: bool) -> None: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - device.change_pin(client, remove=_should_remove(enable, remove)) + device.change_pin(session, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: +@with_session(management=True) +def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -209,32 +210,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - device.change_wipe_code(client, remove=_should_remove(enable, remove)) + device.change_wipe_code(session, remove=_should_remove(enable, remove)) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_client -def label(client: "TrezorClient", label: str) -> None: +@with_session(management=True) +def label(session: "Session", label: str) -> None: """Set new device label.""" - device.apply_settings(client, label=label) + device.apply_settings(session, label=label) @cli.command() -@with_client -def brightness(client: "TrezorClient") -> None: +@with_session(management=True) +def brightness(session: "Session") -> None: """Set display brightness.""" - device.set_brightness(client) + device.set_brightness(session) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> None: +@with_session(management=True) +def haptic_feedback(session: "Session", enable: bool) -> None: """Enable or disable haptic feedback.""" - device.apply_settings(client, haptic_feedback=enable) + device.apply_settings(session, haptic_feedback=enable) @cli.command() @@ -243,9 +244,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_client +@with_session(management=True) def language( - client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None + session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> None: """Set new language with translations.""" if remove != (path_or_url is None): @@ -269,30 +270,28 @@ def language( raise click.ClickException( f"Failed to load translations from {path_or_url}" ) from None - device.change_language(client, language_data=language_data, show_display=display) + device.change_language(session, language_data=language_data, show_display=display) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_client -def display_rotation( - client: "TrezorClient", rotation: messages.DisplayRotation -) -> None: +@with_session(management=True) +def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - device.apply_settings(client, display_rotation=rotation) + device.apply_settings(session, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> None: +@with_session(management=True) +def auto_lock_delay(session: "Session", delay: str) -> None: """Set auto-lock delay (in seconds).""" - if not client.features.pin_protection: + if not session.features.pin_protection: raise click.ClickException("Set up a PIN first") value, unit = delay[:-1], delay[-1:] @@ -301,13 +300,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@with_client -def flags(client: "TrezorClient", flags: str) -> None: +@with_session(management=True) +def flags(session: "Session", flags: str) -> None: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -315,7 +314,7 @@ def flags(client: "TrezorClient", flags: str) -> None: flags_int = int(flags, 16) else: flags_int = int(flags) - device.apply_flags(client, flags=flags_int) + device.apply_flags(session, flags=flags_int) @cli.command() @@ -324,8 +323,8 @@ def flags(client: "TrezorClient", flags: str) -> None: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: +@with_session(management=True) +def homescreen(session: "Session", filename: str, quality: int) -> None: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -337,39 +336,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: if not path.exists() or not path.is_file(): raise click.ClickException("Cannot open file") - if client.features.model == "1": + if session.features.model == "1": img = image_to_t1(path) else: - if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: + if session.features.homescreen_format == messages.HomescreenFormat.Jpeg: width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 240 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 240 ) img = image_to_jpeg(path, width, height, quality) - elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: - width = client.features.homescreen_width - height = client.features.homescreen_height + elif session.features.homescreen_format == messages.HomescreenFormat.ToiG: + width = session.features.homescreen_width + height = session.features.homescreen_height if width is None or height is None: raise click.ClickException("Device did not report homescreen size.") img = image_to_toif(path, width, height, True) elif ( - client.features.homescreen_format == messages.HomescreenFormat.Toif - or client.features.homescreen_format is None + session.features.homescreen_format == messages.HomescreenFormat.Toif + or session.features.homescreen_format is None ): width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 144 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 144 ) img = image_to_toif(path, width, height, False) @@ -379,7 +378,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: "Unknown image format requested by the device." ) - device.apply_settings(client, homescreen=img) + device.apply_settings(session, homescreen=img) @cli.command() @@ -387,9 +386,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_client +@with_session(management=True) def safety_checks( - client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel + session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> None: """Set safety check level. @@ -402,18 +401,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - device.apply_settings(client, safety_checks=level) + device.apply_settings(session, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def experimental_features(client: "TrezorClient", enable: bool) -> None: +@with_session(management=True) +def experimental_features(session: "Session", enable: bool) -> None: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - device.apply_settings(client, experimental_features=enable) + device.apply_settings(session, experimental_features=enable) # @@ -436,25 +435,25 @@ def passphrase_main() -> None: @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None: +@with_session(management=True) +def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None: """Enable passphrase.""" - if client.features.passphrase_protection is not True: + if session.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None device.apply_settings( - client, + session, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, ) @passphrase.command(name="off") -@with_client -def passphrase_off(client: "TrezorClient") -> None: +@with_session(management=True) +def passphrase_off(session: "Session") -> None: """Disable passphrase.""" - device.apply_settings(client, use_passphrase=False) + device.apply_settings(session, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -467,10 +466,10 @@ def passphrase_off(client: "TrezorClient") -> None: @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None: +@with_session(management=True) +def hide_passphrase_from_host(session: "Session", hide: bool) -> None: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - device.apply_settings(client, hide_passphrase_from_host=hide) + device.apply_settings(session, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 590b4f79146..52574a89d65 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -4,10 +4,10 @@ import click from .. import messages, solana, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h" @@ -21,40 +21,40 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, ) -> bytes: """Get Solana public key.""" address_n = tools.parse_path(address) - return solana.get_public_key(client, address_n, show_display) + return solana.get_public_key(session, address_n, show_display) @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, chunkify: bool, ) -> str: """Get Solana address.""" address_n = tools.parse_path(address) - return solana.get_address(client, address_n, show_display, chunkify) + return solana.get_address(session, address_n, show_display, chunkify) @cli.command() @click.argument("serialized_tx", type=str) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-a", "--additional-information-file", type=click.File("r")) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, serialized_tx: str, additional_information_file: Optional[TextIO], @@ -78,7 +78,7 @@ def sign_tx( ) return solana.sign_tx( - client, + session, address_n, bytes.fromhex(serialized_tx), additional_information, diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 77ce700ee5b..9acb6a57ed7 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -21,10 +21,10 @@ import click from .. import stellar, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from stellar_sdk import ( @@ -52,13 +52,13 @@ def cli() -> None: ) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) - return stellar.get_address(client, address_n, show_display, chunkify) + return stellar.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -77,9 +77,9 @@ def get_address( help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@with_client +@with_session def sign_transaction( - client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str + session: "Session", b64envelope: str, address: str, network_passphrase: str ) -> bytes: """Sign a base64-encoded transaction envelope. @@ -109,6 +109,6 @@ def sign_transaction( address_n = tools.parse_path(address) tx, operations = stellar.from_envelope(envelope) - resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) + resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase) return base64.b64encode(resp.signature) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 7dcd1ab9db1..e4f0c1a877d 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -20,10 +20,10 @@ import click from .. import messages, protobuf, tezos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" @@ -37,23 +37,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) - return tezos.get_address(client, address_n, show_display, chunkify) + return tezos.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) - return tezos.get_public_key(client, address_n, show_display) + return tezos.get_public_key(session, address_n, show_display) @cli.command() @@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) - return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) + return tezos.sign_tx(session, address_n, msg, chunkify=chunkify) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 60f8e8d3092..b94ee5af72e 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,9 +24,12 @@ import click -from .. import __version__, log, messages, protobuf, ui -from ..client import TrezorClient +from .. import __version__, log, messages, protobuf +from ..client import ProtocolVersion, TrezorClient from ..transport import DeviceIsBusy, enumerate_devices +from ..transport.session import Session +from ..transport.thp import channel_database +from ..transport.thp.channel_database import get_channel_db from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -50,6 +53,7 @@ stellar, tezos, with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None: "--record", help="Record screen changes into a specified directory.", ) +@click.option( + "-n", + "--no-store", + is_flag=True, + help="Do not store channels data between commands.", + default=False, +) @click.version_option(version=__version__) @click.pass_context def cli_main( @@ -204,9 +215,10 @@ def cli_main( script: bool, session_id: Optional[str], record: Optional[str], + no_store: bool, ) -> None: configure_logging(verbose) - + channel_database.set_channel_database(should_not_store=no_store) bytes_session_id: Optional[bytes] = None if session_id is not None: try: @@ -285,18 +297,23 @@ def format_device_name(features: messages.Features) -> str: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + for d in enumerate_devices(): + click.echo(d.get_path()) + return + + from . import get_client for transport in enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) + client = get_client(transport) description = format_device_name(client.features) - client.end_session() + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" - except Exception: - description = "Failed to read details" - click.echo(f"{transport} - {description}") + except Exception as e: + description = "Failed to read details " + str(type(e)) + click.echo(f"{transport.get_path()} - {description}") return None @@ -314,15 +331,19 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session(empty_passphrase=True) +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.ping(message, button_protection) @cli.command() @click.pass_obj -def get_session(obj: TrezorConnection) -> str: +def get_session( + obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False +) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -341,18 +362,38 @@ def get_session(obj: TrezorConnection) -> str: "Upgrade your firmware to enable session support." ) - client.ensure_unlocked() - if client.session_id is None: + # client.ensure_unlocked() + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + if session.id is None: raise click.ClickException("Passphrase not enabled or firmware too old.") else: - return client.session_id.hex() + return session.id.hex() @cli.command() -@with_client -def clear_session(client: "TrezorClient") -> None: +@with_session(must_resume=True, empty_passphrase=True) +def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - return client.clear_session() + if session is None: + click.echo("Cannot clear session as it was not properly resumed.") + return + session.call(messages.LockDevice()) + session.end() + # TODO different behaviour than main, not sure if ok + + +@cli.command() +def delete_channels() -> None: + """ + Delete cached channels. + + Do not use together with the `-n` (`--no-store`) flag, + as the JSON database will not be deleted in that case. + """ + get_channel_db().clear_stored_channels() + click.echo("Deleted stored channels") @cli.command() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 4e432bd0123..9d4a9c0f39f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -13,28 +13,24 @@ # # You should have received a copy of the License along with this library. # If not, see . - from __future__ import annotations import logging import os -import warnings -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar - -from mnemonic import Mnemonic - -from . import exceptions, mapping, messages, models -from .log import DUMP_BYTES -from .messages import Capability -from .protobuf import MessageType -from .tools import parse_path, session +import typing as t +from enum import IntEnum -if TYPE_CHECKING: - from .transport import Transport - from .ui import TrezorClientUI +from . import mapping, messages, models +from .mapping import ProtobufMapping +from .tools import parse_path +from .transport import Transport, get_transport +from .transport.thp.channel_data import ChannelData +from .transport.thp.protocol_and_channel import ProtocolAndChannel +from .transport.thp.protocol_v1 import ProtocolV1 +from .transport.thp.protocol_v2 import ProtocolV2 -UI = TypeVar("UI", bound="TrezorClientUI") -MT = TypeVar("MT", bound=MessageType) +if t.TYPE_CHECKING: + from .transport.session import Session LOG = logging.getLogger(__name__) @@ -51,447 +47,218 @@ """.strip() -def get_default_client( - path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any -) -> "TrezorClient": - """Get a client for a connected Trezor device. - - Returns a TrezorClient instance with minimum fuss. - - If path is specified, does a prefix-search for the specified device. Otherwise, uses - the value of TREZOR_PATH env variable, or finds first connected Trezor. - If no UI is supplied, instantiates the default CLI UI. - """ - from .transport import get_transport - from .ui import ClickUI - - if path is None: - path = os.getenv("TREZOR_PATH") - - transport = get_transport(path, prefix_search=True) - if ui is None: - ui = ClickUI() +LOG = logging.getLogger(__name__) - return TrezorClient(transport, ui, **kwargs) +class ProtocolVersion(IntEnum): + UNKNOWN = 0x00 + PROTOCOL_V1 = 0x01 # Codec + PROTOCOL_V2 = 0x02 # THP -class TrezorClient(Generic[UI]): - """Trezor client, a connection to a Trezor device. - This class allows you to manage connection state, send and receive protobuf - messages, handle user interactions, and perform some generic tasks - (send a cancel message, initialize or clear a session, ping the device). - """ +class TrezorClient: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None - model: models.TrezorModel - transport: "Transport" - session_id: Optional[bytes] - ui: UI - features: messages.Features + _seedless_session: Session | None = None + _features: messages.Features | None = None + _protocol_version: int + _setup_pin: str | None = None # Should by used only by conftest def __init__( self, - transport: "Transport", - ui: UI, - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, - model: Optional[models.TrezorModel] = None, - _init_device: bool = True, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: ProtocolAndChannel | None = None, ) -> None: - """Create a TrezorClient instance. - - You have to provide a `transport`, i.e., a raw connection to the device. You can - use `trezorlib.transport.get_transport` to find one. - - You have to provide a UI implementation for the three kinds of interaction: - - button request (notify the user that their interaction is needed) - - PIN request (on T1, ask the user to input numbers for a PIN matrix) - - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for - details. - - You can supply a `session_id` you might have saved in the previous session. If - you do, the user might not need to enter their passphrase again. - - You can provide Trezor model information. If not provided, it is detected from - the model name reported at initialization time. - - By default, the instance will open a connection to the Trezor device, send an - `Initialize` message, set up the `features` field from the response, and connect - to a session. By specifying `_init_device=False`, this step is skipped. Notably, - this means that `client.features` is unset. Use `client.init_device()` or - `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. - Only use this if you are _sure_ that you know what you are doing. This feature - might be removed at any time. - """ - LOG.info(f"creating client instance for device: {transport.get_path()}") - # Here, self.model could be set to None. Unless _init_device is False, it will - # get correctly reconfigured as part of the init_device flow. - self.model = model # type: ignore ["None" is incompatible with "TrezorModel"] - if self.model: - self.mapping = self.model.default_mapping - else: - self.mapping = mapping.DEFAULT_MAPPING + self._is_invalidated: bool = False self.transport = transport - self.ui = ui - self.session_counter = 0 - self.session_id = session_id - if _init_device: - self.init_device(session_id=session_id, derive_cardano=derive_cardano) - - def open(self) -> None: - if self.session_counter == 0: - self.transport.begin_session() - self.session_counter += 1 - - def close(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - # TODO call EndSession here? - self.transport.end_session() - - def cancel(self) -> None: - self._raw_write(messages.Cancel()) - - def call_raw(self, msg: MessageType) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self._raw_write(msg) - return self._raw_read() - - def _raw_write(self, msg: MessageType) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - msg_type, msg_bytes = self.mapping.encode(msg) - LOG.log( - DUMP_BYTES, - f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - self.transport.write(msg_type, msg_bytes) - - def _raw_read(self) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - msg_type, msg_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - msg = self.mapping.decode(msg_type, msg_bytes) - LOG.debug( - f"received message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - return msg - - def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType: - try: - pin = self.ui.get_pin(msg.type) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - self.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - - resp = self.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise exceptions.PinException(resp.code, resp.message) + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING else: - return resp - - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType: - available_on_device = Capability.PassphraseEntry in self.features.capabilities - - def send_passphrase( - passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> MessageType: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = self.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - self.session_id = resp.state - resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = self.ui.get_passphrase(available_on_device=available_on_device) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - self.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - self.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def _callback_button(self, msg: messages.ButtonRequest) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - self._raw_write(messages.ButtonAck()) - self.ui.button_request(msg) - return self._raw_read() - - @session - def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: - self.check_firmware_version() - resp = self.call_raw(msg) - while True: - if isinstance(resp, messages.PinMatrixRequest): - resp = self._callback_pin(resp) - elif isinstance(resp, messages.PassphraseRequest): - resp = self._callback_passphrase(resp) - elif isinstance(resp, messages.ButtonRequest): - resp = self._callback_button(resp) - elif isinstance(resp, messages.Failure): - if resp.code == messages.FailureType.ActionCancelled: - raise exceptions.Cancelled - raise exceptions.TrezorFailure(resp) - elif not isinstance(resp, expect): - raise exceptions.UnexpectedMessageError(expect, resp) + self.mapping = protobuf_mapping + if protocol is None: + self.protocol = self._get_protocol() + else: + self.protocol = protocol + self.protocol.mapping = self.mapping + + if isinstance(self.protocol, ProtocolV1): + self._protocol_version = ProtocolVersion.PROTOCOL_V1 + elif isinstance(self.protocol, ProtocolV2): + self._protocol_version = ProtocolVersion.PROTOCOL_V2 + else: + self._protocol_version = ProtocolVersion.UNKNOWN + + @classmethod + def resume( + cls, + transport: Transport, + channel_data: ChannelData, + protobuf_mapping: ProtobufMapping | None = None, + ) -> TrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + protocol_v1 = ProtocolV1(transport, protobuf_mapping) + if channel_data.protocol_version_major == 2: + try: + protocol_v1.write(messages.Ping(message="Sanity check - to resume")) + except Exception as e: + print(type(e)) + response = protocol_v1.read() + if ( + isinstance(response, messages.Failure) + and response.code == messages.FailureType.InvalidProtocol + ): + protocol = ProtocolV2(transport, protobuf_mapping, channel_data) + protocol.write(0, messages.Ping()) + response = protocol.read(0) + if not isinstance(response, messages.Success): + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + LOG.debug("Protocol V2 detected - can be resumed") else: - return resp - - def _refresh_features(self, features: messages.Features) -> None: - """Update internal fields based on passed-in Features message.""" - - if not self.model: - # Trezor Model One bootloader 1.8.0 or older does not send model name - model = models.by_internal_name(features.internal_model) - if model is None: - model = models.by_name(features.model or "1") - if model is None: - raise RuntimeError( - "Unsupported Trezor model" - f" (internal_model: {features.internal_model}, model: {features.model})" - ) - self.model = model - - if features.vendor not in self.model.vendors: - raise RuntimeError("Unsupported device") - - self.features = features - self.version = ( - self.features.major_version, - self.features.minor_version, - self.features.patch_version, - ) - self.check_firmware_version(warn_only=True) - if self.features.session_id is not None: - self.session_id = self.features.session_id - self.features.session_id = None + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + else: + protocol = ProtocolV1(transport, protobuf_mapping, channel_data) + return TrezorClient(transport, protobuf_mapping, protocol) - @session - def refresh_features(self) -> messages.Features: - """Reload features from the device. + def get_session( + self, + passphrase: str | object | None = None, + derive_cardano: bool = False, + session_id: int = 0, + ) -> Session: + """ + Returns initialized session (with derived seed). - Should be called after changing settings or performing operations that affect - device state. + Will fail if the device is not initialized """ - resp = self.call_raw(messages.GetFeatures()) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to GetFeatures") - self._refresh_features(resp) - return resp - - @session - def init_device( - self, - *, - session_id: Optional[bytes] = None, - new_session: bool = False, - derive_cardano: Optional[bool] = None, - ) -> Optional[bytes]: - """Initialize the device and return a session ID. - - You can optionally specify a session ID. If the session still exists on the - device, the same session ID will be returned and the session is resumed. - Otherwise a different session ID is returned. - - Specify `new_session=True` to open a fresh session. Since firmware version - 1.9.0/2.3.0, the previous session will remain cached on the device, and can be - resumed by calling `init_device` again with the appropriate session ID. - - If neither `new_session` nor `session_id` is specified, the current session ID - will be reused. If no session ID was cached, a new session ID will be allocated - and returned. - - # Version notes: - - Trezor One older than 1.9.0 does not have session management. Optional arguments - have no effect and the function returns None - - Trezor T older than 2.3.0 does not have session cache. Requesting a new session - will overwrite the old one. In addition, this function will always return None. - A valid session_id can be obtained from the `session_id` attribute, but only - after a passphrase-protected call is performed. You can use the following code: - - >>> client.init_device() - >>> client.ensure_unlocked() - >>> valid_session_id = client.session_id + from .transport.session import SessionV1, SessionV2 + + if isinstance(self.protocol, ProtocolV1): + if passphrase is None: + passphrase = "" + return SessionV1.new(self, passphrase, derive_cardano) + if isinstance(self.protocol, ProtocolV2): + assert isinstance(passphrase, str) or passphrase is None + return SessionV2.new(self, passphrase, derive_cardano, session_id) + raise NotImplementedError # TODO + + def resume_session(self, session: Session): """ - if new_session: - self.session_id = None - elif session_id is not None: - self.session_id = session_id - - resp = self.call_raw( - messages.Initialize( - session_id=self.session_id, - derive_cardano=derive_cardano, + Note: this function potentially modifies the input session. + """ + from .debuglink import SessionDebugWrapper + from .transport.session import SessionV1, SessionV2 + + if isinstance(session, SessionDebugWrapper): + session = session._session + + if isinstance(session, SessionV2): + return session + elif isinstance(session, SessionV1): + session.init_session() + return session + + else: + raise NotImplementedError + + def get_seedless_session(self, new_session: bool = False) -> Session: + from .transport.session import SessionV1, SessionV2 + + if not new_session and self._seedless_session is not None: + return self._seedless_session + if isinstance(self.protocol, ProtocolV1): + self._seedless_session = SessionV1.new( + client=self, + passphrase="", + derive_cardano=False, ) + elif isinstance(self.protocol, ProtocolV2): + self._seedless_session = SessionV2(client=self, id=b"\x00") + assert self._seedless_session is not None + return self._seedless_session + + def invalidate(self) -> None: + self._is_invalidated = True + + @property + def features(self) -> messages.Features: + if self._features is None: + self._features = self.protocol.get_features() + assert self._features is not None + return self._features + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def model(self) -> models.TrezorModel: + f = self.features + model = models.by_name(f.model or "1") + + if model is None: + raise RuntimeError( + "Unsupported Trezor model" + f" (internal_model: {f.internal_model}, model: {f.model})" + ) + return model + + @property + def version(self) -> tuple[int, int, int]: + f = self.features + ver = ( + f.major_version, + f.minor_version, + f.patch_version, ) - if isinstance(resp, messages.Failure): - # can happen if `derive_cardano` does not match the current session - raise exceptions.TrezorFailure(resp) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to Initialize") - - if self.session_id is not None and resp.session_id == self.session_id: - LOG.info("Successfully resumed session") - elif session_id is not None: - LOG.info("Failed to resume session") - - # TT < 2.3.0 compatibility: - # _refresh_features will clear out the session_id field. We want this function - # to return its value, so that callers can rely on it being either a valid - # session_id, or None if we can't do that. - # Older TT FW does not report session_id in Features and self.session_id might - # be invalid because TT will not allocate a session_id until a passphrase - # exchange happens. - reported_session_id = resp.session_id - self._refresh_features(resp) - return reported_session_id - - def is_outdated(self) -> bool: - if self.features.bootloader_mode: - return False - return self.version < self.model.minimum_version - - def check_firmware_version(self, warn_only: bool = False) -> None: - if self.is_outdated(): - if warn_only: - warnings.warn("Firmware is out of date", stacklevel=2) - else: - raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - - def ping(self, msg: str, button_protection: bool = False) -> str: - # We would like ping to work on any valid TrezorClient instance, but - # due to the protection modes, we need to go through self.call, and that will - # raise an exception if the firmware is too old. - # So we short-circuit the simplest variant of ping with call_raw. - if not button_protection: - # XXX this should be: `with self:` - try: - self.open() - resp = self.call_raw(messages.Ping(message=msg)) - if isinstance(resp, messages.ButtonRequest): - # device is PIN-locked. - # respond and hope for the best - resp = self._callback_button(resp) - resp = messages.Success.ensure_isinstance(resp) - assert resp.message is not None - return resp.message - finally: - self.close() - - resp = self.call( - messages.Ping(message=msg, button_protection=button_protection), - expect=messages.Success, - ) - assert resp.message is not None - return resp.message + return ver - def get_device_id(self) -> Optional[str]: - return self.features.device_id + @property + def is_invalidated(self) -> bool: + return self._is_invalidated - @session - def lock(self, *, _refresh_features: bool = True) -> None: - """Lock the device. + def refresh_features(self) -> None: + self.protocol.update_features() + self._features = self.protocol.get_features() - If the device does not have a PIN configured, this will do nothing. - Otherwise, a lock screen will be shown and the device will prompt for PIN - before further actions. + def _get_protocol(self) -> ProtocolAndChannel: + self.transport.open() - This call does _not_ invalidate passphrase cache. If passphrase is in use, - the device will not prompt for it after unlocking. + protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) - To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate - passphrase cache, use `clear_session()`. - """ - # Private argument _refresh_features can be used internally to avoid - # refreshing in cases where we will refresh soon anyway. This is used - # in TrezorClient.clear_session() - self.call(messages.LockDevice()) - if _refresh_features: - self.refresh_features() - - @session - def ensure_unlocked(self) -> None: - """Ensure the device is unlocked and a passphrase is cached. - - If the device is locked, this will prompt for PIN. If passphrase is enabled - and no passphrase is cached for the current session, the device will also - prompt for passphrase. - - After calling this method, further actions on the device will not prompt for - PIN or passphrase until the device is locked or the session becomes invalid. - """ - from .btc import get_address + protocol.write(messages.Initialize()) - get_address(self, "Testnet", PASSPHRASE_TEST_PATH) - self.refresh_features() + response = protocol.read() + self.transport.close() + if isinstance(response, messages.Failure): + if response.code == messages.FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol = ProtocolV2(self.transport, self.mapping) + return protocol - def end_session(self) -> None: - """Close the current session and clear cached passphrase. - The session will become invalid until `init_device()` is called again. - If passphrase is enabled, further actions will prompt for it again. +def get_default_client( + path: t.Optional[str] = None, + **kwargs: t.Any, +) -> "TrezorClient": + """Get a client for a connected Trezor device. - This is a no-op in bootloader mode, as it does not support session management. - """ - # since: 2.3.4, 1.9.4 - try: - if not self.features.bootloader_mode: - self.call(messages.EndSession()) - except exceptions.TrezorFailure: - # A failure most likely means that the FW version does not support - # the EndSession call. We ignore the failure and clear the local session_id. - # The client-side end result is identical. - pass - self.session_id = None - - @session - def clear_session(self) -> None: - """Lock the device and present a fresh session. - - The current session will be invalidated and a new one will be started. If the - device has PIN enabled, it will become locked. - - Equivalent to calling `lock()`, `end_session()` and `init_device()`. - """ - self.lock(_refresh_features=False) - self.end_session() - self.init_device(new_session=True) + Returns a TrezorClient instance with minimum fuss. + + If path is specified, does a prefix-search for the specified device. Otherwise, uses + the value of TREZOR_PATH env variable, or finds first connected Trezor. + If no UI is supplied, instantiates the default CLI UI. + """ + + if path is None: + path = os.getenv("TREZOR_PATH") + + transport = get_transport(path, prefix_search=True) + + return TrezorClient(transport, **kwargs) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 0a2096993b2..0c63c30b652 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -21,55 +21,55 @@ import re import textwrap import time +import typing as t from contextlib import contextmanager from copy import deepcopy from datetime import datetime from enum import Enum, IntEnum, auto from itertools import zip_longest from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - Sequence, - Tuple, - Union, -) from mnemonic import Mnemonic -from . import mapping, messages, models, protobuf -from .client import TrezorClient -from .exceptions import TrezorFailure, UnexpectedMessageError +from . import btc, mapping, messages, models, protobuf +from .client import ( + MAX_PASSPHRASE_LENGTH, + MAX_PIN_LENGTH, + PASSPHRASE_ON_DEVICE, + TrezorClient, +) +from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES -from .messages import DebugWaitType +from .messages import Capability, DebugWaitType +from .protobuf import MessageType +from .tools import parse_path +from .transport.session import Session, SessionV1 +from .transport.thp.protocol_v1 import ProtocolV1 -if TYPE_CHECKING: +if t.TYPE_CHECKING: from typing_extensions import Protocol from .messages import PinMatrixRequestType from .transport import Transport - ExpectedMessage = Union[ - protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" + ExpectedMessage = t.Union[ + protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter" ] - AnyDict = Dict[str, Any] + AnyDict = t.Dict[str, t.Any] class InputFunc(Protocol): + def __call__( self, hold_ms: int | None = None, ) -> "None": ... - InputFlowType = Generator[None, messages.ButtonRequest, None] + InputFlowType = t.Generator[None, messages.ButtonRequest, None] EXPECTED_RESPONSES_CONTEXT_LINES = 3 +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") LOG = logging.getLogger(__name__) @@ -107,11 +107,11 @@ def __init__(self, json_str: str) -> None: except json.JSONDecodeError: self.dict = {} - def top_level_value(self, key: str) -> Any: + def top_level_value(self, key: str) -> t.Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if data.get(key) == value: yield data @@ -124,7 +124,7 @@ def recursively_find(data: Any) -> Iterator[Any]: return list(recursively_find(self.dict)) def find_unique_object_with_key_and_value( - self, key: str, value: Any + self, key: str, value: t.Any ) -> AnyDict | None: objects = self.find_objects_with_key_and_value(key, value) if not objects: @@ -132,8 +132,10 @@ def find_unique_object_with_key_and_value( assert len(objects) == 1 return objects[0] - def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_values_by_key( + self, key: str, only_type: type | None = None + ) -> list[t.Any]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if key in data: yield data[key] @@ -151,8 +153,8 @@ def recursively_find(data: Any) -> Iterator[Any]: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: type | None = None - ) -> Any: + self, key: str, default: t.Any, only_type: type | None = None + ) -> t.Any: values = self.find_values_by_key(key, only_type=only_type) if not values: return default @@ -163,7 +165,7 @@ def find_unique_value_by_key( class LayoutContent(UnstructuredJSONReader): """Contains helper functions to extract specific parts of the layout.""" - def __init__(self, json_tokens: Sequence[str]) -> None: + def __init__(self, json_tokens: t.Sequence[str]) -> None: json_str = "".join(json_tokens) super().__init__(json_str) @@ -429,6 +431,7 @@ def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.allow_interactions = auto_interact self.mapping = mapping.DEFAULT_MAPPING + self.protocol = ProtocolV1(self.transport, self.mapping) # To be set by TrezorClientDebugLink (is not known during creation time) self.model: models.TrezorModel | None = None self.version: tuple[int, int, int] = (0, 0, 0) @@ -471,10 +474,16 @@ def layout_type(self) -> LayoutType: return LayoutType.from_model(self.model) def open(self) -> None: - self.transport.begin_session() + self.transport.open() + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_begin_session() def close(self) -> None: - self.transport.end_session() + pass + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_end_session() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -491,15 +500,10 @@ def _write(self, msg: protobuf.MessageType) -> None: DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - self.transport.write(msg_type, msg_bytes) + self.protocol.write(msg) def _read(self) -> protobuf.MessageType: - ret_type, ret_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}", - ) - msg = self.mapping.decode(ret_type, ret_bytes) + msg = self.protocol.read() # Collapse tokens to make log use less lines. msg_for_log = msg @@ -513,7 +517,7 @@ def _read(self) -> protobuf.MessageType: ) return msg - def _call(self, msg: protobuf.MessageType) -> Any: + def _call(self, msg: protobuf.MessageType) -> t.Any: self._write(msg) return self._read() @@ -531,6 +535,25 @@ def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkSta raise TrezorFailure(result) return result + def pairing_info( + self, + thp_channel_id: bytes | None = None, + handshake_hash: bytes | None = None, + nfc_secret_host: bytes | None = None, + ) -> messages.DebugLinkPairingInfo: + result = self._call( + messages.DebugLinkGetPairingInfo( + channel_id=thp_channel_id, + handshake_hash=handshake_hash, + nfc_secret_host=nfc_secret_host, + ) + ) + while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)): + result = self._read() + if isinstance(result, messages.Failure): + raise TrezorFailure(result) + return result + def read_layout(self, wait: bool | None = None) -> LayoutContent: """ Force waiting for the layout by setting `wait=True`. Force not waiting by @@ -547,7 +570,7 @@ def read_layout(self, wait: bool | None = None) -> LayoutContent: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: # Next layout change will be caused by external event - # (e.g. device being auto-locked or as a result of device_handler.run(xxx)) + # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx)) # and not by our debug actions/decisions. # Resetting the debug state so we wait for the next layout change # (and do not return the current state). @@ -562,7 +585,7 @@ def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: return LayoutContent(obj.tokens) @contextmanager - def wait_for_layout_change(self) -> Iterator[None]: + def wait_for_layout_change(self) -> t.Iterator[None]: # make sure some current layout is up by issuing a dummy GetState self.state() @@ -615,7 +638,7 @@ def encode_pin(self, pin: str, matrix: str | None = None) -> str: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[str | None, int | None]: + def read_recovery_word(self) -> t.Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -671,7 +694,7 @@ def input(self, word: str) -> None: """Send text input to the device. See `_decision` for more details.""" self._decision(messages.DebugLinkDecision(input=word)) - def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None: + def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None: """Send a click to the device. See `_decision` for more details.""" x, y = click self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms)) @@ -794,10 +817,10 @@ def __init__(self, debuglink: DebugLink) -> None: self.clear() def clear(self) -> None: - self.pins: Iterator[str] | None = None + self.pins: t.Iterator[str] | None = None self.passphrase = "" - self.input_flow: Union[ - Generator[None, messages.ButtonRequest, None], object, None + self.input_flow: t.Union[ + t.Generator[None, messages.ButtonRequest, None], object, None ] = None def _default_input_flow(self, br: messages.ButtonRequest) -> None: @@ -829,7 +852,7 @@ def button_request(self, br: messages.ButtonRequest) -> None: raise AssertionError("input flow ended prematurely") else: try: - assert isinstance(self.input_flow, Generator) + assert isinstance(self.input_flow, t.Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE @@ -851,12 +874,15 @@ def get_passphrase(self, available_on_device: bool) -> str: class MessageFilter: - def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None: + + def __init__( + self, message_type: t.Type[protobuf.MessageType], **fields: t.Any + ) -> None: self.message_type = message_type - self.fields: Dict[str, Any] = {} + self.fields: t.Dict[str, t.Any] = {} self.update_fields(**fields) - def update_fields(self, **fields: Any) -> "MessageFilter": + def update_fields(self, **fields: t.Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -904,7 +930,7 @@ def match(self, message: protobuf.MessageType) -> bool: return True def to_string(self, maxwidth: int = 80) -> str: - fields: list[Tuple[str, str]] = [] + fields: list[t.Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -934,7 +960,7 @@ def to_string(self, maxwidth: int = 80) -> str: class MessageFilterGenerator: - def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: + def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -942,6 +968,245 @@ def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: message_filters = MessageFilterGenerator() +class SessionDebugWrapper(Session): + def __init__(self, session: Session) -> None: + self._session = session + self.reset_debug_features() + if isinstance(session, SessionDebugWrapper): + raise Exception("Cannot wrap already wrapped session!") + + @property + def protocol_version(self) -> int: + return self.client.protocol_version + + @property + def client(self) -> TrezorClientDebugLink: + assert isinstance(self._session.client, TrezorClientDebugLink) + return self._session.client + + @property + def id(self) -> bytes: + return self._session.id + + def _write(self, msg: t.Any) -> None: + print("writing message:", msg.__class__.__name__) + self._session._write(self._filter_message(msg)) + + def _read(self) -> t.Any: + resp = self._filter_message(self._session._read()) + print("reading message:", resp.__class__.__name__) + if self.actual_responses is not None: + self.actual_responses.append(resp) + return resp + + def set_expected_responses( + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> session.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + + def lock(self, *, _refresh_features: bool = True) -> None: + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate + passphrase cache, use `clear_session()`. + """ + # TODO update the documentation above + # Private argument _refresh_features can be used internally to avoid + # refreshing in cases where we will refresh soon anyway. This is used + # in TrezorClient.clear_session() + self.call(messages.LockDevice()) + if _refresh_features: + self.refresh_features() + + def cancel(self) -> None: + self._write(messages.Cancel()) + + def ensure_unlocked(self) -> None: + btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) + self.refresh_features() + + def set_filter( + self, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ) -> None: + """Configure a filter function for a specified message type. + + The `callback` must be a function that accepts a protobuf message, and returns + a (possibly modified) protobuf message of the same type. Whenever a message + is sent or received that matches `message_type`, `callback` is invoked on the + message and its result is substituted for the original. + + Useful for test scenarios with an active malicious actor on the wire. + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + self.filters[message_type] = callback + + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: + message_type = msg.__class__ + callback = self.filters.get(message_type) + if callable(callback): + return callback(deepcopy(msg)) + else: + return msg + + def reset_debug_features(self) -> None: + """Prepare the debugging session for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.in_with_statement = False + self.expected_responses: list[MessageFilter] | None = None + self.actual_responses: list[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ] = {} + self.button_callback = self.client.button_callback + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self._session.passphrase_callback + self.passphrase = self._session.passphrase + + def __enter__(self) -> "SessionDebugWrapper": + # For usage in with/expected_responses + if self.in_with_statement: + raise RuntimeError("Do not nest!") + self.in_with_statement = True + return self + + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + + # grab a copy of the inputflow generator to raise an exception through it + if isinstance(self.client.ui, DebugUI): + input_flow = self.client.ui.input_flow + else: + input_flow = None + + self.reset_debug_features() + + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) + if isinstance(input_flow, t.Generator): + # Ensure that the input flow is exhausted + try: + input_flow.throw( + AssertionError("input flow continues past end of test") + ) + except StopIteration: + pass + + elif isinstance(input_flow, t.Generator): + # Propagate the exception through the input flow, so that we see in + # traceback where it is stuck. + input_flow.throw(exc_type, value, traceback) + + @classmethod + def _verify_responses( + cls, + expected: list[MessageFilter] | None, + actual: list[protobuf.MessageType] | None, + ) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + if expected is None and actual is None: + return + + assert expected is not None + assert actual is not None + + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = cls._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) + + if act is None: + output = cls._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) + + if not exp.match(act): + output = cls._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) + + @staticmethod + def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) + output: list[str] = [] + output.append("Expected responses:") + if start_at > 0: + output.append(f" (...{start_at} previous responses omitted)") + for i in range(start_at, stop_at): + exp = expected[i] + prefix = " " if i != current else ">>> " + output.append(textwrap.indent(exp.to_string(), prefix)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") + + output.append("") + return output + + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses # and other functionality for unit tests @@ -967,23 +1232,34 @@ def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: raise # set transport explicitly so that sync_responses can work + super().__init__(transport) + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features() + self.reset_debug_features(new_seedless_session=True) self.sync_responses() - super().__init__(transport, ui=self.ui) # So that we can choose right screenshotting logic (T1 vs TT) # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version + self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: return self.debug.layout_type - def reset_debug_features(self) -> None: - """Prepare the debugging client for a new testcase. + def get_new_client(self) -> TrezorClientDebugLink: + new_client = TrezorClientDebugLink( + self.transport, self.debug.allow_interactions + ) + new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir + return new_client + + def reset_debug_features(self, new_seedless_session: bool = False) -> None: + """ + Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ @@ -991,30 +1267,139 @@ def reset_debug_features(self) -> None: self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None - self.filters: dict[ - type[protobuf.MessageType], - Callable[[protobuf.MessageType], protobuf.MessageType] | None, + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} + if new_seedless_session: + self._seedless_session = self.get_seedless_session(new_session=True) + + @property + def button_callback(self): + + def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() + + return _callback_button + + @property + def pin_callback(self): + + def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp + + return _callback_pin + + @property + def passphrase_callback(self): + def _callback_passphrase( + session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> MessageType: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + # session.session_id = resp.state + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if session.passphrase is None and isinstance(session, SessionV1): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + else: + passphrase = session.passphrase + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + return _callback_passphrase def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" - if self.session_counter == 0: - self.open() + # if self.session_counter == 0: + # self.open() + # TODO check if is this needed def open(self) -> None: - super().open() - if self.session_counter == 1: - self.debug.open() + pass + # TODO is this needed? + # self.debug.open() def close(self) -> None: - if self.session_counter == 1: - self.debug.close() - super().close() + pass + # TODO is this needed? + # self.debug.close() + + def lock(self) -> None: + s = SessionDebugWrapper(self.get_seedless_session()) + s.lock() + + def get_session( + self, + passphrase: str | object | None = "", + derive_cardano: bool = False, + session_id: int = 0, + ) -> Session: + if isinstance(passphrase, str): + passphrase = Mnemonic.normalize_string(passphrase) + return super().get_session(passphrase, derive_cardano, session_id) def set_filter( self, - message_type: type[protobuf.MessageType], - callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ) -> None: """Configure a filter function for a specified message type. @@ -1039,7 +1424,7 @@ def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: return msg def set_input_flow( - self, input_flow: InputFlowType | Callable[[], InputFlowType] + self, input_flow: InputFlowType | t.Callable[[], InputFlowType] ) -> None: """Configure a sequence of input events for the current with-block. @@ -1095,7 +1480,7 @@ def __enter__(self) -> "TrezorClientDebugLink": self.in_with_statement = True return self - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # copy expected/actual responses before clearing them @@ -1108,21 +1493,23 @@ def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: else: input_flow = None - self.reset_debug_features() + self.reset_debug_features(new_seedless_session=False) if exc_type is None: # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, Generator): + elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) def set_expected_responses( self, - expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]], + expected: t.Sequence[ + t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]] + ], ) -> None: """Set a sequence of expected responses to client calls. @@ -1161,7 +1548,7 @@ def set_expected_responses( ] self.actual_responses = [] - def use_pin_sequence(self, pins: Iterable[str]) -> None: + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ @@ -1169,6 +1556,7 @@ def use_pin_sequence(self, pins: Iterable[str]) -> None: def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" + self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1178,15 +1566,14 @@ def use_mnemonic(self, mnemonic: str) -> None: def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - resp = super()._raw_read() + resp = self.get_seedless_session()._read() resp = self._filter_message(resp) if self.actual_responses is not None: self.actual_responses.append(resp) return resp def _raw_write(self, msg: protobuf.MessageType) -> None: - return super()._raw_write(self._filter_message(msg)) + return self.get_seedless_session()._write(self._filter_message(msg)) @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: @@ -1256,23 +1643,25 @@ def sync_responses(self) -> None: # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() + # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) + self.transport.open() try: - self.transport.write(*cancel_msg) - + # self.protocol.write(messages.Cancel()) message = "SYNC" + secrets.token_hex(8) - ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) - self.transport.write(*ping_msg) + self.get_seedless_session()._write(messages.Ping(message=message)) resp = None while resp != messages.Success(message=message): - msg_id, msg_bytes = self.transport.read() try: - resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) + resp = self.get_seedless_session()._read() + + raise Exception + except Exception: pass + finally: - self.transport.end_session() + pass # TODO fix + # self.transport.end_session(self.session_id or b"") def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() @@ -1285,8 +1674,8 @@ def mnemonic_callback(self, _) -> str: def load_device( - client: "TrezorClient", - mnemonic: Union[str, Iterable[str]], + session: "Session", + mnemonic: str | t.Iterable[str], pin: str | None, passphrase_protection: bool, label: str | None, @@ -1299,12 +1688,12 @@ def load_device( mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call device.wipe() and try again." ) - client.call( + session.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1316,18 +1705,18 @@ def load_device( ), expect=messages.Success, ) - client.init_device() + session.refresh_features() # keep the old name for compatibility load_device_by_mnemonic = load_device -def prodtest_t1(client: "TrezorClient") -> None: - if client.features.bootloader_mode is not True: +def prodtest_t1(session: "Session") -> None: + if session.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - client.call( + session.call( messages.ProdTestT1( payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" ), @@ -1337,8 +1726,8 @@ def prodtest_t1(client: "TrezorClient") -> None: def record_screen( debug_client: "TrezorClientDebugLink", - directory: Union[str, None], - report_func: Union[Callable[[str], None], None] = None, + directory: str | None, + report_func: t.Callable[[str], None] | None = None, ) -> None: """Record screen changes into a specified directory. @@ -1383,5 +1772,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: return debug_client.features.fw_vendor == "EMULATOR" -def optiga_set_sec_max(client: "TrezorClient") -> None: - client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) +def optiga_set_sec_max(session: "Session") -> None: + session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index c08d485ed0b..a3b24c247da 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -28,16 +28,10 @@ from . import messages from .exceptions import Cancelled, TrezorException -from .tools import ( - Address, - _deprecation_retval_helper, - _return_success, - parse_path, - session, -) +from .tools import Address, _deprecation_retval_helper, _return_success, parse_path if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session RECOVERY_BACK = "\x08" # backspace character, sent literally @@ -46,9 +40,8 @@ ENTROPY_CHECK_MIN_VERSION = (2, 8, 7) -@session def apply_settings( - client: "TrezorClient", + session: "Session", label: Optional[str] = None, language: Optional[str] = None, use_passphrase: Optional[bool] = None, @@ -79,13 +72,13 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings, expect=messages.Success) - client.refresh_features() + out = session.call(settings, expect=messages.Success) + session.refresh_features() return _return_success(out) def _send_language_data( - client: "TrezorClient", + session: "Session", request: "messages.TranslationDataRequest", language_data: bytes, ) -> None: @@ -95,69 +88,63 @@ def _send_language_data( data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] - response = client.call(messages.TranslationDataAck(data_chunk=chunk)) + response = session.call(messages.TranslationDataAck(data_chunk=chunk)) -@session def change_language( - client: "TrezorClient", + session: "Session", language_data: bytes, show_display: bool | None = None, ) -> str | None: data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) - response = client.call(msg) + response = session.call(msg) if data_length > 0: response = messages.TranslationDataRequest.ensure_isinstance(response) - _send_language_data(client, response, language_data) + _send_language_data(session, response, language_data) else: messages.Success.ensure_isinstance(response) - client.refresh_features() # changing the language in features + session.refresh_features() # changing the language in features return _return_success(messages.Success(message="Language changed.")) -@session -def apply_flags(client: "TrezorClient", flags: int) -> str | None: - out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success) - client.refresh_features() +def apply_flags(session: "Session", flags: int) -> str | None: + out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success) + session.refresh_features() return _return_success(out) -@session -def change_pin(client: "TrezorClient", remove: bool = False) -> str | None: - ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success) - client.refresh_features() +def change_pin(session: "Session", remove: bool = False) -> str | None: + ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None: - ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) - client.refresh_features() +def change_wipe_code(session: "Session", remove: bool = False) -> str | None: + ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType + session: "Session", operation: messages.SdProtectOperationType ) -> str | None: - ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success) - client.refresh_features() + ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session -def wipe(client: "TrezorClient") -> str | None: - ret = client.call(messages.WipeDevice(), expect=messages.Success) - if not client.features.bootloader_mode: - client.init_device() +def wipe(session: "Session") -> str | None: + ret = session.call(messages.WipeDevice(), expect=messages.Success) + session.invalidate() + # if not session.features.bootloader_mode: + # session.refresh_features() return _return_success(ret) -@session def recover( - client: "TrezorClient", + session: "Session", word_count: int = 24, passphrase_protection: bool = False, pin_protection: bool = True, @@ -193,13 +180,13 @@ def recover( if type is None: type = messages.RecoveryType.NormalRecovery - if client.features.model == "1" and input_callback is None: + if session.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") if word_count not in (12, 18, 24): raise ValueError("Invalid word count. Use 12/18/24") - if client.features.initialized and type == messages.RecoveryType.NormalRecovery: + if session.features.initialized and type == messages.RecoveryType.NormalRecovery: raise RuntimeError( "Device already initialized. Call device.wipe() and try again." ) @@ -221,20 +208,20 @@ def recover( msg.label = label msg.u2f_counter = u2f_counter - res = client.call(msg) + res = session.call(msg) while isinstance(res, messages.WordRequest): try: assert input_callback is not None inp = input_callback(res.type) - res = client.call(messages.WordAck(word=inp)) + res = session.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(messages.Cancel()) + res = session.call(messages.Cancel()) # check that the result is a Success res = messages.Success.ensure_isinstance(res) # reinitialize the device - client.init_device() + session.refresh_features() return _deprecation_retval_helper(res) @@ -280,7 +267,7 @@ def _seed_from_entropy( def reset( - client: "TrezorClient", + session: "Session", display_random: bool = False, strength: Optional[int] = None, passphrase_protection: bool = False, @@ -313,7 +300,7 @@ def reset( ) setup( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes: return secrets.token_bytes(32) -@session def setup( - client: "TrezorClient", + session: "Session", *, strength: Optional[int] = None, passphrase_protection: bool = True, @@ -388,19 +374,19 @@ def setup( check. """ - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." ) if strength is None: - if client.features.model == "1": + if session.features.model == "1": strength = 256 else: strength = 128 if backup_type is None: - if client.version < SLIP39_EXTENDABLE_MIN_VERSION: + if session.version < SLIP39_EXTENDABLE_MIN_VERSION: # includes Trezor One 1.x.x backup_type = messages.BackupType.Bip39 else: @@ -411,7 +397,7 @@ def setup( paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")] if entropy_check_count is None: - if client.version < ENTROPY_CHECK_MIN_VERSION: + if session.version < ENTROPY_CHECK_MIN_VERSION: # includes Trezor One 1.x.x entropy_check_count = 0 else: @@ -431,18 +417,18 @@ def setup( ) if entropy_check_count > 0: xpubs = _reset_with_entropycheck( - client, msg, entropy_check_count, paths, _get_entropy + session, msg, entropy_check_count, paths, _get_entropy ) else: - _reset_no_entropycheck(client, msg, _get_entropy) + _reset_no_entropycheck(session, msg, _get_entropy) xpubs = [] - client.init_device() + session.refresh_features() return xpubs def _reset_no_entropycheck( - client: "TrezorClient", + session: "Session", msg: messages.ResetDevice, get_entropy: Callable[[], bytes], ) -> None: @@ -454,12 +440,12 @@ def _reset_no_entropycheck( << Success """ assert msg.entropy_check is False - client.call(msg, expect=messages.EntropyRequest) - client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) + session.call(msg, expect=messages.EntropyRequest) + session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) def _reset_with_entropycheck( - client: "TrezorClient", + session: "Session", reset_msg: messages.ResetDevice, entropy_check_count: int, paths: Iterable[Address], @@ -495,7 +481,7 @@ def _reset_with_entropycheck( def get_xpubs() -> list[tuple[Address, str]]: xpubs = [] for path in paths: - resp = client.call( + resp = session.call( messages.GetPublicKey(address_n=path), expect=messages.PublicKey ) xpubs.append((path, resp.xpub)) @@ -524,13 +510,13 @@ def verify_entropy_commitment( raise TrezorException("Invalid XPUB in entropy check") xpubs = [] - resp = client.call(reset_msg, expect=messages.EntropyRequest) + resp = session.call(reset_msg, expect=messages.EntropyRequest) entropy_commitment = resp.entropy_commitment while True: # provide external entropy for this round external_entropy = get_entropy() - client.call( + session.call( messages.EntropyAck(entropy=external_entropy), expect=messages.EntropyCheckReady, ) @@ -540,7 +526,7 @@ def verify_entropy_commitment( if entropy_check_count <= 0: # last round, wait for a Success and exit the loop - client.call( + session.call( messages.EntropyCheckContinue(finish=True), expect=messages.Success, ) @@ -549,7 +535,7 @@ def verify_entropy_commitment( entropy_check_count -= 1 # Next round starts. - resp = client.call( + resp = session.call( messages.EntropyCheckContinue(finish=False), expect=messages.EntropyRequest, ) @@ -570,13 +556,12 @@ def verify_entropy_commitment( return xpubs -@session def backup( - client: "TrezorClient", + session: "Session", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), ) -> str | None: - ret = client.call( + ret = session.call( messages.BackupDevice( group_threshold=group_threshold, groups=[ @@ -586,37 +571,36 @@ def backup( ), expect=messages.Success, ) - client.refresh_features() + session.refresh_features() return _return_success(ret) -def cancel_authorization(client: "TrezorClient") -> str | None: - ret = client.call(messages.CancelAuthorization(), expect=messages.Success) +def cancel_authorization(session: "Session") -> str | None: + ret = session.call(messages.CancelAuthorization(), expect=messages.Success) return _return_success(ret) -def unlock_path(client: "TrezorClient", n: "Address") -> bytes: - resp = client.call( +def unlock_path(session: "Session", n: "Address") -> bytes: + resp = session.call( messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest ) # Cancel the UnlockPath workflow now that we have the authentication code. try: - client.call(messages.Cancel()) + session.call(messages.Cancel()) except Cancelled: return resp.mac else: raise TrezorException("Unexpected response in UnlockPath flow") -@session def reboot_to_bootloader( - client: "TrezorClient", + session: "Session", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", ) -> str | None: - response = client.call( + response = session.call( messages.RebootToBootloader( boot_command=boot_command, firmware_header=firmware_header, @@ -624,43 +608,38 @@ def reboot_to_bootloader( ) ) if isinstance(response, messages.TranslationDataRequest): - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) return _return_success(messages.Success(message="")) -@session -def show_device_tutorial(client: "TrezorClient") -> str | None: - ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success) +def show_device_tutorial(session: "Session") -> str | None: + ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success) return _return_success(ret) -@session -def unlock_bootloader(client: "TrezorClient") -> str | None: - ret = client.call(messages.UnlockBootloader(), expect=messages.Success) +def unlock_bootloader(session: "Session") -> str | None: + ret = session.call(messages.UnlockBootloader(), expect=messages.Success) return _return_success(ret) -@session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None: +def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None: """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) - client.refresh_features() + ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) + session.refresh_features() return _return_success(ret) -def authenticate( - client: "TrezorClient", challenge: bytes -) -> messages.AuthenticityProof: - return client.call( +def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof: + return session.call( messages.AuthenticateDevice(challenge=challenge), expect=messages.AuthenticityProof, ) -def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None: - ret = client.call(messages.SetBrightness(value=value), expect=messages.Success) +def set_brightness(session: "Session", value: Optional[int] = None) -> str | None: + ret = session.call(messages.SetBrightness(value=value), expect=messages.Success) return _return_success(ret) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index eb491f204c1..990adf38555 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, session +from .tools import b58decode if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def name_to_number(name: str) -> int: @@ -319,17 +319,16 @@ def parse_transaction_json( def get_public_key( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> messages.EosPublicKey: - return client.call( + return session.call( messages.EosGetPublicKey(address_n=n, show_display=show_display), expect=messages.EosPublicKey, ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", address: "Address", transaction: dict, chain_id: str, @@ -345,11 +344,11 @@ def sign_tx( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) try: while isinstance(response, messages.EosTxActionRequest): - response = client.call(actions.pop(0)) + response = session.call(actions.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 96ce4d10663..77b071f6b72 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import prepare_message_bytes, session, unharden +from .tools import prepare_message_bytes, unharden if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def int_to_big_endian(value: int) -> bytes: @@ -161,13 +161,13 @@ def network_from_address_n( def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> str: - resp = client.call( + resp = session.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, @@ -181,17 +181,16 @@ def get_address( def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> messages.EthereumPublicKey: - return client.call( + return session.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display), expect=messages.EthereumPublicKey, ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", n: "Address", nonce: int, gas_price: int, @@ -227,13 +226,13 @@ def sign_tx( data, chunk = data[1024:], data[:1024] msg.data_initial_chunk = chunk - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -248,9 +247,8 @@ def sign_tx( return response.signature_v, response.signature_r, response.signature_s -@session def sign_tx_eip1559( - client: "TrezorClient", + session: "Session", n: "Address", *, nonce: int, @@ -283,13 +281,13 @@ def sign_tx_eip1559( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -299,13 +297,13 @@ def sign_tx_eip1559( def sign_message( - client: "TrezorClient", + session: "Session", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> messages.EthereumMessageSignature: - return client.call( + return session.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), @@ -317,7 +315,7 @@ def sign_message( def sign_typed_data( - client: "TrezorClient", + session: "Session", n: "Address", data: Dict[str, Any], *, @@ -333,7 +331,7 @@ def sign_typed_data( metamask_v4_compat=metamask_v4_compat, definitions=definitions, ) - response = client.call(request) + response = session.call(request) # Sending all the types while isinstance(response, messages.EthereumTypedDataStructRequest): @@ -349,7 +347,7 @@ def sign_typed_data( members.append(struct_member) request = messages.EthereumTypedDataStructAck(members=members) - response = client.call(request) + response = session.call(request) # Sending the whole message that should be signed while isinstance(response, messages.EthereumTypedDataValueRequest): @@ -362,7 +360,7 @@ def sign_typed_data( member_typename = data["primaryType"] member_data = data["message"] else: - client.cancel() + # TODO session.cancel() raise exceptions.TrezorException("Root index can only be 0 or 1") # It can be asking for a nested structure (the member path being [X, Y, Z, ...]) @@ -385,20 +383,20 @@ def sign_typed_data( encoded_data = encode_data(member_data, member_typename) request = messages.EthereumTypedDataValueAck(value=encoded_data) - response = client.call(request) + response = session.call(request) return messages.EthereumTypedDataSignature.ensure_isinstance(response) def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: bytes, message: AnyStr, chunkify: bool = False, ) -> bool: try: - client.call( + session.call( messages.EthereumVerifyMessage( address=address, signature=signature, @@ -413,13 +411,13 @@ def verify_message( def sign_typed_data_hash( - client: "TrezorClient", + session: "Session", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, ) -> messages.EthereumTypedDataSignature: - return client.call( + return session.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index 99f0048dd36..44d25d7088c 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -65,3 +65,7 @@ def __init__(self, expected: type[MessageType], actual: MessageType) -> None: self.expected = expected self.actual = actual super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}") + + +class DeviceLockedException(TrezorException): + pass diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index a2618b72dbb..aaa3b084bff 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -22,37 +22,37 @@ from .tools import _return_success if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session -def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]: - return client.call( +def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]: + return session.call( messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials ).credentials -def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None: - ret = client.call( +def add_credential(session: "Session", credential_id: bytes) -> str | None: + ret = session.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id), expect=messages.Success, ) return _return_success(ret) -def remove_credential(client: "TrezorClient", index: int) -> str | None: - ret = client.call( +def remove_credential(session: "Session", index: int) -> str | None: + ret = session.call( messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success ) return _return_success(ret) -def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None: - ret = client.call( +def set_counter(session: "Session", u2f_counter: int) -> str | None: + ret = session.call( messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success ) return _return_success(ret) -def get_next_counter(client: "TrezorClient") -> int: - ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) +def get_next_counter(session: "Session") -> int: + ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) return ret.u2f_counter diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 4cfc11dd40a..56168306bb6 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -20,7 +20,6 @@ from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import session from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware @@ -38,7 +37,7 @@ from .vendor import * # noqa: F401, F403 if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session T = t.TypeVar("T", bound="FirmwareType") @@ -72,20 +71,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]: # ====== Client functions ====== # -@session def update( - client: "TrezorClient", + session: "Session", data: bytes, progress_update: t.Callable[[int], t.Any] = lambda _: None, ): - if client.features.bootloader_mode is False: + if session.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") - resp = client.call(messages.FirmwareErase(length=len(data))) + resp = session.call(messages.FirmwareErase(length=len(data))) # TREZORv1 method if isinstance(resp, messages.Success): - resp = client.call(messages.FirmwareUpload(payload=data)) + resp = session.call(messages.FirmwareUpload(payload=data)) progress_update(len(data)) if isinstance(resp, messages.Success): return @@ -97,7 +95,7 @@ def update( length = resp.length payload = data[resp.offset : resp.offset + length] digest = blake2s(payload).digest() - resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) + resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest)) progress_update(length) if isinstance(resp, messages.Success): @@ -106,7 +104,7 @@ def update( raise RuntimeError(f"Unexpected message {resp}") -def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes: - return client.call( +def get_hash(session: "Session", challenge: t.Optional[bytes]) -> bytes: + return session.call( messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash ).hash diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index d50324d5868..04b75f0aa56 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,6 +17,7 @@ from __future__ import annotations import io +import logging from types import ModuleType from typing import Dict, Optional, Tuple, Type, TypeVar @@ -25,6 +26,7 @@ from . import messages, protobuf T = TypeVar("T") +LOG = logging.getLogger(__name__) class ProtobufMapping: @@ -63,11 +65,21 @@ def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - + LOG.debug("encoding wire type %d", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) return wire_type, buf.getvalue() + def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: + """Serialize a Python protobuf class. + + Returns the byte representation of the protobuf message. + """ + + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return buf.getvalue() + def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: """Deserialize a protobuf message into a Python class.""" cls = self.type_to_class[msg_wire_type] @@ -83,7 +95,9 @@ def from_module(cls, module: ModuleType) -> Self: mapping = cls() message_types = getattr(module, "MessageType") - for entry in message_types: + thp_message_types = getattr(module, "ThpMessageType") + + for entry in (*message_types, *thp_message_types): msg_class = getattr(module, entry.name, None) if msg_class is None: raise ValueError( diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index 024c3ae6961..b163c34b870 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -43,6 +43,10 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 + BufferError = 17 + DeviceIsBusy = 18 FirmwareError = 99 @@ -400,6 +404,34 @@ class TezosBallotType(IntEnum): Pass = 2 +class ThpMessageType(IntEnum): + ThpCreateNewSession = 1000 + ThpPairingRequest = 1006 + ThpPairingRequestApproved = 1007 + ThpSelectMethod = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceTrezor = 1018 + ThpCodeEntryCpaceHostTag = 1019 + ThpCodeEntrySecret = 1020 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcTagHost = 1032 + ThpNfcTagTrezor = 1033 + + +class ThpPairingMethod(IntEnum): + SkipPairing = 1 + CodeEntry = 2 + QrCode = 3 + NFC = 4 + + class MessageType(IntEnum): Initialize = 0 Ping = 1 @@ -500,6 +532,8 @@ class MessageType(IntEnum): DebugLinkWatchLayout = 9006 DebugLinkResetDebugEvents = 9007 DebugLinkOptigaSetSecMax = 9008 + DebugLinkGetPairingInfo = 9009 + DebugLinkPairingInfo = 9010 EthereumGetPublicKey = 450 EthereumPublicKey = 451 EthereumGetAddress = 56 @@ -4203,6 +4237,52 @@ def __init__( self.mnemonic_type = mnemonic_type +class DebugLinkGetPairingInfo(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 9009 + FIELDS = { + 1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None), + 3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + channel_id: Optional["bytes"] = None, + handshake_hash: Optional["bytes"] = None, + nfc_secret_host: Optional["bytes"] = None, + ) -> None: + self.channel_id = channel_id + self.handshake_hash = handshake_hash + self.nfc_secret_host = nfc_secret_host + + +class DebugLinkPairingInfo(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 9010 + FIELDS = { + 1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None), + 3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None), + 4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None), + 5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + channel_id: Optional["bytes"] = None, + handshake_hash: Optional["bytes"] = None, + code_entry_code: Optional["int"] = None, + code_qr_code: Optional["bytes"] = None, + nfc_secret_trezor: Optional["bytes"] = None, + ) -> None: + self.channel_id = channel_id + self.handshake_hash = handshake_hash + self.code_entry_code = code_entry_code + self.code_qr_code = code_qr_code + self.nfc_secret_trezor = nfc_secret_trezor + + class DebugLinkStop(protobuf.MessageType): MESSAGE_WIRE_TYPE = 103 @@ -7863,18 +7943,288 @@ def __init__( self.amount = amount +class ThpDeviceProperties(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None), + 3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None), + 4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None), + 5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + internal_model: Optional["str"] = None, + model_variant: Optional["int"] = None, + protocol_version_major: Optional["int"] = None, + protocol_version_minor: Optional["int"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.internal_model = internal_model + self.model_variant = model_variant + self.protocol_version_major = protocol_version_major + self.protocol_version_minor = protocol_version_minor + + +class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_pairing_credential: Optional["bytes"] = None, + ) -> None: + self.host_pairing_credential = host_pairing_credential + + +class ThpCreateNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1000 + FIELDS = { + 1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None), + 3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + passphrase: Optional["str"] = None, + on_device: Optional["bool"] = None, + derive_cardano: Optional["bool"] = None, + ) -> None: + self.passphrase = passphrase + self.on_device = on_device + self.derive_cardano = derive_cardano + + +class ThpPairingRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1006 + FIELDS = { + 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_name: Optional["str"] = None, + ) -> None: + self.host_name = host_name + + +class ThpPairingRequestApproved(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1007 + + +class ThpSelectMethod(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1008 + FIELDS = { + 1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + selected_pairing_method: Optional["ThpPairingMethod"] = None, + ) -> None: + self.selected_pairing_method = selected_pairing_method + + +class ThpPairingPreparationsFinished(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1009 + + +class ThpCodeEntryCommitment(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1016 + FIELDS = { + 1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + commitment: Optional["bytes"] = None, + ) -> None: + self.commitment = commitment + + +class ThpCodeEntryChallenge(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1017 + FIELDS = { + 1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + challenge: Optional["bytes"] = None, + ) -> None: + self.challenge = challenge + + +class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1018 + FIELDS = { + 1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_trezor_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_trezor_public_key = cpace_trezor_public_key + + +class ThpCodeEntryCpaceHostTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1019 + FIELDS = { + 1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_host_public_key: Optional["bytes"] = None, + tag: Optional["bytes"] = None, + ) -> None: + self.cpace_host_public_key = cpace_host_public_key + self.tag = tag + + +class ThpCodeEntrySecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1020 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpQrCodeTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1024 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpQrCodeSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1025 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpNfcTagHost(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1032 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpNfcTagTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1033 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpCredentialRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1010 + FIELDS = { + 1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_static_pubkey: Optional["bytes"] = None, + autoconnect: Optional["bool"] = None, + ) -> None: + self.host_static_pubkey = host_static_pubkey + self.autoconnect = autoconnect + + +class ThpCredentialResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1011 + FIELDS = { + 1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + trezor_static_pubkey: Optional["bytes"] = None, + credential: Optional["bytes"] = None, + ) -> None: + self.trezor_static_pubkey = trezor_static_pubkey + self.credential = credential + + +class ThpEndRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1012 + + +class ThpEndResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1013 + + class ThpCredentialMetadata(protobuf.MessageType): MESSAGE_WIRE_TYPE = None FIELDS = { 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None), } def __init__( self, *, host_name: Optional["str"] = None, + autoconnect: Optional["bool"] = None, ) -> None: self.host_name = host_name + self.autoconnect = autoconnect class ThpPairingCredential(protobuf.MessageType): diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 578c1fa19f1..eeaea268721 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -19,22 +19,22 @@ from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session -def get_entropy(client: "TrezorClient", size: int) -> bytes: - return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy +def get_entropy(session: "Session", size: int) -> bytes: + return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy def sign_identity( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, ) -> messages.SignedIdentity: - return client.call( + return session.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, @@ -46,12 +46,12 @@ def sign_identity( def get_ecdh_session_key( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, ) -> messages.ECDHSessionKey: - return client.call( + return session.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, @@ -62,7 +62,7 @@ def get_ecdh_session_key( def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -70,7 +70,7 @@ def encrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> bytes: - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -85,7 +85,7 @@ def encrypt_keyvalue( def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -93,7 +93,7 @@ def decrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> bytes: - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -107,5 +107,5 @@ def decrypt_keyvalue( ).value -def get_nonce(client: "TrezorClient") -> bytes: - return client.call(messages.GetNonce(), expect=messages.Nonce).nonce +def get_nonce(session: "Session") -> bytes: + return session.call(messages.GetNonce(), expect=messages.Nonce).nonce diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index b2e3214fb95..9e323461561 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -19,8 +19,8 @@ from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session # MAINNET = 0 @@ -30,13 +30,13 @@ def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, ) -> bytes: - return client.call( + return session.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, @@ -48,11 +48,11 @@ def get_address( def get_watch_key( - client: "TrezorClient", + session: "Session", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, ) -> messages.MoneroWatchKey: - return client.call( + return session.call( messages.MoneroGetWatchKey(address_n=n, network_type=network_type), expect=messages.MoneroWatchKey, ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 744dc3205f3..357de145ada 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -20,8 +20,8 @@ from . import exceptions, messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig def get_address( - client: "TrezorClient", + session: "Session", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify ), @@ -210,7 +210,7 @@ def get_address( def sign_tx( - client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False + session: "Session", n: "Address", transaction: dict, chunkify: bool = False ) -> messages.NEMSignedTx: try: msg = create_sign_tx(transaction, chunkify=chunkify) @@ -219,4 +219,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg, expect=messages.NEMSignedTx) + return session.call(msg, expect=messages.NEMSignedTx) diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 00a027c6d97..e5e0f524cc3 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -21,20 +21,20 @@ from .tools import dict_from_camelcase if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -43,14 +43,14 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, ) -> messages.RippleSignedTx: msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg, expect=messages.RippleSignedTx) + return session.call(msg, expect=messages.RippleSignedTx) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index 0054e0fd924..3d0ee755498 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -3,27 +3,27 @@ from . import messages if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, ) -> bytes: - return client.call( + return session.call( messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display), expect=messages.SolanaPublicKey, ).public_key def get_address( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, @@ -34,12 +34,12 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], ) -> bytes: - return client.call( + return session.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index 5bd0a749e42..843a2e0c393 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -20,8 +20,8 @@ from . import exceptions, messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session StellarMessageType = Union[ messages.StellarAccountMergeOp, @@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -336,7 +336,7 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", tx: messages.StellarSignTx, operations: List["StellarMessageType"], address_n: "Address", @@ -352,10 +352,10 @@ def sign_tx( # 3. Receive a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 5. The final message received will be StellarSignedTx which is returned from this method - resp = client.call(tx) + resp = session.call(tx) try: while isinstance(resp, messages.StellarTxOpRequest): - resp = client.call(operations.pop(0)) + resp = session.call(operations.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index 9319aa1eaa1..06bcafe759a 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -19,17 +19,17 @@ from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -38,12 +38,12 @@ def get_address( def get_public_key( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -52,11 +52,11 @@ def get_public_key( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, ) -> messages.TezosSignedTx: sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg, expect=messages.TezosSignedTx) + return session.call(sign_tx_msg, expect=messages.TezosSignedTx) diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 6ba8c64dba3..f753e68a330 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -45,7 +45,7 @@ # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from . import client from .messages import Success @@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None: return _deprecation_retval_helper(msg.message, stacklevel=1) -def session( - f: "Callable[Concatenate[TrezorClient, P], R]", -) -> "Callable[Concatenate[TrezorClient, P], R]": - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - client.open() - try: - return f(client, *args, **kwargs) - finally: - client.close() - - return wrapped_f - - # de-camelcasifier # https://stackoverflow.com/a/1176023/222189 diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b77..2c208be36dc 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -14,24 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -41,7 +35,7 @@ """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException): class Transport: - """Raw connection to a Trezor device. + PATH_PREFIX: str - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. + @classmethod + def enumerate( + cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["T"]: + raise NotImplementedError - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages + @classmethod + def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ + if device.get_path() == path: + return device - PATH_PREFIX: str - ENABLED = False + if prefix_search and device.get_path().startswith(path): + return device - def __str__(self) -> str: - return self.get_path() + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: + def find_debug(self: "T") -> "T": raise NotImplementedError - def read(self) -> MessagePayload: + def open(self) -> None: raise NotImplementedError - def write(self, message_type: int, message_data: bytes) -> None: + def close(self) -> None: raise NotImplementedError - def find_debug(self: "T") -> "T": + def write_chunk(self, chunk: bytes) -> None: raise NotImplementedError - @classmethod - def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + def read_chunk(self) -> bytes: raise NotImplementedError - @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device - - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + CHUNK_SIZE: t.ClassVar[int] -def all_transports() -> Iterable[Type["Transport"]]: +def all_transports() -> t.Iterable[t.Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( + transports: t.Tuple[t.Type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: t.Iterable["TrezorModel"] | None = None, +) -> t.Sequence["Transport"]: + devices: t.List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -145,9 +121,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f701..7f136608b03 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -14,24 +14,30 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +import typing as t import requests from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) +PROTOCOL_VERSION_1 = 1 +PROTOCOL_VERSION_2 = 2 + TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_VERSION_MODERN = (2, 0, 25) +TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) @@ -45,7 +51,7 @@ def __init__(self, path: str, status: int, message: str) -> None: super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge(path: str, data: str | None = None) -> requests.Response: url = TREZORD_HOST + "/" + path r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN + + +def supports_protocolV2() -> bool: + return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT + + +def detect_protocol_version(transport: "BridgeTransport") -> int: + from .. import mapping, messages + from ..messages import FailureType + + protocol_version = PROTOCOL_VERSION_1 + request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) + transport.deprecated_begin_session() + transport.deprecated_write(request_type, request_data) + + response_type, response_data = transport.deprecated_read() + response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) + transport.deprecated_begin_session() + if isinstance(response, messages.Failure): + if response.code == FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol_version = PROTOCOL_VERSION_2 + + return protocol_version + + +def _is_transport_valid(transport: "BridgeTransport") -> bool: + is_valid = ( + supports_protocolV2() + or detect_protocol_version(transport) == PROTOCOL_VERSION_1 + ) + if not is_valid: + LOG.warning("Detected unsupported Bridge transport!") + return is_valid + + +def filter_invalid_bridge_transports( + transports: t.Iterable["BridgeTransport"], +) -> t.Sequence["BridgeTransport"]: + """Filters out invalid bridge transports. Keeps only valid ones.""" + return [t for t in transports if _is_transport_valid(t)] class BridgeHandle: @@ -84,7 +134,7 @@ def read_buf(self) -> bytes: class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: @@ -112,13 +162,12 @@ class BridgeTransport(Transport): ENABLED: bool = True def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: Optional[str] = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -135,7 +184,7 @@ def find_debug(self) -> "BridgeTransport": raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call(self, action: str, data: str | None = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: @@ -144,17 +193,20 @@ def _call(self, action: str, data: Optional[str] = None) -> requests.Response: @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() - return [ - BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() - ] + return filter_invalid_bridge_transports( + [ + BridgeTransport(dev, legacy) + for dev in call_bridge("enumerate").json() + ] + ) except Exception: return [] - def begin_session(self) -> None: + def deprecated_begin_session(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -163,18 +215,32 @@ def begin_session(self) -> None: raise self.session = data.json()["session"] - def end_session(self) -> None: + def deprecated_end_session(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: + def deprecated_write(self, message_type: int, message_data: bytes) -> None: header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(header + message_data) - def read(self) -> MessagePayload: + def deprecated_read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) return msg_type, data[headerlen : headerlen + datalen] + + def open(self) -> None: + pass + # TODO self.handle.open() + + def close(self) -> None: + pass + # TODO self.handle.close() + + def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :) + self.handle.write_buf(chunk) + + def read_chunk(self) -> bytes: # TODO check if it works :) + return self.handle.read_buf() diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd70..65e2cddf7d5 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -14,15 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Transport, TransportException LOG = logging.getLogger(__name__) @@ -35,23 +36,61 @@ HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any + + +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -62,11 +101,11 @@ def open(self) -> None: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -77,7 +116,7 @@ def open(self) -> None: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -115,53 +154,6 @@ def probe_hid_version(self) -> int: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: List["HidTransport"] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> "HidTransport": - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index a5a0ee6be4d..00000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,165 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -import logging -import struct -from typing import Tuple - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle) -> None: - self.handle = handle - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self) -> MessagePayload: - return self.protocol.read() - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError("Unexpected magic characters") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError("Cannot parse header") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError("Unexpected magic characters") - return chunk[1:] diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py new file mode 100644 index 00000000000..90dab10bfac --- /dev/null +++ b/python/src/trezorlib/transport/session.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging +import typing as t + +from .. import exceptions, messages, models +from ..protobuf import MessageType +from .thp.protocol_v1 import ProtocolV1 +from .thp.protocol_v2 import ProtocolV2 + +if t.TYPE_CHECKING: + from ..client import TrezorClient + +LOG = logging.getLogger(__name__) + +MT = t.TypeVar("MT", bound=MessageType) + + +class Session: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + def __init__( + self, client: TrezorClient, id: bytes, passphrase: str | object | None = None + ) -> None: + self.client = client + self._id = id + self.passphrase = passphrase + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool + ) -> Session: + raise NotImplementedError + + def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: + # TODO self.check_firmware_version() + resp = self.call_raw(msg) + + while True: + if isinstance(resp, messages.PinMatrixRequest): + if self.pin_callback is None: + raise Exception # TODO + resp = self.pin_callback(self, resp) + elif isinstance(resp, messages.PassphraseRequest): + if self.passphrase_callback is None: + raise Exception # TODO + resp = self.passphrase_callback(self, resp) + elif isinstance(resp, messages.ButtonRequest): + if self.button_callback is None: + raise Exception # TODO + resp = self.button_callback(self, resp) + elif isinstance(resp, messages.Failure): + if resp.code == messages.FailureType.ActionCancelled: + raise exceptions.Cancelled + raise exceptions.TrezorFailure(resp) + elif not isinstance(resp, expect): + raise exceptions.UnexpectedMessageError(expect, resp) + else: + return resp + + def call_raw(self, msg: t.Any) -> t.Any: + self._write(msg) + return self._read() + + def _write(self, msg: t.Any) -> None: + raise NotImplementedError + + def _read(self) -> t.Any: + raise NotImplementedError + + def refresh_features(self) -> None: + self.client.refresh_features() + + def end(self) -> t.Any: + return self.call(messages.EndSession()) + + def ping(self, message: str, button_protection: bool | None = None) -> str: + resp = self.call( + messages.Ping(message=message, button_protection=button_protection), + expect=messages.Success, + ) + assert resp.message is not None + return resp.message + + def invalidate(self) -> None: + self.client.invalidate() + + @property + def features(self) -> messages.Features: + return self.client.features + + @property + def model(self) -> models.TrezorModel: + return self.client.model + + @property + def version(self) -> t.Tuple[int, int, int]: + return self.client.version + + @property + def id(self) -> bytes: + return self._id + + @id.setter + def id(self, value: bytes) -> None: + if not isinstance(value, bytes): + raise ValueError("id must be of type bytes") + self._id = value + + +class SessionV1(Session): + derive_cardano: bool | None = False + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | object = "", + derive_cardano: bool = False, + session_id: bytes | None = None, + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, id=session_id or b"") + + session._init_callbacks() + session.passphrase = passphrase + session.derive_cardano = derive_cardano + session.init_session(session.derive_cardano) + return session + + @classmethod + def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, session_id) + session.init_session() + return session + + def _init_callbacks(self) -> None: + self.button_callback = self.client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self.client.passphrase_callback + + def _write(self, msg: t.Any) -> None: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + self.client.protocol.write(msg) + + def _read(self) -> t.Any: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + return self.client.protocol.read() + + def init_session(self, derive_cardano: bool | None = None): + if self.id == b"": + session_id = None + else: + session_id = self.id + resp: messages.Features = self.call_raw( + messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) + ) + if isinstance(self.passphrase, str): + self.passphrase_callback = self.client.passphrase_callback + self._id = resp.session_id + + +def _callback_button(session: Session, msg: t.Any) -> t.Any: + print("Please confirm action on your Trezor device.") # TODO how to handle UI? + return session.call(messages.ButtonAck()) + + +class SessionV2(Session): + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | None, + derive_cardano: bool, + session_id: int = 0, + ) -> SessionV2: + assert isinstance(client.protocol, ProtocolV2) + session = cls(client, session_id.to_bytes(1, "big")) + session.call( + messages.ThpCreateNewSession( + passphrase=passphrase, derive_cardano=derive_cardano + ), + expect=messages.Success, + ) + session.update_id_and_sid(session_id.to_bytes(1, "big")) + return session + + def __init__(self, client: TrezorClient, id: bytes) -> None: + from ..debuglink import TrezorClientDebugLink + + super().__init__(client, id) + assert isinstance(client.protocol, ProtocolV2) + + self.pin_callback = client.pin_callback + self.button_callback = client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + helper_debug = None + if isinstance(client, TrezorClientDebugLink): + helper_debug = client.debug + self.channel: ProtocolV2 = client.protocol.get_channel(helper_debug) + self.update_id_and_sid(id) + + def _write(self, msg: t.Any) -> None: + LOG.debug("writing message %s", type(msg)) + self.channel.write(self.sid, msg) + + def _read(self) -> t.Any: + msg = self.channel.read(self.sid) + LOG.debug("reading message %s", type(msg)) + return msg + + def update_id_and_sid(self, id: bytes) -> None: + self._id = id + self.sid = int.from_bytes(id, "big") # TODO update to extract only sid diff --git a/python/src/trezorlib/transport/thp/alternating_bit_protocol.py b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py new file mode 100644 index 00000000000..62fb650fab0 --- /dev/null +++ b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +# from storage.cache_thp import ChannelCache +# from trezor import log +# from trezor.wire.thp import ThpError + + +# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: +# """ +# Checks if: +# - an ACK message is expected +# - the received ACK message acknowledges correct sequence number (bit) +# """ +# if not _is_ack_expected(cache): +# return False + +# if not _has_ack_correct_sync_bit(cache, ack_bit): +# return False + +# return True + + +# def _is_ack_expected(cache: ChannelCache) -> bool: +# is_expected: bool = not is_sending_allowed(cache) +# if __debug__ and not is_expected: +# log.debug(__name__, "Received unexpected ACK message") +# return is_expected + + +# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: +# is_correct: bool = get_send_seq_bit(cache) == sync_bit +# if __debug__ and not is_correct: +# log.debug(__name__, "Received ACK message with wrong ack bit") +# return is_correct + + +# def is_sending_allowed(cache: ChannelCache) -> bool: +# """ +# Checks whether sending a message in the provided channel is allowed. + +# Note: Sending a message in a channel before receipt of ACK message for the previously +# sent message (in the channel) is prohibited, as it can lead to desynchronization. +# """ +# return bool(cache.sync >> 7) + + +# def get_send_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the sequential number (bit) of the next message to be sent +# in the provided channel. +# """ +# return (cache.sync & 0x20) >> 5 + + +# def get_expected_receive_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the (expected) sequential number (bit) of the next message +# to be received in the provided channel. +# """ +# return (cache.sync & 0x40) >> 6 + + +# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: +# """ +# Set the flag whether sending a message in this channel is allowed or not. +# """ +# cache.sync &= 0x7F +# if sending_allowed: +# cache.sync |= 0x80 + + +# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# """ +# Set the expected sequential number (bit) of the next message to be received +# in the provided channel +# """ +# if __debug__: +# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected receive sync bit") + +# # set second bit to "seq_bit" value +# cache.sync &= 0xBF +# if seq_bit: +# cache.sync |= 0x40 + + +# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected send seq bit") +# if __debug__: +# log.debug(__name__, "setting sync send seq bit to %d", seq_bit) +# # set third bit to "seq_bit" value +# cache.sync &= 0xDF +# if seq_bit: +# cache.sync |= 0x20 + + +# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: +# """ +# Set the sequential bit of the "next message to be send" to the opposite value, +# i.e. 1 -> 0 and 0 -> 1 +# """ +# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/python/src/trezorlib/transport/thp/channel_data.py b/python/src/trezorlib/transport/thp/channel_data.py new file mode 100644 index 00000000000..4d9d11d8d0d --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_data.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from binascii import hexlify + + +class ChannelData: + + def __init__( + self, + protocol_version_major: int, + protocol_version_minor: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + handshake_hash: bytes, + ) -> None: + self.protocol_version_major: int = protocol_version_major + self.protocol_version_minor: int = protocol_version_minor + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + self.handshake_hash: str = hexlify(handshake_hash).decode() + + def to_dict(self): + return { + "protocol_version_major": self.protocol_version_major, + "protocol_version_minor": self.protocol_version_minor, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + "handshake_hash": self.handshake_hash, + } diff --git a/python/src/trezorlib/transport/thp/channel_database.py b/python/src/trezorlib/transport/thp/channel_database.py new file mode 100644 index 00000000000..03be0f7ecea --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_database.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import json +import logging +import os +import typing as t + +from ..thp.channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +db: "ChannelDatabase | None" = None + + +def get_channel_db() -> ChannelDatabase: + if db is None: + set_channel_database(should_not_store=True) + assert db is not None + return db + + +class ChannelDatabase: + + def load_stored_channels(self) -> t.List[ChannelData]: ... + def clear_stored_channels(self) -> None: ... + def read_all_channels(self) -> t.List: ... + def save_all_channels(self, channels: t.List[t.Dict]) -> None: ... + def save_channel(self, new_channel: ProtocolAndChannel): ... + def remove_channel(self, transport_path: str) -> None: ... + + +class DummyChannelDatabase(ChannelDatabase): + + def load_stored_channels(self) -> t.List[ChannelData]: + return [] + + def clear_stored_channels(self) -> None: + pass + + def read_all_channels(self) -> t.List: + return [] + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + return + + def save_channel(self, new_channel: ProtocolAndChannel): + pass + + def remove_channel(self, transport_path: str) -> None: + pass + + +class JsonChannelDatabase(ChannelDatabase): + def __init__(self, data_path: str) -> None: + self.data_path = data_path + super().__init__() + + def load_stored_channels(self) -> t.List[ChannelData]: + dicts = self.read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + def clear_stored_channels(self) -> None: + LOG.debug("Clearing contents of %s", self.data_path) + with open(self.data_path, "w") as f: + json.dump([], f) + try: + os.remove(self.data_path) + except Exception as e: + LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e))) + + def read_all_channels(self) -> t.List: + ensure_file_exists(self.data_path) + with open(self.data_path, "r") as f: + return json.load(f) + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(self.data_path, "w") as f: + json.dump(channels, f, indent=4) + + def save_channel(self, new_channel: ProtocolAndChannel): + + LOG.debug("save channel") + channels = self.read_all_channels() + transport_path = new_channel.transport.get_path() + + # If the channel is found in database: replace the old entry by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + self.save_all_channels(channels) + return + + # Channel was not found: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + self.save_all_channels(channels) + + def remove_channel(self, transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = self.read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + self.save_all_channels(remaining_channels) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version_major=dict["protocol_version_minor"], + protocol_version_minor=dict["protocol_version_major"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + handshake_hash=bytes.fromhex(dict["handshake_hash"]), + ) + + +def ensure_file_exists(file_path: str) -> None: + LOG.debug("checking if file %s exists", file_path) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + LOG.debug("File %s does not exist. Creating a new one.", file_path) + with open(file_path, "w") as f: + json.dump([], f) + + +def set_channel_database(should_not_store: bool): + global db + if should_not_store: + db = DummyChannelDatabase() + else: + from platformdirs import user_cache_dir + + APP_NAME = "@trezor" # TODO + DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") + + db = JsonChannelDatabase(DATA_PATH) diff --git a/python/src/trezorlib/transport/thp/checksum.py b/python/src/trezorlib/transport/thp/checksum.py new file mode 100644 index 00000000000..8e0f32f0132 --- /dev/null +++ b/python/src/trezorlib/transport/thp/checksum.py @@ -0,0 +1,19 @@ +import zlib + +CHECKSUM_LENGTH = 4 + + +def compute(data: bytes) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/python/src/trezorlib/transport/thp/control_byte.py b/python/src/trezorlib/transport/thp/control_byte.py new file mode 100644 index 00000000000..dca681ef020 --- /dev/null +++ b/python/src/trezorlib/transport/thp/control_byte.py @@ -0,0 +1,63 @@ +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise Exception("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise Exception("Unexpected acknowledgement bit") + + +def get_seq_bit(ctrl_byte: int) -> int: + return (ctrl_byte & 0x10) >> 4 + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & ACK_MASK == ACK_MESSAGE + + +def is_error(ctrl_byte: int) -> bool: + return ctrl_byte == _ERROR + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/python/src/trezorlib/transport/thp/cpace.py b/python/src/trezorlib/transport/thp/cpace.py new file mode 100644 index 00000000000..d0b28e265c4 --- /dev/null +++ b/python/src/trezorlib/transport/thp/cpace.py @@ -0,0 +1,40 @@ +import typing as t +from hashlib import sha512 + +from . import curve25519 + +_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06" +_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20" + + +class Cpace: + """ + CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/ + """ + + random_bytes: t.Callable[[int], bytes] + + def __init__(self, handshake_hash: bytes) -> None: + self.handshake_hash: bytes = handshake_hash + self.shared_secret: bytes + self.host_private_key: bytes + self.host_public_key: bytes + + def generate_keys_and_secret( + self, code_code_entry: bytes, trezor_public_key: bytes + ) -> None: + """ + Generate ephemeral key pair and a shared secret using Elligator2 with X25519. + """ + sha_ctx = sha512(_PREFIX) + sha_ctx.update(code_code_entry) + sha_ctx.update(_PADDING) + sha_ctx.update(self.handshake_hash) + sha_ctx.update(b"\x00") + pregenerator = sha_ctx.digest()[:32] + generator = curve25519.elligator2(pregenerator) + self.host_private_key = self.random_bytes(32) + self.host_public_key = curve25519.multiply(self.host_private_key, generator) + self.shared_secret = curve25519.multiply( + self.host_private_key, trezor_public_key + ) diff --git a/python/src/trezorlib/transport/thp/curve25519.py b/python/src/trezorlib/transport/thp/curve25519.py new file mode 100644 index 00000000000..e4416225f1f --- /dev/null +++ b/python/src/trezorlib/transport/thp/curve25519.py @@ -0,0 +1,159 @@ +from typing import Tuple + +p = 2**255 - 19 +J = 486662 + +c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1) +c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8 +a24 = 121666 # (J + 2) // 4 + + +def decode_scalar(scalar: bytes) -> int: + # decodeScalar25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + if len(scalar) != 32: + raise ValueError("Invalid length of scalar") + + array = bytearray(scalar) + array[0] &= 248 + array[31] &= 127 + array[31] |= 64 + + return int.from_bytes(array, "little") + + +def decode_coordinate(coordinate: bytes) -> int: + # decodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + if len(coordinate) != 32: + raise ValueError("Invalid length of coordinate") + + array = bytearray(coordinate) + array[-1] &= 0x7F + return int.from_bytes(array, "little") % p + + +def encode_coordinate(coordinate: int) -> bytes: + # encodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + return coordinate.to_bytes(32, "little") + + +def get_private_key(secret: bytes) -> bytes: + return decode_scalar(secret).to_bytes(32, "little") + + +def get_public_key(private_key: bytes) -> bytes: + base_point = int.to_bytes(9, 32, "little") + return multiply(private_key, base_point) + + +def multiply(private_scalar: bytes, public_point: bytes): + # X25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + def ladder_operation( + x1: int, x2: int, z2: int, x3: int, z3: int + ) -> Tuple[int, int, int, int]: + # https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3 + # (x4, z4) = 2 * (x2, z2) + # (x5, z5) = (x2, z2) + (x3, z3) + # where (x1, 1) = (x3, z3) - (x2, z2) + + a = (x2 + z2) % p + aa = (a * a) % p + b = (x2 - z2) % p + bb = (b * b) % p + e = (aa - bb) % p + c = (x3 + z3) % p + d = (x3 - z3) % p + da = (d * a) % p + cb = (c * b) % p + t0 = (da + cb) % p + x5 = (t0 * t0) % p + t1 = (da - cb) % p + t2 = (t1 * t1) % p + z5 = (x1 * t2) % p + x4 = (aa * bb) % p + t3 = (a24 * e) % p + t4 = (bb + t3) % p + z4 = (e * t4) % p + + return x4, z4, x5, z5 + + def conditional_swap(first: int, second: int, condition: int): + # Returns (second, first) if condition is true and (first, second) otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask), (second & false_mask) | ( + first & true_mask + ) + + k = decode_scalar(private_scalar) + u = decode_coordinate(public_point) + + x_1 = u + x_2 = 1 + z_2 = 0 + x_3 = u + z_3 = 1 + swap = 0 + + for i in reversed(range(256)): + bit = (k >> i) & 1 + swap = bit ^ swap + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + swap = bit + x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3) + + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + + x = pow(z_2, p - 2, p) * x_2 % p + return encode_coordinate(x) + + +def elligator2(point: bytes) -> bytes: + # map_to_curve_elligator2_curve25519 from + # https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt + + def conditional_move(first: int, second: int, condition: bool): + # Returns second if condition is true and first otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask) + + u = decode_coordinate(point) + tv1 = (u * u) % p + tv1 = (2 * tv1) % p + xd = (tv1 + 1) % p + x1n = (-J) % p + tv2 = (xd * xd) % p + gxd = (tv2 * xd) % p + gx1 = (J * tv1) % p + gx1 = (gx1 * x1n) % p + gx1 = (gx1 + tv2) % p + gx1 = (gx1 * x1n) % p + tv3 = (gxd * gxd) % p + tv2 = (tv3 * tv3) % p + tv3 = (tv3 * gxd) % p + tv3 = (tv3 * gx1) % p + tv2 = (tv2 * tv3) % p + y11 = pow(tv2, c4, p) + y11 = (y11 * tv3) % p + y12 = (y11 * c3) % p + tv2 = (y11 * y11) % p + tv2 = (tv2 * gxd) % p + e1 = tv2 == gx1 + y1 = conditional_move(y12, y11, e1) + x2n = (x1n * tv1) % p + tv2 = (y1 * y1) % p + tv2 = (tv2 * gxd) % p + e3 = tv2 == gx1 + xn = conditional_move(x2n, x1n, e3) + x = xn * pow(xd, p - 2, p) % p + return encode_coordinate(x) diff --git a/python/src/trezorlib/transport/thp/message_header.py b/python/src/trezorlib/transport/thp/message_header.py new file mode 100644 index 00000000000..d2ff002d636 --- /dev/null +++ b/python/src/trezorlib/transport/thp/message_header.py @@ -0,0 +1,82 @@ +import struct + +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +BROADCAST_CHANNEL_ID = 0xFFFF + + +class MessageHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.data_length = length + + def to_bytes_init(self) -> bytes: + return struct.pack( + self.format_str_init, self.ctrl_byte, self.cid, self.data_length + ) + + def to_bytes_cont(self) -> bytes: + return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.data_length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + def is_ack(self) -> bool: + return self.ctrl_byte & ACK_MASK == ACK_MESSAGE + + def is_channel_allocation_response(self): + return ( + self.cid == BROADCAST_CHANNEL_ID + and self.ctrl_byte == _CHANNEL_ALLOCATION_RES + ) + + def is_handshake_init_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES + + def is_handshake_comp_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES + + def is_encrypted_transport(self) -> bool: + return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_request_header(cls, length: int): + return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length) diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py new file mode 100644 index 00000000000..fa420ac0af2 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp.channel_data import ChannelData + +LOG = logging.getLogger(__name__) + + +class ProtocolAndChannel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.transport = transport + self.mapping = mapping + self.channel_keys = channel_data + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def get_channel_data(self) -> ChannelData: + raise NotImplementedError + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py new file mode 100644 index 00000000000..baea7e74010 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + + +class ProtocolV1(ProtocolAndChannel): + HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py new file mode 100644 index 00000000000..b073a0264de --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import typing as t +from binascii import hexlify + +import click +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ... import exceptions, messages, protobuf +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp import checksum, curve25519, thp_io +from ..thp.channel_data import ChannelData +from ..thp.checksum import CHECKSUM_LENGTH +from ..thp.message_header import MessageHeader +from . import control_byte +from .channel_database import ChannelDatabase, get_channel_db +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +DEFAULT_SESSION_ID: int = 0 + +if t.TYPE_CHECKING: + from ...debuglink import DebugLink +MT = t.TypeVar("MT", bound=protobuf.MessageType) + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") + + +class ProtocolV2(ProtocolAndChannel): + channel_id: int + channel_database: ChannelDatabase + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + sync_bit_send: int + sync_bit_receive: int + handshake_hash: bytes + + _has_valid_channel: bool = False + _features: messages.Features | None = None + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.channel_database: ChannelDatabase = get_channel_db() + super().__init__(transport, mapping, channel_data) + if channel_data is not None: + self.channel_id = channel_data.channel_id + self.key_request = bytes.fromhex(channel_data.key_request) + self.key_response = bytes.fromhex(channel_data.key_response) + self.nonce_request = channel_data.nonce_request + self.nonce_response = channel_data.nonce_response + self.sync_bit_receive = channel_data.sync_bit_receive + self.sync_bit_send = channel_data.sync_bit_send + self.handshake_hash = bytes.fromhex(channel_data.handshake_hash) + self._has_valid_channel = True + + def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2: + if not self._has_valid_channel: + self._establish_new_channel(helper_debug) + return self + + def get_channel_data(self) -> ChannelData: + return ChannelData( + protocol_version_major=2, + protocol_version_minor=2, + transport_path=self.transport.get_path(), + channel_id=self.channel_id, + key_request=self.key_request, + key_response=self.key_response, + nonce_request=self.nonce_request, + nonce_response=self.nonce_response, + sync_bit_receive=self.sync_bit_receive, + sync_bit_send=self.sync_bit_send, + handshake_hash=self.handshake_hash, + ) + + def read(self, session_id: int) -> t.Any: + sid, msg_type, msg_data = self.read_and_decrypt() + if sid != session_id: + raise Exception("Received messsage on a different session.") + self.channel_database.save_channel(self) + return self.mapping.decode(msg_type, msg_data) + + def write(self, session_id: int, msg: t.Any) -> None: + msg_type, msg_data = self.mapping.encode(msg) + self._encrypt_and_write(session_id, msg_type, msg_data) + self.channel_database.save_channel(self) + + def get_features(self) -> messages.Features: + if not self._has_valid_channel: + self._establish_new_channel() + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + message = messages.GetFeatures() + message_type, message_data = self.mapping.encode(message) + self.session_id: int = DEFAULT_SESSION_ID + self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data) + _ = self._read_until_valid_crc_check() # TODO check ACK + _, msg_type, msg_data = self.read_and_decrypt() + features = self.mapping.decode(msg_type, msg_data) + if not isinstance(features, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = features + + def _send_message( + self, + message: protobuf.MessageType, + session_id: int = DEFAULT_SESSION_ID, + ): + message_type, message_data = self.mapping.encode(message) + self._encrypt_and_write(session_id, message_type, message_data) + self._read_ack() + + def _read_message(self, message_type: type[MT]) -> MT: + _, msg_type, msg_data = self.read_and_decrypt() + msg = self.mapping.decode(msg_type, msg_data) + assert isinstance(msg, message_type) + return msg + + def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None: + self._reset_sync_bits() + self._do_channel_allocation() + self._do_handshake() + self._do_pairing(helper_debug) + + def _reset_sync_bits(self) -> None: + self.sync_bit_send = 0 + self.sync_bit_receive = 0 + + def _do_channel_allocation(self) -> None: + channel_allocation_nonce = os.urandom(8) + self._send_channel_allocation_request(channel_allocation_nonce) + cid, dp = self._read_channel_allocation_response(channel_allocation_nonce) + self.channel_id = cid + self.device_properties = dp + + def _send_channel_allocation_request(self, nonce: bytes): + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + MessageHeader.get_channel_allocation_request_header(12), + nonce, + ) + + def _read_channel_allocation_response( + self, expected_nonce: bytes + ) -> tuple[int, bytes]: + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, expected_nonce + ): + raise Exception("Invalid channel allocation response.") + + channel_id = int.from_bytes(payload[8:10], "big") + device_properties = payload[10:] + return (channel_id, device_properties) + + def _do_handshake( + self, credential: bytes | None = None, host_static_privkey: bytes | None = None + ): + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + self._send_handshake_init_request(host_ephemeral_pubkey) + self._read_ack() + init_response = self._read_handshake_init_response() + + trezor_ephemeral_pubkey = init_response[:32] + encrypted_trezor_static_pubkey = init_response[32:80] + noise_tag = init_response[80:96] + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) + + # TODO check noise_tag is valid + + ck = self._send_handshake_completion_request( + host_ephemeral_pubkey, + host_ephemeral_privkey, + trezor_ephemeral_pubkey, + encrypted_trezor_static_pubkey, + credential, + host_static_privkey, + ) + self._read_ack() + self._read_handshake_completion_response() + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None: + ha_init_req_header = MessageHeader(0, self.channel_id, 36) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, ha_init_req_header, host_ephemeral_pubkey + ) + + def _read_handshake_init_response(self) -> bytes: + header, payload = self._read_until_valid_crc_check() + self._send_ack_0() + + if header.ctrl_byte == 0x42: + if payload == b"\x05": + raise exceptions.DeviceLockedException() + + if not header.is_handshake_init_response(): + LOG.debug("Received message is not a valid handshake init response message") + + click.echo( + "Received message is not a valid handshake init response message", + err=True, + ) + return payload + + def _send_handshake_completion_request( + self, + host_ephemeral_pubkey: bytes, + host_ephemeral_privkey: bytes, + trezor_ephemeral_pubkey: bytes, + encrypted_trezor_static_pubkey: bytes, + credential: bytes | None = None, + host_static_privkey: bytes | None = None, + ) -> bytes: + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + except Exception as e: + click.echo( + f"Exception of type{type(e)}", err=True + ) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + + # TODO: search for saved credentials + if host_static_privkey is not None and credential is not None: + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + else: + credential = None + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key( + temp_host_static_privkey + ) + host_static_privkey = temp_host_static_privkey + host_static_pubkey = temp_host_static_pubkey + + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = self.mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + host_pairing_credential=credential, + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload[:-16]) + ha_completion_req_header = MessageHeader( + 0x12, + self.channel_id, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + self.handshake_hash = h + return ck + + def _read_handshake_completion_response(self) -> None: + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + click.echo( + "Received message is not a valid handshake completion response", + err=True, + ) + self._send_ack_1() + + def _do_pairing(self, helper_debug: DebugLink | None): + + self._send_message(messages.ThpPairingRequest()) + self._read_message(messages.ButtonRequest) + self._send_message(messages.ButtonAck()) + + if helper_debug is not None: + helper_debug.press_yes() + + self._read_message(messages.ThpPairingRequestApproved) + self._send_message( + messages.ThpSelectMethod( + selected_pairing_method=messages.ThpPairingMethod.SkipPairing + ) + ) + self._read_message(messages.ThpEndResponse) + + self._has_valid_channel = True + + def _read_ack(self): + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + def _send_ack_0(self): + LOG.debug("sending ack 0") + header = MessageHeader(0x20, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _send_ack_1(self): + LOG.debug("sending ack 1") + header = MessageHeader(0x28, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _encrypt_and_write( + self, + session_id: int, + message_type: int, + message_data: bytes, + ctrl_byte: int | None = None, + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + + if ctrl_byte is None: + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) + self.sync_bit_send = 1 - self.sync_bit_send + + sid = session_id.to_bytes(1, "big") + msg_type = message_type.to_bytes(2, "big") + data = sid + msg_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = MessageHeader( + ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, header, encrypted_message + ) + + def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if control_byte.is_ack(header.ctrl_byte): + # TODO fix this recursion + return self.read_and_decrypt() + if control_byte.is_error(header.ctrl_byte): + # TODO check for different channel + err = _get_error_from_int(raw_payload[0]) + raise Exception("Received ThpError: " + err) + if not header.is_encrypted_transport(): + click.echo( + "Trying to decrypt not encrypted message! (" + + hexlify(header.to_bytes_init() + raw_payload).decode() + + ")", + err=True, + ) + + if not control_byte.is_ack(header.ctrl_byte): + LOG.debug( + "--> Get sequence bit %d %s %s", + control_byte.get_seq_bit(header.ctrl_byte), + "from control byte", + hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(), + ) + if control_byte.get_seq_bit(header.ctrl_byte): + self._send_ack_1() + else: + self._send_ack_0() + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + session_id, + int.from_bytes(message_type, "big"), + message_data, + ) + + def _read_until_valid_crc_check( + self, + ) -> t.Tuple[MessageHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.transport) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + click.echo( + "Received a message with an invalid checksum:" + + hexlify(header.to_bytes_init() + payload + chksum).decode(), + err=True, + ) + header, payload, chksum = thp_io.read(self.transport) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: MessageHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + click.echo( + "Received message is not a channel allocation response", err=True + ) + return False + if len(payload) < 10: + click.echo("Invalid channel allocation response payload", err=True) + return False + if payload[:8] != original_nonce: + click.echo( + "Invalid channel allocation response payload (nonce mismatch)", err=True + ) + return False + return True + + +def _get_error_from_int(error_code: int) -> str: + # TODO FIXME improve this (ThpErrorType) + if error_code == 1: + return "TRANSPORT BUSY" + if error_code == 2: + return "UNALLOCATED CHANNEL" + if error_code == 3: + return "DECRYPTION FAILED" + if error_code == 4: + return "INVALID DATA" + if error_code == 5: + return "DEVICE LOCKED" + raise Exception("Not Implemented error case") diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py new file mode 100644 index 00000000000..d0237f9e36d --- /dev/null +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -0,0 +1,93 @@ +import struct +from typing import Tuple + +from .. import Transport +from ..thp import checksum +from .message_header import MessageHeader + +INIT_HEADER_LENGTH = 5 +CONT_HEADER_LENGTH = 3 +MAX_PAYLOAD_LEN = 60000 +MESSAGE_TYPE_LENGTH = 2 + +CONTINUATION_PACKET = 0x80 + + +def write_payload_to_wire_and_add_checksum( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) + data = transport_payload + chksum + write_payload_to_wire(transport, header, data) + + +def write_payload_to_wire( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + transport.open() + buffer = bytearray(transport_payload) + chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH] + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + + buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :] + while buffer: + chunk = ( + header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH] + ) + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] + + +def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]: + """ + Reads from the given wire transport. + + Returns `Tuple[MessageHeader, bytes, bytes]`: + 1. `header` (`MessageHeader`): Header of the message. + 2. `data` (`bytes`): Contents of the message (if any). + 3. `checksum` (`bytes`): crc32 checksum of the header + data. + + """ + buffer = bytearray() + + # Read header with first part of message data + header, first_chunk = read_first(transport) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < header.data_length: + buffer.extend(read_next(transport, header.cid)) + + data_len = header.data_length - checksum.CHECKSUM_LENGTH + msg_data = buffer[:data_len] + chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH] + + return (header, msg_data, chksum) + + +def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]: + chunk = transport.read_chunk() + try: + ctrl_byte, cid, data_length = struct.unpack( + MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] + ) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[INIT_HEADER_LENGTH:] + return MessageHeader(ctrl_byte, cid, data_length), data + + +def read_next(transport: Transport, cid: int) -> bytes: + chunk = transport.read_chunk() + ctrl_byte, read_cid = struct.unpack( + MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] + ) + if ctrl_byte != CONTINUATION_PACKET: + raise RuntimeError("Continuation packet with incorrect control byte") + if read_cid != cid: + raise RuntimeError("Continuation packet for different channel") + + return chunk[CONT_HEADER_LENGTH:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c63..e17d6f45006 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -14,14 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -31,14 +32,18 @@ LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 - def __init__(self, device: Optional[str] = None) -> None: + def __init__( + self, + device: str | None = None, + ) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -46,24 +51,17 @@ def __init__(self, device: Optional[str] = None) -> None: devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) - self.socket: Optional[socket.socket] = None - - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) + self.device: Tuple[str, int] = (host, port) - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + self.socket: socket.socket | None = None + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -77,7 +75,7 @@ def _try_path(cls, path: str) -> "UdpTransport": @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -99,20 +97,8 @@ def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": else: raise TransportException(f"No UDP device at {path}") - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +110,9 @@ def close(self) -> None: self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +120,8 @@ def write_chunk(self, chunk: bytes) -> None: self.socket.sendall(chunk) def read_chunk(self) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None while True: try: @@ -154,3 +133,33 @@ def read_chunk(self) -> bytes: if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a6..872d9619601 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException LOG = logging.getLogger(__name__) @@ -44,13 +45,69 @@ WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" def open(self) -> None: self.handle = self.device.open() @@ -64,6 +121,8 @@ def open(self) -> None: self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: @@ -75,6 +134,8 @@ def close(self) -> None: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -97,6 +158,8 @@ def write_chunk(self, chunk: bytes) -> None: return def read_chunk(self) -> bytes: + if self.handle is None: + self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -117,70 +180,6 @@ def read_chunk(self) -> bytes: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return chunk - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True) diff --git a/python/tools/encfs_aes_getpass.py b/python/tools/encfs_aes_getpass.py index 82773e50fa7..9466cce574d 100755 --- a/python/tools/encfs_aes_getpass.py +++ b/python/tools/encfs_aes_getpass.py @@ -35,7 +35,6 @@ from trezorlib.client import TrezorClient from trezorlib.tools import Address from trezorlib.transport import enumerate_devices -from trezorlib.ui import ClickUI version_tuple = tuple(map(int, trezorlib.__version__.split("."))) if not (0, 11) <= version_tuple < (0, 14): @@ -71,7 +70,7 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write("Available devices:\n") for d in devices: try: - client = TrezorClient(d, ui=ClickUI()) + client = TrezorClient(d) except IOError: sys.stderr.write("[-] \n") continue @@ -80,7 +79,7 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write(f"[{i}] {client.features.label}\n") else: sys.stderr.write(f"[{i}] \n") - client.close() + # TODO client.close() i += 1 sys.stderr.write("----------------------------\n") @@ -106,7 +105,8 @@ def main() -> None: devices = wait_for_devices() transport = choose_device(devices) - client = TrezorClient(transport, ui=ClickUI()) + client = TrezorClient(transport) + session = client.get_seedless_session() rootdir = os.environ["encfs_root"] # Read "man encfs" for more passw_file = os.path.join(rootdir, "password.dat") @@ -120,7 +120,7 @@ def main() -> None: sys.stderr.write("Computer asked Trezor for new strong password.\n") # 32 bytes, good for AES - trezor_entropy = trezorlib.misc.get_entropy(client, 32) + trezor_entropy = trezorlib.misc.get_entropy(session, 32) urandom_entropy = os.urandom(32) passw = hashlib.sha256(trezor_entropy + urandom_entropy).digest() @@ -129,7 +129,7 @@ def main() -> None: bip32_path = Address([10, 0]) passw_encrypted = trezorlib.misc.encrypt_keyvalue( - client, bip32_path, label, passw, False, True + session, bip32_path, label, passw, False, True ) data = { @@ -144,7 +144,7 @@ def main() -> None: data = json.load(open(passw_file, "r")) passw = trezorlib.misc.decrypt_keyvalue( - client, + session, data["bip32_path"], data["label"], bytes.fromhex(data["password_encrypted_hex"]), diff --git a/python/tools/helloworld.py b/python/tools/helloworld.py index 76b4502da2d..2403bb4bc29 100755 --- a/python/tools/helloworld.py +++ b/python/tools/helloworld.py @@ -24,13 +24,14 @@ def main() -> None: # Use first connected device client = get_default_client() + session = client.get_session() # Print out Trezor's features and settings - print(client.features) + print(session.features) # Get the first address of first BIP44 account bip32_path = parse_path("44h/0h/0h/0/0") - address = btc.get_address(client, "Bitcoin", bip32_path, True) + address = btc.get_address(session, "Bitcoin", bip32_path, True) print("Bitcoin address:", address) diff --git a/python/tools/pwd_reader.py b/python/tools/pwd_reader.py index afd405e1642..3921da24de6 100755 --- a/python/tools/pwd_reader.py +++ b/python/tools/pwd_reader.py @@ -26,23 +26,24 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.tools import parse_path from trezorlib.transport import get_transport +from trezorlib.transport.session import Session # Return path by BIP-32 BIP32_PATH = parse_path("10016h/0") # Deriving master key -def getMasterKey(client: TrezorClient) -> str: +def getMasterKey(session: Session) -> str: bip32_path = BIP32_PATH ENC_KEY = "Activate TREZOR Password Manager?" ENC_VALUE = bytes.fromhex( "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee" ) - key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True) + key = misc.encrypt_keyvalue(session, bip32_path, ENC_KEY, ENC_VALUE, True, True) return key.hex() @@ -101,7 +102,7 @@ def decryptEntryValue(nonce: str, val: bytes) -> dict: # Decrypt give entry nonce -def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: +def getDecryptedNonce(session: Session, entry: dict) -> str: print() print("Waiting for Trezor input ...") print() @@ -117,7 +118,7 @@ def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: ENC_KEY = f"Unlock {item} for user {entry['username']}?" ENC_VALUE = entry["nonce"] decrypted_nonce = misc.decrypt_keyvalue( - client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True + session, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True ) return decrypted_nonce.hex() @@ -144,13 +145,14 @@ def main() -> None: print(e) return - client = TrezorClient(transport=transport, ui=ui.ClickUI()) + client = TrezorClient(transport=transport) + session = client.get_seedless_session() print() print("Confirm operation on Trezor") print() - masterKey = getMasterKey(client) + masterKey = getMasterKey(session) # print('master key:', masterKey) fileName = getFileEncKey(masterKey)[0] @@ -173,7 +175,7 @@ def main() -> None: entry_id = input("Select entry number to decrypt: ") entry_id = str(entry_id) - plain_nonce = getDecryptedNonce(client, entries[entry_id]) + plain_nonce = getDecryptedNonce(session, entries[entry_id]) pwdArr = entries[entry_id]["password"]["data"] pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr]) diff --git a/python/tools/pybridge.py b/python/tools/pybridge.py index 30d69bbc9b1..d94ec121d03 100644 --- a/python/tools/pybridge.py +++ b/python/tools/pybridge.py @@ -24,6 +24,8 @@ from gevent import monkey +import trezorlib.transport + monkey.patch_all() import json @@ -103,11 +105,11 @@ def __init__(self, transport: trezorlib.transport.Transport) -> None: self.session: Session | None = None self.transport = transport - client = TrezorClient(transport, ui=SilentUI()) + client = TrezorClient(transport) # TODO add silent UI? self.model = ( trezorlib.models.by_name(client.features.model) or trezorlib.models.TREZOR_T ) - client.end_session() + # TODO client.end_session() def acquire(self, sid: str) -> str: if self.session_id() != sid: @@ -116,11 +118,11 @@ def acquire(self, sid: str) -> str: self.session.release() self.session = Session(self) - self.transport.begin_session() + # TODO self.transport.deprecated_begin_session() return self.session.id def release(self) -> None: - self.transport.end_session() + # TODO self.transport.deprecated_end_session() self.session = None def session_id(self) -> str | None: @@ -141,10 +143,14 @@ def to_json(self) -> dict: } def write(self, msg_id: int, data: bytes) -> None: - self.transport.write(msg_id, data) + raise NotImplementedError + # TODO + # self.transport.write(msg_id, data) def read(self) -> tuple[int, bytes]: - return self.transport.read() + raise NotImplementedError + # TODO + # return self.transport.read() @classmethod def find(cls, path: str) -> Transport | None: diff --git a/python/tools/rng_entropy_collector.py b/python/tools/rng_entropy_collector.py index 2b0a5b80d79..e99ed1e2734 100755 --- a/python/tools/rng_entropy_collector.py +++ b/python/tools/rng_entropy_collector.py @@ -7,14 +7,15 @@ import io import sys -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.transport import get_transport def main() -> None: try: - client = TrezorClient(get_transport(), ui=ui.ClickUI()) + client = TrezorClient(get_transport()) + session = client.get_seedless_session() except Exception as e: print(e) return @@ -25,11 +26,9 @@ def main() -> None: with io.open(arg1, "wb") as f: for _ in range(0, arg2, step): - entropy = misc.get_entropy(client, step) + entropy = misc.get_entropy(session, step) f.write(entropy) - client.close() - if __name__ == "__main__": main() diff --git a/python/tools/trezor-otp.py b/python/tools/trezor-otp.py index bc0b66daa97..435feac821f 100755 --- a/python/tools/trezor-otp.py +++ b/python/tools/trezor-otp.py @@ -27,26 +27,25 @@ from trezorlib.misc import decrypt_keyvalue, encrypt_keyvalue from trezorlib.tools import parse_path from trezorlib.transport import get_transport -from trezorlib.ui import ClickUI BIP32_PATH = parse_path("10016h/0") def encrypt(type: str, domain: str, secret: str) -> str: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + client = TrezorClient(transport) + session = client.get_seedless_session() dom = type.upper() + ": " + domain - enc = encrypt_keyvalue(client, BIP32_PATH, dom, secret.encode(), False, True) - client.close() + enc = encrypt_keyvalue(session, BIP32_PATH, dom, secret.encode(), False, True) return enc.hex() def decrypt(type: str, domain: str, secret: bytes) -> bytes: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + client = TrezorClient(transport) + session = client.get_seedless_session() dom = type.upper() + ": " + domain - dec = decrypt_keyvalue(client, BIP32_PATH, dom, secret, False, True) - client.close() + dec = decrypt_keyvalue(session, BIP32_PATH, dom, secret, False, True) return dec diff --git a/rust/trezor-client/src/messages/generated.rs b/rust/trezor-client/src/messages/generated.rs index 551a1e92e22..ab956549368 100644 --- a/rust/trezor-client/src/messages/generated.rs +++ b/rust/trezor-client/src/messages/generated.rs @@ -82,6 +82,8 @@ trezor_message_impl! { DebugLinkWatchLayout => MessageType_DebugLinkWatchLayout, DebugLinkResetDebugEvents => MessageType_DebugLinkResetDebugEvents, DebugLinkOptigaSetSecMax => MessageType_DebugLinkOptigaSetSecMax, + DebugLinkGetPairingInfo => MessageType_DebugLinkGetPairingInfo, + DebugLinkPairingInfo => MessageType_DebugLinkPairingInfo, BenchmarkListNames => MessageType_BenchmarkListNames, BenchmarkNames => MessageType_BenchmarkNames, BenchmarkRun => MessageType_BenchmarkRun, diff --git a/rust/trezor-client/src/protos/generated/messages.rs b/rust/trezor-client/src/protos/generated/messages.rs index 0a265410a6c..109f6e52329 100644 --- a/rust/trezor-client/src/protos/generated/messages.rs +++ b/rust/trezor-client/src/protos/generated/messages.rs @@ -226,6 +226,10 @@ pub enum MessageType { MessageType_DebugLinkResetDebugEvents = 9007, // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_DebugLinkOptigaSetSecMax) MessageType_DebugLinkOptigaSetSecMax = 9008, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_DebugLinkGetPairingInfo) + MessageType_DebugLinkGetPairingInfo = 9009, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_DebugLinkPairingInfo) + MessageType_DebugLinkPairingInfo = 9010, // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_EthereumGetPublicKey) MessageType_EthereumGetPublicKey = 450, // @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_EthereumPublicKey) @@ -632,6 +636,8 @@ impl ::protobuf::Enum for MessageType { 9006 => ::std::option::Option::Some(MessageType::MessageType_DebugLinkWatchLayout), 9007 => ::std::option::Option::Some(MessageType::MessageType_DebugLinkResetDebugEvents), 9008 => ::std::option::Option::Some(MessageType::MessageType_DebugLinkOptigaSetSecMax), + 9009 => ::std::option::Option::Some(MessageType::MessageType_DebugLinkGetPairingInfo), + 9010 => ::std::option::Option::Some(MessageType::MessageType_DebugLinkPairingInfo), 450 => ::std::option::Option::Some(MessageType::MessageType_EthereumGetPublicKey), 451 => ::std::option::Option::Some(MessageType::MessageType_EthereumPublicKey), 56 => ::std::option::Option::Some(MessageType::MessageType_EthereumGetAddress), @@ -885,6 +891,8 @@ impl ::protobuf::Enum for MessageType { "MessageType_DebugLinkWatchLayout" => ::std::option::Option::Some(MessageType::MessageType_DebugLinkWatchLayout), "MessageType_DebugLinkResetDebugEvents" => ::std::option::Option::Some(MessageType::MessageType_DebugLinkResetDebugEvents), "MessageType_DebugLinkOptigaSetSecMax" => ::std::option::Option::Some(MessageType::MessageType_DebugLinkOptigaSetSecMax), + "MessageType_DebugLinkGetPairingInfo" => ::std::option::Option::Some(MessageType::MessageType_DebugLinkGetPairingInfo), + "MessageType_DebugLinkPairingInfo" => ::std::option::Option::Some(MessageType::MessageType_DebugLinkPairingInfo), "MessageType_EthereumGetPublicKey" => ::std::option::Option::Some(MessageType::MessageType_EthereumGetPublicKey), "MessageType_EthereumPublicKey" => ::std::option::Option::Some(MessageType::MessageType_EthereumPublicKey), "MessageType_EthereumGetAddress" => ::std::option::Option::Some(MessageType::MessageType_EthereumGetAddress), @@ -1137,6 +1145,8 @@ impl ::protobuf::Enum for MessageType { MessageType::MessageType_DebugLinkWatchLayout, MessageType::MessageType_DebugLinkResetDebugEvents, MessageType::MessageType_DebugLinkOptigaSetSecMax, + MessageType::MessageType_DebugLinkGetPairingInfo, + MessageType::MessageType_DebugLinkPairingInfo, MessageType::MessageType_EthereumGetPublicKey, MessageType::MessageType_EthereumPublicKey, MessageType::MessageType_EthereumGetAddress, @@ -1395,154 +1405,156 @@ impl ::protobuf::EnumFull for MessageType { MessageType::MessageType_DebugLinkWatchLayout => 96, MessageType::MessageType_DebugLinkResetDebugEvents => 97, MessageType::MessageType_DebugLinkOptigaSetSecMax => 98, - MessageType::MessageType_EthereumGetPublicKey => 99, - MessageType::MessageType_EthereumPublicKey => 100, - MessageType::MessageType_EthereumGetAddress => 101, - MessageType::MessageType_EthereumAddress => 102, - MessageType::MessageType_EthereumSignTx => 103, - MessageType::MessageType_EthereumSignTxEIP1559 => 104, - MessageType::MessageType_EthereumTxRequest => 105, - MessageType::MessageType_EthereumTxAck => 106, - MessageType::MessageType_EthereumSignMessage => 107, - MessageType::MessageType_EthereumVerifyMessage => 108, - MessageType::MessageType_EthereumMessageSignature => 109, - MessageType::MessageType_EthereumSignTypedData => 110, - MessageType::MessageType_EthereumTypedDataStructRequest => 111, - MessageType::MessageType_EthereumTypedDataStructAck => 112, - MessageType::MessageType_EthereumTypedDataValueRequest => 113, - MessageType::MessageType_EthereumTypedDataValueAck => 114, - MessageType::MessageType_EthereumTypedDataSignature => 115, - MessageType::MessageType_EthereumSignTypedHash => 116, - MessageType::MessageType_NEMGetAddress => 117, - MessageType::MessageType_NEMAddress => 118, - MessageType::MessageType_NEMSignTx => 119, - MessageType::MessageType_NEMSignedTx => 120, - MessageType::MessageType_NEMDecryptMessage => 121, - MessageType::MessageType_NEMDecryptedMessage => 122, - MessageType::MessageType_TezosGetAddress => 123, - MessageType::MessageType_TezosAddress => 124, - MessageType::MessageType_TezosSignTx => 125, - MessageType::MessageType_TezosSignedTx => 126, - MessageType::MessageType_TezosGetPublicKey => 127, - MessageType::MessageType_TezosPublicKey => 128, - MessageType::MessageType_StellarSignTx => 129, - MessageType::MessageType_StellarTxOpRequest => 130, - MessageType::MessageType_StellarGetAddress => 131, - MessageType::MessageType_StellarAddress => 132, - MessageType::MessageType_StellarCreateAccountOp => 133, - MessageType::MessageType_StellarPaymentOp => 134, - MessageType::MessageType_StellarPathPaymentStrictReceiveOp => 135, - MessageType::MessageType_StellarManageSellOfferOp => 136, - MessageType::MessageType_StellarCreatePassiveSellOfferOp => 137, - MessageType::MessageType_StellarSetOptionsOp => 138, - MessageType::MessageType_StellarChangeTrustOp => 139, - MessageType::MessageType_StellarAllowTrustOp => 140, - MessageType::MessageType_StellarAccountMergeOp => 141, - MessageType::MessageType_StellarManageDataOp => 142, - MessageType::MessageType_StellarBumpSequenceOp => 143, - MessageType::MessageType_StellarManageBuyOfferOp => 144, - MessageType::MessageType_StellarPathPaymentStrictSendOp => 145, - MessageType::MessageType_StellarClaimClaimableBalanceOp => 146, - MessageType::MessageType_StellarSignedTx => 147, - MessageType::MessageType_CardanoGetPublicKey => 148, - MessageType::MessageType_CardanoPublicKey => 149, - MessageType::MessageType_CardanoGetAddress => 150, - MessageType::MessageType_CardanoAddress => 151, - MessageType::MessageType_CardanoTxItemAck => 152, - MessageType::MessageType_CardanoTxAuxiliaryDataSupplement => 153, - MessageType::MessageType_CardanoTxWitnessRequest => 154, - MessageType::MessageType_CardanoTxWitnessResponse => 155, - MessageType::MessageType_CardanoTxHostAck => 156, - MessageType::MessageType_CardanoTxBodyHash => 157, - MessageType::MessageType_CardanoSignTxFinished => 158, - MessageType::MessageType_CardanoSignTxInit => 159, - MessageType::MessageType_CardanoTxInput => 160, - MessageType::MessageType_CardanoTxOutput => 161, - MessageType::MessageType_CardanoAssetGroup => 162, - MessageType::MessageType_CardanoToken => 163, - MessageType::MessageType_CardanoTxCertificate => 164, - MessageType::MessageType_CardanoTxWithdrawal => 165, - MessageType::MessageType_CardanoTxAuxiliaryData => 166, - MessageType::MessageType_CardanoPoolOwner => 167, - MessageType::MessageType_CardanoPoolRelayParameters => 168, - MessageType::MessageType_CardanoGetNativeScriptHash => 169, - MessageType::MessageType_CardanoNativeScriptHash => 170, - MessageType::MessageType_CardanoTxMint => 171, - MessageType::MessageType_CardanoTxCollateralInput => 172, - MessageType::MessageType_CardanoTxRequiredSigner => 173, - MessageType::MessageType_CardanoTxInlineDatumChunk => 174, - MessageType::MessageType_CardanoTxReferenceScriptChunk => 175, - MessageType::MessageType_CardanoTxReferenceInput => 176, - MessageType::MessageType_RippleGetAddress => 177, - MessageType::MessageType_RippleAddress => 178, - MessageType::MessageType_RippleSignTx => 179, - MessageType::MessageType_RippleSignedTx => 180, - MessageType::MessageType_MoneroTransactionInitRequest => 181, - MessageType::MessageType_MoneroTransactionInitAck => 182, - MessageType::MessageType_MoneroTransactionSetInputRequest => 183, - MessageType::MessageType_MoneroTransactionSetInputAck => 184, - MessageType::MessageType_MoneroTransactionInputViniRequest => 185, - MessageType::MessageType_MoneroTransactionInputViniAck => 186, - MessageType::MessageType_MoneroTransactionAllInputsSetRequest => 187, - MessageType::MessageType_MoneroTransactionAllInputsSetAck => 188, - MessageType::MessageType_MoneroTransactionSetOutputRequest => 189, - MessageType::MessageType_MoneroTransactionSetOutputAck => 190, - MessageType::MessageType_MoneroTransactionAllOutSetRequest => 191, - MessageType::MessageType_MoneroTransactionAllOutSetAck => 192, - MessageType::MessageType_MoneroTransactionSignInputRequest => 193, - MessageType::MessageType_MoneroTransactionSignInputAck => 194, - MessageType::MessageType_MoneroTransactionFinalRequest => 195, - MessageType::MessageType_MoneroTransactionFinalAck => 196, - MessageType::MessageType_MoneroKeyImageExportInitRequest => 197, - MessageType::MessageType_MoneroKeyImageExportInitAck => 198, - MessageType::MessageType_MoneroKeyImageSyncStepRequest => 199, - MessageType::MessageType_MoneroKeyImageSyncStepAck => 200, - MessageType::MessageType_MoneroKeyImageSyncFinalRequest => 201, - MessageType::MessageType_MoneroKeyImageSyncFinalAck => 202, - MessageType::MessageType_MoneroGetAddress => 203, - MessageType::MessageType_MoneroAddress => 204, - MessageType::MessageType_MoneroGetWatchKey => 205, - MessageType::MessageType_MoneroWatchKey => 206, - MessageType::MessageType_DebugMoneroDiagRequest => 207, - MessageType::MessageType_DebugMoneroDiagAck => 208, - MessageType::MessageType_MoneroGetTxKeyRequest => 209, - MessageType::MessageType_MoneroGetTxKeyAck => 210, - MessageType::MessageType_MoneroLiveRefreshStartRequest => 211, - MessageType::MessageType_MoneroLiveRefreshStartAck => 212, - MessageType::MessageType_MoneroLiveRefreshStepRequest => 213, - MessageType::MessageType_MoneroLiveRefreshStepAck => 214, - MessageType::MessageType_MoneroLiveRefreshFinalRequest => 215, - MessageType::MessageType_MoneroLiveRefreshFinalAck => 216, - MessageType::MessageType_EosGetPublicKey => 217, - MessageType::MessageType_EosPublicKey => 218, - MessageType::MessageType_EosSignTx => 219, - MessageType::MessageType_EosTxActionRequest => 220, - MessageType::MessageType_EosTxActionAck => 221, - MessageType::MessageType_EosSignedTx => 222, - MessageType::MessageType_BinanceGetAddress => 223, - MessageType::MessageType_BinanceAddress => 224, - MessageType::MessageType_BinanceGetPublicKey => 225, - MessageType::MessageType_BinancePublicKey => 226, - MessageType::MessageType_BinanceSignTx => 227, - MessageType::MessageType_BinanceTxRequest => 228, - MessageType::MessageType_BinanceTransferMsg => 229, - MessageType::MessageType_BinanceOrderMsg => 230, - MessageType::MessageType_BinanceCancelMsg => 231, - MessageType::MessageType_BinanceSignedTx => 232, - MessageType::MessageType_WebAuthnListResidentCredentials => 233, - MessageType::MessageType_WebAuthnCredentials => 234, - MessageType::MessageType_WebAuthnAddResidentCredential => 235, - MessageType::MessageType_WebAuthnRemoveResidentCredential => 236, - MessageType::MessageType_SolanaGetPublicKey => 237, - MessageType::MessageType_SolanaPublicKey => 238, - MessageType::MessageType_SolanaGetAddress => 239, - MessageType::MessageType_SolanaAddress => 240, - MessageType::MessageType_SolanaSignTx => 241, - MessageType::MessageType_SolanaTxSignature => 242, - MessageType::MessageType_BenchmarkListNames => 243, - MessageType::MessageType_BenchmarkNames => 244, - MessageType::MessageType_BenchmarkRun => 245, - MessageType::MessageType_BenchmarkResult => 246, + MessageType::MessageType_DebugLinkGetPairingInfo => 99, + MessageType::MessageType_DebugLinkPairingInfo => 100, + MessageType::MessageType_EthereumGetPublicKey => 101, + MessageType::MessageType_EthereumPublicKey => 102, + MessageType::MessageType_EthereumGetAddress => 103, + MessageType::MessageType_EthereumAddress => 104, + MessageType::MessageType_EthereumSignTx => 105, + MessageType::MessageType_EthereumSignTxEIP1559 => 106, + MessageType::MessageType_EthereumTxRequest => 107, + MessageType::MessageType_EthereumTxAck => 108, + MessageType::MessageType_EthereumSignMessage => 109, + MessageType::MessageType_EthereumVerifyMessage => 110, + MessageType::MessageType_EthereumMessageSignature => 111, + MessageType::MessageType_EthereumSignTypedData => 112, + MessageType::MessageType_EthereumTypedDataStructRequest => 113, + MessageType::MessageType_EthereumTypedDataStructAck => 114, + MessageType::MessageType_EthereumTypedDataValueRequest => 115, + MessageType::MessageType_EthereumTypedDataValueAck => 116, + MessageType::MessageType_EthereumTypedDataSignature => 117, + MessageType::MessageType_EthereumSignTypedHash => 118, + MessageType::MessageType_NEMGetAddress => 119, + MessageType::MessageType_NEMAddress => 120, + MessageType::MessageType_NEMSignTx => 121, + MessageType::MessageType_NEMSignedTx => 122, + MessageType::MessageType_NEMDecryptMessage => 123, + MessageType::MessageType_NEMDecryptedMessage => 124, + MessageType::MessageType_TezosGetAddress => 125, + MessageType::MessageType_TezosAddress => 126, + MessageType::MessageType_TezosSignTx => 127, + MessageType::MessageType_TezosSignedTx => 128, + MessageType::MessageType_TezosGetPublicKey => 129, + MessageType::MessageType_TezosPublicKey => 130, + MessageType::MessageType_StellarSignTx => 131, + MessageType::MessageType_StellarTxOpRequest => 132, + MessageType::MessageType_StellarGetAddress => 133, + MessageType::MessageType_StellarAddress => 134, + MessageType::MessageType_StellarCreateAccountOp => 135, + MessageType::MessageType_StellarPaymentOp => 136, + MessageType::MessageType_StellarPathPaymentStrictReceiveOp => 137, + MessageType::MessageType_StellarManageSellOfferOp => 138, + MessageType::MessageType_StellarCreatePassiveSellOfferOp => 139, + MessageType::MessageType_StellarSetOptionsOp => 140, + MessageType::MessageType_StellarChangeTrustOp => 141, + MessageType::MessageType_StellarAllowTrustOp => 142, + MessageType::MessageType_StellarAccountMergeOp => 143, + MessageType::MessageType_StellarManageDataOp => 144, + MessageType::MessageType_StellarBumpSequenceOp => 145, + MessageType::MessageType_StellarManageBuyOfferOp => 146, + MessageType::MessageType_StellarPathPaymentStrictSendOp => 147, + MessageType::MessageType_StellarClaimClaimableBalanceOp => 148, + MessageType::MessageType_StellarSignedTx => 149, + MessageType::MessageType_CardanoGetPublicKey => 150, + MessageType::MessageType_CardanoPublicKey => 151, + MessageType::MessageType_CardanoGetAddress => 152, + MessageType::MessageType_CardanoAddress => 153, + MessageType::MessageType_CardanoTxItemAck => 154, + MessageType::MessageType_CardanoTxAuxiliaryDataSupplement => 155, + MessageType::MessageType_CardanoTxWitnessRequest => 156, + MessageType::MessageType_CardanoTxWitnessResponse => 157, + MessageType::MessageType_CardanoTxHostAck => 158, + MessageType::MessageType_CardanoTxBodyHash => 159, + MessageType::MessageType_CardanoSignTxFinished => 160, + MessageType::MessageType_CardanoSignTxInit => 161, + MessageType::MessageType_CardanoTxInput => 162, + MessageType::MessageType_CardanoTxOutput => 163, + MessageType::MessageType_CardanoAssetGroup => 164, + MessageType::MessageType_CardanoToken => 165, + MessageType::MessageType_CardanoTxCertificate => 166, + MessageType::MessageType_CardanoTxWithdrawal => 167, + MessageType::MessageType_CardanoTxAuxiliaryData => 168, + MessageType::MessageType_CardanoPoolOwner => 169, + MessageType::MessageType_CardanoPoolRelayParameters => 170, + MessageType::MessageType_CardanoGetNativeScriptHash => 171, + MessageType::MessageType_CardanoNativeScriptHash => 172, + MessageType::MessageType_CardanoTxMint => 173, + MessageType::MessageType_CardanoTxCollateralInput => 174, + MessageType::MessageType_CardanoTxRequiredSigner => 175, + MessageType::MessageType_CardanoTxInlineDatumChunk => 176, + MessageType::MessageType_CardanoTxReferenceScriptChunk => 177, + MessageType::MessageType_CardanoTxReferenceInput => 178, + MessageType::MessageType_RippleGetAddress => 179, + MessageType::MessageType_RippleAddress => 180, + MessageType::MessageType_RippleSignTx => 181, + MessageType::MessageType_RippleSignedTx => 182, + MessageType::MessageType_MoneroTransactionInitRequest => 183, + MessageType::MessageType_MoneroTransactionInitAck => 184, + MessageType::MessageType_MoneroTransactionSetInputRequest => 185, + MessageType::MessageType_MoneroTransactionSetInputAck => 186, + MessageType::MessageType_MoneroTransactionInputViniRequest => 187, + MessageType::MessageType_MoneroTransactionInputViniAck => 188, + MessageType::MessageType_MoneroTransactionAllInputsSetRequest => 189, + MessageType::MessageType_MoneroTransactionAllInputsSetAck => 190, + MessageType::MessageType_MoneroTransactionSetOutputRequest => 191, + MessageType::MessageType_MoneroTransactionSetOutputAck => 192, + MessageType::MessageType_MoneroTransactionAllOutSetRequest => 193, + MessageType::MessageType_MoneroTransactionAllOutSetAck => 194, + MessageType::MessageType_MoneroTransactionSignInputRequest => 195, + MessageType::MessageType_MoneroTransactionSignInputAck => 196, + MessageType::MessageType_MoneroTransactionFinalRequest => 197, + MessageType::MessageType_MoneroTransactionFinalAck => 198, + MessageType::MessageType_MoneroKeyImageExportInitRequest => 199, + MessageType::MessageType_MoneroKeyImageExportInitAck => 200, + MessageType::MessageType_MoneroKeyImageSyncStepRequest => 201, + MessageType::MessageType_MoneroKeyImageSyncStepAck => 202, + MessageType::MessageType_MoneroKeyImageSyncFinalRequest => 203, + MessageType::MessageType_MoneroKeyImageSyncFinalAck => 204, + MessageType::MessageType_MoneroGetAddress => 205, + MessageType::MessageType_MoneroAddress => 206, + MessageType::MessageType_MoneroGetWatchKey => 207, + MessageType::MessageType_MoneroWatchKey => 208, + MessageType::MessageType_DebugMoneroDiagRequest => 209, + MessageType::MessageType_DebugMoneroDiagAck => 210, + MessageType::MessageType_MoneroGetTxKeyRequest => 211, + MessageType::MessageType_MoneroGetTxKeyAck => 212, + MessageType::MessageType_MoneroLiveRefreshStartRequest => 213, + MessageType::MessageType_MoneroLiveRefreshStartAck => 214, + MessageType::MessageType_MoneroLiveRefreshStepRequest => 215, + MessageType::MessageType_MoneroLiveRefreshStepAck => 216, + MessageType::MessageType_MoneroLiveRefreshFinalRequest => 217, + MessageType::MessageType_MoneroLiveRefreshFinalAck => 218, + MessageType::MessageType_EosGetPublicKey => 219, + MessageType::MessageType_EosPublicKey => 220, + MessageType::MessageType_EosSignTx => 221, + MessageType::MessageType_EosTxActionRequest => 222, + MessageType::MessageType_EosTxActionAck => 223, + MessageType::MessageType_EosSignedTx => 224, + MessageType::MessageType_BinanceGetAddress => 225, + MessageType::MessageType_BinanceAddress => 226, + MessageType::MessageType_BinanceGetPublicKey => 227, + MessageType::MessageType_BinancePublicKey => 228, + MessageType::MessageType_BinanceSignTx => 229, + MessageType::MessageType_BinanceTxRequest => 230, + MessageType::MessageType_BinanceTransferMsg => 231, + MessageType::MessageType_BinanceOrderMsg => 232, + MessageType::MessageType_BinanceCancelMsg => 233, + MessageType::MessageType_BinanceSignedTx => 234, + MessageType::MessageType_WebAuthnListResidentCredentials => 235, + MessageType::MessageType_WebAuthnCredentials => 236, + MessageType::MessageType_WebAuthnAddResidentCredential => 237, + MessageType::MessageType_WebAuthnRemoveResidentCredential => 238, + MessageType::MessageType_SolanaGetPublicKey => 239, + MessageType::MessageType_SolanaPublicKey => 240, + MessageType::MessageType_SolanaGetAddress => 241, + MessageType::MessageType_SolanaAddress => 242, + MessageType::MessageType_SolanaSignTx => 243, + MessageType::MessageType_SolanaTxSignature => 244, + MessageType::MessageType_BenchmarkListNames => 245, + MessageType::MessageType_BenchmarkNames => 246, + MessageType::MessageType_BenchmarkRun => 247, + MessageType::MessageType_BenchmarkResult => 248, }; Self::enum_descriptor().value_by_index(index) } @@ -1561,7 +1573,7 @@ impl MessageType { } static file_descriptor_proto_data: &'static [u8] = b"\ - \n\x0emessages.proto\x12\x12hw.trezor.messages\x1a\roptions.proto*\xe8U\ + \n\x0emessages.proto\x12\x12hw.trezor.messages\x1a\roptions.proto*\xcdV\ \n\x0bMessageType\x12(\n\x16MessageType_Initialize\x10\0\x1a\x0c\x80\xa6\ \x1d\x01\xb0\xb5\x18\x01\x90\xb5\x18\x01\x12\x1e\n\x10MessageType_Ping\ \x10\x01\x1a\x08\x80\xa6\x1d\x01\x90\xb5\x18\x01\x12%\n\x13MessageType_S\ @@ -1681,172 +1693,174 @@ static file_descriptor_proto_data: &'static [u8] = b"\ \x1d\x01\xa0\xb5\x18\x01\x124\n%MessageType_DebugLinkResetDebugEvents\ \x10\xafF\x1a\x08\x80\xa6\x1d\x01\xa0\xb5\x18\x01\x123\n$MessageType_Deb\ ugLinkOptigaSetSecMax\x10\xb0F\x1a\x08\x80\xa6\x1d\x01\xa0\xb5\x18\x01\ - \x12+\n\x20MessageType_EthereumGetPublicKey\x10\xc2\x03\x1a\x04\x90\xb5\ - \x18\x01\x12(\n\x1dMessageType_EthereumPublicKey\x10\xc3\x03\x1a\x04\x98\ - \xb5\x18\x01\x12(\n\x1eMessageType_EthereumGetAddress\x108\x1a\x04\x90\ - \xb5\x18\x01\x12%\n\x1bMessageType_EthereumAddress\x109\x1a\x04\x98\xb5\ - \x18\x01\x12$\n\x1aMessageType_EthereumSignTx\x10:\x1a\x04\x90\xb5\x18\ - \x01\x12,\n!MessageType_EthereumSignTxEIP1559\x10\xc4\x03\x1a\x04\x90\ - \xb5\x18\x01\x12'\n\x1dMessageType_EthereumTxRequest\x10;\x1a\x04\x98\ - \xb5\x18\x01\x12#\n\x19MessageType_EthereumTxAck\x10<\x1a\x04\x90\xb5\ - \x18\x01\x12)\n\x1fMessageType_EthereumSignMessage\x10@\x1a\x04\x90\xb5\ - \x18\x01\x12+\n!MessageType_EthereumVerifyMessage\x10A\x1a\x04\x90\xb5\ - \x18\x01\x12.\n$MessageType_EthereumMessageSignature\x10B\x1a\x04\x98\ - \xb5\x18\x01\x12,\n!MessageType_EthereumSignTypedData\x10\xd0\x03\x1a\ - \x04\x90\xb5\x18\x01\x125\n*MessageType_EthereumTypedDataStructRequest\ - \x10\xd1\x03\x1a\x04\x98\xb5\x18\x01\x121\n&MessageType_EthereumTypedDat\ - aStructAck\x10\xd2\x03\x1a\x04\x90\xb5\x18\x01\x124\n)MessageType_Ethere\ - umTypedDataValueRequest\x10\xd3\x03\x1a\x04\x98\xb5\x18\x01\x120\n%Messa\ - geType_EthereumTypedDataValueAck\x10\xd4\x03\x1a\x04\x90\xb5\x18\x01\x12\ - 1\n&MessageType_EthereumTypedDataSignature\x10\xd5\x03\x1a\x04\x98\xb5\ - \x18\x01\x12,\n!MessageType_EthereumSignTypedHash\x10\xd6\x03\x1a\x04\ - \x90\xb5\x18\x01\x12#\n\x19MessageType_NEMGetAddress\x10C\x1a\x04\x90\ - \xb5\x18\x01\x12\x20\n\x16MessageType_NEMAddress\x10D\x1a\x04\x98\xb5\ - \x18\x01\x12\x1f\n\x15MessageType_NEMSignTx\x10E\x1a\x04\x90\xb5\x18\x01\ - \x12!\n\x17MessageType_NEMSignedTx\x10F\x1a\x04\x98\xb5\x18\x01\x12'\n\ - \x1dMessageType_NEMDecryptMessage\x10K\x1a\x04\x90\xb5\x18\x01\x12)\n\ - \x1fMessageType_NEMDecryptedMessage\x10L\x1a\x04\x98\xb5\x18\x01\x12&\n\ - \x1bMessageType_TezosGetAddress\x10\x96\x01\x1a\x04\x90\xb5\x18\x01\x12#\ - \n\x18MessageType_TezosAddress\x10\x97\x01\x1a\x04\x98\xb5\x18\x01\x12\"\ - \n\x17MessageType_TezosSignTx\x10\x98\x01\x1a\x04\x90\xb5\x18\x01\x12$\n\ - \x19MessageType_TezosSignedTx\x10\x99\x01\x1a\x04\x98\xb5\x18\x01\x12(\n\ - \x1dMessageType_TezosGetPublicKey\x10\x9a\x01\x1a\x04\x90\xb5\x18\x01\ - \x12%\n\x1aMessageType_TezosPublicKey\x10\x9b\x01\x1a\x04\x98\xb5\x18\ - \x01\x12$\n\x19MessageType_StellarSignTx\x10\xca\x01\x1a\x04\x90\xb5\x18\ - \x01\x12)\n\x1eMessageType_StellarTxOpRequest\x10\xcb\x01\x1a\x04\x98\ - \xb5\x18\x01\x12(\n\x1dMessageType_StellarGetAddress\x10\xcf\x01\x1a\x04\ - \x90\xb5\x18\x01\x12%\n\x1aMessageType_StellarAddress\x10\xd0\x01\x1a\ - \x04\x98\xb5\x18\x01\x12-\n\"MessageType_StellarCreateAccountOp\x10\xd2\ - \x01\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessageType_StellarPaymentOp\x10\ - \xd3\x01\x1a\x04\x90\xb5\x18\x01\x128\n-MessageType_StellarPathPaymentSt\ - rictReceiveOp\x10\xd4\x01\x1a\x04\x90\xb5\x18\x01\x12/\n$MessageType_Ste\ - llarManageSellOfferOp\x10\xd5\x01\x1a\x04\x90\xb5\x18\x01\x126\n+Message\ - Type_StellarCreatePassiveSellOfferOp\x10\xd6\x01\x1a\x04\x90\xb5\x18\x01\ - \x12*\n\x1fMessageType_StellarSetOptionsOp\x10\xd7\x01\x1a\x04\x90\xb5\ - \x18\x01\x12+\n\x20MessageType_StellarChangeTrustOp\x10\xd8\x01\x1a\x04\ - \x90\xb5\x18\x01\x12*\n\x1fMessageType_StellarAllowTrustOp\x10\xd9\x01\ - \x1a\x04\x90\xb5\x18\x01\x12,\n!MessageType_StellarAccountMergeOp\x10\ - \xda\x01\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_StellarManageData\ - Op\x10\xdc\x01\x1a\x04\x90\xb5\x18\x01\x12,\n!MessageType_StellarBumpSeq\ - uenceOp\x10\xdd\x01\x1a\x04\x90\xb5\x18\x01\x12.\n#MessageType_StellarMa\ - nageBuyOfferOp\x10\xde\x01\x1a\x04\x90\xb5\x18\x01\x125\n*MessageType_St\ - ellarPathPaymentStrictSendOp\x10\xdf\x01\x1a\x04\x90\xb5\x18\x01\x125\n*\ - MessageType_StellarClaimClaimableBalanceOp\x10\xe1\x01\x1a\x04\x90\xb5\ - \x18\x01\x12&\n\x1bMessageType_StellarSignedTx\x10\xe6\x01\x1a\x04\x98\ - \xb5\x18\x01\x12*\n\x1fMessageType_CardanoGetPublicKey\x10\xb1\x02\x1a\ - \x04\x90\xb5\x18\x01\x12'\n\x1cMessageType_CardanoPublicKey\x10\xb2\x02\ - \x1a\x04\x98\xb5\x18\x01\x12(\n\x1dMessageType_CardanoGetAddress\x10\xb3\ - \x02\x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMessageType_CardanoAddress\x10\ - \xb4\x02\x1a\x04\x98\xb5\x18\x01\x12'\n\x1cMessageType_CardanoTxItemAck\ - \x10\xb9\x02\x1a\x04\x98\xb5\x18\x01\x127\n,MessageType_CardanoTxAuxilia\ - ryDataSupplement\x10\xba\x02\x1a\x04\x98\xb5\x18\x01\x12.\n#MessageType_\ - CardanoTxWitnessRequest\x10\xbb\x02\x1a\x04\x90\xb5\x18\x01\x12/\n$Messa\ - geType_CardanoTxWitnessResponse\x10\xbc\x02\x1a\x04\x98\xb5\x18\x01\x12'\ - \n\x1cMessageType_CardanoTxHostAck\x10\xbd\x02\x1a\x04\x90\xb5\x18\x01\ - \x12(\n\x1dMessageType_CardanoTxBodyHash\x10\xbe\x02\x1a\x04\x98\xb5\x18\ - \x01\x12,\n!MessageType_CardanoSignTxFinished\x10\xbf\x02\x1a\x04\x98\ - \xb5\x18\x01\x12(\n\x1dMessageType_CardanoSignTxInit\x10\xc0\x02\x1a\x04\ - \x90\xb5\x18\x01\x12%\n\x1aMessageType_CardanoTxInput\x10\xc1\x02\x1a\ - \x04\x90\xb5\x18\x01\x12&\n\x1bMessageType_CardanoTxOutput\x10\xc2\x02\ - \x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_CardanoAssetGroup\x10\xc3\ - \x02\x1a\x04\x90\xb5\x18\x01\x12#\n\x18MessageType_CardanoToken\x10\xc4\ - \x02\x1a\x04\x90\xb5\x18\x01\x12+\n\x20MessageType_CardanoTxCertificate\ - \x10\xc5\x02\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_CardanoTxWith\ - drawal\x10\xc6\x02\x1a\x04\x90\xb5\x18\x01\x12-\n\"MessageType_CardanoTx\ - AuxiliaryData\x10\xc7\x02\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessageType_\ - CardanoPoolOwner\x10\xc8\x02\x1a\x04\x90\xb5\x18\x01\x121\n&MessageType_\ - CardanoPoolRelayParameters\x10\xc9\x02\x1a\x04\x90\xb5\x18\x01\x121\n&Me\ - ssageType_CardanoGetNativeScriptHash\x10\xca\x02\x1a\x04\x90\xb5\x18\x01\ - \x12.\n#MessageType_CardanoNativeScriptHash\x10\xcb\x02\x1a\x04\x98\xb5\ - \x18\x01\x12$\n\x19MessageType_CardanoTxMint\x10\xcc\x02\x1a\x04\x90\xb5\ - \x18\x01\x12/\n$MessageType_CardanoTxCollateralInput\x10\xcd\x02\x1a\x04\ - \x90\xb5\x18\x01\x12.\n#MessageType_CardanoTxRequiredSigner\x10\xce\x02\ - \x1a\x04\x90\xb5\x18\x01\x120\n%MessageType_CardanoTxInlineDatumChunk\ - \x10\xcf\x02\x1a\x04\x90\xb5\x18\x01\x124\n)MessageType_CardanoTxReferen\ - ceScriptChunk\x10\xd0\x02\x1a\x04\x90\xb5\x18\x01\x12.\n#MessageType_Car\ - danoTxReferenceInput\x10\xd1\x02\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessa\ - geType_RippleGetAddress\x10\x90\x03\x1a\x04\x90\xb5\x18\x01\x12$\n\x19Me\ - ssageType_RippleAddress\x10\x91\x03\x1a\x04\x98\xb5\x18\x01\x12#\n\x18Me\ - ssageType_RippleSignTx\x10\x92\x03\x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMes\ - sageType_RippleSignedTx\x10\x93\x03\x1a\x04\x90\xb5\x18\x01\x123\n(Messa\ - geType_MoneroTransactionInitRequest\x10\xf5\x03\x1a\x04\x98\xb5\x18\x01\ - \x12/\n$MessageType_MoneroTransactionInitAck\x10\xf6\x03\x1a\x04\x98\xb5\ - \x18\x01\x127\n,MessageType_MoneroTransactionSetInputRequest\x10\xf7\x03\ - \x1a\x04\x98\xb5\x18\x01\x123\n(MessageType_MoneroTransactionSetInputAck\ - \x10\xf8\x03\x1a\x04\x98\xb5\x18\x01\x128\n-MessageType_MoneroTransactio\ - nInputViniRequest\x10\xfb\x03\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType\ - _MoneroTransactionInputViniAck\x10\xfc\x03\x1a\x04\x98\xb5\x18\x01\x12;\ - \n0MessageType_MoneroTransactionAllInputsSetRequest\x10\xfd\x03\x1a\x04\ - \x98\xb5\x18\x01\x127\n,MessageType_MoneroTransactionAllInputsSetAck\x10\ - \xfe\x03\x1a\x04\x98\xb5\x18\x01\x128\n-MessageType_MoneroTransactionSet\ - OutputRequest\x10\xff\x03\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_Mon\ - eroTransactionSetOutputAck\x10\x80\x04\x1a\x04\x98\xb5\x18\x01\x128\n-Me\ - ssageType_MoneroTransactionAllOutSetRequest\x10\x81\x04\x1a\x04\x98\xb5\ - \x18\x01\x124\n)MessageType_MoneroTransactionAllOutSetAck\x10\x82\x04\ - \x1a\x04\x98\xb5\x18\x01\x128\n-MessageType_MoneroTransactionSignInputRe\ - quest\x10\x83\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_MoneroTrans\ - actionSignInputAck\x10\x84\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageTyp\ - e_MoneroTransactionFinalRequest\x10\x85\x04\x1a\x04\x98\xb5\x18\x01\x120\ - \n%MessageType_MoneroTransactionFinalAck\x10\x86\x04\x1a\x04\x98\xb5\x18\ - \x01\x126\n+MessageType_MoneroKeyImageExportInitRequest\x10\x92\x04\x1a\ - \x04\x98\xb5\x18\x01\x122\n'MessageType_MoneroKeyImageExportInitAck\x10\ - \x93\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_MoneroKeyImageSyncSt\ - epRequest\x10\x94\x04\x1a\x04\x98\xb5\x18\x01\x120\n%MessageType_MoneroK\ - eyImageSyncStepAck\x10\x95\x04\x1a\x04\x98\xb5\x18\x01\x125\n*MessageTyp\ - e_MoneroKeyImageSyncFinalRequest\x10\x96\x04\x1a\x04\x98\xb5\x18\x01\x12\ - 1\n&MessageType_MoneroKeyImageSyncFinalAck\x10\x97\x04\x1a\x04\x98\xb5\ - \x18\x01\x12'\n\x1cMessageType_MoneroGetAddress\x10\x9c\x04\x1a\x04\x90\ - \xb5\x18\x01\x12$\n\x19MessageType_MoneroAddress\x10\x9d\x04\x1a\x04\x98\ - \xb5\x18\x01\x12(\n\x1dMessageType_MoneroGetWatchKey\x10\x9e\x04\x1a\x04\ - \x90\xb5\x18\x01\x12%\n\x1aMessageType_MoneroWatchKey\x10\x9f\x04\x1a\ - \x04\x98\xb5\x18\x01\x12-\n\"MessageType_DebugMoneroDiagRequest\x10\xa2\ - \x04\x1a\x04\x90\xb5\x18\x01\x12)\n\x1eMessageType_DebugMoneroDiagAck\ - \x10\xa3\x04\x1a\x04\x98\xb5\x18\x01\x12,\n!MessageType_MoneroGetTxKeyRe\ - quest\x10\xa6\x04\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_MoneroGe\ - tTxKeyAck\x10\xa7\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_MoneroL\ - iveRefreshStartRequest\x10\xa8\x04\x1a\x04\x90\xb5\x18\x01\x120\n%Messag\ - eType_MoneroLiveRefreshStartAck\x10\xa9\x04\x1a\x04\x98\xb5\x18\x01\x123\ - \n(MessageType_MoneroLiveRefreshStepRequest\x10\xaa\x04\x1a\x04\x90\xb5\ - \x18\x01\x12/\n$MessageType_MoneroLiveRefreshStepAck\x10\xab\x04\x1a\x04\ - \x98\xb5\x18\x01\x124\n)MessageType_MoneroLiveRefreshFinalRequest\x10\ - \xac\x04\x1a\x04\x90\xb5\x18\x01\x120\n%MessageType_MoneroLiveRefreshFin\ - alAck\x10\xad\x04\x1a\x04\x98\xb5\x18\x01\x12&\n\x1bMessageType_EosGetPu\ - blicKey\x10\xd8\x04\x1a\x04\x90\xb5\x18\x01\x12#\n\x18MessageType_EosPub\ - licKey\x10\xd9\x04\x1a\x04\x98\xb5\x18\x01\x12\x20\n\x15MessageType_EosS\ - ignTx\x10\xda\x04\x1a\x04\x90\xb5\x18\x01\x12)\n\x1eMessageType_EosTxAct\ - ionRequest\x10\xdb\x04\x1a\x04\x98\xb5\x18\x01\x12%\n\x1aMessageType_Eos\ - TxActionAck\x10\xdc\x04\x1a\x04\x90\xb5\x18\x01\x12\"\n\x17MessageType_E\ - osSignedTx\x10\xdd\x04\x1a\x04\x98\xb5\x18\x01\x12(\n\x1dMessageType_Bin\ - anceGetAddress\x10\xbc\x05\x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMessageType\ - _BinanceAddress\x10\xbd\x05\x1a\x04\x98\xb5\x18\x01\x12*\n\x1fMessageTyp\ - e_BinanceGetPublicKey\x10\xbe\x05\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMess\ - ageType_BinancePublicKey\x10\xbf\x05\x1a\x04\x98\xb5\x18\x01\x12$\n\x19M\ - essageType_BinanceSignTx\x10\xc0\x05\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cM\ - essageType_BinanceTxRequest\x10\xc1\x05\x1a\x04\x98\xb5\x18\x01\x12)\n\ - \x1eMessageType_BinanceTransferMsg\x10\xc2\x05\x1a\x04\x90\xb5\x18\x01\ - \x12&\n\x1bMessageType_BinanceOrderMsg\x10\xc3\x05\x1a\x04\x90\xb5\x18\ - \x01\x12'\n\x1cMessageType_BinanceCancelMsg\x10\xc4\x05\x1a\x04\x90\xb5\ - \x18\x01\x12&\n\x1bMessageType_BinanceSignedTx\x10\xc5\x05\x1a\x04\x98\ - \xb5\x18\x01\x126\n+MessageType_WebAuthnListResidentCredentials\x10\xa0\ - \x06\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_WebAuthnCredentials\ - \x10\xa1\x06\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_WebAuthnAddResid\ - entCredential\x10\xa2\x06\x1a\x04\x90\xb5\x18\x01\x127\n,MessageType_Web\ - AuthnRemoveResidentCredential\x10\xa3\x06\x1a\x04\x90\xb5\x18\x01\x12)\n\ - \x1eMessageType_SolanaGetPublicKey\x10\x84\x07\x1a\x04\x90\xb5\x18\x01\ - \x12&\n\x1bMessageType_SolanaPublicKey\x10\x85\x07\x1a\x04\x98\xb5\x18\ - \x01\x12'\n\x1cMessageType_SolanaGetAddress\x10\x86\x07\x1a\x04\x90\xb5\ - \x18\x01\x12$\n\x19MessageType_SolanaAddress\x10\x87\x07\x1a\x04\x98\xb5\ - \x18\x01\x12#\n\x18MessageType_SolanaSignTx\x10\x88\x07\x1a\x04\x90\xb5\ - \x18\x01\x12(\n\x1dMessageType_SolanaTxSignature\x10\x89\x07\x1a\x04\x98\ - \xb5\x18\x01\x12)\n\x1eMessageType_BenchmarkListNames\x10\x8cG\x1a\x04\ - \x80\xa6\x1d\x01\x12%\n\x1aMessageType_BenchmarkNames\x10\x8dG\x1a\x04\ - \x80\xa6\x1d\x01\x12#\n\x18MessageType_BenchmarkRun\x10\x8eG\x1a\x04\x80\ - \xa6\x1d\x01\x12&\n\x1bMessageType_BenchmarkResult\x10\x8fG\x1a\x04\x80\ - \xa6\x1d\x01\x1a\x04\xc8\xf3\x18\x01\"\x04\x08Z\x10\\\"\x04\x08G\x10J\"\ - \x04\x08r\x10z\"\x06\x08\xdb\x01\x10\xdb\x01\"\x06\x08\xe0\x01\x10\xe0\ - \x01\"\x06\x08\xac\x02\x10\xb0\x02\"\x06\x08\xb5\x02\x10\xb8\x02\"\x06\ - \x08\xe8\x07\x10\xcb\x08B8\n#com.satoshilabs.trezor.lib.protobufB\rTrezo\ - rMessage\x80\xa6\x1d\x01\ + \x122\n#MessageType_DebugLinkGetPairingInfo\x10\xb1F\x1a\x08\x80\xa6\x1d\ + \x01\xa0\xb5\x18\x01\x12/\n\x20MessageType_DebugLinkPairingInfo\x10\xb2F\ + \x1a\x08\x80\xa6\x1d\x01\xa8\xb5\x18\x01\x12+\n\x20MessageType_EthereumG\ + etPublicKey\x10\xc2\x03\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_Et\ + hereumPublicKey\x10\xc3\x03\x1a\x04\x98\xb5\x18\x01\x12(\n\x1eMessageTyp\ + e_EthereumGetAddress\x108\x1a\x04\x90\xb5\x18\x01\x12%\n\x1bMessageType_\ + EthereumAddress\x109\x1a\x04\x98\xb5\x18\x01\x12$\n\x1aMessageType_Ether\ + eumSignTx\x10:\x1a\x04\x90\xb5\x18\x01\x12,\n!MessageType_EthereumSignTx\ + EIP1559\x10\xc4\x03\x1a\x04\x90\xb5\x18\x01\x12'\n\x1dMessageType_Ethere\ + umTxRequest\x10;\x1a\x04\x98\xb5\x18\x01\x12#\n\x19MessageType_EthereumT\ + xAck\x10<\x1a\x04\x90\xb5\x18\x01\x12)\n\x1fMessageType_EthereumSignMess\ + age\x10@\x1a\x04\x90\xb5\x18\x01\x12+\n!MessageType_EthereumVerifyMessag\ + e\x10A\x1a\x04\x90\xb5\x18\x01\x12.\n$MessageType_EthereumMessageSignatu\ + re\x10B\x1a\x04\x98\xb5\x18\x01\x12,\n!MessageType_EthereumSignTypedData\ + \x10\xd0\x03\x1a\x04\x90\xb5\x18\x01\x125\n*MessageType_EthereumTypedDat\ + aStructRequest\x10\xd1\x03\x1a\x04\x98\xb5\x18\x01\x121\n&MessageType_Et\ + hereumTypedDataStructAck\x10\xd2\x03\x1a\x04\x90\xb5\x18\x01\x124\n)Mess\ + ageType_EthereumTypedDataValueRequest\x10\xd3\x03\x1a\x04\x98\xb5\x18\ + \x01\x120\n%MessageType_EthereumTypedDataValueAck\x10\xd4\x03\x1a\x04\ + \x90\xb5\x18\x01\x121\n&MessageType_EthereumTypedDataSignature\x10\xd5\ + \x03\x1a\x04\x98\xb5\x18\x01\x12,\n!MessageType_EthereumSignTypedHash\ + \x10\xd6\x03\x1a\x04\x90\xb5\x18\x01\x12#\n\x19MessageType_NEMGetAddress\ + \x10C\x1a\x04\x90\xb5\x18\x01\x12\x20\n\x16MessageType_NEMAddress\x10D\ + \x1a\x04\x98\xb5\x18\x01\x12\x1f\n\x15MessageType_NEMSignTx\x10E\x1a\x04\ + \x90\xb5\x18\x01\x12!\n\x17MessageType_NEMSignedTx\x10F\x1a\x04\x98\xb5\ + \x18\x01\x12'\n\x1dMessageType_NEMDecryptMessage\x10K\x1a\x04\x90\xb5\ + \x18\x01\x12)\n\x1fMessageType_NEMDecryptedMessage\x10L\x1a\x04\x98\xb5\ + \x18\x01\x12&\n\x1bMessageType_TezosGetAddress\x10\x96\x01\x1a\x04\x90\ + \xb5\x18\x01\x12#\n\x18MessageType_TezosAddress\x10\x97\x01\x1a\x04\x98\ + \xb5\x18\x01\x12\"\n\x17MessageType_TezosSignTx\x10\x98\x01\x1a\x04\x90\ + \xb5\x18\x01\x12$\n\x19MessageType_TezosSignedTx\x10\x99\x01\x1a\x04\x98\ + \xb5\x18\x01\x12(\n\x1dMessageType_TezosGetPublicKey\x10\x9a\x01\x1a\x04\ + \x90\xb5\x18\x01\x12%\n\x1aMessageType_TezosPublicKey\x10\x9b\x01\x1a\ + \x04\x98\xb5\x18\x01\x12$\n\x19MessageType_StellarSignTx\x10\xca\x01\x1a\ + \x04\x90\xb5\x18\x01\x12)\n\x1eMessageType_StellarTxOpRequest\x10\xcb\ + \x01\x1a\x04\x98\xb5\x18\x01\x12(\n\x1dMessageType_StellarGetAddress\x10\ + \xcf\x01\x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMessageType_StellarAddress\ + \x10\xd0\x01\x1a\x04\x98\xb5\x18\x01\x12-\n\"MessageType_StellarCreateAc\ + countOp\x10\xd2\x01\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessageType_Stella\ + rPaymentOp\x10\xd3\x01\x1a\x04\x90\xb5\x18\x01\x128\n-MessageType_Stella\ + rPathPaymentStrictReceiveOp\x10\xd4\x01\x1a\x04\x90\xb5\x18\x01\x12/\n$M\ + essageType_StellarManageSellOfferOp\x10\xd5\x01\x1a\x04\x90\xb5\x18\x01\ + \x126\n+MessageType_StellarCreatePassiveSellOfferOp\x10\xd6\x01\x1a\x04\ + \x90\xb5\x18\x01\x12*\n\x1fMessageType_StellarSetOptionsOp\x10\xd7\x01\ + \x1a\x04\x90\xb5\x18\x01\x12+\n\x20MessageType_StellarChangeTrustOp\x10\ + \xd8\x01\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_StellarAllowTrust\ + Op\x10\xd9\x01\x1a\x04\x90\xb5\x18\x01\x12,\n!MessageType_StellarAccount\ + MergeOp\x10\xda\x01\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_Stella\ + rManageDataOp\x10\xdc\x01\x1a\x04\x90\xb5\x18\x01\x12,\n!MessageType_Ste\ + llarBumpSequenceOp\x10\xdd\x01\x1a\x04\x90\xb5\x18\x01\x12.\n#MessageTyp\ + e_StellarManageBuyOfferOp\x10\xde\x01\x1a\x04\x90\xb5\x18\x01\x125\n*Mes\ + sageType_StellarPathPaymentStrictSendOp\x10\xdf\x01\x1a\x04\x90\xb5\x18\ + \x01\x125\n*MessageType_StellarClaimClaimableBalanceOp\x10\xe1\x01\x1a\ + \x04\x90\xb5\x18\x01\x12&\n\x1bMessageType_StellarSignedTx\x10\xe6\x01\ + \x1a\x04\x98\xb5\x18\x01\x12*\n\x1fMessageType_CardanoGetPublicKey\x10\ + \xb1\x02\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessageType_CardanoPublicKey\ + \x10\xb2\x02\x1a\x04\x98\xb5\x18\x01\x12(\n\x1dMessageType_CardanoGetAdd\ + ress\x10\xb3\x02\x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMessageType_CardanoAd\ + dress\x10\xb4\x02\x1a\x04\x98\xb5\x18\x01\x12'\n\x1cMessageType_CardanoT\ + xItemAck\x10\xb9\x02\x1a\x04\x98\xb5\x18\x01\x127\n,MessageType_CardanoT\ + xAuxiliaryDataSupplement\x10\xba\x02\x1a\x04\x98\xb5\x18\x01\x12.\n#Mess\ + ageType_CardanoTxWitnessRequest\x10\xbb\x02\x1a\x04\x90\xb5\x18\x01\x12/\ + \n$MessageType_CardanoTxWitnessResponse\x10\xbc\x02\x1a\x04\x98\xb5\x18\ + \x01\x12'\n\x1cMessageType_CardanoTxHostAck\x10\xbd\x02\x1a\x04\x90\xb5\ + \x18\x01\x12(\n\x1dMessageType_CardanoTxBodyHash\x10\xbe\x02\x1a\x04\x98\ + \xb5\x18\x01\x12,\n!MessageType_CardanoSignTxFinished\x10\xbf\x02\x1a\ + \x04\x98\xb5\x18\x01\x12(\n\x1dMessageType_CardanoSignTxInit\x10\xc0\x02\ + \x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMessageType_CardanoTxInput\x10\xc1\ + \x02\x1a\x04\x90\xb5\x18\x01\x12&\n\x1bMessageType_CardanoTxOutput\x10\ + \xc2\x02\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_CardanoAssetGroup\ + \x10\xc3\x02\x1a\x04\x90\xb5\x18\x01\x12#\n\x18MessageType_CardanoToken\ + \x10\xc4\x02\x1a\x04\x90\xb5\x18\x01\x12+\n\x20MessageType_CardanoTxCert\ + ificate\x10\xc5\x02\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_Cardan\ + oTxWithdrawal\x10\xc6\x02\x1a\x04\x90\xb5\x18\x01\x12-\n\"MessageType_Ca\ + rdanoTxAuxiliaryData\x10\xc7\x02\x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessa\ + geType_CardanoPoolOwner\x10\xc8\x02\x1a\x04\x90\xb5\x18\x01\x121\n&Messa\ + geType_CardanoPoolRelayParameters\x10\xc9\x02\x1a\x04\x90\xb5\x18\x01\ + \x121\n&MessageType_CardanoGetNativeScriptHash\x10\xca\x02\x1a\x04\x90\ + \xb5\x18\x01\x12.\n#MessageType_CardanoNativeScriptHash\x10\xcb\x02\x1a\ + \x04\x98\xb5\x18\x01\x12$\n\x19MessageType_CardanoTxMint\x10\xcc\x02\x1a\ + \x04\x90\xb5\x18\x01\x12/\n$MessageType_CardanoTxCollateralInput\x10\xcd\ + \x02\x1a\x04\x90\xb5\x18\x01\x12.\n#MessageType_CardanoTxRequiredSigner\ + \x10\xce\x02\x1a\x04\x90\xb5\x18\x01\x120\n%MessageType_CardanoTxInlineD\ + atumChunk\x10\xcf\x02\x1a\x04\x90\xb5\x18\x01\x124\n)MessageType_Cardano\ + TxReferenceScriptChunk\x10\xd0\x02\x1a\x04\x90\xb5\x18\x01\x12.\n#Messag\ + eType_CardanoTxReferenceInput\x10\xd1\x02\x1a\x04\x90\xb5\x18\x01\x12'\n\ + \x1cMessageType_RippleGetAddress\x10\x90\x03\x1a\x04\x90\xb5\x18\x01\x12\ + $\n\x19MessageType_RippleAddress\x10\x91\x03\x1a\x04\x98\xb5\x18\x01\x12\ + #\n\x18MessageType_RippleSignTx\x10\x92\x03\x1a\x04\x90\xb5\x18\x01\x12%\ + \n\x1aMessageType_RippleSignedTx\x10\x93\x03\x1a\x04\x90\xb5\x18\x01\x12\ + 3\n(MessageType_MoneroTransactionInitRequest\x10\xf5\x03\x1a\x04\x98\xb5\ + \x18\x01\x12/\n$MessageType_MoneroTransactionInitAck\x10\xf6\x03\x1a\x04\ + \x98\xb5\x18\x01\x127\n,MessageType_MoneroTransactionSetInputRequest\x10\ + \xf7\x03\x1a\x04\x98\xb5\x18\x01\x123\n(MessageType_MoneroTransactionSet\ + InputAck\x10\xf8\x03\x1a\x04\x98\xb5\x18\x01\x128\n-MessageType_MoneroTr\ + ansactionInputViniRequest\x10\xfb\x03\x1a\x04\x98\xb5\x18\x01\x124\n)Mes\ + sageType_MoneroTransactionInputViniAck\x10\xfc\x03\x1a\x04\x98\xb5\x18\ + \x01\x12;\n0MessageType_MoneroTransactionAllInputsSetRequest\x10\xfd\x03\ + \x1a\x04\x98\xb5\x18\x01\x127\n,MessageType_MoneroTransactionAllInputsSe\ + tAck\x10\xfe\x03\x1a\x04\x98\xb5\x18\x01\x128\n-MessageType_MoneroTransa\ + ctionSetOutputRequest\x10\xff\x03\x1a\x04\x98\xb5\x18\x01\x124\n)Message\ + Type_MoneroTransactionSetOutputAck\x10\x80\x04\x1a\x04\x98\xb5\x18\x01\ + \x128\n-MessageType_MoneroTransactionAllOutSetRequest\x10\x81\x04\x1a\ + \x04\x98\xb5\x18\x01\x124\n)MessageType_MoneroTransactionAllOutSetAck\ + \x10\x82\x04\x1a\x04\x98\xb5\x18\x01\x128\n-MessageType_MoneroTransactio\ + nSignInputRequest\x10\x83\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType\ + _MoneroTransactionSignInputAck\x10\x84\x04\x1a\x04\x98\xb5\x18\x01\x124\ + \n)MessageType_MoneroTransactionFinalRequest\x10\x85\x04\x1a\x04\x98\xb5\ + \x18\x01\x120\n%MessageType_MoneroTransactionFinalAck\x10\x86\x04\x1a\ + \x04\x98\xb5\x18\x01\x126\n+MessageType_MoneroKeyImageExportInitRequest\ + \x10\x92\x04\x1a\x04\x98\xb5\x18\x01\x122\n'MessageType_MoneroKeyImageEx\ + portInitAck\x10\x93\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_Moner\ + oKeyImageSyncStepRequest\x10\x94\x04\x1a\x04\x98\xb5\x18\x01\x120\n%Mess\ + ageType_MoneroKeyImageSyncStepAck\x10\x95\x04\x1a\x04\x98\xb5\x18\x01\ + \x125\n*MessageType_MoneroKeyImageSyncFinalRequest\x10\x96\x04\x1a\x04\ + \x98\xb5\x18\x01\x121\n&MessageType_MoneroKeyImageSyncFinalAck\x10\x97\ + \x04\x1a\x04\x98\xb5\x18\x01\x12'\n\x1cMessageType_MoneroGetAddress\x10\ + \x9c\x04\x1a\x04\x90\xb5\x18\x01\x12$\n\x19MessageType_MoneroAddress\x10\ + \x9d\x04\x1a\x04\x98\xb5\x18\x01\x12(\n\x1dMessageType_MoneroGetWatchKey\ + \x10\x9e\x04\x1a\x04\x90\xb5\x18\x01\x12%\n\x1aMessageType_MoneroWatchKe\ + y\x10\x9f\x04\x1a\x04\x98\xb5\x18\x01\x12-\n\"MessageType_DebugMoneroDia\ + gRequest\x10\xa2\x04\x1a\x04\x90\xb5\x18\x01\x12)\n\x1eMessageType_Debug\ + MoneroDiagAck\x10\xa3\x04\x1a\x04\x98\xb5\x18\x01\x12,\n!MessageType_Mon\ + eroGetTxKeyRequest\x10\xa6\x04\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessage\ + Type_MoneroGetTxKeyAck\x10\xa7\x04\x1a\x04\x98\xb5\x18\x01\x124\n)Messag\ + eType_MoneroLiveRefreshStartRequest\x10\xa8\x04\x1a\x04\x90\xb5\x18\x01\ + \x120\n%MessageType_MoneroLiveRefreshStartAck\x10\xa9\x04\x1a\x04\x98\ + \xb5\x18\x01\x123\n(MessageType_MoneroLiveRefreshStepRequest\x10\xaa\x04\ + \x1a\x04\x90\xb5\x18\x01\x12/\n$MessageType_MoneroLiveRefreshStepAck\x10\ + \xab\x04\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType_MoneroLiveRefreshFin\ + alRequest\x10\xac\x04\x1a\x04\x90\xb5\x18\x01\x120\n%MessageType_MoneroL\ + iveRefreshFinalAck\x10\xad\x04\x1a\x04\x98\xb5\x18\x01\x12&\n\x1bMessage\ + Type_EosGetPublicKey\x10\xd8\x04\x1a\x04\x90\xb5\x18\x01\x12#\n\x18Messa\ + geType_EosPublicKey\x10\xd9\x04\x1a\x04\x98\xb5\x18\x01\x12\x20\n\x15Mes\ + sageType_EosSignTx\x10\xda\x04\x1a\x04\x90\xb5\x18\x01\x12)\n\x1eMessage\ + Type_EosTxActionRequest\x10\xdb\x04\x1a\x04\x98\xb5\x18\x01\x12%\n\x1aMe\ + ssageType_EosTxActionAck\x10\xdc\x04\x1a\x04\x90\xb5\x18\x01\x12\"\n\x17\ + MessageType_EosSignedTx\x10\xdd\x04\x1a\x04\x98\xb5\x18\x01\x12(\n\x1dMe\ + ssageType_BinanceGetAddress\x10\xbc\x05\x1a\x04\x90\xb5\x18\x01\x12%\n\ + \x1aMessageType_BinanceAddress\x10\xbd\x05\x1a\x04\x98\xb5\x18\x01\x12*\ + \n\x1fMessageType_BinanceGetPublicKey\x10\xbe\x05\x1a\x04\x90\xb5\x18\ + \x01\x12'\n\x1cMessageType_BinancePublicKey\x10\xbf\x05\x1a\x04\x98\xb5\ + \x18\x01\x12$\n\x19MessageType_BinanceSignTx\x10\xc0\x05\x1a\x04\x90\xb5\ + \x18\x01\x12'\n\x1cMessageType_BinanceTxRequest\x10\xc1\x05\x1a\x04\x98\ + \xb5\x18\x01\x12)\n\x1eMessageType_BinanceTransferMsg\x10\xc2\x05\x1a\ + \x04\x90\xb5\x18\x01\x12&\n\x1bMessageType_BinanceOrderMsg\x10\xc3\x05\ + \x1a\x04\x90\xb5\x18\x01\x12'\n\x1cMessageType_BinanceCancelMsg\x10\xc4\ + \x05\x1a\x04\x90\xb5\x18\x01\x12&\n\x1bMessageType_BinanceSignedTx\x10\ + \xc5\x05\x1a\x04\x98\xb5\x18\x01\x126\n+MessageType_WebAuthnListResident\ + Credentials\x10\xa0\x06\x1a\x04\x90\xb5\x18\x01\x12*\n\x1fMessageType_We\ + bAuthnCredentials\x10\xa1\x06\x1a\x04\x98\xb5\x18\x01\x124\n)MessageType\ + _WebAuthnAddResidentCredential\x10\xa2\x06\x1a\x04\x90\xb5\x18\x01\x127\ + \n,MessageType_WebAuthnRemoveResidentCredential\x10\xa3\x06\x1a\x04\x90\ + \xb5\x18\x01\x12)\n\x1eMessageType_SolanaGetPublicKey\x10\x84\x07\x1a\ + \x04\x90\xb5\x18\x01\x12&\n\x1bMessageType_SolanaPublicKey\x10\x85\x07\ + \x1a\x04\x98\xb5\x18\x01\x12'\n\x1cMessageType_SolanaGetAddress\x10\x86\ + \x07\x1a\x04\x90\xb5\x18\x01\x12$\n\x19MessageType_SolanaAddress\x10\x87\ + \x07\x1a\x04\x98\xb5\x18\x01\x12#\n\x18MessageType_SolanaSignTx\x10\x88\ + \x07\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_SolanaTxSignature\x10\ + \x89\x07\x1a\x04\x98\xb5\x18\x01\x12)\n\x1eMessageType_BenchmarkListName\ + s\x10\x8cG\x1a\x04\x80\xa6\x1d\x01\x12%\n\x1aMessageType_BenchmarkNames\ + \x10\x8dG\x1a\x04\x80\xa6\x1d\x01\x12#\n\x18MessageType_BenchmarkRun\x10\ + \x8eG\x1a\x04\x80\xa6\x1d\x01\x12&\n\x1bMessageType_BenchmarkResult\x10\ + \x8fG\x1a\x04\x80\xa6\x1d\x01\x1a\x04\xc8\xf3\x18\x01\"\x04\x08Z\x10\\\"\ + \x04\x08G\x10J\"\x04\x08r\x10z\"\x06\x08\xdb\x01\x10\xdb\x01\"\x06\x08\ + \xe0\x01\x10\xe0\x01\"\x06\x08\xac\x02\x10\xb0\x02\"\x06\x08\xb5\x02\x10\ + \xb8\x02\"\x06\x08\xe8\x07\x10\xcb\x08B8\n#com.satoshilabs.trezor.lib.pr\ + otobufB\rTrezorMessage\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_common.rs b/rust/trezor-client/src/protos/generated/messages_common.rs index 4fd72b22f0b..55b796f7b4c 100644 --- a/rust/trezor-client/src/protos/generated/messages_common.rs +++ b/rust/trezor-client/src/protos/generated/messages_common.rs @@ -414,6 +414,14 @@ pub mod failure { Failure_WipeCodeMismatch = 13, // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidSession) Failure_InvalidSession = 14, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_ThpUnallocatedSession) + Failure_ThpUnallocatedSession = 15, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidProtocol) + Failure_InvalidProtocol = 16, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_BufferError) + Failure_BufferError = 17, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_DeviceIsBusy) + Failure_DeviceIsBusy = 18, // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_FirmwareError) Failure_FirmwareError = 99, } @@ -441,6 +449,10 @@ pub mod failure { 12 => ::std::option::Option::Some(FailureType::Failure_PinMismatch), 13 => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch), 14 => ::std::option::Option::Some(FailureType::Failure_InvalidSession), + 15 => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession), + 16 => ::std::option::Option::Some(FailureType::Failure_InvalidProtocol), + 17 => ::std::option::Option::Some(FailureType::Failure_BufferError), + 18 => ::std::option::Option::Some(FailureType::Failure_DeviceIsBusy), 99 => ::std::option::Option::Some(FailureType::Failure_FirmwareError), _ => ::std::option::Option::None } @@ -462,6 +474,10 @@ pub mod failure { "Failure_PinMismatch" => ::std::option::Option::Some(FailureType::Failure_PinMismatch), "Failure_WipeCodeMismatch" => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch), "Failure_InvalidSession" => ::std::option::Option::Some(FailureType::Failure_InvalidSession), + "Failure_ThpUnallocatedSession" => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession), + "Failure_InvalidProtocol" => ::std::option::Option::Some(FailureType::Failure_InvalidProtocol), + "Failure_BufferError" => ::std::option::Option::Some(FailureType::Failure_BufferError), + "Failure_DeviceIsBusy" => ::std::option::Option::Some(FailureType::Failure_DeviceIsBusy), "Failure_FirmwareError" => ::std::option::Option::Some(FailureType::Failure_FirmwareError), _ => ::std::option::Option::None } @@ -482,6 +498,10 @@ pub mod failure { FailureType::Failure_PinMismatch, FailureType::Failure_WipeCodeMismatch, FailureType::Failure_InvalidSession, + FailureType::Failure_ThpUnallocatedSession, + FailureType::Failure_InvalidProtocol, + FailureType::Failure_BufferError, + FailureType::Failure_DeviceIsBusy, FailureType::Failure_FirmwareError, ]; } @@ -508,7 +528,11 @@ pub mod failure { FailureType::Failure_PinMismatch => 11, FailureType::Failure_WipeCodeMismatch => 12, FailureType::Failure_InvalidSession => 13, - FailureType::Failure_FirmwareError => 14, + FailureType::Failure_ThpUnallocatedSession => 14, + FailureType::Failure_InvalidProtocol => 15, + FailureType::Failure_BufferError => 16, + FailureType::Failure_DeviceIsBusy => 17, + FailureType::Failure_FirmwareError => 18, }; Self::enum_descriptor().value_by_index(index) } @@ -2481,9 +2505,9 @@ impl ::protobuf::reflect::ProtobufValue for HDNodeType { static file_descriptor_proto_data: &'static [u8] = b"\ \n\x15messages-common.proto\x12\x19hw.trezor.messages.common\x1a\roption\ s.proto\"%\n\x07Success\x12\x1a\n\x07message\x18\x01\x20\x01(\t:\0R\x07m\ - essage\"\x8f\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2..hw.t\ + essage\"\x82\x05\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2..hw.t\ rezor.messages.common.Failure.FailureTypeR\x04code\x12\x18\n\x07message\ - \x18\x02\x20\x01(\tR\x07message\"\xa5\x03\n\x0bFailureType\x12\x1d\n\x19\ + \x18\x02\x20\x01(\tR\x07message\"\x98\x04\n\x0bFailureType\x12\x1d\n\x19\ Failure_UnexpectedMessage\x10\x01\x12\x1a\n\x16Failure_ButtonExpected\ \x10\x02\x12\x15\n\x11Failure_DataError\x10\x03\x12\x1b\n\x17Failure_Act\ ionCancelled\x10\x04\x12\x17\n\x13Failure_PinExpected\x10\x05\x12\x18\n\ @@ -2492,44 +2516,46 @@ static file_descriptor_proto_data: &'static [u8] = b"\ essError\x10\t\x12\x1a\n\x16Failure_NotEnoughFunds\x10\n\x12\x1a\n\x16Fa\ ilure_NotInitialized\x10\x0b\x12\x17\n\x13Failure_PinMismatch\x10\x0c\ \x12\x1c\n\x18Failure_WipeCodeMismatch\x10\r\x12\x1a\n\x16Failure_Invali\ - dSession\x10\x0e\x12\x19\n\x15Failure_FirmwareError\x10c\"\xab\x06\n\rBu\ - ttonRequest\x12N\n\x04code\x18\x01\x20\x01(\x0e2:.hw.trezor.messages.com\ - mon.ButtonRequest.ButtonRequestTypeR\x04code\x12\x14\n\x05pages\x18\x02\ - \x20\x01(\rR\x05pages\x12\x12\n\x04name\x18\x04\x20\x01(\tR\x04name\"\ - \x99\x05\n\x11ButtonRequestType\x12\x17\n\x13ButtonRequest_Other\x10\x01\ - \x12\"\n\x1eButtonRequest_FeeOverThreshold\x10\x02\x12\x1f\n\x1bButtonRe\ - quest_ConfirmOutput\x10\x03\x12\x1d\n\x19ButtonRequest_ResetDevice\x10\ - \x04\x12\x1d\n\x19ButtonRequest_ConfirmWord\x10\x05\x12\x1c\n\x18ButtonR\ - equest_WipeDevice\x10\x06\x12\x1d\n\x19ButtonRequest_ProtectCall\x10\x07\ - \x12\x18\n\x14ButtonRequest_SignTx\x10\x08\x12\x1f\n\x1bButtonRequest_Fi\ - rmwareCheck\x10\t\x12\x19\n\x15ButtonRequest_Address\x10\n\x12\x1b\n\x17\ - ButtonRequest_PublicKey\x10\x0b\x12#\n\x1fButtonRequest_MnemonicWordCoun\ - t\x10\x0c\x12\x1f\n\x1bButtonRequest_MnemonicInput\x10\r\x120\n(_Depreca\ - ted_ButtonRequest_PassphraseType\x10\x0e\x1a\x02\x08\x01\x12'\n#ButtonRe\ - quest_UnknownDerivationPath\x10\x0f\x12\"\n\x1eButtonRequest_RecoveryHom\ - epage\x10\x10\x12\x19\n\x15ButtonRequest_Success\x10\x11\x12\x19\n\x15Bu\ - ttonRequest_Warning\x10\x12\x12!\n\x1dButtonRequest_PassphraseEntry\x10\ - \x13\x12\x1a\n\x16ButtonRequest_PinEntry\x10\x14J\x04\x08\x03\x10\x04\"\ - \x0b\n\tButtonAck\"\xbb\x02\n\x10PinMatrixRequest\x12T\n\x04type\x18\x01\ - \x20\x01(\x0e2@.hw.trezor.messages.common.PinMatrixRequest.PinMatrixRequ\ - estTypeR\x04type\"\xd0\x01\n\x14PinMatrixRequestType\x12\x20\n\x1cPinMat\ - rixRequestType_Current\x10\x01\x12!\n\x1dPinMatrixRequestType_NewFirst\ - \x10\x02\x12\"\n\x1ePinMatrixRequestType_NewSecond\x10\x03\x12&\n\"PinMa\ - trixRequestType_WipeCodeFirst\x10\x04\x12'\n#PinMatrixRequestType_WipeCo\ - deSecond\x10\x05\"\x20\n\x0cPinMatrixAck\x12\x10\n\x03pin\x18\x01\x20\ - \x02(\tR\x03pin\"5\n\x11PassphraseRequest\x12\x20\n\n_on_device\x18\x01\ - \x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"g\n\rPassphraseAck\x12\x1e\n\np\ - assphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x19\n\x06_state\x18\x02\ - \x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\n\ton_device\x18\x03\x20\ - \x01(\x08R\x08onDevice\"=\n!Deprecated_PassphraseStateRequest\x12\x14\n\ - \x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\x18\x01\"#\n\x1dDeprecated\ - _PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\nHDNodeType\x12\x14\n\x05de\ - pth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\x0bfingerprint\x18\x02\x20\ - \x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\x18\x03\x20\x02(\rR\x08chil\ - dNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\x0cR\tchainCode\x12\x1f\n\ - \x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivateKey\x12\x1d\n\npublic_key\ - \x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satoshilabs.trezor.lib.protobu\ - fB\x13TrezorMessageCommon\x80\xa6\x1d\x01\ + dSession\x10\x0e\x12!\n\x1dFailure_ThpUnallocatedSession\x10\x0f\x12\x1b\ + \n\x17Failure_InvalidProtocol\x10\x10\x12\x17\n\x13Failure_BufferError\ + \x10\x11\x12\x18\n\x14Failure_DeviceIsBusy\x10\x12\x12\x19\n\x15Failure_\ + FirmwareError\x10c\"\xab\x06\n\rButtonRequest\x12N\n\x04code\x18\x01\x20\ + \x01(\x0e2:.hw.trezor.messages.common.ButtonRequest.ButtonRequestTypeR\ + \x04code\x12\x14\n\x05pages\x18\x02\x20\x01(\rR\x05pages\x12\x12\n\x04na\ + me\x18\x04\x20\x01(\tR\x04name\"\x99\x05\n\x11ButtonRequestType\x12\x17\ + \n\x13ButtonRequest_Other\x10\x01\x12\"\n\x1eButtonRequest_FeeOverThresh\ + old\x10\x02\x12\x1f\n\x1bButtonRequest_ConfirmOutput\x10\x03\x12\x1d\n\ + \x19ButtonRequest_ResetDevice\x10\x04\x12\x1d\n\x19ButtonRequest_Confirm\ + Word\x10\x05\x12\x1c\n\x18ButtonRequest_WipeDevice\x10\x06\x12\x1d\n\x19\ + ButtonRequest_ProtectCall\x10\x07\x12\x18\n\x14ButtonRequest_SignTx\x10\ + \x08\x12\x1f\n\x1bButtonRequest_FirmwareCheck\x10\t\x12\x19\n\x15ButtonR\ + equest_Address\x10\n\x12\x1b\n\x17ButtonRequest_PublicKey\x10\x0b\x12#\n\ + \x1fButtonRequest_MnemonicWordCount\x10\x0c\x12\x1f\n\x1bButtonRequest_M\ + nemonicInput\x10\r\x120\n(_Deprecated_ButtonRequest_PassphraseType\x10\ + \x0e\x1a\x02\x08\x01\x12'\n#ButtonRequest_UnknownDerivationPath\x10\x0f\ + \x12\"\n\x1eButtonRequest_RecoveryHomepage\x10\x10\x12\x19\n\x15ButtonRe\ + quest_Success\x10\x11\x12\x19\n\x15ButtonRequest_Warning\x10\x12\x12!\n\ + \x1dButtonRequest_PassphraseEntry\x10\x13\x12\x1a\n\x16ButtonRequest_Pin\ + Entry\x10\x14J\x04\x08\x03\x10\x04\"\x0b\n\tButtonAck\"\xbb\x02\n\x10Pin\ + MatrixRequest\x12T\n\x04type\x18\x01\x20\x01(\x0e2@.hw.trezor.messages.c\ + ommon.PinMatrixRequest.PinMatrixRequestTypeR\x04type\"\xd0\x01\n\x14PinM\ + atrixRequestType\x12\x20\n\x1cPinMatrixRequestType_Current\x10\x01\x12!\ + \n\x1dPinMatrixRequestType_NewFirst\x10\x02\x12\"\n\x1ePinMatrixRequestT\ + ype_NewSecond\x10\x03\x12&\n\"PinMatrixRequestType_WipeCodeFirst\x10\x04\ + \x12'\n#PinMatrixRequestType_WipeCodeSecond\x10\x05\"\x20\n\x0cPinMatrix\ + Ack\x12\x10\n\x03pin\x18\x01\x20\x02(\tR\x03pin\"5\n\x11PassphraseReques\ + t\x12\x20\n\n_on_device\x18\x01\x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"\ + g\n\rPassphraseAck\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassphrase\ + \x12\x19\n\x06_state\x18\x02\x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\ + \n\ton_device\x18\x03\x20\x01(\x08R\x08onDevice\"=\n!Deprecated_Passphra\ + seStateRequest\x12\x14\n\x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\ + \x18\x01\"#\n\x1dDeprecated_PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\ + \nHDNodeType\x12\x14\n\x05depth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\ + \x0bfingerprint\x18\x02\x20\x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\ + \x18\x03\x20\x02(\rR\x08childNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\ + \x0cR\tchainCode\x12\x1f\n\x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivate\ + Key\x12\x1d\n\npublic_key\x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satos\ + hilabs.trezor.lib.protobufB\x13TrezorMessageCommon\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_debug.rs b/rust/trezor-client/src/protos/generated/messages_debug.rs index 58328f52508..bf37c7f9314 100644 --- a/rust/trezor-client/src/protos/generated/messages_debug.rs +++ b/rust/trezor-client/src/protos/generated/messages_debug.rs @@ -2118,6 +2118,629 @@ impl ::protobuf::reflect::ProtobufValue for DebugLinkState { type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; } +// @@protoc_insertion_point(message:hw.trezor.messages.debug.DebugLinkGetPairingInfo) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct DebugLinkGetPairingInfo { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetPairingInfo.channel_id) + pub channel_id: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetPairingInfo.handshake_hash) + pub handshake_hash: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetPairingInfo.nfc_secret_host) + pub nfc_secret_host: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.debug.DebugLinkGetPairingInfo.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a DebugLinkGetPairingInfo { + fn default() -> &'a DebugLinkGetPairingInfo { + ::default_instance() + } +} + +impl DebugLinkGetPairingInfo { + pub fn new() -> DebugLinkGetPairingInfo { + ::std::default::Default::default() + } + + // optional bytes channel_id = 1; + + pub fn channel_id(&self) -> &[u8] { + match self.channel_id.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_channel_id(&mut self) { + self.channel_id = ::std::option::Option::None; + } + + pub fn has_channel_id(&self) -> bool { + self.channel_id.is_some() + } + + // Param is passed by value, moved + pub fn set_channel_id(&mut self, v: ::std::vec::Vec) { + self.channel_id = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_channel_id(&mut self) -> &mut ::std::vec::Vec { + if self.channel_id.is_none() { + self.channel_id = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.channel_id.as_mut().unwrap() + } + + // Take field + pub fn take_channel_id(&mut self) -> ::std::vec::Vec { + self.channel_id.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes handshake_hash = 2; + + pub fn handshake_hash(&self) -> &[u8] { + match self.handshake_hash.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_handshake_hash(&mut self) { + self.handshake_hash = ::std::option::Option::None; + } + + pub fn has_handshake_hash(&self) -> bool { + self.handshake_hash.is_some() + } + + // Param is passed by value, moved + pub fn set_handshake_hash(&mut self, v: ::std::vec::Vec) { + self.handshake_hash = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_handshake_hash(&mut self) -> &mut ::std::vec::Vec { + if self.handshake_hash.is_none() { + self.handshake_hash = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.handshake_hash.as_mut().unwrap() + } + + // Take field + pub fn take_handshake_hash(&mut self) -> ::std::vec::Vec { + self.handshake_hash.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes nfc_secret_host = 3; + + pub fn nfc_secret_host(&self) -> &[u8] { + match self.nfc_secret_host.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_nfc_secret_host(&mut self) { + self.nfc_secret_host = ::std::option::Option::None; + } + + pub fn has_nfc_secret_host(&self) -> bool { + self.nfc_secret_host.is_some() + } + + // Param is passed by value, moved + pub fn set_nfc_secret_host(&mut self, v: ::std::vec::Vec) { + self.nfc_secret_host = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_nfc_secret_host(&mut self) -> &mut ::std::vec::Vec { + if self.nfc_secret_host.is_none() { + self.nfc_secret_host = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.nfc_secret_host.as_mut().unwrap() + } + + // Take field + pub fn take_nfc_secret_host(&mut self) -> ::std::vec::Vec { + self.nfc_secret_host.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(3); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "channel_id", + |m: &DebugLinkGetPairingInfo| { &m.channel_id }, + |m: &mut DebugLinkGetPairingInfo| { &mut m.channel_id }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "handshake_hash", + |m: &DebugLinkGetPairingInfo| { &m.handshake_hash }, + |m: &mut DebugLinkGetPairingInfo| { &mut m.handshake_hash }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "nfc_secret_host", + |m: &DebugLinkGetPairingInfo| { &m.nfc_secret_host }, + |m: &mut DebugLinkGetPairingInfo| { &mut m.nfc_secret_host }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "DebugLinkGetPairingInfo", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for DebugLinkGetPairingInfo { + const NAME: &'static str = "DebugLinkGetPairingInfo"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.channel_id = ::std::option::Option::Some(is.read_bytes()?); + }, + 18 => { + self.handshake_hash = ::std::option::Option::Some(is.read_bytes()?); + }, + 26 => { + self.nfc_secret_host = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.channel_id.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.handshake_hash.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + if let Some(v) = self.nfc_secret_host.as_ref() { + my_size += ::protobuf::rt::bytes_size(3, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.channel_id.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.handshake_hash.as_ref() { + os.write_bytes(2, v)?; + } + if let Some(v) = self.nfc_secret_host.as_ref() { + os.write_bytes(3, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> DebugLinkGetPairingInfo { + DebugLinkGetPairingInfo::new() + } + + fn clear(&mut self) { + self.channel_id = ::std::option::Option::None; + self.handshake_hash = ::std::option::Option::None; + self.nfc_secret_host = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static DebugLinkGetPairingInfo { + static instance: DebugLinkGetPairingInfo = DebugLinkGetPairingInfo { + channel_id: ::std::option::Option::None, + handshake_hash: ::std::option::Option::None, + nfc_secret_host: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for DebugLinkGetPairingInfo { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("DebugLinkGetPairingInfo").unwrap()).clone() + } +} + +impl ::std::fmt::Display for DebugLinkGetPairingInfo { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for DebugLinkGetPairingInfo { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.debug.DebugLinkPairingInfo) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct DebugLinkPairingInfo { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkPairingInfo.channel_id) + pub channel_id: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkPairingInfo.handshake_hash) + pub handshake_hash: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkPairingInfo.code_entry_code) + pub code_entry_code: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkPairingInfo.code_qr_code) + pub code_qr_code: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkPairingInfo.nfc_secret_trezor) + pub nfc_secret_trezor: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.debug.DebugLinkPairingInfo.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a DebugLinkPairingInfo { + fn default() -> &'a DebugLinkPairingInfo { + ::default_instance() + } +} + +impl DebugLinkPairingInfo { + pub fn new() -> DebugLinkPairingInfo { + ::std::default::Default::default() + } + + // optional bytes channel_id = 1; + + pub fn channel_id(&self) -> &[u8] { + match self.channel_id.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_channel_id(&mut self) { + self.channel_id = ::std::option::Option::None; + } + + pub fn has_channel_id(&self) -> bool { + self.channel_id.is_some() + } + + // Param is passed by value, moved + pub fn set_channel_id(&mut self, v: ::std::vec::Vec) { + self.channel_id = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_channel_id(&mut self) -> &mut ::std::vec::Vec { + if self.channel_id.is_none() { + self.channel_id = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.channel_id.as_mut().unwrap() + } + + // Take field + pub fn take_channel_id(&mut self) -> ::std::vec::Vec { + self.channel_id.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes handshake_hash = 2; + + pub fn handshake_hash(&self) -> &[u8] { + match self.handshake_hash.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_handshake_hash(&mut self) { + self.handshake_hash = ::std::option::Option::None; + } + + pub fn has_handshake_hash(&self) -> bool { + self.handshake_hash.is_some() + } + + // Param is passed by value, moved + pub fn set_handshake_hash(&mut self, v: ::std::vec::Vec) { + self.handshake_hash = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_handshake_hash(&mut self) -> &mut ::std::vec::Vec { + if self.handshake_hash.is_none() { + self.handshake_hash = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.handshake_hash.as_mut().unwrap() + } + + // Take field + pub fn take_handshake_hash(&mut self) -> ::std::vec::Vec { + self.handshake_hash.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional uint32 code_entry_code = 3; + + pub fn code_entry_code(&self) -> u32 { + self.code_entry_code.unwrap_or(0) + } + + pub fn clear_code_entry_code(&mut self) { + self.code_entry_code = ::std::option::Option::None; + } + + pub fn has_code_entry_code(&self) -> bool { + self.code_entry_code.is_some() + } + + // Param is passed by value, moved + pub fn set_code_entry_code(&mut self, v: u32) { + self.code_entry_code = ::std::option::Option::Some(v); + } + + // optional bytes code_qr_code = 4; + + pub fn code_qr_code(&self) -> &[u8] { + match self.code_qr_code.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_code_qr_code(&mut self) { + self.code_qr_code = ::std::option::Option::None; + } + + pub fn has_code_qr_code(&self) -> bool { + self.code_qr_code.is_some() + } + + // Param is passed by value, moved + pub fn set_code_qr_code(&mut self, v: ::std::vec::Vec) { + self.code_qr_code = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_code_qr_code(&mut self) -> &mut ::std::vec::Vec { + if self.code_qr_code.is_none() { + self.code_qr_code = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.code_qr_code.as_mut().unwrap() + } + + // Take field + pub fn take_code_qr_code(&mut self) -> ::std::vec::Vec { + self.code_qr_code.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes nfc_secret_trezor = 5; + + pub fn nfc_secret_trezor(&self) -> &[u8] { + match self.nfc_secret_trezor.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_nfc_secret_trezor(&mut self) { + self.nfc_secret_trezor = ::std::option::Option::None; + } + + pub fn has_nfc_secret_trezor(&self) -> bool { + self.nfc_secret_trezor.is_some() + } + + // Param is passed by value, moved + pub fn set_nfc_secret_trezor(&mut self, v: ::std::vec::Vec) { + self.nfc_secret_trezor = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_nfc_secret_trezor(&mut self) -> &mut ::std::vec::Vec { + if self.nfc_secret_trezor.is_none() { + self.nfc_secret_trezor = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.nfc_secret_trezor.as_mut().unwrap() + } + + // Take field + pub fn take_nfc_secret_trezor(&mut self) -> ::std::vec::Vec { + self.nfc_secret_trezor.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(5); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "channel_id", + |m: &DebugLinkPairingInfo| { &m.channel_id }, + |m: &mut DebugLinkPairingInfo| { &mut m.channel_id }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "handshake_hash", + |m: &DebugLinkPairingInfo| { &m.handshake_hash }, + |m: &mut DebugLinkPairingInfo| { &mut m.handshake_hash }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "code_entry_code", + |m: &DebugLinkPairingInfo| { &m.code_entry_code }, + |m: &mut DebugLinkPairingInfo| { &mut m.code_entry_code }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "code_qr_code", + |m: &DebugLinkPairingInfo| { &m.code_qr_code }, + |m: &mut DebugLinkPairingInfo| { &mut m.code_qr_code }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "nfc_secret_trezor", + |m: &DebugLinkPairingInfo| { &m.nfc_secret_trezor }, + |m: &mut DebugLinkPairingInfo| { &mut m.nfc_secret_trezor }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "DebugLinkPairingInfo", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for DebugLinkPairingInfo { + const NAME: &'static str = "DebugLinkPairingInfo"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.channel_id = ::std::option::Option::Some(is.read_bytes()?); + }, + 18 => { + self.handshake_hash = ::std::option::Option::Some(is.read_bytes()?); + }, + 24 => { + self.code_entry_code = ::std::option::Option::Some(is.read_uint32()?); + }, + 34 => { + self.code_qr_code = ::std::option::Option::Some(is.read_bytes()?); + }, + 42 => { + self.nfc_secret_trezor = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.channel_id.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.handshake_hash.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + if let Some(v) = self.code_entry_code { + my_size += ::protobuf::rt::uint32_size(3, v); + } + if let Some(v) = self.code_qr_code.as_ref() { + my_size += ::protobuf::rt::bytes_size(4, &v); + } + if let Some(v) = self.nfc_secret_trezor.as_ref() { + my_size += ::protobuf::rt::bytes_size(5, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.channel_id.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.handshake_hash.as_ref() { + os.write_bytes(2, v)?; + } + if let Some(v) = self.code_entry_code { + os.write_uint32(3, v)?; + } + if let Some(v) = self.code_qr_code.as_ref() { + os.write_bytes(4, v)?; + } + if let Some(v) = self.nfc_secret_trezor.as_ref() { + os.write_bytes(5, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> DebugLinkPairingInfo { + DebugLinkPairingInfo::new() + } + + fn clear(&mut self) { + self.channel_id = ::std::option::Option::None; + self.handshake_hash = ::std::option::Option::None; + self.code_entry_code = ::std::option::Option::None; + self.code_qr_code = ::std::option::Option::None; + self.nfc_secret_trezor = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static DebugLinkPairingInfo { + static instance: DebugLinkPairingInfo = DebugLinkPairingInfo { + channel_id: ::std::option::Option::None, + handshake_hash: ::std::option::Option::None, + code_entry_code: ::std::option::Option::None, + code_qr_code: ::std::option::Option::None, + nfc_secret_trezor: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for DebugLinkPairingInfo { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("DebugLinkPairingInfo").unwrap()).clone() + } +} + +impl ::std::fmt::Display for DebugLinkPairingInfo { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for DebugLinkPairingInfo { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + // @@protoc_insertion_point(message:hw.trezor.messages.debug.DebugLinkStop) #[derive(PartialEq,Clone,Default,Debug)] pub struct DebugLinkStop { @@ -3707,20 +4330,28 @@ static file_descriptor_proto_data: &'static [u8] = b"\ dPos\x12$\n\x0ereset_word_pos\x18\x0b\x20\x01(\rR\x0cresetWordPos\x12N\n\ \rmnemonic_type\x18\x0c\x20\x01(\x0e2).hw.trezor.messages.management.Bac\ kupTypeR\x0cmnemonicType\x12\x16\n\x06tokens\x18\r\x20\x03(\tR\x06tokens\ - \"\x0f\n\rDebugLinkStop\"P\n\x0cDebugLinkLog\x12\x14\n\x05level\x18\x01\ - \x20\x01(\rR\x05level\x12\x16\n\x06bucket\x18\x02\x20\x01(\tR\x06bucket\ - \x12\x12\n\x04text\x18\x03\x20\x01(\tR\x04text\"G\n\x13DebugLinkMemoryRe\ - ad\x12\x18\n\x07address\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06leng\ - th\x18\x02\x20\x01(\rR\x06length\")\n\x0fDebugLinkMemory\x12\x16\n\x06me\ - mory\x18\x01\x20\x01(\x0cR\x06memory\"^\n\x14DebugLinkMemoryWrite\x12\ - \x18\n\x07address\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06memory\x18\ - \x02\x20\x01(\x0cR\x06memory\x12\x14\n\x05flash\x18\x03\x20\x01(\x08R\ - \x05flash\"-\n\x13DebugLinkFlashErase\x12\x16\n\x06sector\x18\x01\x20\ - \x01(\rR\x06sector\".\n\x14DebugLinkEraseSdCard\x12\x16\n\x06format\x18\ - \x01\x20\x01(\x08R\x06format\"0\n\x14DebugLinkWatchLayout\x12\x14\n\x05w\ - atch\x18\x01\x20\x01(\x08R\x05watch:\x02\x18\x01\"\x1f\n\x19DebugLinkRes\ - etDebugEvents:\x02\x18\x01\"\x1a\n\x18DebugLinkOptigaSetSecMaxB=\n#com.s\ - atoshilabs.trezor.lib.protobufB\x12TrezorMessageDebug\x80\xa6\x1d\x01\ + \"\x87\x01\n\x17DebugLinkGetPairingInfo\x12\x1d\n\nchannel_id\x18\x01\ + \x20\x01(\x0cR\tchannelId\x12%\n\x0ehandshake_hash\x18\x02\x20\x01(\x0cR\ + \rhandshakeHash\x12&\n\x0fnfc_secret_host\x18\x03\x20\x01(\x0cR\rnfcSecr\ + etHost\"\xd2\x01\n\x14DebugLinkPairingInfo\x12\x1d\n\nchannel_id\x18\x01\ + \x20\x01(\x0cR\tchannelId\x12%\n\x0ehandshake_hash\x18\x02\x20\x01(\x0cR\ + \rhandshakeHash\x12&\n\x0fcode_entry_code\x18\x03\x20\x01(\rR\rcodeEntry\ + Code\x12\x20\n\x0ccode_qr_code\x18\x04\x20\x01(\x0cR\ncodeQrCode\x12*\n\ + \x11nfc_secret_trezor\x18\x05\x20\x01(\x0cR\x0fnfcSecretTrezor\"\x0f\n\r\ + DebugLinkStop\"P\n\x0cDebugLinkLog\x12\x14\n\x05level\x18\x01\x20\x01(\r\ + R\x05level\x12\x16\n\x06bucket\x18\x02\x20\x01(\tR\x06bucket\x12\x12\n\ + \x04text\x18\x03\x20\x01(\tR\x04text\"G\n\x13DebugLinkMemoryRead\x12\x18\ + \n\x07address\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06length\x18\x02\ + \x20\x01(\rR\x06length\")\n\x0fDebugLinkMemory\x12\x16\n\x06memory\x18\ + \x01\x20\x01(\x0cR\x06memory\"^\n\x14DebugLinkMemoryWrite\x12\x18\n\x07a\ + ddress\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06memory\x18\x02\x20\ + \x01(\x0cR\x06memory\x12\x14\n\x05flash\x18\x03\x20\x01(\x08R\x05flash\"\ + -\n\x13DebugLinkFlashErase\x12\x16\n\x06sector\x18\x01\x20\x01(\rR\x06se\ + ctor\".\n\x14DebugLinkEraseSdCard\x12\x16\n\x06format\x18\x01\x20\x01(\ + \x08R\x06format\"0\n\x14DebugLinkWatchLayout\x12\x14\n\x05watch\x18\x01\ + \x20\x01(\x08R\x05watch:\x02\x18\x01\"\x1f\n\x19DebugLinkResetDebugEvent\ + s:\x02\x18\x01\"\x1a\n\x18DebugLinkOptigaSetSecMaxB=\n#com.satoshilabs.t\ + rezor.lib.protobufB\x12TrezorMessageDebug\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -3741,13 +4372,15 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { deps.push(super::messages_common::file_descriptor().clone()); deps.push(super::messages_management::file_descriptor().clone()); deps.push(super::options::file_descriptor().clone()); - let mut messages = ::std::vec::Vec::with_capacity(16); + let mut messages = ::std::vec::Vec::with_capacity(18); messages.push(DebugLinkDecision::generated_message_descriptor_data()); messages.push(DebugLinkLayout::generated_message_descriptor_data()); messages.push(DebugLinkReseedRandom::generated_message_descriptor_data()); messages.push(DebugLinkRecordScreen::generated_message_descriptor_data()); messages.push(DebugLinkGetState::generated_message_descriptor_data()); messages.push(DebugLinkState::generated_message_descriptor_data()); + messages.push(DebugLinkGetPairingInfo::generated_message_descriptor_data()); + messages.push(DebugLinkPairingInfo::generated_message_descriptor_data()); messages.push(DebugLinkStop::generated_message_descriptor_data()); messages.push(DebugLinkLog::generated_message_descriptor_data()); messages.push(DebugLinkMemoryRead::generated_message_descriptor_data()); diff --git a/rust/trezor-client/src/protos/generated/messages_thp.rs b/rust/trezor-client/src/protos/generated/messages_thp.rs index 9e0d8e8aea8..bd04e4d9b05 100644 --- a/rust/trezor-client/src/protos/generated/messages_thp.rs +++ b/rust/trezor-client/src/protos/generated/messages_thp.rs @@ -25,12 +25,3291 @@ /// of protobuf runtime. const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_3_3_0; +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpDeviceProperties) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpDeviceProperties { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.internal_model) + pub internal_model: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.model_variant) + pub model_variant: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.protocol_version_major) + pub protocol_version_major: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.protocol_version_minor) + pub protocol_version_minor: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.pairing_methods) + pub pairing_methods: ::std::vec::Vec<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpDeviceProperties.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpDeviceProperties { + fn default() -> &'a ThpDeviceProperties { + ::default_instance() + } +} + +impl ThpDeviceProperties { + pub fn new() -> ThpDeviceProperties { + ::std::default::Default::default() + } + + // optional string internal_model = 1; + + pub fn internal_model(&self) -> &str { + match self.internal_model.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_internal_model(&mut self) { + self.internal_model = ::std::option::Option::None; + } + + pub fn has_internal_model(&self) -> bool { + self.internal_model.is_some() + } + + // Param is passed by value, moved + pub fn set_internal_model(&mut self, v: ::std::string::String) { + self.internal_model = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_internal_model(&mut self) -> &mut ::std::string::String { + if self.internal_model.is_none() { + self.internal_model = ::std::option::Option::Some(::std::string::String::new()); + } + self.internal_model.as_mut().unwrap() + } + + // Take field + pub fn take_internal_model(&mut self) -> ::std::string::String { + self.internal_model.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional uint32 model_variant = 2; + + pub fn model_variant(&self) -> u32 { + self.model_variant.unwrap_or(0) + } + + pub fn clear_model_variant(&mut self) { + self.model_variant = ::std::option::Option::None; + } + + pub fn has_model_variant(&self) -> bool { + self.model_variant.is_some() + } + + // Param is passed by value, moved + pub fn set_model_variant(&mut self, v: u32) { + self.model_variant = ::std::option::Option::Some(v); + } + + // optional uint32 protocol_version_major = 3; + + pub fn protocol_version_major(&self) -> u32 { + self.protocol_version_major.unwrap_or(0) + } + + pub fn clear_protocol_version_major(&mut self) { + self.protocol_version_major = ::std::option::Option::None; + } + + pub fn has_protocol_version_major(&self) -> bool { + self.protocol_version_major.is_some() + } + + // Param is passed by value, moved + pub fn set_protocol_version_major(&mut self, v: u32) { + self.protocol_version_major = ::std::option::Option::Some(v); + } + + // optional uint32 protocol_version_minor = 4; + + pub fn protocol_version_minor(&self) -> u32 { + self.protocol_version_minor.unwrap_or(0) + } + + pub fn clear_protocol_version_minor(&mut self) { + self.protocol_version_minor = ::std::option::Option::None; + } + + pub fn has_protocol_version_minor(&self) -> bool { + self.protocol_version_minor.is_some() + } + + // Param is passed by value, moved + pub fn set_protocol_version_minor(&mut self, v: u32) { + self.protocol_version_minor = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(5); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "internal_model", + |m: &ThpDeviceProperties| { &m.internal_model }, + |m: &mut ThpDeviceProperties| { &mut m.internal_model }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "model_variant", + |m: &ThpDeviceProperties| { &m.model_variant }, + |m: &mut ThpDeviceProperties| { &mut m.model_variant }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "protocol_version_major", + |m: &ThpDeviceProperties| { &m.protocol_version_major }, + |m: &mut ThpDeviceProperties| { &mut m.protocol_version_major }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "protocol_version_minor", + |m: &ThpDeviceProperties| { &m.protocol_version_minor }, + |m: &mut ThpDeviceProperties| { &mut m.protocol_version_minor }, + )); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "pairing_methods", + |m: &ThpDeviceProperties| { &m.pairing_methods }, + |m: &mut ThpDeviceProperties| { &mut m.pairing_methods }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpDeviceProperties", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpDeviceProperties { + const NAME: &'static str = "ThpDeviceProperties"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.internal_model = ::std::option::Option::Some(is.read_string()?); + }, + 16 => { + self.model_variant = ::std::option::Option::Some(is.read_uint32()?); + }, + 24 => { + self.protocol_version_major = ::std::option::Option::Some(is.read_uint32()?); + }, + 32 => { + self.protocol_version_minor = ::std::option::Option::Some(is.read_uint32()?); + }, + 40 => { + self.pairing_methods.push(is.read_enum_or_unknown()?); + }, + 42 => { + ::protobuf::rt::read_repeated_packed_enum_or_unknown_into(is, &mut self.pairing_methods)? + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.internal_model.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.model_variant { + my_size += ::protobuf::rt::uint32_size(2, v); + } + if let Some(v) = self.protocol_version_major { + my_size += ::protobuf::rt::uint32_size(3, v); + } + if let Some(v) = self.protocol_version_minor { + my_size += ::protobuf::rt::uint32_size(4, v); + } + for value in &self.pairing_methods { + my_size += ::protobuf::rt::int32_size(5, value.value()); + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.internal_model.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.model_variant { + os.write_uint32(2, v)?; + } + if let Some(v) = self.protocol_version_major { + os.write_uint32(3, v)?; + } + if let Some(v) = self.protocol_version_minor { + os.write_uint32(4, v)?; + } + for v in &self.pairing_methods { + os.write_enum(5, ::protobuf::EnumOrUnknown::value(v))?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpDeviceProperties { + ThpDeviceProperties::new() + } + + fn clear(&mut self) { + self.internal_model = ::std::option::Option::None; + self.model_variant = ::std::option::Option::None; + self.protocol_version_major = ::std::option::Option::None; + self.protocol_version_minor = ::std::option::Option::None; + self.pairing_methods.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpDeviceProperties { + static instance: ThpDeviceProperties = ThpDeviceProperties { + internal_model: ::std::option::Option::None, + model_variant: ::std::option::Option::None, + protocol_version_major: ::std::option::Option::None, + protocol_version_minor: ::std::option::Option::None, + pairing_methods: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpDeviceProperties { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpDeviceProperties").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpDeviceProperties { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpDeviceProperties { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpHandshakeCompletionReqNoisePayload { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.host_pairing_credential) + pub host_pairing_credential: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpHandshakeCompletionReqNoisePayload { + fn default() -> &'a ThpHandshakeCompletionReqNoisePayload { + ::default_instance() + } +} + +impl ThpHandshakeCompletionReqNoisePayload { + pub fn new() -> ThpHandshakeCompletionReqNoisePayload { + ::std::default::Default::default() + } + + // optional bytes host_pairing_credential = 1; + + pub fn host_pairing_credential(&self) -> &[u8] { + match self.host_pairing_credential.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_host_pairing_credential(&mut self) { + self.host_pairing_credential = ::std::option::Option::None; + } + + pub fn has_host_pairing_credential(&self) -> bool { + self.host_pairing_credential.is_some() + } + + // Param is passed by value, moved + pub fn set_host_pairing_credential(&mut self, v: ::std::vec::Vec) { + self.host_pairing_credential = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_pairing_credential(&mut self) -> &mut ::std::vec::Vec { + if self.host_pairing_credential.is_none() { + self.host_pairing_credential = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.host_pairing_credential.as_mut().unwrap() + } + + // Take field + pub fn take_host_pairing_credential(&mut self) -> ::std::vec::Vec { + self.host_pairing_credential.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_pairing_credential", + |m: &ThpHandshakeCompletionReqNoisePayload| { &m.host_pairing_credential }, + |m: &mut ThpHandshakeCompletionReqNoisePayload| { &mut m.host_pairing_credential }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpHandshakeCompletionReqNoisePayload", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpHandshakeCompletionReqNoisePayload { + const NAME: &'static str = "ThpHandshakeCompletionReqNoisePayload"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_pairing_credential = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_pairing_credential.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_pairing_credential.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpHandshakeCompletionReqNoisePayload { + ThpHandshakeCompletionReqNoisePayload::new() + } + + fn clear(&mut self) { + self.host_pairing_credential = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpHandshakeCompletionReqNoisePayload { + static instance: ThpHandshakeCompletionReqNoisePayload = ThpHandshakeCompletionReqNoisePayload { + host_pairing_credential: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpHandshakeCompletionReqNoisePayload { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpHandshakeCompletionReqNoisePayload").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpHandshakeCompletionReqNoisePayload { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpHandshakeCompletionReqNoisePayload { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCreateNewSession) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCreateNewSession { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.passphrase) + pub passphrase: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.on_device) + pub on_device: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.derive_cardano) + pub derive_cardano: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCreateNewSession.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCreateNewSession { + fn default() -> &'a ThpCreateNewSession { + ::default_instance() + } +} + +impl ThpCreateNewSession { + pub fn new() -> ThpCreateNewSession { + ::std::default::Default::default() + } + + // optional string passphrase = 1; + + pub fn passphrase(&self) -> &str { + match self.passphrase.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_passphrase(&mut self) { + self.passphrase = ::std::option::Option::None; + } + + pub fn has_passphrase(&self) -> bool { + self.passphrase.is_some() + } + + // Param is passed by value, moved + pub fn set_passphrase(&mut self, v: ::std::string::String) { + self.passphrase = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_passphrase(&mut self) -> &mut ::std::string::String { + if self.passphrase.is_none() { + self.passphrase = ::std::option::Option::Some(::std::string::String::new()); + } + self.passphrase.as_mut().unwrap() + } + + // Take field + pub fn take_passphrase(&mut self) -> ::std::string::String { + self.passphrase.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional bool on_device = 2; + + pub fn on_device(&self) -> bool { + self.on_device.unwrap_or(false) + } + + pub fn clear_on_device(&mut self) { + self.on_device = ::std::option::Option::None; + } + + pub fn has_on_device(&self) -> bool { + self.on_device.is_some() + } + + // Param is passed by value, moved + pub fn set_on_device(&mut self, v: bool) { + self.on_device = ::std::option::Option::Some(v); + } + + // optional bool derive_cardano = 3; + + pub fn derive_cardano(&self) -> bool { + self.derive_cardano.unwrap_or(false) + } + + pub fn clear_derive_cardano(&mut self) { + self.derive_cardano = ::std::option::Option::None; + } + + pub fn has_derive_cardano(&self) -> bool { + self.derive_cardano.is_some() + } + + // Param is passed by value, moved + pub fn set_derive_cardano(&mut self, v: bool) { + self.derive_cardano = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(3); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "passphrase", + |m: &ThpCreateNewSession| { &m.passphrase }, + |m: &mut ThpCreateNewSession| { &mut m.passphrase }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "on_device", + |m: &ThpCreateNewSession| { &m.on_device }, + |m: &mut ThpCreateNewSession| { &mut m.on_device }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "derive_cardano", + |m: &ThpCreateNewSession| { &m.derive_cardano }, + |m: &mut ThpCreateNewSession| { &mut m.derive_cardano }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCreateNewSession", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCreateNewSession { + const NAME: &'static str = "ThpCreateNewSession"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.passphrase = ::std::option::Option::Some(is.read_string()?); + }, + 16 => { + self.on_device = ::std::option::Option::Some(is.read_bool()?); + }, + 24 => { + self.derive_cardano = ::std::option::Option::Some(is.read_bool()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.passphrase.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.on_device { + my_size += 1 + 1; + } + if let Some(v) = self.derive_cardano { + my_size += 1 + 1; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.passphrase.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.on_device { + os.write_bool(2, v)?; + } + if let Some(v) = self.derive_cardano { + os.write_bool(3, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCreateNewSession { + ThpCreateNewSession::new() + } + + fn clear(&mut self) { + self.passphrase = ::std::option::Option::None; + self.on_device = ::std::option::Option::None; + self.derive_cardano = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCreateNewSession { + static instance: ThpCreateNewSession = ThpCreateNewSession { + passphrase: ::std::option::Option::None, + on_device: ::std::option::Option::None, + derive_cardano: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCreateNewSession { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCreateNewSession").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCreateNewSession { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCreateNewSession { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpPairingRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpPairingRequest { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpPairingRequest.host_name) + pub host_name: ::std::option::Option<::std::string::String>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpPairingRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpPairingRequest { + fn default() -> &'a ThpPairingRequest { + ::default_instance() + } +} + +impl ThpPairingRequest { + pub fn new() -> ThpPairingRequest { + ::std::default::Default::default() + } + + // optional string host_name = 1; + + pub fn host_name(&self) -> &str { + match self.host_name.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_host_name(&mut self) { + self.host_name = ::std::option::Option::None; + } + + pub fn has_host_name(&self) -> bool { + self.host_name.is_some() + } + + // Param is passed by value, moved + pub fn set_host_name(&mut self, v: ::std::string::String) { + self.host_name = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_name(&mut self) -> &mut ::std::string::String { + if self.host_name.is_none() { + self.host_name = ::std::option::Option::Some(::std::string::String::new()); + } + self.host_name.as_mut().unwrap() + } + + // Take field + pub fn take_host_name(&mut self) -> ::std::string::String { + self.host_name.take().unwrap_or_else(|| ::std::string::String::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_name", + |m: &ThpPairingRequest| { &m.host_name }, + |m: &mut ThpPairingRequest| { &mut m.host_name }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpPairingRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpPairingRequest { + const NAME: &'static str = "ThpPairingRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_name = ::std::option::Option::Some(is.read_string()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_name.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_name.as_ref() { + os.write_string(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpPairingRequest { + ThpPairingRequest::new() + } + + fn clear(&mut self) { + self.host_name = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpPairingRequest { + static instance: ThpPairingRequest = ThpPairingRequest { + host_name: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpPairingRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpPairingRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpPairingRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpPairingRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpPairingRequestApproved) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpPairingRequestApproved { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpPairingRequestApproved.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpPairingRequestApproved { + fn default() -> &'a ThpPairingRequestApproved { + ::default_instance() + } +} + +impl ThpPairingRequestApproved { + pub fn new() -> ThpPairingRequestApproved { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpPairingRequestApproved", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpPairingRequestApproved { + const NAME: &'static str = "ThpPairingRequestApproved"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpPairingRequestApproved { + ThpPairingRequestApproved::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpPairingRequestApproved { + static instance: ThpPairingRequestApproved = ThpPairingRequestApproved { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpPairingRequestApproved { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpPairingRequestApproved").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpPairingRequestApproved { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpPairingRequestApproved { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpSelectMethod) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpSelectMethod { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpSelectMethod.selected_pairing_method) + pub selected_pairing_method: ::std::option::Option<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpSelectMethod.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpSelectMethod { + fn default() -> &'a ThpSelectMethod { + ::default_instance() + } +} + +impl ThpSelectMethod { + pub fn new() -> ThpSelectMethod { + ::std::default::Default::default() + } + + // optional .hw.trezor.messages.thp.ThpPairingMethod selected_pairing_method = 1; + + pub fn selected_pairing_method(&self) -> ThpPairingMethod { + match self.selected_pairing_method { + Some(e) => e.enum_value_or(ThpPairingMethod::SkipPairing), + None => ThpPairingMethod::SkipPairing, + } + } + + pub fn clear_selected_pairing_method(&mut self) { + self.selected_pairing_method = ::std::option::Option::None; + } + + pub fn has_selected_pairing_method(&self) -> bool { + self.selected_pairing_method.is_some() + } + + // Param is passed by value, moved + pub fn set_selected_pairing_method(&mut self, v: ThpPairingMethod) { + self.selected_pairing_method = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "selected_pairing_method", + |m: &ThpSelectMethod| { &m.selected_pairing_method }, + |m: &mut ThpSelectMethod| { &mut m.selected_pairing_method }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpSelectMethod", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpSelectMethod { + const NAME: &'static str = "ThpSelectMethod"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 8 => { + self.selected_pairing_method = ::std::option::Option::Some(is.read_enum_or_unknown()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.selected_pairing_method { + my_size += ::protobuf::rt::int32_size(1, v.value()); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.selected_pairing_method { + os.write_enum(1, ::protobuf::EnumOrUnknown::value(&v))?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpSelectMethod { + ThpSelectMethod::new() + } + + fn clear(&mut self) { + self.selected_pairing_method = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpSelectMethod { + static instance: ThpSelectMethod = ThpSelectMethod { + selected_pairing_method: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpSelectMethod { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpSelectMethod").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpSelectMethod { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpSelectMethod { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpPairingPreparationsFinished) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpPairingPreparationsFinished { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpPairingPreparationsFinished.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpPairingPreparationsFinished { + fn default() -> &'a ThpPairingPreparationsFinished { + ::default_instance() + } +} + +impl ThpPairingPreparationsFinished { + pub fn new() -> ThpPairingPreparationsFinished { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpPairingPreparationsFinished", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpPairingPreparationsFinished { + const NAME: &'static str = "ThpPairingPreparationsFinished"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpPairingPreparationsFinished { + ThpPairingPreparationsFinished::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpPairingPreparationsFinished { + static instance: ThpPairingPreparationsFinished = ThpPairingPreparationsFinished { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpPairingPreparationsFinished { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpPairingPreparationsFinished").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpPairingPreparationsFinished { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpPairingPreparationsFinished { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCommitment) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCommitment { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCommitment.commitment) + pub commitment: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCommitment.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCommitment { + fn default() -> &'a ThpCodeEntryCommitment { + ::default_instance() + } +} + +impl ThpCodeEntryCommitment { + pub fn new() -> ThpCodeEntryCommitment { + ::std::default::Default::default() + } + + // optional bytes commitment = 1; + + pub fn commitment(&self) -> &[u8] { + match self.commitment.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_commitment(&mut self) { + self.commitment = ::std::option::Option::None; + } + + pub fn has_commitment(&self) -> bool { + self.commitment.is_some() + } + + // Param is passed by value, moved + pub fn set_commitment(&mut self, v: ::std::vec::Vec) { + self.commitment = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_commitment(&mut self) -> &mut ::std::vec::Vec { + if self.commitment.is_none() { + self.commitment = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.commitment.as_mut().unwrap() + } + + // Take field + pub fn take_commitment(&mut self) -> ::std::vec::Vec { + self.commitment.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "commitment", + |m: &ThpCodeEntryCommitment| { &m.commitment }, + |m: &mut ThpCodeEntryCommitment| { &mut m.commitment }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCommitment", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCommitment { + const NAME: &'static str = "ThpCodeEntryCommitment"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.commitment = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.commitment.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.commitment.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCommitment { + ThpCodeEntryCommitment::new() + } + + fn clear(&mut self) { + self.commitment = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCommitment { + static instance: ThpCodeEntryCommitment = ThpCodeEntryCommitment { + commitment: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCommitment { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCommitment").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCommitment { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCommitment { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryChallenge) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryChallenge { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryChallenge.challenge) + pub challenge: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryChallenge.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryChallenge { + fn default() -> &'a ThpCodeEntryChallenge { + ::default_instance() + } +} + +impl ThpCodeEntryChallenge { + pub fn new() -> ThpCodeEntryChallenge { + ::std::default::Default::default() + } + + // optional bytes challenge = 1; + + pub fn challenge(&self) -> &[u8] { + match self.challenge.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_challenge(&mut self) { + self.challenge = ::std::option::Option::None; + } + + pub fn has_challenge(&self) -> bool { + self.challenge.is_some() + } + + // Param is passed by value, moved + pub fn set_challenge(&mut self, v: ::std::vec::Vec) { + self.challenge = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_challenge(&mut self) -> &mut ::std::vec::Vec { + if self.challenge.is_none() { + self.challenge = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.challenge.as_mut().unwrap() + } + + // Take field + pub fn take_challenge(&mut self) -> ::std::vec::Vec { + self.challenge.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "challenge", + |m: &ThpCodeEntryChallenge| { &m.challenge }, + |m: &mut ThpCodeEntryChallenge| { &mut m.challenge }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryChallenge", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryChallenge { + const NAME: &'static str = "ThpCodeEntryChallenge"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.challenge = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.challenge.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.challenge.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryChallenge { + ThpCodeEntryChallenge::new() + } + + fn clear(&mut self) { + self.challenge = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryChallenge { + static instance: ThpCodeEntryChallenge = ThpCodeEntryChallenge { + challenge: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryChallenge { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryChallenge").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryChallenge { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryChallenge { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCpaceTrezor { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor.cpace_trezor_public_key) + pub cpace_trezor_public_key: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCpaceTrezor { + fn default() -> &'a ThpCodeEntryCpaceTrezor { + ::default_instance() + } +} + +impl ThpCodeEntryCpaceTrezor { + pub fn new() -> ThpCodeEntryCpaceTrezor { + ::std::default::Default::default() + } + + // optional bytes cpace_trezor_public_key = 1; + + pub fn cpace_trezor_public_key(&self) -> &[u8] { + match self.cpace_trezor_public_key.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_cpace_trezor_public_key(&mut self) { + self.cpace_trezor_public_key = ::std::option::Option::None; + } + + pub fn has_cpace_trezor_public_key(&self) -> bool { + self.cpace_trezor_public_key.is_some() + } + + // Param is passed by value, moved + pub fn set_cpace_trezor_public_key(&mut self, v: ::std::vec::Vec) { + self.cpace_trezor_public_key = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_cpace_trezor_public_key(&mut self) -> &mut ::std::vec::Vec { + if self.cpace_trezor_public_key.is_none() { + self.cpace_trezor_public_key = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.cpace_trezor_public_key.as_mut().unwrap() + } + + // Take field + pub fn take_cpace_trezor_public_key(&mut self) -> ::std::vec::Vec { + self.cpace_trezor_public_key.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "cpace_trezor_public_key", + |m: &ThpCodeEntryCpaceTrezor| { &m.cpace_trezor_public_key }, + |m: &mut ThpCodeEntryCpaceTrezor| { &mut m.cpace_trezor_public_key }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCpaceTrezor", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCpaceTrezor { + const NAME: &'static str = "ThpCodeEntryCpaceTrezor"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.cpace_trezor_public_key = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.cpace_trezor_public_key.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.cpace_trezor_public_key.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCpaceTrezor { + ThpCodeEntryCpaceTrezor::new() + } + + fn clear(&mut self) { + self.cpace_trezor_public_key = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCpaceTrezor { + static instance: ThpCodeEntryCpaceTrezor = ThpCodeEntryCpaceTrezor { + cpace_trezor_public_key: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCpaceTrezor { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCpaceTrezor").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCpaceTrezor { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCpaceTrezor { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCpaceHostTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCpaceHostTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceHostTag.cpace_host_public_key) + pub cpace_host_public_key: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceHostTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCpaceHostTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCpaceHostTag { + fn default() -> &'a ThpCodeEntryCpaceHostTag { + ::default_instance() + } +} + +impl ThpCodeEntryCpaceHostTag { + pub fn new() -> ThpCodeEntryCpaceHostTag { + ::std::default::Default::default() + } + + // optional bytes cpace_host_public_key = 1; + + pub fn cpace_host_public_key(&self) -> &[u8] { + match self.cpace_host_public_key.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_cpace_host_public_key(&mut self) { + self.cpace_host_public_key = ::std::option::Option::None; + } + + pub fn has_cpace_host_public_key(&self) -> bool { + self.cpace_host_public_key.is_some() + } + + // Param is passed by value, moved + pub fn set_cpace_host_public_key(&mut self, v: ::std::vec::Vec) { + self.cpace_host_public_key = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_cpace_host_public_key(&mut self) -> &mut ::std::vec::Vec { + if self.cpace_host_public_key.is_none() { + self.cpace_host_public_key = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.cpace_host_public_key.as_mut().unwrap() + } + + // Take field + pub fn take_cpace_host_public_key(&mut self) -> ::std::vec::Vec { + self.cpace_host_public_key.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes tag = 2; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "cpace_host_public_key", + |m: &ThpCodeEntryCpaceHostTag| { &m.cpace_host_public_key }, + |m: &mut ThpCodeEntryCpaceHostTag| { &mut m.cpace_host_public_key }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpCodeEntryCpaceHostTag| { &m.tag }, + |m: &mut ThpCodeEntryCpaceHostTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCpaceHostTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCpaceHostTag { + const NAME: &'static str = "ThpCodeEntryCpaceHostTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.cpace_host_public_key = ::std::option::Option::Some(is.read_bytes()?); + }, + 18 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.cpace_host_public_key.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.cpace_host_public_key.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.tag.as_ref() { + os.write_bytes(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCpaceHostTag { + ThpCodeEntryCpaceHostTag::new() + } + + fn clear(&mut self) { + self.cpace_host_public_key = ::std::option::Option::None; + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCpaceHostTag { + static instance: ThpCodeEntryCpaceHostTag = ThpCodeEntryCpaceHostTag { + cpace_host_public_key: ::std::option::Option::None, + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCpaceHostTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCpaceHostTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCpaceHostTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCpaceHostTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntrySecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntrySecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntrySecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntrySecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntrySecret { + fn default() -> &'a ThpCodeEntrySecret { + ::default_instance() + } +} + +impl ThpCodeEntrySecret { + pub fn new() -> ThpCodeEntrySecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpCodeEntrySecret| { &m.secret }, + |m: &mut ThpCodeEntrySecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntrySecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntrySecret { + const NAME: &'static str = "ThpCodeEntrySecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntrySecret { + ThpCodeEntrySecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntrySecret { + static instance: ThpCodeEntrySecret = ThpCodeEntrySecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntrySecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntrySecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntrySecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntrySecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpQrCodeTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpQrCodeTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpQrCodeTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpQrCodeTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpQrCodeTag { + fn default() -> &'a ThpQrCodeTag { + ::default_instance() + } +} + +impl ThpQrCodeTag { + pub fn new() -> ThpQrCodeTag { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpQrCodeTag| { &m.tag }, + |m: &mut ThpQrCodeTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpQrCodeTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpQrCodeTag { + const NAME: &'static str = "ThpQrCodeTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpQrCodeTag { + ThpQrCodeTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpQrCodeTag { + static instance: ThpQrCodeTag = ThpQrCodeTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpQrCodeTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpQrCodeTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpQrCodeTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpQrCodeTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpQrCodeSecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpQrCodeSecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpQrCodeSecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpQrCodeSecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpQrCodeSecret { + fn default() -> &'a ThpQrCodeSecret { + ::default_instance() + } +} + +impl ThpQrCodeSecret { + pub fn new() -> ThpQrCodeSecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpQrCodeSecret| { &m.secret }, + |m: &mut ThpQrCodeSecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpQrCodeSecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpQrCodeSecret { + const NAME: &'static str = "ThpQrCodeSecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpQrCodeSecret { + ThpQrCodeSecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpQrCodeSecret { + static instance: ThpQrCodeSecret = ThpQrCodeSecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpQrCodeSecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpQrCodeSecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpQrCodeSecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpQrCodeSecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNfcTagHost) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNfcTagHost { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNfcTagHost.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNfcTagHost.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNfcTagHost { + fn default() -> &'a ThpNfcTagHost { + ::default_instance() + } +} + +impl ThpNfcTagHost { + pub fn new() -> ThpNfcTagHost { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpNfcTagHost| { &m.tag }, + |m: &mut ThpNfcTagHost| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNfcTagHost", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNfcTagHost { + const NAME: &'static str = "ThpNfcTagHost"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNfcTagHost { + ThpNfcTagHost::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNfcTagHost { + static instance: ThpNfcTagHost = ThpNfcTagHost { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNfcTagHost { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNfcTagHost").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNfcTagHost { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNfcTagHost { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNfcTagTrezor) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNfcTagTrezor { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNfcTagTrezor.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNfcTagTrezor.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNfcTagTrezor { + fn default() -> &'a ThpNfcTagTrezor { + ::default_instance() + } +} + +impl ThpNfcTagTrezor { + pub fn new() -> ThpNfcTagTrezor { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpNfcTagTrezor| { &m.tag }, + |m: &mut ThpNfcTagTrezor| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNfcTagTrezor", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNfcTagTrezor { + const NAME: &'static str = "ThpNfcTagTrezor"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNfcTagTrezor { + ThpNfcTagTrezor::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNfcTagTrezor { + static instance: ThpNfcTagTrezor = ThpNfcTagTrezor { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNfcTagTrezor { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNfcTagTrezor").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNfcTagTrezor { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNfcTagTrezor { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCredentialRequest { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialRequest.host_static_pubkey) + pub host_static_pubkey: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialRequest.autoconnect) + pub autoconnect: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCredentialRequest { + fn default() -> &'a ThpCredentialRequest { + ::default_instance() + } +} + +impl ThpCredentialRequest { + pub fn new() -> ThpCredentialRequest { + ::std::default::Default::default() + } + + // optional bytes host_static_pubkey = 1; + + pub fn host_static_pubkey(&self) -> &[u8] { + match self.host_static_pubkey.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_host_static_pubkey(&mut self) { + self.host_static_pubkey = ::std::option::Option::None; + } + + pub fn has_host_static_pubkey(&self) -> bool { + self.host_static_pubkey.is_some() + } + + // Param is passed by value, moved + pub fn set_host_static_pubkey(&mut self, v: ::std::vec::Vec) { + self.host_static_pubkey = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_static_pubkey(&mut self) -> &mut ::std::vec::Vec { + if self.host_static_pubkey.is_none() { + self.host_static_pubkey = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.host_static_pubkey.as_mut().unwrap() + } + + // Take field + pub fn take_host_static_pubkey(&mut self) -> ::std::vec::Vec { + self.host_static_pubkey.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bool autoconnect = 2; + + pub fn autoconnect(&self) -> bool { + self.autoconnect.unwrap_or(false) + } + + pub fn clear_autoconnect(&mut self) { + self.autoconnect = ::std::option::Option::None; + } + + pub fn has_autoconnect(&self) -> bool { + self.autoconnect.is_some() + } + + // Param is passed by value, moved + pub fn set_autoconnect(&mut self, v: bool) { + self.autoconnect = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_static_pubkey", + |m: &ThpCredentialRequest| { &m.host_static_pubkey }, + |m: &mut ThpCredentialRequest| { &mut m.host_static_pubkey }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "autoconnect", + |m: &ThpCredentialRequest| { &m.autoconnect }, + |m: &mut ThpCredentialRequest| { &mut m.autoconnect }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCredentialRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCredentialRequest { + const NAME: &'static str = "ThpCredentialRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_static_pubkey = ::std::option::Option::Some(is.read_bytes()?); + }, + 16 => { + self.autoconnect = ::std::option::Option::Some(is.read_bool()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_static_pubkey.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.autoconnect { + my_size += 1 + 1; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_static_pubkey.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.autoconnect { + os.write_bool(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCredentialRequest { + ThpCredentialRequest::new() + } + + fn clear(&mut self) { + self.host_static_pubkey = ::std::option::Option::None; + self.autoconnect = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCredentialRequest { + static instance: ThpCredentialRequest = ThpCredentialRequest { + host_static_pubkey: ::std::option::Option::None, + autoconnect: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCredentialRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCredentialRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCredentialRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCredentialRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCredentialResponse { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialResponse.trezor_static_pubkey) + pub trezor_static_pubkey: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialResponse.credential) + pub credential: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCredentialResponse { + fn default() -> &'a ThpCredentialResponse { + ::default_instance() + } +} + +impl ThpCredentialResponse { + pub fn new() -> ThpCredentialResponse { + ::std::default::Default::default() + } + + // optional bytes trezor_static_pubkey = 1; + + pub fn trezor_static_pubkey(&self) -> &[u8] { + match self.trezor_static_pubkey.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_trezor_static_pubkey(&mut self) { + self.trezor_static_pubkey = ::std::option::Option::None; + } + + pub fn has_trezor_static_pubkey(&self) -> bool { + self.trezor_static_pubkey.is_some() + } + + // Param is passed by value, moved + pub fn set_trezor_static_pubkey(&mut self, v: ::std::vec::Vec) { + self.trezor_static_pubkey = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_trezor_static_pubkey(&mut self) -> &mut ::std::vec::Vec { + if self.trezor_static_pubkey.is_none() { + self.trezor_static_pubkey = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.trezor_static_pubkey.as_mut().unwrap() + } + + // Take field + pub fn take_trezor_static_pubkey(&mut self) -> ::std::vec::Vec { + self.trezor_static_pubkey.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes credential = 2; + + pub fn credential(&self) -> &[u8] { + match self.credential.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_credential(&mut self) { + self.credential = ::std::option::Option::None; + } + + pub fn has_credential(&self) -> bool { + self.credential.is_some() + } + + // Param is passed by value, moved + pub fn set_credential(&mut self, v: ::std::vec::Vec) { + self.credential = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_credential(&mut self) -> &mut ::std::vec::Vec { + if self.credential.is_none() { + self.credential = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.credential.as_mut().unwrap() + } + + // Take field + pub fn take_credential(&mut self) -> ::std::vec::Vec { + self.credential.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "trezor_static_pubkey", + |m: &ThpCredentialResponse| { &m.trezor_static_pubkey }, + |m: &mut ThpCredentialResponse| { &mut m.trezor_static_pubkey }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "credential", + |m: &ThpCredentialResponse| { &m.credential }, + |m: &mut ThpCredentialResponse| { &mut m.credential }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCredentialResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCredentialResponse { + const NAME: &'static str = "ThpCredentialResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.trezor_static_pubkey = ::std::option::Option::Some(is.read_bytes()?); + }, + 18 => { + self.credential = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.trezor_static_pubkey.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.credential.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.trezor_static_pubkey.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.credential.as_ref() { + os.write_bytes(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCredentialResponse { + ThpCredentialResponse::new() + } + + fn clear(&mut self) { + self.trezor_static_pubkey = ::std::option::Option::None; + self.credential = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCredentialResponse { + static instance: ThpCredentialResponse = ThpCredentialResponse { + trezor_static_pubkey: ::std::option::Option::None, + credential: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCredentialResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCredentialResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCredentialResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCredentialResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpEndRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpEndRequest { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpEndRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpEndRequest { + fn default() -> &'a ThpEndRequest { + ::default_instance() + } +} + +impl ThpEndRequest { + pub fn new() -> ThpEndRequest { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpEndRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpEndRequest { + const NAME: &'static str = "ThpEndRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpEndRequest { + ThpEndRequest::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpEndRequest { + static instance: ThpEndRequest = ThpEndRequest { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpEndRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpEndRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpEndRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpEndRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpEndResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpEndResponse { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpEndResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpEndResponse { + fn default() -> &'a ThpEndResponse { + ::default_instance() + } +} + +impl ThpEndResponse { + pub fn new() -> ThpEndResponse { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpEndResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpEndResponse { + const NAME: &'static str = "ThpEndResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpEndResponse { + ThpEndResponse::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpEndResponse { + static instance: ThpEndResponse = ThpEndResponse { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpEndResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpEndResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpEndResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpEndResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + // @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialMetadata) #[derive(PartialEq,Clone,Default,Debug)] pub struct ThpCredentialMetadata { // message fields // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialMetadata.host_name) pub host_name: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialMetadata.autoconnect) + pub autoconnect: ::std::option::Option, // special fields // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialMetadata.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -83,14 +3362,38 @@ impl ThpCredentialMetadata { self.host_name.take().unwrap_or_else(|| ::std::string::String::new()) } + // optional bool autoconnect = 2; + + pub fn autoconnect(&self) -> bool { + self.autoconnect.unwrap_or(false) + } + + pub fn clear_autoconnect(&mut self) { + self.autoconnect = ::std::option::Option::None; + } + + pub fn has_autoconnect(&self) -> bool { + self.autoconnect.is_some() + } + + // Param is passed by value, moved + pub fn set_autoconnect(&mut self, v: bool) { + self.autoconnect = ::std::option::Option::Some(v); + } + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(1); + let mut fields = ::std::vec::Vec::with_capacity(2); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( "host_name", |m: &ThpCredentialMetadata| { &m.host_name }, |m: &mut ThpCredentialMetadata| { &mut m.host_name }, )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "autoconnect", + |m: &ThpCredentialMetadata| { &m.autoconnect }, + |m: &mut ThpCredentialMetadata| { &mut m.autoconnect }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "ThpCredentialMetadata", fields, @@ -112,6 +3415,9 @@ impl ::protobuf::Message for ThpCredentialMetadata { 10 => { self.host_name = ::std::option::Option::Some(is.read_string()?); }, + 16 => { + self.autoconnect = ::std::option::Option::Some(is.read_bool()?); + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -127,6 +3433,9 @@ impl ::protobuf::Message for ThpCredentialMetadata { if let Some(v) = self.host_name.as_ref() { my_size += ::protobuf::rt::string_size(1, &v); } + if let Some(v) = self.autoconnect { + my_size += 1 + 1; + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -136,6 +3445,9 @@ impl ::protobuf::Message for ThpCredentialMetadata { if let Some(v) = self.host_name.as_ref() { os.write_string(1, v)?; } + if let Some(v) = self.autoconnect { + os.write_bool(2, v)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -154,12 +3466,14 @@ impl ::protobuf::Message for ThpCredentialMetadata { fn clear(&mut self) { self.host_name = ::std::option::Option::None; + self.autoconnect = ::std::option::Option::None; self.special_fields.clear(); } fn default_instance() -> &'static ThpCredentialMetadata { static instance: ThpCredentialMetadata = ThpCredentialMetadata { host_name: ::std::option::Option::None, + autoconnect: ::std::option::Option::None, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -537,17 +3851,313 @@ impl ::protobuf::reflect::ProtobufValue for ThpAuthenticatedCredentialData { type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; } +#[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] +// @@protoc_insertion_point(enum:hw.trezor.messages.thp.ThpMessageType) +pub enum ThpMessageType { + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCreateNewSession) + ThpMessageType_ThpCreateNewSession = 1000, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpPairingRequest) + ThpMessageType_ThpPairingRequest = 1006, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpPairingRequestApproved) + ThpMessageType_ThpPairingRequestApproved = 1007, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpSelectMethod) + ThpMessageType_ThpSelectMethod = 1008, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpPairingPreparationsFinished) + ThpMessageType_ThpPairingPreparationsFinished = 1009, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCredentialRequest) + ThpMessageType_ThpCredentialRequest = 1010, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCredentialResponse) + ThpMessageType_ThpCredentialResponse = 1011, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpEndRequest) + ThpMessageType_ThpEndRequest = 1012, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpEndResponse) + ThpMessageType_ThpEndResponse = 1013, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryCommitment) + ThpMessageType_ThpCodeEntryCommitment = 1016, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryChallenge) + ThpMessageType_ThpCodeEntryChallenge = 1017, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryCpaceTrezor) + ThpMessageType_ThpCodeEntryCpaceTrezor = 1018, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryCpaceHostTag) + ThpMessageType_ThpCodeEntryCpaceHostTag = 1019, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntrySecret) + ThpMessageType_ThpCodeEntrySecret = 1020, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpQrCodeTag) + ThpMessageType_ThpQrCodeTag = 1024, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpQrCodeSecret) + ThpMessageType_ThpQrCodeSecret = 1025, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpNfcTagHost) + ThpMessageType_ThpNfcTagHost = 1032, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpNfcTagTrezor) + ThpMessageType_ThpNfcTagTrezor = 1033, +} + +impl ::protobuf::Enum for ThpMessageType { + const NAME: &'static str = "ThpMessageType"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1000 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCreateNewSession), + 1006 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingRequest), + 1007 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingRequestApproved), + 1008 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpSelectMethod), + 1009 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished), + 1010 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialRequest), + 1011 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialResponse), + 1012 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndRequest), + 1013 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndResponse), + 1016 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCommitment), + 1017 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryChallenge), + 1018 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor), + 1019 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHostTag), + 1020 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntrySecret), + 1024 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeTag), + 1025 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeSecret), + 1032 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcTagHost), + 1033 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcTagTrezor), + _ => ::std::option::Option::None + } + } + + fn from_str(str: &str) -> ::std::option::Option { + match str { + "ThpMessageType_ThpCreateNewSession" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCreateNewSession), + "ThpMessageType_ThpPairingRequest" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingRequest), + "ThpMessageType_ThpPairingRequestApproved" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingRequestApproved), + "ThpMessageType_ThpSelectMethod" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpSelectMethod), + "ThpMessageType_ThpPairingPreparationsFinished" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished), + "ThpMessageType_ThpCredentialRequest" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialRequest), + "ThpMessageType_ThpCredentialResponse" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialResponse), + "ThpMessageType_ThpEndRequest" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndRequest), + "ThpMessageType_ThpEndResponse" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndResponse), + "ThpMessageType_ThpCodeEntryCommitment" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCommitment), + "ThpMessageType_ThpCodeEntryChallenge" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryChallenge), + "ThpMessageType_ThpCodeEntryCpaceTrezor" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor), + "ThpMessageType_ThpCodeEntryCpaceHostTag" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHostTag), + "ThpMessageType_ThpCodeEntrySecret" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntrySecret), + "ThpMessageType_ThpQrCodeTag" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeTag), + "ThpMessageType_ThpQrCodeSecret" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeSecret), + "ThpMessageType_ThpNfcTagHost" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcTagHost), + "ThpMessageType_ThpNfcTagTrezor" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcTagTrezor), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [ThpMessageType] = &[ + ThpMessageType::ThpMessageType_ThpCreateNewSession, + ThpMessageType::ThpMessageType_ThpPairingRequest, + ThpMessageType::ThpMessageType_ThpPairingRequestApproved, + ThpMessageType::ThpMessageType_ThpSelectMethod, + ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished, + ThpMessageType::ThpMessageType_ThpCredentialRequest, + ThpMessageType::ThpMessageType_ThpCredentialResponse, + ThpMessageType::ThpMessageType_ThpEndRequest, + ThpMessageType::ThpMessageType_ThpEndResponse, + ThpMessageType::ThpMessageType_ThpCodeEntryCommitment, + ThpMessageType::ThpMessageType_ThpCodeEntryChallenge, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHostTag, + ThpMessageType::ThpMessageType_ThpCodeEntrySecret, + ThpMessageType::ThpMessageType_ThpQrCodeTag, + ThpMessageType::ThpMessageType_ThpQrCodeSecret, + ThpMessageType::ThpMessageType_ThpNfcTagHost, + ThpMessageType::ThpMessageType_ThpNfcTagTrezor, + ]; +} + +impl ::protobuf::EnumFull for ThpMessageType { + fn enum_descriptor() -> ::protobuf::reflect::EnumDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::EnumDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().enum_by_package_relative_name("ThpMessageType").unwrap()).clone() + } + + fn descriptor(&self) -> ::protobuf::reflect::EnumValueDescriptor { + let index = match self { + ThpMessageType::ThpMessageType_ThpCreateNewSession => 0, + ThpMessageType::ThpMessageType_ThpPairingRequest => 1, + ThpMessageType::ThpMessageType_ThpPairingRequestApproved => 2, + ThpMessageType::ThpMessageType_ThpSelectMethod => 3, + ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished => 4, + ThpMessageType::ThpMessageType_ThpCredentialRequest => 5, + ThpMessageType::ThpMessageType_ThpCredentialResponse => 6, + ThpMessageType::ThpMessageType_ThpEndRequest => 7, + ThpMessageType::ThpMessageType_ThpEndResponse => 8, + ThpMessageType::ThpMessageType_ThpCodeEntryCommitment => 9, + ThpMessageType::ThpMessageType_ThpCodeEntryChallenge => 10, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor => 11, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHostTag => 12, + ThpMessageType::ThpMessageType_ThpCodeEntrySecret => 13, + ThpMessageType::ThpMessageType_ThpQrCodeTag => 14, + ThpMessageType::ThpMessageType_ThpQrCodeSecret => 15, + ThpMessageType::ThpMessageType_ThpNfcTagHost => 16, + ThpMessageType::ThpMessageType_ThpNfcTagTrezor => 17, + }; + Self::enum_descriptor().value_by_index(index) + } +} + +// Note, `Default` is implemented although default value is not 0 +impl ::std::default::Default for ThpMessageType { + fn default() -> Self { + ThpMessageType::ThpMessageType_ThpCreateNewSession + } +} + +impl ThpMessageType { + fn generated_enum_descriptor_data() -> ::protobuf::reflect::GeneratedEnumDescriptorData { + ::protobuf::reflect::GeneratedEnumDescriptorData::new::("ThpMessageType") + } +} + +#[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] +// @@protoc_insertion_point(enum:hw.trezor.messages.thp.ThpPairingMethod) +pub enum ThpPairingMethod { + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.SkipPairing) + SkipPairing = 1, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.CodeEntry) + CodeEntry = 2, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.QrCode) + QrCode = 3, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.NFC) + NFC = 4, +} + +impl ::protobuf::Enum for ThpPairingMethod { + const NAME: &'static str = "ThpPairingMethod"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1 => ::std::option::Option::Some(ThpPairingMethod::SkipPairing), + 2 => ::std::option::Option::Some(ThpPairingMethod::CodeEntry), + 3 => ::std::option::Option::Some(ThpPairingMethod::QrCode), + 4 => ::std::option::Option::Some(ThpPairingMethod::NFC), + _ => ::std::option::Option::None + } + } + + fn from_str(str: &str) -> ::std::option::Option { + match str { + "SkipPairing" => ::std::option::Option::Some(ThpPairingMethod::SkipPairing), + "CodeEntry" => ::std::option::Option::Some(ThpPairingMethod::CodeEntry), + "QrCode" => ::std::option::Option::Some(ThpPairingMethod::QrCode), + "NFC" => ::std::option::Option::Some(ThpPairingMethod::NFC), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [ThpPairingMethod] = &[ + ThpPairingMethod::SkipPairing, + ThpPairingMethod::CodeEntry, + ThpPairingMethod::QrCode, + ThpPairingMethod::NFC, + ]; +} + +impl ::protobuf::EnumFull for ThpPairingMethod { + fn enum_descriptor() -> ::protobuf::reflect::EnumDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::EnumDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().enum_by_package_relative_name("ThpPairingMethod").unwrap()).clone() + } + + fn descriptor(&self) -> ::protobuf::reflect::EnumValueDescriptor { + let index = match self { + ThpPairingMethod::SkipPairing => 0, + ThpPairingMethod::CodeEntry => 1, + ThpPairingMethod::QrCode => 2, + ThpPairingMethod::NFC => 3, + }; + Self::enum_descriptor().value_by_index(index) + } +} + +// Note, `Default` is implemented although default value is not 0 +impl ::std::default::Default for ThpPairingMethod { + fn default() -> Self { + ThpPairingMethod::SkipPairing + } +} + +impl ThpPairingMethod { + fn generated_enum_descriptor_data() -> ::protobuf::reflect::GeneratedEnumDescriptorData { + ::protobuf::reflect::GeneratedEnumDescriptorData::new::("ThpPairingMethod") + } +} + static file_descriptor_proto_data: &'static [u8] = b"\ \n\x12messages-thp.proto\x12\x16hw.trezor.messages.thp\x1a\roptions.prot\ - o\":\n\x15ThpCredentialMetadata\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\ - \x08hostName:\x04\x98\xb2\x19\x01\"\x82\x01\n\x14ThpPairingCredential\ + o\"\xa0\x02\n\x13ThpDeviceProperties\x12%\n\x0einternal_model\x18\x01\ + \x20\x01(\tR\rinternalModel\x12#\n\rmodel_variant\x18\x02\x20\x01(\rR\ + \x0cmodelVariant\x124\n\x16protocol_version_major\x18\x03\x20\x01(\rR\ + \x14protocolVersionMajor\x124\n\x16protocol_version_minor\x18\x04\x20\ + \x01(\rR\x14protocolVersionMinor\x12Q\n\x0fpairing_methods\x18\x05\x20\ + \x03(\x0e2(.hw.trezor.messages.thp.ThpPairingMethodR\x0epairingMethods\"\ + _\n%ThpHandshakeCompletionReqNoisePayload\x126\n\x17host_pairing_credent\ + ial\x18\x01\x20\x01(\x0cR\x15hostPairingCredential\"y\n\x13ThpCreateNewS\ + ession\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x1b\n\t\ + on_device\x18\x02\x20\x01(\x08R\x08onDevice\x12%\n\x0ederive_cardano\x18\ + \x03\x20\x01(\x08R\rderiveCardano\"0\n\x11ThpPairingRequest\x12\x1b\n\th\ + ost_name\x18\x01\x20\x01(\tR\x08hostName\"\x1b\n\x19ThpPairingRequestApp\ + roved\"s\n\x0fThpSelectMethod\x12`\n\x17selected_pairing_method\x18\x01\ + \x20\x01(\x0e2(.hw.trezor.messages.thp.ThpPairingMethodR\x15selectedPair\ + ingMethod\"\x20\n\x1eThpPairingPreparationsFinished\"8\n\x16ThpCodeEntry\ + Commitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitment\"5\n\ + \x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\x0cR\tch\ + allenge\"P\n\x17ThpCodeEntryCpaceTrezor\x125\n\x17cpace_trezor_public_ke\ + y\x18\x01\x20\x01(\x0cR\x14cpaceTrezorPublicKey\"_\n\x18ThpCodeEntryCpac\ + eHostTag\x121\n\x15cpace_host_public_key\x18\x01\x20\x01(\x0cR\x12cpaceH\ + ostPublicKey\x12\x10\n\x03tag\x18\x02\x20\x01(\x0cR\x03tag\",\n\x12ThpCo\ + deEntrySecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\x06secret\"\x20\ + \n\x0cThpQrCodeTag\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\")\n\ + \x0fThpQrCodeSecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\x06secret\ + \"!\n\rThpNfcTagHost\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\"#\n\ + \x0fThpNfcTagTrezor\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\"f\n\ + \x14ThpCredentialRequest\x12,\n\x12host_static_pubkey\x18\x01\x20\x01(\ + \x0cR\x10hostStaticPubkey\x12\x20\n\x0bautoconnect\x18\x02\x20\x01(\x08R\ + \x0bautoconnect\"i\n\x15ThpCredentialResponse\x120\n\x14trezor_static_pu\ + bkey\x18\x01\x20\x01(\x0cR\x12trezorStaticPubkey\x12\x1e\n\ncredential\ + \x18\x02\x20\x01(\x0cR\ncredential\"\x0f\n\rThpEndRequest\"\x10\n\x0eThp\ + EndResponse\"\\\n\x15ThpCredentialMetadata\x12\x1b\n\thost_name\x18\x01\ + \x20\x01(\tR\x08hostName\x12\x20\n\x0bautoconnect\x18\x02\x20\x01(\x08R\ + \x0bautoconnect:\x04\x98\xb2\x19\x01\"\x82\x01\n\x14ThpPairingCredential\ \x12R\n\rcred_metadata\x18\x01\x20\x01(\x0b2-.hw.trezor.messages.thp.Thp\ CredentialMetadataR\x0ccredMetadata\x12\x10\n\x03mac\x18\x02\x20\x01(\ \x0cR\x03mac:\x04\x98\xb2\x19\x01\"\xa8\x01\n\x1eThpAuthenticatedCredent\ ialData\x12,\n\x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStatic\ Pubkey\x12R\n\rcred_metadata\x18\x02\x20\x01(\x0b2-.hw.trezor.messages.t\ - hp.ThpCredentialMetadataR\x0ccredMetadata:\x04\x98\xb2\x19\x01B;\n#com.s\ - atoshilabs.trezor.lib.protobufB\x10TrezorMessageThp\x80\xa6\x1d\x01\ + hp.ThpCredentialMetadataR\x0ccredMetadata:\x04\x98\xb2\x19\x01*\xeb\x06\ + \n\x0eThpMessageType\x12-\n\"ThpMessageType_ThpCreateNewSession\x10\xe8\ + \x07\x1a\x04\x80\xa6\x1d\x01\x12+\n\x20ThpMessageType_ThpPairingRequest\ + \x10\xee\x07\x1a\x04\x80\xa6\x1d\x01\x123\n(ThpMessageType_ThpPairingReq\ + uestApproved\x10\xef\x07\x1a\x04\x80\xa6\x1d\x01\x12)\n\x1eThpMessageTyp\ + e_ThpSelectMethod\x10\xf0\x07\x1a\x04\x80\xa6\x1d\x01\x128\n-ThpMessageT\ + ype_ThpPairingPreparationsFinished\x10\xf1\x07\x1a\x04\x80\xa6\x1d\x01\ + \x12.\n#ThpMessageType_ThpCredentialRequest\x10\xf2\x07\x1a\x04\x80\xa6\ + \x1d\x01\x12/\n$ThpMessageType_ThpCredentialResponse\x10\xf3\x07\x1a\x04\ + \x80\xa6\x1d\x01\x12'\n\x1cThpMessageType_ThpEndRequest\x10\xf4\x07\x1a\ + \x04\x80\xa6\x1d\x01\x12(\n\x1dThpMessageType_ThpEndResponse\x10\xf5\x07\ + \x1a\x04\x80\xa6\x1d\x01\x120\n%ThpMessageType_ThpCodeEntryCommitment\ + \x10\xf8\x07\x1a\x04\x80\xa6\x1d\x01\x12/\n$ThpMessageType_ThpCodeEntryC\ + hallenge\x10\xf9\x07\x1a\x04\x80\xa6\x1d\x01\x121\n&ThpMessageType_ThpCo\ + deEntryCpaceTrezor\x10\xfa\x07\x1a\x04\x80\xa6\x1d\x01\x122\n'ThpMessage\ + Type_ThpCodeEntryCpaceHostTag\x10\xfb\x07\x1a\x04\x80\xa6\x1d\x01\x12,\n\ + !ThpMessageType_ThpCodeEntrySecret\x10\xfc\x07\x1a\x04\x80\xa6\x1d\x01\ + \x12&\n\x1bThpMessageType_ThpQrCodeTag\x10\x80\x08\x1a\x04\x80\xa6\x1d\ + \x01\x12)\n\x1eThpMessageType_ThpQrCodeSecret\x10\x81\x08\x1a\x04\x80\ + \xa6\x1d\x01\x12'\n\x1cThpMessageType_ThpNfcTagHost\x10\x88\x08\x1a\x04\ + \x80\xa6\x1d\x01\x12)\n\x1eThpMessageType_ThpNfcTagTrezor\x10\x89\x08\ + \x1a\x04\x80\xa6\x1d\x01\"\x05\x08\0\x10\xe7\x07\"\t\x08\xcc\x08\x10\xff\ + \xff\xff\xff\x07*G\n\x10ThpPairingMethod\x12\x0f\n\x0bSkipPairing\x10\ + \x01\x12\r\n\tCodeEntry\x10\x02\x12\n\n\x06QrCode\x10\x03\x12\x07\n\x03N\ + FC\x10\x04B;\n#com.satoshilabs.trezor.lib.protobufB\x10TrezorMessageThp\ + \x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -566,11 +4176,33 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { let generated_file_descriptor = generated_file_descriptor_lazy.get(|| { let mut deps = ::std::vec::Vec::with_capacity(1); deps.push(super::options::file_descriptor().clone()); - let mut messages = ::std::vec::Vec::with_capacity(3); + let mut messages = ::std::vec::Vec::with_capacity(23); + messages.push(ThpDeviceProperties::generated_message_descriptor_data()); + messages.push(ThpHandshakeCompletionReqNoisePayload::generated_message_descriptor_data()); + messages.push(ThpCreateNewSession::generated_message_descriptor_data()); + messages.push(ThpPairingRequest::generated_message_descriptor_data()); + messages.push(ThpPairingRequestApproved::generated_message_descriptor_data()); + messages.push(ThpSelectMethod::generated_message_descriptor_data()); + messages.push(ThpPairingPreparationsFinished::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCommitment::generated_message_descriptor_data()); + messages.push(ThpCodeEntryChallenge::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCpaceTrezor::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCpaceHostTag::generated_message_descriptor_data()); + messages.push(ThpCodeEntrySecret::generated_message_descriptor_data()); + messages.push(ThpQrCodeTag::generated_message_descriptor_data()); + messages.push(ThpQrCodeSecret::generated_message_descriptor_data()); + messages.push(ThpNfcTagHost::generated_message_descriptor_data()); + messages.push(ThpNfcTagTrezor::generated_message_descriptor_data()); + messages.push(ThpCredentialRequest::generated_message_descriptor_data()); + messages.push(ThpCredentialResponse::generated_message_descriptor_data()); + messages.push(ThpEndRequest::generated_message_descriptor_data()); + messages.push(ThpEndResponse::generated_message_descriptor_data()); messages.push(ThpCredentialMetadata::generated_message_descriptor_data()); messages.push(ThpPairingCredential::generated_message_descriptor_data()); messages.push(ThpAuthenticatedCredentialData::generated_message_descriptor_data()); - let mut enums = ::std::vec::Vec::with_capacity(0); + let mut enums = ::std::vec::Vec::with_capacity(2); + enums.push(ThpMessageType::generated_enum_descriptor_data()); + enums.push(ThpPairingMethod::generated_enum_descriptor_data()); ::protobuf::reflect::GeneratedFileDescriptor::new_generated( file_descriptor_proto(), deps, diff --git a/tests/REGISTERED_MARKERS b/tests/REGISTERED_MARKERS index fab4ec8b3a3..bec85ca898b 100644 --- a/tests/REGISTERED_MARKERS +++ b/tests/REGISTERED_MARKERS @@ -11,6 +11,7 @@ multisig nem ontology peercoin +protocol ripple sd_card solana diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index 78ef5c9e13d..1be8303b620 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -21,7 +21,9 @@ import pytest from trezorlib import btc, device, exceptions, messages +from trezorlib.client import PASSPHRASE_ON_DEVICE from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import MessageType from trezorlib.tools import parse_path @@ -58,8 +60,8 @@ def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int): debug = device_handler.debuglink() - - device_handler.run(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore + Session(device_handler.client.get_seedless_session()).lock() + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore assert "PinKeyboard" in debug.read_layout().all_components() @@ -98,7 +100,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore + device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore assert ( "1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1" @@ -136,6 +138,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() + + # Prepare session to use later + session = Session(device_handler.client.get_session()) + # try to sign a transaction inp1 = messages.TxInputType( address_n=parse_path("86h/0h/0h/0/0"), @@ -151,8 +157,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run( - btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + device_handler.run_with_provided_session( + session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert ( @@ -182,14 +188,14 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.TxAck, None) + session.set_filter(messages.TxAck, None) return msg - with device_handler.client: - device_handler.client.set_filter(messages.TxAck, sleepy_filter) + with session, device_handler.client: + session.set_filter(messages.TxAck, sleepy_filter) # confirm transaction if debug.layout_type is LayoutType.Bolt: - debug.click(buttons.OK) + debug.click(buttons.OK, hold_ms=1000) elif debug.layout_type is LayoutType.Delizia: debug.click(buttons.TAP_TO_CONFIRM) elif debug.layout_type is LayoutType.Caesar: @@ -198,7 +204,6 @@ def sleepy_filter(msg: MessageType) -> MessageType: signatures, tx = device_handler.result() assert len(signatures) == 1 assert tx - assert device_handler.features().unlocked is False @@ -208,8 +213,10 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler") debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore - + session = Session( + device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE) + ) + device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() if debug.layout_type is LayoutType.Caesar: @@ -250,8 +257,10 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore - + session = Session( + device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE) + ) + device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() if debug.layout_type is LayoutType.Caesar: @@ -290,7 +299,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) layout = unlock_dry_run(debug) assert TR.recovery__num_of_words in debug.read_layout().text_content() @@ -323,7 +332,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -350,7 +359,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -415,7 +424,11 @@ def test_autolock_does_not_interrupt_preauthorized( debug = device_handler.debuglink() - device_handler.run( + # Prepare session to use later + session = Session(device_handler.client.get_session()) + + device_handler.run_with_provided_session( + session, btc.authorize_coinjoin, coordinator="www.example.com", max_rounds=2, @@ -529,14 +542,15 @@ def test_autolock_does_not_interrupt_preauthorized( def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.SignTx, None) + session.set_filter(messages.SignTx, None) return msg - with device_handler.client: + with session: # Start DoPreauthorized flow when device is unlocked. Wait 10s before # delivering SignTx, by that time autolock timer should have fired. - device_handler.client.set_filter(messages.SignTx, sleepy_filter) - device_handler.run( + session.set_filter(messages.SignTx, sleepy_filter) + device_handler.run_with_provided_session( + session, btc.sign_tx, "Testnet", inputs, diff --git a/tests/click_tests/test_backup_slip39_custom.py b/tests/click_tests/test_backup_slip39_custom.py index c98752d2c03..98dff0cc8a6 100644 --- a/tests/click_tests/test_backup_slip39_custom.py +++ b/tests/click_tests/test_backup_slip39_custom.py @@ -52,7 +52,9 @@ def test_backup_slip39_custom( assert features.initialized is False - device_handler.run( + session = device_handler.client.get_seedless_session() + device_handler.run_with_provided_session( + session, device.setup, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -71,7 +73,7 @@ def test_backup_slip39_custom( # retrieve the result to check that it's not a TrezorFailure exception device_handler.result() - device_handler.run( + device_handler.run_with_session( device.backup, group_threshold=group_threshold, groups=[(share_threshold, share_count)], diff --git a/tests/click_tests/test_lock.py b/tests/click_tests/test_lock.py index 9a0340910f6..f0ae89d78ef 100644 --- a/tests/click_tests/test_lock.py +++ b/tests/click_tests/test_lock.py @@ -19,7 +19,7 @@ import pytest -from trezorlib import models +from trezorlib import messages, models from trezorlib.debuglink import LayoutType from .. import buttons, common @@ -34,6 +34,9 @@ @pytest.mark.setup_client(pin=PIN4) def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() + session = device_handler.client.get_seedless_session() + session.call(messages.LockDevice()) + session.refresh_features() short_duration = { models.T1B1: 500, @@ -59,22 +62,25 @@ def hold(duration: int) -> None: assert device_handler.features().unlocked is False # unlock with message - device_handler.run(common.get_test_address) + device_handler.run_with_session(common.get_test_address) assert "PinKeyboard" in debug.read_layout().all_components() debug.input("1234") assert device_handler.result() + session.refresh_features() assert device_handler.features().unlocked is True # short touch hold(short_duration) time.sleep(0.5) # so that the homescreen appears again (hacky) + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False # unlock by touching @@ -86,8 +92,10 @@ def hold(duration: int) -> None: assert "PinKeyboard" in layout.all_components() debug.input("1234") + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False diff --git a/tests/click_tests/test_passphrase_bolt.py b/tests/click_tests/test_passphrase_bolt.py index 8f490c03098..79993b954fb 100644 --- a/tests/click_tests/test_passphrase_bolt.py +++ b/tests/click_tests/test_passphrase_bolt.py @@ -69,7 +69,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore assert debug.read_layout().main_component() == "PassphraseKeyboard" # Resetting the category as it could have been changed by previous tests diff --git a/tests/click_tests/test_passphrase_caesar.py b/tests/click_tests/test_passphrase_caesar.py index 57685451ba0..0affa4fbb6b 100644 --- a/tests/click_tests/test_passphrase_caesar.py +++ b/tests/click_tests/test_passphrase_caesar.py @@ -91,7 +91,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore layout = debug.read_layout() assert "PassphraseKeyboard" in layout.all_components() assert layout.passphrase() == "" diff --git a/tests/click_tests/test_passphrase_delizia.py b/tests/click_tests/test_passphrase_delizia.py index 85bdc371735..fc7d79610e4 100644 --- a/tests/click_tests/test_passphrase_delizia.py +++ b/tests/click_tests/test_passphrase_delizia.py @@ -97,7 +97,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore # TODO assert debug.read_layout().main_component() == "PassphraseKeyboard" diff --git a/tests/click_tests/test_pin.py b/tests/click_tests/test_pin.py index d810910dc8d..81ec77d742a 100644 --- a/tests/click_tests/test_pin.py +++ b/tests/click_tests/test_pin.py @@ -23,6 +23,7 @@ from trezorlib import device, exceptions from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from .. import buttons from .. import translations as TR @@ -91,17 +92,19 @@ def prepare( tap = False + Session(device_handler.client.get_seedless_session()).lock() + # Setup according to the wanted situation if situation == Situation.PIN_INPUT: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore tap = True if situation == Situation.PIN_INPUT_CANCEL: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore elif situation == Situation.PIN_SETUP: # Set new PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore assert ( TR.pin__turn_on in debug.read_layout().text_content() or TR.pin__info in debug.read_layout().text_content() @@ -115,14 +118,14 @@ def prepare( go_next(debug) elif situation == Situation.PIN_CHANGE: # Change PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore _input_see_confirm(debug, old_pin) assert TR.pin__change in debug.read_layout().text_content() go_next(debug) _input_see_confirm(debug, old_pin) elif situation == Situation.WIPE_CODE_SETUP: # Set wipe code - device_handler.run(device.change_wipe_code) # type: ignore + device_handler.run_with_session(device.change_wipe_code) # type: ignore if old_pin: _input_see_confirm(debug, old_pin) assert TR.wipe_code__turn_on in debug.read_layout().text_content() diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index 7161e774dda..ee091b6a63c 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate( features = device_handler.features() debug = device_handler.debuglink() assert features.initialized is False - device_handler.run(device.recover, pin_protection=False) # type: ignore + device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore yield debug diff --git a/tests/click_tests/test_repeated_backup.py b/tests/click_tests/test_repeated_backup.py index e0d62b53d2a..9723fb6d2d8 100644 --- a/tests/click_tests/test_repeated_backup.py +++ b/tests/click_tests/test_repeated_backup.py @@ -41,7 +41,7 @@ def test_repeated_backup( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -97,7 +97,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # run recovery to unlock backup - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) @@ -164,7 +164,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # try to unlock backup again... - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) diff --git a/tests/click_tests/test_reset_bip39.py b/tests/click_tests/test_reset_bip39.py index d405f514413..2d9d400cb23 100644 --- a/tests/click_tests/test_reset_bip39.py +++ b/tests/click_tests/test_reset_bip39.py @@ -39,7 +39,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"): assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, strength=128, backup_type=messages.BackupType.Bip39, diff --git a/tests/click_tests/test_reset_slip39_advanced.py b/tests/click_tests/test_reset_slip39_advanced.py index 77f68269e69..a93eb71cec6 100644 --- a/tests/click_tests/test_reset_slip39_advanced.py +++ b/tests/click_tests/test_reset_slip39_advanced.py @@ -51,7 +51,7 @@ def test_reset_slip39_advanced( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, backup_type=messages.BackupType.Slip39_Advanced, pin_protection=False, diff --git a/tests/click_tests/test_reset_slip39_basic.py b/tests/click_tests/test_reset_slip39_basic.py index d18bf7a4fba..d13d7ba884b 100644 --- a/tests/click_tests/test_reset_slip39_basic.py +++ b/tests/click_tests/test_reset_slip39_basic.py @@ -47,7 +47,7 @@ def test_reset_slip39_basic( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.setup, strength=128, backup_type=messages.BackupType.Slip39_Basic, diff --git a/tests/click_tests/test_tutorial_caesar.py b/tests/click_tests/test_tutorial_caesar.py index 2394b0a1022..de0e010f1dd 100644 --- a/tests/click_tests/test_tutorial_caesar.py +++ b/tests/click_tests/test_tutorial_caesar.py @@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it( device_handler: "BackgroundDeviceHandler", cancelled: bool = False ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) yield debug diff --git a/tests/click_tests/test_tutorial_delizia.py b/tests/click_tests/test_tutorial_delizia.py index 08902162f68..1f0db5a63b3 100644 --- a/tests/click_tests/test_tutorial_delizia.py +++ b/tests/click_tests/test_tutorial_delizia.py @@ -36,7 +36,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(buttons.TAP_TO_CONFIRM) @@ -56,7 +56,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(buttons.TAP_TO_CONFIRM) @@ -82,7 +82,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(buttons.TAP_TO_CONFIRM) @@ -105,7 +105,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(buttons.TAP_TO_CONFIRM) @@ -135,7 +135,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) assert debug.read_layout().title() == TR.tutorial__welcome_safe5 debug.click(buttons.TAP_TO_CONFIRM) diff --git a/tests/common.py b/tests/common.py index 4d2151ce08e..9e938d4ddf9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -34,8 +34,8 @@ from _pytest.mark.structures import MarkDecorator from trezorlib.debuglink import DebugLink - from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ButtonRequest + from trezorlib.transport.session import Session PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] @@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None: assert got >= expected -def get_test_address(client: "Client") -> str: +def get_test_address(session: "Session") -> str: """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase protected call, or to identify the root secret (seed+passphrase)""" - return btc.get_address(client, "Testnet", TEST_ADDRESS_N) + return btc.get_address(session, "Testnet", TEST_ADDRESS_N) def compact_size(n: int) -> bytes: @@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None: debug.swipe_up() -def is_core(client: "Client") -> bool: - return client.model is not models.T1B1 +def is_core(session: "Session") -> bool: + return session.model is not models.T1B1 diff --git a/tests/conftest.py b/tests/conftest.py index 05e6622e1b7..66417c70ee0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,17 +20,22 @@ import typing as t from enum import IntEnum from pathlib import Path +from time import sleep +import cryptography import pytest import xdist from _pytest.python import IdMaker from _pytest.reports import TestReport from trezorlib import debuglink, log, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.device import apply_settings from trezorlib.device import wipe as wipe_device from trezorlib.transport import enumerate_devices, get_transport +from trezorlib.transport.thp.protocol_v1 import ProtocolV1 # register rewrites before importing from local package # so that we see details of failed asserts from this module @@ -135,6 +140,10 @@ def _get_port() -> int: @pytest.fixture(scope="session") def _raw_client(request: pytest.FixtureRequest) -> Client: + return _get_raw_client(request) + + +def _get_raw_client(request: pytest.FixtureRequest) -> Client: # In case tests run in parallel, each process has its own emulator/client. # Requesting the emulator fixture only if relevant. if request.session.config.getoption("control_emulators"): @@ -165,10 +174,7 @@ def _client_from_path( def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client: devices = enumerate_devices() for device in devices: - try: - return Client(device, auto_interact=not interact) - except Exception: - pass + return Client(device, auto_interact=not interact) request.session.shouldstop = "Failed to communicate with Trezor" raise RuntimeError("No debuggable device found") @@ -243,7 +249,7 @@ def _set_from_marker_list( @pytest.fixture(scope="function") -def client( +def _client_unlocked( request: pytest.FixtureRequest, _raw_client: Client ) -> t.Generator[Client, None, None]: """Client fixture. @@ -273,6 +279,29 @@ def client( if _raw_client.model not in models_filter: pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}") + protocol_marker: Mark | None = request.node.get_closest_marker("protocol") + if protocol_marker: + args = protocol_marker.args + protocol_version = _raw_client.protocol_version + + if ( + protocol_version == ProtocolVersion.PROTOCOL_V1 + and "protocol_v1" not in args + ): + pytest.skip( + f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." + ) + + if ( + protocol_version == ProtocolVersion.PROTOCOL_V2 + and "protocol_v2" not in args + ): + pytest.skip( + f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." + ) + + if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2: + pass sd_marker = request.node.get_closest_marker("sd_card") if sd_marker and not _raw_client.features.sd_card_present: raise RuntimeError( @@ -283,14 +312,15 @@ def client( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features() + _raw_client.reset_debug_features(new_seedless_session=True) _raw_client.open() - try: - _raw_client.sync_responses() - _raw_client.init_device() - except Exception: - request.session.shouldstop = "Failed to communicate with Trezor" - pytest.fail("Failed to communicate with Trezor") + if isinstance(_raw_client.protocol, ProtocolV1): + try: + _raw_client.sync_responses() + # TODO _raw_client.init_device() + except Exception: + request.session.shouldstop = "Failed to communicate with Trezor" + pytest.fail("Failed to communicate with Trezor") # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -303,13 +333,36 @@ def client( should_format = sd_marker.kwargs.get("formatted", True) _raw_client.debug.erase_sd_card(format=should_format) - wipe_device(_raw_client) + while True: + try: + if _raw_client.is_invalidated: + _raw_client = _get_raw_client(request) + session = _raw_client.get_seedless_session() + wipe_device(session) + sleep(1.5) # Makes tests more stable (wait for wipe to finish) + break + except cryptography.exceptions.InvalidTag: + # Get a new client + _raw_client = _get_raw_client(request) + + from trezorlib.transport.thp.channel_database import get_channel_db + + get_channel_db().clear_stored_channels() + _raw_client.protocol = None + _raw_client.__init__( + transport=_raw_client.transport, + auto_interact=_raw_client.debug.allow_interactions, + ) + if not _raw_client.features.bootloader_mode: + _raw_client.refresh_features() # Load language again, as it got erased in wipe if _raw_client.model is not models.T1B1: lang = request.session.config.getoption("lang") or "en" assert isinstance(lang, str) - translations.set_language(_raw_client, lang) + translations.set_language( + SessionDebugWrapper(_raw_client.get_seedless_session()), lang + ) setup_params = dict( uninitialized=False, @@ -327,10 +380,10 @@ def client( use_passphrase = setup_params["passphrase"] is True or isinstance( setup_params["passphrase"], str ) - if not setup_params["uninitialized"]: + session = _raw_client.get_seedless_session(new_session=True) debuglink.load_device( - _raw_client, + session, mnemonic=setup_params["mnemonic"], # type: ignore pin=setup_params["pin"], # type: ignore passphrase_protection=use_passphrase, @@ -338,21 +391,56 @@ def client( needs_backup=setup_params["needs_backup"], # type: ignore no_backup=setup_params["no_backup"], # type: ignore ) + _raw_client._setup_pin = setup_params["pin"] if request.node.get_closest_marker("experimental"): - apply_settings(_raw_client, experimental_features=True) + apply_settings(session, experimental_features=True) if use_passphrase and isinstance(setup_params["passphrase"], str): _raw_client.use_passphrase(setup_params["passphrase"]) - _raw_client.clear_session() + # TODO _raw_client.clear_session() - with ui_tests.screen_recording(_raw_client, request): - yield _raw_client + yield _raw_client _raw_client.close() +@pytest.fixture(scope="function") +def client( + request: pytest.FixtureRequest, _client_unlocked: Client +) -> t.Generator[Client, None, None]: + _client_unlocked.lock() + with ui_tests.screen_recording(_client_unlocked, request): + yield _client_unlocked + + +@pytest.fixture(scope="function") +def session( + request: pytest.FixtureRequest, _client_unlocked: Client +) -> t.Generator[SessionDebugWrapper, None, None]: + if bool(request.node.get_closest_marker("uninitialized_session")): + session = _client_unlocked.get_seedless_session() + else: + derive_cardano = bool(request.node.get_closest_marker("cardano")) + passphrase = _client_unlocked.passphrase or "" + if _client_unlocked._setup_pin is not None: + _client_unlocked.use_pin_sequence([_client_unlocked._setup_pin]) + session = _client_unlocked.get_session( + derive_cardano=derive_cardano, passphrase=passphrase + ) + try: + wrapped_session = SessionDebugWrapper(session) + if _client_unlocked._setup_pin is not None: + wrapped_session.lock() + with ui_tests.screen_recording(_client_unlocked, request): + yield wrapped_session + finally: + pass + # TODO + # session.end() + + def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool: """Return True if the current process is the main test runner. @@ -462,6 +550,10 @@ def pytest_configure(config: "Config") -> None: "markers", 'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance', ) + config.addinivalue_line( + "markers", + "uninitialized_session: use uninitialized session instance", + ) with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f: for line in f: config.addinivalue_line("markers", line.strip()) diff --git a/tests/device_handler.py b/tests/device_handler.py index 74eb77a5a52..4df01f9c844 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -52,7 +52,7 @@ def _configure_client(self, client: "Client") -> None: self.client.watch_layout(True) self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT - def run( + def run_with_session( self, function: t.Callable[tx.Concatenate["Client", P], t.Any], *args: P.args, @@ -67,15 +67,34 @@ def run( # wait for the first UI change triggered by the task running in the background with self.debuglink().wait_for_layout_change(): - self.task = self._pool.submit(function, self.client, *args, **kwargs) + session = self.client.get_session() + self.task = self._pool.submit(function, session, *args, **kwargs) + + def run_with_provided_session( + self, + session, + function: t.Callable[tx.Concatenate["Client", P], t.Any], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """Runs some function that interacts with a device. + + Makes sure the UI is updated before returning. + """ + if self.task is not None: + raise RuntimeError("Wait for previous task first") + + # wait for the first UI change triggered by the task running in the background + with self.debuglink().wait_for_layout_change(): + self.task = self._pool.submit(function, session, *args, **kwargs) def kill_task(self) -> None: if self.task is not None: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have # a close() method. - while self.client.session_counter > 0: - self.client.close() + # while self.client.session_counter > 0: + # self.client.close() try: self.task.result(timeout=1) except Exception: @@ -99,7 +118,7 @@ def result(self, timeout: float | None = None) -> t.Any: def features(self) -> "Features": if self.task is not None: raise RuntimeError("Cannot query features while task is running") - self.client.init_device() + self.client.refresh_features() return self.client.features def debuglink(self) -> "DebugLink": diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index cdb6e722713..6b5a0247676 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -17,7 +17,7 @@ import pytest from trezorlib.binance import get_address -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowAddressQRCode @@ -38,23 +38,23 @@ @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) -def test_binance_get_address(client: Client, path: str, expected_address: str): +def test_binance_get_address(session: Session, path: str, expected_address: str): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - address = get_address(client, parse_path(path), show_display=True) + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) def test_binance_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index ea04fdbd88f..f65baa5dd83 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowXpubQRCode @@ -31,11 +31,11 @@ @pytest.mark.setup_client( mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) -def test_binance_get_public_key(client: Client): - with client: +def test_binance_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - sig = binance.get_public_key(client, BINANCE_PATH, show_display=True) + sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() == "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e" diff --git a/tests/device_tests/binance/test_sign_tx.py b/tests/device_tests/binance/test_sign_tx.py index ceb06924650..1665e005a46 100644 --- a/tests/device_tests/binance/test_sign_tx.py +++ b/tests/device_tests/binance/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path BINANCE_TEST_VECTORS = [ @@ -110,10 +110,10 @@ @pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS) @pytest.mark.parametrize("chunkify", (True, False)) def test_binance_sign_message( - client: Client, chunkify: bool, message: dict, expected_response: dict + session: Session, chunkify: bool, message: dict, expected_response: dict ): response = binance.sign_tx( - client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify + session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify ) assert response.public_key.hex() == expected_response["public_key"] diff --git a/tests/device_tests/bitcoin/payment_req.py b/tests/device_tests/bitcoin/payment_req.py index 73d98859ba1..f928a5fa8e8 100644 --- a/tests/device_tests/bitcoin/payment_req.py +++ b/tests/device_tests/bitcoin/payment_req.py @@ -4,6 +4,7 @@ from ecdsa import SECP256k1, SigningKey from trezorlib import btc, messages +from trezorlib.transport.session import Session from ...common import compact_size @@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data): def make_payment_request( - client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None + session: Session, + recipient_name, + outputs, + change_addresses=None, + memos=None, + nonce=None, ): h_pr = sha256(b"SL\x00\x24") @@ -52,7 +58,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, memo.text.encode()) elif isinstance(memo, RefundMemo): address_resp = btc.get_authenticated_address( - client, "Testnet", memo.address_n + session, "Testnet", memo.address_n ) msg_memo = messages.RefundMemo( address=address_resp.address, mac=address_resp.mac @@ -63,7 +69,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, address_resp.address.encode()) elif isinstance(memo, CoinPurchaseMemo): address_resp = btc.get_authenticated_address( - client, memo.coin_name, memo.address_n + session, memo.coin_name, memo.address_n ) msg_memo = messages.CoinPurchaseMemo( coin_type=memo.slip44, diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 15028d83b3e..8c0e7a4484c 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -19,6 +19,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -59,15 +60,15 @@ @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.setup_client(pin=PIN) -def test_sign_tx(client: Client, chunkify: bool): +def test_sign_tx(session: Session, chunkify: bool): # NOTE: FAKE input tx - + assert session.features.unlocked is False commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=2, max_coordinator_fee_rate=500_000, # 0.5 % @@ -77,14 +78,14 @@ def test_sign_tx(client: Client, chunkify: bool): script_type=messages.InputScriptType.SPENDTAPROOT, ) - client.call(messages.LockDevice()) + session.call(messages.LockDevice()) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -93,12 +94,12 @@ def test_sign_tx(client: Client, chunkify: bool): preauthorized=True, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/5"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -206,8 +207,8 @@ def test_sign_tx(client: Client, chunkify: bool): no_fee_indices=[], ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.PreauthorizedRequest(), request_input(0), @@ -222,7 +223,7 @@ def test_sign_tx(client: Client, chunkify: bool): ] ) signatures, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -243,7 +244,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a second time. btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -256,7 +257,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a third time, number of rounds should be exceeded. with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -267,7 +268,7 @@ def test_sign_tx(client: Client, chunkify: bool): ) -def test_sign_tx_large(client: Client): +def test_sign_tx_large(session: Session): # NOTE: FAKE input tx commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") @@ -278,17 +279,16 @@ def test_sign_tx_large(client: Client): output_denom = 10_000 # sats max_expected_delay = 80 # seconds - with client: - btc.authorize_coinjoin( - client, - coordinator="www.example.com", - max_rounds=2, - max_coordinator_fee_rate=500_000, # 0.5 % - max_fee_per_kvbyte=3500, - n=parse_path("m/10025h/1h/0h/1h"), - coin_name="Testnet", - script_type=messages.InputScriptType.SPENDTAPROOT, - ) + btc.authorize_coinjoin( + session, + coordinator="www.example.com", + max_rounds=2, + max_coordinator_fee_rate=500_000, # 0.5 % + max_fee_per_kvbyte=3500, + n=parse_path("m/10025h/1h/0h/1h"), + coin_name="Testnet", + script_type=messages.InputScriptType.SPENDTAPROOT, + ) # INPUTS. @@ -399,22 +399,21 @@ def test_sign_tx_large(client: Client): ) start = time.time() - with client: - btc.sign_tx( - client, - "Testnet", - inputs, - outputs, - prev_txes=TX_CACHE_TESTNET, - coinjoin_request=coinjoin_req, - preauthorized=True, - serialize=False, - ) + btc.sign_tx( + session, + "Testnet", + inputs, + outputs, + prev_txes=TX_CACHE_TESTNET, + coinjoin_request=coinjoin_req, + preauthorized=True, + serialize=False, + ) delay = time.time() - start assert delay <= max_expected_delay -def test_sign_tx_spend(client: Client): +def test_sign_tx_spend(session: Session): # NOTE: FAKE input tx inputs = [ @@ -446,15 +445,15 @@ def test_sign_tx_spend(client: Client): # Ensure that Trezor refuses to spend from CoinJoin without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -462,7 +461,7 @@ def test_sign_tx_spend(client: Client): request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -472,7 +471,7 @@ def test_sign_tx_spend(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -487,7 +486,7 @@ def test_sign_tx_spend(client: Client): ) -def test_sign_tx_migration(client: Client): +def test_sign_tx_migration(session: Session): inputs = [ messages.TxInputType( address_n=parse_path("m/84h/1h/3h/0/12"), @@ -520,15 +519,15 @@ def test_sign_tx_migration(client: Client): # Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -536,7 +535,7 @@ def test_sign_tx_migration(client: Client): request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_2cc3c1), @@ -558,7 +557,7 @@ def test_sign_tx_migration(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -573,11 +572,11 @@ def test_sign_tx_migration(client: Client): ) -def test_wrong_coordinator(client: Client): +def test_wrong_coordinator(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -589,7 +588,7 @@ def test_wrong_coordinator(client: Client): with pytest.raises(TrezorFailure, match="Unauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -599,9 +598,9 @@ def test_wrong_coordinator(client: Client): ) -def test_wrong_account_type(client: Client): +def test_wrong_account_type(session: Session): params = { - "client": client, + "session": session, "coordinator": "www.example.com", "max_rounds": 10, "max_coordinator_fee_rate": 500_000, # 0.5 % @@ -625,11 +624,11 @@ def test_wrong_account_type(client: Client): ) -def test_cancel_authorization(client: Client): +def test_cancel_authorization(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -639,11 +638,11 @@ def test_cancel_authorization(client: Client): script_type=messages.InputScriptType.SPENDTAPROOT, ) - device.cancel_authorization(client) + device.cancel_authorization(session) with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -653,35 +652,35 @@ def test_cancel_authorization(client: Client): ) -def test_get_public_key(client: Client): +def test_get_public_key(session: Session): ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h") EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU" # Ensure that user cannot access SLIP-25 path without UnlockPath. with pytest.raises(TrezorFailure, match="Forbidden key path"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) # Get unlock path MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH) # Ensure that UnlockPath fails with invalid MAC. invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:] with pytest.raises(TrezorFailure, match="Invalid MAC"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -690,15 +689,15 @@ def test_get_public_key(client: Client): ) # Ensure that user does not need to confirm access when path unlock is requested with MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.UnlockedPathRequest, messages.PublicKey, ] ) resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -708,11 +707,12 @@ def test_get_public_key(client: Client): assert resp.xpub == EXPECTED_XPUB -def test_get_address(client: Client): +def test_get_address(session: Session): + # Ensure that the SLIP-0025 external chain is inaccessible without user confirmation. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -720,20 +720,20 @@ def test_get_address(client: Client): ) # Unlock CoinJoin path. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, SLIP25_PATH) # Ensure that the SLIP-0025 external chain is accessible after user confirmation. for chunkify in (True, False): resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -745,7 +745,7 @@ def test_get_address(client: Client): assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7" resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -758,7 +758,7 @@ def test_get_address(client: Client): # Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -769,7 +769,7 @@ def test_get_address(client: Client): with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -781,7 +781,7 @@ def test_get_address(client: Client): # Ensure that another SLIP-0025 account is inaccessible with the same MAC. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/1h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -793,8 +793,10 @@ def test_get_address(client: Client): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. + session1 = client.get_session(session_id=1) + btc.authorize_coinjoin( - client, + session1, coordinator="www.example1.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -803,14 +805,14 @@ def test_multisession_authorization(client: Client): coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) - + session2 = client.get_session(session_id=2) # Open a second session. - session_id1 = client.session_id - client.init_device(new_session=True) + # session_id1 = session.session_id + # TODO client.init_device(new_session=True) # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( - client, + session2, coordinator="www.example2.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -823,7 +825,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example1.com should fail in session 2. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -834,7 +836,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -849,12 +851,12 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - session_id2 = client.session_id - client.init_device(session_id=session_id1) - + # session_id2 = session.session_id + # TODO client.init_device(session_id=session_id1) + client.resume_session(session1) # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -871,7 +873,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should fail in session 1. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -881,12 +883,12 @@ def test_multisession_authorization(client: Client): ) # Cancel the authorization in session 1. - device.cancel_authorization(client) + device.cancel_authorization(session1) # Requesting a preauthorized ownership proof should fail now. with pytest.raises(TrezorFailure, match="No preauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -896,11 +898,11 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - client.init_device(session_id=session_id2) - + # TODO client.init_device(session_id=session_id2) + client.resume_session(session2) # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, diff --git a/tests/device_tests/bitcoin/test_bcash.py b/tests/device_tests/bitcoin/test_bcash.py index 76538828632..d1f0129741c 100644 --- a/tests/device_tests/bitcoin/test_bcash.py +++ b/tests/device_tests/bitcoin/test_bcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -53,7 +53,7 @@ pytestmark = pytest.mark.altcoin -def test_send_bch_change(client: Client): +def test_send_bch_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/0/0"), # bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv @@ -72,14 +72,14 @@ def test_send_bch_change(client: Client): amount=73_452, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_bc37c2), @@ -92,9 +92,9 @@ def test_send_bch_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) - + # raise Exception(hexlify(serialized_tx)) assert_tx_matches( serialized_tx, hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c", @@ -102,7 +102,7 @@ def test_send_bch_change(client: Client): ) -def test_send_bch_nochange(client: Client): +def test_send_bch_nochange(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client): ) -def test_send_bch_oldaddr(client: Client): +def test_send_bch_oldaddr(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -252,15 +252,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_bd32ff), @@ -271,16 +271,16 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_bch_multisig_wrongchange(client: Client): +def test_send_bch_multisig_wrongchange(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -327,13 +327,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=23_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_062fbd), @@ -346,7 +346,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1], prev_txes=TX_API + session, "Bcash", [inp1], [out1], prev_txes=TX_API ) assert ( signatures1[0].hex() @@ -359,12 +359,12 @@ def getmultisig(chain, nr, signatures): @pytest.mark.multisig -def test_send_bch_multisig_change(client: Client): +def test_send_bch_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -395,13 +395,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=24_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -415,7 +415,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -434,13 +434,13 @@ def getmultisig(chain, nr, signatures): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -454,7 +454,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -468,7 +468,7 @@ def getmultisig(chain, nr, signatures): @pytest.mark.models("core") -def test_send_bch_external_presigned(client: Client): +def test_send_bch_external_presigned(session: Session): inp1 = messages.TxInputType( # address_n=parse_path("44'/145'/0'/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_bgold.py b/tests/device_tests/bitcoin/test_bgold.py index 71c1a6c3ad4..831ea216cbd 100644 --- a/tests/device_tests/bitcoin/test_bgold.py +++ b/tests/device_tests/bitcoin/test_bgold.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path, tx_hash @@ -51,7 +51,7 @@ # All data taken from T1 -def test_send_bitcoin_gold_change(client: Client): +def test_send_bitcoin_gold_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client): amount=1_252_382_934 - 1_896_050 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client): ) -def test_send_bitcoin_gold_nochange(client: Client): +def test_send_bitcoin_gold_nochange(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client): amount=1_252_382_934 + 38_448_607 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -193,15 +193,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -213,16 +213,16 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_btg_multisig_change(client: Client): +def test_send_btg_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" + session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -254,13 +254,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=1_252_382_934 - 24_000 - 1_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -275,7 +275,7 @@ def getmultisig(chain, nr, signatures): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -293,13 +293,13 @@ def getmultisig(chain, nr, signatures): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -314,7 +314,7 @@ def getmultisig(chain, nr, signatures): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -327,7 +327,7 @@ def getmultisig(chain, nr, signatures): ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -347,16 +347,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_db7239), @@ -371,7 +371,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -380,7 +380,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_witness_change(client: Client): +def test_send_p2sh_witness_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" + session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client): request_finished(), ] ) - signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API) + signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API) # store signature inp1.multisig.signatures[0] = signatures[0] # sign with third key inp1.address_n[2] = H_(3) - client.set_expected_responses( + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1], prev_txes=TX_API + session, "Bgold", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client): ) -def test_send_mixed_inputs(client: Client): +def test_send_mixed_inputs(session: Session): # NOTE: fake input tx used # First is non-segwit, second is segwit. @@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client): @pytest.mark.models("core") -def test_send_btg_external_presigned(client: Client): +def test_send_btg_external_presigned(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client): amount=1_252_382_934 + 58_456 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_dash.py b/tests/device_tests/bitcoin/test_dash.py index 4dde98bfbfd..06b335c1487 100644 --- a/tests/device_tests/bitcoin/test_dash.py +++ b/tests/device_tests/bitcoin/test_dash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] -def test_send_dash(client: Client): +def test_send_dash(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -57,13 +57,13 @@ def test_send_dash(client: Client): amount=999_999_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -77,7 +77,9 @@ def test_send_dash(client: Client): request_finished(), ] ) - _, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API) + _, serialized_tx = btc.sign_tx( + session, "Dash", [inp1], [out1], prev_txes=TX_API + ) assert ( serialized_tx.hex() @@ -85,7 +87,7 @@ def test_send_dash(client: Client): ) -def test_send_dash_dip2_input(client: Client): +def test_send_dash_dip2_input(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client): amount=95_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Dash", [inp1], [out1, out2], prev_txes=TX_API + session, "Dash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_decred.py b/tests/device_tests/bitcoin/test_decred.py index 78bb1b0c3af..204d0559280 100644 --- a/tests/device_tests/bitcoin/test_decred.py +++ b/tests/device_tests/bitcoin/test_decred.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -57,7 +57,7 @@ ] -def test_send_decred(client: Client): +def test_send_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -76,13 +76,13 @@ def test_send_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -95,7 +95,7 @@ def test_send_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -105,7 +105,7 @@ def test_send_decred(client: Client): @pytest.mark.models("core") -def test_purchase_ticket_decred(client: Client): +def test_purchase_ticket_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), @@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1], [out1, out2, out3], @@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client): @pytest.mark.models("core") -def test_spend_from_stake_generation_and_revocation_decred(client: Client): +def test_spend_from_stake_generation_and_revocation_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_8b6890), @@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ) -def test_send_decred_change(client: Client): +def test_send_decred_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -278,15 +278,15 @@ def test_send_decred_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_input(2), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -311,7 +311,7 @@ def test_send_decred_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2, inp3], [out1, out2], @@ -325,12 +325,12 @@ def test_send_decred_change(client: Client): @pytest.mark.multisig -def test_decred_multisig_change(client: Client): +def test_decred_multisig_change(session: Session): # NOTE: fake input tx used paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)] nodes = [ - btc.get_public_node(client, address_n, coin_name="Decred Testnet").node + btc.get_public_node(session, address_n, coin_name="Decred Testnet").node for address_n in paths ] @@ -384,15 +384,15 @@ def test_multisig(index): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_9ac7d2), @@ -410,7 +410,7 @@ def test_multisig(index): ] ) signature, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 6efdd99ed82..7a077b20527 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -18,7 +18,7 @@ from trezorlib import btc, messages, models from trezorlib.cli import btc as btc_cli -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_ from ...input_flows import InputFlowShowXpubQRCode @@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type): @pytest.mark.parametrize( "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) -def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors): - with client: +def test_descriptors( + session: Session, coin, account, purpose, script_type, descriptors +): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( - client, + session, _address_n(purpose, coin, account, script_type), show_display=True, coin_name=coin, @@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) def test_descriptors_trezorlib( - client: Client, coin, account, purpose, script_type, descriptors + session: Session, coin, account, purpose, script_type, descriptors ): - with client: + with session.client as client: if client.model != models.T1B1: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) res = btc_cli._get_descriptor( - client, coin, account, purpose, script_type, show_display=True + session, coin, account, purpose, script_type, show_display=True ) assert res == descriptors diff --git a/tests/device_tests/bitcoin/test_firo.py b/tests/device_tests/bitcoin/test_firo.py index 52db787957d..2ceeb2c2d77 100644 --- a/tests/device_tests/bitcoin/test_firo.py +++ b/tests/device_tests/bitcoin/test_firo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -30,7 +30,7 @@ @pytest.mark.altcoin -def test_spend_lelantus(client: Client): +def test_spend_lelantus(session: Session): inp1 = messages.TxInputType( # THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk address_n=parse_path("m/44h/1h/0h/0/4"), @@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API + session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API ) assert_tx_matches( serialized_tx, diff --git a/tests/device_tests/bitcoin/test_fujicoin.py b/tests/device_tests/bitcoin/test_fujicoin.py index f28747c7173..45886e8603b 100644 --- a/tests/device_tests/bitcoin/test_fujicoin.py +++ b/tests/device_tests/bitcoin/test_fujicoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path TXHASH_33043a = bytes.fromhex( @@ -27,7 +27,7 @@ pytestmark = pytest.mark.altcoin -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7 address_n=parse_path("m/86h/75h/0h/0/1"), @@ -42,7 +42,7 @@ def test_send_p2tr(client: Client): amount=99_996_670_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1]) + _, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1]) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86 assert ( diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 5367bcbb3e6..3c8a2fbc9da 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import MultisigPubkeysOrder, SafetyCheckLevel from trezorlib.tools import parse_path @@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs): ) -def test_btc(client: Client): +def test_btc(session: Session): assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) == "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) == "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" ) @pytest.mark.altcoin -def test_ltc(client: Client): +def test_ltc(session: Session): assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0")) == "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1")) == "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0")) == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" ) -def test_tbtc(client: Client): +def test_tbtc(session: Session): assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1")) == "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" ) @pytest.mark.altcoin -def test_bch(client: Client): +def test_bch(session: Session): assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0")) == "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1")) == "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0")) == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" ) @pytest.mark.altcoin -def test_grs(client: Client): +def test_grs(session: Session): assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) == "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) == "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" ) @pytest.mark.altcoin -def test_tgrs(client: Client): +def test_tgrs(session: Session): assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) == "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6" ) @pytest.mark.altcoin -def test_elements(client: Client): +def test_elements(session: Session): assert ( - btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0")) == "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr" ) @pytest.mark.models("core") -def test_address_mac(client: Client): +def test_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/1/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/1/0") ) assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert ( @@ -150,7 +150,7 @@ def test_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Testnet", parse_path("m/44h/1h/0h/1/0") + session, "Testnet", parse_path("m/44h/1h/0h/1/0") ) assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" assert ( @@ -160,16 +160,16 @@ def test_address_mac(client: Client): # Script type mismatch. resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False ) assert resp.mac is None @pytest.mark.models("core") @pytest.mark.altcoin -def test_altcoin_address_mac(client: Client): +def test_altcoin_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Litecoin", parse_path("m/44h/2h/0h/1/0") + session, "Litecoin", parse_path("m/44h/2h/0h/1/0") ) assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" assert ( @@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Bcash", parse_path("m/44h/145h/0h/1/0") + session, "Bcash", parse_path("m/44h/145h/0h/1/0") ) assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" assert ( @@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") + session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") ) assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" assert ( @@ -197,9 +197,9 @@ def test_altcoin_address_mac(client: Client): @pytest.mark.multisig -def test_multisig_pubkeys_order(client: Client): - xpub_internal = btc.get_public_node(client, parse_path("m/45h/0")).xpub - xpub_external = btc.get_public_node(client, parse_path("m/45h/1")).xpub +def test_multisig_pubkeys_order(session: Session): + xpub_internal = btc.get_public_node(session, parse_path("m/45h/0")).xpub + xpub_external = btc.get_public_node(session, parse_path("m/45h/1")).xpub multisig_unsorted_1 = messages.MultisigRedeemScriptType( nodes=[bip32.deserialize(xpub) for xpub in [xpub_external, xpub_internal]], @@ -238,45 +238,45 @@ def test_multisig_pubkeys_order(client: Client): assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) == address_unsorted_1 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 ) == address_unsorted_2 ) @pytest.mark.multisig -def test_multisig(client: Client): +def test_multisig(session: Session): xpubs = [] for n in range(1, 4): - node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h")) + node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h")) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/0/0"), show_display=(nr == 1), @@ -286,7 +286,7 @@ def test_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/1/0"), show_display=(nr == 1), @@ -298,11 +298,11 @@ def test_multisig(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node for i in range(1, 4) ] @@ -321,12 +321,12 @@ def test_multisig_missing(client: Client, show_display): ) for multisig in (multisig1, multisig2): - with client, pytest.raises(TrezorFailure): - if is_core(client): + with session.client as client, pytest.raises(TrezorFailure): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=show_display, @@ -336,22 +336,22 @@ def test_multisig_missing(client: Client, show_display): @pytest.mark.altcoin @pytest.mark.multisig -def test_bch_multisig(client: Client): +def test_bch_multisig(session: Session): xpubs = [] for n in range(1, 4): node = btc.get_public_node( - client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" + session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" ) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/0/0"), show_display=(nr == 1), @@ -361,7 +361,7 @@ def test_bch_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/1/0"), show_display=(nr == 1), @@ -371,43 +371,43 @@ def test_bch_multisig(client: Client): ) -def test_public_ckd(client: Client): - node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node - node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node +def test_public_ckd(session: Session): + node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node + node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node node_sub2 = bip32.public_ckd(node, [1, 0]) assert node_sub1.chain_code == node_sub2.chain_code assert node_sub1.public_key == node_sub2.public_key - address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) address2 = bip32.get_address(node_sub2, 0) assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert address1 == address2 -def test_invalid_path(client: Client): +def test_invalid_path(session: Session): with pytest.raises(TrezorFailure, match="Forbidden key path"): # slip44 id mismatch btc.get_address( - client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True ) -def test_unknown_path(client: Client): +def test_unknown_path(session: Session): UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") - with client: - client.set_expected_responses([messages.Failure]) + with session: + session.set_expected_responses([messages.Failure]) with pytest.raises(TrezorFailure, match="Forbidden key path"): # account number is too high - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) # disable safety checks - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ messages.ButtonRequest( code=messages.ButtonRequestType.UnknownDerivationPath @@ -416,30 +416,30 @@ def test_unknown_path(client: Client): messages.Address, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) # try again with a warning - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) - with client: + with session: # no warning is displayed when the call is silent - client.set_expected_responses([messages.Address]) - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False) + session.set_expected_responses([messages.Address]) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) @pytest.mark.altcoin -def test_crw(client: Client): +def test_crw(session: Session): assert ( - btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0")) + btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0")) == "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48" ) @pytest.mark.multisig -def test_multisig_different_paths(client: Client): +def test_multisig_different_paths(session: Session): nodes = [ - btc.get_public_node(client, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node + btc.get_public_node(session, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node for i in range(2) ] @@ -455,12 +455,12 @@ def test_multisig_different_paths(client: Client): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with client: - if is_core(client): + with session.client as client, session: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, @@ -468,13 +468,13 @@ def test_multisig_different_paths(client: Client): script_type=messages.InputScriptType.SPENDMULTISIG, ) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - if is_core(client): + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index 848097a8cbb..b1e3affac72 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -25,10 +25,10 @@ from ...input_flows import InputFlowConfirmAllWarnings -def test_show_segwit(client: Client): +def test_show_segwit(session: Session): assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -39,7 +39,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/0/0"), False, @@ -50,7 +50,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -61,7 +61,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -73,14 +73,14 @@ def test_show_segwit(client: Client): @pytest.mark.altcoin -def test_show_segwit_altcoin(client: Client): - with client: - if is_core(client): +def test_show_segwit_altcoin(session: Session): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/0/0"), True, @@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Elements", parse_path("m/49h/1h/0h/0/0"), True, @@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/49h/1h/{i}h/0/7"), False, @@ -168,11 +168,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node for i in range(1, 4) ] @@ -193,7 +193,7 @@ def test_multisig_missing(client: Client, show_display): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/49h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py index 55b0fbfdb5e..7c220adf65a 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -141,7 +141,7 @@ @pytest.mark.parametrize("show_display", (True, False)) @pytest.mark.parametrize("coin, path, script_type, address", VECTORS) def test_show_segwit( - client: Client, + session: Session, show_display: bool, coin: str, path: str, @@ -150,7 +150,7 @@ def test_show_segwit( ): assert ( btc.get_address( - client, + session, coin, parse_path(path), show_display, @@ -166,10 +166,10 @@ def test_show_segwit( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) @pytest.mark.parametrize("path, address", BIP86_VECTORS) -def test_bip86(client: Client, path: str, address: str): +def test_bip86(session: Session, path: str, address: str): assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(path), False, @@ -181,10 +181,10 @@ def test_bip86(client: Client, path: str, address: str): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -197,7 +197,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/1"), False, @@ -208,7 +208,7 @@ def test_show_multisig_3(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/0"), False, @@ -221,11 +221,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display: bool): +def test_multisig_missing(session: Session, show_display: bool): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] @@ -246,7 +246,7 @@ def test_multisig_missing(client: Client, show_display: bool): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 8770176d427..bcb685db1dc 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import is_core @@ -55,20 +55,20 @@ @pytest.mark.models("legacy") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_t1( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): def input_flow_t1(): yield - client.debug.press_no() + session.debug.press_no() yield - client.debug.press_yes() + session.debug.press_yes() - with client: + with session: # This is the only place where even T1 is using input flow - client.set_input_flow(input_flow_t1) + session.set_input_flow(input_flow_t1) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -82,18 +82,18 @@ def input_flow_t1(): @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_tt( - client: Client, + session: Session, chunkify: bool, path: str, script_type: messages.InputScriptType, address: str, ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -107,13 +107,13 @@ def test_show_tt( @pytest.mark.models("core") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_cancel( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowShowAddressQRCodeCancel(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -121,10 +121,10 @@ def test_show_cancel( ) -def test_show_unrecognized_path(client: Client): +def test_show_unrecognized_path(session: Session): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", tools.parse_path("m/24684621h/516582h/5156h/21/856"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -133,10 +133,10 @@ def test_show_unrecognized_path(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in [1, 2, 3] ] @@ -157,13 +157,13 @@ def test_show_multisig_3(client: Client): for multisig in (multisig1, multisig2): for i in [1, 2, 3]: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, @@ -250,7 +250,7 @@ def test_show_multisig_3(client: Client): "script_type, bip48_type, address, xpubs, ignore_xpub_magic", VECTORS_MULTISIG ) def test_show_multisig_xpubs( - client: Client, + session: Session, script_type: messages.InputScriptType, bip48_type: int, address: str, @@ -259,7 +259,7 @@ def test_show_multisig_xpubs( ): nodes = [ btc.get_public_node( - client, + session, tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h"), coin_name="Bitcoin", ) @@ -273,13 +273,13 @@ def test_show_multisig_xpubs( ) for i in range(3): - with client: + with session, session.client as client: IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) client.set_input_flow(IF.get()) client.debug.synchronize_at("Homescreen") client.watch_layout() btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h/0/0"), show_display=True, @@ -290,10 +290,10 @@ def test_show_multisig_xpubs( @pytest.mark.multisig -def test_show_multisig_15(client: Client): +def test_show_multisig_15(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in range(15) ] @@ -314,13 +314,13 @@ def test_show_multisig_15(client: Client): for multisig in [multisig1, multisig2]: for i in range(15): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getownershipproof.py b/tests/device_tests/bitcoin/test_getownershipproof.py index b21fe944b0e..51309eb625d 100644 --- a/tests/device_tests/bitcoin/test_getownershipproof.py +++ b/tests/device_tests/bitcoin/test_getownershipproof.py @@ -17,14 +17,14 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path -def test_p2wpkh_ownership_id(client: Client): +def test_p2wpkh_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -35,9 +35,9 @@ def test_p2wpkh_ownership_id(client: Client): ) -def test_p2tr_ownership_id(client: Client): +def test_p2tr_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -48,12 +48,12 @@ def test_p2tr_ownership_id(client: Client): ) -def test_attack_ownership_id(client: Client): +def test_attack_ownership_id(session: Session): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -62,7 +62,7 @@ def test_attack_ownership_id(client: Client): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -77,7 +77,7 @@ def test_attack_ownership_id(client: Client): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), multisig=multisig, @@ -85,9 +85,9 @@ def test_attack_ownership_id(client: Client): ) -def test_p2wpkh_ownership_proof(client: Client): +def test_p2wpkh_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -98,9 +98,9 @@ def test_p2wpkh_ownership_proof(client: Client): ) -def test_p2tr_ownership_proof(client: Client): +def test_p2tr_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -111,10 +111,10 @@ def test_p2tr_ownership_proof(client: Client): ) -def test_fake_ownership_id(client: Client): +def test_fake_ownership_id(session: Session): with pytest.raises(TrezorFailure, match="Invalid ownership identifier"): btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -124,9 +124,9 @@ def test_fake_ownership_id(client: Client): ) -def test_confirm_ownership_proof(client: Client): +def test_confirm_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -139,9 +139,9 @@ def test_confirm_ownership_proof(client: Client): ) -def test_confirm_ownership_proof_with_data(client: Client): +def test_confirm_ownership_proof_with_data(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index be0c43e535a..7fcc1a595e6 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -110,33 +110,37 @@ @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node(client: Client, coin_name, xpub_magic, path, xpub): - res = btc.get_public_node(client, path, coin_name=coin_name) +def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): + res = btc.get_public_node(session, path, coin_name=coin_name) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show(client: Client, coin_name, xpub_magic, path, xpub): - with client: +def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") @pytest.mark.parametrize("coin_name, path", VECTORS_INVALID) -def test_invalid_path(client: Client, coin_name, path): +def test_invalid_path(session: Session, coin_name, path): with pytest.raises(TrezorFailure, match="Forbidden key path"): - btc.get_public_node(client, path, coin_name=coin_name) + btc.get_public_node(session, path, coin_name=coin_name) @pytest.mark.models("legacy") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show_legacy(client: Client, coin_name, xpub_magic, path, xpub): +def test_get_public_node_show_legacy( + session: Session, coin_name, xpub_magic, path, xpub +): + client = session.client + def input_flow(): yield client.debug.press_no() # show QR code @@ -167,11 +171,11 @@ def input_flow(): assert bip32.serialize(res.node, xpub_magic) == xpub -def test_slip25_path(client: Client): +def test_slip25_path(session: Session): # Ensure that CoinJoin XPUBs are inaccessible without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_public_node( - client, + session, parse_path("m/10025h/0h/0h/1h"), script_type=messages.InputScriptType.SPENDTAPROOT, ) @@ -202,14 +206,14 @@ def test_slip25_path(client: Client): @pytest.mark.parametrize("script_type, xpub, xpub_ignored_magic", VECTORS_SCRIPT_TYPES) -def test_script_type(client: Client, script_type, xpub, xpub_ignored_magic): +def test_script_type(session: Session, script_type, xpub, xpub_ignored_magic): path = parse_path("m/44h/0h/0") res = btc.get_public_node( - client, path, coin_name="Bitcoin", script_type=script_type + session, path, coin_name="Bitcoin", script_type=script_type ) assert res.xpub == xpub res = btc.get_public_node( - client, + session, path, coin_name="Bitcoin", script_type=script_type, diff --git a/tests/device_tests/bitcoin/test_getpublickey_curve.py b/tests/device_tests/bitcoin/test_getpublickey_curve.py index 8b8cba68871..393afca61c8 100644 --- a/tests/device_tests/bitcoin/test_getpublickey_curve.py +++ b/tests/device_tests/bitcoin/test_getpublickey_curve.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -54,21 +54,21 @@ @pytest.mark.parametrize("curve, path, pubkey", VECTORS) -def test_publickey_curve(client: Client, curve, path, pubkey): - resp = btc.get_public_node(client, path, ecdsa_curve_name=curve) +def test_publickey_curve(session: Session, curve, path, pubkey): + resp = btc.get_public_node(session, path, ecdsa_curve_name=curve) assert resp.node.public_key.hex() == pubkey -def test_ed25519_public(client: Client): +def test_ed25519_public(session: Session): with pytest.raises(TrezorFailure): - btc.get_public_node(client, PATH_PUBLIC, ecdsa_curve_name="ed25519") + btc.get_public_node(session, PATH_PUBLIC, ecdsa_curve_name="ed25519") @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") -def test_coin_and_curve(client: Client): +def test_coin_and_curve(session: Session): with pytest.raises( TrezorFailure, match="Cannot use coin_name or script_type with ecdsa_curve_name" ): btc.get_public_node( - client, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" + session, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" ) diff --git a/tests/device_tests/bitcoin/test_grs.py b/tests/device_tests/bitcoin/test_grs.py index d25ffd20f00..ff2b5c4cdfc 100644 --- a/tests/device_tests/bitcoin/test_grs.py +++ b/tests/device_tests/bitcoin/test_grs.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ pytestmark = pytest.mark.altcoin -def test_legacy(client: Client): +def test_legacy(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -56,7 +56,7 @@ def test_legacy(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -64,7 +64,7 @@ def test_legacy(client: Client): ) -def test_legacy_change(client: Client): +def test_legacy_change(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -78,7 +78,7 @@ def test_legacy_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -86,7 +86,7 @@ def test_legacy_change(client: Client): ) -def test_send_segwit_p2sh(client: Client): +def test_send_segwit_p2sh(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -107,7 +107,7 @@ def test_send_segwit_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -120,7 +120,7 @@ def test_send_segwit_p2sh(client: Client): ) -def test_send_segwit_p2sh_change(client: Client): +def test_send_segwit_p2sh_change(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -141,7 +141,7 @@ def test_send_segwit_p2sh_change(client: Client): amount=123_456_789 - 11_000 - 12_300_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -154,7 +154,7 @@ def test_send_segwit_p2sh_change(client: Client): ) -def test_send_segwit_native(client: Client): +def test_send_segwit_native(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -174,7 +174,7 @@ def test_send_segwit_native(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -187,7 +187,7 @@ def test_send_segwit_native(client: Client): ) -def test_send_segwit_native_change(client: Client): +def test_send_segwit_native_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -207,7 +207,7 @@ def test_send_segwit_native_change(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -220,7 +220,7 @@ def test_send_segwit_native_change(client: Client): ) -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # tgrs1paxhjl357yzctuf3fe58fcdx6nul026hhh6kyldpfsf3tckj9a3wsvuqrgn address_n=parse_path("m/86h/1h/1h/0/0"), @@ -236,7 +236,7 @@ def test_send_p2tr(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://blockbook-test.groestlcoin.org/tx/c66a79075044aaab3dba17daffb23f48addee87d7c87c7bc88e2997ce38a74ee diff --git a/tests/device_tests/bitcoin/test_komodo.py b/tests/device_tests/bitcoin/test_komodo.py index f883afc7bcd..111acefc6f3 100644 --- a/tests/device_tests/bitcoin/test_komodo.py +++ b/tests/device_tests/bitcoin/test_komodo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.komodo] -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: 2807c5b126ec8e2b078cab0f12e4c8b4ce1d7724905f8ebef8dca26b0c8e0f1d:0 # input 1: 10.9998 KMD @@ -61,13 +61,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -82,7 +82,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1], @@ -100,7 +100,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_one_one_rewards_claim(client: Client): +def test_one_one_rewards_claim(session: Session): # prevout: 7b28bd91119e9776f0d4ebd80e570165818a829bbf4477cd1afe5149dbcd34b1:0 # input 1: 10.9997 KMD @@ -125,16 +125,16 @@ def test_one_one_rewards_claim(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -150,7 +150,7 @@ def test_one_one_rewards_claim(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 2a01db8108f..5888409d864 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -55,12 +55,12 @@ @pytest.mark.multisig @pytest.mark.parametrize("chunkify", (True, False)) -def test_2_of_3(client: Client, chunkify: bool): +def test_2_of_3(session: Session, chunkify: bool): # input tx: 6b07c1321b52d9c85743f9695e13eb431b41708cdf4e1585258d51208e5b93fc nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -89,7 +89,7 @@ def test_2_of_3(client: Client, chunkify: bool): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_6b07c1), @@ -101,12 +101,12 @@ def test_2_of_3(client: Client, chunkify: bool): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) # Now we have first signature signatures1, _ = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1], @@ -143,10 +143,10 @@ def test_2_of_3(client: Client, chunkify: bool): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET ) assert ( @@ -162,12 +162,12 @@ def test_2_of_3(client: Client, chunkify: bool): @pytest.mark.multisig -def test_pubkeys_order(client: Client): +def test_pubkeys_order(session: Session): node_internal = btc.get_public_node( - client, parse_path("m/45h/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0"), coin_name="Bitcoin" ).node node_external = btc.get_public_node( - client, parse_path("m/45h/1"), coin_name="Bitcoin" + session, parse_path("m/45h/1"), coin_name="Bitcoin" ).node # A dummy signature is used to ensure that the signatures are serialized in the correct order @@ -206,17 +206,17 @@ def test_pubkeys_order(client: Client): ) address_unsorted_1 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) address_unsorted_2 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) pubkey_internal = btc.get_public_node( - client, parse_path("m/45h/0/0/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0/0/0"), coin_name="Bitcoin" ).node.public_key pubkey_external = btc.get_public_node( - client, parse_path("m/45h/1/0/0"), coin_name="Bitcoin" + session, parse_path("m/45h/1/0/0"), coin_name="Bitcoin" ).node.public_key # This assertion implies that script pubkey of multisig_sorted_1, multisig_sorted_2 and multisig_unsorted_1 are the same @@ -295,7 +295,7 @@ def test_pubkeys_order(client: Client): tx_unsorted_2 = "0100000001637ffac0d4fbd8a6c02b114e36b079615ec3e4bdf09b769c7bf8b5fd6f8e781701000000da004800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000147304402204914036468434698e2d87985007a66691f170195e4a16507bbb86b4c00da5fde02200a788312d447b3796ee5288ce9e9c0247896debfa473339302bc928da6dd78cb014751210369b79f2094a6eb89e7aff0e012a5699f7272968a341e48e99e64a54312f2932b210262e9ac5bea4c84c7dea650424ed768cf123af9e447eef3c63d37c41d1f825e4952aeffffffff01301b0f000000000017a914320ad0ff0f1b605ab1fa8e29b70d22827cf45a9f8700000000" _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_1], [output_unsorted_1], @@ -304,7 +304,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_2], [output_unsorted_2], @@ -313,7 +313,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_2 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_1], [output_sorted_1], @@ -322,7 +322,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_2], [output_sorted_2], @@ -332,11 +332,11 @@ def test_pubkeys_order(client: Client): @pytest.mark.multisig -def test_15_of_15(client: Client): +def test_15_of_15(session: Session): # input tx: 0d5b5648d47b5650edea1af3d47bbe5624213abb577cf1b1c96f98321f75cdbc node = btc.get_public_node( - client, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" + session, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" ).node pubs = [messages.HDNodePathType(node=node, address_n=[0, x]) for x in range(15)] @@ -362,9 +362,9 @@ def test_15_of_15(client: Client): multisig=multisig, ) - with client: + with session: sig, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) signatures[x] = sig[0] @@ -376,9 +376,9 @@ def test_15_of_15(client: Client): @pytest.mark.multisig @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_missing_pubkey(client: Client): +def test_missing_pubkey(session: Session): node = btc.get_public_node( - client, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" ).node multisig = messages.MultisigRedeemScriptType( @@ -408,16 +408,16 @@ def test_missing_pubkey(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) - if client.model is models.T1B1: + if session.model is models.T1B1: assert exc.value.message.endswith("Failed to derive scriptPubKey") else: assert exc.value.message.endswith("Pubkey not found in multisig script") @pytest.mark.multisig -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): """ In Phases 1 and 2 the attacker replaces a non-multisig input `input_real` with a multisig input `input_fake`, which allows the @@ -440,7 +440,7 @@ def test_attack_change_input(client: Client): multisig_fake = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -475,12 +475,12 @@ def test_attack_change_input(client: Client): ) # Transaction can be signed without the attack processor - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], @@ -497,11 +497,11 @@ def attack_processor(msg): attack_count -= 1 return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index 7beaa31badc..efc4f42d56c 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ... import bip32 @@ -191,7 +191,7 @@ def _responses( - client: Client, + session: Session, INP1: messages.TxInputType, INP2: messages.TxInputType, change_indices: Optional[list[int]] = None, @@ -212,7 +212,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 1 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(request_output(1)) @@ -221,7 +221,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 2 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp += [ @@ -250,7 +250,7 @@ def _responses( # both outputs are external -def test_external_external(client: Client): +def test_external_external(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -263,10 +263,10 @@ def test_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -275,7 +275,7 @@ def test_external_external(client: Client): # first external, second internal -def test_external_internal(client: Client): +def test_external_internal(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -288,21 +288,21 @@ def test_external_internal(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[], foreign_indices=[2], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -311,7 +311,7 @@ def test_external_internal(client: Client): # first internal, second external -def test_internal_external(client: Client): +def test_internal_external(session: Session): out1 = messages.TxOutputType( address_n=parse_path("m/45h/0/1/0"), amount=40_000_000, @@ -324,21 +324,21 @@ def test_internal_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[], foreign_indices=[1], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -347,7 +347,7 @@ def test_internal_external(client: Client): # both outputs are external -def test_multisig_external_external(client: Client): +def test_multisig_external_external(session: Session): out1 = messages.TxOutputType( address="3B23k4kFBRtu49zvpG3Z9xuFzfpHvxBcwt", amount=40_000_000, @@ -360,10 +360,10 @@ def test_multisig_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -372,7 +372,7 @@ def test_multisig_external_external(client: Client): # inputs match, change matches (first is change) -def test_multisig_change_match_first(client: Client): +def test_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -393,12 +393,12 @@ def test_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[1]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[1]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -407,7 +407,7 @@ def test_multisig_change_match_first(client: Client): # inputs match, change matches (second is change) -def test_multisig_change_match_second(client: Client): +def test_multisig_change_match_second(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 1], @@ -428,12 +428,12 @@ def test_multisig_change_match_second(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[2]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[2]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -442,7 +442,7 @@ def test_multisig_change_match_second(client: Client): # inputs match, change matches (first is change) -def test_sorted_multisig_change_match_first(client: Client): +def test_sorted_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT2], address_n=[1, 0], @@ -464,12 +464,12 @@ def test_sorted_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP4, INP5, change_indices=[1]) + with session: + session.set_expected_responses( + _responses(session, INP4, INP5, change_indices=[1]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP5], [out1, out2], @@ -478,7 +478,7 @@ def test_sorted_multisig_change_match_first(client: Client): # inputs match, change mismatches (second tries to be change but isn't because the pubkeys are in different order) -def test_multisig_mismatch_multisig_change(client: Client): +def test_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT2], address_n=[1, 0], @@ -499,10 +499,10 @@ def test_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -511,7 +511,7 @@ def test_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't because the pubkeys are different) -def test_sorted_multisig_mismatch_multisig_change(client: Client): +def test_sorted_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], address_n=[1, 0], @@ -532,10 +532,10 @@ def test_sorted_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP4, INP5)) + with session: + session.set_expected_responses(_responses(session, INP4, INP5)) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP5], [out1, out2], @@ -544,7 +544,7 @@ def test_sorted_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't because is uses per-node paths) -def test_multisig_mismatch_multisig_change_different_paths(client: Client): +def test_multisig_mismatch_multisig_change_different_paths(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( pubkeys=[ messages.HDNodePathType(node=NODE_EXT1, address_n=[1, 0]), @@ -568,10 +568,10 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -580,7 +580,7 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): # inputs mismatch because the pubkeys are different, change matches with first input -def test_multisig_mismatch_inputs(client: Client): +def test_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -601,10 +601,10 @@ def test_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP3)) + with session: + session.set_expected_responses(_responses(session, INP1, INP3)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP3], [out1, out2], @@ -613,7 +613,7 @@ def test_multisig_mismatch_inputs(client: Client): # inputs mismatch because the pubkeys are different, change matches with first input -def test_sorted_multisig_mismatch_inputs(client: Client): +def test_sorted_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -635,10 +635,10 @@ def test_sorted_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP4, INP6)) + with session: + session.set_expected_responses(_responses(session, INP4, INP6)) btc.sign_tx( - client, + session, "Bitcoin", [INP4, INP6], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index ac33ee8b40e..77d57aa9517 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -94,11 +94,11 @@ # accepted in case we make this more restrictive in the future. @pytest.mark.parametrize("path, script_types", VECTORS) def test_getpublicnode( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: res = btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin", script_type=script_type + session, parse_path(path), coin_name="Bitcoin", script_type=script_type ) assert res.xpub @@ -107,18 +107,18 @@ def test_getpublicnode( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_types", VECTORS) def test_getaddress( - client: Client, + session: Session, chunkify: bool, path: str, script_types: list[messages.InputScriptType], ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) res = btc.get_address( - client, + session, "Bitcoin", parse_path(path), show_display=True, @@ -131,16 +131,16 @@ def test_getaddress( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signmessage( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path(path), script_type=script_type, @@ -152,12 +152,14 @@ def test_signmessage( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signtx( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): address_n = parse_path(path) for script_type in script_types: - address = btc.get_address(client, "Bitcoin", address_n, script_type=script_type) + address = btc.get_address( + session, "Bitcoin", address_n, script_type=script_type + ) prevhash, prevtx = forge_prevtx([(address, 390_000)]) inp1 = messages.TxInputType( address_n=address_n, @@ -173,12 +175,12 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert serialized_tx.hex() @@ -187,12 +189,12 @@ def test_signtx( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) def test_getaddress_multisig( - client: Client, paths: list[str], address_index: list[int] + session: Session, paths: list[str], address_index: list[int] ): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -200,12 +202,12 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) address = btc.get_address( - client, + session, "Bitcoin", parse_path(paths[0]) + address_index, show_display=True, @@ -218,11 +220,11 @@ def test_getaddress_multisig( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) -def test_signtx_multisig(client: Client, paths: list[str], address_index: list[int]): +def test_signtx_multisig(session: Session, paths: list[str], address_index: list[int]): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -235,7 +237,7 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i address_n = parse_path(paths[0]) + address_index address = btc.get_address( - client, + session, "Bitcoin", address_n, multisig=multisig, @@ -259,12 +261,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig, _ = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert sig[0] diff --git a/tests/device_tests/bitcoin/test_op_return.py b/tests/device_tests/bitcoin/test_op_return.py index b5063891993..0aa8acb0802 100644 --- a/tests/device_tests/bitcoin/test_op_return.py +++ b/tests/device_tests/bitcoin/test_op_return.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,7 +43,7 @@ ) -def test_opreturn(client: Client): +def test_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/1h/0/21"), # myGMXcCxmuDooMdzZFPMmvHviijzqYKhza amount=89_581, @@ -63,13 +63,13 @@ def test_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.SignTx), @@ -86,7 +86,7 @@ def test_opreturn(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -96,7 +96,7 @@ def test_opreturn(client: Client): ) -def test_nonzero_opreturn(client: Client): +def test_nonzero_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/10h/0/5"), amount=390_000, @@ -110,18 +110,18 @@ def test_nonzero_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="OP_RETURN output with non-zero amount" ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) -def test_opreturn_address(client: Client): +def test_opreturn_address(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/2"), amount=390_000, @@ -136,11 +136,11 @@ def test_opreturn_address(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="Output's address_n provided but not expected." ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_peercoin.py b/tests/device_tests/bitcoin/test_peercoin.py index b1b62e49e55..b3de714e26e 100644 --- a/tests/device_tests/bitcoin/test_peercoin.py +++ b/tests/device_tests/bitcoin/test_peercoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -32,7 +32,7 @@ @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_included(client: Client): +def test_timestamp_included(session: Session): # tx: 41b29ad615d8eea40a4654a052d18bb10cd08f203c351f4d241f88b031357d3d # input 0: 0.1 PPC @@ -50,7 +50,7 @@ def test_timestamp_included(client: Client): ) _, timestamp_tx = btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -66,7 +66,7 @@ def test_timestamp_included(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing(client: Client): +def test_timestamp_missing(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -81,7 +81,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -92,7 +92,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -104,7 +104,7 @@ def test_timestamp_missing(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing_prevtx(client: Client): +def test_timestamp_missing_prevtx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -122,7 +122,7 @@ def test_timestamp_missing_prevtx(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -134,7 +134,7 @@ def test_timestamp_missing_prevtx(client: Client): prevtx.timestamp = None with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index be3eb46ec82..ce6bb2debea 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -20,7 +20,7 @@ from trezorlib import btc, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import message_filters from trezorlib.exceptions import Cancelled from trezorlib.tools import parse_path @@ -286,7 +286,7 @@ def case( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -296,7 +296,7 @@ def test_signmessage( signature: str, ): sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -312,7 +312,7 @@ def test_signmessage( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage_info( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -321,11 +321,11 @@ def test_signmessage_info( message: str, signature: str, ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignMessageInfo(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -352,12 +352,12 @@ def test_signmessage_info( @pytest.mark.models("core") @pytest.mark.parametrize("message", MESSAGE_LENGTHS) -def test_signmessage_pagination(client: Client, message: str): - with client: +def test_signmessage_pagination(session: Session, message: str): + with session.client as client: IF = InputFlowSignMessagePagination(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, @@ -365,19 +365,19 @@ def test_signmessage_pagination(client: Client, message: str): # We cannot differentiate between a newline and space in the message read from Trezor. # TODO: do the check also for T2B1 - if client.layout_type in (LayoutType.Bolt, LayoutType.Delizia): + if session.client.layout_type in (LayoutType.Bolt, LayoutType.Delizia): message_read = IF.message_read.replace(" ", "").replace("...", "") signed_message = message.replace("\n", "").replace(" ", "") assert signed_message in message_read @pytest.mark.models("t2t1", reason="Tailored to TT fonts and screen size") -def test_signmessage_pagination_trailing_newline(client: Client): +def test_signmessage_pagination_trailing_newline(session: Session): message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" # The trailing newline must not cause a new paginated screen to appear. # The UI must be a single dialog without pagination. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # expect address confirmation message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), @@ -387,18 +387,18 @@ def test_signmessage_pagination_trailing_newline(client: Client): ] ) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, ) -def test_signmessage_path_warning(client: Client): +def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ # expect a path warning message_filters.ButtonRequest( @@ -409,11 +409,11 @@ def test_signmessage_path_warning(client: Client): messages.MessageSignature, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/86h/0h/0h/0/0"), message=message, diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index 9dce03a317a..0a637efb63f 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.tools import H_, parse_path @@ -111,7 +111,7 @@ CORNER_BUTTON = (215, 25) -def test_one_one_fee(client: Client): +def test_one_one_fee(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -127,13 +127,13 @@ def test_one_one_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_0dac36), @@ -148,7 +148,7 @@ def test_one_one_fee(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -158,7 +158,7 @@ def test_one_one_fee(client: Client): ) -def test_testnet_one_two_fee(client: Client): +def test_testnet_one_two_fee(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd inp1 = messages.TxInputType( @@ -180,13 +180,13 @@ def test_testnet_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -203,7 +203,7 @@ def test_testnet_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -213,7 +213,7 @@ def test_testnet_one_two_fee(client: Client): ) -def test_testnet_fee_high_warning(client: Client): +def test_testnet_fee_high_warning(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -230,13 +230,13 @@ def test_testnet_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -250,7 +250,7 @@ def test_testnet_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -260,7 +260,7 @@ def test_testnet_fee_high_warning(client: Client): ) -def test_one_two_fee(client: Client): +def test_one_two_fee(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -282,14 +282,14 @@ def test_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -305,7 +305,7 @@ def test_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -316,7 +316,7 @@ def test_one_two_fee(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_one_three_fee(client: Client, chunkify: bool): +def test_one_three_fee(session: Session, chunkify: bool): # input tx: bb5169091f09e833e155b291b662019df56870effe388c626221c5ea84274bc4 inp1 = messages.TxInputType( @@ -344,16 +344,16 @@ def test_one_three_fee(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -371,7 +371,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2, out3], @@ -386,7 +386,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ) -def test_two_two(client: Client): +def test_two_two(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -415,15 +415,15 @@ def test_two_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -449,7 +449,7 @@ def test_two_two(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -464,7 +464,7 @@ def test_two_two(client: Client): @pytest.mark.slow -def test_lots_of_inputs(client: Client): +def test_lots_of_inputs(session: Session): # Tests if device implements serialization of len(inputs) correctly # input tx: 3019487f064329247daad245aed7a75349d09c14b1d24f170947690e030f5b20 @@ -485,7 +485,7 @@ def test_lots_of_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET + session, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -495,7 +495,7 @@ def test_lots_of_inputs(client: Client): @pytest.mark.slow -def test_lots_of_outputs(client: Client): +def test_lots_of_outputs(session: Session): # Tests if device implements serialization of len(outputs) correctly # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e @@ -518,7 +518,7 @@ def test_lots_of_outputs(client: Client): outputs.append(out) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -528,7 +528,7 @@ def test_lots_of_outputs(client: Client): @pytest.mark.slow -def test_lots_of_change(client: Client): +def test_lots_of_change(session: Session): # Tests if device implements prompting for multiple change addresses correctly # input tx: 892d06cb3394b8e6006eec9a2aa90692b718a29be6844b6c6a9e89ec3aa6aac4 @@ -559,13 +559,13 @@ def test_lots_of_change(client: Client): request_change_outputs = [request_output(i + 1) for i in range(cnt)] - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), ] + request_change_outputs + [ @@ -585,7 +585,7 @@ def test_lots_of_change(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -594,7 +594,7 @@ def test_lots_of_change(client: Client): ) -def test_fee_high_warning(client: Client): +def test_fee_high_warning(session: Session): # input tx: 1f326f65768d55ef146efbb345bd87abe84ac7185726d0457a026fc347a26ef3 inp1 = messages.TxInputType( @@ -610,13 +610,13 @@ def test_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -631,7 +631,7 @@ def test_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -642,7 +642,7 @@ def test_fee_high_warning(client: Client): @pytest.mark.models("core") -def test_fee_high_hardfail(client: Client): +def test_fee_high_hardfail(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -660,18 +660,18 @@ def test_fee_high_hardfail(client: Client): ) with pytest.raises(TrezorFailure, match="fee is unexpectedly large"): - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) # set SafetyCheckLevel to PromptTemporarily and try again device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: + with session.client as client: IF = InputFlowSignTxHighFee(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert IF.finished @@ -682,7 +682,7 @@ def test_fee_high_hardfail(client: Client): ) -def test_not_enough_funds(client: Client): +def test_not_enough_funds(session: Session): # input tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 inp1 = messages.TxInputType( @@ -698,21 +698,21 @@ def test_not_enough_funds(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.Failure(code=messages.FailureType.NotEnoughFunds), ] ) with pytest.raises(TrezorFailure, match="NotEnoughFunds"): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) -def test_p2sh(client: Client): +def test_p2sh(session: Session): # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e inp1 = messages.TxInputType( @@ -728,13 +728,13 @@ def test_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_58d56a), @@ -748,7 +748,7 @@ def test_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -758,7 +758,7 @@ def test_p2sh(client: Client): ) -def test_testnet_big_amount(client: Client): +def test_testnet_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 # input tx: 074b0070939db4c2635c1bef0c8e68412ccc8d3c8782137547c7a2bbde073fc0 @@ -775,7 +775,7 @@ def test_testnet_big_amount(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -785,7 +785,7 @@ def test_testnet_big_amount(client: Client): ) -def test_attack_change_outputs(client: Client): +def test_attack_change_outputs(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -815,15 +815,15 @@ def test_attack_change_outputs(client: Client): ) # Test if the transaction can be signed normally - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -849,7 +849,7 @@ def test_attack_change_outputs(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -871,14 +871,14 @@ def attack_processor(msg): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -886,7 +886,7 @@ def attack_processor(msg): ) -def test_attack_modify_change_address(client: Client): +def test_attack_modify_change_address(session: Session): # Ensure that if the change output is modified after the user confirms the # transaction, then signing fails. @@ -926,16 +926,18 @@ def attack_processor(msg): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # input tx: d2dcdaf547ea7f57a713c607f15e883ddc4a98167ee2c43ed953c53cb5153e24 inp1 = messages.TxInputType( @@ -960,7 +962,7 @@ def test_attack_change_input_address(client: Client): # Test if the transaction can be signed normally _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -982,14 +984,14 @@ def attack_processor(msg): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1004,7 +1006,7 @@ def attack_processor(msg): # Now run the attack, must trigger the exception with pytest.raises(TrezorFailure) as exc: btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1015,7 +1017,7 @@ def attack_processor(msg): assert exc.value.message.endswith("Transaction has changed during signing") -def test_spend_coinbase(client: Client): +def test_spend_coinbase(session: Session): # NOTE: the input transaction is not real # We did not have any coinbase transaction at connected with `all all` seed, # so it was artificially created for the test purpose @@ -1033,13 +1035,13 @@ def test_spend_coinbase(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_005f6f), @@ -1052,7 +1054,7 @@ def test_spend_coinbase(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -1062,7 +1064,7 @@ def test_spend_coinbase(client: Client): ) -def test_two_changes(client: Client): +def test_two_changes(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1091,13 +1093,13 @@ def test_two_changes(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), request_output(2), messages.ButtonRequest(code=B.SignTx), @@ -1118,7 +1120,7 @@ def test_two_changes(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change1, out_change2], @@ -1126,7 +1128,7 @@ def test_two_changes(client: Client): ) -def test_change_on_main_chain_allowed(client: Client): +def test_change_on_main_chain_allowed(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1150,13 +1152,13 @@ def test_change_on_main_chain_allowed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1174,7 +1176,7 @@ def test_change_on_main_chain_allowed(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change], @@ -1182,7 +1184,7 @@ def test_change_on_main_chain_allowed(client: Client): ) -def test_not_enough_vouts(client: Client): +def test_not_enough_vouts(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a prev_tx = TX_CACHE_MAINNET[TXHASH_ac4ca0] @@ -1222,7 +1224,7 @@ def test_not_enough_vouts(client: Client): TrezorFailure, match="Not enough outputs in previous transaction." ): btc.sign_tx( - client, + session, "Bitcoin", [inp0, inp1, inp2], [out1], @@ -1240,7 +1242,7 @@ def test_not_enough_vouts(client: Client): ("branch_id", 13), ), ) -def test_prevtx_forbidden_fields(client: Client, field, value): +def test_prevtx_forbidden_fields(session: Session, field, value): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1258,7 +1260,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} + session, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} ) @@ -1266,7 +1268,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): "field, value", (("expiry", 9), ("timestamp", 42), ("version_group_id", 69), ("branch_id", 13)), ) -def test_signtx_forbidden_fields(client: Client, field: str, value: int): +def test_signtx_forbidden_fields(session: Session, field: str, value: int): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1283,7 +1285,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs + session, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs ) @@ -1291,7 +1293,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): "script_type", (messages.InputScriptType.SPENDADDRESS, messages.InputScriptType.EXTERNAL), ) -def test_incorrect_input_script_type(client: Client, script_type): +def test_incorrect_input_script_type(session: Session, script_type): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( "030e669acac1f280d1ddf441cd2ba5e97417bf2689e4bbec86df4f831bf9f7ffd0" @@ -1300,7 +1302,7 @@ def test_incorrect_input_script_type(client: Client, script_type): multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1335,7 +1337,9 @@ def test_incorrect_input_script_type(client: Client, script_type): with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( @@ -1346,7 +1350,7 @@ def test_incorrect_input_script_type(client: Client, script_type): ), ) def test_incorrect_output_script_type( - client: Client, script_type: messages.OutputScriptType + session: Session, script_type: messages.OutputScriptType ): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( @@ -1356,7 +1360,7 @@ def test_incorrect_output_script_type( multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1390,14 +1394,16 @@ def test_incorrect_output_script_type( with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( "lock_time, sequence", ((499_999_999, 0xFFFFFFFE), (500_000_000, 0xFFFFFFFE), (1, 0xFFFFFFFF)), ) -def test_lock_time(client: Client, lock_time: int, sequence: int): +def test_lock_time(session: Session, lock_time: int, sequence: int): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1414,13 +1420,13 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1436,7 +1442,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): ) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1446,7 +1452,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_lock_time_blockheight(client: Client): +def test_lock_time_blockheight(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1463,12 +1469,12 @@ def test_lock_time_blockheight(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowLockTimeBlockHeight(client, "499999999") client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1481,7 +1487,7 @@ def test_lock_time_blockheight(client: Client): @pytest.mark.parametrize( "lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00") ) -def test_lock_time_datetime(client: Client, lock_time_str: str): +def test_lock_time_datetime(session: Session, lock_time_str: str): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1502,12 +1508,12 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with client: + with session.client as client: IF = InputFlowLockTimeDatetime(client, lock_time_str) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1517,7 +1523,7 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information(client: Client): +def test_information(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1534,12 +1540,12 @@ def test_information(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformation(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1548,7 +1554,7 @@ def test_information(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_mixed(client: Client): +def test_information_mixed(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/0"), # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q amount=31_000_000, @@ -1569,12 +1575,12 @@ def test_information_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationMixed(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -1583,7 +1589,7 @@ def test_information_mixed(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_cancel(client: Client): +def test_information_cancel(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1600,12 +1606,12 @@ def test_information_cancel(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignTxInformationCancel(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1618,7 +1624,7 @@ def test_information_cancel(client: Client): skip="delizia", reason="Cannot test layouts on T1, not implemented in Delizia UI", ) -def test_information_replacement(client: Client): +def test_information_replacement(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -1650,12 +1656,12 @@ def test_information_replacement(client: Client): orig_index=0, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationReplacement(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_amount_unit.py b/tests/device_tests/bitcoin/test_signtx_amount_unit.py index d3dfa3d00ec..50cc19151b6 100644 --- a/tests/device_tests/bitcoin/test_signtx_amount_unit.py +++ b/tests/device_tests/bitcoin/test_signtx_amount_unit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_testnet(client: Client, amount_unit): +def test_signtx_testnet(session: Session, amount_unit): inp1 = messages.TxInputType( # tb1qajr3a3y5uz27lkxrmn7ck8lp22dgytvagr5nqy address_n=parse_path("m/84h/1h/0h/0/87"), @@ -61,9 +61,9 @@ def test_signtx_testnet(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -79,7 +79,7 @@ def test_signtx_testnet(client: Client, amount_unit): @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_btc(client: Client, amount_unit): +def test_signtx_btc(session: Session, amount_unit): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -95,9 +95,9 @@ def test_signtx_btc(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_external.py b/tests/device_tests/bitcoin/test_signtx_external.py index fd8e0cff3e9..4d44e3ec763 100644 --- a/tests/device_tests/bitcoin/test_signtx_external.py +++ b/tests/device_tests/bitcoin/test_signtx_external.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path @@ -82,7 +82,7 @@ @pytest.mark.models("core") -def test_p2pkh_presigned(client: Client): +def test_p2pkh_presigned(session: Session): inp1 = messages.TxInputType( # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q address_n=parse_path("m/44h/1h/0h/0/0"), @@ -142,9 +142,9 @@ def test_p2pkh_presigned(client: Client): ) # Test with first input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1ext, inp2], [out1, out2], @@ -155,9 +155,9 @@ def test_p2pkh_presigned(client: Client): assert serialized_tx.hex() == expected_tx # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -170,7 +170,7 @@ def test_p2pkh_presigned(client: Client): inp2ext.script_sig[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -179,7 +179,7 @@ def test_p2pkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_presigned(client: Client): +def test_p2wpkh_in_p2sh_presigned(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX amount=123_456_789, @@ -216,20 +216,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -252,7 +252,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -267,20 +267,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): # Test corrupted script hash in scriptsig. inp1.script_sig[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -293,7 +293,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid public key hash"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -302,7 +302,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_presigned(client: Client): +def test_p2wpkh_presigned(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -339,9 +339,9 @@ def test_p2wpkh_presigned(client: Client): ) # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -358,7 +358,7 @@ def test_p2wpkh_presigned(client: Client): inp2.witness[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -367,7 +367,7 @@ def test_p2wpkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wsh_external_presigned(client: Client): +def test_p2wsh_external_presigned(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=10_000, @@ -399,14 +399,14 @@ def test_p2wsh_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -429,7 +429,7 @@ def test_p2wsh_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -444,14 +444,14 @@ def test_p2wsh_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -470,12 +470,12 @@ def test_p2wsh_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) @pytest.mark.models("core") -def test_p2tr_external_presigned(client: Client): +def test_p2tr_external_presigned(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -509,14 +509,14 @@ def test_p2tr_external_presigned(client: Client): amount=4_600, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -530,7 +530,7 @@ def test_p2tr_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -541,14 +541,14 @@ def test_p2tr_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -558,7 +558,7 @@ def test_p2tr_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -567,18 +567,18 @@ def test_p2tr_external_presigned(client: Client): @pytest.mark.models("core") -def test_p2pkh_with_proof(client: Client): +def test_p2pkh_with_proof(session: Session): # TODO pass @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_with_proof(client: Client): +def test_p2wpkh_in_p2sh_with_proof(session: Session): # TODO pass -def test_p2wpkh_with_proof(client: Client): +def test_p2wpkh_with_proof(session: Session): inp1 = messages.TxInputType( # seed "alcohol woman abuse must during monitor noble actual mixed trade anger aisle" # 84'/1'/0'/0/0 @@ -610,18 +610,18 @@ def test_p2wpkh_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e5b7e2), @@ -643,7 +643,7 @@ def test_p2wpkh_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -660,7 +660,7 @@ def test_p2wpkh_with_proof(client: Client): inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -671,7 +671,7 @@ def test_p2wpkh_with_proof(client: Client): @pytest.mark.setup_client( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) -def test_p2tr_with_proof(client: Client): +def test_p2tr_with_proof(session: Session): # Resulting TXID 48ec6dc7bb772ff18cbce0135fedda7c0e85212c7b2f85a5d0cc7a917d77c48a inp1 = messages.TxInputType( @@ -703,15 +703,15 @@ def test_p2tr_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -722,7 +722,7 @@ def test_p2tr_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -736,10 +736,12 @@ def test_p2tr_with_proof(client: Client): # Test corrupted ownership proof. inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + ) -def test_p2wpkh_with_false_proof(client: Client): +def test_p2wpkh_with_false_proof(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -768,8 +770,8 @@ def test_p2wpkh_with_false_proof(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), @@ -779,7 +781,7 @@ def test_p2wpkh_with_false_proof(client: Client): with pytest.raises(TrezorFailure, match="Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -787,7 +789,7 @@ def test_p2wpkh_with_false_proof(client: Client): ) -def test_p2tr_external_unverified(client: Client): +def test_p2tr_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -823,13 +825,13 @@ def test_p2tr_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. @@ -840,7 +842,7 @@ def test_p2tr_external_unverified(client: Client): ) -def test_p2wpkh_external_unverified(client: Client): +def test_p2wpkh_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -875,13 +877,13 @@ def test_p2wpkh_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 5ef4ba0389c..27f0599de9b 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -36,7 +36,7 @@ # Litecoin does not have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should fail. @pytest.mark.altcoin -def test_invalid_path_fail(client: Client): +def test_invalid_path_fail(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -52,7 +52,7 @@ def test_invalid_path_fail(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) assert exc.value.code == messages.FailureType.DataError assert exc.value.message.endswith("Forbidden key path") @@ -61,7 +61,7 @@ def test_invalid_path_fail(client: Client): # Litecoin does not have strong replay protection using SIGHASH_FORKID, but # spending from Bitcoin path should pass with safety checks set to prompt. @pytest.mark.altcoin -def test_invalid_path_prompt(client: Client): +def test_invalid_path_prompt(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -77,21 +77,21 @@ def test_invalid_path_prompt(client: Client): ) device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) # Bcash does have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should work. @pytest.mark.altcoin -def test_invalid_path_pass_forkid(client: Client): +def test_invalid_path_pass_forkid(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -106,32 +106,32 @@ def test_invalid_path_pass_forkid(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) -def test_attack_path_segwit(client: Client): +def test_attack_path_segwit(session: Session): # Scenario: The attacker falsely claims that the transaction uses Testnet paths to # avoid the path warning dialog, but in step6_sign_segwit_inputs() uses Bitcoin paths # to get a valid signature. device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) # Generate keys address_a = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/0h/0/0"), script_type=messages.InputScriptType.SPENDWITNESS, ) address_b = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -178,15 +178,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} ) -def test_invalid_path_fail_asap(client: Client): +def test_invalid_path_fail_asap(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/0"), amount=1_000_000, @@ -202,14 +202,14 @@ def test_invalid_path_fail_asap(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), messages.Failure(code=messages.FailureType.DataError), ] ) try: - btc.sign_tx(client, "Testnet", [inp1], [out1]) + btc.sign_tx(session, "Testnet", [inp1], [out1]) except TrezorFailure: pass diff --git a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py index de0f3807689..d3ab1cf37b0 100644 --- a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py +++ b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py @@ -15,7 +15,7 @@ # If not, see . from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -34,7 +34,7 @@ ) -def test_non_segwit_segwit_inputs(client: Client): +def test_non_segwit_segwit_inputs(session: Session): # First is non-segwit, second is segwit. inp1 = messages.TxInputType( @@ -58,9 +58,9 @@ def test_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -71,7 +71,7 @@ def test_non_segwit_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_inputs(client: Client): +def test_segwit_non_segwit_inputs(session: Session): # First is segwit, second is non-segwit. inp1 = messages.TxInputType( @@ -94,9 +94,9 @@ def test_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -107,7 +107,7 @@ def test_segwit_non_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_segwit_inputs(client: Client): +def test_segwit_non_segwit_segwit_inputs(session: Session): # First is segwit, second is non-segwit and third is segwit again. inp1 = messages.TxInputType( @@ -138,9 +138,9 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 @@ -151,7 +151,7 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): ) -def test_non_segwit_segwit_non_segwit_inputs(client: Client): +def test_non_segwit_segwit_non_segwit_inputs(session: Session): # First is non-segwit, second is segwit and third is non-segwit again. inp1 = messages.TxInputType( @@ -180,9 +180,9 @@ def test_non_segwit_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index e02cb2b6c6b..32c90d05e0c 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -18,8 +18,8 @@ import pytest -from trezorlib import btc, messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import btc, messages, misc, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -138,7 +138,7 @@ def case(id, *args, altcoin: bool = False, models: str | None = None): case("out12", (PaymentRequestParams([1, 2], [], get_nonce=True),)), ), ) -def test_payment_request(client: Client, payment_request_params): +def test_payment_request(session: Session, payment_request_params): for txo in outputs: txo.payment_req_index = None @@ -148,10 +148,10 @@ def test_payment_request(client: Client, payment_request_params): for txo_index in params.txo_indices: outputs[txo_index].payment_req_index = i request_outputs.append(outputs[txo_index]) - nonce = misc.get_nonce(client) if params.get_nonce else None + nonce = misc.get_nonce(session) if params.get_nonce else None payment_reqs.append( make_payment_request( - client, + session, recipient_name="trezor.io", outputs=request_outputs, change_addresses=["tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9"], @@ -161,7 +161,7 @@ def test_payment_request(client: Client, payment_request_params): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -174,7 +174,7 @@ def test_payment_request(client: Client, payment_request_params): # Ensure that the nonce has been invalidated. with pytest.raises(TrezorFailure, match="Invalid nonce in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -184,15 +184,18 @@ def test_payment_request(client: Client, payment_request_params): @pytest.mark.models(skip="safe3") -def test_payment_request_details(client: Client): +def test_payment_request_details(session: Session): + if session.model is models.T2B1: + pytest.skip("Details not implemented on T2B1") + # Test that payment request details are shown when requested. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None - nonce = misc.get_nonce(client) + nonce = misc.get_nonce(session) payment_reqs = [ make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[TextMemo("Invoice #87654321.")], @@ -200,12 +203,12 @@ def test_payment_request_details(client: Client): ) ] - with client: + with session.client as client: IF = InputFlowPaymentRequestDetails(client, outputs) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -216,16 +219,16 @@ def test_payment_request_details(client: Client): assert serialized_tx.hex() == SERIALIZED_TX -def test_payment_req_wrong_amount(client: Client): +def test_payment_req_wrong_amount(session: Session): # Test wrong total amount in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Decrease the total amount of the payment request. @@ -233,7 +236,7 @@ def test_payment_req_wrong_amount(client: Client): with pytest.raises(TrezorFailure, match="Invalid amount in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -242,18 +245,18 @@ def test_payment_req_wrong_amount(client: Client): ) -def test_payment_req_wrong_mac_refund(client: Client): +def test_payment_req_wrong_mac_refund(session: Session): # Test wrong MAC in payment request memo. memo = RefundMemo(parse_path("m/44h/1h/0h/1/0")) outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -263,7 +266,7 @@ def test_payment_req_wrong_mac_refund(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -274,7 +277,7 @@ def test_payment_req_wrong_mac_refund(client: Client): @pytest.mark.altcoin @pytest.mark.models("t2t1", reason="Dash not supported on Safe family") -def test_payment_req_wrong_mac_purchase(client: Client): +def test_payment_req_wrong_mac_purchase(session: Session): # Test wrong MAC in payment request memo. memo = CoinPurchaseMemo( amount="22.34904 DASH", @@ -286,11 +289,11 @@ def test_payment_req_wrong_mac_purchase(client: Client): outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -300,7 +303,7 @@ def test_payment_req_wrong_mac_purchase(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -309,16 +312,16 @@ def test_payment_req_wrong_mac_purchase(client: Client): ) -def test_payment_req_wrong_output(client: Client): +def test_payment_req_wrong_output(session: Session): # Test wrong output in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Use a different address in the second output. @@ -335,7 +338,7 @@ def test_payment_req_wrong_output(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, fake_outputs, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index 307823a9f3f..a2f96c04ed1 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -5,7 +5,7 @@ import pytest from trezorlib import btc, messages, models, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import is_core @@ -78,7 +78,7 @@ def _check_error_message(value: bytes, model: models.TrezorModel, message: str): @with_bad_prevhashes -def test_invalid_prev_hash(client: Client, prev_hash): +def test_invalid_prev_hash(session: Session, prev_hash): inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), amount=123_456_789, @@ -93,12 +93,12 @@ def test_invalid_prev_hash(client: Client, prev_hash): ) with pytest.raises(TrezorFailure) as e: - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes={}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes={}) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_attack(client: Client, prev_hash): +def test_invalid_prev_hash_attack(session: Session, prev_hash): # prepare input with a valid prev-hash inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), @@ -130,20 +130,20 @@ def attack_filter(msg): msg.tx.inputs[0].prev_hash = prev_hash return msg - with client, pytest.raises(TrezorFailure) as e: - client.set_filter(messages.TxAck, attack_filter) - if is_core(client): + with session, session.client as client, pytest.raises(TrezorFailure) as e: + session.set_filter(messages.TxAck, attack_filter) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed assert counter == 0 - _check_error_message(prev_hash, client.model, e.value.message) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): +def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): prev_tx = copy(PREV_TX) # smoke check: replace prev_hash with all zeros, reserialize and hash, try to sign @@ -161,16 +161,16 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): amount=99_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) # attack: replace prev_hash with an invalid value prev_tx.inputs[0].prev_hash = prev_hash tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with client, pytest.raises(TrezorFailure) as e: - if client.model is not models.T1B1: + with session, session.client as client, pytest.raises(TrezorFailure) as e: + if session.model is not models.T1B1: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_replacement.py b/tests/device_tests/bitcoin/test_signtx_replacement.py index 97fe7e2d873..fd5db6a5027 100644 --- a/tests/device_tests/bitcoin/test_signtx_replacement.py +++ b/tests/device_tests/bitcoin/test_signtx_replacement.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -90,7 +90,7 @@ ) -def test_p2pkh_fee_bump(client: Client): +def test_p2pkh_fee_bump(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/4"), amount=174_998, @@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_50f6f1), @@ -132,7 +132,7 @@ def test_p2pkh_fee_bump(client: Client): request_meta(TXHASH_beafc7), request_input(0, TXHASH_beafc7), request_output(0, TXHASH_beafc7), - (is_core(client), request_orig_input(0, TXHASH_50f6f1)), + (is_core(session), request_orig_input(0, TXHASH_50f6f1)), request_orig_input(0, TXHASH_50f6f1), request_orig_output(0, TXHASH_50f6f1), request_orig_output(1, TXHASH_50f6f1), @@ -145,7 +145,7 @@ def test_p2pkh_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -159,7 +159,7 @@ def test_p2pkh_fee_bump(client: Client): ) -def test_p2wpkh_op_return_fee_bump(client: Client): +def test_p2wpkh_op_return_fee_bump(session: Session): # Original input. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/1h/0/14"), @@ -190,9 +190,9 @@ def test_p2wpkh_op_return_fee_bump(client: Client): orig_index=1, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -207,7 +207,7 @@ def test_p2wpkh_op_return_fee_bump(client: Client): # txid 48bc29fc42a64b43d043b0b7b99b21aa39654234754608f791c60bcbd91a8e92 -def test_p2tr_fee_bump(client: Client): +def test_p2tr_fee_bump(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -243,8 +243,8 @@ def test_p2tr_fee_bump(client: Client): orig_index=1, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_8e4af7), @@ -269,7 +269,7 @@ def test_p2tr_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -281,7 +281,7 @@ def test_p2tr_fee_bump(client: Client): ) -def test_p2wpkh_finalize(client: Client): +def test_p2wpkh_finalize(session: Session): # Original input with disabled RBF opt-in, i.e. we finalize the transaction. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/2"), @@ -312,8 +312,8 @@ def test_p2wpkh_finalize(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_70f987), @@ -339,7 +339,7 @@ def test_p2wpkh_finalize(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -401,7 +401,7 @@ def test_p2wpkh_finalize(client: Client): ), ) def test_p2wpkh_payjoin( - client, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx + session, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx ): # Original input. inp1 = messages.TxInputType( @@ -444,8 +444,8 @@ def test_p2wpkh_payjoin( orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_65b768), @@ -478,7 +478,7 @@ def test_p2wpkh_payjoin( ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -489,7 +489,7 @@ def test_p2wpkh_payjoin( assert serialized_tx.hex() == expected_tx -def test_p2wpkh_in_p2sh_remove_change(client: Client): +def test_p2wpkh_in_p2sh_remove_change(session: Session): # Test fee bump with change-output removal. Originally fee was 3780, now 98060. inp1 = messages.TxInputType( @@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -553,7 +553,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -567,7 +567,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ) -def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): +def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -634,7 +634,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -649,7 +649,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): @pytest.mark.models("core") -def test_tx_meld(client: Client): +def test_tx_meld(session: Session): # Meld two original transactions into one, joining the change-outputs into a different one. inp1 = messages.TxInputType( @@ -720,8 +720,8 @@ def test_tx_meld(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -785,7 +785,7 @@ def test_tx_meld(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3], @@ -799,7 +799,7 @@ def test_tx_meld(client: Client): ) -def test_attack_steal_change(client: Client): +def test_attack_steal_change(session: Session): # Attempt to steal amount equivalent to the change in the original transaction by # hiding the fact that an output in the original transaction is a change-output. @@ -860,7 +860,7 @@ def test_attack_steal_change(client: Client): TrezorFailure, match="Original output is missing change-output parameters" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -870,7 +870,7 @@ def test_attack_steal_change(client: Client): @pytest.mark.models("core") -def test_attack_false_internal(client: Client): +def test_attack_false_internal(session: Session): # Falsely claim that an external input is internal in the original transaction. # If this were possible, it would allow an attacker to make it look like the # user was spending more in the original than they actually were, making it @@ -914,7 +914,7 @@ def test_attack_false_internal(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -922,7 +922,7 @@ def test_attack_false_internal(client: Client): ) -def test_attack_fake_int_input_amount(client: Client): +def test_attack_fake_int_input_amount(session: Session): # Give a fake input amount for an original internal input while giving the correct # amount for the replacement input. If an attacker could increase the amount of an # internal input in the original transaction, then they could bump the fee of the @@ -968,7 +968,7 @@ def test_attack_fake_int_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -977,7 +977,7 @@ def test_attack_fake_int_input_amount(client: Client): @pytest.mark.models("core") -def test_attack_fake_ext_input_amount(client: Client): +def test_attack_fake_ext_input_amount(session: Session): # Give a fake input amount for an original external input while giving the correct # amount for the replacement input. If an attacker could decrease the amount of an # external input in the original transaction, then they could steal the fee from @@ -1044,7 +1044,7 @@ def test_attack_fake_ext_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -1052,7 +1052,7 @@ def test_attack_fake_ext_input_amount(client: Client): ) -def test_p2wpkh_invalid_signature(client: Client): +def test_p2wpkh_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. # Original input with disabled RBF opt-in, i.e. we finalize the transaction. @@ -1096,7 +1096,7 @@ def test_p2wpkh_invalid_signature(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1105,7 +1105,7 @@ def test_p2wpkh_invalid_signature(client: Client): ) -def test_p2tr_invalid_signature(client: Client): +def test_p2tr_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. inp1 = messages.TxInputType( @@ -1151,4 +1151,4 @@ def test_p2tr_invalid_signature(client: Client): prev_txes = {TXHASH_8e4af7: prev_tx_invalid} with pytest.raises(TrezorFailure, match="Invalid signature"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit.py b/tests/device_tests/bitcoin/test_signtx_segwit.py index 763626caef0..ef8c988ff39 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -47,7 +47,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2sh(client: Client, chunkify: bool): +def test_send_p2sh(session: Session, chunkify: bool): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -66,16 +66,16 @@ def test_send_p2sh(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -90,7 +90,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -105,7 +105,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -124,13 +124,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -146,7 +146,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -156,11 +156,11 @@ def test_send_p2sh_change(client: Client): ) -def test_testnet_segwit_big_amount(client: Client): +def test_testnet_segwit_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 address_n = parse_path("m/49h/1h/0h/0/0") address = btc.get_address( - client, + session, "Testnet", address_n, script_type=messages.InputScriptType.SPENDP2SHWITNESS, @@ -179,13 +179,13 @@ def test_testnet_segwit_big_amount(client: Client): amount=2**32 + 1, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(prev_hash), @@ -198,7 +198,7 @@ def test_testnet_segwit_big_amount(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} ) # Transaction does not exist on the blockchain, not using assert_tx_matches() assert ( @@ -208,12 +208,12 @@ def test_testnet_segwit_big_amount(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input: 338e2d02e0eaf8848e38925904e51546cf22e58db5b1860c4a0e72b69c56afe5 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -241,7 +241,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_338e2d), @@ -254,10 +254,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -265,10 +265,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -278,7 +278,7 @@ def test_send_multisig_1(client: Client): ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # Simulates an attack where the user is coerced into unknowingly # transferring funds from one account to another one of their accounts, # potentially resulting in privacy issues. @@ -303,17 +303,17 @@ def test_attack_change_input_address(client: Client): ) # Test if the transaction can be signed normally. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), # The user is required to confirm transfer to another account. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -328,7 +328,7 @@ def test_attack_change_input_address(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -349,15 +349,15 @@ def attack_processor(msg): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) -def test_attack_mixed_inputs(client: Client): +def test_attack_mixed_inputs(session: Session): TRUE_AMOUNT = 123_456_789 FAKE_AMOUNT = 120_000_000 @@ -389,11 +389,11 @@ def test_attack_mixed_inputs(client: Client): request_output(0), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), ), messages.ButtonRequest(code=messages.ButtonRequestType.FeeOverThreshold), @@ -417,16 +417,16 @@ def test_attack_mixed_inputs(client: Client): request_finished(), ] - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 asks for first input for witness again expected_responses.insert(-2, request_input(0)) - with client: + with session: # Sign unmodified transaction. # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT - client.set_expected_responses(expected_responses) + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -436,7 +436,7 @@ def test_attack_mixed_inputs(client: Client): # In Phase 1 make the user confirm a lower value of the segwit input. inp2.amount = FAKE_AMOUNT - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 fails as soon as it encounters the fake amount. expected_responses = ( expected_responses[:4] + expected_responses[5:15] + [messages.Failure()] @@ -446,10 +446,10 @@ def test_attack_mixed_inputs(client: Client): expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] ) - with pytest.raises(TrezorFailure) as e, client: - client.set_expected_responses(expected_responses) + with pytest.raises(TrezorFailure) as e, session: + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 0c779c777ef..920b0bf48b7 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ...bip32 import deserialize @@ -61,7 +61,7 @@ ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -82,16 +82,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -106,7 +106,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -116,7 +116,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -137,13 +137,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -159,7 +159,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -169,7 +169,7 @@ def test_send_p2sh_change(client: Client): ) -def test_send_native(client: Client): +def test_send_native(session: Session): # input tx: b36780ceb86807ca6e7535a6fd418b1b788cb9b227d2c8a26a0de295e523219e inp1 = messages.TxInputType( @@ -190,16 +190,16 @@ def test_send_native(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b36780), @@ -214,7 +214,7 @@ def test_send_native(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -224,7 +224,7 @@ def test_send_native(client: Client): ) -def test_send_to_taproot(client: Client): +def test_send_to_taproot(session: Session): # input tx: ec16dc5a539c5d60001a7471c37dbb0b5294c289c77df8bd07870b30d73e2231 inp1 = messages.TxInputType( @@ -244,9 +244,9 @@ def test_send_to_taproot(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=10_000 - 7_000 - 200, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -256,7 +256,7 @@ def test_send_to_taproot(client: Client): ) -def test_send_native_change(client: Client): +def test_send_native_change(session: Session): # input tx: fcb3f5436224900afdba50e9e763d98b920dfed056e552040d99ea9bc03a9d83 inp1 = messages.TxInputType( @@ -277,13 +277,13 @@ def test_send_native_change(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -300,7 +300,7 @@ def test_send_native_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -310,7 +310,7 @@ def test_send_native_change(client: Client): ) -def test_send_both(client: Client): +def test_send_both(session: Session): # input 1 tx: 65047a2b107d6301d72d4a1e49e7aea9cf06903fdc4ae74a4a9bba9bc1a414d2 # input 2 tx: d159fd2fcb5854a7c8b275d598765a446f1e2ff510bf077545a404a0c9db65f7 @@ -344,21 +344,21 @@ def test_send_both(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_65047a), @@ -382,7 +382,7 @@ def test_send_both(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -397,12 +397,12 @@ def test_send_both(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -433,7 +433,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -449,10 +449,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -460,10 +460,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -474,12 +474,12 @@ def test_send_multisig_1(client: Client): @pytest.mark.multisig -def test_send_multisig_2(client: Client): +def test_send_multisig_2(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -510,7 +510,7 @@ def test_send_multisig_2(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -526,10 +526,10 @@ def test_send_multisig_2(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -537,10 +537,10 @@ def test_send_multisig_2(client: Client): # sign with first key inp1.address_n[2] = H_(1) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -551,12 +551,12 @@ def test_send_multisig_2(client: Client): @pytest.mark.multisig -def test_send_multisig_3_change(client: Client): +def test_send_multisig_3_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -595,7 +595,7 @@ def test_send_multisig_3_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -611,13 +611,13 @@ def test_send_multisig_3_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -626,13 +626,13 @@ def test_send_multisig_3_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -643,12 +643,12 @@ def test_send_multisig_3_change(client: Client): @pytest.mark.multisig -def test_send_multisig_4_change(client: Client): +def test_send_multisig_4_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -687,7 +687,7 @@ def test_send_multisig_4_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -703,13 +703,13 @@ def test_send_multisig_4_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -718,13 +718,13 @@ def test_send_multisig_4_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -734,7 +734,7 @@ def test_send_multisig_4_change(client: Client): ) -def test_multisig_mismatch_inputs_single(client: Client): +def test_multisig_mismatch_inputs_single(session: Session): # Ensure that if there is a non-multisig input, then a multisig output # will not be identified as a change output. @@ -788,18 +788,18 @@ def test_multisig_mismatch_inputs_single(client: Client): amount=100_000 + 100_000 - 50_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), # Ensure that the multisig output is not identified as a change output. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_1c022d), @@ -824,7 +824,7 @@ def test_multisig_mismatch_inputs_single(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_signtx_taproot.py b/tests/device_tests/bitcoin/test_signtx_taproot.py index f548154ae70..0453474af91 100644 --- a/tests/device_tests/bitcoin/test_signtx_taproot.py +++ b/tests/device_tests/bitcoin/test_signtx_taproot.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -64,7 +64,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2tr(client: Client, chunkify: bool): +def test_send_p2tr(session: Session, chunkify: bool): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -79,13 +79,13 @@ def test_send_p2tr(client: Client, chunkify: bool): amount=4_450, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -94,7 +94,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify + session, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify ) assert_tx_matches( @@ -104,7 +104,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ) -def test_send_two_with_change(client: Client): +def test_send_two_with_change(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -133,14 +133,14 @@ def test_send_two_with_change(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, amount=6_800 + 13_000 - 200 - 15_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -153,7 +153,7 @@ def test_send_two_with_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API ) assert_tx_matches( @@ -163,7 +163,7 @@ def test_send_two_with_change(client: Client): ) -def test_send_mixed(client: Client): +def test_send_mixed(session: Session): inp1 = messages.TxInputType( # 2MutHjgAXkqo3jxX2DZWorLAckAnwTxSM9V address_n=parse_path("m/49h/1h/1h/0/0"), @@ -222,8 +222,8 @@ def test_send_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # process inputs request_input(0), @@ -233,19 +233,19 @@ def test_send_mixed(client: Client): # approve outputs request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(3), messages.ButtonRequest(code=B.ConfirmOutput), request_output(4), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), # verify inputs request_input(0), @@ -293,12 +293,12 @@ def test_send_mixed(client: Client): request_input(0), request_input(1), request_input(2), - (client.model is models.T1B1, request_input(3)), + (session.model is models.T1B1, request_input(3)), request_finished(), ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3, out4, out5], @@ -312,13 +312,12 @@ def test_send_mixed(client: Client): ) -def test_attack_script_type(client: Client): +def test_attack_script_type(session: Session): # Scenario: The attacker falsely claims that the transaction is Taproot-only to # avoid prev tx streaming and gives a lower amount for one of the inputs. The # correct input types and amounts are revelaled only in step6_sign_segwit_inputs() # to get a valid signature. This results in a transaction which pays a fee much # larger than what the user confirmed. - inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/1/0"), amount=7_289_000, @@ -354,16 +353,16 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -374,7 +373,7 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) assert exc.value.code == messages.FailureType.ProcessError assert exc.value.message.endswith("Transaction has changed during signing") @@ -392,7 +391,7 @@ def attack_processor(msg): "tb1pllllllllllllllllllllllllllllllllllllllllllllallllscqgl4zhn", ), ) -def test_send_invalid_address(client: Client, address: str): +def test_send_invalid_address(session: Session, address: str): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -407,12 +406,12 @@ def test_send_invalid_address(client: Client, address: str): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure): - client.set_expected_responses( + with session, pytest.raises(TrezorFailure): + session.set_expected_responses( [ request_input(0), request_output(0), messages.Failure, ] ) - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index 86389d8a515..88907e318dc 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -19,12 +19,12 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -35,9 +35,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "mirio8q3gtv7fhdnmb3TpZ4EuafdzSs7zL", bytes.fromhex( @@ -49,9 +49,9 @@ def test_message_testnet(client: Client): @pytest.mark.altcoin -def test_message_grs(client: Client): +def test_message_grs(session: Session): ret = btc.verify_message( - client, + session, "Groestlcoin", "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM", base64.b64decode( @@ -62,9 +62,9 @@ def test_message_grs(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -76,7 +76,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -88,7 +88,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -100,7 +100,7 @@ def test_message_verify(client: Client): # compressed pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -112,7 +112,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -124,7 +124,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -136,7 +136,7 @@ def test_message_verify(client: Client): # trezor pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -148,7 +148,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -160,7 +160,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -172,9 +172,9 @@ def test_message_verify(client: Client): @pytest.mark.altcoin -def test_message_verify_bcash(client: Client): +def test_message_verify_bcash(session: Session): res = btc.verify_message( - client, + session, "Bcash", "bitcoincash:qqj22md58nm09vpwsw82fyletkxkq36zxyxh322pru", bytes.fromhex( @@ -185,9 +185,9 @@ def test_message_verify_bcash(client: Client): assert res is True -def test_verify_bitcoind(client: Client): +def test_verify_bitcoind(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1KzXE97kV7DrpxCViCN3HbGbiKhzzPM7TQ", bytes.fromhex( @@ -199,12 +199,12 @@ def test_verify_bitcoind(client: Client): assert res is True -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -214,7 +214,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit.py b/tests/device_tests/bitcoin/test_verifymessage_segwit.py index 84f04442646..9c3169e0c78 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "2N4VkePSzKH2sv5YBikLHGvzUYvfPxV6zS9", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "3L6TyTisPBmrDAj6RoKmDzNnj4eQi54gD2", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py index 5bea51f7dc1..3a4ed68e5da 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "tb1qyjjkmdpu7metqt5r36jf872a34syws336p3n3p", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "bc1qannfxke2tfd4l7vhepehpvt05y83v3qsf6nfkk", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_zcash.py b/tests/device_tests/bitcoin/test_zcash.py index dc959199a35..adb99589150 100644 --- a/tests/device_tests/bitcoin/test_zcash.py +++ b/tests/device_tests/bitcoin/test_zcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -57,7 +57,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_v3_not_supported(client: Client): +def test_v3_not_supported(session: Session): # prevout: aaf51e4606c264e47e5c42c958fe4cf1539c5172684721e38e69f4ef634d75dc:1 # input 1: 3.0 TAZ @@ -75,9 +75,9 @@ def test_v3_not_supported(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure, match="DataError"): + with session, pytest.raises(TrezorFailure, match="DataError"): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -88,7 +88,7 @@ def test_v3_not_supported(client: Client): ) -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: e3820602226974b1dd87b7113cc8aea8c63e5ae29293991e7bfa80c126930368:0 # input 1: 3.0 TAZ @@ -106,13 +106,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -128,7 +128,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -145,7 +145,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -161,7 +161,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -170,7 +170,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_old_versions(client: Client): +def test_spend_old_versions(session: Session): # NOTE: fake input tx used input_v1 = messages.TxInputType( @@ -210,9 +210,9 @@ def test_spend_old_versions(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", inputs, [output], @@ -229,7 +229,7 @@ def test_spend_old_versions(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -259,14 +259,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -289,7 +289,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d7c02e6b6dc..d8ec9288eb7 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -22,7 +22,7 @@ get_public_key, parse_optional_bytes, ) -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import CardanoAddressType, CardanoDerivationType from trezorlib.tools import parse_path @@ -48,15 +48,15 @@ "cardano/get_base_address.derivations.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_cardano_get_address(client: Client, chunkify: bool, parameters, result): - client.init_device(new_session=True, derive_cardano=True) +def test_cardano_get_address(session: Session, chunkify: bool, parameters, result): + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] address = get_address( - client, + session, address_parameters=create_address_parameters( address_type=getattr( CardanoAddressType, parameters["address_type"].upper() @@ -94,17 +94,17 @@ def test_cardano_get_address(client: Client, chunkify: bool, parameters, result) "cardano/get_public_key.slip39.json", "cardano/get_public_key.derivations.json", ) -def test_cardano_get_public_key(client: Client, parameters, result): - with client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(client.ui.passphrase)) +def test_cardano_get_public_key(session: Session, parameters, result): + with session, session.client as client: + IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase)) client.set_input_flow(IF.get()) - client.init_device(new_session=True, derive_cardano=True) + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] key = get_public_key( - client, parse_path(parameters["path"]), derivation_type, show_display=True + session, parse_path(parameters["path"]), derivation_type, show_display=True ) assert key.node.public_key.hex() == result["public_key"] diff --git a/tests/device_tests/cardano/test_derivations.py b/tests/device_tests/cardano/test_derivations.py index 656c31a8bde..148a0a85037 100644 --- a/tests/device_tests/cardano/test_derivations.py +++ b/tests/device_tests/cardano/test_derivations.py @@ -17,7 +17,7 @@ import pytest from trezorlib.cardano import get_public_key -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import CardanoDerivationType as D from trezorlib.tools import parse_path @@ -26,35 +26,29 @@ pytestmark = [ pytest.mark.altcoin, - pytest.mark.cardano, pytest.mark.models("core"), ] ADDRESS_N = parse_path("m/1852h/1815h/0h") -def test_bad_session(client: Client): - client.init_device(new_session=True) +def test_bad_session(session: Session): with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) + get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) - client.init_device(new_session=True, derive_cardano=False) - with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) +def test_ledger_available_without_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) -def test_ledger_available_always(client: Client): - client.init_device(new_session=True, derive_cardano=False) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) - client.init_device(new_session=True, derive_cardano=True) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) +@pytest.mark.cardano # derive_cardano=True +def test_ledger_available_with_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.parametrize("derivation_type", D) # try ALL derivation types -def test_derivation_irrelevant_on_slip39(client: Client, derivation_type): - client.init_device(new_session=True, derive_cardano=False) - pubkey = get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - test_pubkey = get_public_key(client, ADDRESS_N, derivation_type=derivation_type) +def test_derivation_irrelevant_on_slip39(session: Session, derivation_type): + pubkey = get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) + test_pubkey = get_public_key(session, ADDRESS_N, derivation_type=derivation_type) assert pubkey == test_pubkey diff --git a/tests/device_tests/cardano/test_get_native_script_hash.py b/tests/device_tests/cardano/test_get_native_script_hash.py index 63ee56d16fb..2859d69a41a 100644 --- a/tests/device_tests/cardano/test_get_native_script_hash.py +++ b/tests/device_tests/cardano/test_get_native_script_hash.py @@ -18,7 +18,7 @@ from trezorlib import messages from trezorlib.cardano import get_native_script_hash, parse_native_script -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import parametrize_using_common_fixtures @@ -32,11 +32,9 @@ @parametrize_using_common_fixtures( "cardano/get_native_script_hash.json", ) -def test_cardano_get_native_script_hash(client: Client, parameters, result): - client.init_device(new_session=True, derive_cardano=True) - +def test_cardano_get_native_script_hash(session: Session, parameters, result): native_script_hash = get_native_script_hash( - client, + session, native_script=parse_native_script(parameters["native_script"]), display_format=messages.CardanoNativeScriptHashDisplayFormat.__members__[ parameters["display_format"] diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 1900ff2cbe6..ef02d82965a 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -18,6 +18,7 @@ from trezorlib import cardano, device, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -58,9 +59,9 @@ def show_details_input_flow(client: Client): "cardano/sign_tx.plutus.json", "cardano/sign_tx.slip39.json", ) -def test_cardano_sign_tx(client: Client, parameters, result): +def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( - client, + session, parameters, input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), ) @@ -68,8 +69,8 @@ def test_cardano_sign_tx(client: Client, parameters, result): @parametrize_using_common_fixtures("cardano/sign_tx.show_details.json") -def test_cardano_sign_tx_show_details(client: Client, parameters, result): - response = call_sign_tx(client, parameters, show_details_input_flow, chunkify=True) +def test_cardano_sign_tx_show_details(session: Session, parameters, result): + response = call_sign_tx(session, parameters, show_details_input_flow, chunkify=True) assert response == _transform_expected_result(result) @@ -79,13 +80,13 @@ def test_cardano_sign_tx_show_details(client: Client, parameters, result): "cardano/sign_tx.multisig.failed.json", "cardano/sign_tx.plutus.failed.json", ) -def test_cardano_sign_tx_failed(client: Client, parameters, result): +def test_cardano_sign_tx_failed(session: Session, parameters, result): with pytest.raises(TrezorFailure, match=result["error_message"]): - call_sign_tx(client, parameters, None) + call_sign_tx(session, parameters, None) -def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = False): - client.init_device(new_session=True, derive_cardano=True) +def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = False): + # session.init_device(new_session=True, derive_cardano=True) signing_mode = messages.CardanoTxSigningMode.__members__[parameters["signing_mode"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]] @@ -116,18 +117,18 @@ def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = F if parameters.get("security_checks") == "prompt": device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) else: - device.apply_settings(client, safety_checks=messages.SafetyCheckLevel.Strict) + device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with client: + with session.client as client: if input_flow is not None: client.watch_layout() client.set_input_flow(input_flow(client)) return cardano.sign_tx( - client=client, + session=session, signing_mode=signing_mode, inputs=inputs, outputs=outputs, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index 1b518e95f2e..d99c54cb2b6 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.eos import get_public_key from trezorlib.tools import parse_path @@ -28,12 +28,12 @@ @pytest.mark.eos @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_eos_get_public_key(client: Client): - with client: +def test_eos_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) public_key = get_public_key( - client, parse_path("m/44h/194h/0h/0/0"), show_display=True + session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) assert ( public_key.wif_public_key @@ -43,7 +43,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02015fabe197c955036bab25f4e7c16558f9f672f9f625314ab1ec8f64f7b1198e" ) - public_key = get_public_key(client, parse_path("m/44h/194h/0h/0/1")) + public_key = get_public_key(session, parse_path("m/44h/194h/0h/0/1")) assert ( public_key.wif_public_key == "EOS5d1VP15RKxT4dSakWu2TFuEgnmaGC2ckfSvQwND7pZC1tXkfLP" @@ -52,7 +52,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02608bc2c431521dee0b9d5f2fe34053e15fc3b20d2895e0abda857b9ed8e77a78" ) - public_key = get_public_key(client, parse_path("m/44h/194h/1h/0/0")) + public_key = get_public_key(session, parse_path("m/44h/194h/1h/0/0")) assert ( public_key.wif_public_key == "EOS7UuNeTf13nfcG85rDB7AHGugZi4C4wJ4ft12QRotqNfxdV2NvP" diff --git a/tests/device_tests/eos/test_signtx.py b/tests/device_tests/eos/test_signtx.py index 57fd051bb4a..54ebece6a9a 100644 --- a/tests/device_tests/eos/test_signtx.py +++ b/tests/device_tests/eos/test_signtx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import eos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import EosSignedTx from trezorlib.tools import parse_path @@ -35,7 +35,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_eos_signtx_transfer_token(client: Client, chunkify: bool): +def test_eos_signtx_transfer_token(session: Session, chunkify: bool): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -60,8 +60,8 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -69,7 +69,7 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): ) -def test_eos_signtx_buyram(client: Client): +def test_eos_signtx_buyram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -93,8 +93,8 @@ def test_eos_signtx_buyram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -102,7 +102,7 @@ def test_eos_signtx_buyram(client: Client): ) -def test_eos_signtx_buyrambytes(client: Client): +def test_eos_signtx_buyrambytes(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -126,8 +126,8 @@ def test_eos_signtx_buyrambytes(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -135,7 +135,7 @@ def test_eos_signtx_buyrambytes(client: Client): ) -def test_eos_signtx_sellram(client: Client): +def test_eos_signtx_sellram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -155,8 +155,8 @@ def test_eos_signtx_sellram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -164,7 +164,7 @@ def test_eos_signtx_sellram(client: Client): ) -def test_eos_signtx_delegate(client: Client): +def test_eos_signtx_delegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -190,8 +190,8 @@ def test_eos_signtx_delegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -199,7 +199,7 @@ def test_eos_signtx_delegate(client: Client): ) -def test_eos_signtx_undelegate(client: Client): +def test_eos_signtx_undelegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -224,8 +224,8 @@ def test_eos_signtx_undelegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -233,7 +233,7 @@ def test_eos_signtx_undelegate(client: Client): ) -def test_eos_signtx_refund(client: Client): +def test_eos_signtx_refund(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -253,8 +253,8 @@ def test_eos_signtx_refund(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -262,7 +262,7 @@ def test_eos_signtx_refund(client: Client): ) -def test_eos_signtx_linkauth(client: Client): +def test_eos_signtx_linkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -287,8 +287,8 @@ def test_eos_signtx_linkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -296,7 +296,7 @@ def test_eos_signtx_linkauth(client: Client): ) -def test_eos_signtx_unlinkauth(client: Client): +def test_eos_signtx_unlinkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -320,8 +320,8 @@ def test_eos_signtx_unlinkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -329,7 +329,7 @@ def test_eos_signtx_unlinkauth(client: Client): ) -def test_eos_signtx_updateauth(client: Client): +def test_eos_signtx_updateauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -376,8 +376,8 @@ def test_eos_signtx_updateauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -385,7 +385,7 @@ def test_eos_signtx_updateauth(client: Client): ) -def test_eos_signtx_deleteauth(client: Client): +def test_eos_signtx_deleteauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -405,8 +405,8 @@ def test_eos_signtx_deleteauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -414,7 +414,7 @@ def test_eos_signtx_deleteauth(client: Client): ) -def test_eos_signtx_vote(client: Client): +def test_eos_signtx_vote(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -468,8 +468,8 @@ def test_eos_signtx_vote(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -477,7 +477,7 @@ def test_eos_signtx_vote(client: Client): ) -def test_eos_signtx_vote_proxy(client: Client): +def test_eos_signtx_vote_proxy(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -497,8 +497,8 @@ def test_eos_signtx_vote_proxy(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -506,7 +506,7 @@ def test_eos_signtx_vote_proxy(client: Client): ) -def test_eos_signtx_unknown(client: Client): +def test_eos_signtx_unknown(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -526,8 +526,8 @@ def test_eos_signtx_unknown(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -535,7 +535,7 @@ def test_eos_signtx_unknown(client: Client): ) -def test_eos_signtx_newaccount(client: Client): +def test_eos_signtx_newaccount(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -602,8 +602,8 @@ def test_eos_signtx_newaccount(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -611,7 +611,7 @@ def test_eos_signtx_newaccount(client: Client): ) -def test_eos_signtx_setcontract(client: Client): +def test_eos_signtx_setcontract(session: Session): transaction = { "expiration": "2018-06-19T13:29:53", "ref_block_num": 30587, @@ -638,8 +638,8 @@ def test_eos_signtx_setcontract(client: Client): "context_free_data": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 314189ca597..9cc3fd57043 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -5,7 +5,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -40,60 +40,60 @@ } -def test_builtin(client: Client) -> None: +def test_builtin(session: Session) -> None: # Ethereum (SLIP-44 60, chain_id 1) will sign without any definitions provided - ethereum.sign_tx(client, **DEFAULT_TX_PARAMS) + ethereum.sign_tx(session, **DEFAULT_TX_PARAMS) -def test_chain_id_allowed(client: Client) -> None: +def test_chain_id_allowed(session: Session) -> None: # Any chain id is allowed as long as the SLIP44 stays the same params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=222222) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_disallowed(client: Client) -> None: +def test_slip44_disallowed(session: Session) -> None: # SLIP44 is not allowed without a valid network definition params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0")) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_external(client: Client) -> None: +def test_slip44_external(session: Session) -> None: # to use a non-default SLIP44, a valid network definition must be provided network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_slip44_external_disallowed(client: Client) -> None: +def test_slip44_external_disallowed(session: Session) -> None: # network definition does not allow a different SLIP44 network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/55555h/0h/0/0"), chain_id=66666) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_chain_id_mismatch(client: Client) -> None: +def test_chain_id_mismatch(session: Session) -> None: # network definition for a different chain id will be rejected network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=55555) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_definition_does_not_override_builtin(client: Client) -> None: +def test_definition_does_not_override_builtin(session: Session) -> None: # The builtin definition for Ethereum (SLIP44 60, chain_id 1) will be used # even if a valid definition with a different SLIP44 is provided network = common.encode_network(chain_id=1, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=1) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO: test that the builtin definition will not show different symbol @@ -102,50 +102,50 @@ def test_definition_does_not_override_builtin(client: Client) -> None: # all tokens are currently accepted, we would need to check the screenshots -def test_builtin_token(client: Client) -> None: +def test_builtin_token(session: Session) -> None: # The builtin definition for USDT (ERC20) will be used even if not provided params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) # TODO check that USDT symbol is shown # TODO: test_builtin_token_not_overriden (builtin definition is used even if a custom one is provided) -def test_external_token(client: Client) -> None: +def test_external_token(session: Session) -> None: # A valid token definition must be provided to use a non-builtin token token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=1, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) - ethereum.sign_tx(client, **params, definitions=common.make_defs(None, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(None, token)) # TODO check that FakeTok symbol is shown -def test_external_chain_without_token(client: Client) -> None: - with client: +def test_external_chain_without_token(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO check that UNKN token is used, FAKE network -def test_external_chain_token_ok(client: Client) -> None: +def test_external_chain_token_ok(session: Session) -> None: # when providing an external chain and matching token, everything works network = common.encode_network(chain_id=66666, slip44=60) token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=66666, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, token)) # TODO check that FakeTok is used, FAKE network -def test_external_chain_token_mismatch(client: Client) -> None: - with client: +def test_external_chain_token_mismatch(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when providing external defs, we explicitly allow, but not use, tokens @@ -156,31 +156,33 @@ def test_external_chain_token_mismatch(client: Client) -> None: ) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx( + session, **params, definitions=common.make_defs(network, token) + ) # TODO check that UNKN is used for token, FAKE for network -def _call_getaddress(client: Client, slip44: int, network: bytes | None) -> None: +def _call_getaddress(session: Session, slip44: int, network: bytes | None) -> None: ethereum.get_address( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), show_display=False, encoded_network=network, ) -def _call_signmessage(client: Client, slip44: int, network: bytes | None) -> None: +def _call_signmessage(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_message( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), b"hello", encoded_network=network, ) -def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> None: +def _call_sign_typed_data(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_typed_data( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), TYPED_DATA, metamask_v4_compat=True, @@ -189,10 +191,10 @@ def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> def _call_sign_typed_data_hash( - client: Client, slip44: int, network: bytes | None + session: Session, slip44: int, network: bytes | None ) -> None: ethereum.sign_typed_data_hash( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), b"\x00" * 32, b"\xff" * 32, @@ -200,7 +202,7 @@ def _call_sign_typed_data_hash( ) -MethodType = Callable[[Client, int, "bytes | None"], None] +MethodType = Callable[[Session, int, "bytes | None"], None] METHODS = ( @@ -212,29 +214,29 @@ def _call_sign_typed_data_hash( @pytest.mark.parametrize("method", METHODS) -def test_method_builtin(client: Client, method: MethodType) -> None: +def test_method_builtin(session: Session, method: MethodType) -> None: # calling a method with a builtin slip44 will work - method(client, 60, None) + method(session, 60, None) @pytest.mark.parametrize("method", METHODS) -def test_method_def_missing(client: Client, method: MethodType) -> None: +def test_method_def_missing(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has no definition will fail with pytest.raises(TrezorFailure, match="Forbidden key path"): - method(client, 66666, None) + method(session, 66666, None) @pytest.mark.parametrize("method", METHODS) -def test_method_external(client: Client, method: MethodType) -> None: +def test_method_external(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition will work network = common.encode_network(slip44=66666) - method(client, 66666, network) + method(session, 66666, network) @pytest.mark.parametrize("method", METHODS) -def test_method_external_mismatch(client: Client, method: MethodType) -> None: +def test_method_external_mismatch(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition that does not match # the slip44 will fail network = common.encode_network(slip44=77777) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - method(client, 66666, network) + method(session, 66666, network) diff --git a/tests/device_tests/ethereum/test_definitions_bad.py b/tests/device_tests/ethereum/test_definitions_bad.py index 3f21195643f..ae917105ae9 100644 --- a/tests/device_tests/ethereum/test_definitions_bad.py +++ b/tests/device_tests/ethereum/test_definitions_bad.py @@ -5,7 +5,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import EthereumDefinitionType from trezorlib.tools import parse_path @@ -16,99 +16,99 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] -def fails(client: Client, network: bytes, match: str) -> None: +def fails(session: Session, network: bytes, match: str) -> None: with pytest.raises(TrezorFailure, match=match): ethereum.get_address( - client, + session, parse_path("m/44h/666666h/0h"), show_display=False, encoded_network=network, ) -def test_short_message(client: Client) -> None: - fails(client, b"\x00", "Invalid Ethereum definition") +def test_short_message(session: Session) -> None: + fails(session, b"\x00", "Invalid Ethereum definition") -def test_mangled_signature(client: Client) -> None: +def test_mangled_signature(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_signature = signature[:-1] + b"\xff" - fails(client, payload + proof + bad_signature, "Invalid definition signature") + fails(session, payload + proof + bad_signature, "Invalid definition signature") -def test_not_enough_signatures(client: Client) -> None: +def test_not_enough_signatures(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [], threshold=1) - fails(client, payload + proof + signature, "Invalid definition signature") + fails(session, payload + proof + signature, "Invalid definition signature") -def test_missing_signature(client: Client) -> None: +def test_missing_signature(session: Session) -> None: payload = make_payload() proof, _ = sign_payload(payload, []) - fails(client, payload + proof, "Invalid Ethereum definition") + fails(session, payload + proof, "Invalid Ethereum definition") -def test_mangled_payload(client: Client) -> None: +def test_mangled_payload(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_payload = payload[:-1] + b"\xff" - fails(client, bad_payload + proof + signature, "Invalid definition signature") + fails(session, bad_payload + proof + signature, "Invalid definition signature") -def test_proof_length_mismatch(client: Client) -> None: +def test_proof_length_mismatch(session: Session) -> None: payload = make_payload() _, signature = sign_payload(payload, []) bad_proof = b"\x01" - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_proof(client: Client) -> None: +def test_bad_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [sha256(b"x").digest()]) bad_proof = proof[:-1] + b"\xff" - fails(client, payload + bad_proof + signature, "Invalid definition signature") + fails(session, payload + bad_proof + signature, "Invalid definition signature") -def test_trimmed_proof(client: Client) -> None: +def test_trimmed_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_proof = proof[:-1] - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_prefix(client: Client) -> None: +def test_bad_prefix(session: Session) -> None: payload = make_payload() payload = b"trzd2" + payload[5:] proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_bad_type(client: Client) -> None: +def test_bad_type(session: Session) -> None: # assuming we expect a network definition payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=make_token()) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition type mismatch") + fails(session, payload + proof + signature, "Definition type mismatch") -def test_outdated(client: Client) -> None: +def test_outdated(session: Session) -> None: payload = make_payload(timestamp=0) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition is outdated") + fails(session, payload + proof + signature, "Definition is outdated") -def test_malformed_protobuf(client: Client) -> None: +def test_malformed_protobuf(session: Session) -> None: payload = make_payload(message=b"\x00") proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_protobuf_mismatch(client: Client) -> None: +def test_protobuf_mismatch(session: Session) -> None: payload = make_payload( data_type=EthereumDefinitionType.NETWORK, message=make_token() ) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") payload = make_payload( data_type=EthereumDefinitionType.TOKEN, message=make_network() @@ -119,13 +119,13 @@ def test_protobuf_mismatch(client: Client) -> None: params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) ethereum.sign_tx( - client, + session, **params, definitions=make_defs(None, payload + proof + signature), ) -def test_trailing_garbage(client: Client) -> None: +def test_trailing_garbage(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature + b"\x00", "Invalid Ethereum definition") + fails(session, payload + proof + signature + b"\x00", "Invalid Ethereum definition") diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index 3add0ad92fb..b57fcd6afd3 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -27,21 +27,21 @@ @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress(client: Client, parameters, result): +def test_getaddress(session: Session, parameters, result): address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True) == result["address"] + ethereum.get_address(session, address_n, show_display=True) == result["address"] ) @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress_chunkify_details(client: Client, parameters, result): - with client: +def test_getaddress_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True, chunkify=True) + ethereum.get_address(session, address_n, show_display=True, chunkify=True) == result["address"] ) diff --git a/tests/device_tests/ethereum/test_getpublickey.py b/tests/device_tests/ethereum/test_getpublickey.py index 103b261f579..586abf736d7 100644 --- a/tests/device_tests/ethereum/test_getpublickey.py +++ b/tests/device_tests/ethereum/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -27,9 +27,9 @@ @parametrize_using_common_fixtures("ethereum/getpublickey.json") -def test_ethereum_getpublickey(client: Client, parameters, result): +def test_ethereum_getpublickey(session: Session, parameters, result): path = parse_path(parameters["path"]) - res = ethereum.get_public_node(client, path) + res = ethereum.get_public_node(session, path) assert res.node.depth == len(path) assert res.node.fingerprint == result["fingerprint"] assert res.node.child_num == result["child_num"] @@ -38,14 +38,14 @@ def test_ethereum_getpublickey(client: Client, parameters, result): assert res.xpub == result["xpub"] -def test_slip25_disallowed(client: Client): +def test_slip25_disallowed(session: Session): path = parse_path("m/10025'/60'/0'/0/0") with pytest.raises(TrezorFailure): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) @pytest.mark.models("legacy") -def test_legacy_restrictions(client: Client): +def test_legacy_restrictions(session: Session): path = parse_path("m/46'") with pytest.raises(TrezorFailure, match="Invalid path for EthereumGetPublicKey"): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index 38159e39e04..dbb70c0810a 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum, exceptions -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,11 +28,11 @@ @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data( - client, + session, address_n, parameters["data"], metamask_v4_compat=parameters["metamask_v4_compat"], @@ -43,11 +43,11 @@ def test_ethereum_sign_typed_data(client: Client, parameters, result): @pytest.mark.models("legacy") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data_hash( - client, + session, address_n, ethereum.decode_hex(parameters["domain_separator_hash"]), # message hash is empty for domain-only hashes @@ -96,13 +96,13 @@ def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") -def test_ethereum_sign_typed_data_show_more_button(client: Client): - with client: +def test_ethereum_sign_typed_data_show_more_button(session: Session): + with session.client as client: client.watch_layout() IF = InputFlowEIP712ShowMore(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, @@ -110,13 +110,13 @@ def test_ethereum_sign_typed_data_show_more_button(client: Client): @pytest.mark.models("core") -def test_ethereum_sign_typed_data_cancel(client: Client): - with client, pytest.raises(exceptions.Cancelled): +def test_ethereum_sign_typed_data_cancel(session: Session): + with session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() IF = InputFlowEIP712Cancel(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index 8cf2680ad82..7e50bd205ad 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -26,18 +26,18 @@ @parametrize_using_common_fixtures("ethereum/signmessage.json") -def test_signmessage(client: Client, parameters, result): +def test_signmessage(session: Session, parameters, result): res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] @parametrize_using_common_fixtures("ethereum/verifymessage.json") -def test_verify(client: Client, parameters, result): +def test_verify(session: Session, parameters, result): res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], @@ -45,7 +45,7 @@ def test_verify(client: Client, parameters, result): assert res is True -def test_verify_invalid(client: Client): +def test_verify_invalid(session: Session): # First vector from the verifymessage JSON fixture msg = "This is an example of a signed message." address = "0xEa53AF85525B1779eE99ece1a5560C0b78537C3b" @@ -54,7 +54,7 @@ def test_verify_invalid(client: Client): ) res = ethereum.verify_message( - client, + session, address, sig, msg, @@ -63,7 +63,7 @@ def test_verify_invalid(client: Client): # Changing the signature, expecting failure res = ethereum.verify_message( - client, + session, address, sig[:-1] + b"\x00", msg, @@ -72,7 +72,7 @@ def test_verify_invalid(client: Client): # Changing the message, expecting failure res = ethereum.verify_message( - client, + session, address, sig, msg + "abc", @@ -81,7 +81,7 @@ def test_verify_invalid(client: Client): # Changing the address, expecting failure res = ethereum.verify_message( - client, + session, address[:-1] + "a", sig, msg, diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index 4c84843585d..62180957b40 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -17,6 +17,7 @@ import pytest from trezorlib import ethereum, exceptions, messages, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters from trezorlib.exceptions import TrezorFailure @@ -56,28 +57,28 @@ def make_defs(parameters: dict) -> messages.EthereumDefinitions: "ethereum/sign_tx_eip155.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx(client: Client, chunkify: bool, parameters: dict, result: dict): +def test_signtx(session: Session, chunkify: bool, parameters: dict, result: dict): input_flow = ( - InputFlowConfirmAllWarnings(client).get() - if not client.debug.legacy_debug + InputFlowConfirmAllWarnings(session.client).get() + if not session.client.debug.legacy_debug else None ) - _do_test_signtx(client, parameters, result, input_flow, chunkify=chunkify) + _do_test_signtx(session, parameters, result, input_flow, chunkify=chunkify) def _do_test_signtx( - client: Client, + session: Session, parameters: dict, result: dict, input_flow=None, chunkify: bool = False, ): - with client: + with session.client as client: if input_flow: client.watch_layout() client.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -120,10 +121,10 @@ def _do_test_signtx( @pytest.mark.models("core", reason="T1 does not support input flows") -def test_signtx_fee_info(client: Client): - input_flow = InputFlowEthereumSignTxShowFeeInfo(client).get() +def test_signtx_fee_info(session: Session): + input_flow = InputFlowEthereumSignTxShowFeeInfo(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -135,10 +136,10 @@ def test_signtx_fee_info(client: Client): skip="delizia", reason="T1 does not support input flows; Delizia can't send Cancel on Summary", ) -def test_signtx_go_back_from_summary(client: Client): - input_flow = InputFlowEthereumSignTxGoBackFromSummary(client).get() +def test_signtx_go_back_from_summary(session: Session): + input_flow = InputFlowEthereumSignTxGoBackFromSummary(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -147,12 +148,14 @@ def test_signtx_go_back_from_summary(client: Client): @parametrize_using_common_fixtures("ethereum/sign_tx_eip1559.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result: dict): - with client: +def test_signtx_eip1559( + session: Session, chunkify: bool, parameters: dict, result: dict +): + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_limit=int(parameters["gas_limit"], 16), @@ -171,14 +174,14 @@ def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result assert sig_v == result["sig_v"] -def test_sanity_checks(client: Client): +def test_sanity_checks(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -191,7 +194,7 @@ def test_sanity_checks(client: Client): # gas overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -204,7 +207,7 @@ def test_sanity_checks(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -215,13 +218,13 @@ def test_sanity_checks(client: Client): ) -def test_data_streaming(client: Client): +def test_data_streaming(session: Session): """Only verifying the expected responses, the signatures are checked in vectorized function above. """ - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), (is_t1, messages.ButtonRequest(code=messages.ButtonRequestType.SignTx)), @@ -259,7 +262,7 @@ def test_data_streaming(client: Client): ) ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0, gas_price=20_000, @@ -271,11 +274,11 @@ def test_data_streaming(client: Client): ) -def test_signtx_eip1559_access_list(client: Client): - with client: +def test_signtx_eip1559_access_list(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -310,11 +313,11 @@ def test_signtx_eip1559_access_list(client: Client): ) -def test_signtx_eip1559_access_list_larger(client: Client): - with client: +def test_signtx_eip1559_access_list_larger(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -363,14 +366,14 @@ def test_signtx_eip1559_access_list_larger(client: Client): ) -def test_sanity_checks_eip1559(client: Client): +def test_sanity_checks_eip1559(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -384,7 +387,7 @@ def test_sanity_checks_eip1559(client: Client): # max fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -398,7 +401,7 @@ def test_sanity_checks_eip1559(client: Client): # priority fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -412,7 +415,7 @@ def test_sanity_checks_eip1559(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -443,10 +446,10 @@ def input_flow_data_go_back(client: Client, cancel: bool = False): "flow", (input_flow_data_skip, input_flow_data_scroll_down, input_flow_data_go_back) ) @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") -def test_signtx_data_pagination(client: Client, flow): +def test_signtx_data_pagination(session: Session, flow): def _sign_tx_call(): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0x0, gas_price=0x14, @@ -458,13 +461,13 @@ def _sign_tx_call(): data=bytes.fromhex(HEXDATA), ) - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(flow(client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with client, pytest.raises(exceptions.Cancelled): + with session, session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() client.set_input_flow(flow(client, cancel=True)) _sign_tx_call() @@ -473,20 +476,22 @@ def _sign_tx_call(): @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_staking(client: Client, chunkify: bool, parameters: dict, result: dict): - input_flow = InputFlowEthereumSignTxStaking(client).get() +def test_signtx_staking( + session: Session, chunkify: bool, parameters: dict, result: dict +): + input_flow = InputFlowEthereumSignTxStaking(session.client).get() _do_test_signtx( - client, parameters, result, input_flow=input_flow, chunkify=chunkify + session, parameters, result, input_flow=input_flow, chunkify=chunkify ) @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_data_error.json") -def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dict): +def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: dict): # result not needed with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -503,10 +508,10 @@ def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dic @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") -def test_signtx_staking_eip1559(client: Client, parameters: dict, result: dict): - with client: +def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), max_gas_fee=int(parameters["max_gas_fee"], 16), diff --git a/tests/device_tests/misc/test_msg_cipherkeyvalue.py b/tests/device_tests/misc/test_msg_cipherkeyvalue.py index 7a9fe664206..4efec7ab060 100644 --- a/tests/device_tests/misc/test_msg_cipherkeyvalue.py +++ b/tests/device_tests/misc/test_msg_cipherkeyvalue.py @@ -17,15 +17,15 @@ import pytest from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_encrypt(client: Client): +def test_encrypt(session: Session): res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -35,7 +35,7 @@ def test_encrypt(client: Client): assert res.hex() == "676faf8f13272af601776bc31bc14e8f" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -45,7 +45,7 @@ def test_encrypt(client: Client): assert res.hex() == "5aa0fbcb9d7fa669880745479d80c622" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -55,7 +55,7 @@ def test_encrypt(client: Client): assert res.hex() == "958d4f63269b61044aaedc900c8d6208" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -66,7 +66,7 @@ def test_encrypt(client: Client): # different key res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test2", b"testing message!", @@ -77,7 +77,7 @@ def test_encrypt(client: Client): # different message res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message! it is different", @@ -90,7 +90,7 @@ def test_encrypt(client: Client): # different path res = misc.encrypt_keyvalue( - client, + session, [0, 1, 3], "test", b"testing message!", @@ -101,9 +101,9 @@ def test_encrypt(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_decrypt(client: Client): +def test_decrypt(session: Session): res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), @@ -113,7 +113,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), @@ -123,7 +123,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), @@ -133,7 +133,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), @@ -144,7 +144,7 @@ def test_decrypt(client: Client): # different key res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test2", bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), @@ -155,7 +155,7 @@ def test_decrypt(client: Client): # different message res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex( @@ -168,7 +168,7 @@ def test_decrypt(client: Client): # different path res = misc.decrypt_keyvalue( - client, + session, [0, 1, 3], "test", bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), @@ -178,11 +178,11 @@ def test_decrypt(client: Client): assert res == b"testing message!" -def test_encrypt_badlen(client: Client): +def test_encrypt_badlen(session: Session): with pytest.raises(Exception): - misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.encrypt_keyvalue(session, [0, 1, 2], "test", b"testing") -def test_decrypt_badlen(client: Client): +def test_decrypt_badlen(session: Session): with pytest.raises(Exception): - misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.decrypt_keyvalue(session, [0, 1, 2], "test", b"testing") diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index 2c33498b75c..4856e234ebf 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -17,6 +17,7 @@ import pytest from trezorlib import misc +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from ... import translations as TR @@ -32,10 +33,11 @@ def input_flow(): client.debug.swipe_up() client.debug.press_yes() - with client: + session = Session(client.get_session()) + with client, session: client.set_input_flow(input_flow()) misc.encrypt_keyvalue( - client, + session, [], "Enable labeling?", b"", diff --git a/tests/device_tests/misc/test_msg_getecdhsessionkey.py b/tests/device_tests/misc/test_msg_getecdhsessionkey.py index 8c38f612b1d..d7c532dc5a0 100644 --- a/tests/device_tests/misc/test_msg_getecdhsessionkey.py +++ b/tests/device_tests/misc/test_msg_getecdhsessionkey.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_ecdh(client: Client): +def test_ecdh(session: Session): identity = messages.IdentityType( proto="gpg", user="", @@ -37,7 +37,7 @@ def test_ecdh(client: Client): "0407f2c6e5becf3213c1d07df0cfbe8e39f70a8c643df7575e5c56859ec52c45ca950499c019719dae0fda04248d851e52cf9d66eeb211d89a77be40de22b6c89d" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="secp256k1", @@ -55,7 +55,7 @@ def test_ecdh(client: Client): "04811a6c2bd2a547d0dd84747297fec47719e7c3f9b0024f027c2b237be99aac39a9230acbd163d0cb1524a0f5ea4bfed6058cec6f18368f72a12aa0c4d083ff64" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="nist256p1", @@ -73,7 +73,7 @@ def test_ecdh(client: Client): "40a8cf4b6a64c4314e80f15a8ea55812bd735fbb365936a48b2d78807b575fa17a" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="curve25519", diff --git a/tests/device_tests/misc/test_msg_getentropy.py b/tests/device_tests/misc/test_msg_getentropy.py index 593fb1a76c1..d5d19425f9b 100644 --- a/tests/device_tests/misc/test_msg_getentropy.py +++ b/tests/device_tests/misc/test_msg_getentropy.py @@ -20,7 +20,7 @@ from trezorlib import messages as m from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session ENTROPY_LENGTHS_POW2 = [2**l for l in range(10)] ENTROPY_LENGTHS_POW2_1 = [2**l + 1 for l in range(10)] @@ -40,11 +40,11 @@ def entropy(data): @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) -def test_entropy(client: Client, entropy_length): - with client: - client.set_expected_responses( +def test_entropy(session: Session, entropy_length): + with session: + session.set_expected_responses( [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] ) - ent = misc.get_entropy(client, entropy_length) + ent = misc.get_entropy(session, entropy_length) assert len(ent) == entropy_length print(f"{entropy_length} bytes: entropy = {entropy(ent)}") diff --git a/tests/device_tests/misc/test_msg_signidentity.py b/tests/device_tests/misc/test_msg_signidentity.py index bc9e7f5bd4e..6715387d387 100644 --- a/tests/device_tests/misc/test_msg_signidentity.py +++ b/tests/device_tests/misc/test_msg_signidentity.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_sign(client: Client): +def test_sign(session: Session): hidden = bytes.fromhex( "cd8552569d6e4509266ef137584d1e62c7579b5b8ed69bbafa4b864c6521e7c2" ) @@ -40,7 +40,7 @@ def test_sign(client: Client): path="/login", index=0, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "17F17smBTX9VTZA9Mj8LM5QGYNZnmziCjL" assert ( sig.public_key.hex() @@ -62,7 +62,7 @@ def test_sign(client: Client): path="/pub", index=3, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "1KAr6r5qF2kADL8bAaRQBjGKYEGxn9WrbS" assert ( sig.public_key.hex() @@ -80,7 +80,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="nist256p1" + session, identity, hidden, visual, ecdsa_curve_name="nist256p1" ) assert sig.address is None assert ( @@ -99,7 +99,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -116,7 +116,7 @@ def test_sign(client: Client): proto="gpg", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -133,7 +133,7 @@ def test_sign(client: Client): proto="signify", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index dfd0ce5ab09..1a6d3ffc01c 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -47,19 +47,19 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_monero_getaddress(client: Client, path: str, expected_address: bytes): - address = monero.get_address(client, parse_path(path), show_display=True) +def test_monero_getaddress(session: Session, path: str, expected_address: bytes): + address = monero.get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_monero_getaddress_chunkify_details( - client: Client, path: str, expected_address: bytes + session: Session, path: str, expected_address: bytes ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = monero.get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/monero/test_getwatchkey.py b/tests/device_tests/monero/test_getwatchkey.py index eee83d0445f..30e3d7b1140 100644 --- a/tests/device_tests/monero/test_getwatchkey.py +++ b/tests/device_tests/monero/test_getwatchkey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -27,8 +27,8 @@ @pytest.mark.monero @pytest.mark.models("core") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_monero_getwatchkey(client: Client): - res = monero.get_watch_key(client, parse_path("m/44h/128h/0h")) +def test_monero_getwatchkey(session: Session): + res = monero.get_watch_key(session, parse_path("m/44h/128h/0h")) assert ( res.address == b"4Ahp23WfMrMFK3wYL2hLWQFGt87ZTeRkufS6JoQZu6MEFDokAQeGWmu9MA3GFq1yVLSJQbKJqVAn9F9DLYGpRzRAEXqAXKM" @@ -37,7 +37,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "8722520a581e2a50cc1adab4a1692401effd37b0d63b9d9b60fd7f34ea2b950e" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/1h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/1h")) assert ( res.address == b"44iAazhoAkv5a5RqLNVyh82a1n3ceNggmN4Ho7bUBJ14WkEVR8uFTe9f7v5rNnJ2kEbVXxfXiRzsD5Jtc6NvBi4D6WNHPie" @@ -46,7 +46,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "1f70b7d9e86c11b7a5bee883b75c43d6be189c8f812726ea1ecd94b06bb7db04" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/2h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/2h")) assert ( res.address == b"47ejhmbZ4wHUhXaqA4b7PN667oPMkokf4ZkNdWrMSPy9TNaLVr7vLqVUQHh2MnmaAEiyrvLsX8xUf99q3j1iAeMV8YvSFcH" diff --git a/tests/device_tests/nem/test_getaddress.py b/tests/device_tests/nem/test_getaddress.py index b2b20c529ec..920dd974904 100644 --- a/tests/device_tests/nem/test_getaddress.py +++ b/tests/device_tests/nem/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -28,10 +28,10 @@ @pytest.mark.models("t1b1", "t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_getaddress(client: Client, chunkify: bool): +def test_nem_getaddress(session: Session, chunkify: bool): assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x68, show_display=True, @@ -41,7 +41,7 @@ def test_nem_getaddress(client: Client, chunkify: bool): ) assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x98, show_display=True, diff --git a/tests/device_tests/nem/test_signtx_mosaics.py b/tests/device_tests/nem/test_signtx_mosaics.py index 51cfd556a77..3e6b835f953 100644 --- a/tests/device_tests/nem/test_signtx_mosaics.py +++ b/tests/device_tests/nem/test_signtx_mosaics.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -32,9 +32,9 @@ ] -def test_nem_signtx_mosaic_supply_change(client: Client): +def test_nem_signtx_mosaic_supply_change(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_mosaic_supply_change(client: Client): ) -def test_nem_signtx_mosaic_creation(client: Client): +def test_nem_signtx_mosaic_creation(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -93,9 +93,9 @@ def test_nem_signtx_mosaic_creation(client: Client): ) -def test_nem_signtx_mosaic_creation_properties(client: Client): +def test_nem_signtx_mosaic_creation_properties(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -130,9 +130,9 @@ def test_nem_signtx_mosaic_creation_properties(client: Client): ) -def test_nem_signtx_mosaic_creation_levy(client: Client): +def test_nem_signtx_mosaic_creation_levy(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_multisig.py b/tests/device_tests/nem/test_signtx_multisig.py index d153547c424..ef641e52f39 100644 --- a/tests/device_tests/nem/test_signtx_multisig.py +++ b/tests/device_tests/nem/test_signtx_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,9 +31,9 @@ # assertion data from T1 -def test_nem_signtx_aggregate_modification(client: Client): +def test_nem_signtx_aggregate_modification(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_aggregate_modification(client: Client): ) -def test_nem_signtx_multisig(client: Client): +def test_nem_signtx_multisig(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 1, @@ -98,7 +98,7 @@ def test_nem_signtx_multisig(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -132,9 +132,9 @@ def test_nem_signtx_multisig(client: Client): ) -def test_nem_signtx_multisig_signer(client: Client): +def test_nem_signtx_multisig_signer(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 333, @@ -169,7 +169,7 @@ def test_nem_signtx_multisig_signer(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 900000, diff --git a/tests/device_tests/nem/test_signtx_others.py b/tests/device_tests/nem/test_signtx_others.py index f775c60cdf6..9760d8c5235 100644 --- a/tests/device_tests/nem/test_signtx_others.py +++ b/tests/device_tests/nem/test_signtx_others.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,10 +31,10 @@ # assertion data from T1 -def test_nem_signtx_importance_transfer(client: Client): - with client: +def test_nem_signtx_importance_transfer(session: Session): + with session: tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 12349215, @@ -60,9 +60,9 @@ def test_nem_signtx_importance_transfer(client: Client): ) -def test_nem_signtx_provision_namespace(client: Client): +def test_nem_signtx_provision_namespace(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_transfers.py b/tests/device_tests/nem/test_signtx_transfers.py index 0388b30ffb4..2df62b55936 100644 --- a/tests/device_tests/nem/test_signtx_transfers.py +++ b/tests/device_tests/nem/test_signtx_transfers.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12, is_core @@ -32,16 +32,16 @@ # assertion data from T1 @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_signtx_simple(client: Client, chunkify: bool): - with client: - client.set_expected_responses( +def test_nem_signtx_simple(session: Session, chunkify: bool): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Unencrypted message messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -53,7 +53,7 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -82,16 +82,16 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_encrypted_payload(client: Client): - with client: - client.set_expected_responses( +def test_nem_signtx_encrypted_payload(session: Session): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Ask for encryption messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -103,7 +103,7 @@ def test_nem_signtx_encrypted_payload(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -134,9 +134,9 @@ def test_nem_signtx_encrypted_payload(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_xem_as_mosaic(client: Client): +def test_nem_signtx_xem_as_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -168,9 +168,9 @@ def test_nem_signtx_xem_as_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_unknown_mosaic(client: Client): +def test_nem_signtx_unknown_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -202,9 +202,9 @@ def test_nem_signtx_unknown_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic(client: Client): +def test_nem_signtx_known_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -236,9 +236,9 @@ def test_nem_signtx_known_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic_with_levy(client: Client): +def test_nem_signtx_known_mosaic_with_levy(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -270,9 +270,9 @@ def test_nem_signtx_known_mosaic_with_levy(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_multiple_mosaics(client: Client): +def test_nem_signtx_multiple_mosaics(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 416fef78eac..8841a52426f 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -19,7 +19,7 @@ import pytest from trezorlib import device, exceptions, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import ( @@ -28,9 +28,9 @@ ) -def do_recover_legacy(client: Client, mnemonic: list[str]): +def do_recover_legacy(session: Session, mnemonic: list[str]): def input_callback(_): - word, pos = client.debug.read_recovery_word() + word, pos = session.client.debug.read_recovery_word() if pos != 0 and pos is not None: word = mnemonic[pos - 1] mnemonic[pos - 1] = None @@ -39,7 +39,7 @@ def input_callback(_): return word ret = device.recover( - client, + session, type=messages.RecoveryType.DryRun, word_count=len(mnemonic), input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, @@ -50,58 +50,59 @@ def input_callback(_): return ret -def do_recover_core(client: Client, mnemonic: list[str], mismatch: bool = False): - with client: +def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): + with session.client as client: client.watch_layout() IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) client.set_input_flow(IF.get()) - return device.recover(client, type=messages.RecoveryType.DryRun) + return device.recover(session, type=messages.RecoveryType.DryRun) -def do_recover(client: Client, mnemonic: list[str], mismatch: bool = False): - if client.model is models.T1B1: - return do_recover_legacy(client, mnemonic) +def do_recover(session: Session, mnemonic: list[str], mismatch: bool = False): + if session.model is models.T1B1: + return do_recover_legacy(session, mnemonic) else: - return do_recover_core(client, mnemonic, mismatch) + return do_recover_core(session, mnemonic, mismatch) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_dry_run(client: Client): - ret = do_recover(client, MNEMONIC12.split(" ")) +def test_dry_run(session: Session): + ret = do_recover(session, MNEMONIC12.split(" ")) assert isinstance(ret, messages.Success) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_seed_mismatch(client: Client): +def test_seed_mismatch(session: Session): with pytest.raises( exceptions.TrezorFailure, match="does not match the one in the device" ): - do_recover(client, ["all"] * 12, mismatch=True) + do_recover(session, ["all"] * 12, mismatch=True) @pytest.mark.models("legacy") -def test_invalid_seed_t1(client: Client): +def test_invalid_seed_t1(session: Session): with pytest.raises(exceptions.TrezorFailure, match="Invalid seed"): - do_recover(client, ["stick"] * 12) + do_recover(session, ["stick"] * 12) @pytest.mark.models("core") -def test_invalid_seed_core(client: Client): - with client: +def test_invalid_seed_core(session: Session): + with session, session.client as client: client.watch_layout() - IF = InputFlowBip39RecoveryDryRunInvalid(client) + IF = InputFlowBip39RecoveryDryRunInvalid(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( - client, + session, type=messages.RecoveryType.DryRun, ) @pytest.mark.setup_client(uninitialized=True) -def test_uninitialized(client: Client): +@pytest.mark.uninitialized_session +def test_uninitialized(session: Session): with pytest.raises(exceptions.TrezorFailure, match="not initialized"): - do_recover(client, ["all"] * 12) + do_recover(session, ["all"] * 12) DRY_RUN_ALLOWED_FIELDS = ( @@ -140,7 +141,7 @@ def _make_bad_params(): @pytest.mark.parametrize("field_name, field_value", _make_bad_params()) -def test_bad_parameters(client: Client, field_name: str, field_value: Any): +def test_bad_parameters(session: Session, field_name: str, field_value: Any): msg = messages.RecoveryDevice( type=messages.RecoveryType.DryRun, word_count=12, @@ -152,4 +153,4 @@ def test_bad_parameters(client: Client, field_name: str, field_value: Any): exceptions.TrezorFailure, match="Forbidden field set in dry-run", ): - client.call(msg) + session.call(msg) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py index 4f2eab6147b..7ddc634b8d4 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -29,9 +29,10 @@ @pytest.mark.setup_client(uninitialized=True) -def test_pin_passphrase(client: Client): +def test_pin_passphrase(session: Session): + debug = session.client.debug mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -43,30 +44,30 @@ def test_pin_passphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -76,23 +77,26 @@ def test_pin_passphrase(client: Client): assert fakes == 12 assert mnemonic == [None] * 12 + raise Exception("TEST IS USING INIT MESSAGE - TODO CHANGE") # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session.init_device() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_nopin_nopassphrase(client: Client): +def test_nopin_nopassphrase(session: Session): mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -104,19 +108,20 @@ def test_nopin_nopassphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug = session.client.debug + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -126,21 +131,26 @@ def test_nopin_nopassphrase(client: Client): assert fakes == 12 assert mnemonic == [None] * 12 + raise Exception("TEST IS USING INIT MESSAGE - TODO CHANGE") + # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + # session.init_device() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_word_fail(client: Client): - ret = client.call_raw( +def test_word_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -152,23 +162,24 @@ def test_word_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.WordRequest) for _ in range(int(12 * 2)): - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word="kwyjibo")) + ret = session.call_raw(messages.WordAck(word="kwyjibo")) assert isinstance(ret, messages.Failure) break else: - client.call_raw(messages.WordAck(word=word)) + session.call_raw(messages.WordAck(word=word)) @pytest.mark.setup_client(uninitialized=True) -def test_pin_fail(client: Client): - ret = client.call_raw( +def test_pin_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -180,36 +191,36 @@ def test_pin_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN4) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN4) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time, but different one - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Failure should be raised assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): device.recover( - client, + session, word_count=12, pin_protection=False, passphrase_protection=False, label="label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index 6046e85ca78..abca75bbee6 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import InputFlowBip39Recovery @@ -26,47 +26,49 @@ @pytest.mark.setup_client(uninitialized=True) -def test_tt_pin_passphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" @pytest.mark.setup_client(uninitialized=True) -def test_tt_nopin_nopassphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_nopin_nopassphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): - device.recover(client) + device.recover(session) with pytest.raises(exceptions.TrezorFailure, match="Already initialized"): - client.call(messages.RecoveryDevice()) + session.call(messages.RecoveryDevice()) diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index ad6f51ed43f..9000952b01a 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_33 from ...input_flows import ( @@ -28,7 +28,7 @@ InputFlowSlip39AdvancedRecoveryThresholdReached, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] EXTRA_GROUP_SHARE = [ "eraser senior decision smug corner ruin rescue cubic angel tackle skin skunk program roster trash rumor slush angel flea amazing" @@ -46,98 +46,98 @@ # To allow reusing functionality for multiple tests def _test_secret( - client: Client, shares: list[str], secret: str, click_info: bool = False + session: Session, shares: list[str], secret: str, click_info: bool = False ): - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", ) - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Advanced - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Advanced + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_secret(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret) +def test_secret(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret) @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models(skip="safe3", reason="safe3 does not have info button") -def test_secret_click_info_button(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret, click_info=True) +def test_secret_click_info_button(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret, click_info=True) @pytest.mark.setup_client(uninitialized=True) -def test_extra_share_entered(client: Client): +def test_extra_share_entered(session: Session): _test_secret( - client, + session, shares=EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, secret=VECTORS[0][1], ) @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryNoAbort( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): # we choose the second share from the fixture because # the 1st is 1of1 and group threshold condition is reached first first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ") # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_group_threshold_reached(client: Client): +def test_group_threshold_reached(session: Session): # first share in the fixture is 1of1 so we choose that first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ") # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 52309834974..37b4a0264dd 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import MNEMONIC_SLIP39_ADVANCED_20 @@ -39,14 +39,14 @@ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryDryRun( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -55,9 +55,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( @@ -65,7 +65,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 9c84af71184..6ca108820f8 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -36,7 +36,7 @@ InputFlowSlip39BasicRecoveryWrongNthWord, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] MNEMONIC_SLIP39_BASIC_20_1of1 = [ "academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic rebuild aquatic spew" @@ -70,137 +70,137 @@ @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("shares, secret, backup_type", VECTORS) def test_secret( - client: Client, shares: list[str], secret: str, backup_type: messages.BackupType + session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with client: + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") # Workflow successfully ended - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is backup_type + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is backup_type # Check mnemonic - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.setup_client(uninitialized=True) -def test_recover_with_pin_passphrase(client: Client): - with client: +def test_recover_with_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery( client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="label", ) # Workflow successfully ended - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Slip39_Basic @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_abort_between_shares(client: Client): - with client: +def test_abort_between_shares(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_first_share(client: Client): - with client: - IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(client) +def test_invalid_mnemonic_first_share(session: Session): + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_second_share(client: Client): - with client: +def test_invalid_mnemonic_second_share(session: Session): + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("nth_word", range(3)) -def test_wrong_nth_word(client: Client, nth_word: int): +def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoveryWrongNthWord(client, share, nth_word) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoverySameShare(client, share) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoverySameShare(session, share) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_1of1(client: Client): - with client: +def test_1of1(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", ) # Workflow successfully ended - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Basic diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index 8d5d57f9a13..b9c4ca6daab 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...input_flows import InputFlowSlip39BasicRecoveryDryRun @@ -37,12 +37,12 @@ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -51,9 +51,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( @@ -61,7 +61,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index db7e3c88454..9710ee6201f 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -19,7 +19,7 @@ from shamir_mnemonic import shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupAvailability, BackupType from ...common import MOCK_GET_ENTROPY @@ -31,32 +31,32 @@ ) -def backup_flow_bip39(client: Client) -> bytes: - with client: +def backup_flow_bip39(session: Session) -> bytes: + with session.client as client: IF = InputFlowBip39Backup(client) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) assert IF.mnemonic is not None return IF.mnemonic.encode() -def backup_flow_slip39_basic(client: Client): - with client: +def backup_flow_slip39_basic(session: Session): + with session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) ems = shamir.recover_ems(groups) return ems.ciphertext -def backup_flow_slip39_advanced(client: Client): - with client: +def backup_flow_slip39_advanced(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] groups = shamir.decode_mnemonics(mnemonics) @@ -74,10 +74,13 @@ def backup_flow_slip39_advanced(client: Client): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_msg(client: Client, backup_type, backup_flow): - with client: +@pytest.mark.uninitialized_session +def test_skip_backup_msg(session: Session, backup_type, backup_flow): + assert session.features.initialized is False + + with session: device.setup( - client, + session, skip_backup=True, passphrase_protection=False, pin_protection=False, @@ -86,22 +89,22 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): _get_entropy=MOCK_GET_ENTROPY, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret @@ -109,12 +112,15 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow): - with client: +@pytest.mark.uninitialized_session +def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): + assert session.features.initialized is False + + with session, session.client as client: IF = InputFlowResetSkipBackup(client) client.set_input_flow(IF.get()) device.setup( - client, + session, pin_protection=False, passphrase_protection=False, backup_type=backup_type, @@ -122,21 +128,21 @@ def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow _get_entropy=MOCK_GET_ENTROPY, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py index 803818b3752..b9989ff8520 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py @@ -18,7 +18,7 @@ from mnemonic import Mnemonic from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -28,8 +28,10 @@ @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -40,17 +42,17 @@ def test_reset_device_skip_backup(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False @@ -61,14 +63,14 @@ def test_reset_device_skip_backup(client: Client): expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -78,9 +80,9 @@ def test_reset_device_skip_backup(client: Client): mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.Success) @@ -90,13 +92,15 @@ def test_reset_device_skip_backup(client: Client): assert mnemonic == expected_mnemonic # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup_break(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup_break(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -107,26 +111,26 @@ def test_reset_device_skip_backup_break(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False assert ret.no_backup is False # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) # send Initialize -> break workflow - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -134,11 +138,11 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) # read Features again - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -146,6 +150,6 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False -def test_initialized_device_backup_fail(client: Client): - ret = client.call_raw(messages.BackupDevice()) +def test_initialized_device_backup_fail(session: Session): + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py index 0c96ee4f5c8..b482a5af699 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py @@ -18,7 +18,7 @@ from mnemonic import Mnemonic from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -26,9 +26,10 @@ pytestmark = pytest.mark.models("legacy") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): + debug = session.client.debug # No PIN, no passphrase - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=False, @@ -38,13 +39,13 @@ def reset_device(client: Client, strength: int): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -53,9 +54,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(session.debug.read_reset_word()) + session.debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -65,9 +66,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(session.debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -77,32 +78,38 @@ def reset_device(client: Client, strength: int): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False assert resp.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_128(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_128(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_256_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_256_pin(session: Session): + debug = session.client.debug strength = 256 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -113,24 +120,24 @@ def test_reset_device_256_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -139,9 +146,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -151,9 +158,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -163,23 +170,27 @@ def test_reset_device_256_pin(client: Client): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True assert resp.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -190,27 +201,27 @@ def test_failed_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("1234") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("1234") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("6789") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("6789") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index fe627400674..42e6f30dd83 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -19,8 +19,9 @@ from trezorlib import device, messages from trezorlib.btc import get_public_node +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import EXTERNAL_ENTROPY, MNEMONIC12, MOCK_GET_ENTROPY, generate_entropy @@ -33,14 +34,15 @@ pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): - with client: +def reset_device(session: Session, strength: int): + debug = session.client.debug + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -51,7 +53,7 @@ def reset_device(client: Client, strength: int): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -60,40 +62,43 @@ def reset_device(client: Client, strength: int): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is False + assert resp.passphrase_protection is False + assert resp.backup_type is messages.BackupType.Bip39 # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device(client: Client): - reset_device(client, 128) # 12 words +@pytest.mark.uninitialized_session +def test_reset_device(session: Session): + reset_device(session, 128) # 12 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) # 18 words +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) # 18 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_pin(session: Session): + debug = session.client.debug strength = 256 # 24 words - with client: + with session.client as client: IF = InputFlowBip39ResetPIN(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( - client, + session, strength=strength, passphrase_protection=True, pin_protection=True, @@ -104,7 +109,7 @@ def test_reset_device_pin(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -113,25 +118,25 @@ def test_reset_device_pin(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is True + assert resp.passphrase_protection is True @pytest.mark.setup_client(uninitialized=True) -def test_reset_entropy_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_entropy_check(session: Session): strength = 128 # 12 words - with client: + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -151,31 +156,38 @@ def test_reset_entropy_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check that the device is properly initialized. - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + features = session.call_raw(messages.Initialize()) + else: + session.refresh_features() + features = session.features + + assert features.initialized is True + assert features.backup_availability == messages.BackupAvailability.NotAvailable + assert features.pin_protection is False + assert features.passphrase_protection is False + assert features.backup_type is messages.BackupType.Bip39 # Check that the XPUBs are the same as those from the entropy check. + session = session.client.get_session() for path, xpub in path_xpubs: - res = get_public_node(client, path) + res = get_public_node(session, path) assert res.xpub == xpub @pytest.mark.setup_client(uninitialized=True) -def test_reset_failed_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_failed_check(session: Session): + debug = session.client.debug strength = 256 # 24 words - with client: + with session.client as client: IF = InputFlowBip39ResetFailedCheck(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -186,7 +198,7 @@ def test_reset_failed_check(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -195,55 +207,65 @@ def test_reset_failed_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - assert client.features.initialized is True - assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable - ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 + resp = session.call_raw(messages.GetFeatures()) + assert resp.initialized is True + assert resp.backup_availability == messages.BackupAvailability.NotAvailable + assert resp.pin_protection is False + assert resp.passphrase_protection is False + assert resp.backup_type is messages.BackupType.Bip39 @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice(strength=strength, pin_protection=True, label="test") ) # Confirm Reset assert isinstance(ret, messages.ButtonRequest) - client._raw_write(messages.ButtonAck()) - client.debug.press_yes() + + # client._raw_write(messages.ButtonAck()) + # client.debug.press_yes() + + # # Enter PIN for first time + # client.debug.input("654") + # ret = client.call_raw(messages.ButtonAck()) + + debug.press_yes() # TODO test fails here on T3T1 + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for first time - client.debug.input("654") - ret = client.call_raw(messages.ButtonAck()) + assert isinstance(ret, messages.ButtonRequest) + debug.input("654") + ret = session.call_raw(messages.ButtonAck()) # Re-enter PIN for TR - if client.layout_type is LayoutType.Caesar: + if session.client.layout_type is LayoutType.Caesar: assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for second time assert isinstance(ret, messages.ButtonRequest) - client.debug.input("456") - ret = client.call_raw(messages.ButtonAck()) + debug.input("456") + ret = session.call_raw(messages.ButtonAck()) # PIN mismatch assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.ButtonRequest) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, @@ -252,10 +274,11 @@ def test_already_initialized(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_entropy_check(client: Client): - with client: - delizia = client.debug.layout_type is LayoutType.Delizia - client.set_expected_responses( +@pytest.mark.uninitialized_session +def test_entropy_check(session: Session): + with session: + delizia = session.client.debug.layout_type is LayoutType.Delizia + session.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), @@ -273,11 +296,10 @@ def test_entropy_check(client: Client): messages.PublicKey, (delizia, messages.ButtonRequest(name="backup_device")), messages.Success, - messages.Features, ] ) device.setup( - client, + session, strength=128, entropy_check_count=2, backup_type=messages.BackupType.Bip39, @@ -289,21 +311,21 @@ def test_entropy_check(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_no_entropy_check(client: Client): - with client: - delizia = client.debug.layout_type is LayoutType.Delizia - client.set_expected_responses( +@pytest.mark.uninitialized_session +def test_no_entropy_check(session: Session): + with session: + delizia = session.client.debug.layout_type is LayoutType.Delizia + session.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), messages.EntropyRequest, (delizia, messages.ButtonRequest(name="backup_device")), messages.Success, - messages.Features, ] ) device.setup( - client, + session, strength=128, entropy_check_count=0, backup_type=messages.BackupType.Bip39, diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index ac24ccbcfa6..fd2d3f36049 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -29,25 +30,30 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonic = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonic = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - recover(client, mnemonic) - address_after = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_seedless_session()) + set_language(session, lang[:2]) + recover(session, mnemonic) + session = client.get_session() + address_after = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) assert address_before == address_after -def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str: - with client: +def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: + with session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -58,24 +64,25 @@ def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False assert IF.mnemonic is not None return IF.mnemonic -def recover(client: Client, mnemonic: str): +def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with client: + with session.client as client: IF = InputFlowBip39Recovery(client, words) client.set_input_flow(IF.get()) client.watch_layout() - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index ffa9e73f772..8f95e00a1f7 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -32,8 +33,10 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) # we're generating 3of5 groups 3of5 shares each test_combinations = [ mnemonics[0:3] # shares 1-3 from groups 1-3 @@ -50,25 +53,28 @@ def test_reset_recovery(client: Client): + mnemonics[22:25], ] for combination in test_combinations: + session = client.get_seedless_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - - recover(client, combination) + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_seedless_session()) + set_language(session, lang[:2]) + recover(session, combination) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -79,23 +85,24 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, False) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 44baf4cff36..7753b123f34 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -20,6 +20,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -35,29 +36,35 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_seedless_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) for share_subset in itertools.combinations(mnemonics, 3): + session = client.get_seedless_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_seedless_session()) + set_language(session, lang[:2]) selected_mnemonics = share_subset - recover(client, selected_mnemonics) + recover(session, selected_mnemonics) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -68,23 +75,24 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable return IF.mnemonics -def recover(client: Client, shares: t.Sequence[str]): - with client: +def recover(session: Session, shares: t.Sequence[str]): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + # Workflow successfully ended + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 840841d734e..2d5c9edd4ac 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -37,10 +37,10 @@ def test_reset_device_slip39_advanced(client: Client): with client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) - + session = client.get_seedless_session() # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -57,17 +57,17 @@ def test_reset_device_slip39_advanced(client: Client): # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) def validate_mnemonics( diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index b284012cbe6..dd25fc1342f 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -21,7 +21,7 @@ from trezorlib import device from trezorlib.btc import get_public_node -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import BackupAvailability, BackupType @@ -31,16 +31,16 @@ pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): member_threshold = 3 - with client: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -51,48 +51,51 @@ def reset_device(client: Client, strength: int): ) # generate secret locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = session.client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic_256(client: Client): - reset_device(client, 256) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic_256(session: Session): + reset_device(session, 256) @pytest.mark.setup_client(uninitialized=True) -def test_reset_entropy_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_entropy_check(session: Session): member_threshold = 3 strength = 128 # 20 words - with client: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase. path_xpubs = device.setup( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -101,25 +104,27 @@ def test_reset_entropy_check(client: Client): entropy_check_count=3, _get_entropy=MOCK_GET_ENTROPY, ) - # Generate the master secret locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # Check that all combinations will result in the correct master secret. validate_mnemonics(IF.mnemonics, member_threshold, secret) + # Create a session with cache backing + session = session.client.get_session() + # Check that the device is properly initialized. - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # Check that the XPUBs are the same as those from the entropy check. for path, xpub in path_xpubs: - res = get_public_node(client, path) + res = get_public_node(session, path) assert res.xpub == xpub diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 0d35b6c5b93..2a066926cd8 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.ripple import get_address from trezorlib.tools import parse_path @@ -43,28 +43,28 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_ripple_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_ripple_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_ripple_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address @pytest.mark.setup_client(mnemonic=CUSTOM_MNEMONIC) -def test_ripple_get_address_other(client: Client): +def test_ripple_get_address_other(session: Session): # data from https://github.com/you21979/node-ripple-bip32/blob/master/test/test.js - address = get_address(client, parse_path("m/44h/144h/0h/0/0")) + address = get_address(session, parse_path("m/44h/144h/0h/0/0")) assert address == "r4ocGE47gm4G4LkA9mriVHQqzpMLBTgnTY" - address = get_address(client, parse_path("m/44h/144h/0h/0/1")) + address = get_address(session, parse_path("m/44h/144h/0h/0/1")) assert address == "rUt9ULSrUvfCmke8HTFU1szbmFpWzVbBXW" diff --git a/tests/device_tests/ripple/test_sign_tx.py b/tests/device_tests/ripple/test_sign_tx.py index a03a29d4bec..82911c8abe8 100644 --- a/tests/device_tests/ripple/test_sign_tx.py +++ b/tests/device_tests/ripple/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ripple -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -29,7 +29,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_ripple_sign_simple_tx(client: Client, chunkify: bool): +def test_ripple_sign_simple_tx(session: Session, chunkify: bool): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -43,7 +43,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -66,7 +66,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -92,7 +92,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -104,7 +104,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): ) -def test_ripple_sign_invalid_fee(client: Client): +def test_ripple_sign_invalid_fee(session: Session): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -121,4 +121,4 @@ def test_ripple_sign_invalid_fee(client: Client): TrezorFailure, match="ProcessError: Fee must be in the range of 10 to 10,000 drops", ): - ripple.sign_tx(client, parse_path("m/44h/144h/0h/0/2"), msg) + ripple.sign_tx(session, parse_path("m/44h/144h/0h/0/2"), msg) diff --git a/tests/device_tests/solana/test_address.py b/tests/device_tests/solana/test_address.py index b3af4ea8ed0..e3f53aba875 100644 --- a/tests/device_tests/solana/test_address.py +++ b/tests/device_tests/solana/test_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_address from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ @parametrize_using_common_fixtures( "solana/get_address.json", ) -def test_solana_get_address(client: Client, parameters, result): +def test_solana_get_address(session: Session, parameters, result): actual_result = get_address( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result == result["expected_address"] diff --git a/tests/device_tests/solana/test_public_key.py b/tests/device_tests/solana/test_public_key.py index e12c345fc3e..4ef7924b4db 100644 --- a/tests/device_tests/solana/test_public_key.py +++ b/tests/device_tests/solana/test_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_public_key from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ @parametrize_using_common_fixtures( "solana/get_public_key.json", ) -def test_solana_get_public_key(client: Client, parameters, result): +def test_solana_get_public_key(session: Session, parameters, result): actual_result = get_public_key( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.hex() == result["expected_public_key"] diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index de7ec344625..b9579053a2d 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import sign_tx from trezorlib.tools import parse_path @@ -42,13 +42,11 @@ "solana/sign_tx.unknown_instructions.json", "solana/sign_tx.predefined_transactions.json", ) -def test_solana_sign_tx(client: Client, parameters, result): - client.init_device(new_session=True) - +def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) actual_result = sign_tx( - client, + session, address_n=parse_path(parameters["address"]), serialized_tx=serialized_tx, additional_info=( diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 8e214ab1135..1d5c59e1f8e 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -55,7 +55,7 @@ import pytest from trezorlib import messages, protobuf, stellar -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -87,10 +87,10 @@ def make_op(operation_data): @parametrize_using_common_fixtures("stellar/sign_tx.json") -def test_sign_tx(client: Client, parameters, result): +def test_sign_tx(session: Session, parameters, result): tx, operations = parameters_to_proto(parameters) response = stellar.sign_tx( - client, tx, operations, tx.address_n, tx.network_passphrase + session, tx, operations, tx.address_n, tx.network_passphrase ) assert response.public_key.hex() == result["public_key"] assert b64encode(response.signature).decode() == result["signature"] @@ -113,20 +113,20 @@ def test_xdr(parameters, result): @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address(client: Client, parameters, result): +def test_get_address(session: Session, parameters, result): address_n = parse_path(parameters["path"]) - address = stellar.get_address(client, address_n, show_display=True) + address = stellar.get_address(session, address_n, show_display=True) assert address == result["address"] @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address_chunkify_details(client: Client, parameters, result): - with client: +def test_get_address_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( - client, address_n, show_display=True, chunkify=True + session, address_n, show_display=True, chunkify=True ) assert address == result["address"] diff --git a/tests/device_tests/test_authenticate_device.py b/tests/device_tests/test_authenticate_device.py index f2ffb5d7157..5e697b4f070 100644 --- a/tests/device_tests/test_authenticate_device.py +++ b/tests/device_tests/test_authenticate_device.py @@ -5,7 +5,7 @@ from cryptography.x509 import extensions as ext from trezorlib import device, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import compact_size @@ -35,16 +35,16 @@ ), ), ) -def test_authenticate_device(client: Client, challenge: bytes) -> None: +def test_authenticate_device(session: Session, challenge: bytes) -> None: # NOTE Applications must generate a random challenge for each request. # Issue an AuthenticateDevice challenge to Trezor. - proof = device.authenticate(client, challenge) + proof = device.authenticate(session, challenge) certs = [x509.load_der_x509_certificate(cert) for cert in proof.certificates] # Verify the last certificate in the certificate chain against trust anchor. root_public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256R1(), ROOT_PUBLIC_KEY[client.model] + ec.SECP256R1(), ROOT_PUBLIC_KEY[session.model] ) root_public_key.verify( certs[-1].signature, @@ -78,11 +78,11 @@ def test_authenticate_device(client: Client, challenge: bytes) -> None: # Verify that the common name matches the Trezor model. common_name = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0] - if client.model == models.T3B1: + if session.model == models.T3B1: # XXX TODO replace as soon as we have T3B1 staging internal_model = "T2B1" else: - internal_model = client.model.internal_name + internal_model = session.model.internal_name assert common_name.value.startswith(internal_model) # Verify the signature of the challenge. diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index dc0f69a1df9..a310ff3841e 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -19,7 +19,7 @@ import pytest from trezorlib import device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ..common import TEST_ADDRESS_N, get_test_address @@ -29,42 +29,42 @@ pytestmark = pytest.mark.setup_client(pin=PIN4) -def pin_request(client: Client): +def pin_request(session: Session): return ( messages.PinMatrixRequest - if client.model is models.T1B1 + if session.model is models.T1B1 else messages.ButtonRequest ) -def set_autolock_delay(client: Client, delay): - with client: +def set_autolock_delay(session: Session, delay): + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) -def test_apply_auto_lock_delay(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_apply_auto_lock_delay(session: Session): + set_autolock_delay(session, 10 * 1000) time.sleep(0.1) # sleep less than auto-lock delay - with client: + with session: # No PIN protection is required. - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([pin_request(session), messages.Address]) + get_test_address(session) @pytest.mark.parametrize( @@ -78,44 +78,45 @@ def test_apply_auto_lock_delay(client: Client): 536870, # 149 hours, maximum ], ) -def test_apply_auto_lock_delay_valid(client: Client, seconds): - set_autolock_delay(client, seconds * 1000) - assert client.features.auto_lock_delay_ms == seconds * 1000 +def test_apply_auto_lock_delay_valid(session: Session, seconds): + set_autolock_delay(session, seconds * 1000) + assert session.features.auto_lock_delay_ms == seconds * 1000 -def test_autolock_default_value(client: Client): - assert client.features.auto_lock_delay_ms is None - with client: +def test_autolock_default_value(session: Session): + assert session.features.auto_lock_delay_ms is None + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, label="pls unlock") - client.refresh_features() - assert client.features.auto_lock_delay_ms == 60 * 10 * 1000 + device.apply_settings(session, label="pls unlock") + session.refresh_features() + assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @pytest.mark.parametrize( "seconds", [0, 1, 9, 536871, 2**22], ) -def test_apply_auto_lock_delay_out_of_range(client: Client, seconds): - with client: +def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.Failure(code=messages.FailureType.ProcessError), ] ) delay = seconds * 1000 with pytest.raises(TrezorFailure): - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) @pytest.mark.models("core") -def test_autolock_cancels_ui(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_cancels_ui(session: Session): + set_autolock_delay(session, 10 * 1000) - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -126,44 +127,47 @@ def test_autolock_cancels_ui(client: Client): assert isinstance(resp, messages.ButtonRequest) # send an ack, do not read response - client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) # sleep more than auto-lock delay time.sleep(10.5) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, messages.Failure) assert resp.code == messages.FailureType.ActionCancelled -def test_autolock_ignores_initialize(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_initialize(session: Session): + client = session.client + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() while time.monotonic() - start < 11: # init_device should always work even if locked - client.init_device() + client.resume_session(session) time.sleep(0.1) # after 11 seconds we are definitely locked - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False + +def test_autolock_ignores_getaddress(session: Session): -def test_autolock_ignores_getaddress(client: Client): - set_autolock_delay(client, 10 * 1000) + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() # let's continue for 8 seconds to give a little leeway to the slow CI while time.monotonic() - start < 8: - get_test_address(client) + get_test_address(session) time.sleep(0.1) # sleep 3 more seconds to wait for autolock time.sleep(3) # after 11 seconds we are definitely locked - client.refresh_features() - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index c2d1202eb52..f6ec0965023 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -15,44 +15,64 @@ # If not, see . from trezorlib import device, messages, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client def test_features(client: Client): - f0 = client.features - # client erases session_id from its features - f0.session_id = client.session_id - f1 = client.call(messages.Initialize(session_id=f0.session_id)) - assert f0 == f1 + session = client.get_session() + f0 = session.features + if client.protocol_version == ProtocolVersion.PROTOCOL_V1: + # session erases session_id from its features + f0.session_id = session.id + f1 = session.call(messages.Initialize(session_id=session.id)) + assert f0 == f1 + else: + session2 = client.resume_session(session) + f1: messages.Features = session2.call(messages.GetFeatures()) + assert f1.session_id is None + assert f0 == f1 -def test_capabilities(client: Client): - assert (messages.Capability.Translations in client.features.capabilities) == ( - client.model is not models.T1B1 + +def test_capabilities(session: Session): + assert (messages.Capability.Translations in session.features.capabilities) == ( + session.model is not models.T1B1 ) -def test_ping(client: Client): - ping = client.call(messages.Ping(message="ahoj!")) +def test_ping(session: Session): + ping = session.call(messages.Ping(message="ahoj!")) assert ping == messages.Success(message="ahoj!") def test_device_id_same(client: Client): - id1 = client.get_device_id() - client.init_device() - id2 = client.get_device_id() + session1 = client.get_session() + session2 = client.get_session() + id1 = session1.features.device_id + session2.refresh_features() + id2 = session2.features.device_id + client = client.get_new_client() + session3 = client.get_session() + id3 = session3.features.device_id # ID must be at least 12 characters assert len(id1) >= 12 # Every resulf of UUID must be the same assert id1 == id2 + assert id2 == id3 def test_device_id_different(client: Client): - id1 = client.get_device_id() - device.wipe(client) - id2 = client.get_device_id() + session = client.get_seedless_session() + id1 = client.features.device_id + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + + id2 = client.features.device_id # Device ID must be fresh after every reset assert id1 != id2 diff --git a/tests/device_tests/test_bip32_speed.py b/tests/device_tests/test_bip32_speed.py index 1d184c7e4a8..84d8cf9ae59 100644 --- a/tests/device_tests/test_bip32_speed.py +++ b/tests/device_tests/test_bip32_speed.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import H_ @@ -29,47 +29,47 @@ ] -def test_public_ckd(client: Client): +def test_public_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() - btc.get_address(client, "Bitcoin", range(depth)) + btc.get_address(session, "Bitcoin", range(depth)) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_private_ckd(client: Client): +def test_private_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() address_n = [H_(-i) for i in range(-depth, 0)] - btc.get_address(client, "Bitcoin", address_n) + btc.get_address(session, "Bitcoin", address_n) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_cache(client: Client): +def test_cache(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) + btc.get_address(session, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) nocache_time = time.time() - start start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) + btc.get_address(session, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) cache_time = time.time() - start print("NOCACHE TIME", nocache_time) diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 706745a1981..27fb1b23e6a 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -20,62 +20,66 @@ from trezorlib import btc, device from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path PIN = "1234" -def _assert_busy(client: Client, should_be_busy: bool, screen: str = "Homescreen"): - assert client.features.busy is should_be_busy - if client.layout_type is not LayoutType.T1: +def _assert_busy(session: Session, should_be_busy: bool, screen: str = "Homescreen"): + assert session.features.busy is should_be_busy + if session.client.layout_type is not LayoutType.T1: if should_be_busy: - assert "CoinJoinProgress" in client.debug.read_layout().all_components() + assert ( + "CoinJoinProgress" + in session.client.debug.read_layout().all_components() + ) else: - assert client.debug.read_layout().main_component() == screen + assert session.client.debug.read_layout().main_component() == screen @pytest.mark.setup_client(pin=PIN) -def test_busy_state(client: Client): - _assert_busy(client, False, "Lockscreen") - assert client.features.unlocked is False +def test_busy_state(session: Session): + _assert_busy(session, False, "Lockscreen") + assert session.features.unlocked is False # Show busy dialog for 1 minute. - device.set_busy(client, expiry_ms=60 * 1000) - _assert_busy(client, True) - assert client.features.unlocked is False + device.set_busy(session, expiry_ms=60 * 1000) + _assert_busy(session, True) + assert session.features.unlocked is False - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) - client.refresh_features() - _assert_busy(client, True) - assert client.features.unlocked is True + session.refresh_features() + _assert_busy(session, True) + assert session.features.unlocked is True # Hide the busy dialog. - device.set_busy(client, None) + device.set_busy(session, None) - _assert_busy(client, False) - assert client.features.unlocked is True + _assert_busy(session, False) + assert session.features.unlocked is True @pytest.mark.models("core") -def test_busy_expiry_core(client: Client): +def test_busy_expiry_core(session: Session): WAIT_TIME_MS = 1500 TOLERANCE = 1000 - _assert_busy(client, False) + _assert_busy(session, False) # Start a timer start = time.monotonic() # Show the busy dialog. - device.set_busy(client, expiry_ms=WAIT_TIME_MS) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=WAIT_TIME_MS) + _assert_busy(session, True) # Wait until the layout changes - client.debug.wait_layout() + time.sleep(0.1) # Improves stability of the test for devices with THP + session.client.debug.wait_layout() end = time.monotonic() # Check that the busy dialog was shown for at least WAIT_TIME_MS. @@ -84,26 +88,26 @@ def test_busy_expiry_core(client: Client): # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) @pytest.mark.flaky(max_runs=5) @pytest.mark.models("legacy") -def test_busy_expiry_legacy(client: Client): - _assert_busy(client, False) +def test_busy_expiry_legacy(session: Session): + _assert_busy(session, False) # Show the busy dialog. - device.set_busy(client, expiry_ms=1500) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=1500) + _assert_busy(session, True) # Hasn't expired yet. time.sleep(0.1) - _assert_busy(client, True) + _assert_busy(session, True) # Wait for it to expire. Add some tolerance to account for CI/hardware slowness. time.sleep(4.0) # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index b72e95a88e9..9ab7e9165e8 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -17,7 +17,7 @@ import pytest import trezorlib.messages as m -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled from ..common import TEST_ADDRESS_N @@ -35,15 +35,15 @@ ), ], ) -def test_cancel_message_via_cancel(client: Client, message): +def test_cancel_message_via_cancel(session: Session, message): def input_flow(): yield - client.cancel() + session.cancel() - with client, pytest.raises(Cancelled): - client.set_expected_responses([m.ButtonRequest(), m.Failure()]) + with session, session.client as client, pytest.raises(Cancelled): + session.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_input_flow(input_flow) - client.call(message) + session.call(message) @pytest.mark.parametrize( @@ -58,43 +58,45 @@ def input_flow(): ), ], ) -def test_cancel_message_via_initialize(client: Client, message): - resp = client.call_raw(message) +@pytest.mark.protocol("protocol_v1") +def test_cancel_message_via_initialize(session: Session, message): + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client._raw_write(m.Initialize()) + session._write(m.ButtonAck()) + session._write(m.Initialize()) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.Features) @pytest.mark.models("core") -def test_cancel_on_paginated(client: Client): +def test_cancel_on_paginated(session: Session): """Check that device is responsive on paginated screen. See #1708.""" # In #1708, the device would ignore USB (or UDP) events while waiting for the user # to page through the screen. This means that this testcase, instead of failing, # would get stuck waiting for the _raw_read result. # I'm not spending the effort to modify the testcase to cause a _failure_ if that # happens again. Just be advised that this should not get stuck. + message = m.SignMessage( message=b"hello" * 64, address_n=TEST_ADDRESS_N, coin_name="Testnet", ) - resp = client.call_raw(message) + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client.debug.press_yes() + session._write(m.ButtonAck()) + session.client.debug.press_yes() - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.ButtonRequest) assert resp.pages is not None - client._raw_write(m.ButtonAck()) + session._write(m.ButtonAck()) - client._raw_write(m.Cancel()) - resp = client._raw_read() + session._write(m.Cancel()) + resp = session._read() assert isinstance(resp, m.Failure) assert resp.code == m.FailureType.ActionCancelled diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 747613db127..d9445fddec8 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -17,6 +17,8 @@ import pytest from trezorlib import debuglink, device, messages, misc +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path from trezorlib.transport import udp @@ -32,35 +34,41 @@ def test_layout(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_mnemonic(client: Client): - client.ensure_unlocked() - mnemonic = client.debug.state().mnemonic_secret +def test_mnemonic(session: Session): + session.ensure_unlocked() + mnemonic = session.client.debug.state().mnemonic_secret assert mnemonic == MNEMONIC12.encode() @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12, pin="1234", passphrase="") -def test_pin(client: Client): - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) +def test_pin(session: Session): + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PinMatrixRequest) - state = client.debug.state() - assert state.pin == "1234" - assert state.matrix != "" + with session.client as client: + state = client.debug.state() + assert state.pin == "1234" + assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") - resp = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) - assert isinstance(resp, messages.PassphraseRequest) + pin_encoded = client.debug.encode_pin("1234") + resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + assert isinstance(resp, messages.PassphraseRequest) - resp = client.call_raw(messages.PassphraseAck(passphrase="")) - assert isinstance(resp, messages.Address) + resp = session.call_raw(messages.PassphraseAck(passphrase="")) + assert isinstance(resp, messages.Address) @pytest.mark.models("core") -def test_softlock_instability(client: Client): +def test_softlock_instability(session: Session): + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: + raise Exception("THIS NEEDS TO BE CHANGED FOR THP") + def load_device(): debuglink.load_device( - client, + session, mnemonic=MNEMONIC12, pin="1234", passphrase_protection=False, @@ -68,27 +76,29 @@ def load_device(): ) # start from a clean slate: - resp = client.debug.reseed(0) + resp = session.client.debug.reseed(0) if isinstance(resp, messages.Failure) and not isinstance( - client.transport, udp.UdpTransport + session.client.transport, udp.UdpTransport ): pytest.xfail("reseed only supported on emulator") - device.wipe(client) - entropy_after_wipe = misc.get_entropy(client, 16) + device.wipe(session) + entropy_after_wipe = misc.get_entropy(session, 16) + session.refresh_features() # configure and wipe the device load_device() - client.debug.reseed(0) - device.wipe(client) - assert misc.get_entropy(client, 16) == entropy_after_wipe + session.client.debug.reseed(0) + device.wipe(session) + assert misc.get_entropy(session, 16) == entropy_after_wipe + session.refresh_features() load_device() # the device has PIN -> lock it - client.call(messages.LockDevice()) - client.debug.reseed(0) + session.call(messages.LockDevice()) + session.client.debug.reseed(0) # wipe_device should succeed with no need to unlock - device.wipe(client) + device.wipe(session) # the device is now trying to run the lockscreen, which attempts to unlock. # If the device actually called config.unlock(), it would use additional randomness. # That is undesirable. Assert that the returned entropy is still the same. - assert misc.get_entropy(client, 16) == entropy_after_wipe + assert misc.get_entropy(session, 16) == entropy_after_wipe diff --git a/tests/device_tests/test_firmware_hash.py b/tests/device_tests/test_firmware_hash.py index 50eb063c2b3..217be1c45d9 100644 --- a/tests/device_tests/test_firmware_hash.py +++ b/tests/device_tests/test_firmware_hash.py @@ -3,7 +3,7 @@ import pytest from trezorlib import firmware, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session # size of FIRMWARE_AREA, see core/embed/models/model_*_layout.c FIRMWARE_LENGTHS = { @@ -15,35 +15,35 @@ } -def test_firmware_hash_emu(client: Client) -> None: - if client.features.fw_vendor != "EMULATOR": +def test_firmware_hash_emu(session: Session) -> None: + if session.features.fw_vendor != "EMULATOR": pytest.skip("Only for emulator") - data = b"\xff" * FIRMWARE_LENGTHS[client.model] + data = b"\xff" * FIRMWARE_LENGTHS[session.model] expected_hash = blake2s(data).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash == expected_hash challenge = b"Hello Trezor" expected_hash = blake2s(data, key=challenge).digest() - hash = firmware.get_hash(client, challenge) + hash = firmware.get_hash(session, challenge) assert hash == expected_hash -def test_firmware_hash_hw(client: Client) -> None: - if client.features.fw_vendor == "EMULATOR": +def test_firmware_hash_hw(session: Session) -> None: + if session.features.fw_vendor == "EMULATOR": pytest.skip("Only for hardware") # TODO get firmware image from outside the environment, check for actual result challenge = b"Hello Trezor" - empty_data = b"\xff" * FIRMWARE_LENGTHS[client.model] + empty_data = b"\xff" * FIRMWARE_LENGTHS[session.model] empty_hash = blake2s(empty_data).digest() empty_hash_challenge = blake2s(empty_data, key=challenge).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash != empty_hash - hash2 = firmware.get_hash(client, challenge) + hash2 = firmware.get_hash(session, challenge) assert hash != hash2 assert hash2 != empty_hash_challenge diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index d313608ee20..73bbb87ee3b 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -23,6 +23,7 @@ from trezorlib import debuglink, device, exceptions, messages, models from trezorlib._internal import translations +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters @@ -57,228 +58,235 @@ def get_ping_title(lang: str) -> str: @pytest.fixture -def client(client: Client) -> Iterator[Client]: - lang_before = client.features.language or "" +def session(session: Session) -> Iterator[Session]: + lang_before = session.features.language or "" try: - set_language(client, "en") - yield client + set_language(session, "en") + yield session finally: - set_language(client, lang_before[:2]) + set_language(session, lang_before[:2]) -def _check_ping_screen_texts(client: Client, title: str, right_button: str) -> None: - def ping_input_flow(client: Client, title: str, right_button: str): +def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> None: + def ping_input_flow(session: Session, title: str, right_button: str): yield - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert layout.title().upper() == title.upper() assert layout.button_contents()[-1].upper() == right_button.upper() - client.debug.press_yes() + session.client.debug.press_yes() # TT does not have a right button text (but a green OK tick) - if client.model in (models.T2T1, models.T3T1): + if session.model in (models.T2T1, models.T3T1): right_button = "-" - with client: + with session, session.client as client: client.watch_layout(True) - client.set_input_flow(ping_input_flow(client, title, right_button)) - ping = client.call(messages.Ping(message="ahoj!", button_protection=True)) + client.set_input_flow(ping_input_flow(session, title, right_button)) + ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") -def test_error_too_long(client: Client): - assert client.features.language == "en-US" +def test_error_too_long(session: Session): + assert session.features.language == "en-US" # Translations too long # Sending more than allowed by the flash capacity - max_length = MAX_DATA_LENGTH[client.model] - with pytest.raises(exceptions.TrezorFailure, match="Translations too long"), client: + max_length = MAX_DATA_LENGTH[session.model] + with pytest.raises( + exceptions.TrezorFailure, match="Translations too long" + ), session: bad_data = (max_length + 1) * b"a" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_length(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_length(session: Session): + assert session.features.language == "en-US" # Invalid data length # Sending more data than advertised in the header - with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), client: - good_data = build_and_sign_blob("cs", client) + with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data + b"abcd" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_header_magic(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_header_magic(session: Session): + assert session.features.language == "en-US" # Invalid header magic # Does not match the expected magic with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = 4 * b"a" + good_data[4:] - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_hash(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_hash(session: Session): + assert session.features.language == "en-US" # Invalid data hash # Changing the data after their hash has been calculated with pytest.raises( exceptions.TrezorFailure, match="Translation data verification failed" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data[:-8] + 8 * b"a" device.change_language( - client, + session, language_data=bad_data, ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_version_mismatch(client: Client): - assert client.features.language == "en-US" +def test_error_version_mismatch(session: Session): + assert session.features.language == "en-US" # Translations version mismatch # Change the version to one not matching the current device with pytest.raises( exceptions.TrezorFailure, match="Translations version mismatch" - ), client: - blob = prepare_blob("cs", client.model, (3, 5, 4, 0)) + ), session: + blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) device.change_language( - client, + session, language_data=sign_blob(blob), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_signature(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_signature(session: Session): + assert session.features.language == "en-US" # Invalid signature # Changing the data in the signature section with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - blob = prepare_blob("cs", client.model, client.version) + ), session: + blob = prepare_blob("cs", session.model, session.version) blob.proof = translations.Proof( merkle_proof=[], sigmask=0b011, signature=b"a" * 64, ) device.change_language( - client, + session, language_data=blob.build(), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) @pytest.mark.parametrize("lang", LANGUAGES) -def test_full_language_change(client: Client, lang: str): - assert client.features.language == "en-US" - assert client.features.language_version_matches is True +def test_full_language_change(session: Session, lang: str): + assert session.features.language == "en-US" + assert session.features.language_version_matches is True # Setting selected language - set_language(client, lang) - assert client.features.language[:2] == lang - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + set_language(session, lang) + assert session.features.language[:2] == lang + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) # Setting the default language via empty data - set_language(client, "en") - assert client.features.language == "en-US" - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + set_language(session, "en") + assert session.features.language == "en-US" + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def test_language_is_removed_after_wipe(client: Client): - assert client.features.language == "en-US" + session = Session(client.get_session()) + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Setting cs language - set_language(client, "cs") - assert client.features.language == "cs-CZ" + set_language(session, "cs") + assert session.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Wipe device - device.wipe(client) - assert client.features.language == "en-US" + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_seedless_session()) + assert session.features.language == "en-US" # Load it again debuglink.load_device( - client, + session, mnemonic=" ".join(["all"] * 12), pin=None, passphrase_protection=False, label="test", ) - assert client.features.language == "en-US" + assert session.features.language == "en-US" + + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) +def test_translations_renders_on_screen(session: Session): -def test_translations_renders_on_screen(client: Client): czech_data = get_lang_json("cs") # Setting some values of words__confirm key and checking that in ping screen title - assert client.features.language == "en-US" + assert session.features.language == "en-US" # Normal english - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) - + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Normal czech - set_language(client, "cs") - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + set_language(session, "cs") + + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Modified czech - changed value czech_data_copy = deepcopy(czech_data) new_czech_confirm = "ABCD" czech_data_copy["translations"]["words__confirm"] = new_czech_confirm device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, new_czech_confirm, get_ping_button("cs")) + _check_ping_screen_texts(session, new_czech_confirm, get_ping_button("cs")) # Modified czech - key deleted completely, english is shown czech_data_copy = deepcopy(czech_data) del czech_data_copy["translations"]["words__confirm"] device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("cs")) + +def test_reject_update(session: Session): -def test_reject_update(client: Client): - assert client.features.language == "en-US" + assert session.features.language == "en-US" lang = "cs" - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) def input_flow_reject(): yield - client.debug.press_no() + session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), client: + with pytest.raises(exceptions.Cancelled), session, session.client as client: client.set_input_flow(input_flow_reject) - device.change_language(client, language_data) + device.change_language(session, language_data) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def _maybe_confirm_set_language( - client: Client, lang: str, show_display: bool | None, is_displayed: bool + session: Session, lang: str, show_display: bool | None, is_displayed: bool ) -> None: - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) CHUNK_SIZE = 1024 @@ -289,34 +297,35 @@ def chunks(data, size): expected_responses_silent: list[Any] = [ messages.TranslationDataRequest(data_offset=off, data_length=len) for off, len in chunks(language_data, CHUNK_SIZE) - ] + [message_filters.Success(), message_filters.Features()] + ] + [message_filters.Success()] + # , message_filters.Features()] expected_responses_confirm = expected_responses_silent[:] # confirmation after first TranslationDataRequest expected_responses_confirm.insert(1, message_filters.ButtonRequest()) # success screen before Success / Features - expected_responses_confirm.insert(-2, message_filters.ButtonRequest()) + expected_responses_confirm.insert(-1, message_filters.ButtonRequest()) if is_displayed: expected_responses = expected_responses_confirm else: expected_responses = expected_responses_silent - with client: - client.set_expected_responses(expected_responses) - device.change_language(client, language_data, show_display=show_display) - assert client.features.language is not None - assert client.features.language[:2] == lang + with session: + session.set_expected_responses(expected_responses) + device.change_language(session, language_data, show_display=show_display) + assert session.features.language is not None + assert session.features.language[:2] == lang # explicitly handle the cases when expected_responses are correct for # change_language but incorrect for selected is_displayed mode (otherwise the # user would get an unhelpful generic expected_responses mismatch) - if is_displayed and client.actual_responses == expected_responses_silent: + if is_displayed and session.actual_responses == expected_responses_silent: raise AssertionError("Change should have been visible but was silent") - if not is_displayed and client.actual_responses == expected_responses_confirm: + if not is_displayed and session.actual_responses == expected_responses_confirm: raise AssertionError("Change should have been silent but was visible") # if the expected_responses do not match either, the generic error message will - # be raised by the client context manager + # be raised by the session context manager @pytest.mark.parametrize( @@ -328,61 +337,64 @@ def chunks(data, size): ], ) @pytest.mark.setup_client(uninitialized=True) -def test_silent_first_install(client: Client, show_display: bool, is_displayed: bool): - assert not client.features.initialized - _maybe_confirm_set_language(client, "cs", show_display, is_displayed) +@pytest.mark.uninitialized_session +def test_silent_first_install(session: Session, show_display: bool, is_displayed: bool): + assert not session.features.initialized + _maybe_confirm_set_language(session, "cs", show_display, is_displayed) @pytest.mark.parametrize("show_display", (True, None)) -def test_switch_from_english(client: Client, show_display: bool | None): - assert client.features.initialized - assert client.features.language == "en-US" - _maybe_confirm_set_language(client, "cs", show_display, True) +def test_switch_from_english(session: Session, show_display: bool | None): + assert session.features.initialized + assert session.features.language == "en-US" + _maybe_confirm_set_language(session, "cs", show_display, True) -def test_switch_from_english_not_silent(client: Client): - assert client.features.initialized - assert client.features.language == "en-US" +def test_switch_from_english_not_silent(session: Session): + assert session.features.initialized + assert session.features.language == "en-US" with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) @pytest.mark.setup_client(uninitialized=True) -def test_switch_language(client: Client): - assert not client.features.initialized - assert client.features.language == "en-US" +@pytest.mark.uninitialized_session +def test_switch_language(session: Session): + assert not session.features.initialized + assert session.features.language == "en-US" # switch to Czech silently - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) # switch to French silently with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "fr", False, False) + _maybe_confirm_set_language(session, "fr", False, False) # switch to French with display, explicitly - _maybe_confirm_set_language(client, "fr", True, True) + _maybe_confirm_set_language(session, "fr", True, True) # switch back to Czech with display, implicitly - _maybe_confirm_set_language(client, "cs", None, True) + _maybe_confirm_set_language(session, "cs", None, True) -def test_header_trailing_data(client: Client): +def test_header_trailing_data(session: Session): """Adding trailing data to _header_ section specifically must be accepted by firmware, as long as the blob is otherwise valid and signed. (this ensures forwards compatibility if we extend the header) """ - assert client.features.language == "en-US" + + assert session.features.language == "en-US" lang = "cs" - blob = prepare_blob(lang, client.model, client.version) + blob = prepare_blob(lang, session.model, session.version) blob.header_bytes += b"trailing dataa" assert len(blob.header_bytes) % 2 == 0, "Trailing data must keep the 2-alignment" language_data = sign_blob(blob) - device.change_language(client, language_data) - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + device.change_language(session, language_data) + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 9e3161bb8be..ff5776a5467 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,7 +19,8 @@ import pytest from trezorlib import btc, device, exceptions, messages, misc, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..input_flows import InputFlowConfirmAllWarnings @@ -30,7 +31,7 @@ EXPECTED_RESPONSES_NOPIN = [ messages.ButtonRequest(), messages.Success, - messages.Features, + # messages.Features, ] EXPECTED_RESPONSES_PIN_T1 = [messages.PinMatrixRequest()] + EXPECTED_RESPONSES_NOPIN EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPIN @@ -38,7 +39,7 @@ EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES = [ messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] PIN4 = "1234" @@ -50,173 +51,178 @@ TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:. from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_ping(client: Client): - with client: - client.set_expected_responses([messages.Success]) - res = client.ping("random data") - assert res == "random data" +def test_ping(session: Session): + with session: + session.set_expected_responses([messages.Success]) + res = session.call(messages.Ping(message="random data")) + assert res.message == "random data" - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.Success, ] ) - res = client.ping("random data", button_protection=True) - assert res == "random data" + res = session.call( + messages.Ping(message="random data 2", button_protection=True) + ) + assert res.message == "random data 2" diff --git a/tests/device_tests/test_msg_sd_protect.py b/tests/device_tests/test_msg_sd_protect.py index fb305613825..7c509d95ff2 100644 --- a/tests/device_tests/test_msg_sd_protect.py +++ b/tests/device_tests/test_msg_sd_protect.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op @@ -26,64 +27,71 @@ pytestmark = [pytest.mark.models("core", skip="safe3"), pytest.mark.sd_card] -def test_enable_disable(client: Client): - assert client.features.sd_protection is False +def test_enable_disable(session: Session): + assert session.features.sd_protection is False # Disabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.DISABLE) + device.sd_protect(session, Op.DISABLE) # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Enabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False -def test_refresh(client: Client): - assert client.features.sd_protection is False +def test_refresh(session: Session): + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is True + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False # Refreshing SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is False + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is False def test_wipe(client: Client): + session = client.get_seedless_session() # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Wipe device (this wipes internal storage) - device.wipe(client) - assert client.features.sd_protection is False + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() + assert session.features.sd_protection is False # Restore device to working status debuglink.load_device( - client, mnemonic=MNEMONIC12, pin=None, passphrase_protection=False, label="test" + session, + mnemonic=MNEMONIC12, + pin=None, + passphrase_protection=False, + label="test", ) - assert client.features.sd_protection is False + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) + device.sd_protect(session, Op.REFRESH) diff --git a/tests/device_tests/test_msg_show_device_tutorial.py b/tests/device_tests/test_msg_show_device_tutorial.py index 52904c50c50..f6a083879f3 100644 --- a/tests/device_tests/test_msg_show_device_tutorial.py +++ b/tests/device_tests/test_msg_show_device_tutorial.py @@ -17,11 +17,11 @@ import pytest from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("safe") -def test_tutorial(client: Client): - device.show_device_tutorial(client) - assert client.features.initialized is False +def test_tutorial(session: Session): + device.show_device_tutorial(session) + assert session.features.initialized is False diff --git a/tests/device_tests/test_msg_wipedevice.py b/tests/device_tests/test_msg_wipedevice.py index 6009dd624d3..c758f72e769 100644 --- a/tests/device_tests/test_msg_wipedevice.py +++ b/tests/device_tests/test_msg_wipedevice.py @@ -19,6 +19,7 @@ import pytest from trezorlib import device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from ..common import get_test_address @@ -31,31 +32,35 @@ def test_wipe_device(client: Client): assert client.features.initialized is True assert client.features.label == "test" assert client.features.passphrase_protection is True - device_id = client.get_device_id() - - device.wipe(client) + device_id = client.features.device_id + device.wipe(client.get_session()) + client = client.get_new_client() assert client.features.initialized is False assert client.features.label is None assert client.features.passphrase_protection is False - assert client.get_device_id() != device_id + assert client.features.device_id != device_id @pytest.mark.setup_client(pin=PIN4) -def test_autolock_not_retained(client: Client): +def test_autolock_not_retained(session: Session): + client = session.client with client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, auto_lock_delay_ms=10_000) + device.apply_settings(session, auto_lock_delay_ms=10_000) + + assert session.features.auto_lock_delay_ms == 10_000 - assert client.features.auto_lock_delay_ms == 10_000 + device.wipe(session) + client = client.get_new_client() + session = client.get_seedless_session() - device.wipe(client) assert client.features.auto_lock_delay_ms > 10_000 with client: client.use_pin_sequence([PIN4, PIN4]) device.setup( - client, + session, skip_backup=True, pin_protection=True, passphrase_protection=False, @@ -64,7 +69,9 @@ def test_autolock_not_retained(client: Client): ) time.sleep(10.5) - with client: + session = Session(client.get_session()) + + with session, client: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_passphrase_slip39_advanced.py b/tests/device_tests/test_passphrase_slip39_advanced.py index 64ef1f5e577..89a68fb1de2 100644 --- a/tests/device_tests/test_passphrase_slip39_advanced.py +++ b/tests/device_tests/test_passphrase_slip39_advanced.py @@ -34,14 +34,14 @@ def test_128bit_passphrase(client: Client): xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mkKDUMRR1CcK8eLAzCZAjKnNbCquPoWPxN" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare + assert address_compare == "n1HeeeojjHgQnG6Bf5VWkM1gcpQkkXqSGw" @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_33, passphrase=True) @@ -53,11 +53,10 @@ def test_256bit_passphrase(client: Client): xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mxVtGxUJ898WLzPMmy6PT1FDHD1GUCWGm7" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare diff --git a/tests/device_tests/test_passphrase_slip39_basic.py b/tests/device_tests/test_passphrase_slip39_basic.py index de0e7a734b2..120a6f556ec 100644 --- a/tests/device_tests/test_passphrase_slip39_basic.py +++ b/tests/device_tests/test_passphrase_slip39_basic.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -28,14 +28,14 @@ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6, passphrase="TREZOR") -def test_3of6_passphrase(client: Client): +def test_3of6_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mi4HXfRJAqCDyEdet5veunBvXLTKSxpuim" @@ -46,25 +46,25 @@ def test_3of6_passphrase(client: Client): ), passphrase="TREZOR", ) -def test_2of5_passphrase(client: Client): +def test_2of5_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mjXH4pN7TtbHp3tWLqVKktKuaQeByHMoBZ" @pytest.mark.setup_client( mnemonic=MNEMONIC_SLIP39_BASIC_EXT_20_2of3, passphrase="TREZOR" ) -def test_2of3_ext_passphrase(client: Client): +def test_2of3_ext_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: xprv9s21ZrQH143K4FS1qQdXYAFVAHiSAnjj21YAKGh2CqUPJ2yQhMmYGT4e5a2tyGLiVsRgTEvajXkxhg92zJ8zmWZas9LguQWz7WZShfJg6RS """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "moELJhDbGK41k6J2ePYh2U8uc5qskC663C" diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index ee58790c046..c911dfee503 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -19,7 +19,7 @@ import pytest from trezorlib import messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import PinException from ..common import check_pin_backoff_time, get_test_address @@ -32,18 +32,18 @@ @pytest.mark.setup_client(pin=None) -def test_no_protection(client: Client): - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_no_protection(session: Session): + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) -def test_correct_pin(client: Client): - with client: +def test_correct_pin(session: Session): + with session, session.client as client: client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ (is_t1, messages.PinMatrixRequest), ( @@ -53,45 +53,44 @@ def test_correct_pin(client: Client): messages.Address, ] ) - # client.set_expected_responses([messages.ButtonRequest, messages.Address]) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_incorrect_pin_t1(client: Client): +def test_incorrect_pin_t1(session: Session): with pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + session.client.use_pin_sequence([BAD_PIN]) + get_test_address(session) @pytest.mark.models("core") -def test_incorrect_pin_t2(client: Client): - with client: +def test_incorrect_pin_t2(session: Session): + with session, session.client as client: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt client.use_pin_sequence([BAD_PIN, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.Address, ] ) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_exponential_backoff_t1(client: Client): +def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with client, pytest.raises(PinException): + with session, session.client as client, pytest.raises(PinException): client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") -def test_exponential_backoff_t2(client: Client): - with client: +def test_exponential_backoff_t2(session: Session): + with session.client as client: IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) client.set_input_flow(IF.get()) - get_test_address(client) + get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index be2a3a81e0d..16190728a31 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -17,8 +17,9 @@ import pytest from trezorlib import btc, device, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,186 +44,226 @@ pytestmark = pytest.mark.setup_client(pin=PIN4, passphrase=True) -def _pin_request(client: Client): +def _pin_request(session: Session): """Get appropriate PIN request for each model""" - if client.model is models.T1B1: + if session.model is models.T1B1: return messages.PinMatrixRequest else: return messages.ButtonRequest(code=B.PinEntry) def _assert_protection( - client: Client, pin: bool = True, passphrase: bool = True -) -> None: + session: Session, pin: bool = True, passphrase: bool = True +) -> Session: """Make sure PIN and passphrase protection have expected values""" - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.ensure_unlocked() + session.ensure_unlocked() + client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase - client.clear_session() + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: + new_session = session.client.get_session() + session.lock() + # session.end() + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + new_session = session.client.get_session() + return Session(new_session) -def test_initialize(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses([messages.Features]) - client.init_device() +def test_initialize(session: Session): + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: + # Test is skipped for THP + return + + with session, session.client as client: + client.use_pin_sequence([PIN4]) + session.ensure_unlocked() + session = _assert_protection(session) + with session: + session.set_expected_responses([messages.Features]) + session.call(messages.Initialize(session_id=session.id)) @pytest.mark.models("core") @pytest.mark.setup_client(pin=PIN4) @pytest.mark.parametrize("passphrase", (True, False)) -def test_passphrase_reporting(client: Client, passphrase): +def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, use_passphrase=passphrase) + device.apply_settings(session, use_passphrase=passphrase) - client.lock() + session.lock() # on a locked device, passphrase_protection should be None - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + assert session.features.unlocked is False + assert session.features.passphrase_protection is None # on an unlocked device, protection should be reported accurately - _assert_protection(client, pin=True, passphrase=passphrase) + session = _assert_protection(session, pin=True, passphrase=passphrase) # after re-locking, the setting should be hidden again - client.lock() - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + session.lock() + assert session.features.unlocked is False + assert session.features.passphrase_protection is None -def test_apply_settings(client: Client): - _assert_protection(client) - with client: +def test_apply_settings(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] - ) # TrezorClient reinitializes device - device.apply_settings(client, label="nazdar") + ) + device.apply_settings(session, label="nazdar") @pytest.mark.models("legacy") -def test_change_pin_t1(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t1(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - _pin_request(client), + _pin_request(session), + _pin_request(session), + _pin_request(session), messages.Success, messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.models("core") -def test_change_pin_t2(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t2(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - (client.layout_type is LayoutType.Caesar, messages.ButtonRequest), - _pin_request(client), + _pin_request(session), + _pin_request(session), + ( + session.client.layout_type is LayoutType.Caesar, + messages.ButtonRequest, + ), + _pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.setup_client(pin=None, passphrase=False) -def test_ping(client: Client): - _assert_protection(client, pin=False, passphrase=False) - with client: - client.set_expected_responses([messages.ButtonRequest, messages.Success]) - client.ping("msg", True) +def test_ping(session: Session): + session = _assert_protection(session, pin=False, passphrase=False) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + session.call(messages.Ping(message="msg", button_protection=True)) -def test_get_entropy(client: Client): - _assert_protection(client) - with client: +def test_get_entropy(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest(code=B.ProtectCall), messages.Entropy, ] ) - misc.get_entropy(client, 10) + misc.get_entropy(session, 10) + +def test_get_public_key(session: Session): + session = _assert_protection(session) -def test_get_public_key(client: Client): - _assert_protection(client) - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.PublicKey, - ] - ) - btc.get_public_node(client, []) + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.append(messages.PublicKey) -def test_get_address(client: Client): - _assert_protection(client) - with client: - client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.Address, - ] - ) - get_test_address(client) + session.set_expected_responses(expected_responses) + btc.get_public_node(session, []) -def test_wipe_device(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( - [messages.ButtonRequest, messages.Success, messages.Features] - ) - device.wipe(client) +def test_get_address(session: Session): + session = _assert_protection(session) + + with session, session.client as client: + client.use_pin_sequence([PIN4]) + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.append(messages.Address) + + session.set_expected_responses(expected_responses) + + get_test_address(session) + + +def test_wipe_device(session: Session): + # Precise cause of crash is not determined, it happens with some order of + # tests, but not with all. The following leads to crash: + # pytest --random-order-seed=675848 tests/device_tests/test_protection_levels.py + # + # Traceback (most recent call last): + # File "trezor/wire/__init__.py", line 70, in handle_session + # File "trezor/wire/thp_main.py", line 79, in thp_main_loop + # File "trezor/wire/thp_main.py", line 145, in _handle_allocated + # File "trezor/wire/thp/received_message_handler.py", line 123, in handle_received_message + # File "trezor/wire/thp/received_message_handler.py", line 231, in _handle_state_TH1 + # File "trezor/wire/thp/crypto.py", line 93, in handle_th1_crypto + # File "trezor/wire/thp/crypto.py", line 178, in _derive_static_key_pair + # File "storage/device.py", line 364, in get_device_secret + # File "storage/common.py", line 21, in set + # RuntimeError: Could not save value + + session = _assert_protection(session) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + device.wipe(session) + client = session.client.get_new_client() + session = Session(client.get_seedless_session()) + with session, session.client as client: + client.use_pin_sequence([PIN4]) + session.set_expected_responses([messages.Features]) + session.call(messages.GetFeatures()) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_reset_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - with client: - client.set_expected_responses( +def test_reset_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.EntropyRequest] + [messages.ButtonRequest] * 24 + [messages.Success, messages.Features] ) device.setup( - client, + session, strength=128, passphrase_protection=True, pin_protection=False, @@ -230,11 +271,12 @@ def test_reset_device(client: Client): entropy_check_count=0, _get_entropy=MOCK_GET_ENTROPY, ) + session.call(messages.GetFeatures()) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.setup` has its own check - client.call( + session.call( messages.ResetDevice( strength=128, passphrase_protection=True, @@ -246,30 +288,30 @@ def test_reset_device(client: Client): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_recovery_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - client.use_mnemonic(MNEMONIC12) - with client: - client.set_expected_responses( +def test_recovery_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + session.client.use_mnemonic(MNEMONIC12) + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.WordRequest] * 24 - + [messages.Success, messages.Features] + + [messages.Success] # , messages.Features] ) device.recover( - client, + session, 12, False, False, "label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.recover` has its own check - client.call( + session.call( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -279,29 +321,37 @@ def test_recovery_device(client: Client): ) -def test_sign_message(client: Client): - _assert_protection(client) - with client: +def test_sign_message(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + + expected_responses = [_pin_request(session)] + + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + + expected_responses.extend( [ - _pin_request(client), - messages.PassphraseRequest, messages.ButtonRequest, messages.ButtonRequest, messages.MessageSignature, ] ) + + session.set_expected_responses(expected_responses) + btc.sign_message( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" ) @pytest.mark.models("legacy") -def test_verify_message_t1(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( +def test_verify_message_t1(session: Session): + session = _assert_protection(session) + with session: + session.set_expected_responses( [ messages.ButtonRequest, messages.ButtonRequest, @@ -310,7 +360,7 @@ def test_verify_message_t1(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -321,13 +371,13 @@ def test_verify_message_t1(client: Client): @pytest.mark.models("core") -def test_verify_message_t2(client: Client): - _assert_protection(client) - with client: +def test_verify_message_t2(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.ButtonRequest, messages.ButtonRequest, @@ -335,7 +385,7 @@ def test_verify_message_t2(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -345,7 +395,7 @@ def test_verify_message_t2(client: Client): ) -def test_signtx(client: Client): +def test_signtx(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -361,17 +411,18 @@ def test_signtx(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _assert_protection(client) - with client: + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.extend( [ - _pin_request(client), - messages.PassphraseRequest, request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -384,7 +435,9 @@ def test_signtx(client: Client): request_finished(), ] ) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) + session.set_expected_responses(expected_responses) + + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) # def test_firmware_erase(): @@ -395,29 +448,37 @@ def test_signtx(client: Client): @pytest.mark.setup_client(pin=PIN4, passphrase=False) -def test_unlocked(client: Client): - assert client.features.unlocked is False +def test_unlocked(session: Session): + assert session.features.unlocked is False + + session = _assert_protection(session, passphrase=False) - _assert_protection(client, passphrase=False) - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([_pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([_pin_request(session), messages.Address]) + get_test_address(session) - client.init_device() - assert client.features.unlocked is True - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.refresh_features() + assert session.features.unlocked is True + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) @pytest.mark.setup_client(pin=None, passphrase=True) -def test_passphrase_cached(client: Client): - _assert_protection(client, pin=False) - with client: - client.set_expected_responses([messages.PassphraseRequest, messages.Address]) - get_test_address(client) - - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_passphrase_cached(session: Session): + session = _assert_protection(session, pin=False) + with session: + if session.protocol_version == 1: + session.set_expected_responses( + [messages.PassphraseRequest, messages.Address] + ) + elif session.protocol_version == 2: + session.set_expected_responses([messages.Address]) + else: + raise Exception("Unknown session type") + get_test_address(session) + + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 9fc25ad202d..601c898fbbd 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -17,8 +17,8 @@ import pytest -from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import device, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from .. import translations as TR @@ -33,187 +33,191 @@ @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) -def test_repeated_backup_upgrade_single(client: Client): +def test_repeated_backup_upgrade_single(session: Session): assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing - assert client.features.backup_type == messages.BackupType.Slip39_Single_Extendable + assert session.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # backup type was upgraded: - assert client.features.backup_type == messages.BackupType.Slip39_Basic_Extendable + assert session.features.backup_type == messages.BackupType.Slip39_Basic_Extendable # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup_cancel(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_cancel(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a Cancel message with pytest.raises(Cancelled): - client.call(messages.Cancel()) + session.call(messages.Cancel()) - client.refresh_features() + session.refresh_features() # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) -def test_repeated_backup_send_disallowed_message(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_send_disallowed_message(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a GetAddress message - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -224,10 +228,13 @@ def test_repeated_backup_send_disallowed_message(client: Client): assert isinstance(resp, messages.Failure) assert "not allowed" in resp.message - assert client.features.backup_availability == messages.BackupAvailability.Available - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.backup_availability == messages.BackupAvailability.Available + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we are still on the confirmation screen! assert ( - TR.recovery__unlock_repeated_backup in client.debug.read_layout().text_content() + TR.recovery__unlock_repeated_backup + in session.client.debug.read_layout().text_content() ) + with pytest.raises(exceptions.Cancelled): + session.call(messages.Cancel()) diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 69098d81df7..8d5c45b81fc 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -17,111 +17,117 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op from .. import translations as TR +PIN = "1234" + pytestmark = pytest.mark.models("core", skip="safe3") @pytest.mark.sd_card(formatted=False) -def test_sd_format(client: Client): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True +def test_sd_format(session: Session): + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True @pytest.mark.sd_card(formatted=False) -def test_sd_no_format(client: Client): +def test_sd_no_format(session: Session): + debug = session.client.debug + def input_flow(): yield # enable SD protection? - client.debug.press_yes() + debug.press_yes() yield # format SD card - client.debug.press_no() + debug.press_no() - with pytest.raises(TrezorFailure) as e, client: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.set_input_flow(input_flow) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @pytest.mark.sd_card -@pytest.mark.setup_client(pin="1234") -def test_sd_protect_unlock(client: Client): - layout = client.debug.read_layout +@pytest.mark.setup_client(pin=PIN) +def test_sd_protect_unlock(session: Session): + debug = session.client.debug + layout = debug.read_layout def input_flow_enable_sd_protect(): + # debug.press_yes() yield # Enter PIN to unlock device assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # do you really want to enable SD protection assert TR.sd_card__enable in layout().text_content() - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # you have successfully enabled SD protection assert TR.sd_card__enabled in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_enable_sd_protect) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN again assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # Pin change successful assert TR.pin__changed in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_change_pin) - device.change_pin(client) + device.change_pin(session) - client.debug.erase_sd_card(format=False) + debug.erase_sd_card(format=False) def input_flow_change_pin_format(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # SD card problem assert ( TR.sd_card__unplug_and_insert_correct in layout().text_content() or TR.sd_card__insert_correct_card in layout().text_content() ) - client.debug.press_no() # close + debug.press_no() # close - with client, pytest.raises(TrezorFailure) as e: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.watch_layout() client.set_input_flow(input_flow_change_pin_format) - device.change_pin(client) + device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index a8020d0354d..3fbc860db7b 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -18,6 +18,7 @@ from trezorlib import cardano, messages, models from trezorlib.btc import get_public_node +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -30,6 +31,16 @@ PIN4 = "1234" +@pytest.mark.protocol("protocol_v2") +def test_thp_end_session(client: Client): + session = Session(client.get_session()) + + msg = session.call(messages.EndSession()) + assert isinstance(msg, messages.Success) + with pytest.raises(TrezorFailure, match="ThpUnallocatedSession"): + session.call(messages.GetFeatures()) + + @pytest.mark.setup_client(pin=PIN4, passphrase="") def test_clear_session(client: Client): is_t1 = client.model is models.T1B1 @@ -39,100 +50,105 @@ def test_clear_session(client: Client): ] cached_responses = [messages.PublicKey] - - with client: + session = Session(client.get_session()) + session.lock() + with client, session: client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(init_responses + cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + client.resume_session(session) + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - client.clear_session() + session.lock() + session.end() + session = Session(client.get_session()) # session cache is cleared - with client: + with client, session: client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(init_responses + cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + client.resume_session(session) + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB def test_end_session(client: Client): # client instance starts out not initialized # XXX do we want to change this? - assert client.session_id is not None + session = client.get_session() + assert session.id is not None # get_address will succeed - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + with Session(session) as session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - client.end_session() - assert client.session_id is None + session.end() + # assert client.session_id is None with pytest.raises(TrezorFailure) as exc: - get_test_address(client) + get_test_address(session) assert exc.value.code == messages.FailureType.InvalidSession assert exc.value.message.endswith("Invalid session") - client.init_device() - assert client.session_id is not None - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session = client.get_session() + assert session.id is not None + with Session(session) as session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - with client: - # end_session should succeed on empty session too - client.set_expected_responses([messages.Success] * 2) - client.end_session() - client.end_session() + # TODO: is the following valid? I do not think so + # with Session(session) as session: + # # end_session should succeed on empty session too + # session.set_expected_responses([messages.Success] * 2) + # session.end_session() + # session.end_session() def test_cannot_resume_ended_session(client: Client): - session_id = client.session_id - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + session = client.get_session() + session_id = session.id + + session_resumed = client.resume_session(session) - assert session_id == client.session_id + assert session_resumed.id == session_id - client.end_session() - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + session.end() + session_resumed2 = client.resume_session(session) - assert session_id != client.session_id + assert session_resumed2.id != session_id def test_end_session_only_current(client: Client): """test that EndSession only destroys the current session""" - session_id_a = client.session_id - client.init_device(new_session=True) - session_id_b = client.session_id + session_a = client.get_session() + session_b = client.get_session() + session_b_id = session_b.id - client.end_session() - assert client.session_id is None + session_b.end() + # assert client.session_id is None # resume ended session - client.init_device(session_id=session_id_b) - assert client.session_id != session_id_b + session_b_resumed = client.resume_session(session_b) + assert session_b_resumed.id != session_b_id # resume first session that was not ended - client.init_device(session_id=session_id_a) - assert client.session_id == session_id_a + session_a_resumed = client.resume_session(session_a) + assert session_a_resumed.id == session_a.id @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): - session_id_orig = client.session_id - with client: - client.set_expected_responses( + session = Session(client.get_session(passphrase="TREZOR")) + with client, session: + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -141,20 +157,22 @@ def test_session_recycling(client: Client): ] ) client.use_passphrase("TREZOR") - address = get_test_address(client) + _ = get_test_address(session) + # address = get_test_address(session) # create and close 100 sessions - more than the session limit for _ in range(100): - client.init_device(new_session=True) - client.end_session() + session_x = client.get_session() + session_x.end() # it should still be possible to resume the original session - with client: - # passphrase should still be cached - client.set_expected_responses([messages.Features, messages.Address]) - client.use_passphrase("TREZOR") - client.init_device(session_id=session_id_orig) - assert address == get_test_address(client) + # TODO imo not True anymore + # with client, session: + # # passphrase should still be cached + # session.set_expected_responses([messages.Features, messages.Address]) + # client.use_passphrase("TREZOR") + # client.resume_session(session) + # assert address == get_test_address(session) @pytest.mark.altcoin @@ -162,18 +180,19 @@ def test_session_recycling(client: Client): @pytest.mark.models("core") def test_derive_cardano_empty_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=True) + # session_id = client.session_id # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session2 = client.resume_session(session) + assert session.id == session2.id # restarting same session should go well with any setting - client.init_device(derive_cardano=False) - assert session_id == client.session_id - client.init_device(derive_cardano=True) - assert session_id == client.session_id + # TODO I do not think that it holds True now + # client.init_device(derive_cardano=False) + # assert session_id == client.session_id + # client.init_device(derive_cardano=True) + # assert session_id == client.session_id @pytest.mark.altcoin @@ -181,43 +200,41 @@ def test_derive_cardano_empty_session(client: Client): @pytest.mark.models("core") def test_derive_cardano_running_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=False) + # force derivation of seed - get_test_address(client) + get_test_address(session) # session should not have Cardano capability with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session, parse_path("m/44h/1815h/0h")) # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session2 = client.resume_session(session) + assert session.id == session2.id - # restarting same session should go well if we _don't_ want to derive cardano - client.init_device(derive_cardano=False) - assert session_id == client.session_id + # TODO restarting same session should go well if we _don't_ want to derive cardano + # # client.init_device(derive_cardano=False) + # # assert session_id == client.session_id # restarting with derive_cardano=True should kill old session and create new one - client.init_device(derive_cardano=True) - assert session_id != client.session_id - - session_id = client.session_id + session3 = client.get_session(derive_cardano=True) + assert session3.id != session.id # new session should have Cardano capability - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session3, parse_path("m/44h/1815h/0h")) # restarting with derive_cardano=True should keep same session - client.init_device(derive_cardano=True) - assert session_id == client.session_id + session4 = client.resume_session(session3) + assert session4.id == session3.id - # restarting with no setting should keep same session - client.init_device() - assert session_id == client.session_id + # # restarting with no setting should keep same session + # client.init_device() + # assert session_id == client.session_id - # restarting with derive_cardano=False should kill old session and create new one - client.init_device(derive_cardano=False) - assert session_id != client.session_id + # # restarting with derive_cardano=False should kill old session and create new one + # client.init_device(derive_cardano=False) + # assert session_id != client.session_id - with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + # with pytest.raises(TrezorFailure, match="not enabled"): + # cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 51a2c0731fd..ee1b20ab02d 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -19,7 +19,9 @@ import pytest from trezorlib import device, exceptions, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import FailureType, SafetyCheckLevel @@ -49,19 +51,13 @@ SESSIONS_STORED = 10 -def _init_session(client: Client, session_id=None, derive_cardano=False): - """Call Initialize, check and return the session ID.""" - response = client.call( - messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) - ) - assert isinstance(response, messages.Features) - assert len(response.session_id) == 32 - return response.session_id - - -def _get_xpub(client: Client, passphrase=None): +def _get_xpub( + session: Session, + expected_passphrase_req: bool = False, + passphrase_v1: str | None = None, +): """Get XPUB and check that the appropriate passphrase flow has happened.""" - if passphrase is not None: + if expected_passphrase_req: expected_responses = [ messages.PassphraseRequest, messages.ButtonRequest, @@ -70,126 +66,143 @@ def _get_xpub(client: Client, passphrase=None): ] else: expected_responses = [messages.PublicKey] - - with client: - client.use_passphrase(passphrase or "") - client.set_expected_responses(expected_responses) - result = client.call(XPUB_REQUEST) + if ( + passphrase_v1 is not None + and session.protocol_version == ProtocolVersion.PROTOCOL_V1 + ): + session.passphrase = passphrase_v1 + + with session: + session.set_expected_responses(expected_responses) + result = session.call(XPUB_REQUEST) return result.xpub @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_session_with_passphrase(client: Client): - # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = Session(client.get_session(passphrase="A")) + session_id = session.id # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] # Call Initialize again, this time with the received session id and then call # GetPublicKey. The passphrase should be cached now so Trezor must # not ask for it again, whilst returning the same xpub. - new_session_id = _init_session(client, session_id=session_id) - assert new_session_id == session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session2 = Session(client.resume_session(session)) + assert session2.id == session_id + assert _get_xpub(session2) == XPUB_PASSPHRASES["A"] # If we set session id in Initialize to None, the cache will be cleared # and Trezor will ask for the passphrase again. - new_session_id = _init_session(client) - assert new_session_id != session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session3 = Session(client.get_session(passphrase="A")) + assert session3 != session_id + assert _get_xpub(session3, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] - # Unknown session id is the same as setting it to None. - _init_session(client, session_id=b"X" * 32) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + # TODO: The following part is kept only for solving UI-diff in tests + # - it can be removed if fixtures are updated, imo + session4 = Session(client.get_session(passphrase="A")) + assert session4 != session_id + assert _get_xpub(session4, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_multiple_sessions(client: Client): # start SESSIONS_STORED sessions session_ids = [] + sessions = [] for _ in range(SESSIONS_STORED): - session_ids.append(_init_session(client)) + session = client.get_session() + sessions.append(session) + session_ids.append(session.id) # Resume each session - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Creating a new session replaces the least-recently-used session - _init_session(client) + client.get_session() # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Resuming session 0 will not work - new_session_id = _init_session(client, session_ids[0]) - assert new_session_id != session_ids[0] + resumed_session = client.resume_session(sessions[0]) + assert session_ids[0] != resumed_session.id # New session bumped out the least-recently-used anonymous session. # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Creating a new session replaces session_ids[0] again - _init_session(client) + client.get_session() # Resuming all sessions one by one will in turn bump out the previous session. - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id != new_session_id + for i in range(SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] != resumed_session.id @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_multiple_passphrases(client: Client): # start a session - session_a = _init_session(client) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session_a = Session(client.get_session(passphrase="A")) + session_a_id = session_a.id + assert _get_xpub(session_a, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] # start it again wit the same session id - new_session_id = _init_session(client, session_id=session_a) + session_a_resumed = Session(client.resume_session(session_a)) # session is the same - assert new_session_id == session_a + assert session_a_resumed.id == session_a_id # passphrase is not prompted - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(session_a_resumed) == XPUB_PASSPHRASES["A"] # start a second session - session_b = _init_session(client) + session_b = Session(client.get_session(passphrase="B")) + session_b_id = session_b.id # new session -> new session id and passphrase prompt - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + assert _get_xpub(session_b, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] # provide the same session id -> must not ask for passphrase again. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b_resumed = Session(client.resume_session(session_b)) + assert session_b_resumed.id == session_b_id + assert _get_xpub(session_b_resumed) == XPUB_PASSPHRASES["B"] # provide the first session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_a) - assert new_session_id == session_a - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session_a_resumed_again = Session(client.resume_session(session_a)) + assert session_a_resumed_again.id == session_a_id + assert _get_xpub(session_a_resumed_again) == XPUB_PASSPHRASES["A"] # provide the second session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b_resumed_again = Session(client.resume_session(session_b)) + assert session_b_resumed_again.id == session_b_id + assert _get_xpub(session_b_resumed_again) == XPUB_PASSPHRASES["B"] @pytest.mark.slow @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_max_sessions_with_passphrases(client: Client): # for the following tests, we are using as many passphrases as there are available sessions assert len(XPUB_PASSPHRASES) == SESSIONS_STORED # start as many sessions as the limit is session_ids = {} + sessions = {} for passphrase, xpub in XPUB_PASSPHRASES.items(): - session_id = _init_session(client) - assert session_id not in session_ids.values() - session_ids[passphrase] = session_id - assert _get_xpub(client, passphrase=passphrase) == xpub + session = Session(client.get_session(passphrase=passphrase)) + assert session.id not in session_ids.values() + session_ids[passphrase] = session.id + sessions[passphrase] = session + assert _get_xpub(session, expected_passphrase_req=True) == xpub # passphrase is not prompted for the started the sessions, regardless the order # let's try 20 different orderings @@ -198,125 +211,135 @@ def test_max_sessions_with_passphrases(client: Client): for _ in range(20): random.shuffle(shuffling) for passphrase in shuffling: - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + resumed_session = Session(client.resume_session(sessions[passphrase])) + assert resumed_session.id == session_ids[passphrase] + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES[passphrase] # make sure the usage order is the reverse of the creation order for passphrase in reversed(passphrases): - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + resumed_session = Session(client.resume_session(sessions[passphrase])) + assert resumed_session.id == session_ids[passphrase] + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES[passphrase] # creating one more session will exceed the limit - _init_session(client) + new_session = Session(client.get_session(passphrase="XX")) # new session asks for passphrase - _get_xpub(client, passphrase="XX") + _get_xpub(new_session, expected_passphrase_req=True) # restoring the sessions in reverse will evict the next-up session for passphrase in reversed(passphrases): - _init_session(client, session_id=session_ids[passphrase]) - _get_xpub(client, passphrase="whatever") # passphrase is prompted + resumed_session = Session(client.resume_session(sessions[passphrase])) + _get_xpub( + resumed_session, + expected_passphrase_req=True, + passphrase_v1="whatever", + ) # passphrase is prompted +@pytest.mark.protocol("protocol_v1") def test_session_enable_passphrase(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = Session(client.get_session(passphrase="")) # Trezor will not prompt for passphrase because it is turned off. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + assert _get_xpub(session, expected_passphrase_req=False) == XPUB_PASSPHRASE_NONE # Turn on passphrase. # Emit the call explicitly to avoid ClearSession done by the library function - response = client.call(messages.ApplySettings(use_passphrase=True)) + response = session.call(messages.ApplySettings(use_passphrase=True)) assert isinstance(response, messages.Success) # The session id is unchanged, therefore we do not prompt for the passphrase. - new_session_id = _init_session(client, session_id=session_id) - assert session_id == new_session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + session_id = session.id + resumed_session = Session(client.resume_session(session)) + assert session_id == resumed_session.id + assert _get_xpub(resumed_session) == XPUB_PASSPHRASE_NONE # We clear the session id now, so the passphrase should be asked. - new_session_id = _init_session(client) - assert session_id != new_session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + new_session = Session(client.get_session(passphrase="A")) + assert session_id != new_session.id + assert _get_xpub(new_session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_on_device(client: Client): - _init_session(client) - + # _init_session(client) + session = client.get_session(passphrase="A") # try to get xpub with passphrase on host: - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) # using `client.call` to auto-skip subsequent ButtonRequests for "show passphrase" - response = client.call(messages.PassphraseAck(passphrase="A", on_device=False)) + response = session.call(messages.PassphraseAck(passphrase="A", on_device=False)) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # make a new session - _init_session(client) + session2 = session.client.get_session(passphrase="A") # try to get xpub with passphrase on device: - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(on_device=True)) + response = session2.call_raw(messages.PassphraseAck(on_device=True)) # no "show passphrase" here assert isinstance(response, messages.ButtonRequest) client.debug.input("A") - response = client.call_raw(messages.ButtonAck()) + response = session2.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_always_on_device(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = client.get_session() + # session_id = _init_session(client) # Force passphrase entry on Trezor. - response = client.call(messages.ApplySettings(passphrase_always_on_device=True)) + response = session.call(messages.ApplySettings(passphrase_always_on_device=True)) assert isinstance(response, messages.Success) # Since we enabled the always_on_device setting, Trezor will send ButtonRequests and ask for it on the device. - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("") # Input empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # Passphrase will not be prompted. The session id stays the same and the passphrase is cached. - _init_session(client, session_id=session_id) - response = client.call_raw(XPUB_REQUEST) + resumed_session = client.resume_session(session) + response = resumed_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # In case we want to add a new passphrase we need to send session_id = None. - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + new_session = client.get_session(passphrase="A") + response = new_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("A") # Input non-empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = new_session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.models("legacy") @pytest.mark.setup_client(passphrase="") +@pytest.mark.protocol("protocol_v1") def test_passphrase_on_device_not_possible_on_t1(client: Client): # This setting makes no sense on T1. response = client.call_raw(messages.ApplySettings(passphrase_always_on_device=True)) @@ -332,37 +355,42 @@ def test_passphrase_on_device_not_possible_on_t1(client: Client): @pytest.mark.setup_client(passphrase=True) -def test_passphrase_ack_mismatch(client: Client): - response = client.call_raw(XPUB_REQUEST) +@pytest.mark.protocol("protocol_v1") +def test_passphrase_ack_mismatch(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) + response = session.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase="") -def test_passphrase_missing(client: Client): - response = client.call_raw(XPUB_REQUEST) +@pytest.mark.protocol("protocol_v1") +def test_passphrase_missing(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None)) + response = session.call_raw(messages.PassphraseAck(passphrase=None)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None, on_device=False)) + response = session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=False) + ) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_length(client: Client): def call(passphrase: str, expected_result: bool): - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + session = client.get_session(passphrase=passphrase) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) try: - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=passphrase)) assert expected_result is True, "Call should have failed" assert isinstance(response, messages.PublicKey) except exceptions.TrezorFailure as e: @@ -381,19 +409,21 @@ def call(passphrase: str, expected_result: bool): @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails + session = client.get_seedless_session() with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) # Turning it on - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) passphrase = "abc" - - with client: + session = Session(client.get_session(passphrase=passphrase)) + with client, session: def input_flow(): yield @@ -410,7 +440,7 @@ def input_flow(): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -418,17 +448,17 @@ def input_flow(): ] ) client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) + result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_hidden_passphrase = result.xpub # Turning it off - device.apply_settings(client, hide_passphrase_from_host=False) + device.apply_settings(session, hide_passphrase_from_host=False) # Starting new session, otherwise the passphrase would be cached - _init_session(client) + session = Session(client.get_session(passphrase=passphrase)) - with client: + with client, session: def input_flow(): yield @@ -445,7 +475,7 @@ def input_flow(): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -454,22 +484,22 @@ def input_flow(): ] ) client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) + result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_shown_passphrase = result.xpub assert xpub_hidden_passphrase == xpub_shown_passphrase -def _get_xpub_cardano(client: Client, passphrase): +def _get_xpub_cardano(session: Session, expected_passphrase_req: bool = False): msg = messages.CardanoGetPublicKey( address_n=parse_path("m/44h/1815h/0h/0/0"), derivation_type=messages.CardanoDerivationType.ICARUS, ) - response = client.call_raw(msg) - if passphrase is not None: + response = session.call_raw(msg) + if expected_passphrase_req: assert isinstance(response, messages.PassphraseRequest) - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=session.passphrase)) assert isinstance(response, messages.CardanoPublicKey) return response.xpub @@ -477,36 +507,43 @@ def _get_xpub_cardano(client: Client, passphrase): @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_cardano_passphrase(client: Client): # Cardano has a separate derivation method that needs to access the plaintext # of the passphrase. # Historically, Cardano calls would ask for passphrase again. Now, they should not. - session_id = _init_session(client, derive_cardano=True) + # session_id = _init_session(client, derive_cardano=True) # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + session = Session(client.get_session(passphrase="B", derive_cardano=True)) + assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] # The passphrase is now cached for non-Cardano coins. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + assert _get_xpub(session) == XPUB_PASSPHRASES["B"] # The passphrase should be cached for Cardano as well - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B # Initialize with the session id does not destroy the state - _init_session(client, session_id=session_id, derive_cardano=True) - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + resumed_session = Session(client.resume_session(session)) + # _init_session(client, session_id=session_id, derive_cardano=True) + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES["B"] + assert _get_xpub_cardano(resumed_session) == XPUB_CARDANO_PASSPHRASE_B # New session will destroy the state - _init_session(client, derive_cardano=True) + new_session = Session(client.get_session(passphrase="A", derive_cardano=True)) + # _init_session(client, derive_cardano=True) # Cardano must ask for passphrase again - assert _get_xpub_cardano(client, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A + assert ( + _get_xpub_cardano(new_session, expected_passphrase_req=True) + == XPUB_CARDANO_PASSPHRASE_A + ) # Passphrase is now cached for Cardano - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_A + assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A # Passphrase is cached for non-Cardano coins too - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(new_session) == XPUB_PASSPHRASES["A"] diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 3e6b5423938..9f35118370d 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_address from trezorlib.tools import parse_path @@ -35,19 +35,19 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_tezos_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_tezos_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_tezos_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/tezos/test_getpublickey.py b/tests/device_tests/tezos/test_getpublickey.py index 9f5bfcd0f74..8b1e72609d7 100644 --- a/tests/device_tests/tezos/test_getpublickey.py +++ b/tests/device_tests/tezos/test_getpublickey.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_public_key from trezorlib.tools import parse_path @@ -24,11 +24,11 @@ @pytest.mark.altcoin @pytest.mark.tezos @pytest.mark.models("core") -def test_tezos_get_public_key(client: Client): +def test_tezos_get_public_key(session: Session): path = parse_path("m/44h/1729h/0h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkttLhEbVfMC3DhyVVFzdwh8ncRnEWiLD1x8TAuPU7vSJak7RtBX" path = parse_path("m/44h/1729h/1h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkuTPqWjcApwyD3VdJhviKM5C13zGk8c4m87crgFarQboF3Mp56f" diff --git a/tests/device_tests/tezos/test_sign_tx.py b/tests/device_tests/tezos/test_sign_tx.py index 06e17304db6..f70a4934d9c 100644 --- a/tests/device_tests/tezos/test_sign_tx.py +++ b/tests/device_tests/tezos/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, tezos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import dict_to_proto from trezorlib.tools import parse_path @@ -32,10 +32,10 @@ ] -def test_tezos_sign_tx_proposal(client: Client): - with client: +def test_tezos_sign_tx_proposal(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -63,10 +63,10 @@ def test_tezos_sign_tx_proposal(client: Client): assert resp.operation_hash == "opLqntFUu984M7LnGsFvfGW6kWe9QjAz4AfPDqQvwJ1wPM4Si4c" -def test_tezos_sign_tx_multiple_proposals(client: Client): - with client: +def test_tezos_sign_tx_multiple_proposals(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -95,9 +95,9 @@ def test_tezos_sign_tx_multiple_proposals(client: Client): assert resp.operation_hash == "onobSyNgiitGXxSVFJN6949MhUomkkxvH4ZJ2owgWwNeDdntF9Y" -def test_tezos_sing_tx_ballot_yay(client: Client): +def test_tezos_sing_tx_ballot_yay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -119,9 +119,9 @@ def test_tezos_sing_tx_ballot_yay(client: Client): ) -def test_tezos_sing_tx_ballot_nay(client: Client): +def test_tezos_sing_tx_ballot_nay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -142,9 +142,9 @@ def test_tezos_sing_tx_ballot_nay(client: Client): ) -def test_tezos_sing_tx_ballot_pass(client: Client): +def test_tezos_sing_tx_ballot_pass(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -167,9 +167,9 @@ def test_tezos_sing_tx_ballot_pass(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): +def test_tezos_sign_tx_tranasaction(session: Session, chunkify: bool): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -202,9 +202,9 @@ def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): assert resp.operation_hash == "oon8PNUsPETGKzfESv1Epv4535rviGS7RdCfAEKcPvzojrcuufb" -def test_tezos_sign_tx_delegation(client: Client): +def test_tezos_sign_tx_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_15, dict_to_proto( messages.TezosSignTx, @@ -232,9 +232,9 @@ def test_tezos_sign_tx_delegation(client: Client): assert resp.operation_hash == "op79C1tR7wkUgYNid2zC1WNXmGorS38mTXZwtAjmCQm2kG7XG59" -def test_tezos_sign_tx_origination(client: Client): +def test_tezos_sign_tx_origination(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -263,9 +263,9 @@ def test_tezos_sign_tx_origination(client: Client): assert resp.operation_hash == "onmq9FFZzvG2zghNdr1bgv9jzdbzNycXjSSNmCVhXCGSnV3WA9g" -def test_tezos_sign_tx_reveal(client: Client): +def test_tezos_sign_tx_reveal(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH, dict_to_proto( messages.TezosSignTx, @@ -305,9 +305,9 @@ def test_tezos_sign_tx_reveal(client: Client): assert resp.operation_hash == "oo9JFiWTnTSvUZfajMNwQe1VyFN2pqwiJzZPkpSAGfGD57Z6mZJ" -def test_tezos_smart_contract_delegation(client: Client): +def test_tezos_smart_contract_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -342,9 +342,9 @@ def test_tezos_smart_contract_delegation(client: Client): assert resp.operation_hash == "oo75gfQGGPEPChXZzcPPAGtYqCpsg2BS5q9gmhrU3NQP7CEffpU" -def test_tezos_kt_remove_delegation(client: Client): +def test_tezos_kt_remove_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -377,9 +377,9 @@ def test_tezos_kt_remove_delegation(client: Client): assert resp.operation_hash == "ootMi1tXbfoVgFyzJa8iXyR4mnHd5TxLm9hmxVzMVRkbyVjKaHt" -def test_tezos_smart_contract_transfer(client: Client): +def test_tezos_smart_contract_transfer(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -420,9 +420,9 @@ def test_tezos_smart_contract_transfer(client: Client): assert resp.operation_hash == "ooRGGtCmoQDgB36XvQqmM7govc3yb77YDUoa7p2QS7on27wGRns" -def test_tezos_smart_contract_transfer_to_contract(client: Client): +def test_tezos_smart_contract_transfer_to_contract(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, diff --git a/tests/device_tests/thp/__init__.py b/tests/device_tests/thp/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py new file mode 100644 index 00000000000..1745483a714 --- /dev/null +++ b/tests/device_tests/thp/test_thp.py @@ -0,0 +1,326 @@ +import os +import random +import typing as t +from hashlib import sha256 + +import pytest +import typing_extensions as tx + +from trezorlib import protobuf +from trezorlib.client import ProtocolV2 +from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.messages import ( + ButtonAck, + ButtonRequest, + ThpCodeEntryChallenge, + ThpCodeEntryCommitment, + ThpCodeEntryCpaceHostTag, + ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, + ThpCredentialRequest, + ThpCredentialResponse, + ThpEndRequest, + ThpEndResponse, + ThpNfcTagHost, + ThpNfcTagTrezor, + ThpPairingMethod, + ThpPairingPreparationsFinished, + ThpPairingRequest, + ThpPairingRequestApproved, + ThpQrCodeSecret, + ThpQrCodeTag, + ThpSelectMethod, +) +from trezorlib.transport.thp import curve25519 +from trezorlib.transport.thp.cpace import Cpace +from trezorlib.transport.thp.protocol_v2 import _hkdf + +if t.TYPE_CHECKING: + P = tx.ParamSpec("P") + +MT = t.TypeVar("MT", bound=protobuf.MessageType) + +pytestmark = [pytest.mark.protocol("protocol_v2")] + + +def _prepare_protocol(client: Client) -> ProtocolV2: + protocol = client.protocol + assert isinstance(protocol, ProtocolV2) + protocol._reset_sync_bits() + return protocol + + +def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2: + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake() + return protocol + + +def _handle_pairing_request(client: Client, protocol: ProtocolV2) -> None: + protocol._send_message(ThpPairingRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "pairing_request" + + protocol._send_message(ButtonAck()) + + client.debug.press_yes() + + protocol._read_message(ThpPairingRequestApproved) + + +def test_allocate_channel(client: Client) -> None: + protocol = _prepare_protocol(client) + + nonce = random.randbytes(8) + + # Use valid nonce + protocol._send_channel_allocation_request(nonce) + protocol._read_channel_allocation_response(nonce) + + # Expect different nonce + protocol._send_channel_allocation_request(nonce) + with pytest.raises(Exception, match="Invalid channel allocation response."): + protocol._read_channel_allocation_response( + expected_nonce=b"\xDE\xAD\xBE\xEF\xDE\xAD\xBE\xEF" + ) + client.invalidate() + + +def test_handshake(client: Client) -> None: + protocol = _prepare_protocol(client) + + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + protocol._do_channel_allocation() + protocol._send_handshake_init_request(host_ephemeral_pubkey) + protocol._read_ack() + init_response = protocol._read_handshake_init_response() + + trezor_ephemeral_pubkey = init_response[:32] + encrypted_trezor_static_pubkey = init_response[32:80] + noise_tag = init_response[80:96] + + # TODO check noise_tag is valid + + ck = protocol._send_handshake_completion_request( + host_ephemeral_pubkey, + host_ephemeral_privkey, + trezor_ephemeral_pubkey, + encrypted_trezor_static_pubkey, + ) + protocol._read_ack() + protocol._read_handshake_completion_response() + protocol.key_request, protocol.key_response = _hkdf(ck, b"") + protocol.nonce_request = 0 + protocol.nonce_response = 1 + + # TODO - without pairing, the client is damaged and results in fail of the following test + # so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error + protocol._do_pairing(client.debug) + + # TODO the following is just to make style checker happy + assert noise_tag is not None + + +def test_pairing_qr_code(client: Client) -> None: + protocol = _prepare_protocol_for_pairing(client) + _handle_pairing_request(client, protocol) + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) + ) + protocol._read_message(ThpPairingPreparationsFinished) + + # QR Code shown + + # Read code from "Trezor's display" using debuglink + + pairing_info = client.debug.pairing_info( + thp_channel_id=protocol.channel_id.to_bytes(2, "big") + ) + code = pairing_info.code_qr_code + + # Compute tag for response + sha_ctx = sha256(protocol.handshake_hash) + sha_ctx.update(code) + tag = sha_ctx.digest() + + protocol._send_message(ThpQrCodeTag(tag=tag)) + + secret_msg = protocol._read_message(ThpQrCodeSecret) + + # Check that the `code` was derived from the revealed secret + sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(secret_msg.secret) + computed_code = sha_ctx.digest()[:16] + assert code == computed_code + + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + protocol._has_valid_channel = True + + +def test_pairing_code_entry(client: Client) -> None: + protocol = _prepare_protocol_for_pairing(client) + + _handle_pairing_request(client, protocol) + + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) + ) + + commitment_msg = protocol._read_message(ThpCodeEntryCommitment) + commitment = commitment_msg.commitment + + challenge = random.randbytes(16) + protocol._send_message(ThpCodeEntryChallenge(challenge=challenge)) + + cpace_trezor = protocol._read_message(ThpCodeEntryCpaceTrezor) + cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key + + # Code Entry code shown + + pairing_info = client.debug.pairing_info( + thp_channel_id=protocol.channel_id.to_bytes(2, "big") + ) + code = pairing_info.code_entry_code + + cpace = Cpace(handshake_hash=protocol.handshake_hash) + cpace.random_bytes = random.randbytes + cpace.generate_keys_and_secret(code.to_bytes(6, "big"), cpace_trezor_public_key) + sha_ctx = sha256(cpace.shared_secret) + tag = sha_ctx.digest() + + protocol._send_message( + ThpCodeEntryCpaceHostTag( + cpace_host_public_key=cpace.host_public_key, + tag=tag, + ) + ) + + secret_msg = protocol._read_message(ThpCodeEntrySecret) + + # Check `commitment` and `code` + sha_ctx = sha256(secret_msg.secret) + computed_commitment = sha_ctx.digest() + assert commitment == computed_commitment + + sha_ctx = sha256(ThpPairingMethod.CodeEntry.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(secret_msg.secret) + sha_ctx.update(challenge) + code_hash = sha_ctx.digest() + computed_code = int.from_bytes(code_hash, "big") % 1000000 + assert code == computed_code + + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + protocol._has_valid_channel = True + + +def test_pairing_nfc(client: Client) -> None: + protocol = _prepare_protocol_for_pairing(client) + + _nfc_pairing(client, protocol) + + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + protocol._has_valid_channel = True + + +def _nfc_pairing(client: Client, protocol: ProtocolV2): + + _handle_pairing_request(client, protocol) + + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC) + ) + protocol._read_message(ThpPairingPreparationsFinished) + + # NFC screen shown + + nfc_secret_host = random.randbytes(16) + # Read `nfc_secret` and `handshake_hash` from Trezor using debuglink + pairing_info = client.debug.pairing_info( + thp_channel_id=protocol.channel_id.to_bytes(2, "big"), + handshake_hash=protocol.handshake_hash, + nfc_secret_host=nfc_secret_host, + ) + handshake_hash_trezor = pairing_info.handshake_hash + nfc_secret_trezor = pairing_info.nfc_secret_trezor + + assert handshake_hash_trezor[:16] == protocol.handshake_hash[:16] + + # Compute tag for response + sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(nfc_secret_trezor) + tag_host = sha_ctx.digest() + + protocol._send_message(ThpNfcTagHost(tag=tag_host)) + + tag_trezor_msg = protocol._read_message(ThpNfcTagTrezor) + + # Check that the `code` was derived from the revealed secret + sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(nfc_secret_host) + computed_tag = sha_ctx.digest() + assert tag_trezor_msg.tag == computed_tag + + +def test_credential_phase(client: Client): + protocol = _prepare_protocol_for_pairing(client) + _nfc_pairing(client, protocol) + + # Request credential with confirmation after pairing + host_static_privkey = curve25519.get_private_key(os.urandom(32)) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) + ) + credential_response = protocol._read_message(ThpCredentialResponse) + + assert credential_response.credential is not None + credential = credential_response.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, host_static_privkey) + protocol._send_message(ThpEndRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation and ask for autoconnect credential + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True) + ) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + credential_response_2 = protocol._read_message(ThpCredentialResponse) + assert credential_response_2.credential is not None + credential_auto = credential_response_2.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using autoconnect credential + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential_auto, host_static_privkey) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 3fd7ca7fd95..7016e2f5f80 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -17,7 +17,7 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import MNEMONIC12 @@ -30,23 +30,23 @@ @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_add_remove(client: Client): - with client: +def test_add_remove(session: Session): + with session, session.client as client: IF = InputFlowFidoConfirm(client) client.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, 0) + fido.remove_credential(session, 0) # List should be empty. - assert fido.list_credentials(client) == [] + assert fido.list_credentials(session) == [] # Add valid credential #1. - fido.add_credential(client, CRED1) + fido.add_credential(session, CRED1) # Check that the credential was added and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name == "Example" @@ -59,10 +59,10 @@ def test_add_remove(client: Client): assert creds[0].hmac_secret is True # Add valid credential #2, which has same rpId and userId as credential #1. - fido.add_credential(client, CRED2) + fido.add_credential(session, CRED2) # Check that the credential #2 replaced credential #1 and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name is None @@ -76,32 +76,32 @@ def test_add_remove(client: Client): # Adding an invalid credential should appear as if user cancelled. with pytest.raises(Cancelled): - fido.add_credential(client, CRED1[:-2]) + fido.add_credential(session, CRED1[:-2]) # Check that the invalid credential was not added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 # Add valid credential, which has same userId as #2, but different rpId. - fido.add_credential(client, CRED3) + fido.add_credential(session, CRED3) # Check that the credential was added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 2 # Fill up the credential storage to maximum capacity. for cred in CREDS[: RK_CAPACITY - 2]: - fido.add_credential(client, cred) + fido.add_credential(session, cred) # Adding one more valid credential to full storage should fail. with pytest.raises(TrezorFailure): - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) # Removing the index, which is one past the end, should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, RK_CAPACITY) + fido.remove_credential(session, RK_CAPACITY) # Remove index 2. - fido.remove_credential(client, 2) + fido.remove_credential(session, 2) # Adding another valid credential should succeed now. - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) diff --git a/tests/device_tests/webauthn/test_u2f_counter.py b/tests/device_tests/webauthn/test_u2f_counter.py index d99467f2b9d..c140ba54578 100644 --- a/tests/device_tests/webauthn/test_u2f_counter.py +++ b/tests/device_tests/webauthn/test_u2f_counter.py @@ -17,15 +17,15 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.altcoin -def test_u2f_counter(client: Client): - assert fido.get_next_counter(client) == 0 - assert fido.get_next_counter(client) == 1 - fido.set_counter(client, 111111) - assert fido.get_next_counter(client) == 111112 - assert fido.get_next_counter(client) == 111113 - fido.set_counter(client, 0) - assert fido.get_next_counter(client) == 1 +def test_u2f_counter(session: Session): + assert fido.get_next_counter(session) == 0 + assert fido.get_next_counter(session) == 1 + fido.set_counter(session, 111111) + assert fido.get_next_counter(session) == 111112 + assert fido.get_next_counter(session) == 111113 + fido.set_counter(session, 0) + assert fido.get_next_counter(session) == 1 diff --git a/tests/device_tests/zcash/test_sign_tx.py b/tests/device_tests/zcash/test_sign_tx.py index d689c8af969..4d7df800903 100644 --- a/tests/device_tests/zcash/test_sign_tx.py +++ b/tests/device_tests/zcash/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -53,7 +53,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -69,7 +69,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -77,7 +77,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_v4_input(client: Client): +def test_spend_v4_input(session: Session): # 4b6cecb81c825180786ebe07b65bcc76078afc5be0f1c64e08d764005012380d is a v4 tx inp1 = messages.TxInputType( @@ -95,13 +95,13 @@ def test_spend_v4_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -110,7 +110,7 @@ def test_spend_v4_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -126,7 +126,7 @@ def test_spend_v4_input(client: Client): ) -def test_send_to_multisig(client: Client): +def test_send_to_multisig(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/8"), @@ -143,13 +143,13 @@ def test_send_to_multisig(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -158,7 +158,7 @@ def test_send_to_multisig(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -174,7 +174,7 @@ def test_send_to_multisig(client: Client): ) -def test_spend_v5_input(client: Client): +def test_spend_v5_input(session: Session): inp1 = messages.TxInputType( # tmBMyeJebzkP5naji8XUKqLyL1NDwNkgJFt address_n=parse_path("m/44h/1h/0h/0/9"), @@ -190,13 +190,13 @@ def test_spend_v5_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -205,7 +205,7 @@ def test_spend_v5_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -221,7 +221,7 @@ def test_spend_v5_input(client: Client): ) -def test_one_two(client: Client): +def test_one_two(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -243,13 +243,13 @@ def test_one_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -260,7 +260,7 @@ def test_one_two(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -277,7 +277,7 @@ def test_one_two(client: Client): @pytest.mark.models("core") -def test_unified_address(client: Client): +def test_unified_address(session: Session): # identical to the test_one_two # but receiver address is unified with an orchard address inp1 = messages.TxInputType( @@ -301,13 +301,13 @@ def test_unified_address(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -318,7 +318,7 @@ def test_unified_address(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -335,7 +335,7 @@ def test_unified_address(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -365,14 +365,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(1), request_input(0), @@ -383,7 +383,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], @@ -399,7 +399,7 @@ def test_external_presigned(client: Client): ) -def test_refuse_replacement_tx(client: Client): +def test_refuse_replacement_tx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/4"), amount=174998, @@ -437,7 +437,7 @@ def test_refuse_replacement_tx(client: Client): TrezorFailure, match="Replacement transactions are not supported." ): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -447,12 +447,12 @@ def test_refuse_replacement_tx(client: Client): ) -def test_spend_multisig(client: Client): +def test_spend_multisig(session: Session): # Cloned from tests/device_tests/bitcoin/test_multisig.py::test_2_of_3 nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" ).node for index in range(1, 4) ] @@ -482,17 +482,17 @@ def test_spend_multisig(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures1, _ = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -529,10 +529,10 @@ def test_spend_multisig(client: Client): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp3], [out1], diff --git a/tests/input_flows.py b/tests/input_flows.py index 5896523ce70..ea1e395d3ae 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -16,6 +16,7 @@ from trezorlib import messages from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import multipage_content @@ -129,13 +130,15 @@ def input_two_different_pins() -> BRGeneratorType: class InputFlowCodeChangeFail(InputFlowBase): + def __init__( - self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str + self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str ): - super().__init__(client) + super().__init__(session.client) self.current_pin = current_pin self.new_pin_1 = new_pin_1 self.new_pin_2 = new_pin_2 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield # do you want to change pin? @@ -150,7 +153,7 @@ def input_flow_common(self) -> BRGeneratorType: # failed retry yield # enter current pin again - self.client.cancel() + self.session.cancel() class InputFlowWrongPIN(InputFlowBase): @@ -1891,9 +1894,11 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.invalid_mnemonic = ["stick"] * 12 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_dry_run() @@ -1902,7 +1907,7 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_invalid_recovery_seed() yield - self.client.cancel() + self.session.cancel() class InputFlowBip39Recovery(InputFlowBase): @@ -1985,15 +1990,17 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2005,19 +2012,21 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_group_threshold_reached() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2029,7 +2038,7 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase): @@ -2128,10 +2137,12 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.first_invalid = ["slush"] * 20 self.second_invalid = ["slush"] * 33 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2143,16 +2154,18 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_invalid_recovery_share() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): - def __init__(self, client: Client, shares: list[str]): - super().__init__(client) + + def __init__(self, session: Session, shares: list[str]): + super().__init__(session.client) self.shares = shares self.first_share = shares[0].split(" ") self.invalid_share = self.first_share[:3] + ["slush"] * 17 self.second_share = shares[1].split(" ") + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2165,16 +2178,18 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.success_more_shares_needed(1) yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): - def __init__(self, client: Client, share: list[str], nth_word: int): - super().__init__(client) + + def __init__(self, session: Session, share: list[str], nth_word: int): + super().__init__(session.client) self.share = share self.nth_word = nth_word # Invalid share - just enough words to trigger the warning self.modified_share = share[:nth_word] + [self.share[-1]] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2185,15 +2200,17 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_share_from_another_shamir() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): - def __init__(self, client: Client, share: list[str]): - super().__init__(client) + + def __init__(self, session: Session, share: list[str]): + super().__init__(session.client) self.share = share # Second duplicate share - only 4 words are needed to verify it self.duplicate_share = self.share[:4] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2204,7 +2221,7 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowResetSkipBackup(InputFlowBase): diff --git a/tests/persistence_tests/test_safety_checks.py b/tests/persistence_tests/test_safety_checks.py index 1cbf7d75516..04d137ec141 100644 --- a/tests/persistence_tests/test_safety_checks.py +++ b/tests/persistence_tests/test_safety_checks.py @@ -20,16 +20,17 @@ def test_safety_checks_level_after_reboot( core_emulator: Emulator, set_level: SafetyCheckLevel, after_level: SafetyCheckLevel ): - device.wipe(core_emulator.client) + device.wipe(core_emulator.client.get_seedless_session()) debuglink.load_device( - core_emulator.client, + core_emulator.client.get_seedless_session(), mnemonic=MNEMONIC12, pin="", passphrase_protection=False, label="SAFETYLEVEL", ) - device.apply_settings(core_emulator.client, safety_checks=set_level) + device.apply_settings(core_emulator.client.get_session(), safety_checks=set_level) + core_emulator.client.refresh_features() assert core_emulator.client.features.safety_checks == set_level core_emulator.restart() diff --git a/tests/persistence_tests/test_shamir_persistence.py b/tests/persistence_tests/test_shamir_persistence.py index 5907df17964..52bc9636702 100644 --- a/tests/persistence_tests/test_shamir_persistence.py +++ b/tests/persistence_tests/test_shamir_persistence.py @@ -16,7 +16,8 @@ import pytest -from trezorlib import device +from trezorlib import device, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import DebugLink, LayoutType from trezorlib.messages import RecoveryStatus @@ -45,7 +46,7 @@ def test_abort(core_emulator: Emulator): assert features.recovery_status == RecoveryStatus.Nothing - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) layout = debug.read_layout() @@ -82,7 +83,7 @@ def test_recovery_single_reset(core_emulator: Emulator): assert features.initialized is False assert features.recovery_status == RecoveryStatus.Nothing - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) @@ -129,7 +130,7 @@ def assert_mnemonic_keyboard(debug: DebugLink) -> None: assert features.recovery_status == RecoveryStatus.Nothing # enter recovery mode - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) @@ -157,7 +158,8 @@ def assert_mnemonic_keyboard(debug: DebugLink) -> None: layout = debug.read_layout() # while keyboard is open, hit the device with Initialize/GetFeatures - device_handler.client.init_device() + if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1: + device_handler.client.get_seedless_session().call(messages.Initialize()) device_handler.client.refresh_features() # try entering remaining 19 words @@ -207,7 +209,7 @@ def enter_shares_with_restarts(debug: DebugLink) -> None: assert features.recovery_status == RecoveryStatus.Nothing # start device and recovery - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) diff --git a/tests/persistence_tests/test_wipe_code.py b/tests/persistence_tests/test_wipe_code.py index 2497a708f6e..8dee771a6a6 100644 --- a/tests/persistence_tests/test_wipe_code.py +++ b/tests/persistence_tests/test_wipe_code.py @@ -11,46 +11,55 @@ def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client) + device.wipe(client.get_seedless_session()) + client = client.get_new_client() debuglink.load_device( - client, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE" + client.get_seedless_session(), + MNEMONIC12, + pin, + passphrase_protection=False, + label="WIPECODE", ) with client: client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(client) + device.change_wipe_code(client.get_seedless_session()) def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client) + device.wipe(client.get_seedless_session()) + client = client.get_new_client() debuglink.load_device( - client, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE" + client.get_seedless_session(), + MNEMONIC12, + pin, + passphrase_protection=False, + label="WIPECODE", ) with client: client.use_pin_sequence([pin, wipe_code, wipe_code]) - device.change_wipe_code(client) + device.change_wipe_code(client.get_seedless_session()) @core_only def test_wipe_code_activate_core(core_emulator: Emulator): # set up device setup_device_core(core_emulator.client, PIN, WIPE_CODE) - - core_emulator.client.init_device() + session = core_emulator.client.get_session() device_id = core_emulator.client.features.device_id # Initiate Change pin process - ret = core_emulator.client.call_raw(messages.ChangePin(remove=False)) + ret = session.call_raw(messages.ChangePin(remove=False)) assert isinstance(ret, messages.ButtonRequest) assert ret.name == "change_pin" core_emulator.client.debug.press_yes() - ret = core_emulator.client.call_raw(messages.ButtonAck()) + ret = session.call_raw(messages.ButtonAck()) # Enter the wipe code instead of the current PIN expected = message_filters.ButtonRequest(code=messages.ButtonRequestType.PinEntry) assert expected.match(ret) - core_emulator.client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) core_emulator.client.debug.input(WIPE_CODE) # preserving screenshots even after it dies and starts again @@ -75,25 +84,26 @@ def test_wipe_code_activate_legacy(): # set up device setup_device_legacy(emu.client, PIN, WIPE_CODE) - emu.client.init_device() + session = emu.client.get_session() device_id = emu.client.features.device_id # Initiate Change pin process - ret = emu.client.call_raw(messages.ChangePin(remove=False)) + ret = session.call_raw(messages.ChangePin(remove=False)) assert isinstance(ret, messages.ButtonRequest) emu.client.debug.press_yes() - ret = emu.client.call_raw(messages.ButtonAck()) + ret = session.call_raw(messages.ButtonAck()) # Enter the wipe code instead of the current PIN assert isinstance(ret, messages.PinMatrixRequest) wipe_code_encoded = emu.client.debug.encode_pin(WIPE_CODE) - emu.client._raw_write(messages.PinMatrixAck(pin=wipe_code_encoded)) + session._write(messages.PinMatrixAck(pin=wipe_code_encoded)) # wait 30 seconds for emulator to shut down # this will raise a TimeoutError if the emulator doesn't die. emu.wait(30) emu.start() + emu.client.refresh_features() assert emu.client.features.initialized is False assert emu.client.features.pin_protection is False assert emu.client.features.wipe_code_protection is False diff --git a/tests/translations.py b/tests/translations.py index afb12a5fec2..be2c4e762e0 100644 --- a/tests/translations.py +++ b/tests/translations.py @@ -8,7 +8,7 @@ from trezorlib import cosi, device, models from trezorlib._internal import translations -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from . import common @@ -58,19 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes: def build_and_sign_blob( lang_or_def: translations.JsonDef | Path | str, - client: Client, + session: Session, ) -> bytes: - blob = prepare_blob(lang_or_def, client.model, client.version) + blob = prepare_blob(lang_or_def, session.model, session.version) return sign_blob(blob) -def set_language(client: Client, lang: str): +def set_language(session: Session, lang: str): if lang.startswith("en"): language_data = b"" else: - language_data = build_and_sign_blob(lang, client) - with client: - device.change_language(client, language_data) # type: ignore + language_data = build_and_sign_blob(lang, session) + with session: + device.change_language(session, language_data) # type: ignore _CURRENT_TRANSLATION.TR = TRANSLATIONS[lang] diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 2213a03dabd..5d542578293 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -8,6 +8,7 @@ from _pytest.nodes import Node from _pytest.outcomes import Failed +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import TrezorClientDebugLink as Client from . import common @@ -56,11 +57,14 @@ def screen_recording( yield finally: client.ensure_open() - client.sync_responses() + if client.protocol_version == ProtocolVersion.PROTOCOL_V1: + client.sync_responses() # Wait for response to Initialize, which gives the emulator time to catch up # and redraw the homescreen. Otherwise there's a race condition between that # and stopping recording. - client.init_device() + + # Instead of client.init_device() we create a new management session + client.get_seedless_session() client.debug.stop_recording() result = testcase.build_result(request) diff --git a/tests/ui_tests/fixtures.json b/tests/ui_tests/fixtures.json index 72d6216e595..5c41e6385eb 100644 --- a/tests/ui_tests/fixtures.json +++ b/tests/ui_tests/fixtures.json @@ -864,19 +864,19 @@ "T2T1_en_test_passphrase_bolt.py::test_passphrase_prompt_disappears": "d051fc05dc3af0c685de6ec8f00b0ab4facf8f6fd49dcece503fa15261d6c90a", "T2T1_en_test_pin.py::test_last_digit_timeout": "a170405f1451dd9092afc83c4326e08d9076e52a6ef20c940ff916baa739e9c3", "T2T1_en_test_pin.py::test_pin_cancel": "477133459306a2a9f64fc2bd3abeebaf67a678e59bdd1418db4536b2be4e657f", -"T2T1_en_test_pin.py::test_pin_change": "b3dccad89be83c8a5c62169b861835730ad46bf01fd0cf598c47b2ebb6cd3e14", -"T2T1_en_test_pin.py::test_pin_delete_hold": "a667ab0e8b32e0c633ff518631c41e4f8883beab0fb9eba2e389b67326230f8a", -"T2T1_en_test_pin.py::test_pin_empty_cannot_send": "43bbc9818d48677f0f03e70a4a94d5fde13b4629f45b570a777375d19de104b7", -"T2T1_en_test_pin.py::test_pin_incorrect": "033b19549e9d7e351e1f4184599775444cda22b146a0efebed6050176123732d", -"T2T1_en_test_pin.py::test_pin_long": "d70d5a0a58f910fb8d48d2c207aaabeda49c28b024fdd6c62773bb33ca93c66e", -"T2T1_en_test_pin.py::test_pin_long_delete": "6f888b50d0e62a0da3321cbf4e70055273a0721e6344f561468fe0474d4f4c8c", -"T2T1_en_test_pin.py::test_pin_longer_than_max": "bde7e9c8d3be494c8757973459d8a944e5cbc0db6067c383be9833be7bf9deb4", -"T2T1_en_test_pin.py::test_pin_same_as_wipe_code": "290606d09ad41c9f741e75e3e757d1881b7f13c2fe762935b542b11a82329b1c", -"T2T1_en_test_pin.py::test_pin_setup": "2a77bf25fd3b7601d68ba6e13bba43cb947da41b4fd61eb10b73ac079926f881", -"T2T1_en_test_pin.py::test_pin_setup_mismatch": "21d3063f21659942e3d7e40a377f44a6ae3e8e3fac9cac42c3bb49bd05d31156", -"T2T1_en_test_pin.py::test_pin_short": "43bbc9818d48677f0f03e70a4a94d5fde13b4629f45b570a777375d19de104b7", -"T2T1_en_test_pin.py::test_wipe_code_same_as_pin": "1a98ffdf03a0ab799dd794c3215d52bced51fbff622a9499855d87d85cdc0850", -"T2T1_en_test_pin.py::test_wipe_code_setup": "506353c53d464a32d28557965ebe951c2dbcbf1fa64194f7b943873991f14920", +"T2T1_en_test_pin.py::test_pin_change": "22780139b80e64be55e0b06ba7e30a7bf50aedd9a3b5efdcadd209016a861650", +"T2T1_en_test_pin.py::test_pin_delete_hold": "b0ad94241ff310c2420c3e139ef20dd950cc49ea41cc3a15a8063ca46f9b846e", +"T2T1_en_test_pin.py::test_pin_empty_cannot_send": "1ea497ef03af36316e2c2f0f8d8d129b708ff55973b755ad784f49051027f93f", +"T2T1_en_test_pin.py::test_pin_incorrect": "edb24511dd44e09a8e3956154a896f562925c7f2e06ceccdc0ee442d9a86e2b3", +"T2T1_en_test_pin.py::test_pin_long": "8a400098a949714ec45eefaa2ef082ad2b5fec646ef76659f73d2d688b107401", +"T2T1_en_test_pin.py::test_pin_long_delete": "90f45e06e4a2224420738c02849ff4e9fcfaf9a31293d46a41d478e1f08a0ab8", +"T2T1_en_test_pin.py::test_pin_longer_than_max": "720987d80d19b92a3722fcdd4f818edb901ff8e7e2da023d41af972e2af1b159", +"T2T1_en_test_pin.py::test_pin_same_as_wipe_code": "cd82be6ab9d456c6c2ced6b247664c6917bdeb52241a3ece8e26c8e0a8905d51", +"T2T1_en_test_pin.py::test_pin_setup": "ab3141a83ed7e4acdd622c234289849a90c4a9b2f7ec4e8d03c909bce5d620f0", +"T2T1_en_test_pin.py::test_pin_setup_mismatch": "b200a882d69330b05774a887b52199dcf32729056e8792f7cabd05be4bae6b50", +"T2T1_en_test_pin.py::test_pin_short": "1ea497ef03af36316e2c2f0f8d8d129b708ff55973b755ad784f49051027f93f", +"T2T1_en_test_pin.py::test_wipe_code_same_as_pin": "3725a3942a40441f4d94a35378d283f42b67e059e23522c2a4663fcaafee3cd0", +"T2T1_en_test_pin.py::test_wipe_code_setup": "25bc375fb004952492c707312c6ead7b733b1c06b80aec04f2aef19d77d551be", "T2T1_en_test_recovery.py::test_recovery_bip39": "7dac6daa1501680327e1f4e231d2688e34098d5ed15327258a2ef6543af49656", "T2T1_en_test_recovery.py::test_recovery_bip39_previous_word": "acafd9ab4375113cab57f29a0bbd07f5c29651a9e107aa3d82fed37b368cbe04", "T2T1_en_test_recovery.py::test_recovery_slip39_basic": "99678a6fa374037bce7880bb749f04b5c6df5abacd28840221cbc1f9307c04bb", @@ -3965,8 +3965,8 @@ "T2T1_en_bitcoin-test_authorize_coinjoin.py::test_get_address": "b4324b15dba7a207a37cba43b20f6ecdbfe34dfbedbd48daa738e013e076f2fe", "T2T1_en_bitcoin-test_authorize_coinjoin.py::test_get_public_key": "88cd73984ce16c4f9875977848d66d0d077bd3cf41b54b8c9785def18d2b27e7", "T2T1_en_bitcoin-test_authorize_coinjoin.py::test_multisession_authorization": "f2e1e127dc7c8a6c96944844b6ea881c2f806f9a9a6b67da371390f5bee09050", -"T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[False]": "7f058e56fc8249224956ec4cd3fa8d64d2e8b7e257b2a6a17b1786526ae7bd18", -"T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[True]": "7f058e56fc8249224956ec4cd3fa8d64d2e8b7e257b2a6a17b1786526ae7bd18", +"T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[False]": "72e26d8ee8a859fa919d5ebe81ca686e0516b6f80eaa2cf88baed7f83e521ee6", +"T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx[True]": "72e26d8ee8a859fa919d5ebe81ca686e0516b6f80eaa2cf88baed7f83e521ee6", "T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx_large": "d76ca131de94a33d526dc6cf562f5a236a7bffa909279cfec165ae7f1ad96d05", "T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx_migration": "c49427c655671c0a82b1c561219c3a98f89f040fca356c601de0e0c62a4e2fd0", "T2T1_en_bitcoin-test_authorize_coinjoin.py::test_sign_tx_spend": "360668f93741aa043881a329c5750964ac3a92c6dd384ab4e0c65c2838245cb1", @@ -4554,7 +4554,8 @@ "T2T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICA-3b0af713": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.ICARUS]": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_cardano-test_derivations.py::test_derivation_irrelevant_on_slip39[CardanoDerivationType.LEDGER]": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", -"T2T1_en_cardano-test_derivations.py::test_ledger_available_always": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", +"T2T1_en_cardano-test_derivations.py::test_ledger_available_with_cardano": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", +"T2T1_en_cardano-test_derivations.py::test_ledger_available_without_cardano": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script]": "bb8c74d4180f2d57eb785179466d23c2cfe48bdda77156516beaa12635c2b5de", "T2T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-aae1e6c3": "5939e15589fbaf035ce23c76f0c9ccd647f4845e75fac13ba1c65a9f99435d81", "T2T1_en_cardano-test_get_native_script_hash.py::test_cardano_get_native_script_hash[all_script_cont-e4ca0ea5": "665198937b04fec2d2478a19dc96ac11f6d6f6e8f8e8dc0f1a9d3feedaf3e834", @@ -5067,12 +5068,12 @@ "T2T1_en_reset_recovery-test_recovery_slip39_basic.py::test_wrong_nth_word[2]": "07576510508fd5bdeeb8c8c441a3707aa85167605e31c3557475cd31fbbe688a", "T2T1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_dryrun": "e86b1694c918ad9be15595ba812e45904264a1aa06199a4ac423edcce6a144ed", "T2T1_en_reset_recovery-test_recovery_slip39_basic_dryrun.py::test_2of3_invalid_seed_dryrun": "7c4b959f5742054ee34e7b7a78d6123e9b8f16f910fd217d33347eacd13e22bf", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "685ec41d860f8c235e2cbc5dddf5c0c74816c3e05fbb0fa3b894354d32c46434", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "7dac291289bbf2db94362615565e3c18aec61c31ec4b26be551c537e175af6c3", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "ddeed0604093458a8a42d55cd2b5e2b5b2459e8d593d3fb3710243d50a7ecd76", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "5ef25dac555000986c757955c40f18e336979819936f487c70c9ecd633acc653", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "c4d39c0bb7d54c71a0a87e3ee912b5bb5ce6919693c210596cc316ab57deb48f", -"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "ff2d873c486d3f6f55bf0d6558aa759db2b42a8aa1e697a76eb3abf0b0144844", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Bip39-backup_flow_bip39]": "24f33fb3ad69e0f18f83d8b03940aa1771a64302923751178abeb90d13cd27c2", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Advanced_Ext-10ea47d6": "a8c445a1ef669bb842eb054430239c164b259ffa43a6c315fb77ae0836a60755", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_manual[BackupType.Slip39_Basic_Extend-5dbe8b0f": "1511f33bfacc3e88015f2b407a0ac226b4c9f7bf3f6ba6dd65ca229f4911b3bf", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Bip39-backup_flow_bip39]": "c62b6304143661cc5642b80b367c413a3aa0dbfc476bf628a88ee0be7d6c09ea", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Advanced_Extend-8b11c1dc": "b8e9b35ccc909429d204181ea7cec6716f5b15c80fc57f685ac8f234a1cdc852", +"T2T1_en_reset_recovery-test_reset_backup.py::test_skip_backup_msg[BackupType.Slip39_Basic_Extendabl-cc19e908": "775feff539b81e0329c705900dee03834696d2b2f1bd27154bb30e387270964b", "T2T1_en_reset_recovery-test_reset_bip39_t2.py::test_already_initialized": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_reset_recovery-test_reset_bip39_t2.py::test_entropy_check": "f4d850259404fdef6a340ba2245e2d1cc3e68f06ef5eea819b91b3015bc4b4f4", "T2T1_en_reset_recovery-test_reset_bip39_t2.py::test_failed_pin": "2561ba9b866f53847e8b00bf1cf2eb29946fd1df66e96686b327ea63b067aa71", @@ -5226,29 +5227,29 @@ "T2T1_en_stellar-test_stellar.py::test_sign_tx[timebounds-461535181-0]": "6f667b9545ab865b58a8d1cd6ebbe8a6c502d8e312540c224afc647250cc8fbc", "T2T1_en_stellar-test_stellar.py::test_sign_tx[timebounds-461535181-1575234180]": "a44ea241432fe7ee7fe9629eebe1011e66ee480a8e002e1a561f67cc5f4124eb", "T2T1_en_stellar-test_stellar.py::test_sign_tx[tx_source_account_not_equal_signing_key]": "625c54e8e8da8b9e9d29df67a978c815d84a78a2a9c77d49540fba55dac06b64", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay": "6edaa315d6f038055f4fbe5eccfacdddddd34102d2ac9a6c136db4ab46ba6065", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[0]": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[1]": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[4194304]": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[536871]": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[9]": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[10]": "871b8370825185284b9ad5e8b9ceda31356b21e1f393b8b34fc7ae85442dfbb5", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[123]": "ba628bf2cb774bf010499ab63253d4e3ad0f4db4ea6cff17afb890ab46727cb8", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[3601]": "493f46ea5bee73672c6654558f2fc45c30dc1aa2ea5b2cc16657940098f4ed79", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[536870]": "cb1acfb370e33aed8ba279f69a53f7f1f6edd3bfb58d958d3ba9d09e8b38389d", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[60]": "2ebf21e1d9e7a8e884c87b98b0b3482cd19b3eaa4b066d68ab0a9ae85b551cc8", -"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[7227]": "9e17998aa27515cd8c4d5da4c2a8d62608e7099b183d366bce11506c443a5ae9", -"T2T1_en_test_autolock.py::test_autolock_cancels_ui": "a413b0427b162b48c29eaa1630eb115eba93f3ee6c6105cfc69ada7bac85dfe8", -"T2T1_en_test_autolock.py::test_autolock_default_value": "4ad05386b2adf9eddecd7200f18e1d14dfd40dbf898a5c4732d15f1025b65de3", -"T2T1_en_test_autolock.py::test_autolock_ignores_getaddress": "b1daf580d1da4fa5e11e77b53b008f5a08aaa338db1b6658ae965c55246184c9", -"T2T1_en_test_autolock.py::test_autolock_ignores_initialize": "b1daf580d1da4fa5e11e77b53b008f5a08aaa338db1b6658ae965c55246184c9", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay": "f303115ff7660afbcf79a4a0c897ce3d54dd342b9b3298295571ba9cdff5dc7b", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[0]": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[1]": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[4194304]": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[536871]": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_out_of_range[9]": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[10]": "763c1c0b793ba56481aa0094d278630252381415381377f5e0f78431170ae5a3", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[123]": "34757fc0e9e4f91cf04953998bccd231157ce772fc4b03f3bc933d4f90e8cf05", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[3601]": "f76d81c214f2e010c1c721064996fd39909ea390b1d33a10e7d5f704e87f7438", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[536870]": "e944f909df714a4e3909cf322e628b7f51c498a1f848deac5a29ce394e814f61", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[60]": "2160e34e10e3ddcafc80f18bdc5a421165c4e2c6aeb2a1fcaa7c953ca28ec8bb", +"T2T1_en_test_autolock.py::test_apply_auto_lock_delay_valid[7227]": "4cee4bf1c17b669766187e34135792c1360e0146979fa854e0e5b65ab7a0bd06", +"T2T1_en_test_autolock.py::test_autolock_cancels_ui": "7f618c591571f8836875f686894831c2e9bfed3b700106144c730daad139d5d0", +"T2T1_en_test_autolock.py::test_autolock_default_value": "ef5784e6b0044440ba453e851b019d7c474291d170868988075286486801e48b", +"T2T1_en_test_autolock.py::test_autolock_ignores_getaddress": "1d9bbca2294b5cb48d68907d922dcfa933939ff5cee05f41ae1eda18fd784679", +"T2T1_en_test_autolock.py::test_autolock_ignores_initialize": "1d9bbca2294b5cb48d68907d922dcfa933939ff5cee05f41ae1eda18fd784679", "T2T1_en_test_basic.py::test_capabilities": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_basic.py::test_device_id_different": "6aa7c9a4b9599d7abfe8ce860ede3830249ef3e68b312036df1e32dcd48ea6b8", "T2T1_en_test_basic.py::test_device_id_same": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_basic.py::test_features": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_basic.py::test_ping": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_busy_state.py::test_busy_expiry_core": "725eefdcf0a01c29cd7d5cb3d9c467e2c3c40e6a6e037da92625f1d415d4f3e3", -"T2T1_en_test_busy_state.py::test_busy_state": "2d34c06b02cb36da44daea2341c4dd65f564fb7e530e64881a95187d6ca977c5", +"T2T1_en_test_busy_state.py::test_busy_state": "d3611171126e3efd9bac6a8cbf7ebf63b05aa0621fac418e0179ebe62cd5d249", "T2T1_en_test_cancel.py::test_cancel_message_via_cancel[message0]": "6c42efb29e5a843dd3322ab7e9f8a4ddd0eae2d6d25935c7f4c30cc657c8d936", "T2T1_en_test_cancel.py::test_cancel_message_via_cancel[message1]": "6b509c6baf3516715969ba97d41a5e157419928f52624e19603b0a8807d8d971", "T2T1_en_test_cancel.py::test_cancel_message_via_initialize[message0]": "6c42efb29e5a843dd3322ab7e9f8a4ddd0eae2d6d25935c7f4c30cc657c8d936", @@ -5282,39 +5283,39 @@ "T2T1_en_test_language.py::test_switch_from_english_not_silent": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_language.py::test_switch_language": "7bc47ca05bb7a7e15ba5e1b0ffe8565e40a23a59ca02ff8887506f6b351f0368", "T2T1_en_test_language.py::test_translations_renders_on_screen": "a2f3ba7e0af0f5a27b369e565d0177fe313e501bb39f5da2ad134d9cf2b4284d", -"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_jpeg": "a8ec4a261b75e205f35b296247938f16a17d670ec3d182991f9c3238f350674d", -"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_jpeg_progressive": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_jpeg_wrong_size": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_toif": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_msg_applysettings.py::test_apply_settings": "4098a76e826f702c012b7997b7132403f9c07c81a10b9c0ef4f9dae6e1f90763", -"T2T1_en_test_msg_applysettings.py::test_apply_settings_passphrase": "4c412c6bd8e0222d55c87e61152abb8ce14b06ec0e4c8ccdac817912bb184fc6", +"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_jpeg": "d2eaba56eefd55f238df08d4f01ae5365f8b0e3a82a011ba5355732ab31767d4", +"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_jpeg_progressive": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_jpeg_wrong_size": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_msg_applysettings.py::test_apply_homescreen_toif": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_msg_applysettings.py::test_apply_settings": "eafd022791016d723bed5f6eca899a5b164fa38e34a21bf965ab97161e0ac4ec", +"T2T1_en_test_msg_applysettings.py::test_apply_settings_passphrase": "9d57c692539ad81c471c12e7922d108034034c9afe32620ddf1295c79c905550", "T2T1_en_test_msg_applysettings.py::test_apply_settings_passphrase_on_device": "e5dff4c4635ba6cd638b0db0ad06d90c79b70cced704b1c65f7b22a8b60d0d6c", -"T2T1_en_test_msg_applysettings.py::test_apply_settings_rotation": "c08b570382d8ec6f12ba090a2c2c96bc6dc1725b88d4536132e1786c8492835b", -"T2T1_en_test_msg_applysettings.py::test_experimental_features": "5c82db1025af207c81aa9c8ecacceab148dbc8202642b75f16215b128392f510", +"T2T1_en_test_msg_applysettings.py::test_apply_settings_rotation": "cb082739a7a16b579885852a253e5499ec04b03eeb69620945f600d6784e40f5", +"T2T1_en_test_msg_applysettings.py::test_experimental_features": "58c927e0b827512a9cc99d61e6d34cd58abe6337ca19a02600b9f9f6e871d0fd", "T2T1_en_test_msg_applysettings.py::test_label_too_long": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_msg_applysettings.py::test_safety_checks": "b9b9aec8a5789794ae4cae7bb96b286451a017096db804ce182fece10122caff", "T2T1_en_test_msg_applysettings.py::test_set_brightness": "6dbb950febdafc731112225c8059f2a7079b92201378a470376126dc5330b2b6", -"T2T1_en_test_msg_backup_device.py::test_backup_bip39": "90baa1b711bf16f0a3f7bd47dfe18e31da476689d601e6c74aa578a1836561f2", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_advanced[click_info]": "96da39b91f26d1620bb5a3f678f5db99e9ea5dfb921ff2e8c2bae2195754f013", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_advanced[no_click_info]": "726e801238c70be4311f9dd6f1c997f8f22e8559e145574520de3dd199d7566d", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_basic[click_info]": "f915af2252387adb0bf7825004a66d269f0bb6b7866d88e3a4bf3a73287b6f5b", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_basic[no_click_info]": "aa92ec8f3897c1c5b2c12d919dc887d0676b372b830416598e8fc6437fc8e91f", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_custom[1_of_1]": "45c77c9db20f5b1d2213452f83d54e820d50f02af35935257edc68d168d8b7af", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_custom[2_of_2]": "9249b3e8fe1ca35a3bef0c0b8ccadac9dff08b68986e473a13de26e7fd741e48", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_custom[3_of_5]": "edfa808de10e41770facb759853cff0e06900e98fd41d29e53709b0171810e41", -"T2T1_en_test_msg_backup_device.py::test_backup_slip39_single": "55e19db7d4afad012c2f05a6a8322d22e0de7ffa0da407a66fb0ee411e3b4a36", +"T2T1_en_test_msg_backup_device.py::test_backup_bip39": "da666d177da65010a87bc3eca4a2164f2a25173b5fa70acdc3086fea7eefd183", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_advanced[click_info]": "0bc14569e2e77bd06cc541539dfaf7a3af62a6d1bc43ad215fa7ab6d0b80e6cb", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_advanced[no_click_info]": "e4ae8f266321e86b1a213453c35ea819608b1fbf03434258d69f78e3334fcbb2", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_basic[click_info]": "945ce559795f706a619ed58c4eace534509a4f78069eff9922d9c3f154192d3c", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_basic[no_click_info]": "a1dbcafea80920a996a6b6cea0e094e0993b4b26049f5418bdd11143f0d6efde", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_custom[1_of_1]": "2376e4e472ff40b027fb57a000fa943ef47c469358cac90c6478214834755c4a", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_custom[2_of_2]": "cadf0ae7b72ad6d093742ccdac3ab9401e82e7050aabdfc545890a57f9841bfd", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_custom[3_of_5]": "15dc324896766f9dd95cafe0c70537eb6a716101c9d3a39096bf8e82c9c46b4f", +"T2T1_en_test_msg_backup_device.py::test_backup_slip39_single": "f0141998927d5eaee40bbc0e2ccf0df902c583cea209c56b57355f2d2114fced", "T2T1_en_test_msg_backup_device.py::test_interrupt_backup_fails": "e6df3e89ebe2f90cc45aadde502642d3d168cc79caae0a7c966a7ed7e4af8d5c", "T2T1_en_test_msg_backup_device.py::test_no_backup_fails": "068b76a6436dd51195b719f513e17421c7c41e0bdf9ea10b9f81305ea927cef3", -"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_pin_to_wipe_code": "aa58b24d34921fcfea1b2dae599e7d165f6de4a557edaa6ac1a04d92e654f3fd", -"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_remove_wipe_code": "b6cdf4e95649350d95695152f037e5759e168ad16900887a047b6626696a6e32", -"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_wipe_code_mismatch": "dd76af9f691edbb87ab424e33990030c08c24fb8f0c53c162101df0197567a0a", -"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_wipe_code_to_pin": "9dff41569145ed02aeb13af5c299bd37befae0021703fd4563f06ececfa7d4dd", -"T2T1_en_test_msg_changepin_t2.py::test_change_failed": "8392cbcb12e58c355654699c8b7fe9e3c111c5ab19a66326c9e311ff99221980", -"T2T1_en_test_msg_changepin_t2.py::test_change_invalid_current": "7b03238fa54ef02cadfb1a21acca568c50b89645b48ab2554f1f23940f5d8cc6", -"T2T1_en_test_msg_changepin_t2.py::test_change_pin": "bc042fe1f99a208aed34ec55c94a071aacbdbb0e2817cff4e0340b88f3633ccf", -"T2T1_en_test_msg_changepin_t2.py::test_remove_pin": "1cada68e3519daedcd638d3ac09fd1038f427ce25b007cc2fdbca03adac6c034", -"T2T1_en_test_msg_changepin_t2.py::test_set_failed": "5e25fdb8ab1862f14e961472a795d0e27fb71c707d08d371e058942c02c63d77", -"T2T1_en_test_msg_changepin_t2.py::test_set_pin": "eb1edcd1c762e9ee9d740dfa9723c81d4d6108f1be7f9e5fe4b9747044feb2e4", +"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_pin_to_wipe_code": "6637d756aed5b2a7b748620adc51d8055f45211f52ae62c214ae848cae205fd3", +"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_remove_wipe_code": "bc0f7d64306965cd8d6623f6d00c86f6c9a33d0689e0e7ebf7c8c22d0c1edeb8", +"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_wipe_code_mismatch": "ba0aab3b4996207d883d36cf5273a2b30dc6fd952026d623850e0fd5d5f008ac", +"T2T1_en_test_msg_change_wipe_code_t2.py::test_set_wipe_code_to_pin": "ad2556a7039643c245a806903ca238f319e4f10eebd1bdeb70f940dd0dc60253", +"T2T1_en_test_msg_changepin_t2.py::test_change_failed": "ff2d1c551cbb02fad2fbfc328a8bee806b08762c46a2d7b0a8cfa1c6ffe8db1c", +"T2T1_en_test_msg_changepin_t2.py::test_change_invalid_current": "8cf2ad3141c918cf8b9e674fc36a136b04a5409de07e31873690e78645d72f78", +"T2T1_en_test_msg_changepin_t2.py::test_change_pin": "936544171a15442acaafc4dc41705ba66b9c7f555425e5351ebe6d86ac97cd07", +"T2T1_en_test_msg_changepin_t2.py::test_remove_pin": "336b045c7027e5c414114e7ac7f3e5282377b7eafba0286aaa31e679500892a1", +"T2T1_en_test_msg_changepin_t2.py::test_set_failed": "060a0bfdc9a85e607529a709308e385becb4fe09aa21036b6e4a82849ac3ad07", +"T2T1_en_test_msg_changepin_t2.py::test_set_pin": "9285726fdc7b133c671cd6b2cab37e000d1127ffc1d2ae51103b59149ceb8052", "T2T1_en_test_msg_loaddevice.py::test_load_device_1": "925e0161d96bc9e1c9ff32a682d8eda903d0c5a3b9750d0c6ae2814d159b8e82", "T2T1_en_test_msg_loaddevice.py::test_load_device_2": "e9fe2a1843fe8416fe371a98739dfa1a4959363efb47676a908b78e7975e2e49", "T2T1_en_test_msg_loaddevice.py::test_load_device_slip39_advanced": "925e0161d96bc9e1c9ff32a682d8eda903d0c5a3b9750d0c6ae2814d159b8e82", @@ -5324,41 +5325,41 @@ "T2T1_en_test_msg_sd_protect.py::test_enable_disable": "fa93d1ab8cb72e41a4244f6504c918a3ce660a421e7f7f2e1ee1547ce40272e9", "T2T1_en_test_msg_sd_protect.py::test_refresh": "ff0c0c9c3d840cf39974348fade80373626404a6c1c347e7fea0aeaae9fdfe66", "T2T1_en_test_msg_sd_protect.py::test_wipe": "d3ebaf352f40c243f085f600eacb75c08a7f12d86ebf4b16a33c010ad2cdeebd", -"T2T1_en_test_msg_wipedevice.py::test_autolock_not_retained": "1c594c9efcb9fb5dcb7500e0ba0112bdf9e4f9c52bb576a0154da32bc619647e", +"T2T1_en_test_msg_wipedevice.py::test_autolock_not_retained": "cead5ad28ada8e2629a6df0dbd3e9791113350b34a07de16e0ce9aa2e5f98f37", "T2T1_en_test_msg_wipedevice.py::test_wipe_device": "6aa7c9a4b9599d7abfe8ce860ede3830249ef3e68b312036df1e32dcd48ea6b8", "T2T1_en_test_passphrase_slip39_advanced.py::test_128bit_passphrase": "ea3a5d16e7b1e6b9bf8c45132f935a9af3149a3b280a83eace2297c675a7a917", "T2T1_en_test_passphrase_slip39_advanced.py::test_256bit_passphrase": "ea3a5d16e7b1e6b9bf8c45132f935a9af3149a3b280a83eace2297c675a7a917", "T2T1_en_test_passphrase_slip39_basic.py::test_2of3_ext_passphrase": "363d8cc034a9afa5c1aabc271b8c71c4a117f94d97874738bfd496376d6738b6", "T2T1_en_test_passphrase_slip39_basic.py::test_2of5_passphrase": "363d8cc034a9afa5c1aabc271b8c71c4a117f94d97874738bfd496376d6738b6", "T2T1_en_test_passphrase_slip39_basic.py::test_3of6_passphrase": "363d8cc034a9afa5c1aabc271b8c71c4a117f94d97874738bfd496376d6738b6", -"T2T1_en_test_pin.py::test_correct_pin": "c34ff9e8d47a33013b85dc35c9b45804a557ad2240cb72853e2bac9e4dc94a6a", -"T2T1_en_test_pin.py::test_exponential_backoff_t2": "0f1457c286b7de45bb83e313eab0996bb31a09be8c7e5ca85da1c353413b8072", -"T2T1_en_test_pin.py::test_incorrect_pin_t2": "b3c18ef8c497015c357ff95c9290b8a8171e6e13bbae47858721d76fb3e57d94", +"T2T1_en_test_pin.py::test_correct_pin": "c4114a8be5e3e30af4edd2b20116211c809c8402169cbffcbf70f9c86bcc4ae0", +"T2T1_en_test_pin.py::test_exponential_backoff_t2": "09f432e9c800c751817f651868e7feab85fca74f2f6ccdfa03eeedfef8fe66bc", +"T2T1_en_test_pin.py::test_incorrect_pin_t2": "57441093f4eed7886ebceab9b6aa3fe79a4d9fd1a9e380c54746a96f14c6bbc9", "T2T1_en_test_pin.py::test_no_protection": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", -"T2T1_en_test_protection_levels.py::test_apply_settings": "1e274b007a153d6747cca00fb325bb1beaa013e0d2c1edfd8d5888748476b823", -"T2T1_en_test_protection_levels.py::test_change_pin_t2": "68d1b4e7c568f60628ca5c7df51827dd1b714d4624f0716497d934aa6a620c40", -"T2T1_en_test_protection_levels.py::test_get_address": "d4c60218fa775ea63b3211e147b9c8383729a8899fb6efdcdd4ccc471ea52bbc", -"T2T1_en_test_protection_levels.py::test_get_entropy": "9c0ff6d283a7f808ce4490d4a65dd4918c037ac82a200880d8918c6017b5c46a", -"T2T1_en_test_protection_levels.py::test_get_public_key": "d4c60218fa775ea63b3211e147b9c8383729a8899fb6efdcdd4ccc471ea52bbc", -"T2T1_en_test_protection_levels.py::test_initialize": "22017d4b6c83d0ae5b7bf0ff2b23b2df3199b38548a7c336d8d401d8c6d33615", +"T2T1_en_test_protection_levels.py::test_apply_settings": "eb5cf860adc79f5e1b091d1f86a25fbcf6ff5ff82897f08b6a72afaa3e9066bf", +"T2T1_en_test_protection_levels.py::test_change_pin_t2": "324b841db1f36b5c97c80b0d751ba875b04177533ef254c93cb5c34bcad62279", +"T2T1_en_test_protection_levels.py::test_get_address": "f7c3b44c00d2526340c7230993fb770052f164d078624460d7a1a3e0acd979a1", +"T2T1_en_test_protection_levels.py::test_get_entropy": "c45e9c878a2abb6b329cba8784edfd93234b4e99dc7999c12114a0315abc76b2", +"T2T1_en_test_protection_levels.py::test_get_public_key": "f7c3b44c00d2526340c7230993fb770052f164d078624460d7a1a3e0acd979a1", +"T2T1_en_test_protection_levels.py::test_initialize": "7275aacf62695434e1669d5d12fa87caf760d812816a34b9bfa209f659018670", "T2T1_en_test_protection_levels.py::test_passphrase_cached": "9f6ac3223d0de11cb0498a7c640ba23787345f075228cd7ed45cf4a1e288ec4d", -"T2T1_en_test_protection_levels.py::test_passphrase_reporting[False]": "7ebb8fe84870cf971030f9c9c7029ce1e44b64cfb22dc764068c11a4bc5749ba", -"T2T1_en_test_protection_levels.py::test_passphrase_reporting[True]": "e46e761a46b835168ca8b046a07ea7dfc2a8f0898703e752e245561d574df609", +"T2T1_en_test_protection_levels.py::test_passphrase_reporting[False]": "db61f6c29aee2171fe18acbc29f5ea4f97617bde979ff4463e6b7be461882a43", +"T2T1_en_test_protection_levels.py::test_passphrase_reporting[True]": "ba309ceeef35df842ea7b630f3d2690a7c64feb2ba5cef1d2201b2368bcc1fbe", "T2T1_en_test_protection_levels.py::test_ping": "6c42efb29e5a843dd3322ab7e9f8a4ddd0eae2d6d25935c7f4c30cc657c8d936", -"T2T1_en_test_protection_levels.py::test_sign_message": "007578a302bdc155ecfb76928df3d339dfd2c1b7570ce815187ba8593c83d249", -"T2T1_en_test_protection_levels.py::test_signtx": "0fea2d371cdc37c7bf183c66bce9c9b247fcfc7407b277490f2f53717c48b4bf", -"T2T1_en_test_protection_levels.py::test_unlocked": "a49f6dd35508854f0cdd5fc3027df48e6d15eded24a677a9ed5edc874329dacd", -"T2T1_en_test_protection_levels.py::test_verify_message_t2": "a48f80b4e382f9cacd716f1fcd9cc325697c397db79e209f847b9c03d0e72eb9", -"T2T1_en_test_protection_levels.py::test_wipe_device": "239edbff14ad4d73bc99f95d83e8b42190766a7d97fb62537d94184685156428", -"T2T1_en_test_repeated_backup.py::test_repeated_backup": "b55062c2977f345af5d24fee6d1de212a7ff2edd1b9e61969259c1872546924a", -"T2T1_en_test_repeated_backup.py::test_repeated_backup_cancel": "06dc6421f627a6921ca660fb58a271b0fdb0b77fb72dbce4346b2e1590820189", -"T2T1_en_test_repeated_backup.py::test_repeated_backup_send_disallowed_message": "06dc6421f627a6921ca660fb58a271b0fdb0b77fb72dbce4346b2e1590820189", -"T2T1_en_test_repeated_backup.py::test_repeated_backup_upgrade_single": "1f4f782a41804810c0727a31ba4b29507cca84a1ea06d44b3dc738e438772421", +"T2T1_en_test_protection_levels.py::test_sign_message": "bbafe37d4e9e4fcc51e0a3b113374916f68d7d77ec07d404595f344ee47022bb", +"T2T1_en_test_protection_levels.py::test_signtx": "5530a8a89f48b5c0579040299e2ef4da4bcd6a1dd520cb3e661d37cba1e0f226", +"T2T1_en_test_protection_levels.py::test_unlocked": "56233ab170459b5ae3e8827a83c26e5274c1d87a8d42c6869093c9f345d3b56f", +"T2T1_en_test_protection_levels.py::test_verify_message_t2": "d64ff1dd608ecf7e59fd71de6987fd3bf68307f5b403378749d24e8efa161630", +"T2T1_en_test_protection_levels.py::test_wipe_device": "f7128d0929e6b2d701a1e891a328dffa71acb14b56ae8ae256d41a44d67f40cd", +"T2T1_en_test_repeated_backup.py::test_repeated_backup": "b720ca96a106c895640d7c7a0031de175f2b5bdb6e7af80b11fb725f663fe372", +"T2T1_en_test_repeated_backup.py::test_repeated_backup_cancel": "bf6d389313cc4ecd4af16c98dc19ce4306acaa4ef3e08ec34981d887818a7062", +"T2T1_en_test_repeated_backup.py::test_repeated_backup_send_disallowed_message": "bf6d389313cc4ecd4af16c98dc19ce4306acaa4ef3e08ec34981d887818a7062", +"T2T1_en_test_repeated_backup.py::test_repeated_backup_upgrade_single": "fd547bfd65c4bbeb26f6487adbbd07b299e962f56a2c39cf2d72816b8521c393", "T2T1_en_test_sdcard.py::test_sd_format": "bb4abf3cbc25e1899e9588beb0d78be362fb3b60980d6d9eeffe87cc37c3192d", "T2T1_en_test_sdcard.py::test_sd_no_format": "37358128f7a653920355f51f017859ebcf7ce5450cca9a820842d3982310704e", -"T2T1_en_test_sdcard.py::test_sd_protect_unlock": "147dc4cf6a8d2bfb007793f8c5b142deef60a600143025059b20379d75358c65", +"T2T1_en_test_sdcard.py::test_sd_protect_unlock": "47367ad9c650396a64e8ec5cd55edbdfc07b29f2a0f8187747c25cf034fea0e0", "T2T1_en_test_session.py::test_cannot_resume_ended_session": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", -"T2T1_en_test_session.py::test_clear_session": "13480d22ae643d4f6ce4efde8f996df89a5e4e88d4801fea42846a95f375e5b0", +"T2T1_en_test_session.py::test_clear_session": "3e82537c70910b9f308e8c4c45c909c45b38a3b1151f438999898eceeeabe77d", "T2T1_en_test_session.py::test_derive_cardano_empty_session": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_session.py::test_derive_cardano_running_session": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", "T2T1_en_test_session.py::test_end_session": "8b1ccc0dbd6e6e3d02a896650ab90dd332ba4edbbcc4095e0fbb6a96e5256f75", @@ -9794,14 +9795,14 @@ "T2T1_pt_zcash-test_sign_tx.py::test_version_group_id_missing": "e9e80e1bfd347b598699a7a84deb8932eff90fc5fb8b56771453e06fe3c4c216" }, "persistence_tests": { -"T2T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptAlways--081810a6": "d1e7905797c25c34e1be947b2cbcb655b4497068e6e6bdaccf88a1c821339085", -"T2T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptTempora-b3d21f4a": "d69338d8911fc6f4d4a9b4f23297537a9be60975694572d6c8cd1272c55d2af9", -"T2T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "458e2c14abebb664a50fef4082d8933d0db7c167439bcd7ffb11b7fca75dfb66", +"T2T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptAlways--081810a6": "496325a548e40423a74b81d4ca79d45012348295610b16955ead89fd33b32e8d", +"T2T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptTempora-b3d21f4a": "a249dbc9924eb8adcd7b5f8284467dbea1b07a7757582d47727482cd7c67340d", +"T2T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "bd79312a0b9be78de1bdc77b9d1950e6a751d17f74b2768957736466333e5216", "T2T1_en_test_shamir_persistence.py::test_abort": "2ff31047e86de855d4142040cb55fc38afffb4f7ac2a6ce623c77470c5255766", "T2T1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "f353c194c6e7a13b6950b2a6f98056b0f3596170135c34123c696db746b7834d", "T2T1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "8405ef6de2513430b62149c70baec92a20def9a8b17206f45870ed4ed85b7c6b", "T2T1_en_test_shamir_persistence.py::test_recovery_single_reset": "f83817579bb27bf151c0907ae2bd952a94d1939dce26215f5d97ddf5b25a5668", -"T2T1_en_test_wipe_code.py::test_wipe_code_activate_core": "80875ba4bae6dcd0a4a3d1438528bfe99c5bc4567ea6c357e59836271f265282" +"T2T1_en_test_wipe_code.py::test_wipe_code_activate_core": "b503edbc69df9dbeebd37e5a169c7342fab6a0924cff0ac71268361f7ee08091" } }, "T3B1": { @@ -18332,13 +18333,13 @@ "T3B1_pt_zcash-test_sign_tx.py::test_version_group_id_missing": "6ba5ca7223cd8ad675e081407f186acdfc8420304eea96de0fde5eda45ef0a57" }, "persistence_tests": { -"T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptAlways--081810a6": "b13a3b6ab6fd088f72cdd0c0875892152773780e0ae36b194ff89f4cb418f103", -"T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptTempora-b3d21f4a": "b2db0657b53a40b5ca32ba0bd5873bde18bb2d385c909518580d93398ee2bfe2", -"T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "4a4ac39daacae7e3e15d3e818f5e3bced74eb96bf627cde1b67172abf81cdfcf", +"T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptAlways--081810a6": "61286ef2d28b704da9679e88a25898685208a2b2c18744636af1d434d96b8685", +"T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptTempora-b3d21f4a": "6f824ca693625c2b138b84fde399ae7c024776791f4ec43d6d62faaace228572", +"T3B1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "3abba184f39ab2b8d7c90522311c8e2f46c9ad83705c731567dc1b16e8f768b3", "T3B1_en_test_shamir_persistence.py::test_abort": "229c6b2ff4b5fd3a3ca76b71bed897d8a86612a407c687c8f9a6a02c485f8bbe", -"T3B1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "12316220c45ea93bdaed2ca43fb1c63d533c5a58425ef6338ebcc87c45519a83", -"T3B1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "e0c5b66688f9624fcdc0e9ae473b63bbd9dd11b07dc4e97043b0247a81bbe985", -"T3B1_en_test_shamir_persistence.py::test_recovery_single_reset": "6c4ef8b4d455afd581eb66a7538b798fa5b5068821032342212308ff49a762d9", +"T3B1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "91fbc18c7c365ac85fa7fe347ff2e8da0e11a08f956d660a185f75ea85d33cf7", +"T3B1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "a4b048440ff852f7c7428bdc770378583c64763a238a4e27d91adc5424f1e708", +"T3B1_en_test_shamir_persistence.py::test_recovery_single_reset": "a4042c71b10bd465ef095733c0dc3691b3f0315d023e4bdee04a8afa8d501bb9", "T3B1_en_test_wipe_code.py::test_wipe_code_activate_core": "cc3b6cabf915701dafcd7768501fa19c06b8cbc57df4e3f8dcfc7786fc32ea0b" } }, @@ -27048,9 +27049,9 @@ "T3T1_pt_zcash-test_sign_tx.py::test_version_group_id_missing": "e8eb9b57d62689b40a58c053da11694819eb927d8b82b505c6487505cb60a889" }, "persistence_tests": { -"T3T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptAlways--081810a6": "b05832a6be652be0da7d63e6f29d03ec70c4d13c71d0f0140fb7b76970cbb9c6", -"T3T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptTempora-b3d21f4a": "f3bb5f6e8812aa5bd227209caca1db8b0cf339b3527cce200160cd141765d639", -"T3T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "2dd96d140aa4e025fb5accbd3bff3a565792474f9e3241f2a4b12a1fe5b58012", +"T3T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptAlways--081810a6": "a4a8610673138a223f560eb6ef1e516654e755b61fd5b0a15512d46e01f70226", +"T3T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.PromptTempora-b3d21f4a": "af81f7fb461c5b3387642d3b9cdd87da4ecb4f6ac49e07b7712ed9a6ad615880", +"T3T1_en_test_safety_checks.py::test_safety_checks_level_after_reboot[SafetyCheckLevel.Strict-Safety-f1ff9c26": "beda01c27152657546a1ebde8bcf9893a3170b858a09a7c9bf5b1dd094be0f6c", "T3T1_en_test_shamir_persistence.py::test_abort": "bd699a53a1a7280a9106222c61dffc4a8967afeef8c5141485bd71eb06daf3de", "T3T1_en_test_shamir_persistence.py::test_recovery_multiple_resets": "dfd6c4bc7168d3dfb1c068dcfe8c4768343d4bd1b3522c54eaf487644a6dc835", "T3T1_en_test_shamir_persistence.py::test_recovery_on_old_wallet": "4fa6d5b04fe395594fdb1d428120647f127e23b9e6fd72b0bc5c6b12d618e76d", diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 5d8c8778650..49cd73d6c82 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -20,7 +20,9 @@ import pytest from shamir_mnemonic import shamir -from trezorlib import btc, debuglink, device, exceptions, fido, models +from trezorlib import btc, debuglink, device, exceptions, fido, messages, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import ( ApplySettings, BackupAvailability, @@ -57,15 +59,19 @@ @for_all() def test_upgrade_load(gen: str, tag: str) -> None: def asserts(client: "Client"): + client.refresh_features() assert not client.features.pin_protection assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert ( + btc.get_address(client.get_session(passphrase=""), "Bitcoin", PATH) + == ADDRESS + ) with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, @@ -89,12 +95,14 @@ def asserts(client: "Client") -> None: assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + session = Session(client.get_session()) + with client, session: + client.use_pin_sequence([PIN]) + assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + Session(emu.client.get_seedless_session()), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -130,11 +138,11 @@ def asserts(client: "Client") -> None: assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tags[0]) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -164,11 +172,11 @@ def asserts(client: "Client"): assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -177,7 +185,9 @@ def asserts(client: "Client"): # Set wipe code. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(emu.client) + session = Session(emu.client.get_seedless_session()) + session.refresh_features() + device.change_wipe_code(session) device_id = emu.client.features.device_id asserts(emu.client) @@ -189,11 +199,13 @@ def asserts(client: "Client"): # Check that wipe code is set by changing the PIN to it. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) + session = Session(emu.client.get_seedless_session()) + session.refresh_features() with pytest.raises( exceptions.TrezorFailure, match="The new PIN must be different from your wipe code", ): - return device.change_pin(emu.client) + return device.change_pin(session) @for_all("legacy") @@ -209,7 +221,7 @@ def asserts(client: "Client"): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -219,13 +231,13 @@ def asserts(client: "Client"): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all() @@ -241,7 +253,7 @@ def asserts(client: "Client"): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -252,13 +264,13 @@ def asserts(client: "Client"): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all(legacy_minimum_version=(1, 7, 2)) @@ -274,7 +286,7 @@ def asserts(client: "Client"): with EmulatorWrapper(gen, tag) as emu: device.setup( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -286,13 +298,13 @@ def asserts(client: "Client"): device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address # Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9. @@ -305,7 +317,7 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): emu.client.watch_layout(True) debug = device_handler.debuglink() - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery_old.confirm_recovery(debug) recovery_old.select_number_of_words(debug) @@ -350,9 +362,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): @for_all("core", core_minimum_version=(2, 1, 9)) def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with EmulatorWrapper(gen, tag) as emu: + session = Session(emu.client.get_seedless_session()) # Generate a new encrypted master secret and record it. device.setup( - emu.client, + session, pin_protection=False, skip_backup=True, backup_type=BackupType.Slip39_Basic, @@ -363,14 +376,16 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): mnemonic_secret = emu.client.debug.state().mnemonic_secret # Set passphrase_source = HOST. - resp = emu.client.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) + session = Session(emu.client.get_session()) + resp = session.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) assert isinstance(resp, Success) # Get a passphrase-less and a passphrased address. - address = btc.get_address(emu.client, "Bitcoin", PATH) - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - address_passphrase = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(session, "Bitcoin", PATH) + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + session.call(messages.Initialize(new_session=True)) + new_session = emu.client.get_session(passphrase="TREZOR") + address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required storage = emu.get_storage() @@ -383,7 +398,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with emu.client: IF = InputFlowSlip39BasicBackup(emu.client, False) emu.client.set_input_flow(IF.get()) - device.backup(emu.client) + device.backup(emu.client.get_session()) assert ( emu.client.features.backup_availability == BackupAvailability.NotAvailable ) @@ -404,10 +419,13 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): assert ems.ciphertext == mnemonic_secret # Check that addresses are the same after firmware upgrade and backup. - assert btc.get_address(emu.client, "Bitcoin", PATH) == address - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - assert btc.get_address(emu.client, "Bitcoin", PATH) == address_passphrase + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address + assert ( + btc.get_address( + emu.client.get_session(passphrase="TREZOR"), "Bitcoin", PATH + ) + == address_passphrase + ) @for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9)) @@ -415,21 +433,21 @@ def test_upgrade_u2f(gen: str, tag: str): """Check U2F counter stayed the same after an upgrade.""" with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_seedless_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, label=LABEL, ) + session = emu.client.get_seedless_session() + fido.set_counter(session, 10) - fido.set_counter(emu.client, 10) - - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 11 storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 12 diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index a368c75bc50..e2b85453b5d 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -20,6 +20,8 @@ from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper @@ -47,13 +49,14 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]: with EmulatorWrapper(gen, tag) as emu: # set up a passphrase-protected device device.setup( - emu.client, + emu.client.get_seedless_session(), pin_protection=False, skip_backup=True, entropy_check_count=0, backup_type=messages.BackupType.Bip39, ) - resp = emu.client.call( + emu.client.invalidate() + resp = emu.client.get_seedless_session().call( ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST) ) assert isinstance(resp, messages.Success) @@ -89,11 +92,10 @@ def test_passphrase_works(emulator: Emulator): messages.ButtonRequest, messages.Address, ] - - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + emu_session = emulator.client.get_session(passphrase="TREZOR") + with Session(emu_session) as session: + session.set_expected_responses(expected_responses) + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) @for_all( @@ -133,13 +135,18 @@ def test_init_device(emulator: Emulator): messages.Address, ] - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) + emu_session = emulator.client.get_session(passphrase="TREZOR") + with Session(emu_session) as session: + session.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest - session_id = emulator.client.session_id - emulator.client.init_device() - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) - assert session_id == emulator.client.session_id + session_id = session.id + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + session.call(messages.Initialize(session_id=session_id)) + btc.get_address( + session, + "Testnet", + parse_path("44h/1h/0h/0/0"), + ) + assert session_id == session.id diff --git a/vendor/fido2-tests b/vendor/fido2-tests index 737b4960c98..42f810c2060 160000 --- a/vendor/fido2-tests +++ b/vendor/fido2-tests @@ -1 +1 @@ -Subproject commit 737b4960c98b4877653c77ff97a0bb5cfc319213 +Subproject commit 42f810c20602fe25d221cd79c2983a37816b476f