Implement send/receive for the new type.

This commit is contained in:
John Preston 2019-07-08 17:41:34 +02:00
parent 69b6b48738
commit 2f0331b2e0
7 changed files with 201 additions and 73 deletions

View file

@ -24,6 +24,7 @@ constexpr auto kPacketSizeMax = int(0x01000000 * sizeof(mtpPrime));
constexpr auto kFullConnectionTimeout = 8 * crl::time(1000);
constexpr auto kSmallBufferSize = 256 * 1024;
constexpr auto kMinPacketBuffer = 256;
constexpr auto kConnectionStartPrefixSize = 64;
} // namespace
@ -277,7 +278,7 @@ void TcpConnection::ensureAvailableInBuffer(int amount) {
void TcpConnection::socketRead() {
Expects(_leftBytes > 0 || !_usingLargeBuffer);
if (_socket->isConnected()) {
if (!_socket || !_socket->isConnected()) {
LOG(("MTP Error: Socket not connected in socketRead()"));
emit error(kErrorCodeOther);
return;
@ -295,11 +296,7 @@ void TcpConnection::socketRead() {
auto &buffer = _usingLargeBuffer ? _largeBuffer : _smallBuffer;
const auto full = bytes::make_span(buffer).subspan(_offsetBytes);
const auto free = full.subspan(_readBytes);
Assert(free.size() >= readLimit);
const auto readCount = _socket->read(
reinterpret_cast<char*>(free.data()),
readLimit);
const auto readCount = _socket->read(free.subspan(0, readLimit));
if (readCount > 0) {
const auto read = free.subspan(0, readCount);
aesCtrEncrypt(read, _receiveKey, &_receiveState);
@ -365,7 +362,9 @@ void TcpConnection::socketRead() {
TCP_LOG(("TCP Info: no bytes read, but bytes available was true..."));
break;
}
} while (_socket->isConnected() && _socket->hasBytesAvailable());
} while (_socket
&& _socket->isConnected()
&& _socket->hasBytesAvailable());
}
mtpBuffer TcpConnection::parsePacket(bytes::const_span bytes) {
@ -423,16 +422,31 @@ bool TcpConnection::requiresExtendedPadding() const {
void TcpConnection::sendData(mtpBuffer &&buffer) {
Expects(buffer.size() > 2);
if (_status != Status::Finished) {
sendBuffer(std::move(buffer));
if (!_socket) {
return;
}
char connectionStartPrefixBytes[kConnectionStartPrefixSize];
const auto connectionStartPrefix = prepareConnectionStartPrefix(
bytes::make_span(connectionStartPrefixBytes));
// buffer: 2 available int-s + data + available int.
const auto bytes = _protocol->finalizePacket(buffer);
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
aesCtrEncrypt(bytes, _sendKey, &_sendState);
_socket->write(connectionStartPrefix, bytes);
}
void TcpConnection::writeConnectionStart() {
bytes::const_span TcpConnection::prepareConnectionStartPrefix(
bytes::span buffer) {
Expects(_protocol != nullptr);
if (_connectionStarted) {
return {};
}
_connectionStarted = true;
// prepare random part
auto nonceBytes = bytes::vector(64);
char nonceBytes[64];
const auto nonce = bytes::make_span(nonceBytes);
const auto zero = reinterpret_cast<uchar*>(nonce.data());
@ -481,31 +495,17 @@ void TcpConnection::writeConnectionStart() {
const auto dcId = reinterpret_cast<int16*>(nonce.data() + 60);
*dcId = _protocolDcId;
_socket->write(reinterpret_cast<const char*>(nonce.data()), 56);
bytes::copy(buffer, nonce.subspan(0, 56));
aesCtrEncrypt(nonce, _sendKey, &_sendState);
_socket->write(
reinterpret_cast<const char*>(nonce.subspan(56).data()),
8);
bytes::copy(buffer.subspan(56), nonce.subspan(56));
return buffer;
}
void TcpConnection::sendBuffer(mtpBuffer &&buffer) {
if (!_connectionStarted) {
writeConnectionStart();
_connectionStarted = true;
}
// buffer: 2 available int-s + data + available int.
const auto bytes = _protocol->finalizePacket(buffer);
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
aesCtrEncrypt(bytes, _sendKey, &_sendState);
_socket->write(
reinterpret_cast<const char*>(bytes.data()),
bytes.size());
}
void TcpConnection::disconnectFromServer() {
if (_status == Status::Finished) return;
if (_status == Status::Finished) {
return;
}
_status = Status::Finished;
_connectedLifetime.destroy();
_lifetime.destroy();
@ -576,7 +576,7 @@ crl::time TcpConnection::fullConnectTimeout() const {
}
void TcpConnection::socketPacket(bytes::const_span bytes) {
if (_status == Status::Finished) return;
Expects(_socket != nullptr);
// old quickack?..
const auto data = parsePacket(bytes);
@ -620,7 +620,7 @@ bool TcpConnection::isConnected() const {
}
int32 TcpConnection::debugState() const {
return _socket->debugState();
return _socket ? _socket->debugState() : -1;
}
QString TcpConnection::transport() const {
@ -645,7 +645,9 @@ QString TcpConnection::tag() const {
}
void TcpConnection::socketError() {
if (_status == Status::Finished) return;
if (!_socket) {
return;
}
emit error(kErrorCodeOther);
}

View file

@ -48,7 +48,7 @@ private:
};
void socketRead();
void writeConnectionStart();
bytes::const_span prepareConnectionStartPrefix(bytes::span buffer);
void socketPacket(bytes::const_span bytes);
@ -58,14 +58,11 @@ private:
mtpBuffer parsePacket(bytes::const_span bytes);
void ensureAvailableInBuffer(int amount);
static void handleError(QAbstractSocket::SocketError e, QTcpSocket &sock);
static uint32 fourCharsToUInt(char ch1, char ch2, char ch3, char ch4) {
char ch[4] = { ch1, ch2, ch3, ch4 };
return *reinterpret_cast<uint32*>(ch);
}
void sendBuffer(mtpBuffer &&buffer);
std::unique_ptr<AbstractSocket> _socket;
bool _connectionStarted = false;

View file

@ -7,6 +7,8 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/
#pragma once
#include "base/bytes.h"
namespace MTP {
namespace internal {
@ -38,8 +40,10 @@ public:
virtual void connectToHost(const QString &address, int port) = 0;
[[nodiscard]] virtual bool isConnected() = 0;
[[nodiscard]] virtual bool hasBytesAvailable() = 0;
[[nodiscard]] virtual int64 read(char *buffer, int64 maxLength) = 0;
virtual int64 write(const char *buffer, int64 length) = 0;
[[nodiscard]] virtual int64 read(bytes::span buffer) = 0;
virtual void write(
bytes::const_span prefix,
bytes::const_span buffer) = 0;
virtual int32 debugState() = 0;

View file

@ -53,12 +53,23 @@ bool TcpSocket::hasBytesAvailable() {
return _socket.bytesAvailable() > 0;
}
int64 TcpSocket::read(char *buffer, int64 maxLength) {
return _socket.read(buffer, maxLength);
int64 TcpSocket::read(bytes::span buffer) {
return _socket.read(
reinterpret_cast<char*>(buffer.data()),
buffer.size());
}
int64 TcpSocket::write(const char *buffer, int64 length) {
return _socket.write(buffer, length);
void TcpSocket::write(bytes::const_span prefix, bytes::const_span buffer) {
Expects(!buffer.empty());
if (!prefix.empty()) {
_socket.write(
reinterpret_cast<const char*>(prefix.data()),
prefix.size());
}
_socket.write(
reinterpret_cast<const char*>(buffer.data()),
buffer.size());
}
int32 TcpSocket::debugState() {

View file

@ -12,15 +12,15 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
namespace MTP {
namespace internal {
class TcpSocket : public AbstractSocket {
class TcpSocket final : public AbstractSocket {
public:
TcpSocket(not_null<QThread*> thread, const ProxyData &proxy);
void connectToHost(const QString &address, int port) override;
bool isConnected() override;
bool hasBytesAvailable() override;
int64 read(char *buffer, int64 maxLength) override;
int64 write(const char *buffer, int64 length) override;
int64 read(bytes::span buffer) override;
void write(bytes::const_span prefix, bytes::const_span buffer) override;
int32 debugState() override;

View file

@ -22,7 +22,9 @@ const auto kServerHelloPart1 = qstr("\x16\x03\x03");
const auto kServerHelloPart3 = qstr("\x14\x03\x03\x00\x01\x01\x17\x03\x03");
constexpr auto kServerHelloDigestPosition = 11;
const auto kServerHeader = qstr("\x17\x03\x03");
constexpr auto kServerDataSkip = 5;
constexpr auto kClientPartSize = 2878;
const auto kClientPrefix = qstr("\x14\x03\x03\x00\x01\x01");
const auto kClientHeader = qstr("\x17\x03\x03");
[[nodiscard]] MTPTlsClientHello PrepareClientHelloRules() {
auto stack = std::vector<QVector<MTPTlsBlock>>();
@ -258,11 +260,11 @@ void ClientHelloGenerator::writeBlock(const MTPDtlsBlockDomain &data) {
}
void ClientHelloGenerator::writeBlock(const MTPDtlsBlockScope &data) {
const auto already = _result.size();
const auto storage = grow(kLengthSize);
if (storage.empty()) {
return;
}
const auto already = _result.size();
writeBlocks(data.ventries().v);
const auto length = qToBigEndian(uint16(_result.size() - already));
bytes::copy(storage, bytes::object_as_span(&length));
@ -383,7 +385,7 @@ void TlsSocket::plainConnected() {
static const auto kClientHelloRules = PrepareClientHelloRules();
const auto hello = PrepareClientHello(
kClientHelloRules,
"google.com",
"www.google.com",
_key);
if (hello.data.isEmpty()) {
LOG(("TLS Error: Could not generate Client Hello!"));
@ -400,14 +402,15 @@ void TlsSocket::plainDisconnected() {
_state = State::NotConnected;
_incoming = QByteArray();
_serverHelloLength = 0;
_incomingGoodDataOffset = 0;
_incomingGoodDataLimit = 0;
_disconnected.fire({});
}
void TlsSocket::plainReadyRead() {
switch (_state) {
case State::WaitingHello: return readHello();
case State::Ready:
case State::Working: return readData();
case State::Connected: return readData();
}
}
@ -481,8 +484,7 @@ void TlsSocket::checkHelloParts34(int parts123Size) {
}
void TlsSocket::checkHelloDigest() {
const auto incoming = bytes::make_detached_span(_incoming);
const auto fulldata = incoming.subspan(
const auto fulldata = bytes::make_detached_span(_incoming).subspan(
0,
kHelloDigestLength + _serverHelloLength);
const auto digest = fulldata.subspan(
@ -496,18 +498,70 @@ void TlsSocket::checkHelloDigest() {
handleError();
return;
}
if (incoming.size() > fulldata.size()) {
bytes::move(incoming, incoming.subspan(fulldata.size()));
_incoming.chop(fulldata.size());
InvokeQueued(this, [=] { readData(); });
} else {
_incoming.clear();
shiftIncomingBy(fulldata.size());
if (!_incoming.isEmpty()) {
InvokeQueued(this, [=] {
if (!checkNextPacket()) {
handleError();
}
});
}
_state = State::Ready;
_incomingGoodDataOffset = _incomingGoodDataLimit = 0;
_state = State::Connected;
_connected.fire({});
}
void TlsSocket::readData() {
if (!isConnected()) {
return;
}
_incoming.append(_socket.readAll());
if (!checkNextPacket()) {
handleError();
} else if (hasBytesAvailable()) {
_readyRead.fire({});
}
}
bool TlsSocket::checkNextPacket() {
auto offset = 0;
const auto incoming = bytes::make_span(_incoming);
while (!_incomingGoodDataLimit) {
const auto fullHeader = kServerHeader.size() + kLengthSize;
if (incoming.size() <= offset + fullHeader) {
return true;
}
if (!CheckPart(incoming.subspan(offset), kServerHeader)) {
LOG(("TLS Error: Bad packet header."));
return false;
}
const auto length = ReadPartLength(
incoming,
offset + kServerHeader.size());
if (length > 0) {
if (offset > 0) {
shiftIncomingBy(offset);
}
_incomingGoodDataOffset = fullHeader;
_incomingGoodDataLimit = length;
} else {
offset += kServerHeader.size() + kLengthSize + length;
}
}
return true;
}
void TlsSocket::shiftIncomingBy(int amount) {
Expects(_incomingGoodDataOffset == 0);
Expects(_incomingGoodDataLimit == 0);
const auto incoming = bytes::make_detached_span(_incoming);
if (incoming.size() > amount) {
bytes::move(incoming, incoming.subspan(amount));
_incoming.chop(amount);
} else {
_incoming.clear();
}
}
void TlsSocket::connectToHost(const QString &address, int port) {
@ -518,19 +572,76 @@ void TlsSocket::connectToHost(const QString &address, int port) {
}
bool TlsSocket::isConnected() {
return (_socket.state() == QAbstractSocket::ConnectedState);
return (_state == State::Connected);
}
bool TlsSocket::hasBytesAvailable() {
return _socket.bytesAvailable();
return (_incomingGoodDataLimit > 0)
&& (_incomingGoodDataOffset < _incoming.size());
}
int64 TlsSocket::read(char *buffer, int64 maxLength) {
return _socket.read(buffer, maxLength);
int64 TlsSocket::read(bytes::span buffer) {
auto written = int64(0);
while (_incomingGoodDataLimit) {
const auto available = std::min(
_incomingGoodDataLimit,
_incoming.size() - _incomingGoodDataOffset);
if (available <= 0) {
return written;
}
const auto write = std::min(index_type(available), buffer.size());
if (write <= 0) {
return written;
}
bytes::copy(
buffer,
bytes::make_span(_incoming).subspan(
_incomingGoodDataOffset,
write));
written += write;
buffer = buffer.subspan(write);
_incomingGoodDataLimit -= write;
_incomingGoodDataOffset += write;
if (_incomingGoodDataLimit) {
return written;
}
shiftIncomingBy(base::take(_incomingGoodDataOffset));
if (!checkNextPacket()) {
_state = State::Error;
InvokeQueued(this, [=] { handleError(); });
return written;
}
}
return written;
}
int64 TlsSocket::write(const char *buffer, int64 length) {
return _socket.write(buffer, length);
void TlsSocket::write(bytes::const_span prefix, bytes::const_span buffer) {
Expects(!buffer.empty());
if (!isConnected()) {
return;
}
if (!prefix.empty()) {
_socket.write(kClientPrefix.data(), kClientPrefix.size());
}
while (!buffer.empty()) {
const auto write = std::min(
kClientPartSize - prefix.size(),
buffer.size());
_socket.write(kClientHeader.data(), kClientHeader.size());
const auto size = qToBigEndian(uint16(prefix.size() + write));
_socket.write(reinterpret_cast<const char*>(&size), sizeof(size));
if (!prefix.empty()) {
_socket.write(
reinterpret_cast<const char*>(prefix.data()),
prefix.size());
prefix = bytes::const_span();
}
_socket.write(
reinterpret_cast<const char*>(buffer.data()),
write);
buffer = buffer.subspan(write);
}
}
int32 TlsSocket::debugState() {

View file

@ -12,7 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
namespace MTP {
namespace internal {
class TlsSocket : public AbstractSocket {
class TlsSocket final : public AbstractSocket {
public:
TlsSocket(
not_null<QThread*> thread,
@ -22,8 +22,8 @@ public:
void connectToHost(const QString &address, int port) override;
bool isConnected() override;
bool hasBytesAvailable() override;
int64 read(char *buffer, int64 maxLength) override;
int64 write(const char *buffer, int64 length) override;
int64 read(bytes::span buffer) override;
void write(bytes::const_span prefix, bytes::const_span buffer) override;
int32 debugState() override;
@ -32,8 +32,7 @@ private:
NotConnected,
Connecting,
WaitingHello,
Ready,
Working,
Connected,
Error,
};
@ -47,11 +46,15 @@ private:
void checkHelloParts34(int parts123Size);
void checkHelloDigest();
void readData();
[[nodiscard]] bool checkNextPacket();
void shiftIncomingBy(int amount);
QTcpSocket _socket;
bytes::vector _key;
State _state = State::NotConnected;
QByteArray _incoming;
int _incomingGoodDataOffset = 0;
int _incomingGoodDataLimit = 0;
int16 _serverHelloLength = 0;
};