Skip to content

Commit 4fa5e97

Browse files
committed
Merge branch 'ImportTableResolver' into MmpTlsFiber
2 parents 64675a3 + 73af5f1 commit 4fa5e97

File tree

11 files changed

+338
-190
lines changed

11 files changed

+338
-190
lines changed

MemoryModule/ImportTable.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#include "stdafx.h"
2+
3+
typedef struct _MMP_IAT_HANDLE {
4+
5+
HMODULE hModule;
6+
PMM_IAT_RESOLVER lpResolver;
7+
8+
}MMP_IAT_HANDLE, * PMMP_IAT_HANDLE;
9+
10+
HMODULE MmpLoadLibraryA(
11+
_In_ LPCSTR lpModuleName,
12+
_Out_ PMM_IAT_RESOLVER* lpModuleResolver) {
13+
14+
HMODULE hModule = nullptr;
15+
PMM_IAT_RESOLVER resolver = nullptr;
16+
17+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
18+
19+
PLIST_ENTRY lpResolver = MmpGlobalDataPtr->MmpIat->MmpIatResolverList.Flink;
20+
while (lpResolver != &MmpGlobalDataPtr->MmpIat->MmpIatResolverList) {
21+
PMM_IAT_RESOLVER entry = CONTAINING_RECORD(lpResolver, MM_IAT_RESOLVER, MM_IAT_RESOLVER::InMmpIatResolverList);
22+
23+
hModule = entry->LoadLibraryProv(lpModuleName);
24+
if (hModule) {
25+
resolver = entry;
26+
++entry->ReferenceCount;
27+
break;
28+
}
29+
30+
lpResolver = lpResolver->Flink;
31+
}
32+
33+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
34+
35+
*lpModuleResolver = resolver;
36+
return hModule;
37+
}
38+
39+
VOID MemoryFreeImportTable(_In_ PMEMORYMODULE hMemoryModule) {
40+
41+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
42+
43+
PMMP_IAT_HANDLE list = (PMMP_IAT_HANDLE)hMemoryModule->hModulesList;
44+
for (DWORD i = 0; i < hMemoryModule->dwModulesCount; ++i) {
45+
auto entry = list[i];
46+
entry.lpResolver->FreeLibraryProv(entry.hModule);
47+
--entry.lpResolver->ReferenceCount;
48+
}
49+
50+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
51+
52+
53+
RtlFreeHeap(NtCurrentPeb()->ProcessHeap, 0, hMemoryModule->hModulesList);
54+
hMemoryModule->hModulesList = nullptr;
55+
hMemoryModule->dwModulesCount = 0;
56+
}
57+
58+
NTSTATUS MemoryResolveImportTable(
59+
_In_ LPBYTE base,
60+
_In_ PIMAGE_NT_HEADERS lpNtHeaders,
61+
_In_ PMEMORYMODULE hMemoryModule) {
62+
NTSTATUS status = STATUS_SUCCESS;
63+
PIMAGE_IMPORT_DESCRIPTOR importDesc = nullptr;
64+
DWORD count = 0;
65+
66+
do {
67+
__try {
68+
PIMAGE_DATA_DIRECTORY dir = &lpNtHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
69+
PIMAGE_IMPORT_DESCRIPTOR iat = nullptr;
70+
71+
if (dir && dir->Size) {
72+
iat = importDesc = PIMAGE_IMPORT_DESCRIPTOR(lpNtHeaders->OptionalHeader.ImageBase + dir->VirtualAddress);
73+
}
74+
75+
if (iat) {
76+
while (iat->Name) {
77+
++count;
78+
++iat;
79+
}
80+
}
81+
82+
if (importDesc && count) {
83+
PMMP_IAT_HANDLE handles = (PMMP_IAT_HANDLE)RtlAllocateHeap(NtCurrentPeb()->ProcessHeap, HEAP_ZERO_MEMORY, sizeof(MMP_IAT_HANDLE) * count);
84+
hMemoryModule->hModulesList = handles;
85+
if (!hMemoryModule->hModulesList) {
86+
status = STATUS_NO_MEMORY;
87+
break;
88+
}
89+
90+
for (DWORD i = 0; i < count; ++i, ++importDesc) {
91+
uintptr_t* thunkRef;
92+
FARPROC* funcRef;
93+
PMM_IAT_RESOLVER resolver;
94+
HMODULE handle = MmpLoadLibraryA((LPCSTR)(base + importDesc->Name), &resolver);
95+
96+
if (!handle) {
97+
status = STATUS_DLL_NOT_FOUND;
98+
break;
99+
}
100+
101+
handles[hMemoryModule->dwModulesCount].hModule = handle;
102+
handles[hMemoryModule->dwModulesCount++].lpResolver = resolver;
103+
thunkRef = (uintptr_t*)(base + (importDesc->OriginalFirstThunk ? importDesc->OriginalFirstThunk : importDesc->FirstThunk));
104+
funcRef = (FARPROC*)(base + importDesc->FirstThunk);
105+
while (*thunkRef) {
106+
*funcRef = GetProcAddress(
107+
handle,
108+
IMAGE_SNAP_BY_ORDINAL(*thunkRef) ? (LPCSTR)IMAGE_ORDINAL(*thunkRef) : (LPCSTR)PIMAGE_IMPORT_BY_NAME(base + (*thunkRef))->Name
109+
);
110+
if (!*funcRef) {
111+
status = STATUS_ENTRYPOINT_NOT_FOUND;
112+
break;
113+
}
114+
++thunkRef;
115+
++funcRef;
116+
}
117+
118+
if (!NT_SUCCESS(status))break;
119+
}
120+
121+
}
122+
}
123+
__except (EXCEPTION_EXECUTE_HANDLER) {
124+
status = GetExceptionCode();
125+
}
126+
} while (false);
127+
128+
if (!NT_SUCCESS(status)) {
129+
MemoryFreeImportTable(hMemoryModule);
130+
}
131+
132+
return status;
133+
}
134+
135+
HANDLE WINAPI MmRegisterImportTableResolver(
136+
_In_ MM_IAT_RESOLVER_ENTRY LoadLibraryProv,
137+
_In_ MM_IAT_FREE_ENTRY FreeLibraryProv) {
138+
139+
HANDLE heap = RtlProcessHeap();
140+
PMM_IAT_RESOLVER resolver = (PMM_IAT_RESOLVER)RtlAllocateHeap(heap, 0, sizeof(MM_IAT_RESOLVER));
141+
142+
if (resolver) {
143+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
144+
145+
resolver->ReferenceCount = 1;
146+
resolver->LoadLibraryProv = LoadLibraryProv;
147+
resolver->FreeLibraryProv = FreeLibraryProv;
148+
InsertTailList(&MmpGlobalDataPtr->MmpIat->MmpIatResolverList, &resolver->InMmpIatResolverList);
149+
150+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
151+
}
152+
153+
return resolver;
154+
}
155+
156+
_Success_(return)
157+
BOOL WINAPI MmRemoveImportTableResolver(_In_ HANDLE hMmIatResolver) {
158+
159+
HANDLE heap = RtlProcessHeap();
160+
161+
if (hMmIatResolver == &MmpGlobalDataPtr->MmpIat->MmpIatResolverHead) {
162+
return FALSE;
163+
}
164+
165+
PMM_IAT_RESOLVER resolver = CONTAINING_RECORD(hMmIatResolver, MM_IAT_RESOLVER, MM_IAT_RESOLVER::InMmpIatResolverList);
166+
167+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
168+
169+
if (resolver->ReferenceCount > 1) {
170+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
171+
return FALSE;
172+
}
173+
174+
RemoveHeadList(&resolver->InMmpIatResolverList);
175+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
176+
177+
return RtlFreeHeap(heap, 0, hMmIatResolver);
178+
}

MemoryModule/ImportTable.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
typedef HMODULE(WINAPI* MM_IAT_RESOLVER_ENTRY)(LPCSTR lpModuleName);
4+
typedef BOOL(WINAPI* MM_IAT_FREE_ENTRY)(HMODULE hModule);
5+
6+
typedef struct _MM_IAT_RESOLVER {
7+
8+
LIST_ENTRY InMmpIatResolverList;
9+
10+
MM_IAT_RESOLVER_ENTRY LoadLibraryProv;
11+
MM_IAT_FREE_ENTRY FreeLibraryProv;
12+
13+
DWORD ReferenceCount;
14+
15+
}MM_IAT_RESOLVER, * PMM_IAT_RESOLVER;
16+
17+
VOID MemoryFreeImportTable(_In_ PMEMORYMODULE hMemoryModule);
18+
19+
NTSTATUS MemoryResolveImportTable(
20+
_In_ LPBYTE base,
21+
_In_ PIMAGE_NT_HEADERS lpNtHeaders,
22+
_In_ PMEMORYMODULE hMemoryModule
23+
);
24+
25+
HANDLE WINAPI MmRegisterImportTableResolver(
26+
_In_ MM_IAT_RESOLVER_ENTRY LoadLibraryProv,
27+
_In_ MM_IAT_FREE_ENTRY FreeLibraryProv
28+
);
29+
30+
_Success_(return)
31+
BOOL WINAPI MmRemoveImportTableResolver(_In_ HANDLE hMmIatResolver);

MemoryModule/Initialize.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
PMMP_GLOBAL_DATA MmpGlobalDataPtr;
77

8+
#if MEMORY_MODULE_IS_PREVIEW(MEMORY_MODULE_MINOR_VERSION)
9+
#pragma message("WARNING: You are using a preview version of MemoryModulePP.")
10+
#endif
11+
812
PRTL_RB_TREE FindLdrpModuleBaseAddressIndex() {
913
PRTL_RB_TREE LdrpModuleBaseAddressIndex = nullptr;
1014
PLDR_DATA_TABLE_ENTRY_WIN10 nt10 = decltype(nt10)(MmpGlobalDataPtr->MmpBaseAddressIndex->NtdllLdrEntry);
@@ -427,8 +431,10 @@ NTSTATUS InitializeLockHeld() {
427431
status = MmpAllocateGlobalData();
428432
if (!NT_SUCCESS(status)) {
429433
if (status == STATUS_ALREADY_INITIALIZED) {
430-
if ((MmpGlobalDataPtr->MajorVersion < MEMORY_MODULE_MAJOR_VERSION) ||
431-
(MmpGlobalDataPtr->MajorVersion == MEMORY_MODULE_MAJOR_VERSION && MmpGlobalDataPtr->MinorVersion < MEMORY_MODULE_MINOR_VERSION)) {
434+
if ((MmpGlobalDataPtr->MajorVersion != MEMORY_MODULE_MAJOR_VERSION) ||
435+
MEMORY_MODULE_IS_PREVIEW(MmpGlobalDataPtr->MinorVersion) != MEMORY_MODULE_IS_PREVIEW(MEMORY_MODULE_MINOR_VERSION) ||
436+
(MEMORY_MODULE_IS_PREVIEW(MEMORY_MODULE_MINOR_VERSION) ? MmpGlobalDataPtr->MinorVersion != MEMORY_MODULE_MINOR_VERSION :
437+
MmpGlobalDataPtr->MinorVersion < MEMORY_MODULE_MINOR_VERSION)) {
432438
status = STATUS_NOT_SUPPORTED;
433439
}
434440
else {
@@ -458,6 +464,7 @@ NTSTATUS InitializeLockHeld() {
458464
MmpGlobalDataPtr->MmpTls = (PMMP_TLS_DATA)((LPBYTE)MmpGlobalDataPtr->MmpLdrEntry + sizeof(MMP_LDR_ENTRY_DATA));
459465
MmpGlobalDataPtr->MmpDotNet = (PMMP_DOT_NET_DATA)((LPBYTE)MmpGlobalDataPtr->MmpTls + sizeof(MMP_TLS_DATA));
460466
MmpGlobalDataPtr->MmpFunctions = (PMMP_FUNCTIONS)((LPBYTE)MmpGlobalDataPtr->MmpDotNet + sizeof(MMP_DOT_NET_DATA));
467+
MmpGlobalDataPtr->MmpIat = (PMMP_IAT_DATA)((LPBYTE)MmpGlobalDataPtr->MmpFunctions + sizeof(MMP_FUNCTIONS));
461468

462469
PLDR_DATA_TABLE_ENTRY pNtdllEntry = RtlFindLdrTableEntryByBaseName(L"ntdll.dll");
463470
MmpGlobalDataPtr->MmpBaseAddressIndex->NtdllLdrEntry = pNtdllEntry;
@@ -480,6 +487,14 @@ NTSTATUS InitializeLockHeld() {
480487
MmpGlobalDataPtr->MmpFunctions->_MmpHandleTlsData = MmpHandleTlsData;
481488
MmpGlobalDataPtr->MmpFunctions->_MmpReleaseTlsEntry = MmpReleaseTlsEntry;
482489

490+
InitializeCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
491+
InitializeListHead(&MmpGlobalDataPtr->MmpIat->MmpIatResolverList);
492+
InitializeListHead(&MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.InMmpIatResolverList);
493+
MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.LoadLibraryProv = LoadLibraryA;
494+
MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.FreeLibraryProv = FreeLibrary;
495+
MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.ReferenceCount = 1;
496+
InsertTailList(&MmpGlobalDataPtr->MmpIat->MmpIatResolverList, &MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.InMmpIatResolverList);
497+
483498
MmpTlsInitialize();
484499

485500
MmpGlobalDataPtr->MmpDotNet->Initialized = MmpGlobalDataPtr->MmpDotNet->PreHooked = FALSE;

MemoryModule/Loader.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "stdafx.h"
2+
#include <cmath>
23

34
NTSTATUS NTAPI LdrMapDllMemory(
45
_In_ HMEMORYMODULE ViewBase,
@@ -64,6 +65,7 @@ NTSTATUS NTAPI LdrLoadDllMemoryExW(
6465
if (dwFlags & LOAD_FLAGS_USE_DLL_NAME && (!DllName || !DllFullName))return STATUS_INVALID_PARAMETER_3;
6566

6667
if (DllName) {
68+
int length = (int)wcslen(DllName);
6769
PLIST_ENTRY ListHead = &NtCurrentPeb()->Ldr->InLoadOrderModuleList, ListEntry = ListHead->Flink;
6870
PIMAGE_NT_HEADERS h1 = RtlImageNtHeader(BufferAddress), h2 = nullptr;
6971
if (!h1)return STATUS_INVALID_IMAGE_FORMAT;
@@ -74,11 +76,19 @@ NTSTATUS NTAPI LdrLoadDllMemoryExW(
7476

7577
/* Check if it's being unloaded */
7678
if (!CurEntry->InMemoryOrderLinks.Flink) continue;
79+
80+
auto dist = (CurEntry->BaseDllName.Length / sizeof(wchar_t)) - length;
81+
bool equal = false;
82+
if (dist == 0 || dist == 4) {
83+
equal = !_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, length);
84+
}
85+
else {
86+
continue;
87+
}
7788

7889
/* Check if name matches */
79-
if (!_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, (CurEntry->BaseDllName.Length / sizeof(wchar_t)) - 4) ||
80-
!_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, CurEntry->BaseDllName.Length / sizeof(wchar_t))) {
81-
90+
if (equal) {
91+
8292
/* Let's compare their headers */
8393
if (!(h2 = RtlImageNtHeader(CurEntry->DllBase)))continue;
8494
if (!(module = MapMemoryModuleHandle((HMEMORYMODULE)CurEntry->DllBase)))continue;

MemoryModule/MemoryModule.cpp

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -85,86 +85,6 @@ NTSTATUS MmpInitializeStructure(DWORD ImageFileSize, LPCVOID ImageFileBuffer, PI
8585
return STATUS_SUCCESS;
8686
}
8787

88-
89-
NTSTATUS MemoryResolveImportTable(
90-
_In_ LPBYTE base,
91-
_In_ PIMAGE_NT_HEADERS lpNtHeaders,
92-
_In_ PMEMORYMODULE hMemoryModule) {
93-
NTSTATUS status = STATUS_SUCCESS;
94-
PIMAGE_IMPORT_DESCRIPTOR importDesc = nullptr;
95-
DWORD count = 0;
96-
97-
do {
98-
__try {
99-
PIMAGE_DATA_DIRECTORY dir = GET_HEADER_DICTIONARY(lpNtHeaders, IMAGE_DIRECTORY_ENTRY_IMPORT);
100-
PIMAGE_IMPORT_DESCRIPTOR iat = nullptr;
101-
102-
if (dir && dir->Size) {
103-
iat = importDesc = PIMAGE_IMPORT_DESCRIPTOR(lpNtHeaders->OptionalHeader.ImageBase + dir->VirtualAddress);
104-
}
105-
106-
if (iat) {
107-
while (iat->Name) {
108-
++count;
109-
++iat;
110-
}
111-
}
112-
113-
if (importDesc && count) {
114-
hMemoryModule->hModulesList = (HMODULE*)RtlAllocateHeap(NtCurrentPeb()->ProcessHeap, HEAP_ZERO_MEMORY, sizeof(HMODULE) * count);
115-
if (!hMemoryModule->hModulesList) {
116-
status = STATUS_NO_MEMORY;
117-
break;
118-
}
119-
120-
for (DWORD i = 0; i < count; ++i, ++importDesc) {
121-
uintptr_t* thunkRef;
122-
FARPROC* funcRef;
123-
HMODULE handle = LoadLibraryA((LPCSTR)(base + importDesc->Name));
124-
125-
if (!handle) {
126-
status = STATUS_DLL_NOT_FOUND;
127-
break;
128-
}
129-
130-
hMemoryModule->hModulesList[hMemoryModule->dwModulesCount++] = handle;
131-
thunkRef = (uintptr_t*)(base + (importDesc->OriginalFirstThunk ? importDesc->OriginalFirstThunk : importDesc->FirstThunk));
132-
funcRef = (FARPROC*)(base + importDesc->FirstThunk);
133-
while (*thunkRef) {
134-
*funcRef = GetProcAddress(
135-
handle,
136-
IMAGE_SNAP_BY_ORDINAL(*thunkRef) ? (LPCSTR)IMAGE_ORDINAL(*thunkRef) : (LPCSTR)PIMAGE_IMPORT_BY_NAME(base + (*thunkRef))->Name
137-
);
138-
if (!*funcRef) {
139-
status = STATUS_ENTRYPOINT_NOT_FOUND;
140-
break;
141-
}
142-
++thunkRef;
143-
++funcRef;
144-
}
145-
146-
if (!NT_SUCCESS(status))break;
147-
}
148-
149-
}
150-
}
151-
__except (EXCEPTION_EXECUTE_HANDLER) {
152-
status = GetExceptionCode();
153-
}
154-
} while (false);
155-
156-
if (!NT_SUCCESS(status)) {
157-
for (DWORD i = 0; i < hMemoryModule->dwModulesCount; ++i)
158-
FreeLibrary(hMemoryModule->hModulesList[i]);
159-
160-
RtlFreeHeap(NtCurrentPeb()->ProcessHeap, 0, hMemoryModule->hModulesList);
161-
hMemoryModule->hModulesList = nullptr;
162-
hMemoryModule->dwModulesCount = 0;
163-
}
164-
165-
return status;
166-
}
167-
16888
NTSTATUS MemorySetSectionProtection(
16989
_In_ LPBYTE base,
17090
_In_ PIMAGE_NT_HEADERS lpNtHeaders) {
@@ -404,15 +324,7 @@ BOOL MemoryFreeLibrary(HMEMORYMODULE mod) {
404324

405325
if (!module) return FALSE;
406326
if (module->loadFromLdrLoadDllMemory && !module->underUnload)return FALSE;
407-
if (module->hModulesList) {
408-
for (DWORD i = 0; i < module->dwModulesCount; ++i) {
409-
if (module->hModulesList[i]) {
410-
FreeLibrary(module->hModulesList[i]);
411-
}
412-
}
413-
414-
RtlFreeHeap(NtCurrentPeb()->ProcessHeap, 0, module->hModulesList);
415-
}
327+
if (module->hModulesList)MemoryFreeImportTable(module);
416328

417329
if (module->codeBase) VirtualFree(mod, 0, MEM_RELEASE);
418330
return TRUE;

0 commit comments

Comments
 (0)