import mysql.connector
from mysql.connector import Error
import matplotlib.pyplot as plt

from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

GROQ_API_KEY = "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"

def connect_to_mysql():
    
    # Establish the connection
    # connection = mysql.connector.connect(
    #     host='127.0.0.1',  # or localhost
    #     port=3306,
    #     user='root',  # replace with your MySQL username
    #     password='',  # replace with your MySQL password
    #     database='demo database',  # replace with your database name
    #     charset='utf8mb4'
    # )
    connection = mysql.connector.connect(
        host='localhost',  # or localhost
        #port=3306,
        user='root',  # replace with your MySQL username
        password='',  # replace with your MySQL password
        database='demo database',  # replace with your database name
        
    )

        



def generate_sql_query():
    """
    Convert natural language query to SQL using LLM.
    """
    system = """
    You are an expert in converting English questions to SQL query!
    The SQL database has the name lms_courses and has the following columns - name, price, 
    discount, lms_course_id \n\nFor example,\nExample 1 - How many entries of records are present?, 
    the SQL command will be something like this SELECT COUNT(*) FROM lms_courses ; 
    also the sql code should not have ``` in beginning or end and sql word in output

    """

    chat = ChatGroq(model_name='llama-3.1-70b-versatile', groq_api_key=GROQ_API_KEY)

    human = "{text}"
    prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])

    chain = prompt | chat
    output = chain.invoke({"text": "Give me top 10 name, course id and price of course having having highest price."})
    return  output.content
    

def execute_sql_query( sql_query):
    """Execute the SQL query and fetch results."""
    try:
        connection = mysql.connector.connect(
        host='localhost',  # or localhost
        #port=3306,
        user='root',  # replace with your MySQL username
        password='',  # replace with your MySQL password
        database='demo database',  # replace with your database name
        
        )
        cursor = connection.cursor()
        cursor.execute(sql_query)
        results = cursor.fetchall()
        return results
    except Error as e:
        print("Error executing SQL query:", e)
        return None


def main():
    # Input natural language query
    #natural_language_query = input("Enter your query in natural language: ")
    #conn = connect_to_mysql()
    sql_query = generate_sql_query()
    print(sql_query)
    results = execute_sql_query( sql_query)
    if results is not None:
        print("Query Results:")
        course_names = []
        course_prices = []

        for row in results:
            print(row)
            course_names.append(row[0])  # Extract course name
            course_prices.append(float(row[2]))  # Extract price (convert Decimal to float)

        # Generate a bar chart
        plt.figure(figsize=(10, 6))
        plt.barh(course_names, course_prices)
        plt.xlabel("Price (in currency)", fontsize=12)
        plt.ylabel("Course Name", fontsize=12)
        plt.title("Top 10 Courses by Price", fontsize=14)
        plt.gca().invert_yaxis()  # Invert y-axis for better readability
        plt.tight_layout()
        plt.show()
    else:
        print("No results or an error occurred while executing the query.")

if __name__ == "__main__":
    main()