import psycopg2

import numpy as np
from dotenv import load_dotenv
import os

import config.paths

load_dotenv(dotenv_path=config.paths.ENV_FILE)

vector_db_name = os.getenv("VECTOR_DB_NAME")
vector_db_user = os.getenv("DB_USER")
vector_db_password = os.getenv("DB_PASSWORD")
vector_db_host = os.getenv("DB_HOST")

connection_string = f"dbname={vector_db_name} user={vector_db_user} password={vector_db_password} host={vector_db_host}"

def insert_into_vector_database(id, text, embedding_vector: np.ndarray, audio_embedding_vector: np.ndarray):
    conn = psycopg2.connect(connection_string)
    cursor = conn.cursor()

    insert_query = """
    INSERT INTO lecture_repre (id, content, embedding, audio_embedding)
    VALUES (%s, %s, %s, %s)
    """
    cursor.execute(insert_query, (id, text, embedding_vector.tolist(), audio_embedding_vector.tolist()))
    conn.commit()
    cursor.close()
    conn.close()

def find_topk_most_relevant_lectures(k, query_embedding_vector: np.ndarray):
    conn = psycopg2.connect(connection_string)
    cursor = conn.cursor()
    # "<=>" means cosine distance
    limit = k
    select_query = """
        SELECT *
        FROM lecture_repre
        ORDER BY embedding <=> %s::vector
        LIMIT %s;
        """
    cursor.execute(select_query, (query_embedding_vector.tolist(), limit))
    results = cursor.fetchall()

    for row in results:
        print(row)

    conn.commit()
    cursor.close()
    conn.close()

def show_all_tables():
    conn = psycopg2.connect(connection_string)
    cursor = conn.cursor()
    cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
    # Fetch all the table names
    table_names = cursor.fetchall()

    # Print the table names
    for table_name in table_names:
        print(table_name[0])

if __name__ == '__main__':
    # test db code
    show_all_tables()
    # id = random.randint(1, 10000000)
    # insert_into_vector_database(id, "zkusebni text", np.random.random([768]))
    find_topk_most_relevant_lectures(5, np.random.random([768]))

