Skip to content

Commit ff53138

Browse files
committed
Implement an easier wrapper around InitializeSecurityContext and use it
Since calling it is a lot easier now, it seems to actually finish negotiating a connection.. sometimes
1 parent 25c4bb9 commit ff53138

1 file changed

Lines changed: 77 additions & 67 deletions

File tree

src/windows/SChannelConnection.cpp

Lines changed: 77 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
#include <schnlsp.h>
55
#include <assert.h>
66
#include <algorithm>
7+
#include <memory>
78

9+
#include "common/config.h"
810
#include "SChannelConnection.h"
911

1012
#ifndef SCH_USE_STRONG_CRYPTO
@@ -70,6 +72,57 @@ SChannelConnection::~SChannelConnection()
7072
}
7173
}
7274

75+
SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, const std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
76+
{
77+
std::array<SecBuffer, 2> recvBuffers;
78+
recvBuffers[0].BufferType = SECBUFFER_TOKEN;
79+
recvBuffers[0].pvBuffer = outputBuffer.data();
80+
recvBuffers[0].cbBuffer = outputBuffer.size();
81+
82+
recvBuffers[1].BufferType = SECBUFFER_EMPTY;
83+
recvBuffers[1].pvBuffer = nullptr;
84+
recvBuffers[1].cbBuffer = 0;
85+
86+
SecBuffer sendBuffer;
87+
sendBuffer.BufferType = SECBUFFER_TOKEN;
88+
sendBuffer.pvBuffer = const_cast<char*>(inputBuffer.data());
89+
sendBuffer.cbBuffer = inputBuffer.size();
90+
91+
SecBufferDesc recvBufferDesc, sendBufferDesc;
92+
recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
93+
recvBufferDesc.pBuffers = &recvBuffers[0];
94+
recvBufferDesc.cBuffers = recvBuffers.size();
95+
96+
if (inputBuffer.size() > 0)
97+
{
98+
sendBufferDesc.pBuffers = &sendBuffer;
99+
sendBufferDesc.cBuffers = 1;
100+
}
101+
else
102+
{
103+
sendBufferDesc.pBuffers = nullptr;
104+
sendBufferDesc.cBuffers = 0;
105+
}
106+
107+
CtxtHandle* phOldContext = nullptr;
108+
CtxtHandle* phNewContext = nullptr;
109+
if (!phContext)
110+
{
111+
phContext = std::make_unique<CtxtHandle>();
112+
phNewContext = phContext.get();
113+
}
114+
else
115+
{
116+
phOldContext = phContext.get();
117+
}
118+
119+
auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
120+
121+
outputBuffer.resize(recvBuffers[0].cbBuffer);
122+
123+
return ret;
124+
}
125+
73126
bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
74127
{
75128
debug << "Trying to connect to " << hostname << ":" << port << "\n";
@@ -93,42 +146,30 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
93146
}
94147
debug << "Acquired handle\n";
95148

96-
CtxtHandle *context = new CtxtHandle;
97-
CtxtHandle *inHandle = nullptr, *outHandle = context;
98-
99-
SecBufferDesc inputBuffer, outputBuffer;
100-
inputBuffer.ulVersion = outputBuffer.ulVersion = SECBUFFER_VERSION;
101-
inputBuffer.cBuffers = outputBuffer.cBuffers = 0;
102-
inputBuffer.pBuffers = outputBuffer.pBuffers = nullptr;
103-
104-
ULONG contextAttr;
105149

106150
static constexpr size_t bufferSize = 8192;
107151
bool done = false, success = false, contextCreated = false;
108-
char *recvBuffer = nullptr;
109-
char *sendBuffer = new char[2*bufferSize];
110-
111-
SecBuffer recvSecBuffer, sendSecBuffer;
112-
recvSecBuffer.BufferType = sendSecBuffer.BufferType = SECBUFFER_TOKEN;
113-
sendSecBuffer.cbBuffer = bufferSize;
114-
sendSecBuffer.pvBuffer = sendBuffer;
115152

116-
outputBuffer.cBuffers = 1;
117-
outputBuffer.pBuffers = &sendSecBuffer;
153+
ULONG contextAttr;
154+
std::unique_ptr<CtxtHandle> context;
155+
std::vector<char> inputBuffer;
156+
std::vector<char> outputBuffer;
118157

119158
do
120159
{
160+
outputBuffer.resize(bufferSize);
161+
121162
bool recvData = false;
122-
auto ret = InitializeSecurityContext(&credHandle, inHandle, (char*) hostname.c_str(), ISC_REQ_STREAM, 0, 0, &inputBuffer, 0, outHandle, &outputBuffer, &contextAttr, nullptr);
163+
auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
123164
switch (ret)
124165
{
125-
case SEC_I_COMPLETE_NEEDED:
166+
/*case SEC_I_COMPLETE_NEEDED:
126167
case SEC_I_COMPLETE_AND_CONTINUE:
127-
if (CompleteAuthToken(outHandle, &outputBuffer) != SEC_E_OK)
168+
if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
128169
done = true;
129170
else if (ret == SEC_I_COMPLETE_NEEDED)
130171
success = done = true;
131-
break;
172+
break;*/
132173
case SEC_I_CONTINUE_NEEDED:
133174
recvData = true;
134175
break;
@@ -150,64 +191,33 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
150191
if (!done)
151192
contextCreated = true;
152193

153-
inHandle = context;
154-
outHandle = nullptr;
194+
debug << "Initialize done, with " << outputBuffer.size() << " bytes of output and status " << ret << "\n";
155195

156-
debug << "Initialize done, with " << outputBuffer.cBuffers << " output buffers and status " << ret << "\n";
157-
for (unsigned int i = 0; i < outputBuffer.cBuffers && !success; ++i)
158-
{
159-
auto &buffer = outputBuffer.pBuffers[i];
160-
debug << "\tBuffer of size: " << buffer.cbBuffer << "\n";
161-
if (buffer.cbBuffer > 0 && buffer.BufferType == SECBUFFER_TOKEN)
162-
{
163-
socket.write((const char*) buffer.pvBuffer, buffer.cbBuffer);
164-
}
165-
else
166-
debug << "Got buffer with type " << buffer.BufferType << "\n";
167-
168-
if (buffer.pvBuffer == sendBuffer)
169-
{
170-
memset(sendBuffer, 0, bufferSize);
171-
buffer.cbBuffer = bufferSize;
172-
}
173-
//FreeContextBuffer(&buffer);
174-
}
196+
if (outputBuffer.size() > 0)
197+
socket.write(outputBuffer.data(), outputBuffer.size());
175198

176199
if (recvData)
177200
{
178-
debug << "Receiving data\n";
179-
if (!recvBuffer)
180-
recvBuffer = new char[bufferSize];
181-
182-
recvSecBuffer.cbBuffer = socket.read(recvBuffer, bufferSize);
183-
recvSecBuffer.pvBuffer = recvBuffer;
201+
inputBuffer.resize(bufferSize);
202+
size_t actual = socket.read(inputBuffer.data(), bufferSize);
203+
inputBuffer.resize(actual);
184204

185-
inputBuffer.cBuffers = 1;
186-
inputBuffer.pBuffers = &recvSecBuffer;
187-
}
188-
else
189-
{
190-
inputBuffer.cBuffers = 0;
191-
inputBuffer.pBuffers = nullptr;
205+
debug << "Received " << actual << " bytes of data\n";
206+
if (actual == 0)
207+
{
208+
debug << "No data received, break\n";
209+
break;
210+
}
192211
}
193-
194212
// TODO: A bunch of frees?
195213
} while (!done);
196214

197-
delete[] sendBuffer;
198-
delete[] recvBuffer;
199-
200215
debug << "Done!\n";
201216
// TODO: Check resulting context attributes
202217
if (success)
203-
{
204-
this->context = static_cast<void*>(context);
205-
}
218+
this->context = static_cast<void*>(context.release());
206219
else if (contextCreated)
207-
{
208-
DeleteSecurityContext(context);
209-
delete context;
210-
}
220+
DeleteSecurityContext(context.get());
211221

212222
return success;
213223
}

0 commit comments

Comments
 (0)