Skip to content

Commit 5fe22b0

Browse files
committed
adding deep learning
1 parent 92a0412 commit 5fe22b0

10 files changed

Lines changed: 59 additions & 19 deletions

File tree

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "3rd/N3LDG"]
55
path = 3rd/N3LDG
66
url = https://github.com/chncwang/N3LDG-1
7+
[submodule "3rd/eigen"]
8+
path = 3rd/eigen
9+
url = https://github.com/eigenteam/eigen-git-mirror

3rd/N3LDG

Submodule N3LDG updated from ad05845 to ad6a28d

3rd/eigen

Submodule eigen added at a5d8d5a

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ IF(Boost_FOUND)
2222
ENDIF()
2323

2424
INCLUDE_DIRECTORIES(src 3rd/googletest/googletest/include 3rd/googletest/googlemock/include spdlog/include)
25-
INCLUDE_DIRECTORIES(3rd/N3LDG/include 3rd/eigen/include)
25+
INCLUDE_DIRECTORIES(3rd/N3LDG/include 3rd/eigen)
2626

2727
AUX_SOURCE_DIRECTORY(src/board SRCS)
2828
AUX_SOURCE_DIRECTORY(src/piece_structure SRCS)

src/board/full_board.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,13 @@ void FullBoard<BOARD_LEN>::PlayMove(const Move &move) {
245245
PositionIndex move_index = move.position_index;
246246

247247
assert(PstionAndIndxCcltr<BOARD_LEN>::Ins().IsInBoard(move_index));
248-
assert(GetPointState(move_index) == EMPTY_POINT);
248+
if (GetPointState(move_index) != EMPTY_POINT) {
249+
const Position &pos = PstionAndIndxCcltr<BOARD_LEN>::Ins().GetPosition(
250+
move_index);
251+
std::cout << "x:" << (char)(pos.x + 'a') << " y:" << (char)(pos.y + 'a') <<
252+
std::endl;
253+
abort();
254+
}
249255

250256
BoardDifference board_difference;
251257
board_difference.Init(LastForce(), KoIndex());

src/board/pos_cal.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@ class PstionAndIndxCcltr {
1818
} CentralEdgeCorner;
1919
static PstionAndIndxCcltr &Ins();
2020

21-
inline const Position &GetPosition(PositionIndex index) const {
21+
const Position &GetPosition(PositionIndex index) const {
2222
assert(IsInBoard(index));
2323
return position_[index];
2424
}
25-
inline PositionIndex GetIndex(const Position &pos) const {
25+
PositionIndex GetIndex(const Position &pos) const {
2626
assert(IsInBoard(pos));
2727
return indexes_[pos.y][pos.x];
2828
}
2929

30-
inline bool IsInBoard(const Position &pos) const {
30+
bool IsInBoard(const Position &pos) const {
3131
return pos.x >= 0 && pos.x < BOARD_LEN && pos.y >= 0
3232
&& pos.y < BOARD_LEN;
3333
}
34-
inline bool IsInBoard(PositionIndex indx) const {
34+
bool IsInBoard(PositionIndex indx) const {
3535
return indx >= 0 && indx < BoardLenSquare<BOARD_LEN>();
3636
}
3737
CentralEdgeCorner CentralOrEdgeOrCorner(const Position &pos);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef FOOLGO_SRC_DEEP_LEARNING_RESNET
2+
#define FOOLGO_SRC_DEEP_LEARNING_RESNET
3+
4+
#include <vector>
5+
#include <array>
6+
7+
#include "N3LDG.h"
8+
9+
struct GraphBuilder {
10+
std::vector<BucketNode> input_nodes;
11+
std::array<std::vector<ConcatNode>, > concat_nodes;
12+
13+
void CreateNodes() {
14+
input_nodes.resize(441);
15+
}
16+
};
17+
18+
#endif

src/deep_learning/sample.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ template<BoardLen BOARD_LEN>
99
struct Sample {
1010
FullBoard<BOARD_LEN> full_board;
1111
PositionIndex position_index;
12+
13+
Sample() = default;
14+
15+
Sample(const Sample &sample) : position_index(sample.position_index) {
16+
full_board.Copy(sample.full_board);
17+
}
1218
};
1319

1420
}

src/game/sgf_game.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cstdint>
55
#include <memory>
66
#include <functional>
7+
#include <vector>
78

89
#include "board/position.h"
910
#include "def.h"
@@ -24,41 +25,43 @@ class SgfGame : public Game<BOARD_LEN> {
2425
~SgfGame() = default;
2526

2627
static std::unique_ptr<SgfGame> BuildSgfGame(const GameInfo &game_info,
27-
const std::function<void(const Sample<BOARD_LEN>&)> &collect_sample = nullptr) {
28+
std::vector<Sample<BOARD_LEN>> *samples = nullptr) {
2829
SgfPlayer<BOARD_LEN> *black_player = new SgfPlayer<BOARD_LEN>(game_info);
2930
SgfPlayer<BOARD_LEN> *white_player = new SgfPlayer<BOARD_LEN>(game_info);
3031
FullBoard<BOARD_LEN> full_board;
3132
full_board.Init();
3233

3334
return std::unique_ptr<SgfGame>(new SgfGame<BOARD_LEN>(
34-
full_board, black_player, white_player, game_info));
35+
full_board, black_player, white_player, game_info, samples));
3536
}
3637

3738
protected:
3839
SgfGame(const FullBoard<BOARD_LEN> &full_board,
3940
Player<BOARD_LEN> *black_player,
4041
Player<BOARD_LEN> *white_player,
4142
const GameInfo &game_info,
43+
std::vector<Sample<BOARD_LEN>> *samples,
4244
bool only_log_board = true) : Game<BOARD_LEN>(full_board, black_player,
43-
white_player, only_log_board), game_info_(game_info) {}
45+
white_player, only_log_board), game_info_(game_info),
46+
samples_(samples) {}
4447

4548
bool ShouldLog() const override {
46-
return true;
49+
return false;
4750
}
4851

4952
void BeforePlay (PositionIndex index) override {
50-
if (collect_sample_ != nullptr) {
53+
if (samples_ != nullptr) {
5154
Sample<BOARD_LEN> sample;
5255
sample.full_board.Copy(this->GetFullBoard());
5356
sample.position_index = index;
54-
collect_sample_(sample);
57+
samples_->push_back(sample);
5558
}
5659
}
5760

5861
private:
5962
DISALLOW_COPY_AND_ASSIGN_AND_MOVE(SgfGame);
6063
GameInfo game_info_;
61-
std::function<void(const Sample<BOARD_LEN> &)> collect_sample_;
64+
std::vector<Sample<BOARD_LEN>> *samples_ = nullptr;
6265
};
6366

6467
}

src/trainer.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "util/cxxopts.hpp"
1313
#include "util/SGFParser.h"
1414
#include "deep_learning/sample.h"
15+
#include "deep_learning/resnet/graph_builder.h"
1516
#include "N3LDG.h"
1617

1718
using namespace foolgo;
@@ -27,19 +28,21 @@ int main(int argc, char *argv[]) {
2728
("s,sgf", "sgf file name", cxxopts::value<string>());
2829
auto args = options.parse(argc, argv);
2930
string sgf_file_name = args["sgf"].as<string>();
30-
cout << "sgf:" << sgf_file_name << endl;
3131

3232
SGFParser parser;
3333
vector<string> strs = parser.chop_all(sgf_file_name);
3434
cout << strs.size() << endl;
3535
vector<GameInfo> game_infos = parser.get_game_infos(sgf_file_name);
36-
vector<Sample> samples;
37-
auto collect_samples = [&](const Sample &sample) {
38-
}
36+
vector<Sample<19>> samples;
3937

38+
int iter = 0;
4039
for (const GameInfo &game_info : game_infos) {
41-
auto sgf_game = SgfGame<19>::BuildSgfGame(game_info, collect_samples);
40+
auto sgf_game = SgfGame<19>::BuildSgfGame(game_info, &samples);
4241
sgf_game->Run();
42+
if (++iter % 100 == 0) {
43+
cout << "iter:" << iter << "sample count:" << samples.size() << endl;
44+
break;
45+
}
4346
}
4447

4548
return 0;

0 commit comments

Comments
 (0)