Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion aioesphomeapi/_frame_helper/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,30 @@ def _handle_handshake(self, msg: bytes) -> None:
if msg[0] != 0:
self._error_on_incorrect_preamble(msg)
return
self._proto.read_message(msg[1:])
try:
self._proto.read_message(msg[1:])
except InvalidTag as exc:
# The peer's handshake response failed AEAD authentication. Either
# the PSK doesn't match or the ciphertext was tampered with. ESPHome
# firmware normally rejects with the dedicated preamble=0x01
# "Handshake MAC failure" frame, so reaching this path means the
# peer is buggy or hostile; surface the same friendly error the
# named-failure branch raises.
key_err = InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key",
self._server_name,
self._server_mac,
)
key_err.__cause__ = exc
self._handle_error_and_close(key_err)
return
except Exception as exc:
handshake_err = HandshakeAPIError(
f"{self._log_name}: Handshake failed: {exc}"
)
handshake_err.__cause__ = exc
self._handle_error_and_close(handshake_err)
return
self._state = NOISE_STATE_READY
noise_protocol = self._proto.noise_protocol
self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member
Expand Down
62 changes: 62 additions & 0 deletions tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,68 @@ async def test_noise_frame_helper_empty_handshake_frame():
await helper.ready_future


async def test_noise_frame_helper_handshake_invalid_tag() -> None:
"""Handshake body with valid preamble but bogus AEAD payload surfaces as InvalidEncryptionKeyAPIError."""
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
log_name="test",
expected_mac=None,
)

await asyncio.sleep(0) # let the task run to read the hello packet
hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0")
mock_data_received(helper, hello_pkt_with_header)

# preamble=0x00 (handshake), then 48 bytes of garbage that look like a
# valid-length NN responder reply (32-byte ephemeral pubkey + 16-byte MAC)
# but fail authentication. Without explicit InvalidTag handling the
# cryptography library exception would propagate raw to the caller.
bogus_handshake = b"\x00" + (b"\xff" * 48)
pkt_len = len(bogus_handshake)
handshake_pkt_with_header = (
bytes((0x01, (pkt_len >> 8) & 0xFF, pkt_len & 0xFF)) + bogus_handshake
)
mock_data_received(helper, handshake_pkt_with_header)

with pytest.raises(InvalidEncryptionKeyAPIError) as exc_info:
await helper.ready_future
assert exc_info.value.received_name == "servicetest"


async def test_noise_frame_helper_handshake_other_noise_error():
"""A non-InvalidTag noise lib exception is wrapped as HandshakeAPIError."""
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
log_name="test",
expected_mac=None,
)

await asyncio.sleep(0)
hello_pkt_with_header = _make_noise_hello_pkt(b"\x01servicetest\0")
mock_data_received(helper, hello_pkt_with_header)

# preamble=0x00 then a too-short body. The noise library raises
# NoiseInvalidMessage (not InvalidTag) when the responder reply is
# shorter than the protocol's e + ee size.
bogus_handshake = b"\x00\x01\x02"
pkt_len = len(bogus_handshake)
handshake_pkt_with_header = (
bytes((0x01, (pkt_len >> 8) & 0xFF, pkt_len & 0xFF)) + bogus_handshake
)
mock_data_received(helper, handshake_pkt_with_header)

with pytest.raises(HandshakeAPIError, match="Handshake failed"):
await helper.ready_future


async def test_noise_frame_helper_wrong_protocol():
"""Test noise with the wrong protocol."""
connection, _ = _make_mock_connection()
Expand Down
Loading