diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index d45e534af..6dd6f5399 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -312,7 +312,30 @@ def _handle_handshake(self, msg: bytes) -> None: if msg[0] != 0: self._error_on_incorrect_preamble(msg) return - self._proto.read_message(msg[1:]) + try: + self._proto.read_message(msg[1:]) + except InvalidTag as exc: + # The peer's handshake response failed AEAD authentication. Either + # the PSK doesn't match or the ciphertext was tampered with. ESPHome + # firmware normally rejects with the dedicated preamble=0x01 + # "Handshake MAC failure" frame, so reaching this path means the + # peer is buggy or hostile; surface the same friendly error the + # named-failure branch raises. + key_err = InvalidEncryptionKeyAPIError( + f"{self._log_name}: Invalid encryption key", + self._server_name, + self._server_mac, + ) + key_err.__cause__ = exc + self._handle_error_and_close(key_err) + return + except Exception as exc: + handshake_err = HandshakeAPIError( + f"{self._log_name}: Handshake failed: {exc}" + ) + handshake_err.__cause__ = exc + self._handle_error_and_close(handshake_err) + return self._state = NOISE_STATE_READY noise_protocol = self._proto.noise_protocol self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 2d59dbe0a..fa493861d 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -990,6 +990,68 @@ async def test_noise_frame_helper_empty_handshake_frame(): await helper.ready_future +async def test_noise_frame_helper_handshake_invalid_tag() -> None: + """Handshake body with valid preamble but bogus AEAD payload surfaces as InvalidEncryptionKeyAPIError.""" + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="servicetest", + client_info="my client", + log_name="test", + expected_mac=None, + ) + + await asyncio.sleep(0) # let the task run to read the hello packet + hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") + mock_data_received(helper, hello_pkt_with_header) + + # preamble=0x00 (handshake), then 48 bytes of garbage that look like a + # valid-length NN responder reply (32-byte ephemeral pubkey + 16-byte MAC) + # but fail authentication. Without explicit InvalidTag handling the + # cryptography library exception would propagate raw to the caller. + bogus_handshake = b"\x00" + (b"\xff" * 48) + pkt_len = len(bogus_handshake) + handshake_pkt_with_header = ( + bytes((0x01, (pkt_len >> 8) & 0xFF, pkt_len & 0xFF)) + bogus_handshake + ) + mock_data_received(helper, handshake_pkt_with_header) + + with pytest.raises(InvalidEncryptionKeyAPIError) as exc_info: + await helper.ready_future + assert exc_info.value.received_name == "servicetest" + + +async def test_noise_frame_helper_handshake_other_noise_error(): + """A non-InvalidTag noise lib exception is wrapped as HandshakeAPIError.""" + connection, _ = _make_mock_connection() + helper = MockAPINoiseFrameHelper( + connection=connection, + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="servicetest", + client_info="my client", + log_name="test", + expected_mac=None, + ) + + await asyncio.sleep(0) + hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0") + mock_data_received(helper, hello_pkt_with_header) + + # preamble=0x00 then a too-short body. The noise library raises + # NoiseInvalidMessage (not InvalidTag) when the responder reply is + # shorter than the protocol's e + ee size. + bogus_handshake = b"\x00\x01\x02" + pkt_len = len(bogus_handshake) + handshake_pkt_with_header = ( + bytes((0x01, (pkt_len >> 8) & 0xFF, pkt_len & 0xFF)) + bogus_handshake + ) + mock_data_received(helper, handshake_pkt_with_header) + + with pytest.raises(HandshakeAPIError, match="Handshake failed"): + await helper.ready_future + + async def test_noise_frame_helper_wrong_protocol(): """Test noise with the wrong protocol.""" connection, _ = _make_mock_connection()