44#include < schnlsp.h>
55#include < assert.h>
66#include < algorithm>
7+ #include < memory>
78
9+ #include " common/config.h"
810#include " SChannelConnection.h"
911
1012#ifndef SCH_USE_STRONG_CRYPTO
@@ -70,6 +72,57 @@ SChannelConnection::~SChannelConnection()
7072 }
7173}
7274
75+ SECURITY_STATUS InitializeSecurityContext (CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq , const std::vector<char >& inputBuffer, std::vector<char >& outputBuffer, ULONG *pfContextAttr)
76+ {
77+ std::array<SecBuffer, 2 > recvBuffers;
78+ recvBuffers[0 ].BufferType = SECBUFFER_TOKEN;
79+ recvBuffers[0 ].pvBuffer = outputBuffer.data ();
80+ recvBuffers[0 ].cbBuffer = outputBuffer.size ();
81+
82+ recvBuffers[1 ].BufferType = SECBUFFER_EMPTY;
83+ recvBuffers[1 ].pvBuffer = nullptr ;
84+ recvBuffers[1 ].cbBuffer = 0 ;
85+
86+ SecBuffer sendBuffer;
87+ sendBuffer.BufferType = SECBUFFER_TOKEN;
88+ sendBuffer.pvBuffer = const_cast <char *>(inputBuffer.data ());
89+ sendBuffer.cbBuffer = inputBuffer.size ();
90+
91+ SecBufferDesc recvBufferDesc, sendBufferDesc;
92+ recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
93+ recvBufferDesc.pBuffers = &recvBuffers[0 ];
94+ recvBufferDesc.cBuffers = recvBuffers.size ();
95+
96+ if (inputBuffer.size () > 0 )
97+ {
98+ sendBufferDesc.pBuffers = &sendBuffer;
99+ sendBufferDesc.cBuffers = 1 ;
100+ }
101+ else
102+ {
103+ sendBufferDesc.pBuffers = nullptr ;
104+ sendBufferDesc.cBuffers = 0 ;
105+ }
106+
107+ CtxtHandle* phOldContext = nullptr ;
108+ CtxtHandle* phNewContext = nullptr ;
109+ if (!phContext)
110+ {
111+ phContext = std::make_unique<CtxtHandle>();
112+ phNewContext = phContext.get ();
113+ }
114+ else
115+ {
116+ phOldContext = phContext.get ();
117+ }
118+
119+ auto ret = InitializeSecurityContext (phCredential, phOldContext, const_cast <char *>(szTargetName.c_str ()), fContextReq , 0 , 0 , &sendBufferDesc, 0 , phNewContext, &recvBufferDesc, pfContextAttr, nullptr );
120+
121+ outputBuffer.resize (recvBuffers[0 ].cbBuffer );
122+
123+ return ret;
124+ }
125+
73126bool SChannelConnection::connect (const std::string &hostname, uint16_t port)
74127{
75128 debug << " Trying to connect to " << hostname << " :" << port << " \n " ;
@@ -93,42 +146,30 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
93146 }
94147 debug << " Acquired handle\n " ;
95148
96- CtxtHandle *context = new CtxtHandle;
97- CtxtHandle *inHandle = nullptr , *outHandle = context;
98-
99- SecBufferDesc inputBuffer, outputBuffer;
100- inputBuffer.ulVersion = outputBuffer.ulVersion = SECBUFFER_VERSION;
101- inputBuffer.cBuffers = outputBuffer.cBuffers = 0 ;
102- inputBuffer.pBuffers = outputBuffer.pBuffers = nullptr ;
103-
104- ULONG contextAttr;
105149
106150 static constexpr size_t bufferSize = 8192 ;
107151 bool done = false , success = false , contextCreated = false ;
108- char *recvBuffer = nullptr ;
109- char *sendBuffer = new char [2 *bufferSize];
110-
111- SecBuffer recvSecBuffer, sendSecBuffer;
112- recvSecBuffer.BufferType = sendSecBuffer.BufferType = SECBUFFER_TOKEN;
113- sendSecBuffer.cbBuffer = bufferSize;
114- sendSecBuffer.pvBuffer = sendBuffer;
115152
116- outputBuffer.cBuffers = 1 ;
117- outputBuffer.pBuffers = &sendSecBuffer;
153+ ULONG contextAttr;
154+ std::unique_ptr<CtxtHandle> context;
155+ std::vector<char > inputBuffer;
156+ std::vector<char > outputBuffer;
118157
119158 do
120159 {
160+ outputBuffer.resize (bufferSize);
161+
121162 bool recvData = false ;
122- auto ret = InitializeSecurityContext (&credHandle, inHandle, ( char *) hostname. c_str () , ISC_REQ_STREAM, 0 , 0 , & inputBuffer, 0 , outHandle, & outputBuffer, &contextAttr, nullptr );
163+ auto ret = InitializeSecurityContext (&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
123164 switch (ret)
124165 {
125- case SEC_I_COMPLETE_NEEDED:
166+ /* case SEC_I_COMPLETE_NEEDED:
126167 case SEC_I_COMPLETE_AND_CONTINUE:
127- if (CompleteAuthToken (outHandle , &outputBuffer) != SEC_E_OK)
168+ if (CompleteAuthToken(context.get() , &outputBuffer) != SEC_E_OK)
128169 done = true;
129170 else if (ret == SEC_I_COMPLETE_NEEDED)
130171 success = done = true;
131- break ;
172+ break;*/
132173 case SEC_I_CONTINUE_NEEDED:
133174 recvData = true ;
134175 break ;
@@ -150,64 +191,33 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
150191 if (!done)
151192 contextCreated = true ;
152193
153- inHandle = context;
154- outHandle = nullptr ;
194+ debug << " Initialize done, with " << outputBuffer.size () << " bytes of output and status " << ret << " \n " ;
155195
156- debug << " Initialize done, with " << outputBuffer.cBuffers << " output buffers and status " << ret << " \n " ;
157- for (unsigned int i = 0 ; i < outputBuffer.cBuffers && !success; ++i)
158- {
159- auto &buffer = outputBuffer.pBuffers [i];
160- debug << " \t Buffer of size: " << buffer.cbBuffer << " \n " ;
161- if (buffer.cbBuffer > 0 && buffer.BufferType == SECBUFFER_TOKEN)
162- {
163- socket.write ((const char *) buffer.pvBuffer , buffer.cbBuffer );
164- }
165- else
166- debug << " Got buffer with type " << buffer.BufferType << " \n " ;
167-
168- if (buffer.pvBuffer == sendBuffer)
169- {
170- memset (sendBuffer, 0 , bufferSize);
171- buffer.cbBuffer = bufferSize;
172- }
173- // FreeContextBuffer(&buffer);
174- }
196+ if (outputBuffer.size () > 0 )
197+ socket.write (outputBuffer.data (), outputBuffer.size ());
175198
176199 if (recvData)
177200 {
178- debug << " Receiving data\n " ;
179- if (!recvBuffer)
180- recvBuffer = new char [bufferSize];
181-
182- recvSecBuffer.cbBuffer = socket.read (recvBuffer, bufferSize);
183- recvSecBuffer.pvBuffer = recvBuffer;
201+ inputBuffer.resize (bufferSize);
202+ size_t actual = socket.read (inputBuffer.data (), bufferSize);
203+ inputBuffer.resize (actual);
184204
185- inputBuffer.cBuffers = 1 ;
186- inputBuffer.pBuffers = &recvSecBuffer;
187- }
188- else
189- {
190- inputBuffer.cBuffers = 0 ;
191- inputBuffer.pBuffers = nullptr ;
205+ debug << " Received " << actual << " bytes of data\n " ;
206+ if (actual == 0 )
207+ {
208+ debug << " No data received, break\n " ;
209+ break ;
210+ }
192211 }
193-
194212 // TODO: A bunch of frees?
195213 } while (!done);
196214
197- delete[] sendBuffer;
198- delete[] recvBuffer;
199-
200215 debug << " Done!\n " ;
201216 // TODO: Check resulting context attributes
202217 if (success)
203- {
204- this ->context = static_cast <void *>(context);
205- }
218+ this ->context = static_cast <void *>(context.release ());
206219 else if (contextCreated)
207- {
208- DeleteSecurityContext (context);
209- delete context;
210- }
220+ DeleteSecurityContext (context.get ());
211221
212222 return success;
213223}
0 commit comments