diff --git a/uwume/__init__.py b/uwume/__init__.py index 4e485e9..4afd0e8 100644 --- a/uwume/__init__.py +++ b/uwume/__init__.py @@ -2,13 +2,16 @@ from flask import Flask from os import urandom from os import environ try: - import uwume.lib.databaseMethods + from . lib.databaseMethods import initialize_database + from . lib.classUser import User except: from .lib.databaseMethods import initialize_database + from .lib.classUser import User app = Flask(__name__) app.secret_key = urandom(12).hex() initialize_database() +User.restore_users() from . import uwume \ No newline at end of file diff --git a/uwume/lib/classUser.py b/uwume/lib/classUser.py index 204fef0..d5374ce 100644 --- a/uwume/lib/classUser.py +++ b/uwume/lib/classUser.py @@ -1,25 +1,34 @@ try: - from classUwuCounter import UwuCounter - from databaseMethods import create_user -except: from .classUwuCounter import UwuCounter - from .databaseMethods import create_user + from .databaseMethods import create_user, get_users, get_count +except: + from classUwuCounter import UwuCounter + from databaseMethods import create_user, get_users, get_count from flask import Flask, Response from flask_login import LoginManager, UserMixin, login_required class User(UserMixin): # proxy for a database of users - user_database = {"JohnDoe": ("JohnDoe", "John", UwuCounter(0)), - "JaneDoe": ("JaneDoe", "Jane", UwuCounter(10)), - "admin": ("admin", "admin", UwuCounter(100))} + loaded_users={} - def __init__(self, username, password, initialCount=UwuCounter(0)): + def __init__(self, username, password='', initialCount=0): self.id = username - self.password = password self.count = UwuCounter(initialCount) - create_user(0, username, password, 0) + if(password): + create_user(0, username, password, 0) + + @classmethod + def load_user(cls, user): + cls.loaded_users[user.id] = user + + @classmethod + def restore_users(cls): + user_list = get_users() + for user in user_list: + cls.load_user(User(user, initialCount=get_count(user))) + @classmethod def get(cls, id): - return cls.user_database.get(id) \ No newline at end of file + return cls.loaded_users.get(id) \ No newline at end of file diff --git a/uwume/lib/databaseMethods.py b/uwume/lib/databaseMethods.py index 863dd03..61dc7fb 100644 --- a/uwume/lib/databaseMethods.py +++ b/uwume/lib/databaseMethods.py @@ -10,15 +10,19 @@ md = MetaData() def initialize_database(): - users = Table('users', md, - Column('id', Integer()), - Column('name', String(255)), - Column('pass', String(255)), - Column('count', Integer())) - try: - md.create_all(engine) - except Exception as e: - print(f'{e}') + if(not engine.dialect.has_table(engine, 'users')): + users = Table('users', md, + Column('id', Integer()), + Column('name', String(255)), + Column('password', String(255)), + Column('count', Integer())) + try: + md.create_all(engine) + except Exception as e: + print(f'{e}') + create_user(0, 'admin', 'admin', 0) + else: + print('Users table already created, skipping initialization.') def create_user(idn, name, password, count): @@ -33,6 +37,13 @@ def create_user(idn, name, password, count): output = conn.execute(query, values) +def check_password(user, password): + users = Table('users', md, autoload=True, autoload_with=engine) + query = select([users.columns.password]).where(users.columns.name == user) + output = conn.execute(query) + return output.fetchone()[0] == password + + def get_table(): users = Table('users', md, autoload=True, autoload_with=engine) return users.columns.keys() @@ -46,6 +57,32 @@ def get_count(user): def update_count(user, count): + if(get_count(user) < count): + users = Table('users', md, autoload=True, autoload_with=engine) + query = update(users).values(count=count).where(users.columns.name == user) + return conn.execute(query) + + +def get_user(id): users = Table('users', md, autoload=True, autoload_with=engine) - query = update(users).values(count=count).where(users.columns.name == user) - return conn.execute(query) + query = select([users.columns]).where(users.columns.name == id) + output = conn.execute(query) + return output.fetchone()[0] + + +def get_users(): + users = Table('users', md, autoload=True, autoload_with=engine) + query = select([users.columns.name]) + output = conn.execute(query) + user_list = output.fetchall() + user_list_formatted = [] + for u in user_list: + user_list_formatted += [str(u[0])] + return user_list_formatted + + +def get_all(): + users = Table('users', md, autoload=True, autoload_with=engine) + query = select([users]) + output = conn.execute(query) + return output.fetchall() \ No newline at end of file diff --git a/uwume/uwume.py b/uwume/uwume.py index 50f0a79..b058cb9 100644 --- a/uwume/uwume.py +++ b/uwume/uwume.py @@ -2,12 +2,12 @@ try: from .lib.helpers import get_static_paths, get_content_text from .lib.classUser import User from .lib.classUwuCounter import UwuCounter - from .lib.databaseMethods import get_table, get_count, update_count + from .lib.databaseMethods import get_table, get_count, update_count, get_users, get_all, check_password except: from lib.classUser import User from lib.helpers import get_static_paths, get_content_text from lib.classUwuCounter import UwuCounter - from lib.databaseMethods import get_table, get_count, update_count + from lib.databaseMethods import get_table, get_count, update_count, get_users, get_all, check_password from . import app from flask import Flask, render_template, redirect, url_for, request, flash from flask_login import LoginManager, UserMixin, login_required, login_user, current_user, logout_user @@ -38,7 +38,7 @@ def load_user(request): @login_manager.user_loader def load_user(request): - return User(request, User.user_database[request][1]) + return User.get(request) @app.route('/user//getCurval', methods=['GET']) @@ -46,9 +46,9 @@ def get_curval(username): error = '' if(request.method == 'GET'): try: - new_count = f'{User.user_database[username][2].curval}' + new_count = User.get(username).count.curval update_count(username, new_count) - return new_count + return f'{new_count}' except Exception as e: error = f'{e}' return f'ERROR: {error}' @@ -58,15 +58,14 @@ def get_curval(username): def user_page(username): error = '' if(request.method == 'GET'): - if(username in User.user_database.keys()): - return render_template('user/index.html', this_user=username, user_curval=str(get_count(username)), static_paths=get_static_paths(), content_text=get_content_text()) + if(username in get_users()): + return render_template('user/index.html', this_user=username, user_curval=f'{User.get(username).count.curval}', static_paths=get_static_paths(), content_text=get_content_text()) else: - error = 'User doesn\'t exist.' + error = f'User doesn\'t exist.\n{username}\n{get_users()}' elif(request.method == 'POST'): try: - User.user_database[username][2].increment() - print('test') - return f'{User.user_database[username][2].curval}' + User.get(username).count.increment() + return f'{User.get(username).count.curval}' except Exception as e: error = f'{e}' return f'ERROR: {error}' @@ -77,7 +76,9 @@ def home_get_curval(): error = '' if(request.method == 'GET'): try: - return f'{User.user_database[current_user.id][2].curval}' + new_count = User.get(current_user.id).count.curval + update_count(current_user.id, new_count) + return f'{new_count}' except Exception as e: error = f'{e}' return f'ERROR: {error}' @@ -86,7 +87,7 @@ def home_get_curval(): @app.route('/home', methods=['GET']) @login_required def home(): - return render_template('home/index.html', user_curval=str(User.user_database[current_user.id][2].curval), static_paths=get_static_paths(), content_text=get_content_text()) + return render_template('home/index.html', user_curval=str(User.get(current_user.id).count.curval), static_paths=get_static_paths(), content_text=get_content_text()) @app.route('/signup', methods=['GET', 'POST']) @@ -95,8 +96,9 @@ def signup(): if(request.method == 'POST'): username = request.form['username'] password = request.form['password'] - if(not username in User.user_database.keys()): - User.user_database[username] = (username, password, UwuCounter(0)) + if(not username in get_users()): + User.load_user(User(username, password)) + #User.user_database[username] = (username, password, UwuCounter(0)) return redirect('/') else: flash('It looks like a user already exists with that name.') @@ -112,8 +114,8 @@ def login(): if(request.method == 'POST'): username = request.form['username'] password = request.form['password'] - if(username in User.user_database.keys() and password == User.user_database.get(username)[1]): - userClass = User(username, password) + if(username in get_users() and check_password(username, password)): + userClass = User.get(username) login_user(userClass) return redirect('home') else: