Michael Rodin
fd1742ec38
Added a user details page and the ability to update display name and password
297 lines
13 KiB
Python
297 lines
13 KiB
Python
## MAIN FUNCTIONS FILE FOR BACK-BACKEND OF FLASK
|
|
import mariadb as sql
|
|
from os import environ
|
|
import time,re
|
|
from config import *
|
|
|
|
## params populated with environment variables, defaults can be changed for a permanent solution
|
|
conn_params={
|
|
"user" : environ.get('MARIADB_USER') if environ.get('MARIADB_USER') else MARIADB_USER,
|
|
"password" : environ.get('MARIADB_PASSWORD') if environ.get('MARIADB_PASSWORD') else MARIADB_PASSWORD,
|
|
"host" : environ.get('MARIADB_HOST') if environ.get('MARIADB_HOST') else MARIADB_HOST,
|
|
"database" : environ.get('MARIADB_DB') if environ.get('MARIADB_DB') else MARIADB_DB
|
|
}
|
|
|
|
class db:
|
|
def __init__(self):
|
|
self.conn=sql.connect(**conn_params)
|
|
self.conn.autocommit=True
|
|
self.cur=self.conn.cursor()
|
|
|
|
## Creates all archives, if they don't exist already
|
|
## Called only on startup, hence the name
|
|
def startup(self):
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS Archs(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
NAME text NOT NULL,
|
|
HASH text NOT NULL UNIQUE,
|
|
SIZE bigint NOT NULL,
|
|
IMPORTED int,
|
|
CATEGORY int,
|
|
OWNER int
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS Users(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
UNAME text NOT NULL UNIQUE,
|
|
DNAME text NOT NULL,
|
|
CREATED int NOT NULL,
|
|
STATE text,
|
|
PASSHASH text NOT NULL
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS Sessions(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
SESSKEY text NOT NULL UNIQUE,
|
|
USERID int NOT NULL,
|
|
CREATED int NOT NULL,
|
|
LIFE int
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS Cats(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
CATEGORY text NOT NULL,
|
|
PARENT int,
|
|
DESCRIPTION text
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS ArchLab(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
ARCHID int NOT NULL,
|
|
LABID int NOT NULL
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS CatLabType(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
CATID int NOT NULL,
|
|
LABID int NOT NULL
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS Labs(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
LABEL text NOT NULL,
|
|
TYPE int NOT NULL
|
|
);""")
|
|
self.cur.execute("""CREATE TABLE IF NOT EXISTS LabType(
|
|
ID int PRIMARY KEY AUTO_INCREMENT,
|
|
NAME text NOT NULL,
|
|
DESCRIPTION text
|
|
);""")
|
|
|
|
## Gets the passhash from a specific user
|
|
## OUTPUT: (If user exists) int=200, ID:int, passhash:str
|
|
## (If user does not exist) int=400, Exception:str
|
|
def get_passhash(self, username:str):
|
|
self.cur.execute(f"SELECT ID,PASSHASH FROM Users WHERE UNAME='{username}'")
|
|
try:
|
|
resp=self.cur.fetchone()
|
|
except Exception as e:
|
|
return 400, e, NULL
|
|
return 200, resp[0], resp[1]
|
|
|
|
## Checks if sesskey exists and is not expired
|
|
## OUTPUT: (if valiid) bool=True, USERID:str
|
|
## (in invalid) bool=False, str=""
|
|
def check_sesskey(self, sesskey:str):
|
|
self.cur.execute(f"SELECT SESSKEY,USERID FROM Sessions WHERE SESSKEY='{sesskey}'")
|
|
entry=self.cur.fetchone()
|
|
if entry and sesskey in entry:
|
|
return True, entry[1]
|
|
else:
|
|
return False, ""
|
|
|
|
## Sets a session key. That's it.
|
|
def set_sesskey(self, sesskey:str, userid:int, lifetime:int):
|
|
self.cur.execute(f"INSERT INTO Sessions(SESSKEY,USERID,CREATED,LIFE) VALUES('{sesskey}',{userid},{time.time()},{lifetime})")
|
|
|
|
def logout_user(self, sesskey:str):
|
|
self.cur.execute(f"DELETE FROM Sessions WHERE SESSKEY='{sesskey}'")
|
|
|
|
## Gets and returns all user info about one (1) user
|
|
## OUTPUT: tuple=(ID:int,UNAME:str,DNAME:str,CREATED:int,STATE:text,PASSHASH:text)
|
|
def get_user_info(self, userid:int):
|
|
self.cur.execute(f"SELECT * FROM Users WHERE ID='{userid}'")
|
|
return self.cur.fetchone()
|
|
|
|
## like above, just with uname
|
|
## OUTPUT: tuple=(ID:int,UNAME:str,DNAME:str,CREATED:int,STATE:text,PASSHASH:text)
|
|
def get_user_info_from_uname(self, uname:str):
|
|
self.cur.execute(f"SELECT * FROM Users WHERE UNAME='{uname}'")
|
|
return self.cur.fetchone()
|
|
|
|
def update_user_info(self, userid, update_type:str,value):
|
|
allowed_types={"DNAME":str,"PASSHASH":str}
|
|
if update_type.upper() not in allowed_types:
|
|
return False, "Not allowed"
|
|
self.cur.execute(f"""UPDATE Users SET {update_type}={value if allowed_types[update_type]==int else f"'{value}'"} WHERE ID={userid}""")
|
|
return True, "Updated"
|
|
|
|
## Checks information for errors and adds archive to the DB
|
|
## OUTPUT: (if successful) res:bool=True, ID:int
|
|
## (if unsuccessful) res:bool=False, str
|
|
def add_archive(self, archive:dict):
|
|
# Check everything for errors or malicious things
|
|
archive["hash"]=archive["hash"].upper()
|
|
if not re.match('[A-Z0-9]{40}', archive["hash"]):
|
|
return False, "Hash needs to be 40 characters in hexadecimal (SHA-1)."
|
|
if re.match('.*[^A-Za-z0-9\. +_-].*', archive["name"]):
|
|
return False, "The name contains illegal characters. Allowed chars: '[A-Za-z0-9\. _-]'"
|
|
print(archive["name"])
|
|
|
|
curtime=time.time()
|
|
try:
|
|
self.cur.execute(f"INSERT INTO Archs(NAME,HASH,SIZE,IMPORTED,CATEGORY,OWNER) VALUES('{archive['name']}','{archive['hash']}',{archive['size']},{curtime},{archive['category']},{archive['owner']})")
|
|
except Exception as e: # hash needs to be unique
|
|
return False, e
|
|
self.cur.execute(f"SELECT ID FROM Archs WHERE HASH='{archive['hash']}'")
|
|
archid=self.cur.fetchone()
|
|
return True,archid[0]
|
|
|
|
def delete_archive(self, archid:int):
|
|
self.cur.execute(f"""DELETE FROM Archs WHERE ID={archid}""")
|
|
self.cur.execute(f"""DELETE FROM ArchLab WHERE ARCHID={archid}""")
|
|
|
|
## Returns all relevant information about one (1) archive
|
|
## OUTPUT: archive:tuple=(ID:int,NAME:str,HASH:str,SIZE:int,IMPORTED[UNIX]:int,CATEGORY.ID:int,CATEGORY,str,CATEGORY.DESCRIPTION:str,USER.ID:int,DNAME:str),
|
|
## category:tuple=(ID:int,CATEGORY:str,DESCRIPTION:str,PID:int,PCAT:str,PDESC:str)
|
|
## labels:list=[…,(ID:int,LABEL:str,LABTYPE:int,LABDESC:str),…]
|
|
def get_archive_info(self, archid:int):
|
|
# get info about archive itself
|
|
self.cur.execute(f"""SELECT Archs.ID,Archs.NAME,Archs.HASH,Archs.SIZE,Archs.IMPORTED,Cats.ID,Cats.CATEGORY,Cats.DESCRIPTION,Users.ID,Users.DNAME FROM Archs
|
|
JOIN Cats ON Cats.ID=Archs.CATEGORY
|
|
JOIN Users ON Users.ID=Archs.OWNER
|
|
WHERE Archs.ID='{archid}'""")
|
|
archive=self.cur.fetchone()
|
|
# get info about category and it's parent
|
|
self.cur.execute(f"""SELECT c.ID,c.CATEGORY,c.DESCRIPTION,p.ID AS PID,p.CATEGORY as PCAT,p.DESCRIPTION AS PDESC FROM Cats c, Cats p
|
|
WHERE c.ID={archive[5]} AND c.PARENT=p.ID""")
|
|
category=self.cur.fetchone()
|
|
# get info about labels of archive
|
|
|
|
labels=self.get_label_info(archid)
|
|
return archive, category, labels
|
|
|
|
def get_label_info(self, archid:int):
|
|
self.cur.execute(f"""SELECT Labs.ID,Labs.LABEL,LabType.ID AS LABTYPE,LabType.DESCRIPTION AS LABDESC FROM ArchLab
|
|
JOIN Archs ON Archs.ID=ArchLab.ARCHID
|
|
JOIN Labs ON Labs.ID=ArchLab.LABID
|
|
JOIN LabType ON Labs.TYPE=LabType.ID
|
|
WHERE ARCHID={archid};""")
|
|
return self.cur.fetchall()
|
|
|
|
## Returns all categories.
|
|
## OUTPUT: array=[…,(ID:int,CATEGORY:str,PARENT:int,DESCRIPTION:str),…]
|
|
def get_categories(self):
|
|
self.cur.execute("SELECT * FROM Cats;")
|
|
return self.cur.fetchall()
|
|
|
|
## get all labeltypes and their respective labels based on a category parent
|
|
## OUTPUT: res_dict:dict={…,LabType.NAME:[…,(ID:int,NAME:str),…],…}
|
|
def get_label_labeltypes(self, catparentid:int):
|
|
# gets all relevant labtypes: […,(ID.int,NAME:str),…]
|
|
self.cur.execute(f"""SELECT LabType.ID,LabType.NAME FROM CatLabType
|
|
JOIN Cats ON Cats.ID=CatLabType.CATID
|
|
JOIN LabType ON LabType.ID=CatLabType.LABID
|
|
WHERE Cats.ID={catparentid}
|
|
ORDER BY LabType.NAME ASC""")
|
|
labtypes_list=self.cur.fetchall()
|
|
labtypes_ids,labtypes_names=[],[]
|
|
for w,e in labtypes_list:
|
|
labtypes_ids.append(str(w))
|
|
labtypes_names.append(e)
|
|
ltid_string="(" + ",".join(labtypes_ids) + ")"
|
|
# gets all relevant labs: […,(ID:int,NAME:str,LTNAME:str),…]
|
|
self.cur.execute(f"""SELECT Labs.ID,Labs.LABEL,LabType.NAME AS LTNAME FROM Labs
|
|
JOIN LabType ON Labs.TYPE=LabType.ID
|
|
WHERE LabType.ID IN {ltid_string}
|
|
ORDER BY Labs.LABEL ASC""")
|
|
labs_list=self.cur.fetchall()
|
|
res_dict={}
|
|
# creates all labtype entries in dict
|
|
for i in labtypes_names:
|
|
res_dict[i]=[]
|
|
# puts all labs into their respective labtype
|
|
for entry in labs_list:
|
|
res_dict[entry[2]].append(entry[:2])
|
|
return res_dict
|
|
|
|
## get a list of enabled labels and update the DB to reflect that state
|
|
## OUTPUT: (if on_labels empty) bool=False, str
|
|
## (else)
|
|
def update_labels(self, archid:int, on_labels:list):
|
|
# fail if no labels passed
|
|
if len(on_labels) == 0:
|
|
return False, "You have to select at least one label!"
|
|
|
|
# get all relevant labels
|
|
self.cur.execute(f"""SELECT ArchLab.LABID FROM ArchLab
|
|
WHERE ArchLab.ARCHID={archid}""")
|
|
existing_labs=[]
|
|
for i in self.cur.fetchall():
|
|
existing_labs.append(i[0])
|
|
to_add=[]
|
|
# get all missing labels to add
|
|
for lab in on_labels:
|
|
if int(lab) not in existing_labs:
|
|
to_add.append(lab)
|
|
|
|
# remove all labels which are not on
|
|
self.cur.execute(f"""DELETE FROM ArchLab WHERE ARCHID={archid} AND LABID NOT IN ({",".join(on_labels)})""")
|
|
to_add_list=[]
|
|
for i in to_add:
|
|
to_add_list.append("(" + str(archid) + "," + str(i) + ")")
|
|
# add all new labels
|
|
self.cur.execute(f"""INSERT INTO ArchLab(ARCHID,LABID) VALUES{",".join(to_add_list)}""")
|
|
return True, ""
|
|
|
|
## Returns n archives, sorted by (imported )time or size
|
|
## OUTPUT: archives:array=[…,(ID:int,NAME:str,SIZE:str,IMPORTED[UNIX]:int),…]
|
|
def get_n_archives(self, sorttype:str="time",category:int=0, keywords:list=[], count:int=20,labels:list=[]): # TODO: CLEANN!!!!!
|
|
match sorttype:
|
|
case "size":
|
|
sorttype="SIZE DESC"
|
|
case "time":
|
|
sorttype="IMPORTED DESC"
|
|
case "za":
|
|
sorttype="NAME DESC"
|
|
case _:
|
|
sorttype="NAME ASC"
|
|
|
|
# create SQL query for keywords
|
|
keyword_string=""
|
|
for w in keywords:
|
|
keyword_string+=f"AND NAME LIKE '%{w}%' "
|
|
if len(keywords) == 1:
|
|
keyword_string+=f"OR HASH = '{keywords[0]}' "
|
|
|
|
# get all children of category (if exist) and put into query string
|
|
categories=self.get_categories()
|
|
catlist=[str(category)]
|
|
for i in categories:
|
|
if i[2] == int(category):
|
|
catlist.append(str(i[0]))
|
|
categories="(" + ",".join(catlist) + ")"
|
|
|
|
self.cur.execute(f"""SELECT Archs.ID,Archs.NAME,Archs.SIZE,Archs.IMPORTED FROM Archs
|
|
{"WHERE 1=1" if category==0 else "WHERE CATEGORY IN " + categories}
|
|
{keyword_string}
|
|
ORDER BY {sorttype} LIMIT {count if count else 20}""")
|
|
archives=self.cur.fetchall()
|
|
## WARNING: JANK
|
|
#positive_archives=[]
|
|
#print("LABELS:", labels)
|
|
#if len(labels) >= 1:
|
|
# for arch in archives:
|
|
# archid=arch[0]
|
|
# archive_labels=self.get_label_info(archid)
|
|
# success=True
|
|
# for label in archive_labels:
|
|
# if not label[0] in labels:
|
|
# success=False
|
|
# if success:
|
|
# positive_archives.append(arch)
|
|
# if len(positive_archives) >= count:
|
|
# break
|
|
|
|
return archives
|
|
|
|
if __name__ == "__main__":
|
|
#startup()
|
|
db=db(conn_params)
|
|
db.cur.close()
|
|
db.conn.close()
|
|
exit() |