from django.core.management.base import BaseCommand
from django.db import transaction
from apps.tamin.models import (
    PrescriptionType, ServiceType, DrugInstruction, ParTarefValue,
    DrugAmount, DrugUsage, TreatmentPlan, Illness, Complaint,
    ICD10Code, MedicalSpecialty, Service
)

class Command(BaseCommand):
    help = 'Load basic data from Tamin API'
    
    def add_arguments(self, parser):
        parser.add_argument(
            '--client-id',
            type=str,
            help='Tamin API client ID',
        )
        parser.add_argument(
            '--access_token',
            type=str,
            help='Tamin API access_token',
        )
    
    def handle(self, *args, **options):
        self.load_from_api(options)
    
    def load_from_api(self, options):
        """Load basic data from Tamin API"""
        from apps.tamin.tamin_sdk import TaminClient
        
        client_id = options.get('client_id') or 'portal-js'
        access_token = options.get('access_token')
        
        client = TaminClient(
            client_id=client_id,
            redirect_uri='http://127.0.0.1',
            test_environment=True
        )
        client._access_token = access_token
        basic_data = client.get_basic_data()
        self.process_basic_data(basic_data)
    
    def process_basic_data(self, data):
        """Process and save basic data to database"""
        
        # Load prescription types
        if 'prescription_types' in data:
            self.stdout.write('Loading prescription types...')
            PrescriptionType.objects.all().delete()
            prescription_types = []
            for item in data['prescription_types']:
                prescription_types.append(
                    PrescriptionType(
                    id=item['prescTypeId'],
                    code=item['prescTypeCode'],
                    desc=item['prescTypeDesc'],
                    )
                )
            PrescriptionType.objects.bulk_create(prescription_types, batch_size=300)
        
        # Load service types
        srv_type_cache = {}
        if 'service_types' in data:
            self.stdout.write('Loading service types...')
            ServiceType.objects.all().delete()
            for item in data['service_types']:
                srv_type_cache[item['srvType']] = ServiceType.objects.create(
                    code=item['srvType'],
                    desc=item['srvTypeDes'],
                    status=item.get('status'),
                    status_date=item.get('statusstDate'),
                    cost_type=item.get('custType'),
                    presc_type_id=item.get('prescTypeId'),
                    head_expire_date=item.get('headExpireDate'),
                )
        
        # Load drug instructions
        if 'drug_instructions' in data:
            self.stdout.write('Loading drug instructions...')
            DrugInstruction.objects.all().delete()
            drug_instructions = []
            for item in data['drug_instructions']:
                drug_instructions.append(
                    DrugInstruction(
                        id=item['drugInstId'],
                        code=item.get('drugInstCode'),
                        summary=item.get('drugInstSumry'),
                        latin=item.get('drugInstLatin'),
                        concept=item['drugInstConcept'],
                    )
                )
            DrugInstruction.objects.bulk_create(drug_instructions, batch_size=300)
        
        # Load par taref values
        if 'par_taref_values' in data:
            self.stdout.write('Loading par taref values...')
            ParTarefValue.objects.all().delete()
            par_tarefs = []
            for item in data['par_taref_values']:
                par_tarefs.append(
                    ParTarefValue(
                        grp_code=item['parGrpCode'],
                        grp_desc=item['parGrpDesc'],
                        grp_rem=item.get('parGrpRem'),
                        status=item['status'],
                        status_date=item['statusStDate'],
                    )
                )
            ParTarefValue.objects.bulk_create(par_tarefs, batch_size=300)
        
        # Load drug amounts
        if 'drug_amounts' in data:
            self.stdout.write('Loading drug amounts...')
            DrugAmount.objects.all().delete()
            drug_amounts = []
            for item in data['drug_amounts']:
                drug_amounts.append(
                    DrugAmount(
                        id=item['drugAmntId'],
                        code=item['drugAmntCode'],
                        summary=item.get('drugAmntSumry'),
                        latin=item.get('drugAmntLatin'),
                        concept=item['drugAmntConcept'],
                        visible=item['visibled'],
                    )
                )
            DrugAmount.objects.bulk_create(drug_amounts, batch_size=300)
        
        # Load drug usages
        if 'drug_usages' in data:
            self.stdout.write('Loading drug usages...')
            DrugUsage.objects.all().delete()
            drug_usages = []
            for item in data['drug_usages']:
                drug_usages.append(
                    DrugUsage(
                        id=item['drugUsageId'],
                        code=item['drugUsageCode'],
                        summary=item.get('drugUsageSumry'),
                        latin=item.get('drugUsageLatin'),
                        concept=item['drugUsageConcept'],
                        visible=item['visible'],
                        drug_form_code=item.get('drugFormCode'),
                    )
                )
            DrugUsage.objects.bulk_create(drug_usages, batch_size=300)
        
        # Load treatment plans
        if 'treatment_plans' in data:
            self.stdout.write('Loading treatment plans...')
            TreatmentPlan.objects.all().delete()
            treatment_plans = []
            for item in data['treatment_plans']:
                treatment_plans.append(
                    TreatmentPlan(
                        id=item['planId'],
                        desc=item['planDesc'],
                        code=item['planCode'],
                    )
                )
            TreatmentPlan.objects.bulk_create(treatment_plans, batch_size=300)
        
        # Load illnesses
        if 'illnesses' in data:
            self.stdout.write('Loading illnesses...')
            Illness.objects.all().delete()
            illnesses = []
            for item in data['illnesses']:
                illnesses.append(
                    Illness(
                        id=item['illnessId'],
                        desc=item['illnessDesc'],
                    )
                )
            Illness.objects.bulk_create(illnesses, batch_size=600)
        
        # Load complaints
        if 'complaints' in data:
            self.stdout.write('Loading complaints...')
            Complaint.objects.all().delete()
            complaints = []
            for item in data['complaints']:
                complaints.append(
                    Complaint(
                        icd_id=item['icdId'],
                        icd_code=item['icdCode'],
                        icd_name=item['icdName'],
                    )
                )
            Complaint.objects.bulk_create(complaints, batch_size=5000)
        
        # Load ICD10 codes
        if 'icd10_codes' in data:
            self.stdout.write('Loading ICD10 codes...')
            ICD10Code.objects.all().delete()
            icd_codes = []
            for item in data['icd10_codes']:
                icd_codes.append(
                    ICD10Code(
                        id=item['id'],
                        code=item['code'],
                        display_name=item['displayName'],
                        english_name=item['englishName'],
                        terminology=item['terminology'],
                        status=item['status']
                    )
                )
            ICD10Code.objects.bulk_create(icd_codes, batch_size=5000)
        
        # Load medical specialties
        if 'medical_specialties' in data:
            self.stdout.write("Loading medical specialties...")
            MedicalSpecialty.objects.all().delete()
            specialties = []
            for item in data['medical_specialties']:
                specialties.append(
                    MedicalSpecialty(
                        code=item['specCode'],
                        desc=item['specDesc'],
                        grp=item.get('specGRP'),
                        doc_comment=item.get('docComment'),
                        status=item.get('status'),
                        status_date=item.get('statusstDate'),
                        type_spec=item.get('typeSpec'),
                        max_normal=item.get('maxNormal'),
                        max_special=item.get('maxSpecial'),
                        l_status=item.get('lStatus')
                    )
                )
            MedicalSpecialty.objects.bulk_create(specialties, batch_size=600)
        
        # Load services
        if 'service_codes' in data:
            self.stdout.write("Loading services...")
            Service.objects.all().delete()
            for service_type in data['service_codes'].keys():
                services = []
                for item in data['service_codes'][service_type]:
                    if type(item) == dict:
                        # if len(item.get('srvName')) > 240:
                        #     print("Service name too long, truncating: ", item['srvName'])
                        #     continue
                        services.append(
                            Service(
                                service_type=srv_type_cache.get(service_type, None),
                                srv_id=item['srvId'],
                                srv_type=item['srvType'],
                                srv_code=item.get('srvCode'),
                                name=item.get('srvName'),
                                name2=item.get('srvName2'),
                                bim_sw=item.get('srvBimSw'),
                                sex=item.get('srvSex'),
                                price=item.get('srvPrice'),
                                price_date=item.get('srvPriceDate'),
                                dose_code=item.get('doseCode'),
                                form_code=item.get('formCode'),
                                par_taref_group=item.get('parTarefGrp'),
                                status=item.get('status'),
                                status_date=item.get('statusstDate'),
                                bg_type=item.get('bGType'),
                                gsrv_code=item.get('gSrvCode'),
                                agreement_flag=item.get('afreementFlag'),
                                is_deleted=item.get('is_deleted'),
                                visible=item.get('visible'),
                                dental_service_type=item.get('dentalServiceType'),
                                ws_srv_code=item.get('wsSrvCode'),
                                hos_presc_type=item.get('hosprescType'),
                                srv_rule=item.get('srvRule'),
                                count_is_restricted=item.get('countIsRestricted'),
                                terminology=item.get('terminology'),
                                srv_code_complete=item.get('srvCodeComplete')
                            )
                        )
                    else:
                        print(f"Irregular service item: {item} -> {service_type}")
                try:
                    Service.objects.bulk_create(services, batch_size=5000)
                except Exception as e:
                    self.stderr.write(f"Error saving services: {e}")
                    self.stderr.write(f"Failed service type: {service_type}")
                    