This repository has been archived on 2025-04-11. You can view files and clone it, but cannot push or open issues or pull requests.
nn420-private-pine64backup/nn_puzzle_solver.ipynb

291 lines
8.2 KiB
Text
Raw Normal View History

2017-12-01 14:27:20 -06:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
2017-12-01 14:27:20 -06:00
"metadata": {},
"outputs": [],
2017-12-01 14:27:20 -06:00
"source": [
"# Setting up our imported libraries.\n",
"from functools import reduce\n",
2017-12-01 14:27:20 -06:00
"import numpy as np\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"from os.path import join"
2017-12-01 14:27:20 -06:00
]
},
{
"cell_type": "code",
"execution_count": null,
2017-12-01 14:27:20 -06:00
"metadata": {},
"outputs": [],
2017-12-01 14:27:20 -06:00
"source": [
"# Generating dummy data.\n",
"\n",
2017-12-01 14:27:20 -06:00
"data = np.random.random((1000,240))\n",
"output = np.random.random((1000, 29))\n",
"\n",
"# Replace this with parser code later."
2017-12-01 14:27:20 -06:00
]
},
{
"cell_type": "code",
"execution_count": null,
2017-12-01 14:27:20 -06:00
"metadata": {},
"outputs": [],
2017-12-01 14:27:20 -06:00
"source": [
"# Sets up a Sequential model, Sequential is all\n",
"# that should need to be used for this project,\n",
"# considering that it will only be dealing with\n",
"# a linear stack of layers of neurons.\n",
"\n",
2017-12-01 14:27:20 -06:00
"model = Sequential()\n",
"\n",
"# Adding layers to the model.\n",
"\n",
"model.add(Dense(units=240, activation='tanh', input_dim=240))\n",
"model.add(Dense(units=120, activation='tanh'))\n",
"model.add(Dense(units=29, activation='sigmoid'))\n",
"\n",
"# Configure the learning process.\n",
"\n",
"model.compile(optimizer='sgd',\n",
" loss='mean_squared_error',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
2017-12-01 14:27:20 -06:00
"\n",
"def format_input(acc, elem):\n",
" hex_elem = (elem - (elem >> 4 << 4))\n",
" for x in range(16):\n",
" if x == hex_elem:\n",
" acc.append(1)\n",
" else:\n",
" acc.append(0)\n",
" hex_elem = (elem >> 4) % 16\n",
" for x in range(16):\n",
" if x == hex_elem:\n",
" acc.append(1)\n",
" else:\n",
" acc.append(0)\n",
" return acc\n",
2017-12-01 14:27:20 -06:00
"\n",
"with open('data/0.bin', 'rb') as f:\n",
" data = f.read(8)\n",
" counter = 0\n",
2017-12-01 14:27:20 -06:00
"\n",
" while(data):\n",
" bin_data = reduce(format_input, list(data), [])\n",
" bin_data.reverse()\n",
" bin_data = bin_data[16:]\n",
" print(bin_data)\n",
2017-12-01 14:27:20 -06:00
"\n",
" print(counter)\n",
" \n",
" for i in range(int(len(bin_data)/240)):\n",
" for x in range((i*16),(i*16) + 15):\n",
" temp = []\n",
" for y in range(16):\n",
" temp.append(bin_data[(x*16) + y])\n",
" print(temp)\n",
" \n",
" data = f.read(8)\n",
" counter += 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"I'm going to need to calculate Manhattan Distances for each of the states at some point.\n",
"\n",
"This website might be helpful for that formula:\n",
"https://heuristicswiki.wikispaces.com/Manhattan+Distance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def man_dist(x, y):\n",
" for a, b in zip(x, y):\n",
" a_one, a_two = x\n",
" b_one, b_two = y\n",
" \n",
" return (abs(a_one - b_one) + abs(a_two - b_two))\n",
" \n",
"def man_dist_state(x, y):\n",
" return sum(man_dist(a, b) for a, b in zip(x, y))\n",
"\n",
"def format_pos(acc, elem):\n",
" hex_elem = (elem[1] - (elem[1] >> 4 << 4))\n",
" if hex_elem == 0:\n",
" acc.append((hex_elem, (3,3)))\n",
" else:\n",
" acc.append((hex_elem, ((15 - ((elem[0]) * 2)) % 4,int((15 - ((elem[0]) * 2)) / 4))))\n",
" hex_elem = (elem[1] >> 4) % 16\n",
" if hex_elem == 0:\n",
" acc.append((hex_elem, (3,3)))\n",
" else:\n",
" acc.append((hex_elem, ((15 - ((elem[0]) * 2 + 1)) % 4,int((15 - ((elem[0]) * 2 + 1)) / 4))))\n",
" \n",
" return acc\n",
"\n",
"def generate_pos(acc, elem):\n",
" if(elem[0] == 0):\n",
" acc.append((3,3))\n",
" else:\n",
" acc.append((((elem[0] - 1) % 4), (int((elem[0] - 1)/4))))\n",
" \n",
" return acc\n",
"\n",
"def format_man_dist(elem):\n",
" acc = []\n",
" for x in range(28, -1, -1):\n",
" if x == elem:\n",
" acc.append(1)\n",
" else:\n",
" acc.append(0)\n",
" return acc\n",
"\n",
"with open('data/0.bin', 'rb') as f:\n",
" data = f.read(8)\n",
" counter = 0\n",
" \n",
" while(data):\n",
" pos_data = reduce(format_pos, enumerate(list(data)), [])\n",
" pos_data.reverse()\n",
" pos_data = pos_data[1:]\n",
" \n",
" state_pos = []\n",
" \n",
" for p in pos_data:\n",
" state_pos.append(p[1])\n",
" \n",
" target_pos = reduce(generate_pos, pos_data, [])\n",
" \n",
" print(format_man_dist(man_dist_state(state_pos, target_pos)))\n",
" \n",
" data = f.read(8)\n",
" counter += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(29):\n",
" filename = join('data', str(i) + '.bin')\n",
" \n",
" print(i)\n",
" \n",
" with open(filename, 'rb') as f:\n",
" data = f.read(8)\n",
" counter = 0\n",
"\n",
" while(data and counter < 1000):\n",
" bin_data = reduce(format_input, list(data), [])\n",
" bin_data.reverse()\n",
" bin_data = bin_data[16:]\n",
"\n",
" pos_data = reduce(format_pos, enumerate(list(data)), [])\n",
" pos_data.reverse()\n",
"\n",
" state_pos = []\n",
"\n",
" for p in pos_data:\n",
" state_pos.append(p[1])\n",
"\n",
" target_pos = reduce(generate_pos, pos_data, [])\n",
"\n",
" #for i in range(int(len(bin_data)/256)):\n",
" # for x in range((i*16) + 1,(i*16) + 16):\n",
" # temp = []\n",
" # for y in range(16):\n",
" # temp.append(bin_data[(x*16) + y])\n",
" # print(temp)\n",
"\n",
" target = []\n",
" target.append(format_man_dist(man_dist_state(state_pos, target_pos)))\n",
"\n",
" temp = []\n",
" temp.append(bin_data)\n",
"\n",
" # Train the network.\n",
"\n",
" #model.fit(data, output, epochs=5, batch_size=10)\n",
" model.train_on_batch(np.array(temp), np.array(target))\n",
"\n",
" data = f.read(8)\n",
" counter += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('data/2.bin', 'rb') as f:\n",
" data = f.read(8)\n",
" counter = 0\n",
" \n",
" while(data):\n",
" bin_data = reduce(format_input, list(data), [])\n",
" bin_data.reverse()\n",
" bin_data = bin_data[16:]\n",
" \n",
" temp = []\n",
" temp.append(bin_data)\n",
"\n",
" # Generating predictions:\n",
2017-12-01 14:27:20 -06:00
"\n",
" predictions = model.predict(np.array(temp), batch_size=1)\n",
" print(predictions)"
2017-12-01 14:27:20 -06:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}