|
| 1 | +import copy |
| 2 | +import keras |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +#from model import get_model |
| 6 | +from keras.models import load_model |
| 7 | + |
| 8 | +x_train , x_test , y_test , y_train = pd.read_csv('sudoku.csv') |
| 9 | + |
| 10 | +adam = keras.optimizers.Adam(lr = 0.001) |
| 11 | +model.compile(loss='sparse_categorical_crossentropy' , optimizer = adam) |
| 12 | +model.fit(x_train , y_train , batch_size = 32, epochs = 2) |
| 13 | +model =load_model('sudokumodel.py') |
| 14 | +def norm(a): |
| 15 | + return (a/9)-.5; |
| 16 | +def denorm(a): |
| 17 | + return (a+ .5)*9; |
| 18 | +def inference_sudoku(sample): |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | + feat = copy.copy(sample) |
| 23 | + |
| 24 | + while(1): |
| 25 | + |
| 26 | + out = model.predict(feat.reshape((1,9,9,1))) |
| 27 | + out = out.squeeze() |
| 28 | + |
| 29 | + pred = np.argmax(out, axis=1).reshape((9,9))+1 |
| 30 | + prob = np.around(np.max(out, axis=1).reshape((9,9)), 2) |
| 31 | + |
| 32 | + feat = denorm(feat).reshape((9,9)) |
| 33 | + mask = (feat==0) |
| 34 | + |
| 35 | + if(mask.sum()==0): |
| 36 | + break |
| 37 | + |
| 38 | + prob_new = prob*mask |
| 39 | + ind = np.argmax(prob_new) |
| 40 | + x, y = (ind//9), (ind%9) |
| 41 | + |
| 42 | + val = pred[x][y] |
| 43 | + feat[x][y] = val |
| 44 | + feat = norm(feat) |
| 45 | + |
| 46 | + return pred |
| 47 | +def solve_sudoku(game): |
| 48 | + |
| 49 | + game = game.replace('\n', '') |
| 50 | + game = game.replace(' ', '') |
| 51 | + game = np.array([int(j) for j in game]).reshape((9,9,1)) |
| 52 | + game = norm(game) |
| 53 | + game = inference_sudoku(game) |
| 54 | + return game |
| 55 | +game = ''' |
| 56 | + 0 8 0 0 3 2 0 0 1 |
| 57 | + 7 0 3 0 8 0 0 0 2 |
| 58 | + 5 0 0 0 0 7 0 3 0 |
| 59 | + 0 5 0 0 0 1 9 7 0 |
| 60 | + 6 0 0 7 0 9 0 0 8 |
| 61 | + 0 4 7 2 0 0 0 5 0 |
| 62 | + 0 2 0 6 0 0 0 0 9 |
| 63 | + 8 0 0 0 9 0 3 0 5 |
| 64 | + 3 0 0 8 2 0 0 1 0 |
| 65 | + ''' |
| 66 | + |
| 67 | +game = solve_sudoku(game) |
| 68 | + |
| 69 | +print('solved puzzle:\n') |
| 70 | +print(game) |
| 71 | + |
0 commit comments