import sys
import json
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from joblib import dump
import pandas as pd
import numpy as np

def preprocess_data(courses, users, competencies):
    """
    Preprocess the data to create a DataFrame suitable for training.
    """
    # Example: Combine all course descriptions and skills into a single string per course
    for course in courses:
        course_skills = " ".join([str(skill['skill_id']) for skill in course.get('skills', [])])
        course['all_text'] = f"{course['name']} {course['short_description']} {course['description']} {course_skills}"
    
    courses_df = pd.DataFrame(courses)
    return courses_df

def train_model(courses_df):
    """
    Train a TF-IDF model based on the courses data.
    """
    vectorizer = TfidfVectorizer(stop_words='english')
    tfidf_matrix = vectorizer.fit_transform(courses_df['all_text'])
    
    return vectorizer, tfidf_matrix

def save_model(client_id, vectorizer, tfidf_matrix):
    """
    Save the trained model and TF-IDF matrix.
    """
    model_path = f'models/{client_id}/'
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    dump(vectorizer, f'{model_path}tfidf_vectorizer.joblib')
    dump(tfidf_matrix, f'{model_path}tfidf_matrix.joblib')

def main(training_data_json):
    training_data = json.loads(training_data_json)

    for client_id, courses in training_data['courses_data'].items():
        users = training_data['users_data'].get(client_id, {})
        competencies = training_data['competency_data'].get(client_id, [])
        
        # Preprocess the data
        courses_df = preprocess_data(courses, users, competencies)
        
        # Train the model
        vectorizer, tfidf_matrix = train_model(courses_df)
        
        # Save the model and matrix
        save_model(client_id, vectorizer, tfidf_matrix)

    print("Training completed successfully.")

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python train_model.py <training_data_json>")
        sys.exit(1)
    
    training_data_json = sys.argv[1]
    main(training_data_json)
