from flask import Blueprint, request, jsonify, send_from_directory
from flask_jwt_extended import jwt_required
from db.db import get_db_connection
from werkzeug.utils import secure_filename
from config.auth import role_required
from product_image_optimizer import ProductImageOptimizer, optimize_to_50kb, optimize_to_100kb, optimize_to_size
import mysql.connector
import traceback
from datetime import datetime, date
import json
import os
import base64
from datetime import datetime, timedelta
from decimal import Decimal

product_bp = Blueprint('product', __name__)


# ==========================================
# FILE UPLOAD CONFIGURATION
# ==========================================
UPLOAD_FOLDER = 'uploads/products'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

MAX_IMAGE_SIZE = 5 * 1024 * 1024  # 5MB (before optimization)
MAX_CERT_SIZE = 10 * 1024 * 1024  # 10MB

# Image optimization settings
IMAGE_OPTIMIZATION_MODE = 'target_size'
IMAGE_TARGET_SIZE_KB = 20
IMAGE_QUALITY = 55
IMAGE_MAX_DIMENSION = 1200


# ==========================================
# FILE SERVING ROUTE
# ==========================================
@product_bp.route('/uploads/products/<path:filename>', methods=['GET'])
def serve_product_file(filename):
    """Serve uploaded product files (images and certificates)"""
    try:
        return send_from_directory(UPLOAD_FOLDER, filename)
    except FileNotFoundError:
        return jsonify({'error': 'File not found'}), 404


# ==========================================
# HELPER FUNCTIONS
# ==========================================

def validate_required_fields(data, fields):
    """Validate that required fields exist and are not empty"""
    missing = []
    for field in fields:
        value = data.get(field)
        if value is None or value == '':
            missing.append(field)
    return missing


def safe_float(value, default=0.0):
    """Safely convert value to float"""
    try:
        return float(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def safe_int(value, default=0):
    """Safely convert value to int"""
    try:
        return int(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def save_base64_file(base64_data, file_prefix='file'):
    """Save base64 file to disk with automatic image optimization"""
    try:
        if not base64_data or 'base64,' not in base64_data:
            return None
        
        header, encoded = base64_data.split(',', 1)
        
        is_image = 'image/' in header
        is_pdf = 'application/pdf' in header
        
        # Optimize images
        if is_image:
            print(f"\n🖼️  Optimizing image (Mode: {IMAGE_OPTIMIZATION_MODE})...")
            
            optimized_base64 = None
            
            if IMAGE_OPTIMIZATION_MODE == 'target_size':
                print(f"🎯 Target size: {IMAGE_TARGET_SIZE_KB}KB")
                optimized_base64 = optimize_to_size(
                    base64_data,
                    target_kb=IMAGE_TARGET_SIZE_KB
                )
            else:
                print(f"📊 Quality: {IMAGE_QUALITY}%, Max dimension: {IMAGE_MAX_DIMENSION}px")
                optimized_base64 = ProductImageOptimizer.optimize_base64_image(
                    base64_data,
                    max_width=IMAGE_MAX_DIMENSION,
                    max_height=IMAGE_MAX_DIMENSION,
                    quality=IMAGE_QUALITY
                )
            
            if optimized_base64:
                header, encoded = optimized_base64.split(',', 1)
                print(f"✅ Image optimized successfully")
            else:
                print(f"⚠️  Optimization failed, using original image")
        
        # Determine extension
        if 'image/jpeg' in header or 'image/jpg' in header:
            ext = '.jpg'
        elif 'image/png' in header:
            ext = '.png'
        elif 'image/gif' in header:
            ext = '.gif'
        elif 'image/webp' in header:
            ext = '.webp'
        elif 'application/pdf' in header:
            ext = '.pdf'
        else:
            ext = '.bin'
        
        # Generate unique filename
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
        filename = secure_filename(f"{file_prefix}_{timestamp}{ext}")
        filepath = os.path.join(UPLOAD_FOLDER, filename)
        
        # Decode and save
        file_data = base64.b64decode(encoded)
        
        # Validate size
        file_size = len(file_data)
        if is_image and file_size > MAX_IMAGE_SIZE:
            print(f"❌ Image too large after optimization: {file_size:,} bytes (max: {MAX_IMAGE_SIZE:,})")
            return None
        if is_pdf and file_size > MAX_CERT_SIZE:
            print(f"❌ PDF too large: {file_size:,} bytes (max: {MAX_CERT_SIZE:,})")
            return None
        
        # Save file
        with open(filepath, 'wb') as f:
            f.write(file_data)
        
        print(f"💾 Saved: {filename} ({file_size:,} bytes = {file_size/1024:.1f}KB)")
        
        return filepath
        
    except Exception as e:
        print(f"❌ Error saving file: {e}")
        import traceback
        traceback.print_exc()
        return None


def cleanup_uploaded_files(image_paths, certificate_path):
    """Clean up uploaded files in case of error"""
    if image_paths:
        for path in image_paths:
            try:
                if os.path.exists(path):
                    os.remove(path)
                    print(f"  🗑️  Cleaned up: {path}")
            except Exception as e:
                print(f"  ⚠️  Could not delete {path}: {e}")
    
    if certificate_path and os.path.exists(certificate_path):
        try:
            os.remove(certificate_path)
            print(f"  🗑️  Cleaned up certificate: {certificate_path}")
        except Exception as e:
            print(f"  ⚠️  Could not delete certificate: {e}")


# ==========================================
# GET ALL PRODUCTS — FIXED after schema change
# products.base_unit_id NOW → base_units.id (direct)
# OLD: products.base_unit_id → units.id → base_units.id
# ==========================================

@product_bp.route('/get_products', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_products():
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)
    try:
        # ✅ FIXED QUERY:
        #   - products.base_unit_id  → base_units bu1 (DIRECT — no units intermediate)
        #   - products.sale_unit_id  → units u2 → base_units bu2 (unchanged)
        #   - products.purchase_unit_id → units u3 → base_units bu3 (unchanged)
        cursor.execute("""
            SELECT 
                p.id,
                p.product_name,
                p.sku,
                p.barcode_symbology,
                p.product_image,
                p.certificate_image,
                p.product_type,
                p.stock_alert,
                p.note,
                p.product_tax,
                p.tax_type,
                p.expiration_date,
                p.created_at,
                p.updated_at,

                -- Brand
                COALESCE(b.brand_name, '-')     AS brand_name,
                b.id                            AS brand_id,

                -- Category
                COALESCE(c.category_name, '-')  AS category_name,
                c.id                            AS category_id,

                -- ✅ Base unit (products.base_unit_id → base_units DIRECT)
                COALESCE(bu1.base_unit, '-')    AS base_unit_name,
                bu1.id                          AS base_unit_id,

                -- Sale unit  (products.sale_unit_id → units → base_units)
                COALESCE(u2.unit_name, '-')     AS sale_unit,
                u2.id                           AS sale_unit_id,
                COALESCE(bu2.base_unit, '-')    AS sale_base_unit_name,

                -- Purchase unit (products.purchase_unit_id → units → base_units)
                COALESCE(u3.unit_name, '-')     AS purchase_unit,
                u3.id                           AS purchase_unit_id,
                COALESCE(bu3.base_unit, '-')    AS purchase_base_unit_name

            FROM products p
            LEFT JOIN brands     b   ON p.brand_id         = b.id
            LEFT JOIN categories c   ON p.category_id      = c.id

            -- ✅ FIXED: direct join to base_units (no intermediate units table)
            LEFT JOIN base_units bu1 ON p.base_unit_id     = bu1.id

            -- sale unit still goes through units → base_units
            LEFT JOIN units      u2  ON p.sale_unit_id     = u2.id
            LEFT JOIN base_units bu2 ON u2.base_unit_id    = bu2.id

            -- purchase unit still goes through units → base_units
            LEFT JOIN units      u3  ON p.purchase_unit_id = u3.id
            LEFT JOIN base_units bu3 ON u3.base_unit_id    = bu3.id

            ORDER BY p.created_at DESC
        """)
        products = cursor.fetchall()

        for product in products:
            product_id = product['id']

            # ── Product Images ──────────────────────────────────────────────
            product_images = []
            if product.get('product_image'):
                try:
                    raw = product['product_image']
                    image_paths = json.loads(raw) if isinstance(raw, str) else raw
                    if isinstance(image_paths, list):
                        for img_path in image_paths:
                            if img_path and os.path.exists(img_path):
                                filename = os.path.basename(img_path)
                                product_images.append(f"/uploads/products/{filename}")
                except (json.JSONDecodeError, TypeError) as e:
                    print(f"❌ Image parse error (product {product_id}): {e}")

            product['product_images'] = product_images
            product['image_count']    = len(product_images)

            # ── Certificate Image ────────────────────────────────────────────
            certificate_url = None
            if product.get('certificate_image'):
                cert_path = product['certificate_image']
                if cert_path and os.path.exists(cert_path):
                    certificate_url = f"/uploads/products/{os.path.basename(cert_path)}"

            product['certificate_url'] = certificate_url
            product['has_certificate'] = bool(certificate_url)

            # Remove raw file-path columns — clients don't need them
            product.pop('product_image',    None)
            product.pop('certificate_image', None)

            # ── Total Stock (single product) ─────────────────────────────────
            cursor.execute("""
                SELECT IFNULL(SUM(quantity), 0) AS total_stock
                FROM warehouse_stock
                WHERE product_id = %s AND variation_id IS NULL
            """, (product_id,))
            stock_row = cursor.fetchone()
            product['total_stock'] = float(stock_row['total_stock']) if stock_row else 0.0

            # ── Batch Price Stats (single product) ───────────────────────────
            cursor.execute("""
                SELECT
                    AVG(cost)              AS avg_cost,
                    AVG(price)             AS avg_price,
                    MIN(cost)              AS min_cost,
                    MAX(cost)              AS max_cost,
                    MIN(price)             AS min_price,
                    MAX(price)             AS max_price,
                    SUM(quantity)          AS total_purchased,
                    SUM(remaining_quantity) AS total_remaining
                FROM product_batches
                WHERE product_id = %s AND variation_id IS NULL
            """, (product_id,))
            cp = cursor.fetchone()

            def _f(val): return float(val) if val is not None else 0.0

            product['avg_cost']        = _f(cp['avg_cost'])        if cp else 0.0
            product['avg_price']       = _f(cp['avg_price'])       if cp else 0.0
            product['min_cost']        = _f(cp['min_cost'])        if cp else 0.0
            product['max_cost']        = _f(cp['max_cost'])        if cp else 0.0
            product['min_price']       = _f(cp['min_price'])       if cp else 0.0
            product['max_price']       = _f(cp['max_price'])       if cp else 0.0
            product['total_purchased'] = _f(cp['total_purchased']) if cp else 0.0
            product['total_remaining'] = _f(cp['total_remaining']) if cp else 0.0

            # ── Warehouse-wise Stock ─────────────────────────────────────────
            cursor.execute("""
                SELECT
                    ws.warehouse_id,
                    COALESCE(w.warehouse_name, 'Unknown') AS warehouse_name,
                    SUM(ws.quantity)                       AS stock_quantity
                FROM warehouse_stock ws
                LEFT JOIN warehouses w ON ws.warehouse_id = w.id
                WHERE ws.product_id = %s AND ws.variation_id IS NULL
                GROUP BY ws.warehouse_id, w.warehouse_name
            """, (product_id,))
            product['warehouse_stock'] = cursor.fetchall()

            # ── Batch List ───────────────────────────────────────────────────
            cursor.execute("""
                SELECT
                    pb.batch_id,
                    pb.batch_number,
                    pb.quantity,
                    pb.remaining_quantity,
                    pb.cost,
                    pb.price,
                    pb.expiration_date,
                    pb.created_on,
                    po.order_id,
                    po.status       AS order_status,
                    po.grn_status,
                    COALESCE(s.supplier_name, 'Unknown') AS supplier_name,
                    pb.grn_id
                FROM product_batches pb
                LEFT JOIN purchase_orders po ON pb.purchase_order_id = po.order_id
                LEFT JOIN suppliers       s  ON po.supplier_id       = s.id
                WHERE pb.product_id = %s AND pb.variation_id IS NULL
                ORDER BY pb.created_on DESC
            """, (product_id,))
            product['batches'] = cursor.fetchall()

            # ── Variations (variable products only) ──────────────────────────
            if product['product_type'] == 'variable':
                cursor.execute("""
                    SELECT
                        id          AS variation_id,
                        variation_name,
                        variation_type,
                        variation_sku,
                        variation_cost,
                        variation_price,
                        variation_tax_type,
                        variation_tax,
                        variation_stock_alert,
                        expiration_date,
                        created_at,
                        updated_at
                    FROM product_variations
                    WHERE product_id = %s
                    ORDER BY created_at ASC
                """, (product_id,))
                variations = cursor.fetchall()

                for v in variations:
                    vid = v['variation_id']

                    cursor.execute("""
                        SELECT IFNULL(SUM(quantity), 0) AS total_stock
                        FROM warehouse_stock
                        WHERE product_id = %s AND variation_id = %s
                    """, (product_id, vid))
                    vsr = cursor.fetchone()
                    v['total_stock'] = float(vsr['total_stock']) if vsr else 0.0

                    cursor.execute("""
                        SELECT
                            ws.warehouse_id,
                            COALESCE(w.warehouse_name, 'Unknown') AS warehouse_name,
                            SUM(ws.quantity)                       AS stock_quantity
                        FROM warehouse_stock ws
                        LEFT JOIN warehouses w ON ws.warehouse_id = w.id
                        WHERE ws.product_id = %s AND ws.variation_id = %s
                        GROUP BY ws.warehouse_id, w.warehouse_name
                    """, (product_id, vid))
                    v['warehouse_stock'] = cursor.fetchall()

                    cursor.execute("""
                        SELECT
                            pb.batch_id,
                            pb.batch_number,
                            pb.quantity,
                            pb.remaining_quantity,
                            pb.cost,
                            pb.price,
                            pb.expiration_date,
                            pb.created_on,
                            po.order_id,
                            po.status       AS order_status,
                            po.grn_status,
                            COALESCE(s.supplier_name, 'Unknown') AS supplier_name,
                            pb.grn_id
                        FROM product_batches pb
                        LEFT JOIN purchase_orders po ON pb.purchase_order_id = po.order_id
                        LEFT JOIN suppliers       s  ON po.supplier_id       = s.id
                        WHERE pb.product_id = %s AND pb.variation_id = %s
                        ORDER BY pb.created_on DESC
                    """, (product_id, vid))
                    v['batches'] = cursor.fetchall()

                    cursor.execute("""
                        SELECT
                            SUM(quantity)           AS total_purchased,
                            SUM(remaining_quantity) AS total_remaining
                        FROM product_batches
                        WHERE product_id = %s AND variation_id = %s
                    """, (product_id, vid))
                    vbs = cursor.fetchone()
                    v['total_purchased'] = _f(vbs['total_purchased']) if vbs else 0.0
                    v['total_remaining'] = _f(vbs['total_remaining']) if vbs else 0.0

                product['variations']      = variations
                product['total_stock']     = sum(v['total_stock'] for v in variations)
                product['total_variations'] = len(variations)

            else:
                product['variations']      = []
                product['total_variations'] = 0

            # ── Associated Purchase Orders ───────────────────────────────────
            cursor.execute("""
                SELECT DISTINCT
                    po.order_id,
                    po.status,
                    po.grn_status,
                    po.grand_total,
                    po.created_on,
                    COALESCE(s.supplier_name,  'Unknown') AS supplier_name,
                    COALESCE(w.warehouse_name, 'Unknown') AS warehouse_name
                FROM purchase_orders po
                LEFT JOIN suppliers  s ON po.supplier_id  = s.id
                LEFT JOIN warehouses w ON po.warehouse_id = w.id
                WHERE po.order_id IN (
                    SELECT DISTINCT purchase_order_id
                    FROM product_batches
                    WHERE product_id = %s AND purchase_order_id IS NOT NULL
                )
                ORDER BY po.created_on DESC
            """, (product_id,))
            product['purchase_orders'] = cursor.fetchall()

        return jsonify({
            'success': True,
            'count': len(products),
            'data': products
        }), 200

    except mysql.connector.Error as err:
        print(f"Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as err:
        print(f"Unexpected error: {err}")
        traceback.print_exc()
        return jsonify({'error': 'Failed to fetch products'}), 500

    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ==========================================
# ✅ FIXED: ADD PRODUCT WITH PURCHASE ORDER
# ==========================================
# FIX: Use dictionary cursor and properly handle all query results

@product_bp.route('/add_product_and_order', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def add_product_and_order():
    """
    ✅ FIXED: Add new product with purchase order
    
    KEY FIX: Proper cursor result handling to avoid "Commands out of sync" errors
    
    Changes:
    1. Use cursor(dictionary=True) consistently
    2. Consume all results with fetchone()/fetchall() before next query
    3. Check cursor.rowcount instead of checking fetchone() existence for validation
    4. ✅ NEW: Validate base_unit_id against base_units table (not units table)
    """
    
    data = request.get_json()
    if not data:
        return jsonify({'error': 'No data provided'}), 400
    
    print("=" * 80)
    print("📦 NEW PRODUCT REQUEST (FIXED - PROPER CURSOR HANDLING)")
    print("=" * 80)
    
    # Extract fields (same as before)
    product_name = data.get('product_name', '').strip()
    sku = data.get('sku', '').strip()
    barcode_symbology = data.get('barcode_symbology')
    product_type = data.get('product_type', 'single')
    brand_id = data.get('brand_id')
    category_id = data.get('category_id')
    base_unit_id = data.get('base_unit_id')
    sale_unit_id = data.get('sale_unit_id')
    purchase_unit_id = data.get('purchase_unit_id')
    stock_alert = safe_int(data.get('stock_alert', 0))
    note = data.get('note', '').strip()
    product_tax = safe_float(data.get('product_tax', 0.0))
    tax_type = data.get('tax_type', 'exclusive')
    expiration_date = data.get('expiration_date') or None
    
    product_cost = safe_float(data.get('product_cost', 0.0))
    product_price = safe_float(data.get('product_price', 0.0))
    product_quantity = safe_float(data.get('product_quantity', 0.0))
    
    variations = data.get('variations', [])
    
    supplier_id = data.get('supplier_id')
    warehouse_id = data.get('warehouse_id')
    store_id = safe_int(data.get('store_id', 1))
    order_tax = safe_float(data.get('order_tax', 0.0))
    discount = safe_float(data.get('discount', 0.0))
    status = data.get('status', 'Pending')
    
    print("\n📋 EXTRACTED FIELDS:")
    print(f"  product_name: {product_name}")
    print(f"  sku: {sku}")
    print(f"  base_unit_id: {base_unit_id} (from base_units table)")
    print(f"  sale_unit_id: {sale_unit_id} (from units table)")
    print(f"  purchase_unit_id: {purchase_unit_id} (from units table)")
    print(f"  product_cost (PURCHASE): {product_cost} ✅")
    print(f"  product_price (SELLING): {product_price} ✅")
    print(f"  supplier_id: {supplier_id}")
    print("=" * 80)
    
    # Prevent "Received" status
    if status and status.lower() == 'received':
        return jsonify({
            'error': 'Status "Received" not allowed. Use "Ordered" and create GRN.',
            'allowed_statuses': ['Pending', 'Ordered']
        }), 400
    
    grn_status = 'not_received'
    grand_total = safe_float(data.get('grand_total', 0.0))
    products_list = data.get('products', [])
    
    # Process images (same as before)
    product_images_data = data.get('product_image')
    certificate_data = data.get('certificate_image')
    
    saved_image_paths = []
    if product_images_data:
        try:
            images_list = json.loads(product_images_data) if isinstance(product_images_data, str) else product_images_data
            
            if len(images_list) > 10:
                return jsonify({'error': 'Maximum 10 images allowed'}), 400
            
            for idx, img_base64 in enumerate(images_list, 1):
                if img_base64:
                    file_path = save_base64_file(img_base64, f'product_img_{idx}')
                    if file_path:
                        saved_image_paths.append(file_path)
        except Exception as e:
            print(f"❌ Error processing images: {e}")
    
    saved_certificate_path = None
    if certificate_data:
        try:
            saved_certificate_path = save_base64_file(certificate_data, 'certificate')
        except Exception as e:
            print(f"❌ Error processing certificate: {e}")
    
    product_image_json = json.dumps(saved_image_paths) if saved_image_paths else None
    certificate_image_path = saved_certificate_path
    
    # Validation (same as before)
    required_fields = [
        'product_name', 'sku', 'product_type',
        'brand_id', 'category_id', 'base_unit_id',
        'sale_unit_id', 'purchase_unit_id',
        'supplier_id', 'warehouse_id'
    ]
    
    missing = validate_required_fields(data, required_fields)
    
    if product_type == 'single':
        if product_cost <= 0:
            missing.append('product_cost')
        if product_price <= 0:
            missing.append('product_price')
        if product_quantity <= 0:
            missing.append('product_quantity')
    elif product_type == 'variable':
        if not variations or len(variations) == 0:
            missing.append('variations')
    
    if missing:
        cleanup_uploaded_files(saved_image_paths, saved_certificate_path)
        return jsonify({
            'error': 'Validation failed',
            'missing_fields': list(set(missing))
        }), 400
    
    # Database operations
    conn = get_db_connection()
    if not conn:
        cleanup_uploaded_files(saved_image_paths, saved_certificate_path)
        return jsonify({'error': 'Database connection failed'}), 500
    
    # ✅ FIX: Use dictionary cursor
    cursor = conn.cursor(dictionary=True)
    
    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")
        
        # ✅ FIX: Validate supplier - consume result immediately
        print(f"\n🔍 Validating supplier_id: {supplier_id}")
        cursor.execute("SELECT id FROM suppliers WHERE id = %s", (supplier_id,))
        supplier_result = cursor.fetchone()  # ✅ Consume result
        
        if not supplier_result:
            raise ValueError(f"Supplier with ID {supplier_id} not found")
        print(f"✅ Supplier validated")
        
        # ✅ FIX: Validate warehouse - consume result immediately
        print(f"🔍 Validating warehouse_id: {warehouse_id}")
        cursor.execute("SELECT id FROM warehouses WHERE id = %s", (warehouse_id,))
        warehouse_result = cursor.fetchone()  # ✅ Consume result
        
        if not warehouse_result:
            raise ValueError(f"Warehouse with ID {warehouse_id} not found")
        print(f"✅ Warehouse validated")
        
        # ✅ FIX: Validate brand - consume result immediately
        if brand_id:
            print(f"🔍 Validating brand_id: {brand_id}")
            cursor.execute("SELECT id FROM brands WHERE id = %s", (brand_id,))
            brand_result = cursor.fetchone()  # ✅ Consume result
            
            if not brand_result:
                raise ValueError(f"Brand with ID {brand_id} not found")
            print(f"✅ Brand validated")
        
        # ✅ FIX: Validate category - consume result immediately
        if category_id:
            print(f"🔍 Validating category_id: {category_id}")
            cursor.execute("SELECT id FROM categories WHERE id = %s", (category_id,))
            category_result = cursor.fetchone()  # ✅ Consume result
            
            if not category_result:
                raise ValueError(f"Category with ID {category_id} not found")
            print(f"✅ Category validated")
        
        # ✅ FIXED: Validate base_unit_id against base_units table (NOT units table)
        print(f"🔍 Validating base_unit_id: {base_unit_id} (from base_units table)")
        cursor.execute("SELECT id FROM base_units WHERE id = %s", (base_unit_id,))
        base_unit_result = cursor.fetchone()  # ✅ Consume result
        
        if not base_unit_result:
            raise ValueError(f"Base unit with ID {base_unit_id} not found in base_units table")
        print(f"✅ Base unit validated")
        
        # ✅ FIXED: Validate sale_unit_id against units table
        print(f"🔍 Validating sale_unit_id: {sale_unit_id} (from units table)")
        cursor.execute("SELECT id FROM units WHERE id = %s", (sale_unit_id,))
        sale_unit_result = cursor.fetchone()  # ✅ Consume result
        
        if not sale_unit_result:
            raise ValueError(f"Sale unit with ID {sale_unit_id} not found in units table")
        print(f"✅ Sale unit validated")
        
        # ✅ FIXED: Validate purchase_unit_id against units table
        print(f"🔍 Validating purchase_unit_id: {purchase_unit_id} (from units table)")
        cursor.execute("SELECT id FROM units WHERE id = %s", (purchase_unit_id,))
        purchase_unit_result = cursor.fetchone()  # ✅ Consume result
        
        if not purchase_unit_result:
            raise ValueError(f"Purchase unit with ID {purchase_unit_id} not found in units table")
        print(f"✅ Purchase unit validated")
        
        # Insert product
        print(f"\n📦 Inserting product '{product_name}'...")
        cursor.execute("""
            INSERT INTO products (
                product_name, sku, barcode_symbology, product_type,
                expiration_date, brand_id, category_id,
                base_unit_id, sale_unit_id, purchase_unit_id,
                stock_alert, note, product_tax, tax_type,
                product_image, certificate_image, created_at
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
        """, (
            product_name, sku, barcode_symbology, product_type,
            expiration_date, brand_id, category_id,
            base_unit_id, sale_unit_id, purchase_unit_id,
            stock_alert, note, product_tax, tax_type,
            product_image_json, certificate_image_path
        ))
        
        product_id = cursor.lastrowid
        print(f"✅ Product created with ID: {product_id}")
        
        # Insert variations if needed
        variation_map = {}
        if product_type == 'variable' and variations:
            print(f"\n🔢 Inserting {len(variations)} variations...")
            for idx, var in enumerate(variations, 1):
                cursor.execute("""
                    INSERT INTO product_variations (
                        product_id, variation_name, variation_type, variation_sku,
                        variation_cost, variation_price, variation_tax_type, variation_tax,
                        variation_stock_alert, expiration_date, created_at
                    )
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
                """, (
                    product_id,
                    var.get('variation_name'),
                    var.get('variation_type'),
                    var.get('variation_sku'),
                    safe_float(var.get('variation_cost')),
                    safe_float(var.get('variation_price')),
                    var.get('variation_tax_type', 'exclusive'),
                    safe_float(var.get('variation_tax')),
                    safe_int(var.get('variation_stock_alert')),
                    var.get('expiration_date')
                ))
                
                variation_id = cursor.lastrowid
                variation_map[var.get('variation_sku')] = {
                    'id': variation_id,
                    'quantity': safe_float(var.get('variation_quantity', 0)),
                    'data': var
                }
                print(f"  ✅ Variation {idx}: {var.get('variation_type')} (ID: {variation_id})")
        
        # Create purchase order
        print(f"\n🛒 Creating purchase order...")
        
        cursor.execute("""
            INSERT INTO purchase_orders (
                supplier_id, warehouse_id, store_id,
                note, order_tax, discount, 
                status, grn_status,
                grand_total,
                payment_status, paid_amount, due_amount, 
                created_on
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
        """, (
            supplier_id, warehouse_id, store_id,
            note, order_tax, discount,
            status,
            grn_status,
            grand_total,
            'Unpaid', 0.00, grand_total
        ))
        
        order_id = cursor.lastrowid
        print(f"✅ Purchase order created with ID: {order_id}")
        
        # Insert order items
        print(f"\n📝 Creating {len(products_list)} order items...")
        
        for idx, prod in enumerate(products_list, 1):
            sku_match = prod.get('sku')
            variation_info = variation_map.get(sku_match) if sku_match else None
            variation_id = variation_info['id'] if variation_info else None
            
            if variation_id:
                var_data = variation_info['data']
                exp_date = var_data.get('expiration_date')
                tax_percentage = safe_float(var_data.get('variation_tax'))
                tax_type_value = var_data.get('variation_tax_type', 'exclusive')
                cost_value = safe_float(var_data.get('variation_cost'))
                price_value = safe_float(var_data.get('variation_price'))
            else:
                exp_date = expiration_date
                tax_percentage = product_tax
                tax_type_value = tax_type
                cost_value = product_cost
                price_value = product_price
            
            quantity_value = safe_float(prod.get('quantity'))
            unit_price_cost = safe_float(prod.get('unit_price', cost_value))
            net_unit_cost_from_frontend = safe_float(prod.get('price'))
            subtotal = safe_float(prod.get('subtotal'))
            tax_amount = safe_float(prod.get('tax'))
            discount_amount = safe_float(prod.get('discount', 0.0))
            
            net_unit_cost_value = net_unit_cost_from_frontend
            
            print(f"\n  📌 Item {idx}:")
            print(f"     unit_price: {unit_price_cost}")
            print(f"     net_unit_cost: {net_unit_cost_value}")
            print(f"     selling_price: {price_value}")
            
            cursor.execute("""
                INSERT INTO order_items (
                    order_id, product_id, variation_id,
                    product_tax, tax_type, quantity,
                    purchase_unit, discount, discount_type,
                    tax, subtotal, expiration_date,
                    unit_price, net_unit_cost, selling_price,
                    created_on
                )
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
            """, (
                order_id, product_id, variation_id,
                tax_percentage,
                tax_type_value,
                quantity_value,
                prod.get('purchase_unit', ''),
                discount_amount,
                'fixed',
                tax_amount,
                subtotal,
                exp_date,
                unit_price_cost,
                net_unit_cost_value,
                price_value
            ))
            
            order_item_id = cursor.lastrowid
            print(f"     ✅ Order item created (ID: {order_item_id})")
        
        print("\n⚠️  Batches and stock will be created when GRN is approved")
        
        conn.commit()
        print("=" * 80)
        print("✅ TRANSACTION COMMITTED!")
        print("=" * 80)
        
        return jsonify({
            'success': True,
            'message': f"Product and order created successfully! {('Create GRN for PO #' + str(order_id)) if status == 'Ordered' else ''}",
            'data': {
                'product_id': product_id,
                'product_name': product_name,
                'order_id': order_id,
                'status': status,
                'grn_status': grn_status,
                'next_step': f"Create GRN for Purchase Order #{order_id}" if status == 'Ordered' else "Order created"
            }
        }), 201
    
    except mysql.connector.IntegrityError as err:
        conn.rollback()
        cleanup_uploaded_files(saved_image_paths, saved_certificate_path)
        error_msg = str(err)
        if 'Duplicate entry' in error_msg:
            if 'sku' in error_msg.lower():
                return jsonify({'error': 'SKU already exists'}), 409
            return jsonify({'error': 'Duplicate entry'}), 409
        elif 'foreign key constraint' in error_msg.lower():
            return jsonify({'error': 'Invalid supplier, warehouse, brand, category or unit ID'}), 400
        return jsonify({'error': f'Database error: {error_msg}'}), 500
    
    except ValueError as err:
        conn.rollback()
        cleanup_uploaded_files(saved_image_paths, saved_certificate_path)
        return jsonify({'error': str(err)}), 400
    
    except Exception as e:
        conn.rollback()
        cleanup_uploaded_files(saved_image_paths, saved_certificate_path)
        print(f"\n❌ ERROR: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()

def convert_datetime_to_string(obj):
    """
    Convert datetime objects to ISO format strings
    Database format: 2026-01-25 22:20:01
    Returns: 2026-01-25T22:20:01
    Frontend will handle formatting to: 2026/01/25 10:20:01 PM
    """
    if isinstance(obj, datetime):
        return obj.isoformat()
    return obj


def process_row_datetimes(row):
    """Convert all datetime fields in a row to ISO strings"""
    if not row:
        return row
    
    processed_row = {}
    for key, value in row.items():
        if isinstance(value, datetime):
            processed_row[key] = value.isoformat()
        elif isinstance(value, Decimal):
            processed_row[key] = float(value)
        else:
            processed_row[key] = value
    return processed_row


# ==========================================
# GET SINGLE PRODUCT WITH BATCH DISCOUNT RULES
# ==========================================
@product_bp.route('/get_product/<int:id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_product(id):
    """
    Fetch complete product details.
    - All datetime fields returned in ISO format: YYYY-MM-DDTHH:MM:SS
    - our_price column REMOVED (was dropped from product_batches)
    - Each batch now includes discount_rules[] from product_batch_discounts table
      [{payment_method_id, method_name, discount_rate, discount_type, is_active}]
    """
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)
    try:

        # ── Helper: load discount rules for a list of batch_ids ────────────────
        def get_discount_rules_for_batches(batch_ids):
            """
            Returns a dict: {batch_id: [rules]}
            Rules are fetched in one query for all batches.
            """
            if not batch_ids:
                return {}

            fmt = ','.join(['%s'] * len(batch_ids))
            cursor.execute(f"""
                SELECT
                    pbd.batch_id,
                    pbd.id             AS rule_id,
                    pbd.payment_method_id,
                    pm.method_name,
                    pbd.discount_rate,
                    pbd.discount_type,
                    pbd.is_active
                FROM product_batch_discounts pbd
                LEFT JOIN payment_methods pm ON pm.id = pbd.payment_method_id
                WHERE pbd.batch_id IN ({fmt})
                  AND pbd.is_active = 1
                ORDER BY pbd.batch_id, pbd.payment_method_id
            """, tuple(batch_ids))

            rows  = cursor.fetchall()
            rules = {}
            for row in rows:
                bid = row['batch_id']
                if bid not in rules:
                    rules[bid] = []
                rules[bid].append({
                    'rule_id':           row['rule_id'],
                    'payment_method_id': row['payment_method_id'],
                    'method_name':       row['method_name'],
                    'discount_rate':     float(row['discount_rate']),
                    'discount_type':     row['discount_type'],
                    'is_active':         bool(row['is_active']),
                })
            return rules

        # ── Base product query ──────────────────────────────────────────────────
        cursor.execute("""
            SELECT
                p.id,
                p.product_name,
                p.sku,
                p.barcode_symbology,
                p.product_image,
                p.certificate_image,
                p.product_type,
                p.stock_alert,
                p.note,
                p.product_tax,
                p.tax_type,
                p.expiration_date,
                p.created_at,
                p.updated_at,
                b.brand_name,
                b.id AS brand_id,
                c.category_name,
                c.id AS category_id,
                p.base_unit_id,
                bu.base_unit AS base_unit_name,
                u2.id AS sale_unit_id,
                u2.unit_name AS sale_unit_name,
                u3.id AS purchase_unit_id,
                u3.unit_name AS purchase_unit_name
            FROM products p
            LEFT JOIN brands b      ON p.brand_id      = b.id
            LEFT JOIN categories c  ON p.category_id   = c.id
            LEFT JOIN base_units bu ON p.base_unit_id  = bu.id
            LEFT JOIN units u2      ON p.sale_unit_id  = u2.id
            LEFT JOIN units u3      ON p.purchase_unit_id = u3.id
            WHERE p.id = %s
        """, (id,))

        product = cursor.fetchone()

        if not product:
            return jsonify({'error': 'Product not found'}), 404

        product = process_row_datetimes(product)

        # ── Product images ──────────────────────────────────────────────────────
        product_images = []
        if product.get('product_image'):
            try:
                image_paths = (
                    json.loads(product['product_image'])
                    if isinstance(product['product_image'], str)
                    else product['product_image']
                )
                if isinstance(image_paths, list):
                    for img_path in image_paths:
                        if img_path and os.path.exists(img_path):
                            product_images.append(f"/uploads/products/{os.path.basename(img_path)}")
            except (json.JSONDecodeError, TypeError) as e:
                print(f"❌ Error parsing product images for product {id}: {e}")

        product['product_images'] = product_images
        product['image_count']    = len(product_images)

        # ── Certificate image ───────────────────────────────────────────────────
        certificate_url = None
        if product.get('certificate_image'):
            cert_path = product['certificate_image']
            if cert_path and os.path.exists(cert_path):
                certificate_url = f"/uploads/products/{os.path.basename(cert_path)}"

        product['certificate_url'] = certificate_url
        product['has_certificate'] = bool(certificate_url)

        product.pop('product_image',   None)
        product.pop('certificate_image', None)

        # ── Total stock (single product, no variation) ──────────────────────────
        cursor.execute("""
            SELECT IFNULL(SUM(quantity), 0) AS total_stock
            FROM warehouse_stock
            WHERE product_id = %s AND variation_id IS NULL
        """, (id,))
        stock_result = cursor.fetchone()
        product['total_stock'] = float(stock_result['total_stock']) if stock_result else 0.0

        # ── Avg/min/max cost and price from batches (no our_price) ──────────────
        cursor.execute("""
            SELECT
                AVG(cost)              AS avg_cost,
                AVG(price)             AS avg_price,
                MIN(cost)              AS min_cost,
                MAX(cost)              AS max_cost,
                MIN(price)             AS min_price,
                MAX(price)             AS max_price,
                SUM(quantity)          AS total_purchased,
                SUM(remaining_quantity) AS total_remaining
            FROM product_batches
            WHERE product_id = %s AND variation_id IS NULL
        """, (id,))
        cp = cursor.fetchone()

        if cp and cp['avg_cost']:
            product['avg_cost']       = float(cp['avg_cost'])
            product['avg_price']      = float(cp['avg_price'])
            product['min_cost']       = float(cp['min_cost'])
            product['max_cost']       = float(cp['max_cost'])
            product['min_price']      = float(cp['min_price'])
            product['max_price']      = float(cp['max_price'])
            product['total_purchased'] = float(cp['total_purchased'])
            product['total_remaining'] = float(cp['total_remaining'])

            product['price_display'] = (
                f"{product['min_price']:.2f}"
                if product['min_price'] == product['max_price']
                else f"{product['min_price']:.2f} - {product['max_price']:.2f}"
            )
        else:
            product.update({
                'avg_cost': 0.0, 'avg_price': 0.0,
                'min_cost': 0.0, 'max_cost': 0.0,
                'min_price': 0.0, 'max_price': 0.0,
                'total_purchased': 0.0, 'total_remaining': 0.0,
                'price_display': '0.00',
            })

        # ── Warehouse-wise stock ────────────────────────────────────────────────
        cursor.execute("""
            SELECT
                ws.warehouse_id,
                w.warehouse_name,
                ws.batch_id,
                SUM(ws.quantity) AS stock_quantity
            FROM warehouse_stock ws
            LEFT JOIN warehouses w ON ws.warehouse_id = w.id
            WHERE ws.product_id = %s AND ws.variation_id IS NULL
            GROUP BY ws.warehouse_id, w.warehouse_name, ws.batch_id
        """, (id,))
        product['warehouse_stock'] = [
            process_row_datetimes(row) for row in cursor.fetchall()
        ]

        # ── Batches (active or latest empty, with discount_rules) ───────────────
        cursor.execute("""
            SELECT
                pb.batch_id,
                pb.batch_number,
                pb.cost,
                pb.price,
                pb.quantity,
                pb.remaining_quantity,
                pb.expiration_date,
                pb.created_on,
                po.order_id,
                po.status       AS order_status,
                po.grn_status,
                s.supplier_name,
                pb.grn_id,
                g.grn_code,
                g.status        AS grn_status_detail,
                g.grn_date,
                g.created_at    AS grn_created_at
            FROM product_batches pb
            LEFT JOIN purchase_orders po ON pb.purchase_order_id = po.order_id
            LEFT JOIN suppliers s        ON po.supplier_id = s.id
            LEFT JOIN grn g              ON pb.grn_id = g.grn_id
            WHERE pb.product_id = %s
              AND pb.variation_id IS NULL
              AND (
                  pb.remaining_quantity > 0
                  OR (
                      pb.remaining_quantity = 0
                      AND NOT EXISTS (
                          SELECT 1 FROM product_batches pb_check
                          WHERE pb_check.product_id = pb.product_id
                            AND pb_check.variation_id IS NULL
                            AND pb_check.cost  = pb.cost
                            AND pb_check.price = pb.price
                            AND pb_check.remaining_quantity > 0
                      )
                      AND pb.batch_id = (
                          SELECT pb2.batch_id
                          FROM product_batches pb2
                          WHERE pb2.product_id    = pb.product_id
                            AND pb2.variation_id  IS NULL
                            AND pb2.cost          = pb.cost
                            AND pb2.price         = pb.price
                            AND pb2.remaining_quantity = 0
                          ORDER BY pb2.created_on DESC, pb2.batch_id DESC
                          LIMIT 1
                      )
                  )
              )
            ORDER BY pb.created_on DESC, pb.batch_id DESC
        """, (id,))
        raw_batches = cursor.fetchall()

        # Fetch discount rules for all these batches in one query
        batch_ids        = [b['batch_id'] for b in raw_batches]
        batch_rules_map  = get_discount_rules_for_batches(batch_ids)

        processed_batches = []
        for batch in raw_batches:
            pb = process_row_datetimes(batch)
            pb['discount_rules'] = batch_rules_map.get(pb['batch_id'], [])
            processed_batches.append(pb)

        product['batches'] = processed_batches

        # ── Variations (variable products) ──────────────────────────────────────
        if product['product_type'] == 'variable':
            cursor.execute("""
                SELECT
                    id AS variation_id,
                    variation_name,
                    variation_type,
                    variation_sku,
                    variation_cost,
                    variation_price,
                    variation_tax_type,
                    variation_tax,
                    variation_stock_alert,
                    expiration_date,
                    created_at,
                    updated_at
                FROM product_variations
                WHERE product_id = %s
                ORDER BY created_at ASC
            """, (id,))
            variations = cursor.fetchall()

            processed_variations = []
            for variation in variations:
                pv = process_row_datetimes(variation)
                variation_id = pv['variation_id']

                # Variation total stock
                cursor.execute("""
                    SELECT IFNULL(SUM(quantity), 0) AS total_stock
                    FROM warehouse_stock
                    WHERE product_id = %s AND variation_id = %s
                """, (id, variation_id))
                vs = cursor.fetchone()
                pv['total_stock'] = float(vs['total_stock']) if vs else 0.0

                # Variation warehouse stock
                cursor.execute("""
                    SELECT
                        ws.warehouse_id,
                        w.warehouse_name,
                        ws.batch_id,
                        SUM(ws.quantity) AS stock_quantity
                    FROM warehouse_stock ws
                    LEFT JOIN warehouses w ON ws.warehouse_id = w.id
                    WHERE ws.product_id = %s AND ws.variation_id = %s
                    GROUP BY ws.warehouse_id, w.warehouse_name, ws.batch_id
                """, (id, variation_id))
                pv['warehouse_stock'] = [
                    process_row_datetimes(row) for row in cursor.fetchall()
                ]

                # Variation batches
                cursor.execute("""
                    SELECT
                        pb.batch_id,
                        pb.batch_number,
                        pb.cost,
                        pb.price,
                        pb.quantity,
                        pb.remaining_quantity,
                        pb.expiration_date,
                        pb.created_on,
                        po.order_id,
                        po.status       AS order_status,
                        po.grn_status,
                        s.supplier_name,
                        pb.grn_id,
                        g.grn_code,
                        g.status        AS grn_status_detail,
                        g.grn_date,
                        g.created_at    AS grn_created_at
                    FROM product_batches pb
                    LEFT JOIN purchase_orders po ON pb.purchase_order_id = po.order_id
                    LEFT JOIN suppliers s        ON po.supplier_id = s.id
                    LEFT JOIN grn g              ON pb.grn_id = g.grn_id
                    WHERE pb.product_id  = %s
                      AND pb.variation_id = %s
                      AND (
                          pb.remaining_quantity > 0
                          OR (
                              pb.remaining_quantity = 0
                              AND NOT EXISTS (
                                  SELECT 1 FROM product_batches pb_check
                                  WHERE pb_check.product_id   = pb.product_id
                                    AND pb_check.variation_id = pb.variation_id
                                    AND pb_check.cost         = pb.cost
                                    AND pb_check.price        = pb.price
                                    AND pb_check.remaining_quantity > 0
                              )
                              AND pb.batch_id = (
                                  SELECT pb2.batch_id
                                  FROM product_batches pb2
                                  WHERE pb2.product_id    = pb.product_id
                                    AND pb2.variation_id  = pb.variation_id
                                    AND pb2.cost          = pb.cost
                                    AND pb2.price         = pb.price
                                    AND pb2.remaining_quantity = 0
                                  ORDER BY pb2.created_on DESC, pb2.batch_id DESC
                                  LIMIT 1
                              )
                          )
                      )
                    ORDER BY pb.created_on DESC, pb.batch_id DESC
                """, (id, variation_id))
                var_raw_batches = cursor.fetchall()

                var_batch_ids   = [b['batch_id'] for b in var_raw_batches]
                var_rules_map   = get_discount_rules_for_batches(var_batch_ids)

                processed_var_batches = []
                for batch in var_raw_batches:
                    pb = process_row_datetimes(batch)
                    pb['discount_rules'] = var_rules_map.get(pb['batch_id'], [])
                    processed_var_batches.append(pb)

                pv['batches'] = processed_var_batches

                # Variation batch stats (no our_price)
                cursor.execute("""
                    SELECT
                        SUM(quantity)           AS total_purchased,
                        SUM(remaining_quantity) AS total_remaining,
                        AVG(price)              AS avg_price,
                        MIN(price)              AS min_price,
                        MAX(price)              AS max_price
                    FROM product_batches
                    WHERE product_id = %s AND variation_id = %s
                """, (id, variation_id))
                vbs = cursor.fetchone()
                pv['total_purchased'] = float(vbs['total_purchased']) if vbs and vbs['total_purchased'] else 0.0
                pv['total_remaining'] = float(vbs['total_remaining']) if vbs and vbs['total_remaining'] else 0.0
                pv['avg_price']       = float(vbs['avg_price'])       if vbs and vbs['avg_price']       else 0.0
                pv['min_price']       = float(vbs['min_price'])       if vbs and vbs['min_price']       else 0.0
                pv['max_price']       = float(vbs['max_price'])       if vbs and vbs['max_price']       else 0.0

                processed_variations.append(pv)

            product['variations']     = processed_variations
            product['total_stock']    = sum(v['total_stock'] for v in processed_variations)
            product['total_variations'] = len(processed_variations)

            if processed_variations:
                variation_prices = [
                    float(v['variation_price'])
                    for v in processed_variations
                    if v.get('variation_price')
                ]
                if variation_prices:
                    mn, mx = min(variation_prices), max(variation_prices)
                    product['price_display'] = f"{mn:.2f}" if mn == mx else f"{mn:.2f} - {mx:.2f}"
        else:
            product['variations']     = []
            product['total_variations'] = 0

        # ── Associated purchase orders ──────────────────────────────────────────
        cursor.execute("""
            SELECT DISTINCT
                po.order_id,
                po.status,
                po.grn_status,
                po.grand_total,
                po.created_on,
                s.supplier_name,
                w.warehouse_name
            FROM purchase_orders po
            LEFT JOIN suppliers s  ON po.supplier_id  = s.id
            LEFT JOIN warehouses w ON po.warehouse_id = w.id
            WHERE po.order_id IN (
                SELECT DISTINCT purchase_order_id
                FROM product_batches
                WHERE product_id = %s AND purchase_order_id IS NOT NULL
            )
            ORDER BY po.created_on DESC
        """, (id,))
        product['purchase_orders'] = [
            process_row_datetimes(row) for row in cursor.fetchall()
        ]

        return jsonify({
            'success': True,
            'data':    product,
            'timestamp_format': 'ISO 8601 (YYYY-MM-DDTHH:MM:SS)',
            'note': (
                'All datetime fields are in database timezone. '
                'our_price removed — use batch.discount_rules[] for per-payment-method discounts.'
            ),
        }), 200

    except mysql.connector.Error as err:
        print(f"❌ Database Error fetching product: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as err:
        print(f"❌ Error fetching product: {err}")
        traceback.print_exc()
        return jsonify({'error': 'Failed to fetch product'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()

# ==========================================
# GET ALL GRNs WITH EXACT TIMESTAMPS
# ==========================================
@product_bp.route('/grns', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_all_grns():
    """Get all GRNs with exact database timestamps"""
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        # Build query with filters
        query = """
            SELECT 
                g.grn_id,
                g.grn_code,
                g.grn_date,
                g.invoice_number,
                g.invoice_date,
                g.status,
                g.payment_status,
                g.paid_amount,
                g.due_amount,
                g.grand_total,
                g.total_items,
                g.note,
                g.created_at,
                g.approved_at,
                g.updated_at,
                po.order_id,
                s.supplier_name,
                s.supplier_code,
                w.warehouse_name,
                st.store_name,
                u1.name AS received_by_name,
                u2.name AS created_by_name,
                u3.name AS approved_by_name
            FROM grn g
            LEFT JOIN purchase_orders po ON g.purchase_order_id = po.order_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            LEFT JOIN warehouses w ON g.warehouse_id = w.id
            LEFT JOIN stores st ON g.store_id = st.id
            LEFT JOIN users u1 ON g.received_by = u1.id
            LEFT JOIN users u2 ON g.created_by = u2.id
            LEFT JOIN users u3 ON g.approved_by = u3.id
            ORDER BY g.created_at DESC
        """
        
        cursor.execute(query)
        grns = cursor.fetchall()
        
        # Process all GRNs with datetime conversion
        processed_grns = [process_row_datetimes(grn) for grn in grns]
        
        return jsonify({
            'success': True,
            'data': processed_grns,
            'total': len(processed_grns),
            'timestamp_format': 'ISO 8601 (YYYY-MM-DDTHH:MM:SS)'
        }), 200
    
    except mysql.connector.Error as err:
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500
    
    except Exception as err:
        print(f"❌ Error: {err}")
        traceback.print_exc()
        return jsonify({'error': 'Failed to fetch GRNs'}), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ==========================================
# GET SINGLE GRN WITH EXACT TIMESTAMPS
# ==========================================
@product_bp.route('/grn/<int:grn_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_grn_details(grn_id):
    """Get GRN details with exact database timestamps"""
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        # Get GRN header
        cursor.execute("""
            SELECT 
                g.*,
                po.order_id,
                po.status AS po_status,
                s.supplier_name,
                s.supplier_code,
                w.warehouse_name,
                st.store_name,
                u1.name AS received_by_name,
                u2.name AS created_by_name,
                u3.name AS approved_by_name
            FROM grn g
            LEFT JOIN purchase_orders po ON g.purchase_order_id = po.order_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            LEFT JOIN warehouses w ON g.warehouse_id = w.id
            LEFT JOIN stores st ON g.store_id = st.id
            LEFT JOIN users u1 ON g.received_by = u1.id
            LEFT JOIN users u2 ON g.created_by = u2.id
            LEFT JOIN users u3 ON g.approved_by = u3.id
            WHERE g.grn_id = %s
        """, (grn_id,))
        
        grn = cursor.fetchone()
        
        if not grn:
            return jsonify({'error': 'GRN not found'}), 404
        
        # Process GRN header with datetime conversion
        grn = process_row_datetimes(grn)
        
        # Get GRN items
        cursor.execute("""
            SELECT 
                gi.*,
                p.product_name,
                p.sku,
                pv.variation_name,
                pv.variation_type,
                pb.batch_number
            FROM grn_items gi
            LEFT JOIN products p ON gi.product_id = p.id
            LEFT JOIN product_variations pv ON gi.variation_id = pv.id
            LEFT JOIN product_batches pb ON gi.batch_id = pb.batch_id
            WHERE gi.grn_id = %s
            ORDER BY gi.grn_item_id
        """, (grn_id,))
        
        grn_items = cursor.fetchall()
        grn['items'] = [process_row_datetimes(item) for item in grn_items]
        
        return jsonify({
            'success': True,
            'data': grn,
            'timestamp_format': 'ISO 8601 (YYYY-MM-DDTHH:MM:SS)'
        }), 200
    
    except mysql.connector.Error as err:
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
            
# @product_bp.route('/get_product/<int:id>', methods=['GET'])
# @jwt_required()
# @role_required('admin', 'manager')
# def get_product(id):
#     """
#     ✅ FIXED: Fetch complete product details
#     - Correctly joins base_units table
#     """
#     conn = get_db_connection()
#     if conn is None:
#         return jsonify({'error': 'Database connection failed'}), 500

#     cursor = conn.cursor(dictionary=True)
#     try:
#         # ✅ FIXED: Fetch product with correct base_units JOIN
#         cursor.execute("""
#             SELECT 
#                 p.id,
#                 p.product_name,
#                 p.sku,
#                 p.barcode_symbology,
#                 p.product_image,
#                 p.certificate_image,
#                 p.product_type,
#                 p.stock_alert,
#                 p.note,
#                 p.product_tax,
#                 p.tax_type,
#                 p.expiration_date,
#                 p.created_at,
#                 p.updated_at,
                
#                 b.brand_name,
#                 b.id AS brand_id,
                
#                 c.category_name,
#                 c.id AS category_id,
                
#                 -- ✅ FIXED: Get base_unit directly from base_units table
#                 p.base_unit_id AS base_unit_id,
#                 bu.base_unit AS base_unit_name,
                
#                 -- Sale unit (from units table)
#                 u2.id AS sale_unit_id,
#                 u2.unit_name AS sale_unit_name,
                
#                 -- Purchase unit (from units table)
#                 u3.id AS purchase_unit_id,
#                 u3.unit_name AS purchase_unit_name
                
#             FROM products p
#             LEFT JOIN brands b ON p.brand_id = b.id
#             LEFT JOIN categories c ON p.category_id = c.id
            
#             -- ✅ FIXED: Direct join to base_units table
#             LEFT JOIN base_units bu ON p.base_unit_id = bu.id
            
#             -- Sale and Purchase units from units table
#             LEFT JOIN units u2 ON p.sale_unit_id = u2.id
#             LEFT JOIN units u3 ON p.purchase_unit_id = u3.id
            
#             WHERE p.id = %s
#         """, (id,))
#         product = cursor.fetchone()

#         if not product:
#             return jsonify({'error': 'Product not found'}), 404

#         # Process product images
#         product_images = []
#         if product.get('product_image'):
#             try:
#                 image_paths = json.loads(product['product_image']) if isinstance(product['product_image'], str) else product['product_image']
                
#                 if isinstance(image_paths, list):
#                     for img_path in image_paths:
#                         if img_path and os.path.exists(img_path):
#                             filename = os.path.basename(img_path)
#                             image_url = f"/uploads/products/{filename}"
#                             product_images.append(image_url)
#             except (json.JSONDecodeError, TypeError) as e:
#                 print(f"❌ Error parsing product images for product {id}: {e}")
        
#         product['product_images'] = product_images
#         product['image_count'] = len(product_images)
        
#         # Process certificate image
#         certificate_url = None
#         if product.get('certificate_image'):
#             cert_path = product['certificate_image']
#             if cert_path and os.path.exists(cert_path):
#                 filename = os.path.basename(cert_path)
#                 certificate_url = f"/uploads/products/{filename}"
        
#         product['certificate_url'] = certificate_url
#         product['has_certificate'] = bool(certificate_url)
        
#         product.pop('product_image', None)
#         product.pop('certificate_image', None)

#         # ... rest of the code (stock, batches, variations) remains the same ...
#         # Calculate total stock from warehouse_stock
#         cursor.execute("""
#             SELECT IFNULL(SUM(quantity), 0) AS total_stock
#             FROM warehouse_stock
#             WHERE product_id = %s AND variation_id IS NULL
#         """, (id,))
#         stock_result = cursor.fetchone()
#         product['total_stock'] = float(stock_result['total_stock']) if stock_result else 0.0
        
#         # Get average cost and price from batches
#         cursor.execute("""
#             SELECT 
#                 AVG(cost) AS avg_cost,
#                 AVG(price) AS avg_price,
#                 MIN(cost) AS min_cost,
#                 MAX(cost) AS max_cost,
#                 MIN(price) AS min_price,
#                 MAX(price) AS max_price,
#                 SUM(quantity) AS total_purchased,
#                 SUM(remaining_quantity) AS total_remaining
#             FROM product_batches
#             WHERE product_id = %s AND variation_id IS NULL
#         """, (id,))
#         cost_price_result = cursor.fetchone()
        
#         if cost_price_result and cost_price_result['avg_cost']:
#             product['avg_cost'] = float(cost_price_result['avg_cost'])
#             product['avg_price'] = float(cost_price_result['avg_price'])
#             product['min_cost'] = float(cost_price_result['min_cost'])
#             product['max_cost'] = float(cost_price_result['max_cost'])
#             product['min_price'] = float(cost_price_result['min_price'])
#             product['max_price'] = float(cost_price_result['max_price'])
#             product['total_purchased'] = float(cost_price_result['total_purchased'])
#             product['total_remaining'] = float(cost_price_result['total_remaining'])
            
#             if product['min_price'] == product['max_price']:
#                 product['price_display'] = f"{product['min_price']:.2f}"
#             else:
#                 product['price_display'] = f"{product['min_price']:.2f} - {product['max_price']:.2f}"
#         else:
#             product['avg_cost'] = 0.0
#             product['avg_price'] = 0.0
#             product['min_cost'] = 0.0
#             product['max_cost'] = 0.0
#             product['min_price'] = 0.0
#             product['max_price'] = 0.0
#             product['total_purchased'] = 0.0
#             product['total_remaining'] = 0.0
#             product['price_display'] = "0.00"
        
#         # (Include all other queries for warehouse_stock, batches, variations, etc.)
#         # ... [Keep the rest of your existing code] ...

#         return jsonify({
#             'success': True,
#             'data': product
#         }), 200

#     except mysql.connector.Error as err:
#         print(f"Database Error fetching product: {err}")
#         traceback.print_exc()
#         return jsonify({'error': f'Database error: {str(err)}'}), 500
    
#     except Exception as err:
#         print(f"Error fetching product: {err}")
#         traceback.print_exc()
#         return jsonify({'error': 'Failed to fetch product'}), 500

#     finally:
#         if cursor:
#             cursor.close()
#         if conn:
#             conn.close()

# ============================================
# ✅ CORRECTED: EXPIRED PRODUCTS ENDPOINTS
# ============================================



@product_bp.route('/expired_products', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_expired_products():
    """
    Get all expired and expiring-soon products with batch tracking
    
    Returns:
        JSON with two arrays:
        - expired: Products that have already expired
        - expiring_soon: Products expiring within 30 days
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)
        
        current_date = datetime.now().date()
        expiring_threshold = current_date + timedelta(days=30)
        
        # ✅ Corrected SQL Query with proper joins
        query = """
        SELECT 
            pb.batch_id,
            pb.batch_number,
            pb.quantity,
            pb.remaining_quantity,
            pb.cost,
            pb.price,
            pb.expiration_date,
            pb.purchase_order_id,
            
            -- Product details
            p.id AS product_id,
            p.product_name,
            p.sku,
            p.product_type,
            
            -- Variation details (if applicable)
            pv.id AS variation_id,
            pv.variation_name,
            pv.variation_type,
            pv.variation_sku,
            pv.variation_cost,
            pv.variation_price,
            
            -- Brand
            b.brand_name,
            
            -- Category
            c.category_name,
            
            -- GRN details
            g.grn_id,
            g.grn_code,
            g.invoice_number,
            
            -- Purchase Order details
            po.order_id,
            po.status AS purchase_status,
            
            -- Supplier details
            s.supplier_name,
            
            -- Warehouse/Store location
            w.warehouse_name,
            st.store_name,
            
            -- Unit details
            u.unit_name AS sale_unit_name,
            u.unit_short AS sale_unit,
            
            -- Calculate subtotal (use variation cost if exists, else product batch cost)
            (pb.remaining_quantity * COALESCE(pv.variation_cost, pb.cost)) AS subtotal,
            
            -- Determine status
            CASE 
                WHEN pb.expiration_date < CURDATE() THEN 'expired'
                WHEN pb.expiration_date <= DATE_ADD(CURDATE(), INTERVAL 30 DAY) THEN 'expiring_soon'
                ELSE 'normal'
            END AS status
            
        FROM product_batches pb
        
        -- Join with products (required)
        INNER JOIN products p ON pb.product_id = p.id
        
        -- Join with variations (optional - only for variable products)
        LEFT JOIN product_variations pv ON pb.variation_id = pv.id
        
        -- Join with brands (optional)
        LEFT JOIN brands b ON p.brand_id = b.id
        
        -- Join with categories (optional)
        LEFT JOIN categories c ON p.category_id = c.id
        
        -- Join with units (for unit name)
        LEFT JOIN units u ON p.sale_unit_id = u.id
        
        -- Join with GRN (optional)
        LEFT JOIN grn g ON pb.grn_id = g.grn_id
        
        -- Join with purchase orders (optional)
        LEFT JOIN purchase_orders po ON pb.purchase_order_id = po.order_id
        
        -- Join with suppliers (via purchase order)
        LEFT JOIN suppliers s ON po.supplier_id = s.id
        
        -- Join with warehouse stock to get location
        LEFT JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
        LEFT JOIN warehouses w ON ws.warehouse_id = w.id
        LEFT JOIN stores st ON ws.store_id = st.id
        
        WHERE 
            pb.expiration_date IS NOT NULL
            AND pb.remaining_quantity > 0
            AND pb.expiration_date <= %s
            
        ORDER BY 
            pb.expiration_date ASC,
            p.product_name ASC
        """
        
        # Execute query
        cursor.execute(query, (expiring_threshold,))
        results = cursor.fetchall()
        
        # Separate expired and expiring_soon
        expired = []
        expiring_soon = []
        
        for row in results:
            # Determine display SKU
            display_sku = row['variation_sku'] if row['variation_id'] else row['sku']
            
            # Build product name with variation
            product_display_name = row['product_name']
            if row['variation_name'] and row['variation_type']:
                product_display_name = f"{row['product_name']} - {row['variation_name']} - ({row['variation_type']})"
            elif row['variation_name']:
                product_display_name = f"{row['product_name']} - {row['variation_name']}"
            
            # Convert Decimal to float for JSON serialization
            product_data = {
                'batch_id': row['batch_id'],
                'batch_number': row['batch_number'],
                'quantity': float(row['quantity']) if row['quantity'] else 0,
                'remaining_quantity': float(row['remaining_quantity']) if row['remaining_quantity'] else 0,
                'cost': float(row['cost']) if row['cost'] else 0,
                'price': float(row['price']) if row['price'] else 0,
                'expiration_date': row['expiration_date'].isoformat() if row['expiration_date'] else None,
                'purchase_order_id': row['purchase_order_id'],
                
                # Product info
                'product_id': row['product_id'],
                'product_name': row['product_name'],
                'product_display_name': product_display_name,
                'sku': display_sku,
                'product_type': row['product_type'],
                
                # Variation info (for variable products)
                'variation_id': row['variation_id'],
                'variation_name': row['variation_name'],
                'variation_type': row['variation_type'],
                'variation_sku': row['variation_sku'],
                'variation_cost': float(row['variation_cost']) if row['variation_cost'] else None,
                'variation_price': float(row['variation_price']) if row['variation_price'] else None,
                
                # Unit info
                'sale_unit': row['sale_unit'] or '',
                'sale_unit_name': row['sale_unit_name'] or '',
                
                # Additional details
                'brand_name': row['brand_name'],
                'category_name': row['category_name'],
                'grn_id': row['grn_id'],
                'grn_code': row['grn_code'],
                'invoice_number': row['invoice_number'],
                'purchase_status': row['purchase_status'],
                'supplier_name': row['supplier_name'],
                'warehouse_name': row['warehouse_name'],
                'store_name': row['store_name'],
                'subtotal': float(row['subtotal']) if row['subtotal'] else 0,
                'status': row['status']
            }
            
            # Categorize based on status
            if row['status'] == 'expired':
                expired.append(product_data)
            elif row['status'] == 'expiring_soon':
                expiring_soon.append(product_data)
        
        cursor.close()
        conn.close()
        
        # Calculate totals
        total_expired_value = sum(item['subtotal'] for item in expired)
        total_expiring_value = sum(item['subtotal'] for item in expiring_soon)
        
        return jsonify({
            'success': True,
            'expired': expired,
            'expiring_soon': expiring_soon,
            'summary': {
                'total_expired_items': len(expired),
                'total_expiring_items': len(expiring_soon),
                'total_expired_value': round(total_expired_value, 2),
                'total_expiring_value': round(total_expiring_value, 2),
                'total_items': len(expired) + len(expiring_soon),
                'total_value': round(total_expired_value + total_expiring_value, 2)
            },
            'generated_at': datetime.now().isoformat()
        }), 200
        
    except Exception as e:
        print(f"❌ Error fetching expired products: {str(e)}")
        import traceback
        traceback.print_exc()
        return jsonify({
            'success': False,
            'error': 'Failed to fetch expired products',
            'message': str(e)
        }), 500


@product_bp.route('/expired_products/summary', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_expired_products_summary():
    """
    Get summary statistics for expired products
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)
        
        query = """
        SELECT 
            COUNT(DISTINCT pb.batch_id) AS total_batches,
            COUNT(DISTINCT pb.product_id) AS total_products,
            SUM(pb.remaining_quantity) AS total_quantity,
            SUM(pb.remaining_quantity * COALESCE(pv.variation_cost, pb.cost)) AS total_value,
            
            -- Expired count
            SUM(CASE WHEN pb.expiration_date < CURDATE() THEN 1 ELSE 0 END) AS expired_count,
            SUM(CASE 
                WHEN pb.expiration_date < CURDATE() 
                THEN pb.remaining_quantity * COALESCE(pv.variation_cost, pb.cost) 
                ELSE 0 
            END) AS expired_value,
            
            -- Expiring soon (30 days)
            SUM(CASE 
                WHEN pb.expiration_date >= CURDATE() 
                AND pb.expiration_date <= DATE_ADD(CURDATE(), INTERVAL 30 DAY) 
                THEN 1 ELSE 0 
            END) AS expiring_soon_count,
            SUM(CASE 
                WHEN pb.expiration_date >= CURDATE() 
                AND pb.expiration_date <= DATE_ADD(CURDATE(), INTERVAL 30 DAY) 
                THEN pb.remaining_quantity * COALESCE(pv.variation_cost, pb.cost)
                ELSE 0 
            END) AS expiring_soon_value
            
        FROM product_batches pb
        LEFT JOIN product_variations pv ON pb.variation_id = pv.id
        WHERE 
            pb.expiration_date IS NOT NULL
            AND pb.remaining_quantity > 0
        """
        
        cursor.execute(query)
        result = cursor.fetchone()
        cursor.close()
        conn.close()
        
        return jsonify({
            'success': True,
            'summary': {
                'total_batches': result['total_batches'] or 0,
                'total_products': result['total_products'] or 0,
                'total_quantity': float(result['total_quantity']) if result['total_quantity'] else 0,
                'total_value': float(result['total_value']) if result['total_value'] else 0,
                'expired': {
                    'count': result['expired_count'] or 0,
                    'value': float(result['expired_value']) if result['expired_value'] else 0
                },
                'expiring_soon': {
                    'count': result['expiring_soon_count'] or 0,
                    'value': float(result['expiring_soon_value']) if result['expiring_soon_value'] else 0
                }
            }
        }), 200
        
    except Exception as e:
        print(f"❌ Error fetching summary: {str(e)}")
        import traceback
        traceback.print_exc()
        return jsonify({
            'success': False,
            'error': 'Failed to fetch summary',
            'message': str(e)
        }), 500


@product_bp.route('/expired_products/by_category', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_expired_by_category():
    """
    Get expired products grouped by category
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)
        
        query = """
        SELECT 
            c.category_name,
            COUNT(DISTINCT pb.batch_id) AS batch_count,
            SUM(pb.remaining_quantity) AS total_quantity,
            SUM(pb.remaining_quantity * COALESCE(pv.variation_cost, pb.cost)) AS total_value,
            
            SUM(CASE WHEN pb.expiration_date < CURDATE() THEN 1 ELSE 0 END) AS expired_count,
            SUM(CASE 
                WHEN pb.expiration_date >= CURDATE() 
                AND pb.expiration_date <= DATE_ADD(CURDATE(), INTERVAL 30 DAY) 
                THEN 1 ELSE 0 
            END) AS expiring_soon_count
            
        FROM product_batches pb
        INNER JOIN products p ON pb.product_id = p.id
        LEFT JOIN product_variations pv ON pb.variation_id = pv.id
        LEFT JOIN categories c ON p.category_id = c.id
        
        WHERE 
            pb.expiration_date IS NOT NULL
            AND pb.remaining_quantity > 0
            AND pb.expiration_date <= DATE_ADD(CURDATE(), INTERVAL 30 DAY)
            
        GROUP BY c.category_name
        ORDER BY total_value DESC
        """
        
        cursor.execute(query)
        results = cursor.fetchall()
        cursor.close()
        conn.close()
        
        categories = []
        for row in results:
            categories.append({
                'category_name': row['category_name'] or 'Uncategorized',
                'batch_count': row['batch_count'] or 0,
                'total_quantity': float(row['total_quantity']) if row['total_quantity'] else 0,
                'total_value': float(row['total_value']) if row['total_value'] else 0,
                'expired_count': row['expired_count'] or 0,
                'expiring_soon_count': row['expiring_soon_count'] or 0
            })
        
        return jsonify({
            'success': True,
            'categories': categories
        }), 200
        
    except Exception as e:
        print(f"❌ Error fetching by category: {str(e)}")
        import traceback
        traceback.print_exc()
        return jsonify({
            'success': False,
            'error': 'Failed to fetch category data',
            'message': str(e)
        }), 500

# ==========================================
# UPDATE PRODUCT  (our_price removed)
# ==========================================

@product_bp.route('/update_product/<int:id>', methods=['PUT'])
@jwt_required()
@role_required('admin')
def update_product(id):
    """
    Complete update endpoint for products.

    Handles:
    - Single products  : products table + optional specific batch
    - Variable products: product_variations table + optional variation batch
    - Image updates    : product images and certificate images (auto-optimised)

    our_price is intentionally removed – the column has been dropped from
    product_batches and all related tables.
    """
    data = request.get_json()
    if not data:
        return jsonify({'error': 'No data provided'}), 400

    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        # ── Get basic product info ──────────────────────────────────────────
        cursor.execute(
            "SELECT product_type, product_image, certificate_image FROM products WHERE id = %s",
            (id,)
        )
        product_info = cursor.fetchone()
        if not product_info:
            return jsonify({'error': 'Product not found'}), 404

        product_type = product_info['product_type']
        updated      = False

        new_image_paths      = []
        new_certificate_path = None
        old_image_paths      = []
        old_certificate_path = None

        # ── Image updates ───────────────────────────────────────────────────
        if 'product_image' in data:
            if data['product_image'] is None or data['product_image'] == '':
                print("\n🗑️ Deleting all product images...")
                if product_info['product_image']:
                    try:
                        old_imgs = (
                            json.loads(product_info['product_image'])
                            if isinstance(product_info['product_image'], str)
                            else product_info['product_image']
                        )
                        if isinstance(old_imgs, list):
                            old_image_paths = old_imgs
                    except Exception:
                        pass
                cursor.execute("UPDATE products SET product_image = NULL WHERE id = %s", (id,))
                updated = True
            else:
                try:
                    images_list = (
                        json.loads(data['product_image'])
                        if isinstance(data['product_image'], str)
                        else data['product_image']
                    )
                    if not isinstance(images_list, list):
                        images_list = [images_list]
                    if len(images_list) > 10:
                        return jsonify({'error': 'Maximum 10 images allowed'}), 400

                    saved_image_paths = []
                    for idx, img_b64 in enumerate(images_list, 1):
                        if img_b64:
                            if 'base64,' in str(img_b64):
                                file_path = save_base64_file(img_b64, f'product_img_{idx}')
                                if file_path:
                                    saved_image_paths.append(file_path)
                                    new_image_paths.append(file_path)
                            else:
                                if img_b64.startswith('/uploads/products/'):
                                    filename  = img_b64.split('/uploads/products/')[-1]
                                    file_path = os.path.join(UPLOAD_FOLDER, filename)
                                    if os.path.exists(file_path):
                                        saved_image_paths.append(file_path)

                    if saved_image_paths:
                        if product_info['product_image']:
                            try:
                                old_imgs = (
                                    json.loads(product_info['product_image'])
                                    if isinstance(product_info['product_image'], str)
                                    else product_info['product_image']
                                )
                                if isinstance(old_imgs, list):
                                    old_image_paths = [p for p in old_imgs if p not in saved_image_paths]
                            except Exception:
                                pass
                        cursor.execute(
                            "UPDATE products SET product_image = %s WHERE id = %s",
                            (json.dumps(saved_image_paths), id)
                        )
                        updated = True
                except Exception as e:
                    print(f"❌ Error processing product images: {e}")
                    traceback.print_exc()

        if 'certificate_image' in data:
            if data['certificate_image'] is None or data['certificate_image'] == '':
                if product_info['certificate_image']:
                    old_certificate_path = product_info['certificate_image']
                cursor.execute("UPDATE products SET certificate_image = NULL WHERE id = %s", (id,))
                updated = True
            else:
                try:
                    cert_data = data['certificate_image']
                    if 'base64,' in str(cert_data):
                        saved_cert = save_base64_file(cert_data, 'certificate')
                        if saved_cert:
                            new_certificate_path = saved_cert
                            if product_info['certificate_image']:
                                old_certificate_path = product_info['certificate_image']
                            cursor.execute(
                                "UPDATE products SET certificate_image = %s WHERE id = %s",
                                (saved_cert, id)
                            )
                            updated = True
                except Exception as e:
                    print(f"❌ Error processing certificate: {e}")
                    traceback.print_exc()

        # ── products table — common fields ──────────────────────────────────
        prod_mapping = {
            'product_name':      'product_name',
            'sku':               'sku',
            'barcode_symbology': 'barcode_symbology',
            'expiration_date':   'expiration_date',
            'brand_id':          'brand_id',
            'category_id':       'category_id',
            'base_unit_id':      'base_unit_id',
            'sale_unit_id':      'sale_unit_id',
            'purchase_unit_id':  'purchase_unit_id',
            'stock_alert':       'stock_alert',
            'note':              'note',
            'product_tax':       'product_tax',
            'tax_type':          'tax_type',
        }
        prod_fields = []
        prod_values = []
        for key, col in prod_mapping.items():
            if key in data:
                prod_fields.append(f"{col} = %s")
                prod_values.append(data[key])

        if prod_fields:
            prod_values.append(id)
            cursor.execute(
                f"UPDATE products SET {', '.join(prod_fields)} WHERE id = %s",
                prod_values
            )
            updated = True

        # ── SINGLE PRODUCT ──────────────────────────────────────────────────
        if product_type == 'single':

            if data.get('batch_id'):
                batch_id = data['batch_id']
                cursor.execute(
                    """SELECT batch_id FROM product_batches
                       WHERE batch_id = %s AND product_id = %s AND variation_id IS NULL""",
                    (batch_id, id)
                )
                if not cursor.fetchone():
                    return jsonify({'error': 'Batch not found or does not belong to this product'}), 404

                # ✅ our_price removed from batch mapping
                batch_mapping = {
                    'product_cost':  'cost',
                    'product_price': 'price',
                    'batch_number':  'batch_number',
                    'expiration_date': 'expiration_date',
                }
                batch_fields = []
                batch_values = []
                for key, col in batch_mapping.items():
                    if key in data:
                        batch_fields.append(f"{col} = %s")
                        batch_values.append(data[key])

                if batch_fields:
                    batch_values.append(batch_id)
                    cursor.execute(
                        f"UPDATE product_batches SET {', '.join(batch_fields)} WHERE batch_id = %s",
                        batch_values
                    )
                    updated = True
                    print(f"✅ Updated batch {batch_id} for single product {id}")

            if not updated:
                return jsonify({'error': 'No valid fields to update'}), 400

            conn.commit()
            cleanup_old_files(old_image_paths, old_certificate_path)
            return jsonify({
                'success': True,
                'message': 'Product updated successfully',
                'images_optimized': len(new_image_paths),
            }), 200

        # ── VARIABLE PRODUCT ────────────────────────────────────────────────
        elif product_type == 'variable':

            if data.get('variation_id'):
                variation_id = data['variation_id']
                cursor.execute(
                    "SELECT id FROM product_variations WHERE id = %s AND product_id = %s",
                    (variation_id, id)
                )
                if not cursor.fetchone():
                    return jsonify({'error': 'Variation not found or does not belong to this product'}), 404

                # Update variation-level fields
                var_mapping = {
                    'variation_name':         'variation_name',
                    'variation_type':         'variation_type',
                    'variation_sku':          'variation_sku',
                    'variation_tax':          'variation_tax',
                    'variation_tax_type':     'variation_tax_type',
                    'variation_stock_alert':  'variation_stock_alert',
                }
                var_fields = []
                var_values = []
                for key, col in var_mapping.items():
                    if key in data:
                        var_fields.append(f"{col} = %s")
                        var_values.append(data[key])

                if var_fields:
                    var_values.append(variation_id)
                    cursor.execute(
                        f"UPDATE product_variations SET {', '.join(var_fields)} WHERE id = %s",
                        var_values
                    )
                    updated = True

                if data.get('batch_id'):
                    batch_id = data['batch_id']
                    cursor.execute(
                        """SELECT batch_id FROM product_batches
                           WHERE batch_id = %s AND product_id = %s AND variation_id = %s""",
                        (batch_id, id, variation_id)
                    )
                    if not cursor.fetchone():
                        return jsonify({'error': 'Batch not found or does not belong to this variation'}), 404

                    # ✅ our_price removed from variation batch mapping
                    var_batch_mapping = {
                        'variation_cost':  'cost',
                        'variation_price': 'price',
                        'expiration_date': 'expiration_date',
                    }
                    vb_fields = []
                    vb_values = []
                    for key, col in var_batch_mapping.items():
                        if key in data:
                            vb_fields.append(f"{col} = %s")
                            vb_values.append(data[key])

                    if vb_fields:
                        vb_values.append(batch_id)
                        cursor.execute(
                            f"UPDATE product_batches SET {', '.join(vb_fields)} WHERE batch_id = %s",
                            vb_values
                        )
                        updated = True
                        print(f"✅ Updated batch {batch_id} for variation {variation_id}")

                else:
                    # No batch — update variation defaults
                    var_default_mapping = {
                        'variation_cost':  'variation_cost',
                        'variation_price': 'variation_price',
                        'variation_sku':   'variation_sku',
                        'expiration_date': 'expiration_date',
                    }
                    vd_fields = []
                    vd_values = []
                    for key, col in var_default_mapping.items():
                        if key in data:
                            vd_fields.append(f"{col} = %s")
                            vd_values.append(data[key])

                    if vd_fields:
                        vd_values.append(variation_id)
                        cursor.execute(
                            f"UPDATE product_variations SET {', '.join(vd_fields)} WHERE id = %s",
                            vd_values
                        )
                        updated = True

            if not updated:
                return jsonify({'error': 'No valid fields to update'}), 400

            conn.commit()
            cleanup_old_files(old_image_paths, old_certificate_path)
            return jsonify({
                'success': True,
                'message': 'Variable product updated successfully',
                'images_optimized': len(new_image_paths),
            }), 200

        else:
            return jsonify({'error': 'Unknown product type'}), 400

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        cleanup_new_files(new_image_paths, new_certificate_path)
        print(f"❌ DB error in update_product: {e}")
        traceback.print_exc()
        return jsonify({'error': 'Database error occurred', 'details': str(e)}), 500

    except Exception as e:
        if conn: conn.rollback()
        cleanup_new_files(new_image_paths, new_certificate_path)
        print(f"❌ Error in update_product: {e}")
        traceback.print_exc()
        return jsonify({'error': 'Failed to update product', 'details': str(e)}), 500

    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ==========================================
# ✅ GET / PUT — Batch Discount Rules
# Endpoint: /batch_discount_rules/<batch_id>
# ==========================================

@product_bp.route('/batch_discount_rules/<int:batch_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_batch_discount_rules(batch_id):
    """
    GET  /batch_discount_rules/<batch_id>
    Returns all discount rules for a given batch joined with payment_method names.
    """
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)
    try:
        cursor.execute(
            """
            SELECT
                pbd.id,
                pbd.batch_id,
                pbd.payment_method_id,
                pm.method_name AS method_name,
                pbd.discount_rate,
                pbd.discount_type,
                pbd.is_active,
                pbd.created_at,
                pbd.updated_at
            FROM product_batch_discounts pbd
            JOIN payment_methods pm ON pm.id = pbd.payment_method_id
            WHERE pbd.batch_id = %s
            ORDER BY pbd.id
            """,
            (batch_id,)
        )
        rows = cursor.fetchall()

        # Serialise timestamps
        for row in rows:
            for col in ('created_at', 'updated_at'):
                if row.get(col):
                    row[col] = row[col].strftime('%Y-%m-%d %H:%M:%S')
            row['discount_rate'] = float(row['discount_rate'] or 0)
            row['is_active']     = bool(row['is_active'])

        return jsonify({'success': True, 'batch_id': batch_id, 'rules': rows}), 200

    except Exception as e:
        print(f"❌ Error in get_batch_discount_rules: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        cursor.close()
        conn.close()


@product_bp.route('/batch_discount_rules/<int:batch_id>', methods=['PUT'])
@jwt_required()
@role_required('admin')
def upsert_batch_discount_rules(batch_id):
    """
    PUT  /batch_discount_rules/<batch_id>
    Full replacement of discount rules for a batch.

    Request body:
    {
        "rules": [
            {
                "id": 1,                     // null → INSERT, int → UPDATE
                "payment_method_id": 2,
                "discount_rate": 10.00,
                "discount_type": "percent",  // "percent" | "fixed"
                "is_active": 1
            },
            ...
        ]
    }

    Strategy:
    - Existing rule ids not in the payload → DELETE
    - id is null   → INSERT
    - id is int    → UPDATE
    - Unique constraint: (batch_id, payment_method_id)
    """
    data = request.get_json()
    if not data or 'rules' not in data:
        return jsonify({'error': 'No rules data provided'}), 400

    rules = data['rules']

    # Basic validation
    if not isinstance(rules, list):
        return jsonify({'error': '"rules" must be an array'}), 400

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        # Verify batch exists
        cursor.execute("SELECT batch_id FROM product_batches WHERE batch_id = %s", (batch_id,))
        if not cursor.fetchone():
            return jsonify({'error': f'Batch {batch_id} not found'}), 404

        # Collect incoming rule ids (existing rows being updated)
        incoming_ids = {int(r['id']) for r in rules if r.get('id')}

        # ── Delete rules not in the incoming payload ────────────────────────
        if incoming_ids:
            fmt = ','.join(['%s'] * len(incoming_ids))
            cursor.execute(
                f"DELETE FROM product_batch_discounts WHERE batch_id = %s AND id NOT IN ({fmt})",
                [batch_id] + list(incoming_ids)
            )
        else:
            # No existing ids → delete all existing rules for this batch
            cursor.execute(
                "DELETE FROM product_batch_discounts WHERE batch_id = %s",
                (batch_id,)
            )

        inserted = 0
        updated  = 0

        for rule in rules:
            pm_id     = rule.get('payment_method_id')
            rate      = float(rule.get('discount_rate', 0))
            dtype     = rule.get('discount_type', 'percent')
            is_active = 1 if rule.get('is_active', True) else 0
            rule_id   = rule.get('id')

            if not pm_id:
                continue  # skip incomplete rows

            if dtype not in ('percent', 'fixed'):
                dtype = 'percent'

            if rule_id:
                # UPDATE existing rule
                cursor.execute(
                    """UPDATE product_batch_discounts
                       SET payment_method_id = %s,
                           discount_rate     = %s,
                           discount_type     = %s,
                           is_active         = %s
                       WHERE id = %s AND batch_id = %s""",
                    (pm_id, rate, dtype, is_active, rule_id, batch_id)
                )
                updated += 1
            else:
                # INSERT new rule (ON DUPLICATE KEY UPDATE handles race conditions)
                cursor.execute(
                    """INSERT INTO product_batch_discounts
                           (batch_id, payment_method_id, discount_rate, discount_type, is_active)
                       VALUES (%s, %s, %s, %s, %s)
                       ON DUPLICATE KEY UPDATE
                           discount_rate = VALUES(discount_rate),
                           discount_type = VALUES(discount_type),
                           is_active     = VALUES(is_active)""",
                    (batch_id, pm_id, rate, dtype, is_active)
                )
                inserted += 1

        conn.commit()

        print(f"✅ Discount rules updated for batch {batch_id}: "
              f"{inserted} inserted, {updated} updated, deleted old")

        return jsonify({
            'success':  True,
            'message':  f'Discount rules saved for batch {batch_id}',
            'inserted': inserted,
            'updated':  updated,
            'batch_id': batch_id,
        }), 200

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"❌ DB error in upsert_batch_discount_rules: {e}")
        traceback.print_exc()
        return jsonify({'error': 'Database error', 'details': str(e)}), 500

    except Exception as e:
        if conn: conn.rollback()
        print(f"❌ Error in upsert_batch_discount_rules: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        cursor.close()
        conn.close()


# ==========================================
# FILE CLEANUP HELPERS
# ==========================================

def cleanup_old_files(old_image_paths, old_certificate_path):
    """Delete old files after a successful DB commit."""
    for old_path in (old_image_paths or []):
        try:
            if old_path and os.path.exists(old_path):
                os.remove(old_path)
                print(f"🗑️ Deleted old image: {old_path}")
        except Exception as e:
            print(f"⚠️ Could not delete old image {old_path}: {e}")

    if old_certificate_path:
        try:
            if os.path.exists(old_certificate_path):
                os.remove(old_certificate_path)
                print(f"🗑️ Deleted old certificate: {old_certificate_path}")
        except Exception as e:
            print(f"⚠️ Could not delete old certificate {old_certificate_path}: {e}")


def cleanup_new_files(new_image_paths, new_certificate_path):
    """Rollback: delete newly uploaded files when the transaction fails."""
    for new_path in (new_image_paths or []):
        try:
            if new_path and os.path.exists(new_path):
                os.remove(new_path)
                print(f"🗑️ Rolled back new image: {new_path}")
        except Exception as e:
            print(f"⚠️ Could not delete new image {new_path}: {e}")

    if new_certificate_path:
        try:
            if os.path.exists(new_certificate_path):
                os.remove(new_certificate_path)
                print(f"🗑️ Rolled back new certificate: {new_certificate_path}")
        except Exception as e:
            print(f"⚠️ Could not delete new certificate {new_certificate_path}: {e}")
            
@product_bp.route('/delete_product/<int:id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_product(id):
    """
    Delete a product and all related data.
    
    Validations:
    - Check if product has stock
    - Check if product is used in pending orders
    - Delete all related records (variations, batches, warehouse stock)
    
    Note: This is a destructive operation. Consider soft delete for production.
    """
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        # 1️⃣ Check if product exists
        cursor.execute("""
            SELECT 
                id,
                product_name,
                sku,
                product_type
            FROM products 
            WHERE id = %s
        """, (id,))
        
        product = cursor.fetchone()

        if not product:
            return jsonify({'error': 'Product not found'}), 404

        product_name = product['product_name']
        product_sku = product['sku']
        product_type = product['product_type']

        # 2️⃣ Check total stock across all warehouses
        cursor.execute("""
            SELECT IFNULL(SUM(quantity), 0) as total_stock
            FROM warehouse_stock
            WHERE product_id = %s
        """, (id,))
        
        stock_result = cursor.fetchone()
        total_stock = float(stock_result['total_stock']) if stock_result else 0.0

        if total_stock > 0:
            return jsonify({
                'error': f'Cannot delete product with existing stock. Current total stock: {total_stock}. Please remove all stock first through stock adjustments.'
            }), 400

        # 3️⃣ Check if product is used in pending purchase orders
        cursor.execute("""
            SELECT COUNT(*) as order_count
            FROM order_items oi
            JOIN purchase_orders po ON oi.order_id = po.order_id
            WHERE oi.product_id = %s 
            AND po.status IN ('pending', 'ordered')
        """, (id,))
        
        order_result = cursor.fetchone()
        pending_orders = order_result['order_count'] if order_result else 0

        if pending_orders > 0:
            return jsonify({
                'error': f'Cannot delete product used in {pending_orders} pending purchase order(s). Complete, receive, or cancel orders first.'
            }), 400

        # 4️⃣ Count related records before deletion
        cursor.execute("""
            SELECT COUNT(*) as variation_count
            FROM product_variations
            WHERE product_id = %s
        """, (id,))
        variation_count = cursor.fetchone()['variation_count']

        cursor.execute("""
            SELECT COUNT(*) as batch_count
            FROM product_batches
            WHERE product_id = %s
        """, (id,))
        batch_count = cursor.fetchone()['batch_count']

        cursor.execute("""
            SELECT COUNT(*) as warehouse_stock_count
            FROM warehouse_stock
            WHERE product_id = %s
        """, (id,))
        warehouse_stock_count = cursor.fetchone()['warehouse_stock_count']

        # Check completed orders (will be kept for history)
        cursor.execute("""
            SELECT COUNT(*) as completed_order_count
            FROM order_items oi
            JOIN purchase_orders po ON oi.order_id = po.order_id
            WHERE oi.product_id = %s 
            AND po.status IN ('received', 'completed')
        """, (id,))
        completed_orders = cursor.fetchone()['completed_order_count']

        # 5️⃣ Delete related warehouse stock (should be 0 at this point)
        cursor.execute("""
            DELETE FROM warehouse_stock 
            WHERE product_id = %s
        """, (id,))
        deleted_stock = cursor.rowcount

        # 6️⃣ Delete related product batches
        cursor.execute("""
            DELETE FROM product_batches 
            WHERE product_id = %s
        """, (id,))
        deleted_batches = cursor.rowcount

        # 7️⃣ Delete related product variations (if variable product)
        cursor.execute("""
            DELETE FROM product_variations 
            WHERE product_id = %s
        """, (id,))
        deleted_variations = cursor.rowcount

        # 8️⃣ Handle order_items - Keep for completed orders (history)
        if completed_orders > 0:
            # Don't delete - keep for history
            print(f"⚠️ Keeping {completed_orders} order_items in completed orders for history")
            deleted_order_items = 0
        else:
            # No completed orders - safe to delete
            cursor.execute("""
                DELETE FROM order_items 
                WHERE product_id = %s
            """, (id,))
            deleted_order_items = cursor.rowcount

        # 9️⃣ Finally, delete the product itself
        cursor.execute("""
            DELETE FROM products 
            WHERE id = %s
        """, (id,))

        if cursor.rowcount == 0:
            conn.rollback()
            return jsonify({'error': 'Failed to delete product'}), 500

        # Commit all changes
        conn.commit()

        return jsonify({
            'success': True,
            'message': f"✅ Product '{product_name}' (SKU: {product_sku}) deleted successfully!",
            'data': {
                'deleted_product_id': id,
                'product_name': product_name,
                'product_sku': product_sku,
                'product_type': product_type,
                'deleted_records': {
                    'variations': deleted_variations,
                    'batches': deleted_batches,
                    'warehouse_stock': deleted_stock,
                    'order_items': deleted_order_items
                },
                'related_counts_before_deletion': {
                    'variations': variation_count,
                    'batches': batch_count,
                    'warehouse_stock_entries': warehouse_stock_count
                },
                'completed_orders_kept': completed_orders
            }
        }), 200

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"❌ Database Error deleting product: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"❌ Error deleting product: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Failed to delete product: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
            
            
@product_bp.route('/base_units', methods=['GET'])
@jwt_required()
@role_required('admin')
def get_base_units():
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor(dictionary=True)

        # Fetch base units (latest first)
        cursor.execute("SELECT * FROM base_units ORDER BY id DESC")
        base_units = cursor.fetchall()

        # Fetch corresponding units for each base unit
        for base_unit in base_units:
            cursor.execute(
                "SELECT * FROM units WHERE base_unit_id = %s",
                (base_unit['id'],)
            )
            base_unit['units'] = cursor.fetchall()

        cursor.close()
        conn.close()

        return jsonify({"base_units": base_units}), 200

    except Exception as e:
        print("Error in /base_units:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while fetching base units'}), 500
            
            
@product_bp.route('/base_unit', methods=['POST'])
@jwt_required()
@role_required('admin')
def add_base_unit():
    try:
        if not request.is_json:
            return jsonify({'error': 'Request body must be JSON'}), 400

        data = request.get_json()
        base_unit = data.get('base_unit')

        if not base_unit or not base_unit.strip():
            return jsonify({'error': 'base_unit is required'}), 400

        base_unit = base_unit.strip()

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor()

        # Check duplicate
        cursor.execute(
            "SELECT id FROM base_units WHERE base_unit = %s",
            (base_unit,)
        )
        if cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Base unit already exists'}), 409

        cursor.execute(
            "INSERT INTO base_units (base_unit) VALUES (%s)",
            (base_unit,)
        )
        conn.commit()

        new_id = cursor.lastrowid

        cursor.close()
        conn.close()

        return jsonify({
            'message': 'Base unit added successfully',
            'id': new_id
        }), 201

    except Exception as e:
        print("Error in /base_unit:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.rollback()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while adding base unit'}), 500


@product_bp.route('/base_units/<int:id>', methods=['PUT'])
@jwt_required()
@role_required('admin')
def update_base_unit(id):
    try:
        if not request.is_json:
            return jsonify({'error': 'Request body must be JSON'}), 400

        data = request.get_json()
        base_unit = data.get('base_unit')

        if not base_unit or not base_unit.strip():
            return jsonify({'error': 'base_unit is required'}), 400

        base_unit = base_unit.strip()

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor()

        # Check if base unit exists
        cursor.execute("SELECT base_unit FROM base_units WHERE id = %s", (id,))
        existing = cursor.fetchone()

        if not existing:
            cursor.close()
            conn.close()
            return jsonify({'error': 'Base unit not found'}), 404

        # If value unchanged
        if existing[0] == base_unit:
            cursor.close()
            conn.close()
            return jsonify({'message': 'Base unit updated successfully'}), 200

        # Check duplicate name on other rows
        cursor.execute(
            "SELECT id FROM base_units WHERE base_unit = %s AND id <> %s",
            (base_unit, id)
        )
        if cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Another base unit with this name already exists'}), 409

        cursor.execute(
            "UPDATE base_units SET base_unit = %s WHERE id = %s",
            (base_unit, id)
        )
        conn.commit()

        cursor.close()
        conn.close()

        return jsonify({'message': 'Base unit updated successfully'}), 200

    except Exception as e:
        print("Error in /base_units/<id> [PUT]:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.rollback()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while updating base unit'}), 500

@product_bp.route('/base_units/<int:id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_base_unit(id):
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor()

        # Check if base unit exists
        cursor.execute("SELECT id FROM base_units WHERE id = %s", (id,))
        if not cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Base unit not found'}), 404

        # Optional safety: prevent delete if units exist
        cursor.execute(
            "SELECT COUNT(*) FROM units WHERE base_unit_id = %s",
            (id,)
        )
        unit_count = cursor.fetchone()[0]

        if unit_count > 0:
            cursor.close()
            conn.close()
            return jsonify({
                'error': 'Cannot delete base unit with linked units'
            }), 400

        cursor.execute("DELETE FROM base_units WHERE id = %s", (id,))
        conn.commit()

        cursor.close()
        conn.close()

        return jsonify({'message': 'Base unit deleted successfully'}), 200

    except Exception as e:
        print("Error in /base_units/<id> [DELETE]:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.rollback()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while deleting base unit'}), 500
        

@product_bp.route('/get_brands', methods=['GET'])
@jwt_required()
@role_required('admin')
def get_brands():
    try:
        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True)

        query = """
            SELECT 
                b.id, 
                b.brand_name, 
                COUNT(p.id) AS product_count
            FROM brands b
            LEFT JOIN products p ON p.brand_id = b.id
            GROUP BY b.id, b.brand_name
            ORDER BY b.id DESC
        """

        cursor.execute(query)
        brands = cursor.fetchall()

        cursor.close()
        conn.close()

        if not brands:
            return jsonify({'message': 'No brands found'}), 200

        return jsonify(brands), 200

    except Exception as e:
        print("Error in /get_brands:", str(e))
        traceback.print_exc()
        return jsonify({'error': 'An error occurred while fetching brands'}), 500

@product_bp.route('/add_brand', methods=['POST'])
@jwt_required()
@role_required('admin')
def add_brand():
    try:
        if request.is_json:
            data = request.get_json()
            brand_name = data.get('brand_name')
        else:
            brand_name = request.form.get('brand_name')

        if not brand_name or not brand_name.strip():
            return jsonify({'error': 'Brand name is required'}), 400

        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor()

        # Check if brand already exists
        cursor.execute("SELECT id FROM brands WHERE brand_name = %s", (brand_name,))
        existing_brand = cursor.fetchone()

        if existing_brand:
            cursor.close()
            conn.close()
            return jsonify({'error': 'Brand already exists'}), 409

        # ✅ Insert new brand
        cursor.execute(
            "INSERT INTO brands (brand_name) VALUES (%s)",
            (brand_name,)
        )
        conn.commit()

        # ✅ GET THE NEW ID
        new_brand_id = cursor.lastrowid

        cursor.close()
        conn.close()

        # ✅ RETURN ID WITH MESSAGE
        return jsonify({
            'message': 'Brand added successfully',
            'id': new_brand_id,
            'brand_name': brand_name
        }), 201

    except Exception as e:
        print("Error in /add_brand:", str(e))
        traceback.print_exc()
        return jsonify({'error': 'An error occurred while adding the brand'}), 500


@product_bp.route('/update_brand/<int:id>', methods=['PUT'])
@jwt_required()
@role_required('admin')
def update_brand(id):
    try:
        data = request.get_json()
        if not data:
            return jsonify({'error': 'Request body is required'}), 400

        brand_name = data.get('brand_name')

        if not brand_name or not brand_name.strip():
            return jsonify({'error': 'Brand name is required'}), 400

        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor()

        # Check if brand exists
        cursor.execute("SELECT id FROM brands WHERE id = %s", (id,))
        existing_brand = cursor.fetchone()

        if not existing_brand:
            cursor.close()
            conn.close()
            return jsonify({'error': 'Brand not found'}), 404

        # Update brand
        cursor.execute(
            "UPDATE brands SET brand_name = %s WHERE id = %s",
            (brand_name, id)
        )
        conn.commit()

        cursor.close()
        conn.close()

        return jsonify({'message': 'Brand updated successfully'}), 200

    except Exception as e:
        print("Error in /update_brand:", str(e))
        traceback.print_exc()
        return jsonify({'error': 'An error occurred while updating the brand'}), 500


@product_bp.route('/delete_brand/<int:id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_brand(id):
    """Delete a brand by ID."""
    conn = get_db_connection()
    cursor = conn.cursor()

    try:
        cursor.execute("DELETE FROM brands WHERE id = %s", (id,))
        conn.commit()
        return jsonify({'message': 'Brand deleted successfully'}), 200
    except mysql.connector.Error as err:
        conn.rollback()
        return jsonify({'error': 'Failed to delete brand'}), 500
    finally:
        cursor.close()
        conn.close()


# GET categories with product counts
@product_bp.route('/get_categories', methods=['GET'])
@jwt_required()
@role_required('admin')
def get_categories():
    try:
        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True)

        query = """
            SELECT 
                c.id,
                c.category_name,
                COUNT(p.id) AS product_count
            FROM categories c
            LEFT JOIN products p ON p.category_id = c.id
            GROUP BY c.id, c.category_name
            ORDER BY c.id DESC
        """

        cursor.execute(query)
        categories = cursor.fetchall()

        cursor.close()
        conn.close()

        if not categories:
            return jsonify({'message': 'No categories found'}), 200

        return jsonify(categories), 200

    except Exception as e:
        print("Error in /get_categories:", str(e))
        traceback.print_exc()
        # ensure connection/cursor closed if exception occurred before close (best-effort)
        try:
            cursor.close()
        except Exception:
            pass
        try:
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while fetching categories'}), 500


# ADD category
@product_bp.route('/add_category', methods=['POST'])
@jwt_required()
@role_required('admin')
def add_category():
    try:
        if request.is_json:
            data = request.get_json()
            category_name = data.get('category_name') if data else None
        else:
            category_name = request.form.get('category_name')

        if not category_name or not category_name.strip():
            return jsonify({'error': 'Category name is required'}), 400

        category_name = category_name.strip()

        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor()

        # Check duplicate
        cursor.execute("SELECT id FROM categories WHERE category_name = %s", (category_name,))
        if cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Category already exists'}), 409

        # ✅ Insert
        cursor.execute("INSERT INTO categories (category_name) VALUES (%s)", (category_name,))
        conn.commit()

        # ✅ GET THE NEW ID
        new_category_id = cursor.lastrowid

        cursor.close()
        conn.close()

        # ✅ RETURN ID WITH MESSAGE
        return jsonify({
            'message': 'Category added successfully',
            'id': new_category_id,
            'category_name': category_name
        }), 201

    except Exception as e:
        print("Error in /add_category:", str(e))
        traceback.print_exc()
        try:
            if cursor:
                cursor.close()
        except Exception:
            pass
        try:
            if conn:
                conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while adding the category'}), 500


# UPDATE category
@product_bp.route('/update_category/<int:id>', methods=['PUT'])
@jwt_required()
@role_required('admin')
def update_category(id):
    try:
        if not request.is_json:
            return jsonify({'error': 'Request body must be JSON'}), 400

        data = request.get_json()
        if not data:
            return jsonify({'error': 'Request body is required'}), 400

        category_name = data.get('category_name')
        if not category_name or not category_name.strip():
            return jsonify({'error': 'Category name is required'}), 400
        category_name = category_name.strip()

        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor()

        # Check target exists
        cursor.execute("SELECT id FROM categories WHERE id = %s", (id,))
        if cursor.fetchone() is None:
            cursor.close()
            conn.close()
            return jsonify({'error': 'Category not found'}), 404

        # Check duplicate name for other id
        cursor.execute("SELECT id FROM categories WHERE category_name = %s AND id <> %s", (category_name, id))
        if cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Another category with this name already exists'}), 409

        # Update
        cursor.execute("UPDATE categories SET category_name = %s WHERE id = %s", (category_name, id))
        conn.commit()

        cursor.close()
        conn.close()
        return jsonify({'message': 'Category updated successfully'}), 200

    except Exception as e:
        print("Error in /update_category:", str(e))
        traceback.print_exc()
        try:
            if cursor:
                cursor.close()
        except Exception:
            pass
        try:
            if conn:
                conn.rollback()
                conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while updating the category'}), 500


# DELETE category
@product_bp.route('/delete_category/<int:id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_category(id):
    try:
        conn = get_db_connection()
        if conn is None:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor()

        # Check category exists
        cursor.execute("SELECT id FROM categories WHERE id = %s", (id,))
        if cursor.fetchone() is None:
            cursor.close()
            conn.close()
            return jsonify({'error': 'Category not found'}), 404

        # Optional: prevent delete when products exist (safer)
        cursor.execute("SELECT COUNT(1) AS cnt FROM products WHERE category_id = %s", (id,))
        row = cursor.fetchone()
        product_count = row[0] if row else 0
        if product_count and product_count > 0:
            cursor.close()
            conn.close()
            return jsonify({'error': 'Cannot delete category with associated products'}), 400

        # Delete
        cursor.execute("DELETE FROM categories WHERE id = %s", (id,))
        conn.commit()

        cursor.close()
        conn.close()
        return jsonify({'message': 'Category deleted successfully'}), 200

    except Exception as e:
        print("Error in /delete_category:", str(e))
        traceback.print_exc()
        try:
            if cursor:
                cursor.close()
        except Exception:
            pass
        try:
            if conn:
                conn.rollback()
                conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while deleting the category'}), 500


@product_bp.route('/add_unit', methods=['POST'])
@jwt_required()
@role_required('admin')
def add_unit():
    try:
        if not request.is_json:
            return jsonify({'error': 'Request body must be JSON'}), 400

        data = request.get_json()

        unit_name = data.get("unit_name")
        unit_short = data.get("unit_short")
        # ✅ Accept both 'base_unit' and 'base_unit_id' for backward compatibility
        base_unit_id = data.get("base_unit_id") or data.get("base_unit")

        if not unit_name or not unit_name.strip():
            return jsonify({'error': 'unit_name is required'}), 400

        if not unit_short or not unit_short.strip():
            return jsonify({'error': 'unit_short is required'}), 400

        if not base_unit_id:
            return jsonify({'error': 'base_unit_id is required'}), 400

        unit_name = unit_name.strip()
        unit_short = unit_short.strip()

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor()

        # Check base unit exists
        cursor.execute("SELECT id FROM base_units WHERE id = %s", (base_unit_id,))
        if not cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Base unit not found'}), 404

        # Prevent duplicate unit
        cursor.execute(
            "SELECT id FROM units WHERE unit_name = %s AND base_unit_id = %s",
            (unit_name, base_unit_id)
        )
        if cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Unit already exists'}), 409

        cursor.execute(
            "INSERT INTO units (unit_name, unit_short, base_unit_id) VALUES (%s, %s, %s)",
            (unit_name, unit_short, base_unit_id)
        )
        conn.commit()

        new_id = cursor.lastrowid

        cursor.close()
        conn.close()

        return jsonify({
            "message": "Unit added successfully!",
            "id": new_id
        }), 201

    except Exception as e:
        print("Error in /add_unit:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.rollback()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while adding unit'}), 500
    
    
@product_bp.route('/get_units', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager', 'cashier')
def get_all_units():
    """
    Get all units with their base unit information
    """
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)
    try:
        cursor.execute("""
            SELECT 
                u.id,
                u.unit_name,
                u.unit_short,
                u.base_unit_id,
                bu.base_unit AS base_unit_name,
                u.created_at
            FROM units u
            LEFT JOIN base_units bu ON u.base_unit_id = bu.id
            ORDER BY u.unit_name ASC
        """)
        
        units = cursor.fetchall()
        
        return jsonify(units), 200

    except mysql.connector.Error as err:
        print(f"Database Error fetching units: {err}")
        return jsonify({'error': f'Database error: {str(err)}'}), 500
    
    except Exception as err:
        print(f"Error fetching units: {err}")
        return jsonify({'error': 'Failed to fetch units'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


@product_bp.route('/get_units/<int:id>', methods=['GET'])
@jwt_required()
@role_required('admin')
def get_unit(id):
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT 
                u.id, 
                u.unit_name, 
                u.unit_short, 
                u.base_unit_id, 
                b.base_unit, 
                u.created_at
            FROM units u
            JOIN base_units b ON u.base_unit_id = b.id
            WHERE u.id = %s
        """, (id,))

        unit = cursor.fetchone()

        cursor.close()
        conn.close()

        if unit:
            return jsonify(unit), 200
        else:
            return jsonify({"error": "Unit not found"}), 404

    except Exception as e:
        print("Error in /get_units/<id>:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while fetching unit'}), 500


@product_bp.route('/update_unit/<int:id>', methods=['PUT'])
@jwt_required()
@role_required('admin')
def update_unit(id):
    try:
        # Validate request
        if not request.is_json:
            return jsonify({'error': 'Request body must be JSON'}), 400

        data = request.get_json()
        unit_name = data.get("unit_name")
        unit_short = data.get("unit_short")
        # ✅ Accept both 'base_unit_id' and 'base_unit' for compatibility
        base_unit_id = data.get("base_unit_id") or data.get("base_unit")

        if not unit_name or not unit_name.strip():
            return jsonify({'error': 'unit_name is required'}), 400
        if not unit_short or not unit_short.strip():
            return jsonify({'error': 'unit_short is required'}), 400
        if not base_unit_id:
            return jsonify({'error': 'base_unit_id is required'}), 400

        unit_name = unit_name.strip()
        unit_short = unit_short.strip()

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500
        cursor = conn.cursor()

        # Check if unit exists
        cursor.execute("SELECT id FROM units WHERE id = %s", (id,))
        if not cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': f'Unit with ID {id} not found'}), 404

        # Check if base unit exists
        cursor.execute("SELECT id FROM base_units WHERE id = %s", (base_unit_id,))
        if not cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Base unit not found'}), 404

        # Prevent duplicate unit name for same base unit
        cursor.execute(
            "SELECT id FROM units WHERE unit_name = %s AND base_unit_id = %s AND id <> %s",
            (unit_name, base_unit_id, id)
        )
        if cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Another unit with this name already exists for this base unit'}), 409

        # Perform update
        cursor.execute(
            """
            UPDATE units
            SET unit_name = %s, unit_short = %s, base_unit_id = %s
            WHERE id = %s
            """,
            (unit_name, unit_short, base_unit_id, id)
        )
        conn.commit()

        cursor.close()
        conn.close()

        return jsonify({'message': f'Unit with ID {id} updated successfully!'}), 200

    except Exception as e:
        print("Error in /update_unit/<id>:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.rollback()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while updating unit'}), 500


@product_bp.route('/delete_unit/<int:id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_unit(id):
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Failed to connect to database'}), 500

        cursor = conn.cursor()

        # Check unit exists
        cursor.execute("SELECT id FROM units WHERE id = %s", (id,))
        if not cursor.fetchone():
            cursor.close()
            conn.close()
            return jsonify({'error': 'Unit not found'}), 404

        # ✅ Check if unit is being used by products
        cursor.execute("""
            SELECT COUNT(*) as count FROM products 
            WHERE base_unit_id = %s OR sale_unit_id = %s OR purchase_unit_id = %s
        """, (id, id, id))
        
        result = cursor.fetchone()
        if result and result[0] > 0:
            cursor.close()
            conn.close()
            return jsonify({
                'error': 'Cannot delete unit. It is being used by one or more products.'
            }), 409

        # Delete the unit
        cursor.execute("DELETE FROM units WHERE id = %s", (id,))
        conn.commit()

        cursor.close()
        conn.close()

        return jsonify({"message": f"Unit with ID {id} deleted successfully!"}), 200

    except Exception as e:
        print("Error in /delete_unit:", str(e))
        traceback.print_exc()
        try:
            cursor.close()
            conn.rollback()
            conn.close()
        except Exception:
            pass
        return jsonify({'error': 'An error occurred while deleting unit'}), 500

# @product_bp.route('/search_products', methods=['GET'])
# @jwt_required()
# @role_required('admin')
# def search_products():
#     """
#     Search products with warehouse-specific batch information
#     ✅ FIXED: Returns batches filtered by warehouse with correct remaining quantities
#     ✅ NEW: Handles new variation products without batches - shows product details with 0 stock
#     """
#     query = request.args.get('query')
    
#     if not query:
#         return jsonify({"status": "error", "message": "Query parameter is required"}), 400

#     conn = get_db_connection()
#     cursor = conn.cursor(dictionary=True)

#     try:
#         like_query = f"%{query}%"

#         # Main product search query
#         cursor.execute("""
#             SELECT 
#                 p.id AS product_id,
#                 p.product_name,
#                 p.sku AS product_sku,
#                 p.product_type,
#                 p.tax_type,
#                 p.product_tax,
#                 p.stock_alert,
#                 p.expiration_date,
#                 p.note,
                
#                 b.id AS brand_id,
#                 b.brand_name,
                
#                 c.id AS category_id,
#                 c.category_name,
                
#                 bu.id AS base_unit_id,
#                 bu.unit_name AS unit_name,
#                 bu.unit_short AS unit_short,
#                 base_u.base_unit AS base_unit_name,
                
#                 su.id AS sale_unit_id,
#                 su.unit_name AS sale_unit_name,
#                 su.unit_short AS sale_unit_short,
                
#                 pu.id AS purchase_unit_id,
#                 pu.unit_name AS purchase_unit_name,
#                 pu.unit_short AS purchase_unit_short,
                
#                 pv.id AS variation_id,
#                 pv.variation_name,
#                 pv.variation_type,
#                 pv.variation_sku,
#                 pv.variation_cost,
#                 pv.variation_price,
#                 pv.variation_tax_type,
#                 pv.variation_tax,
#                 pv.variation_stock_alert,
#                 pv.expiration_date AS variation_expiration_date
                
#             FROM products p
#             LEFT JOIN brands b ON p.brand_id = b.id
#             LEFT JOIN categories c ON p.category_id = c.id
#             LEFT JOIN units bu ON p.base_unit_id = bu.id
#             LEFT JOIN base_units base_u ON bu.base_unit_id = base_u.id
#             LEFT JOIN units su ON p.sale_unit_id = su.id
#             LEFT JOIN units pu ON p.purchase_unit_id = pu.id
#             LEFT JOIN product_variations pv ON pv.product_id = p.id
            
#             WHERE 
#                 LOWER(p.product_name) LIKE LOWER(%s) OR 
#                 LOWER(p.sku) LIKE LOWER(%s) OR
#                 LOWER(pv.variation_name) LIKE LOWER(%s) OR 
#                 LOWER(pv.variation_sku) LIKE LOWER(%s) OR
#                 LOWER(b.brand_name) LIKE LOWER(%s) OR
#                 LOWER(c.category_name) LIKE LOWER(%s)
            
#             GROUP BY 
#                 p.id, pv.id
            
#             ORDER BY 
#                 CASE
#                     WHEN LOWER(p.product_name) = LOWER(%s) THEN 1
#                     WHEN LOWER(p.sku) = LOWER(%s) THEN 2
#                     WHEN LOWER(pv.variation_sku) = LOWER(%s) THEN 3
#                     WHEN LOWER(p.product_name) LIKE LOWER(%s) THEN 4
#                     ELSE 5
#                 END,
#                 p.product_name ASC,
#                 pv.variation_name ASC
#         """, (like_query, like_query, like_query, like_query, like_query, like_query,
#               query, query, query, like_query))

#         rows = cursor.fetchall()
#         products = {}

#         for row in rows:
#             pid = row['product_id']
            
#             if pid not in products:
#                 # ✅ Get warehouse-specific batches for SINGLE products
#                 batches = []
#                 total_stock = 0.0
                
#                 if row['product_type'] == 'single':
#                     # Get batches with warehouse stock information
#                     cursor.execute("""
#                         SELECT 
#                             pb.batch_id,
#                             pb.batch_number,
#                             pb.cost AS product_cost,
#                             pb.price AS product_price,
#                             pb.expiration_date,
#                             pb.created_on,
#                             ws.warehouse_id,
#                             ws.quantity AS warehouse_remaining_qty,
#                             -- Total quantity across ALL warehouses for this batch
#                             (SELECT COALESCE(SUM(ws2.quantity), 0) 
#                              FROM warehouse_stock ws2 
#                              WHERE ws2.batch_id = pb.batch_id) AS product_total_stock
#                         FROM product_batches pb
#                         INNER JOIN warehouse_stock ws ON ws.batch_id = pb.batch_id
#                         WHERE pb.product_id = %s 
#                             AND pb.variation_id IS NULL
#                         ORDER BY ws.warehouse_id, pb.created_on DESC
#                     """, (pid,))
                    
#                     batch_rows = cursor.fetchall()
                    
#                     # Process batches
#                     for batch_row in batch_rows:
#                         warehouse_qty = float(batch_row['warehouse_remaining_qty']) if batch_row['warehouse_remaining_qty'] else 0.0
#                         total_qty = float(batch_row['product_total_stock']) if batch_row['product_total_stock'] else 0.0
                        
#                         batches.append({
#                             'batch_id': batch_row['batch_id'],
#                             'batch_number': batch_row['batch_number'],
#                             'product_cost': float(batch_row['product_cost']) if batch_row['product_cost'] else 0.0,
#                             'product_price': float(batch_row['product_price']) if batch_row['product_price'] else 0.0,
#                             'remaining_quantity': warehouse_qty,  # ✅ Warehouse-specific quantity
#                             'product_total_stock': total_qty,     # ✅ Total across all warehouses
#                             'expiration_date': batch_row['expiration_date'].isoformat() if batch_row['expiration_date'] else None,
#                             'created_on': batch_row['created_on'].isoformat() if batch_row['created_on'] else None,
#                             'warehouse_id': batch_row['warehouse_id']
#                         })
                    
#                     # Calculate total stock across all warehouses
#                     cursor.execute("""
#                         SELECT COALESCE(SUM(ws.quantity), 0) as total_stock
#                         FROM warehouse_stock ws
#                         INNER JOIN product_batches pb ON ws.batch_id = pb.batch_id
#                         WHERE pb.product_id = %s AND pb.variation_id IS NULL
#                     """, (pid,))
#                     total_result = cursor.fetchone()
#                     total_stock = float(total_result['total_stock']) if total_result else 0.0
                
#                 products[pid] = {
#                     "product_id": pid,
#                     "product_name": row['product_name'],
#                     "sku": row['product_sku'],
#                     "product_type": row['product_type'],
#                     "tax_type": row['tax_type'],
#                     "product_tax": float(row['product_tax']),
#                     "stock_alert": row['stock_alert'],
#                     "expiration_date": row['expiration_date'].isoformat() if row['expiration_date'] else None,
#                     "note": row['note'],
                    
#                     "brand_id": row['brand_id'],
#                     "brand_name": row['brand_name'],
#                     "category_id": row['category_id'],
#                     "category_name": row['category_name'],
                    
#                     "base_unit": {
#                         "id": row['base_unit_id'],
#                         "name": row['unit_name'],
#                         "short": row['unit_short'],
#                         "base_unit_name": row['base_unit_name']
#                     },
#                     "sale_unit": {
#                         "id": row['sale_unit_id'],
#                         "name": row['sale_unit_name'],
#                         "short": row['sale_unit_short']
#                     },
#                     "purchase_unit": {
#                         "id": row['purchase_unit_id'],
#                         "name": row['purchase_unit_name'],
#                         "short": row['purchase_unit_short']
#                     },
                    
#                     "total_stock": total_stock,
#                     "batches": batches,
#                     "variations": []
#                 }

#             # ✅ Add variation if available (EVEN IF NO BATCHES EXIST)
#             if row['variation_id']:
#                 variation_id = row['variation_id']
                
#                 # Get warehouse-specific batches for this variation
#                 cursor.execute("""
#                     SELECT 
#                         pb.batch_id,
#                         pb.batch_number,
#                         pb.cost AS product_cost,
#                         pb.price AS product_price,
#                         pb.expiration_date,
#                         pb.created_on,
#                         ws.warehouse_id,
#                         ws.quantity AS warehouse_remaining_qty,
#                         -- Total quantity across ALL warehouses for this batch
#                         (SELECT COALESCE(SUM(ws2.quantity), 0) 
#                          FROM warehouse_stock ws2 
#                          WHERE ws2.batch_id = pb.batch_id) AS product_total_stock
#                     FROM product_batches pb
#                     INNER JOIN warehouse_stock ws ON ws.batch_id = pb.batch_id
#                     WHERE pb.product_id = %s 
#                         AND pb.variation_id = %s
#                     ORDER BY ws.warehouse_id, pb.created_on DESC
#                 """, (pid, variation_id))
                
#                 var_batch_rows = cursor.fetchall()
                
#                 variation_batches = []
                
#                 # ✅ Process batches if they exist
#                 if var_batch_rows:
#                     for batch_row in var_batch_rows:
#                         warehouse_qty = float(batch_row['warehouse_remaining_qty']) if batch_row['warehouse_remaining_qty'] else 0.0
#                         total_qty = float(batch_row['product_total_stock']) if batch_row['product_total_stock'] else 0.0
                        
#                         variation_batches.append({
#                             'batch_id': batch_row['batch_id'],
#                             'batch_number': batch_row['batch_number'],
#                             'product_cost': float(batch_row['product_cost']) if batch_row['product_cost'] else 0.0,
#                             'product_price': float(batch_row['product_price']) if batch_row['product_price'] else 0.0,
#                             'remaining_quantity': warehouse_qty,  # ✅ Warehouse-specific
#                             'product_total_stock': total_qty,     # ✅ Total across all warehouses
#                             'expiration_date': batch_row['expiration_date'].isoformat() if batch_row['expiration_date'] else None,
#                             'created_on': batch_row['created_on'].isoformat() if batch_row['created_on'] else None,
#                             'warehouse_id': batch_row['warehouse_id']
#                         })
                
#                 # Calculate total stock for this variation across all warehouses
#                 cursor.execute("""
#                     SELECT COALESCE(SUM(ws.quantity), 0) as variation_total_stock
#                     FROM warehouse_stock ws
#                     INNER JOIN product_batches pb ON ws.batch_id = pb.batch_id
#                     WHERE pb.product_id = %s AND pb.variation_id = %s
#                 """, (pid, variation_id))
#                 var_total_result = cursor.fetchone()
#                 variation_total_stock = float(var_total_result['variation_total_stock']) if var_total_result else 0.0
                
#                 # ✅ IMPORTANT: Add variation even if batches list is empty
#                 # This allows new variations to be searchable and purchasable
#                 products[pid]['variations'].append({
#                     "variation_id": variation_id,
#                     "variation_name": row['variation_name'],
#                     "variation_type": row['variation_type'],
#                     "variation_sku": row['variation_sku'],
#                     # ✅ Use variation's defined cost/price if available, otherwise 0
#                     "product_cost": float(row['variation_cost']) if row['variation_cost'] else 0.0,
#                     "product_price": float(row['variation_price']) if row['variation_price'] else 0.0,
#                     "variation_tax_type": row['variation_tax_type'] or 'exclusive',
#                     "variation_tax": float(row['variation_tax'] or 0),
#                     "variation_stock_alert": row['variation_stock_alert'],
#                     "expiration_date": row['variation_expiration_date'].isoformat() if row['variation_expiration_date'] else None,
#                     "total_stock": variation_total_stock,  # Will be 0 for new variations
#                     "batches": variation_batches  # Will be empty [] for new variations
#                 })
        
#         # Calculate total stock for variable products (sum of all variations)
#         for product in products.values():
#             if product['product_type'] == 'variable' and product['variations']:
#                 product['total_stock'] = sum(v['total_stock'] for v in product['variations'])

#         print(f"✅ Found {len(products)} products matching query: {query}")
        
#         # Log variations without batches for debugging
#         for product in products.values():
#             if product['product_type'] == 'variable':
#                 for variation in product['variations']:
#                     if not variation['batches']:
#                         print(f"📦 New variation without batches: {product['product_name']} - {variation['variation_name']} (ID: {variation['variation_id']})")
        
#         return jsonify(list(products.values())), 200

#     except Exception as e:
#         print("❌ Search Error:", e)
#         import traceback
#         traceback.print_exc()
#         return jsonify({"status": "error", "message": str(e)}), 500

#     finally:
#         cursor.close()
#         conn.close()


def group_batches(batches):
    """
    ✅ FIX #4: Group batches by warehouse, price, cost, and expiry date
    ✅ CRITICAL: Ensures warehouse_id is not None
    
    Args:
        batches: List of batch dictionaries
        
    Returns:
        List of grouped batch dictionaries
    """
    if not batches:
        return []
    
    grouped = {}
    
    for batch in batches:
        # ✅ FIX #4: Skip batches without warehouse_id
        warehouse_id = batch.get('warehouse_id')
        if warehouse_id is None:
            print(f"⚠️ Warning: Skipping batch {batch.get('batch_id')} - no warehouse_id")
            continue
        
        # Create grouping key based on warehouse, price, cost, expiry
        key = (
            warehouse_id,
            batch['product_price'],
            batch['product_cost'],
            batch['expiration_date']
        )
        
        if key not in grouped:
            grouped[key] = {
                'batch_ids': [],
                'batch_numbers': [],
                'batch_count': 0,
                'warehouse_id': warehouse_id,
                'warehouse_name': batch['warehouse_name'],
                'product_cost': batch['product_cost'],
                'product_price': batch['product_price'],
                'remaining_quantity': 0.0,
                'expiration_date': batch['expiration_date'],
                'created_on': batch['created_on']  # Use earliest batch date
            }
        
        grouped[key]['batch_ids'].append(batch['batch_id'])
        grouped[key]['batch_numbers'].append(batch['batch_number'])
        grouped[key]['batch_count'] += 1
        grouped[key]['remaining_quantity'] += batch['remaining_quantity']
        
        # Keep earliest created_on date
        if batch['created_on'] and (
            grouped[key]['created_on'] is None or 
            batch['created_on'] < grouped[key]['created_on']
        ):
            grouped[key]['created_on'] = batch['created_on']
    
    # ✅ Sort by created_on (FIFO - oldest first)
    result = sorted(grouped.values(), key=lambda x: x['created_on'] or datetime.max)
    
    return result


# ==========================================
# SEARCH PRODUCTS - COMPLETE FIX
# ==========================================
@product_bp.route('/search_products', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager', 'cashier')
def search_products():
    """
    Search products with warehouse-specific batch information
    
    ✅ FIX #1: Removes duplicate batch IDs
    ✅ FIX #2: Better stock validation (check > 0)
    ✅ FIX #4: Warehouse_id NULL check and validation
    ✅ FIX #5: Safe display of product names
    ✅ Returns batches with warehouse_name
    ✅ Handles variable products with variations
    ✅ Groups similar batches by price/cost/expiry
    """
    query = request.args.get('query')
    group_batches_param = request.args.get('group_batches', 'true').lower() == 'true'
    warehouse_id = request.args.get('warehouse_id')  # ✅ Optional warehouse filter
    
    if not query:
        return jsonify({"status": "error", "message": "Query parameter is required"}), 400

    conn = get_db_connection()
    cursor = conn.cursor(dictionary=True)

    try:
        like_query = f"%{query}%"

        print(f"\n🔍 Searching products: '{query}'")
        if warehouse_id:
            print(f"   Filtering by warehouse: {warehouse_id}")

        # Main product search query
        cursor.execute("""
            SELECT 
                p.id AS product_id,
                p.product_name,
                p.sku AS product_sku,
                p.product_type,
                p.tax_type,
                p.product_tax,
                p.stock_alert,
                p.expiration_date,
                p.note,
                
                b.id AS brand_id,
                b.brand_name,
                
                c.id AS category_id,
                c.category_name,
                
                bu.id AS base_unit_id,
                bu.unit_name AS unit_name,
                bu.unit_short AS unit_short,
                base_u.base_unit AS base_unit_name,
                
                su.id AS sale_unit_id,
                su.unit_name AS sale_unit_name,
                su.unit_short AS sale_unit_short,
                
                pu.id AS purchase_unit_id,
                pu.unit_name AS purchase_unit_name,
                pu.unit_short AS purchase_unit_short,
                
                pv.id AS variation_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                pv.variation_cost,
                pv.variation_price,
                pv.variation_tax_type,
                pv.variation_tax,
                pv.variation_stock_alert,
                pv.expiration_date AS variation_expiration_date
                
            FROM products p
            LEFT JOIN brands b ON p.brand_id = b.id
            LEFT JOIN categories c ON p.category_id = c.id
            LEFT JOIN units bu ON p.base_unit_id = bu.id
            LEFT JOIN base_units base_u ON bu.base_unit_id = base_u.id
            LEFT JOIN units su ON p.sale_unit_id = su.id
            LEFT JOIN units pu ON p.purchase_unit_id = pu.id
            LEFT JOIN product_variations pv ON pv.product_id = p.id
            
            WHERE 
                LOWER(p.product_name) LIKE LOWER(%s) OR 
                LOWER(p.sku) LIKE LOWER(%s) OR
                LOWER(pv.variation_name) LIKE LOWER(%s) OR 
                LOWER(pv.variation_sku) LIKE LOWER(%s) OR
                LOWER(b.brand_name) LIKE LOWER(%s) OR
                LOWER(c.category_name) LIKE LOWER(%s)
            
            GROUP BY 
                p.id, pv.id
            
            ORDER BY 
                CASE
                    WHEN LOWER(p.product_name) = LOWER(%s) THEN 1
                    WHEN LOWER(p.sku) = LOWER(%s) THEN 2
                    WHEN LOWER(pv.variation_sku) = LOWER(%s) THEN 3
                    WHEN LOWER(p.product_name) LIKE LOWER(%s) THEN 4
                    ELSE 5
                END,
                p.product_name ASC,
                pv.variation_name ASC
        """, (like_query, like_query, like_query, like_query, like_query, like_query,
              query, query, query, like_query))

        rows = cursor.fetchall()
        products = {}

        print(f"   Found {len(rows)} product rows")

        for row in rows:
            pid = row['product_id']
            
            if pid not in products:
                # Get warehouse-specific batches for SINGLE products
                batches = []
                total_stock = 0.0
                
                if row['product_type'] == 'single':
                    # ✅ Build batch query with optional warehouse filter
                    batch_query = """
                        SELECT 
                            pb.batch_id,
                            pb.batch_number,
                            pb.cost AS product_cost,
                            pb.price AS product_price,
                            pb.expiration_date,
                            pb.created_on,
                            ws.warehouse_id,
                            w.warehouse_name,
                            ws.quantity AS warehouse_remaining_qty,
                            (SELECT COALESCE(SUM(ws2.quantity), 0) 
                             FROM warehouse_stock ws2 
                             WHERE ws2.batch_id = pb.batch_id) AS product_total_stock
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON ws.batch_id = pb.batch_id
                        LEFT JOIN warehouses w ON ws.warehouse_id = w.id
                        WHERE pb.product_id = %s 
                            AND pb.variation_id IS NULL
                            AND ws.quantity > 0
                            AND ws.warehouse_id IS NOT NULL
                    """
                    
                    batch_params = [pid]
                    
                    # ✅ Add warehouse filter if specified
                    if warehouse_id:
                        batch_query += " AND ws.warehouse_id = %s"
                        batch_params.append(int(warehouse_id))
                    
                    batch_query += " ORDER BY ws.warehouse_id, pb.created_on ASC"
                    
                    cursor.execute(batch_query, batch_params)
                    batch_rows = cursor.fetchall()
                    
                    # Process batches
                    ungrouped_batches = []
                    for batch_row in batch_rows:
                        warehouse_qty = safe_float(batch_row['warehouse_remaining_qty'])
                        
                        # ✅ FIX #2: Only include batches with stock > 0
                        if warehouse_qty <= 0:
                            continue
                        
                        # ✅ FIX #4: Skip batches without warehouse_id
                        if batch_row['warehouse_id'] is None:
                            print(f"⚠️ Skipping batch {batch_row['batch_id']} - no warehouse_id")
                            continue
                        
                        # ✅ Check batch expiry
                        is_expired = False
                        if batch_row['expiration_date']:
                            expiry = batch_row['expiration_date']
                            if isinstance(expiry, str):
                                expiry = datetime.strptime(expiry, '%Y-%m-%d').date()
                            if expiry < date.today():
                                is_expired = True
                                print(f"⚠️ Batch {batch_row['batch_number']} expired on {expiry}")
                        
                        ungrouped_batches.append({
                            'batch_id': batch_row['batch_id'],
                            'batch_number': batch_row['batch_number'],
                            'product_cost': safe_float(batch_row['product_cost']),
                            'product_price': safe_float(batch_row['product_price']),
                            'remaining_quantity': warehouse_qty,
                            'expiration_date': batch_row['expiration_date'],
                            'created_on': batch_row['created_on'],
                            'warehouse_id': batch_row['warehouse_id'],
                            'warehouse_name': batch_row['warehouse_name'] or 'Unknown Warehouse',
                            'is_expired': is_expired
                        })
                    
                    # ✅ Group batches if requested
                    if group_batches_param and ungrouped_batches:
                        batches_grouped = group_batches(ungrouped_batches)
                        
                        batches = []
                        for grouped_batch in batches_grouped:
                            # ✅ FIX #1: Ensure unique batch IDs
                            unique_batch_ids = list(set(grouped_batch['batch_ids']))
                            
                            batches.append({
                                'batch_id': unique_batch_ids[0],  # Primary batch (FIFO)
                                'batch_ids': unique_batch_ids,    # ✅ Unique IDs only
                                'batch_number': grouped_batch['batch_numbers'][0],
                                'batch_numbers': grouped_batch['batch_numbers'],
                                'batch_count': grouped_batch['batch_count'],
                                'product_cost': grouped_batch['product_cost'],
                                'product_price': grouped_batch['product_price'],
                                'remaining_quantity': grouped_batch['remaining_quantity'],
                                'expiration_date': grouped_batch['expiration_date'].isoformat() if grouped_batch['expiration_date'] else None,
                                'created_on': grouped_batch['created_on'].isoformat() if grouped_batch['created_on'] else None,
                                'warehouse_id': grouped_batch['warehouse_id'],
                                'warehouse_name': grouped_batch['warehouse_name']
                            })
                    else:
                        # No grouping - format individual batches
                        batches = [{
                            'batch_id': b['batch_id'],
                            'batch_ids': [b['batch_id']],
                            'batch_number': b['batch_number'],
                            'batch_numbers': [b['batch_number']],
                            'batch_count': 1,
                            'product_cost': b['product_cost'],
                            'product_price': b['product_price'],
                            'remaining_quantity': b['remaining_quantity'],
                            'expiration_date': b['expiration_date'].isoformat() if b['expiration_date'] else None,
                            'created_on': b['created_on'].isoformat() if b['created_on'] else None,
                            'warehouse_id': b['warehouse_id'],
                            'warehouse_name': b['warehouse_name'],
                            'is_expired': b.get('is_expired', False)
                        } for b in ungrouped_batches]
                    
                    # Calculate total stock across all warehouses
                    stock_query = """
                        SELECT COALESCE(SUM(ws.quantity), 0) as total_stock
                        FROM warehouse_stock ws
                        INNER JOIN product_batches pb ON ws.batch_id = pb.batch_id
                        WHERE pb.product_id = %s AND pb.variation_id IS NULL
                        AND ws.warehouse_id IS NOT NULL
                    """
                    stock_params = [pid]
                    
                    if warehouse_id:
                        stock_query += " AND ws.warehouse_id = %s"
                        stock_params.append(int(warehouse_id))
                    
                    cursor.execute(stock_query, stock_params)
                    total_result = cursor.fetchone()
                    total_stock = safe_float(total_result['total_stock'])
                
                # ✅ FIX #5: Safe display of product info
                products[pid] = {
                    "product_id": pid,
                    "product_name": row['product_name'] or 'Unnamed Product',
                    "sku": row['product_sku'] or 'No SKU',
                    "product_type": row['product_type'],
                    "tax_type": row['tax_type'],
                    "product_tax": safe_float(row['product_tax']),
                    "stock_alert": row['stock_alert'],
                    "expiration_date": row['expiration_date'].isoformat() if row['expiration_date'] else None,
                    "note": row['note'],
                    
                    "brand_id": row['brand_id'],
                    "brand_name": row['brand_name'],
                    "category_id": row['category_id'],
                    "category_name": row['category_name'],
                    
                    "base_unit": {
                        "id": row['base_unit_id'],
                        "name": row['unit_name'],
                        "short": row['unit_short'],
                        "base_unit_name": row['base_unit_name']
                    },
                    "sale_unit": {
                        "id": row['sale_unit_id'],
                        "name": row['sale_unit_name'],
                        "short": row['sale_unit_short']
                    },
                    "purchase_unit": {
                        "id": row['purchase_unit_id'],
                        "name": row['purchase_unit_name'],
                        "short": row['purchase_unit_short']
                    },
                    
                    "total_stock": total_stock,
                    "batches": batches,
                    "variations": []
                }

            # Add variation
            if row['variation_id']:
                variation_id = row['variation_id']
                
                # ✅ Build variation batch query
                var_batch_query = """
                    SELECT 
                        pb.batch_id,
                        pb.batch_number,
                        pb.cost AS product_cost,
                        pb.price AS product_price,
                        pb.expiration_date,
                        pb.created_on,
                        ws.warehouse_id,
                        w.warehouse_name,
                        ws.quantity AS warehouse_remaining_qty
                    FROM product_batches pb
                    INNER JOIN warehouse_stock ws ON ws.batch_id = pb.batch_id
                    LEFT JOIN warehouses w ON ws.warehouse_id = w.id
                    WHERE pb.product_id = %s 
                        AND pb.variation_id = %s
                        AND ws.quantity > 0
                        AND ws.warehouse_id IS NOT NULL
                """
                
                var_batch_params = [pid, variation_id]
                
                if warehouse_id:
                    var_batch_query += " AND ws.warehouse_id = %s"
                    var_batch_params.append(int(warehouse_id))
                
                var_batch_query += " ORDER BY ws.warehouse_id, pb.created_on ASC"
                
                cursor.execute(var_batch_query, var_batch_params)
                var_batch_rows = cursor.fetchall()
                
                # Process variation batches
                ungrouped_var_batches = []
                if var_batch_rows:
                    for batch_row in var_batch_rows:
                        warehouse_qty = safe_float(batch_row['warehouse_remaining_qty'])
                        
                        # ✅ FIX #2: Only include batches with stock > 0
                        if warehouse_qty <= 0:
                            continue
                        
                        # ✅ FIX #4: Skip batches without warehouse_id
                        if batch_row['warehouse_id'] is None:
                            print(f"⚠️ Skipping variation batch {batch_row['batch_id']} - no warehouse_id")
                            continue
                        
                        # Check expiry
                        is_expired = False
                        if batch_row['expiration_date']:
                            expiry = batch_row['expiration_date']
                            if isinstance(expiry, str):
                                expiry = datetime.strptime(expiry, '%Y-%m-%d').date()
                            if expiry < date.today():
                                is_expired = True
                        
                        ungrouped_var_batches.append({
                            'batch_id': batch_row['batch_id'],
                            'batch_number': batch_row['batch_number'],
                            'product_cost': safe_float(batch_row['product_cost']),
                            'product_price': safe_float(batch_row['product_price']),
                            'remaining_quantity': warehouse_qty,
                            'expiration_date': batch_row['expiration_date'],
                            'created_on': batch_row['created_on'],
                            'warehouse_id': batch_row['warehouse_id'],
                            'warehouse_name': batch_row['warehouse_name'] or 'Unknown Warehouse',
                            'is_expired': is_expired
                        })
                
                # ✅ Group variation batches
                if group_batches_param and ungrouped_var_batches:
                    variation_batches_grouped = group_batches(ungrouped_var_batches)
                    
                    variation_batches = []
                    for grouped_batch in variation_batches_grouped:
                        # ✅ FIX #1: Ensure unique batch IDs
                        unique_batch_ids = list(set(grouped_batch['batch_ids']))
                        
                        variation_batches.append({
                            'batch_id': unique_batch_ids[0],
                            'batch_ids': unique_batch_ids,
                            'batch_number': grouped_batch['batch_numbers'][0],
                            'batch_numbers': grouped_batch['batch_numbers'],
                            'batch_count': grouped_batch['batch_count'],
                            'product_cost': grouped_batch['product_cost'],
                            'product_price': grouped_batch['product_price'],
                            'remaining_quantity': grouped_batch['remaining_quantity'],
                            'expiration_date': grouped_batch['expiration_date'].isoformat() if grouped_batch['expiration_date'] else None,
                            'created_on': grouped_batch['created_on'].isoformat() if grouped_batch['created_on'] else None,
                            'warehouse_id': grouped_batch['warehouse_id'],
                            'warehouse_name': grouped_batch['warehouse_name']
                        })
                else:
                    variation_batches = [{
                        'batch_id': b['batch_id'],
                        'batch_ids': [b['batch_id']],
                        'batch_number': b['batch_number'],
                        'batch_numbers': [b['batch_number']],
                        'batch_count': 1,
                        'product_cost': b['product_cost'],
                        'product_price': b['product_price'],
                        'remaining_quantity': b['remaining_quantity'],
                        'expiration_date': b['expiration_date'].isoformat() if b['expiration_date'] else None,
                        'created_on': b['created_on'].isoformat() if b['created_on'] else None,
                        'warehouse_id': b['warehouse_id'],
                        'warehouse_name': b['warehouse_name'],
                        'is_expired': b.get('is_expired', False)
                    } for b in ungrouped_var_batches]
                
                # Calculate total stock for variation
                var_stock_query = """
                    SELECT COALESCE(SUM(ws.quantity), 0) as variation_total_stock
                    FROM warehouse_stock ws
                    INNER JOIN product_batches pb ON ws.batch_id = pb.batch_id
                    WHERE pb.product_id = %s AND pb.variation_id = %s
                    AND ws.warehouse_id IS NOT NULL
                """
                var_stock_params = [pid, variation_id]
                
                if warehouse_id:
                    var_stock_query += " AND ws.warehouse_id = %s"
                    var_stock_params.append(int(warehouse_id))
                
                cursor.execute(var_stock_query, var_stock_params)
                var_total_result = cursor.fetchone()
                variation_total_stock = safe_float(var_total_result['variation_total_stock'])
                
                products[pid]['variations'].append({
                    "variation_id": variation_id,
                    "variation_name": row['variation_name'],
                    "variation_type": row['variation_type'],
                    "variation_sku": row['variation_sku'],
                    "product_cost": safe_float(row['variation_cost']),
                    "product_price": safe_float(row['variation_price']),
                    "variation_tax_type": row['variation_tax_type'] or 'exclusive',
                    "variation_tax": safe_float(row['variation_tax']),
                    "variation_stock_alert": row['variation_stock_alert'],
                    "expiration_date": row['variation_expiration_date'].isoformat() if row['variation_expiration_date'] else None,
                    "total_stock": variation_total_stock,
                    "batches": variation_batches
                })
        
        # Calculate total stock for variable products
        for product in products.values():
            if product['product_type'] == 'variable' and product['variations']:
                product['total_stock'] = sum(v['total_stock'] for v in product['variations'])

        print(f"✅ Returning {len(products)} unique products")
        
        return jsonify(list(products.values())), 200

    except Exception as e:
        print("❌ Search Error:", e)
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500

    finally:
        cursor.close()
        conn.close()

# ✅ GET PRODUCT'S ADDED VARIATIONS WITH REMAINING TYPES TO ADD
@product_bp.route('/product/<int:product_id>/variations-with-types', methods=['GET'])
@jwt_required()
def get_product_variations_with_types(product_id):
    conn = None
    cursor = None
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # 1️⃣ Check if product exists
        cursor.execute("""
            SELECT id, product_name 
            FROM products 
            WHERE id = %s
        """, (product_id,))
        product = cursor.fetchone()

        if not product:
            return jsonify({"error": "Product not found"}), 404

        # 2️⃣ Get variation names that are already added to this product
        cursor.execute("""
            SELECT DISTINCT variation_name
            FROM product_variations
            WHERE product_id = %s
        """, (product_id,))
        
        added_variation_names = [row['variation_name'] for row in cursor.fetchall()]

        if not added_variation_names:
            return jsonify({
                "status": True,
                "product": product,
                "variations": [],
                "message": "No variations added to this product yet"
            }), 200

        # 3️⃣ Get variation types already added to this product
        cursor.execute("""
            SELECT variation_name, variation_type
            FROM product_variations
            WHERE product_id = %s
        """, (product_id,))
        
        added_types = cursor.fetchall()
        added_combinations = {(v['variation_name'], v['variation_type']) for v in added_types}

        # 4️⃣ For each added variation name, get ALL available types
        variations_list = []
        
        for var_name in added_variation_names:
            # Get variation ID
            cursor.execute("""
                SELECT id 
                FROM variations 
                WHERE name = %s
            """, (var_name,))
            var_result = cursor.fetchone()
            
            if not var_result:
                continue
            
            variation_id = var_result['id']
            
            # Get all types for this variation
            cursor.execute("""
                SELECT id, type_name, sort_order
                FROM variation_types
                WHERE variation_id = %s
                ORDER BY sort_order
            """, (variation_id,))
            
            all_types = cursor.fetchall()
            
            # Separate added and not added types
            added_types_list = []
            not_added_types_list = []
            
            for type_row in all_types:
                type_data = {
                    "id": type_row['id'],
                    "type_name": type_row['type_name']
                }
                
                if (var_name, type_row['type_name']) in added_combinations:
                    added_types_list.append(type_data)
                else:
                    not_added_types_list.append(type_data)
            
            variations_list.append({
                "id": variation_id,
                "name": var_name,
                "added_types": added_types_list,
                "available_types": not_added_types_list
            })

        # 5️⃣ Return final response
        return jsonify({
            "status": True,
            "product": product,
            "variations": variations_list
        }), 200

    except mysql.connector.Error as err:
        print("Database Error:", err)
        traceback.print_exc()
        return jsonify({"error": f"Database error: {str(err)}"}), 500

    except Exception as e:
        print("Server Error:", e)
        traceback.print_exc()
        return jsonify({"error": f"Server error: {str(e)}"}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


@product_bp.route('/product/<int:product_id>/add-variation', methods=['POST'])
@jwt_required()
@role_required('admin')
def add_variation_to_product(product_id):
    """
    Add a new variation to a variable product.
    Creates ONLY variation record - NO warehouse stock or batches.
    
    Stock is added separately through purchase orders or stock adjustments.
    """
    data = request.get_json()

    # Validate required fields
    required_fields = [
        'variation_name',
        'variation_type',
        'variation_sku',
        'variation_cost',
        'variation_price',
        'variation_tax_type'
    ]

    missing = [f for f in required_fields if not data.get(f)]
    if missing:
        return jsonify({'error': f"Missing required fields: {', '.join(missing)}"}), 400

    # Optional fields with defaults
    variation_tax = float(data.get('variation_tax', 0.0))
    variation_stock_alert = int(data.get('variation_stock_alert', 0))
    expiration_date = data.get('expiration_date') or None

    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        # Check if product exists and is variable type
        cursor.execute("""
            SELECT id, product_type, product_name 
            FROM products 
            WHERE id = %s
        """, (product_id,))
        product = cursor.fetchone()

        if not product:
            return jsonify({'error': 'Product not found'}), 404

        if product['product_type'] != 'variable':
            return jsonify({'error': 'Cannot add variations to non-variable products'}), 400

        # Check if this exact variation combination already exists
        cursor.execute("""
            SELECT id FROM product_variations 
            WHERE product_id = %s 
            AND variation_name = %s 
            AND variation_type = %s
        """, (product_id, data['variation_name'], data['variation_type']))
        
        if cursor.fetchone():
            return jsonify({
                'error': f"Variation '{data['variation_name']} - {data['variation_type']}' already exists for this product"
            }), 400

        # Check if variation SKU already exists for this product
        cursor.execute("""
            SELECT id FROM product_variations 
            WHERE product_id = %s AND variation_sku = %s
        """, (product_id, data['variation_sku']))
        
        if cursor.fetchone():
            return jsonify({
                'error': f"Variation SKU '{data['variation_sku']}' already exists for this product"
            }), 400

        # Insert the variation (ONLY variation record, no stock)
        cursor.execute("""
            INSERT INTO product_variations (
                product_id, 
                variation_name, 
                variation_type,
                variation_sku,
                variation_cost, 
                variation_price,
                variation_tax_type, 
                variation_tax, 
                variation_stock_alert, 
                expiration_date,
                created_at,
                updated_at
            ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
        """, (
            product_id,
            data['variation_name'],
            data['variation_type'],
            data['variation_sku'],
            float(data['variation_cost']),
            float(data['variation_price']),
            data['variation_tax_type'],
            variation_tax,
            variation_stock_alert,
            expiration_date
        ))

        variation_id = cursor.lastrowid

        conn.commit()

        return jsonify({
            'success': True,
            'message': f"Variation '{data['variation_name']} - {data['variation_type']}' added successfully!",
            'data': {
                'variation_id': variation_id,
                'product_id': product_id,
                'product_name': product['product_name'],
                'variation_name': data['variation_name'],
                'variation_type': data['variation_type'],
                'variation_sku': data['variation_sku'],
                'variation_cost': float(data['variation_cost']),
                'variation_price': float(data['variation_price']),
                'variation_tax_type': data['variation_tax_type'],
                'variation_tax': variation_tax,
                'variation_stock_alert': variation_stock_alert,
                'expiration_date': expiration_date,
                'initial_stock': 0,
                'note': 'Add stock through Purchase Orders or Stock Adjustments'
            }
        }), 201

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"Database Error adding variation: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as err:
        conn.rollback()
        print(f"Error adding variation: {err}")
        traceback.print_exc()
        return jsonify({'error': 'Failed to add variation'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()

@product_bp.route('/delete_variations/<int:variation_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_variations(variation_id):
    """
    Delete a product variation and all related data:
    - Product batches
    - Warehouse stock
    - Purchase order items (optional - might need to keep for history)
    
    Note: This is a destructive operation. Consider soft delete for production.
    """
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)

    try:
        # 1️⃣ Check if the variation exists
        cursor.execute("""
            SELECT 
                pv.id, 
                pv.product_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                p.product_name
            FROM product_variations pv
            JOIN products p ON pv.product_id = p.id
            WHERE pv.id = %s
        """, (variation_id,))
        
        variation = cursor.fetchone()

        if not variation:
            return jsonify({'error': 'Variation not found'}), 404

        product_id = variation['product_id']
        variation_info = f"{variation['variation_name']} - {variation['variation_type']}"

        # 2️⃣ Check if variation has stock in any warehouse
        cursor.execute("""
            SELECT IFNULL(SUM(quantity), 0) as total_stock
            FROM warehouse_stock
            WHERE product_id = %s AND variation_id = %s
        """, (product_id, variation_id))
        
        stock_result = cursor.fetchone()
        total_stock = float(stock_result['total_stock']) if stock_result else 0.0

        if total_stock > 0:
            return jsonify({
                'error': f'Cannot delete variation with existing stock. Current stock: {total_stock}. Please remove stock first through stock adjustments.'
            }), 400

        # 3️⃣ Check if variation is used in any pending purchase orders
        cursor.execute("""
            SELECT COUNT(*) as order_count
            FROM order_items oi
            JOIN purchase_orders po ON oi.order_id = po.order_id
            WHERE oi.variation_id = %s 
            AND po.status IN ('pending', 'ordered')
        """, (variation_id,))
        
        order_result = cursor.fetchone()
        pending_orders = order_result['order_count'] if order_result else 0

        if pending_orders > 0:
            return jsonify({
                'error': f'Cannot delete variation used in {pending_orders} pending purchase order(s). Complete, receive, or cancel orders first.'
            }), 400

        # 4️⃣ Delete related warehouse stock (should be 0 at this point)
        cursor.execute("""
            DELETE FROM warehouse_stock 
            WHERE product_id = %s AND variation_id = %s
        """, (product_id, variation_id))
        
        deleted_stock = cursor.rowcount
        print(f"✅ Deleted {deleted_stock} warehouse_stock records")

        # 5️⃣ Delete related product batches
        cursor.execute("""
            DELETE FROM product_batches 
            WHERE product_id = %s AND variation_id = %s
        """, (product_id, variation_id))
        
        deleted_batches = cursor.rowcount
        print(f"✅ Deleted {deleted_batches} product_batches records")

        # 6️⃣ Handle order_items - Keep for completed orders (history)
        # Only delete from cancelled/draft orders if necessary
        cursor.execute("""
            SELECT COUNT(*) as completed_order_count
            FROM order_items oi
            JOIN purchase_orders po ON oi.order_id = po.order_id
            WHERE oi.variation_id = %s 
            AND po.status IN ('received', 'completed')
        """, (variation_id,))
        
        completed_result = cursor.fetchone()
        completed_orders = completed_result['completed_order_count'] if completed_result else 0

        if completed_orders > 0:
            # Don't delete - keep for history
            print(f"⚠️ Keeping {completed_orders} order_items in completed orders for history")
            deleted_order_items = 0
        else:
            # No completed orders - safe to delete
            cursor.execute("""
                DELETE FROM order_items 
                WHERE variation_id = %s
            """, (variation_id,))
            deleted_order_items = cursor.rowcount
            print(f"✅ Deleted {deleted_order_items} order_items records")

        # 7️⃣ Delete the variation itself
        cursor.execute("""
            DELETE FROM product_variations 
            WHERE id = %s
        """, (variation_id,))

        if cursor.rowcount == 0:
            conn.rollback()
            return jsonify({'error': 'Failed to delete variation'}), 500

        # 8️⃣ Check remaining variations count
        cursor.execute("""
            SELECT COUNT(*) as remaining_count
            FROM product_variations
            WHERE product_id = %s
        """, (product_id,))
        
        remaining_result = cursor.fetchone()
        remaining_variations = remaining_result['remaining_count'] if remaining_result else 0

        # 9️⃣ If no variations left and product is variable type, warn user
        warning_message = None
        if remaining_variations == 0:
            cursor.execute("""
                SELECT product_type FROM products WHERE id = %s
            """, (product_id,))
            product_result = cursor.fetchone()
            if product_result and product_result['product_type'] == 'variable':
                warning_message = "⚠️ Warning: This was the last variation. Consider converting product to 'single' type or adding new variations."

        conn.commit()

        response_data = {
            'success': True,
            'message': f"✅ Variation '{variation_info}' deleted successfully!",
            'data': {
                'deleted_variation_id': variation_id,
                'product_id': product_id,
                'product_name': variation['product_name'],
                'variation_info': variation_info,
                'variation_sku': variation['variation_sku'],
                'remaining_variations': remaining_variations,
                'deleted_records': {
                    'batches': deleted_batches,
                    'warehouse_stock': deleted_stock,
                    'order_items': deleted_order_items
                },
                'completed_orders_kept': completed_orders
            }
        }

        if warning_message:
            response_data['warning'] = warning_message

        return jsonify(response_data), 200

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"❌ Database Error deleting variation: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"❌ Error deleting variation: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Failed to delete variation: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
            
            
@product_bp.route('/search_vproducts', methods=['GET'])
@jwt_required()
@role_required('admin', 'user', 'doctor')
def search_vproducts():
    query = request.args.get('query')
    if not query:
        return jsonify({"status": "error", "message": "Query parameter is required"}), 400

    conn = get_db_connection()
    cursor = conn.cursor(dictionary=True)

    try:
        like_query = f"%{query}%"

        cursor.execute("""
            SELECT
                p.id AS product_id,
                p.product_name,
                p.sku AS product_sku,
                p.product_type,
                p.product_quantity,
                p.product_price,
                p.product_cost,
                p.tax_type,
                p.product_tax,
                p.sales_unit,
                pv.id AS variation_id,
                pv.variation_name,
                pv.variation_sku,
                pv.variation_quantity,
                pv.variation_price,
                pv.variation_cost,
                pv.variation_tax_type,
                pv.variation_tax
            FROM products p
            LEFT JOIN product_variations pv ON pv.product_id = p.id
            WHERE
                (
                    LOWER(p.product_name) LIKE LOWER(%s) OR
                    LOWER(p.sku) LIKE LOWER(%s) OR
                    LOWER(pv.variation_name) LIKE LOWER(%s) OR
                    LOWER(pv.variation_sku) LIKE LOWER(%s)
                )
                AND (
                    EXISTS (
                        SELECT 1
                        FROM product_batches b
                        JOIN purchase_orders po ON po.order_id = b.purchase_order_id
                        WHERE 
                            (b.product_id = p.id OR b.variation_id = pv.id)
                            AND po.status = 'Received'
                            AND b.remaining_quantity > 0
                    )
                )
            ORDER BY 
                CASE
                    WHEN LOWER(p.product_name) = LOWER(%s) THEN 1
                    WHEN LOWER(p.sku) = LOWER(%s) THEN 1
                    WHEN LOWER(pv.variation_name) = LOWER(%s) THEN 1
                    WHEN LOWER(pv.variation_sku) = LOWER(%s) THEN 1
                    ELSE 2
                END,
                p.product_name ASC
            LIMIT 100
        """, (like_query, like_query, like_query, like_query,
              query, query, query, query))

        rows = cursor.fetchall()
        results = []

        for row in rows:
            if row['variation_id']:
                results.append({
                    "product_id": row['product_id'],
                    "variation_id": row['variation_id'],
                    "product_name": row['product_name'],
                    "variation_name": row['variation_name'],
                    "sku": row['variation_sku'],
                    "display_name": (
                        f"{row['product_name']} - {row['variation_name']} ({row['variation_sku']})"
                        if row['variation_name'] else
                        f"{row['product_name']} ({row['product_sku']})"
                    ),
                    "product_type": "variation",
                    "product_quantity": row['variation_quantity'],
                    "product_price": float(row['variation_price']),
                    "product_cost": float(row['variation_cost']),
                    "tax_type": row.get('variation_tax_type'),
                    "product_tax": float(row.get('variation_tax', 0)),
                    "sales_unit": row['sales_unit']
                })
            else:
                results.append({
                    "product_id": row['product_id'],
                    "variation_id": None,
                    "product_name": row['product_name'],
                    "variation_name": None,
                    "sku": row['product_sku'],
                    "display_name": f"{row['product_name']} ({row['product_sku']})",
                    "product_type": "single",
                    "product_quantity": row['product_quantity'],
                    "product_price": float(row['product_price']),
                    "product_cost": float(row['product_cost']),
                    "tax_type": row['tax_type'],
                    "product_tax": float(row['product_tax']),
                    "sales_unit": row['sales_unit']
                })

        return jsonify(results), 200

    except Exception as e:
        print("Error:", e)
        return jsonify({"status": "error", "message": str(e)}), 500

    finally:
        cursor.close()
        conn.close()

# @product_bp.route('/search_sales_products', methods=['GET'])
# @jwt_required()
# @role_required('admin', 'cashier')
# def search_sales_products():
#     """
#     Search products and variations with warehouse-specific batch-based pricing and stock.
#     Returns products available in the selected warehouse only.
#     ✅ GROUPS batches by cost, price, and our_price - shows combined stock with latest expiration date
#     ✅ INCLUDES our_price from product_batches table
#     ✅ INCLUDES GRN information for traceability
#     ✅ FIXED: Shows decimal quantities properly (0.25, 0.5, etc.)
#     ✅ FIXED: Variation products now searchable by product name
#     """
#     # Accept multiple possible parameters
#     query = request.args.get('query', '').strip()
#     productname = request.args.get('productname', '').strip()
#     variation = request.args.get('variation', '').strip()
#     warehouse_id = request.args.get('warehouse_id', '').strip()
#     store_id = request.args.get('store_id', '').strip()

#     # Use whichever parameter is provided
#     search_term = query or productname or variation

#     if not search_term:
#         return jsonify({"status": "error", "message": "Query parameter is required"}), 400

#     # Detect combined format like "Product - Variation"
#     main_name = None
#     variation_name = None
#     if " - " in search_term:
#         parts = search_term.split(" - ", 1)
#         main_name = parts[0].strip()
#         variation_name = parts[1].strip()
#     else:
#         main_name = search_term
#         variation_name = None  # ✅ FIXED: Don't assume variation_name = main_name

#     conn = get_db_connection()
#     if conn is None:
#         return jsonify({'error': 'Database connection failed'}), 500
        
#     cursor = conn.cursor(dictionary=True)
#     result = []

#     try:
#         # ============================================
#         # SEARCH PRODUCTS WITH UNIT INFORMATION
#         # ============================================
#         cursor.execute("""
#             SELECT 
#                 p.id,
#                 p.product_name,
#                 p.sku,
#                 p.product_type,
#                 p.tax_type,
#                 p.product_tax,
#                 p.base_unit_id,
#                 p.sale_unit_id,
#                 p.purchase_unit_id,
#                 u1.unit_name AS base_unit_name,
#                 u1.unit_short AS base_unit_short,
#                 u2.unit_name AS sale_unit_name,
#                 u2.unit_short AS sale_unit_short,
#                 u3.unit_name AS purchase_unit_name,
#                 u3.unit_short AS purchase_unit_short
#             FROM products p
#             LEFT JOIN units u1 ON p.base_unit_id = u1.id
#             LEFT JOIN units u2 ON p.sale_unit_id = u2.id
#             LEFT JOIN units u3 ON p.purchase_unit_id = u3.id
#             WHERE p.product_name LIKE %s OR p.sku LIKE %s
#         """, (f"%{main_name}%", f"%{main_name}%"))
#         products = cursor.fetchall()

#         for product in products:
#             # ============================================
#             # VARIABLE PRODUCTS (with variations)
#             # ============================================
#             if product['product_type'] == 'variable':
#                 # ✅ FIXED: If variation_name is provided, filter by it; otherwise show ALL variations
#                 if variation_name:
#                     # User searched for specific variation (e.g., "MILK - 1L")
#                     cursor.execute("""
#                         SELECT 
#                             pv.id,
#                             pv.variation_name,
#                             pv.variation_type,
#                             pv.variation_sku,
#                             pv.variation_cost,
#                             pv.variation_price,
#                             pv.variation_tax_type,
#                             pv.variation_tax
#                         FROM product_variations pv
#                         WHERE pv.product_id = %s
#                         AND (pv.variation_name LIKE %s 
#                              OR pv.variation_type LIKE %s 
#                              OR pv.variation_sku LIKE %s)
#                     """, (product['id'], f"%{variation_name}%", f"%{variation_name}%", f"%{variation_name}%"))
#                 else:
#                     # User searched only by product name (e.g., "MILK") - show ALL variations
#                     cursor.execute("""
#                         SELECT 
#                             pv.id,
#                             pv.variation_name,
#                             pv.variation_type,
#                             pv.variation_sku,
#                             pv.variation_cost,
#                             pv.variation_price,
#                             pv.variation_tax_type,
#                             pv.variation_tax
#                         FROM product_variations pv
#                         WHERE pv.product_id = %s
#                     """, (product['id'],))
                
#                 variations = cursor.fetchall()

#                 for var in variations:
#                     # ✅ GET BATCHES GROUPED BY COST, PRICE, AND OUR_PRICE - WITH GRN INFO
#                     if warehouse_id and store_id:
#                         # Filter by both warehouse and store
#                         cursor.execute("""
#                             SELECT 
#                                 pb.price,
#                                 pb.cost,
#                                 pb.our_price,
#                                 MAX(pb.expiration_date) AS latest_expiration_date,
#                                 SUM(ws.quantity) AS total_stock,
#                                 GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
#                                 GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%Y-%m-%d'), ')') SEPARATOR ', ') AS grn_info
#                             FROM product_batches pb
#                             INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
#                             LEFT JOIN grn g ON pb.grn_id = g.grn_id
#                             WHERE pb.product_id = %s 
#                             AND pb.variation_id = %s 
#                             AND pb.remaining_quantity > 0
#                             AND ws.warehouse_id = %s
#                             AND ws.store_id = %s
#                             AND ws.quantity > 0
#                             GROUP BY pb.price, pb.cost, pb.our_price
#                             ORDER BY pb.price ASC, pb.cost ASC
#                         """, (product['id'], var['id'], warehouse_id, store_id))
#                     elif warehouse_id:
#                         # Filter by warehouse only
#                         cursor.execute("""
#                             SELECT 
#                                 pb.price,
#                                 pb.cost,
#                                 pb.our_price,
#                                 MAX(pb.expiration_date) AS latest_expiration_date,
#                                 SUM(ws.quantity) AS total_stock,
#                                 GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
#                                 GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%Y-%m-%d'), ')') SEPARATOR ', ') AS grn_info
#                             FROM product_batches pb
#                             INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
#                             LEFT JOIN grn g ON pb.grn_id = g.grn_id
#                             WHERE pb.product_id = %s 
#                             AND pb.variation_id = %s 
#                             AND pb.remaining_quantity > 0
#                             AND ws.warehouse_id = %s
#                             AND ws.quantity > 0
#                             GROUP BY pb.price, pb.cost, pb.our_price
#                             ORDER BY pb.price ASC, pb.cost ASC
#                         """, (product['id'], var['id'], warehouse_id))
#                     else:
#                         # No warehouse filter - show all stock grouped by price, cost, and our_price
#                         cursor.execute("""
#                             SELECT 
#                                 price,
#                                 cost,
#                                 our_price,
#                                 MAX(expiration_date) AS latest_expiration_date,
#                                 SUM(remaining_quantity) AS total_stock,
#                                 GROUP_CONCAT(DISTINCT batch_id ORDER BY batch_id SEPARATOR ',') AS batch_ids,
#                                 GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%Y-%m-%d'), ')') SEPARATOR ', ') AS grn_info
#                             FROM product_batches pb
#                             LEFT JOIN grn g ON pb.grn_id = g.grn_id
#                             WHERE pb.product_id = %s 
#                             AND pb.variation_id = %s 
#                             AND pb.remaining_quantity > 0
#                             GROUP BY pb.price, pb.cost, pb.our_price
#                             ORDER BY pb.price ASC, pb.cost ASC
#                         """, (product['id'], var['id']))
                    
#                     batches = cursor.fetchall()

#                     # Default values from variation
#                     default_price = float(var['variation_price']) if var['variation_price'] else 0.0
#                     default_cost = float(var['variation_cost']) if var['variation_cost'] else 0.0
#                     default_our_price = 0.0

#                     if batches:
#                         # ✅ CREATE ENTRY FOR EACH UNIQUE COST/PRICE/OUR_PRICE COMBINATION
#                         for batch_group in batches:
#                             price = float(batch_group['price']) if batch_group['price'] else default_price
#                             cost = float(batch_group['cost']) if batch_group['cost'] else default_cost
#                             our_price = float(batch_group['our_price']) if batch_group['our_price'] else default_our_price
#                             stock = float(batch_group['total_stock']) if batch_group['total_stock'] else 0.0
#                             exp_date = batch_group['latest_expiration_date'].strftime('%Y-%m-%d') if batch_group['latest_expiration_date'] else '-'
#                             batch_ids = batch_group['batch_ids']
#                             grn_info = batch_group.get('grn_info', 'N/A')  # ✅ GRN info
                            
#                             # Use variation tax if available, otherwise product tax
#                             tax_type = var.get('variation_tax_type') or product.get('tax_type')
#                             product_tax = float(var.get('variation_tax', 0)) if var.get('variation_tax') else float(product.get('product_tax', 0))
                            
#                             # ✅ FIXED: Format quantity with decimals (removes trailing zeros)
#                             stock_display = f"{stock:.2f}".rstrip('0').rstrip('.') if stock % 1 else f"{int(stock)}"
                            
#                             # ✅ Format with our_price: MILK - 1L Pack - (Cost 120 - Our Price 140 - Price 150) Stock 200 Exp: 2025-01-15
#                             display_name = f"{product['product_name']} - {var['variation_name']} - (Cost {cost:.0f} - Our Price {our_price:.0f} - Price {price:.0f}) Stock {stock_display} Exp: {exp_date}"
                            
#                             result.append({
#                                 "product_id": product['id'],
#                                 "variation_id": var['id'],
#                                 "batch_ids": batch_ids,
#                                 "grn_info": grn_info,
#                                 "product_name": product['product_name'],
#                                 "variation_name": var['variation_name'],
#                                 "variation_type": var['variation_type'],
#                                 "sku": var['variation_sku'],
#                                 "display_name": display_name,
#                                 "product_type": "variation",
#                                 "product_quantity": stock,
#                                 "product_price": price,
#                                 "product_cost": cost,
#                                 "our_price": our_price,  # ✅ NEW FIELD
#                                 "expiration_date": exp_date,
#                                 "tax_type": tax_type,
#                                 "product_tax": product_tax,
#                                 "base_unit": product.get('base_unit_name'),
#                                 "sale_unit": product.get('sale_unit_name'),
#                                 "purchase_unit": product.get('purchase_unit_name'),
#                                 "sales_unit": product.get('sale_unit_id'),
#                                 "sales_units": [
#                                     {"id": product.get('base_unit_id'), "name": product.get('base_unit_name')},
#                                     {"id": product.get('sale_unit_id'), "name": product.get('sale_unit_name')},
#                                     {"id": product.get('purchase_unit_id'), "name": product.get('purchase_unit_name')}
#                                 ]
#                             })
#                     else:
#                         # ❌ No stock in warehouse
#                         display_name = f"{product['product_name']} - {var['variation_name']} - (Cost {default_cost:.0f} - Our Price 0 - Price {default_price:.0f}) Stock 0 Exp: -"
                        
#                         result.append({
#                             "product_id": product['id'],
#                             "variation_id": var['id'],
#                             "batch_ids": None,
#                             "grn_info": None,
#                             "product_name": product['product_name'],
#                             "variation_name": var['variation_name'],
#                             "variation_type": var['variation_type'],
#                             "sku": var['variation_sku'],
#                             "display_name": display_name,
#                             "product_type": "variation",
#                             "product_quantity": 0.0,
#                             "product_price": default_price,
#                             "product_cost": default_cost,
#                             "our_price": 0.0,  # ✅ NEW FIELD
#                             "expiration_date": '-',
#                             "tax_type": var.get('variation_tax_type') or product.get('tax_type'),
#                             "product_tax": float(var.get('variation_tax', 0)) if var.get('variation_tax') else float(product.get('product_tax', 0)),
#                             "base_unit": product.get('base_unit_name'),
#                             "sale_unit": product.get('sale_unit_name'),
#                             "purchase_unit": product.get('purchase_unit_name'),
#                             "sales_unit": product.get('sale_unit_id'),
#                             "sales_units": [
#                                 {"id": product.get('base_unit_id'), "name": product.get('base_unit_name')},
#                                 {"id": product.get('sale_unit_id'), "name": product.get('sale_unit_name')},
#                                 {"id": product.get('purchase_unit_id'), "name": product.get('purchase_unit_name')}
#                             ]
#                         })
            
#             # ============================================
#             # SINGLE PRODUCTS (no variations)
#             # ============================================
#             else:
#                 # Match if search term matches product name or SKU
#                 if (main_name.lower() in product['product_name'].lower()) or \
#                    (main_name.lower() in (product['sku'] or '').lower()):
                    
#                     # ✅ GET BATCHES GROUPED BY COST, PRICE, AND OUR_PRICE - WITH GRN INFO
#                     if warehouse_id and store_id:
#                         # Filter by both warehouse and store
#                         cursor.execute("""
#                             SELECT 
#                                 pb.price,
#                                 pb.cost,
#                                 pb.our_price,
#                                 MAX(pb.expiration_date) AS latest_expiration_date,
#                                 SUM(ws.quantity) AS total_stock,
#                                 GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
#                                 GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%Y-%m-%d'), ')') SEPARATOR ', ') AS grn_info
#                             FROM product_batches pb
#                             INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
#                             LEFT JOIN grn g ON pb.grn_id = g.grn_id
#                             WHERE pb.product_id = %s 
#                             AND pb.variation_id IS NULL 
#                             AND pb.remaining_quantity > 0
#                             AND ws.warehouse_id = %s
#                             AND ws.store_id = %s
#                             AND ws.quantity > 0
#                             GROUP BY pb.price, pb.cost, pb.our_price
#                             ORDER BY pb.price ASC, pb.cost ASC
#                         """, (product['id'], warehouse_id, store_id))
#                     elif warehouse_id:
#                         # Filter by warehouse only
#                         cursor.execute("""
#                             SELECT 
#                                 pb.price,
#                                 pb.cost,
#                                 pb.our_price,
#                                 MAX(pb.expiration_date) AS latest_expiration_date,
#                                 SUM(ws.quantity) AS total_stock,
#                                 GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
#                                 GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%Y-%m-%d'), ')') SEPARATOR ', ') AS grn_info
#                             FROM product_batches pb
#                             INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
#                             LEFT JOIN grn g ON pb.grn_id = g.grn_id
#                             WHERE pb.product_id = %s 
#                             AND pb.variation_id IS NULL 
#                             AND pb.remaining_quantity > 0
#                             AND ws.warehouse_id = %s
#                             AND ws.quantity > 0
#                             GROUP BY pb.price, pb.cost, pb.our_price
#                             ORDER BY pb.price ASC, pb.cost ASC
#                         """, (product['id'], warehouse_id))
#                     else:
#                         # No warehouse filter - show all stock
#                         cursor.execute("""
#                             SELECT 
#                                 price,
#                                 cost,
#                                 our_price,
#                                 MAX(expiration_date) AS latest_expiration_date,
#                                 SUM(remaining_quantity) AS total_stock,
#                                 GROUP_CONCAT(DISTINCT batch_id ORDER BY batch_id SEPARATOR ',') AS batch_ids,
#                                 GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%Y-%m-%d'), ')') SEPARATOR ', ') AS grn_info
#                             FROM product_batches pb
#                             LEFT JOIN grn g ON pb.grn_id = g.grn_id
#                             WHERE pb.product_id = %s 
#                             AND pb.variation_id IS NULL 
#                             AND pb.remaining_quantity > 0
#                             GROUP BY pb.price, pb.cost, pb.our_price
#                             ORDER BY pb.price ASC, pb.cost ASC
#                         """, (product['id'],))
                    
#                     batches = cursor.fetchall()

#                     if batches:
#                         # ✅ CREATE ENTRY FOR EACH UNIQUE COST/PRICE/OUR_PRICE COMBINATION
#                         for batch_group in batches:
#                             price = float(batch_group['price']) if batch_group['price'] else 0.0
#                             cost = float(batch_group['cost']) if batch_group['cost'] else 0.0
#                             our_price = float(batch_group['our_price']) if batch_group['our_price'] else 0.0
#                             stock = float(batch_group['total_stock']) if batch_group['total_stock'] else 0.0
#                             exp_date = batch_group['latest_expiration_date'].strftime('%Y-%m-%d') if batch_group['latest_expiration_date'] else '-'
#                             batch_ids = batch_group['batch_ids']
#                             grn_info = batch_group.get('grn_info', 'N/A')
                            
#                             # ✅ FIXED: Format quantity with decimals
#                             stock_display = f"{stock:.2f}".rstrip('0').rstrip('.') if stock % 1 else f"{int(stock)}"
                            
#                             # ✅ Format with our_price: MILK (MILK) - (Cost 120 - Our Price 140 - Price 150) Stock 200 Exp: 2025-01-15
#                             display_name = f"{product['product_name']} ({product['sku']}) - (Cost {cost:.0f} - Our Price {our_price:.0f} - Price {price:.0f}) Stock {stock_display} Exp: {exp_date}"
                            
#                             result.append({
#                                 "product_id": product['id'],
#                                 "variation_id": None,
#                                 "batch_ids": batch_ids,
#                                 "grn_info": grn_info,
#                                 "product_name": product['product_name'],
#                                 "variation_name": None,
#                                 "variation_type": None,
#                                 "sku": product['sku'],
#                                 "display_name": display_name,
#                                 "product_type": "single",
#                                 "product_quantity": stock,
#                                 "product_price": price,
#                                 "product_cost": cost,
#                                 "our_price": our_price,  # ✅ NEW FIELD
#                                 "expiration_date": exp_date,
#                                 "tax_type": product['tax_type'],
#                                 "product_tax": float(product['product_tax']) if product['product_tax'] else 0.0,
#                                 "base_unit": product.get('base_unit_name'),
#                                 "sale_unit": product.get('sale_unit_name'),
#                                 "purchase_unit": product.get('purchase_unit_name'),
#                                 "sales_unit": product.get('sale_unit_id'),
#                                 "sales_units": [
#                                     {"id": product.get('base_unit_id'), "name": product.get('base_unit_name')},
#                                     {"id": product.get('sale_unit_id'), "name": product.get('sale_unit_name')},
#                                     {"id": product.get('purchase_unit_id'), "name": product.get('purchase_unit_name')}
#                                 ]
#                             })
#                     else:
#                         # ❌ No stock in warehouse
#                         display_name = f"{product['product_name']} ({product['sku']}) - (Cost 0 - Our Price 0 - Price 0) Stock 0 Exp: -"
                        
#                         result.append({
#                             "product_id": product['id'],
#                             "variation_id": None,
#                             "batch_ids": None,
#                             "grn_info": None,
#                             "product_name": product['product_name'],
#                             "variation_name": None,
#                             "variation_type": None,
#                             "sku": product['sku'],
#                             "display_name": display_name,
#                             "product_type": "single",
#                             "product_quantity": 0.0,
#                             "product_price": 0.0,
#                             "product_cost": 0.0,
#                             "our_price": 0.0,  # ✅ NEW FIELD
#                             "expiration_date": '-',
#                             "tax_type": product['tax_type'],
#                             "product_tax": float(product['product_tax']) if product['product_tax'] else 0.0,
#                             "base_unit": product.get('base_unit_name'),
#                             "sale_unit": product.get('sale_unit_name'),
#                             "purchase_unit": product.get('purchase_unit_name'),
#                             "sales_unit": product.get('sale_unit_id'),
#                             "sales_units": [
#                                 {"id": product.get('base_unit_id'), "name": product.get('base_unit_name')},
#                                 {"id": product.get('sale_unit_id'), "name": product.get('sale_unit_name')},
#                                 {"id": product.get('purchase_unit_id'), "name": product.get('purchase_unit_name')}
#                             ]
#                         })

#         return jsonify(result), 200

#     except mysql.connector.Error as err:
#         print(f"❌ Database Error in search_sales_products: {err}")
#         traceback.print_exc()
#         return jsonify({"status": "error", "message": f"Database error: {str(err)}"}), 500
    
#     except Exception as e:
#         print(f"❌ Error in search_sales_products: {e}")
#         traceback.print_exc()
#         return jsonify({"status": "error", "message": str(e)}), 500

#     finally:
#         if cursor:
#             cursor.close()
#         if conn:
#             conn.close()



@product_bp.route('/search_sales_products', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def search_sales_products():
    """
    Search products and variations with warehouse-specific batch-based pricing and stock.
    Returns products available in the selected warehouse only.

    ✅ GROUPS batches by cost and price — shows combined stock with latest expiration date
    ✅ INCLUDES discount_rules[] per batch group from product_batch_discounts table
       [{payment_method_id, method_name, discount_rate, discount_type, is_active}]
    ✅ INCLUDES GRN information for traceability
    ✅ FIXED: Shows decimal quantities properly (0.25, 0.5, etc.)
    ✅ FIXED: Variation products now searchable by product name
    ✅ FIXED: Variation SKU search now works correctly
    ✅ REMOVED: our_price — use discount_rules[] for per-payment-method discounts
    """

    query        = request.args.get('query',       '').strip()
    productname  = request.args.get('productname', '').strip()
    variation    = request.args.get('variation',   '').strip()
    warehouse_id = request.args.get('warehouse_id','').strip()
    store_id     = request.args.get('store_id',    '').strip()

    search_term = query or productname or variation
    if not search_term:
        return jsonify({"status": "error", "message": "Query parameter is required"}), 400

    # Detect combined format "Product - Variation"
    if " - " in search_term:
        parts          = search_term.split(" - ", 1)
        main_name      = parts[0].strip()
        variation_name = parts[1].strip()
    else:
        main_name      = search_term
        variation_name = None

    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)
    result = []

    try:
        # ── Helper: fetch discount rules for a comma-separated list of batch_ids ─
        def get_discount_rules(batch_ids_csv):
            """
            batch_ids_csv is the GROUP_CONCAT string e.g. "12,13,15"
            Returns list of active discount rules (takes the first rule found per
            payment method across all batches in the group — they should be the same).
            """
            if not batch_ids_csv:
                return []

            ids = [int(b) for b in str(batch_ids_csv).split(',') if b.strip().isdigit()]
            if not ids:
                return []

            fmt = ','.join(['%s'] * len(ids))
            cursor.execute(f"""
                SELECT DISTINCT
                    pbd.payment_method_id,
                    pm.method_name,
                    pbd.discount_rate,
                    pbd.discount_type,
                    pbd.is_active
                FROM product_batch_discounts pbd
                LEFT JOIN payment_methods pm ON pm.id = pbd.payment_method_id
                WHERE pbd.batch_id IN ({fmt})
                  AND pbd.is_active = 1
                ORDER BY pbd.payment_method_id
            """, tuple(ids))

            rows = cursor.fetchall()
            return [
                {
                    'payment_method_id': r['payment_method_id'],
                    'method_name':       r['method_name'],
                    'discount_rate':     float(r['discount_rate']),
                    'discount_type':     r['discount_type'],
                    'is_active':         bool(r['is_active']),
                }
                for r in rows
            ]

        # ── Product search ────────────────────────────────────────────────────────
        cursor.execute("""
            SELECT DISTINCT
                p.id,
                p.product_name,
                p.sku,
                p.product_type,
                p.tax_type,
                p.product_tax,
                p.base_unit_id,
                p.sale_unit_id,
                p.purchase_unit_id,
                u1.unit_name  AS base_unit_name,
                u1.unit_short AS base_unit_short,
                u2.unit_name  AS sale_unit_name,
                u2.unit_short AS sale_unit_short,
                u3.unit_name  AS purchase_unit_name,
                u3.unit_short AS purchase_unit_short
            FROM products p
            LEFT JOIN units u1 ON p.base_unit_id    = u1.id
            LEFT JOIN units u2 ON p.sale_unit_id     = u2.id
            LEFT JOIN units u3 ON p.purchase_unit_id = u3.id
            LEFT JOIN product_variations pv ON pv.product_id = p.id
            WHERE p.product_name LIKE %s
               OR p.sku          LIKE %s
               OR pv.variation_sku  LIKE %s
               OR pv.variation_name LIKE %s
        """, (f"%{main_name}%", f"%{main_name}%", f"%{main_name}%", f"%{main_name}%"))
        products = cursor.fetchall()

        for product in products:

            # ── Batch SQL builder (GROUP BY price, cost — our_price removed) ────
            def batch_sql_variation(wh_id, st_id):
                """Returns (sql, params) for variation batch query."""
                if wh_id and st_id:
                    sql = """
                        SELECT
                            pb.price,
                            pb.cost,
                            MAX(pb.expiration_date) AS latest_expiration_date,
                            SUM(ws.quantity)         AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%%Y-%%m-%%d'), ')') SEPARATOR ', ') AS grn_info
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                        LEFT JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s AND pb.variation_id = %s
                          AND pb.remaining_quantity > 0
                          AND ws.warehouse_id = %s AND ws.store_id = %s
                          AND ws.quantity > 0
                        GROUP BY pb.price, pb.cost
                        ORDER BY pb.price ASC, pb.cost ASC
                    """
                    return sql, (product['id'], None, wh_id, st_id)   # variation_id patched below
                elif wh_id:
                    sql = """
                        SELECT
                            pb.price,
                            pb.cost,
                            MAX(pb.expiration_date) AS latest_expiration_date,
                            SUM(ws.quantity)         AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%%Y-%%m-%%d'), ')') SEPARATOR ', ') AS grn_info
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                        LEFT JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s AND pb.variation_id = %s
                          AND pb.remaining_quantity > 0
                          AND ws.warehouse_id = %s
                          AND ws.quantity > 0
                        GROUP BY pb.price, pb.cost
                        ORDER BY pb.price ASC, pb.cost ASC
                    """
                    return sql, (product['id'], None, wh_id)
                else:
                    sql = """
                        SELECT
                            pb.price,
                            pb.cost,
                            MAX(pb.expiration_date)       AS latest_expiration_date,
                            SUM(pb.remaining_quantity)    AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%%Y-%%m-%%d'), ')') SEPARATOR ', ') AS grn_info
                        FROM product_batches pb
                        LEFT JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s AND pb.variation_id = %s
                          AND pb.remaining_quantity > 0
                        GROUP BY pb.price, pb.cost
                        ORDER BY pb.price ASC, pb.cost ASC
                    """
                    return sql, (product['id'], None)

            def batch_sql_single(wh_id, st_id):
                """Returns (sql, params) for single product batch query."""
                if wh_id and st_id:
                    sql = """
                        SELECT
                            pb.price,
                            pb.cost,
                            MAX(pb.expiration_date) AS latest_expiration_date,
                            SUM(ws.quantity)         AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%%Y-%%m-%%d'), ')') SEPARATOR ', ') AS grn_info
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                        LEFT JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s
                          AND pb.variation_id IS NULL
                          AND pb.remaining_quantity > 0
                          AND ws.warehouse_id = %s AND ws.store_id = %s
                          AND ws.quantity > 0
                        GROUP BY pb.price, pb.cost
                        ORDER BY pb.price ASC, pb.cost ASC
                    """
                    return sql, (product['id'], wh_id, st_id)
                elif wh_id:
                    sql = """
                        SELECT
                            pb.price,
                            pb.cost,
                            MAX(pb.expiration_date) AS latest_expiration_date,
                            SUM(ws.quantity)         AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%%Y-%%m-%%d'), ')') SEPARATOR ', ') AS grn_info
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                        LEFT JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s
                          AND pb.variation_id IS NULL
                          AND pb.remaining_quantity > 0
                          AND ws.warehouse_id = %s
                          AND ws.quantity > 0
                        GROUP BY pb.price, pb.cost
                        ORDER BY pb.price ASC, pb.cost ASC
                    """
                    return sql, (product['id'], wh_id)
                else:
                    sql = """
                        SELECT
                            pb.price,
                            pb.cost,
                            MAX(pb.expiration_date)    AS latest_expiration_date,
                            SUM(pb.remaining_quantity) AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            GROUP_CONCAT(DISTINCT CONCAT(g.grn_code, ' (', DATE_FORMAT(g.grn_date, '%%Y-%%m-%%d'), ')') SEPARATOR ', ') AS grn_info
                        FROM product_batches pb
                        LEFT JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s
                          AND pb.variation_id IS NULL
                          AND pb.remaining_quantity > 0
                        GROUP BY pb.price, pb.cost
                        ORDER BY pb.price ASC, pb.cost ASC
                    """
                    return sql, (product['id'],)

            # ================================================================
            # VARIABLE PRODUCT
            # ================================================================
            if product['product_type'] == 'variable':

                if variation_name:
                    cursor.execute("""
                        SELECT pv.id, pv.variation_name, pv.variation_type, pv.variation_sku,
                               pv.variation_cost, pv.variation_price,
                               pv.variation_tax_type, pv.variation_tax
                        FROM product_variations pv
                        WHERE pv.product_id = %s
                          AND (pv.variation_name LIKE %s
                               OR pv.variation_type LIKE %s
                               OR pv.variation_sku  LIKE %s)
                    """, (product['id'], f"%{variation_name}%", f"%{variation_name}%", f"%{variation_name}%"))
                else:
                    cursor.execute("""
                        SELECT COUNT(*) AS cnt
                        FROM product_variations
                        WHERE product_id = %s
                          AND (variation_sku LIKE %s OR variation_name LIKE %s)
                    """, (product['id'], f"%{main_name}%", f"%{main_name}%"))
                    sku_match_count = cursor.fetchone()['cnt']

                    if sku_match_count > 0 and main_name.lower() not in product['product_name'].lower():
                        cursor.execute("""
                            SELECT pv.id, pv.variation_name, pv.variation_type, pv.variation_sku,
                                   pv.variation_cost, pv.variation_price,
                                   pv.variation_tax_type, pv.variation_tax
                            FROM product_variations pv
                            WHERE pv.product_id = %s
                              AND (pv.variation_sku LIKE %s OR pv.variation_name LIKE %s)
                        """, (product['id'], f"%{main_name}%", f"%{main_name}%"))
                    else:
                        cursor.execute("""
                            SELECT pv.id, pv.variation_name, pv.variation_type, pv.variation_sku,
                                   pv.variation_cost, pv.variation_price,
                                   pv.variation_tax_type, pv.variation_tax
                            FROM product_variations pv
                            WHERE pv.product_id = %s
                        """, (product['id'],))

                variations = cursor.fetchall()

                for var in variations:
                    default_price = float(var['variation_price']) if var['variation_price'] else 0.0
                    default_cost  = float(var['variation_cost'])  if var['variation_cost']  else 0.0

                    # Build batch query with variation_id patched in
                    sql, params = batch_sql_variation(warehouse_id, store_id)
                    # patch None → actual variation id
                    params = tuple(
                        var['id'] if p is None else p
                        for p in params
                    )
                    cursor.execute(sql, params)
                    batches = cursor.fetchall()

                    tax_type   = var.get('variation_tax_type') or product.get('tax_type')
                    product_tax = float(var['variation_tax']) if var.get('variation_tax') else float(product.get('product_tax', 0))

                    base_units = [
                        {"id": product.get('base_unit_id'),     "name": product.get('base_unit_name')},
                        {"id": product.get('sale_unit_id'),     "name": product.get('sale_unit_name')},
                        {"id": product.get('purchase_unit_id'), "name": product.get('purchase_unit_name')},
                    ]

                    if batches:
                        for bg in batches:
                            price    = float(bg['price'])        if bg['price']        else default_price
                            cost     = float(bg['cost'])         if bg['cost']         else default_cost
                            stock    = float(bg['total_stock'])  if bg['total_stock']  else 0.0
                            exp_date = bg['latest_expiration_date'].strftime('%Y-%m-%d') if bg['latest_expiration_date'] else '-'
                            batch_ids = bg['batch_ids']
                            grn_info  = bg.get('grn_info', 'N/A')

                            # Fetch discount rules for these batch ids
                            discount_rules = get_discount_rules(batch_ids)

                            stock_display = f"{stock:.2f}".rstrip('0').rstrip('.') if stock % 1 else str(int(stock))
                            display_name  = (
                                f"{product['product_name']} - {var['variation_name']} - "
                                f"(Cost {cost:.0f} - Price {price:.0f}) "
                                f"Stock {stock_display} Exp: {exp_date}"
                            )

                            result.append({
                                "product_id":     product['id'],
                                "variation_id":   var['id'],
                                "batch_ids":      batch_ids,
                                "grn_info":       grn_info,
                                "product_name":   product['product_name'],
                                "variation_name": var['variation_name'],
                                "variation_type": var['variation_type'],
                                "sku":            var['variation_sku'],
                                "display_name":   display_name,
                                "product_type":   "variation",
                                "product_quantity": stock,
                                "product_price":  price,
                                "product_cost":   cost,
                                "discount_rules": discount_rules,
                                "expiration_date": exp_date,
                                "tax_type":       tax_type,
                                "product_tax":    product_tax,
                                "base_unit":      product.get('base_unit_name'),
                                "sale_unit":      product.get('sale_unit_name'),
                                "purchase_unit":  product.get('purchase_unit_name'),
                                "sales_unit":     product.get('sale_unit_id'),
                                "sales_units":    base_units,
                            })
                    else:
                        # No stock in warehouse
                        display_name = (
                            f"{product['product_name']} - {var['variation_name']} - "
                            f"(Cost {default_cost:.0f} - Price {default_price:.0f}) "
                            f"Stock 0 Exp: -"
                        )
                        result.append({
                            "product_id":     product['id'],
                            "variation_id":   var['id'],
                            "batch_ids":      None,
                            "grn_info":       None,
                            "product_name":   product['product_name'],
                            "variation_name": var['variation_name'],
                            "variation_type": var['variation_type'],
                            "sku":            var['variation_sku'],
                            "display_name":   display_name,
                            "product_type":   "variation",
                            "product_quantity": 0.0,
                            "product_price":  default_price,
                            "product_cost":   default_cost,
                            "discount_rules": [],
                            "expiration_date": '-',
                            "tax_type":       tax_type,
                            "product_tax":    product_tax,
                            "base_unit":      product.get('base_unit_name'),
                            "sale_unit":      product.get('sale_unit_name'),
                            "purchase_unit":  product.get('purchase_unit_name'),
                            "sales_unit":     product.get('sale_unit_id'),
                            "sales_units":    base_units,
                        })

            # ================================================================
            # SINGLE PRODUCT
            # ================================================================
            else:
                if (main_name.lower() in product['product_name'].lower()) or \
                   (main_name.lower() in (product['sku'] or '').lower()):

                    sql, params = batch_sql_single(warehouse_id, store_id)
                    cursor.execute(sql, params)
                    batches = cursor.fetchall()

                    base_units = [
                        {"id": product.get('base_unit_id'),     "name": product.get('base_unit_name')},
                        {"id": product.get('sale_unit_id'),     "name": product.get('sale_unit_name')},
                        {"id": product.get('purchase_unit_id'), "name": product.get('purchase_unit_name')},
                    ]

                    if batches:
                        for bg in batches:
                            price    = float(bg['price'])       if bg['price']       else 0.0
                            cost     = float(bg['cost'])        if bg['cost']        else 0.0
                            stock    = float(bg['total_stock']) if bg['total_stock'] else 0.0
                            exp_date = bg['latest_expiration_date'].strftime('%Y-%m-%d') if bg['latest_expiration_date'] else '-'
                            batch_ids = bg['batch_ids']
                            grn_info  = bg.get('grn_info', 'N/A')

                            # Fetch discount rules for these batch ids
                            discount_rules = get_discount_rules(batch_ids)

                            stock_display = f"{stock:.2f}".rstrip('0').rstrip('.') if stock % 1 else str(int(stock))
                            display_name  = (
                                f"{product['product_name']} ({product['sku']}) - "
                                f"(Cost {cost:.0f} - Price {price:.0f}) "
                                f"Stock {stock_display} Exp: {exp_date}"
                            )

                            result.append({
                                "product_id":     product['id'],
                                "variation_id":   None,
                                "batch_ids":      batch_ids,
                                "grn_info":       grn_info,
                                "product_name":   product['product_name'],
                                "variation_name": None,
                                "variation_type": None,
                                "sku":            product['sku'],
                                "display_name":   display_name,
                                "product_type":   "single",
                                "product_quantity": stock,
                                "product_price":  price,
                                "product_cost":   cost,
                                "discount_rules": discount_rules,
                                "expiration_date": exp_date,
                                "tax_type":       product['tax_type'],
                                "product_tax":    float(product['product_tax']) if product['product_tax'] else 0.0,
                                "base_unit":      product.get('base_unit_name'),
                                "sale_unit":      product.get('sale_unit_name'),
                                "purchase_unit":  product.get('purchase_unit_name'),
                                "sales_unit":     product.get('sale_unit_id'),
                                "sales_units":    base_units,
                            })
                    else:
                        # No stock in warehouse
                        display_name = (
                            f"{product['product_name']} ({product['sku']}) - "
                            f"(Cost 0 - Price 0) Stock 0 Exp: -"
                        )
                        result.append({
                            "product_id":     product['id'],
                            "variation_id":   None,
                            "batch_ids":      None,
                            "grn_info":       None,
                            "product_name":   product['product_name'],
                            "variation_name": None,
                            "variation_type": None,
                            "sku":            product['sku'],
                            "display_name":   display_name,
                            "product_type":   "single",
                            "product_quantity": 0.0,
                            "product_price":  0.0,
                            "product_cost":   0.0,
                            "discount_rules": [],
                            "expiration_date": '-',
                            "tax_type":       product['tax_type'],
                            "product_tax":    float(product['product_tax']) if product['product_tax'] else 0.0,
                            "base_unit":      product.get('base_unit_name'),
                            "sale_unit":      product.get('sale_unit_name'),
                            "purchase_unit":  product.get('purchase_unit_name'),
                            "sales_unit":     product.get('sale_unit_id'),
                            "sales_units":    base_units,
                        })

        return jsonify(result), 200

    except mysql.connector.Error as err:
        print(f"❌ Database Error in search_sales_products: {err}")
        traceback.print_exc()
        return jsonify({"status": "error", "message": f"Database error: {str(err)}"}), 500

    except Exception as e:
        print(f"❌ Error in search_sales_products: {e}")
        traceback.print_exc()
        return jsonify({"status": "error", "message": str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


@product_bp.route('/warehouse_stock_search', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def warehouse_stock_search():
    """
    Search products - Groups by Product + Cost + Price + Supplier + Expiry
    Gets batch details from product_batches table
    """
    
    warehouse_id = request.args.get('warehouse_id', '').strip()
    store_id = request.args.get('store_id', '').strip()
    query = request.args.get('query', '').strip()
    
    if not warehouse_id or not store_id:
        return jsonify({
            "status": "error", 
            "message": "warehouse_id and store_id are required"
        }), 400
    
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500
        
    cursor = conn.cursor(dictionary=True)
    result = []
    
    try:
        search_condition = ""
        search_params = [warehouse_id, store_id]
        
        if query:
            search_condition = """
                AND (
                    p.product_name LIKE %s 
                    OR p.sku LIKE %s
                    OR pv.variation_name LIKE %s
                    OR pv.variation_type LIKE %s
                    OR pv.variation_sku LIKE %s
                    OR pb.batch_number LIKE %s
                    OR g.grn_code LIKE %s
                    OR s.supplier_name LIKE %s
                )
            """
            search_term = f"%{query}%"
            search_params.extend([search_term] * 8)
        
        # ✅ CORRECTED: Get cost, price, expiry from product_batches (pb)
        query_sql = f"""
            SELECT 
                -- Grouping Keys (from product_batches)
                ws.product_id,
                ws.variation_id,
                COALESCE(pb.cost, 0) as batch_cost,
                COALESCE(pb.price, 0) as batch_price,
                pb.expiration_date,
                po.supplier_id,
                
                -- Aggregated Data
                GROUP_CONCAT(DISTINCT ws.batch_id ORDER BY ws.batch_id) as batch_ids,
                SUM(ws.quantity) as total_stock,
                GROUP_CONCAT(DISTINCT pb.batch_number ORDER BY pb.batch_number SEPARATOR ', ') as batch_numbers,
                GROUP_CONCAT(DISTINCT g.grn_code ORDER BY g.grn_code SEPARATOR ', ') as grn_codes,
                MIN(ws.id) as first_stock_id,
                
                -- Product Details
                p.product_name,
                p.sku,
                p.product_type,
                p.tax_type,
                p.product_tax,
                p.base_unit_id,
                p.sale_unit_id,
                p.purchase_unit_id,
                
                -- Unit Names
                u1.unit_name AS base_unit_name,
                u1.unit_short AS base_unit_short,
                u2.unit_name AS sale_unit_name,
                u2.unit_short AS sale_unit_short,
                u3.unit_name AS purchase_unit_name,
                u3.unit_short AS purchase_unit_short,
                
                -- Variation Details
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                pv.variation_cost,
                pv.variation_price,
                pv.variation_tax_type,
                pv.variation_tax,
                
                -- Supplier Details
                s.supplier_name,
                s.supplier_code
                
            FROM warehouse_stock ws
            INNER JOIN products p ON ws.product_id = p.id
            LEFT JOIN product_variations pv ON ws.variation_id = pv.id
            LEFT JOIN product_batches pb ON ws.batch_id = pb.batch_id  -- ✅ JOIN to get batch details
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN purchase_orders po ON g.purchase_order_id = po.order_id
            LEFT JOIN suppliers s ON po.supplier_id = s.id
            LEFT JOIN units u1 ON p.base_unit_id = u1.id
            LEFT JOIN units u2 ON p.sale_unit_id = u2.id
            LEFT JOIN units u3 ON p.purchase_unit_id = u3.id
            
            WHERE ws.warehouse_id = %s
            AND ws.store_id = %s
            AND ws.quantity > 0
            {search_condition}
            
            GROUP BY 
                ws.product_id,
                ws.variation_id,
                batch_cost,          -- ✅ From product_batches
                batch_price,         -- ✅ From product_batches
                pb.expiration_date,  -- ✅ From product_batches
                po.supplier_id,
                p.product_name,
                p.sku,
                p.product_type,
                p.tax_type,
                p.product_tax,
                p.base_unit_id,
                p.sale_unit_id,
                p.purchase_unit_id,
                u1.unit_name,
                u1.unit_short,
                u2.unit_name,
                u2.unit_short,
                u3.unit_name,
                u3.unit_short,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                pv.variation_cost,
                pv.variation_price,
                pv.variation_tax_type,
                pv.variation_tax,
                s.supplier_name,
                s.supplier_code
                
            ORDER BY p.product_name, pv.variation_name, batch_cost, pb.expiration_date
        """
        
        cursor.execute(query_sql, search_params)
        stock_groups = cursor.fetchall()
        
        # Format results
        for item in stock_groups:
            # Determine tax type and tax amount
            tax_type = item.get('variation_tax_type') or item.get('tax_type')
            product_tax = float(item.get('variation_tax', 0)) if item.get('variation_tax') else float(item.get('product_tax', 0))
            
            # ✅ Cost and price from product_batches (via pb.cost, pb.price)
            cost = float(item['batch_cost']) if item['batch_cost'] else 0.0
            price = float(item['batch_price']) if item['batch_price'] else 0.0
            
            # Total stock for this group
            total_stock = float(item['total_stock']) if item['total_stock'] else 0.0
            
            # ✅ Expiration date from product_batches
            exp_date = item['expiration_date'].strftime('%Y-%m-%d') if item['expiration_date'] else '-'
            
            # Supplier info
            supplier_name = item['supplier_name'] or 'Unknown'
            supplier_code = item['supplier_code'] or '-'
            
            # Batch info
            batch_numbers = item['batch_numbers'] or '-'
            grn_codes = item['grn_codes'] or '-'
            
            # Build display name
            if item['variation_id']:
                base_name = f"{item['product_name']} - {item['variation_name']}"
                if item['variation_type']:
                    base_name += f" ({item['variation_type']})"
                sku = item['variation_sku']
            else:
                base_name = item['product_name']
                sku = item['sku']
            
            # Display format
            display_name = (
                f"{base_name} | "
                f"Supplier: {supplier_name} | "
                f"Cost {cost:.2f} - Price {price:.2f} | "
                f"Stock {total_stock:.2f} | "
                f"Exp: {exp_date}"
            )
            
            if batch_numbers != '-':
                display_name += f" | Batches: {batch_numbers}"
            
            result.append({
                # IDs
                "stock_id": item['first_stock_id'],
                "product_id": item['product_id'],
                "variation_id": item['variation_id'],
                "batch_ids": item['batch_ids'],  # Comma-separated list
                "supplier_id": item['supplier_id'],
                
                # Product Info
                "product_name": item['product_name'],
                "variation_name": item['variation_name'],
                "variation_type": item['variation_type'],
                "sku": sku,
                "display_name": display_name,
                "product_type": "variation" if item['variation_id'] else "single",
                
                # ✅ Pricing & Stock from product_batches
                "product_price": price,
                "product_cost": cost,
                "product_quantity": total_stock,
                
                # ✅ Batch & Supplier Info from product_batches
                "batch_numbers": batch_numbers,
                "grn_codes": grn_codes,
                "supplier_name": supplier_name,
                "supplier_code": supplier_code,
                "expiration_date": exp_date,
                
                # Tax Info
                "tax_type": tax_type,
                "product_tax": product_tax,
                
                # Units
                "base_unit": item.get('base_unit_name'),
                "sale_unit": item.get('sale_unit_name'),
                "purchase_unit": item.get('purchase_unit_name'),
                "sales_unit": item.get('sale_unit_id'),
                "sales_units": [
                    {"id": item.get('base_unit_id'), "name": item.get('base_unit_name')},
                    {"id": item.get('sale_unit_id'), "name": item.get('sale_unit_name')},
                    {"id": item.get('purchase_unit_id'), "name": item.get('purchase_unit_name')}
                ]
            })
        
        return jsonify(result), 200
        
    except mysql.connector.Error as err:
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({
            "status": "error", 
            "message": f"Database error: {str(err)}"
        }), 500
    
    except Exception as e:
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return jsonify({
            "status": "error", 
            "message": str(e)
        }), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


@product_bp.route('/search_products_for_purchase_return', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def search_products_for_purchase_return():
    """
    ✅ GRN-AWARE Purchase Return Search
    
    Groups products by:
    - Product/Variation
    - Supplier
    - Cost
    - Price
    - Expiration Date (if exists)
    
    Shows separate entries for different cost/price/expiry combinations
    User can select specific batch group to return
    
    Example Results:
    1. Milk 1L - Supplier ABC - Cost 120 - Price 150 - Exp: 2025-06-15 - Stock: 50
    2. Milk 1L - Supplier ABC - Cost 120 - Price 150 - Exp: 2025-07-20 - Stock: 30
    3. Milk 1L - Supplier XYZ - Cost 125 - Price 155 - Exp: 2025-06-10 - Stock: 25
    """
    query = request.args.get('query', '').strip()
    warehouse_id = request.args.get('warehouse_id', '').strip()
    store_id = request.args.get('store_id', '').strip()

    if not query:
        return jsonify({"error": "Query parameter is required"}), 400
    
    if not warehouse_id or not store_id:
        return jsonify({"error": "Both warehouse_id and store_id are required"}), 400

    # Parse search term (handle "Product - Variation" format)
    main_name = None
    variation_name = None
    if " - " in query:
        parts = query.split(" - ", 1)
        main_name = parts[0].strip()
        variation_name = parts[1].strip()
    else:
        main_name = query
        variation_name = query

    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500
        
    cursor = conn.cursor(dictionary=True)
    result = []

    try:
        # ============================================
        # STEP 1: Search Products
        # ============================================
        cursor.execute("""
            SELECT 
                p.id,
                p.product_name,
                p.sku,
                p.product_type,
                p.tax_type,
                p.product_tax,
                p.purchase_unit_id,
                u.unit_short AS purchase_unit
            FROM products p
            LEFT JOIN units u ON p.purchase_unit_id = u.id
            WHERE p.product_name LIKE %s OR p.sku LIKE %s
        """, (f"%{main_name}%", f"%{main_name}%"))
        
        products = cursor.fetchall()

        for product in products:
            product_id = product['id']
            product_name = product['product_name']
            product_sku = product['sku']
            product_type = product['product_type']
            purchase_unit = product['purchase_unit'] or ''
            tax_type = product['tax_type']
            product_tax = float(product['product_tax']) if product['product_tax'] else 0.0

            # ============================================
            # VARIABLE PRODUCTS (with variations)
            # ============================================
            if product_type == 'variable':
                cursor.execute("""
                    SELECT 
                        pv.id,
                        pv.variation_name,
                        pv.variation_type,
                        pv.variation_sku,
                        pv.variation_tax_type,
                        pv.variation_tax
                    FROM product_variations pv
                    WHERE pv.product_id = %s
                    AND (pv.variation_name LIKE %s 
                         OR pv.variation_type LIKE %s 
                         OR pv.variation_sku LIKE %s)
                """, (product_id, f"%{variation_name}%", f"%{variation_name}%", f"%{variation_name}%"))
                
                variations = cursor.fetchall()

                for var in variations:
                    variation_id = var['id']
                    variation_name_str = var['variation_name']
                    variation_sku = var['variation_sku']
                    var_tax_type = var['variation_tax_type'] or tax_type
                    var_tax = float(var['variation_tax']) if var['variation_tax'] else product_tax

                    # ✅ GROUP BY: Supplier, Cost, Price, Expiration Date
                    # ✅ FIXED: MariaDB-compatible NULL handling
                    cursor.execute("""
                        SELECT 
                            po.supplier_id,
                            s.supplier_name,
                            pb.cost,
                            pb.price,
                            pb.expiration_date,
                            SUM(ws.quantity) AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            MIN(pb.batch_id) AS first_batch_id,
                            GROUP_CONCAT(DISTINCT g.grn_code ORDER BY g.grn_code SEPARATOR ', ') AS grn_codes
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                        INNER JOIN grn g ON pb.grn_id = g.grn_id
                        INNER JOIN purchase_orders po ON g.purchase_order_id = po.order_id
                        LEFT JOIN suppliers s ON po.supplier_id = s.id
                        WHERE pb.product_id = %s 
                        AND pb.variation_id = %s 
                        AND pb.remaining_quantity > 0
                        AND ws.warehouse_id = %s
                        AND ws.store_id = %s
                        AND ws.quantity > 0
                        AND g.status = 'completed'
                        AND po.status = 'Received'
                        GROUP BY 
                            po.supplier_id, 
                            pb.cost, 
                            pb.price, 
                            pb.expiration_date
                        ORDER BY 
                            s.supplier_name ASC, 
                            CASE WHEN pb.expiration_date IS NULL THEN 1 ELSE 0 END,
                            pb.expiration_date ASC,
                            pb.cost ASC, 
                            pb.price ASC
                    """, (product_id, variation_id, warehouse_id, store_id))
                    
                    batch_groups = cursor.fetchall()

                    if batch_groups:
                        for group in batch_groups:
                            supplier_id = group['supplier_id']
                            supplier_name = group['supplier_name'] or 'Unknown'
                            cost = float(group['cost']) if group['cost'] else 0.0
                            price = float(group['price']) if group['price'] else 0.0
                            expiration_date = group['expiration_date']
                            stock = float(group['total_stock']) if group['total_stock'] else 0.0
                            batch_ids = group['batch_ids']
                            grn_codes = group['grn_codes']
                            
                            # Format expiration date
                            if expiration_date:
                                exp_str = expiration_date.strftime('%Y-%m-%d')
                                exp_display = f"Exp: {exp_str}"
                            else:
                                exp_str = None
                                exp_display = "No Expiry"
                            
                            # ✅ BUILD DISPLAY NAME
                            display_name = (
                                f"{product_name} - {variation_name_str} | "
                                f"Supplier: {supplier_name} | "
                                f"Cost: Rs {cost:.2f} | "
                                f"Price: Rs {price:.2f} | "
                                f"{exp_display} | "
                                f"Stock: {stock:.2f} {purchase_unit}"
                            )
                            
                            result.append({
                                "product_id": product_id,
                                "variation_id": variation_id,
                                "batch_ids": batch_ids,
                                "grn_codes": grn_codes,
                                "supplier_id": supplier_id,
                                "supplier_name": supplier_name,
                                "product_name": product_name,
                                "variation_name": variation_name_str,
                                "sku": variation_sku,
                                "display_name": display_name,
                                "product_type": "variation",
                                "product_quantity": stock,
                                "product_price": price,
                                "product_cost": cost,
                                "expiration_date": exp_str,
                                "tax_type": var_tax_type,
                                "product_tax": var_tax,
                                "purchase_unit": purchase_unit
                            })

            # ============================================
            # SINGLE PRODUCTS (no variations)
            # ============================================
            else:
                if (main_name.lower() in product_name.lower()) or \
                   (variation_name.lower() in product_name.lower()) or \
                   (main_name.lower() in (product_sku or '').lower()):
                    
                    # ✅ GROUP BY: Supplier, Cost, Price, Expiration Date
                    # ✅ FIXED: MariaDB-compatible NULL handling
                    cursor.execute("""
                        SELECT 
                            po.supplier_id,
                            s.supplier_name,
                            pb.cost,
                            pb.price,
                            pb.expiration_date,
                            SUM(ws.quantity) AS total_stock,
                            GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                            MIN(pb.batch_id) AS first_batch_id,
                            GROUP_CONCAT(DISTINCT g.grn_code ORDER BY g.grn_code SEPARATOR ', ') AS grn_codes
                        FROM product_batches pb
                        INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                        INNER JOIN grn g ON pb.grn_id = g.grn_id
                        INNER JOIN purchase_orders po ON g.purchase_order_id = po.order_id
                        LEFT JOIN suppliers s ON po.supplier_id = s.id
                        WHERE pb.product_id = %s 
                        AND pb.variation_id IS NULL 
                        AND pb.remaining_quantity > 0
                        AND ws.warehouse_id = %s
                        AND ws.store_id = %s
                        AND ws.quantity > 0
                        AND g.status = 'completed'
                        AND po.status = 'Received'
                        GROUP BY 
                            po.supplier_id, 
                            pb.cost, 
                            pb.price, 
                            pb.expiration_date
                        ORDER BY 
                            s.supplier_name ASC, 
                            CASE WHEN pb.expiration_date IS NULL THEN 1 ELSE 0 END,
                            pb.expiration_date ASC,
                            pb.cost ASC, 
                            pb.price ASC
                    """, (product_id, warehouse_id, store_id))
                    
                    batch_groups = cursor.fetchall()

                    if batch_groups:
                        for group in batch_groups:
                            supplier_id = group['supplier_id']
                            supplier_name = group['supplier_name'] or 'Unknown'
                            cost = float(group['cost']) if group['cost'] else 0.0
                            price = float(group['price']) if group['price'] else 0.0
                            expiration_date = group['expiration_date']
                            stock = float(group['total_stock']) if group['total_stock'] else 0.0
                            batch_ids = group['batch_ids']
                            grn_codes = group['grn_codes']
                            
                            # Format expiration date
                            if expiration_date:
                                exp_str = expiration_date.strftime('%Y-%m-%d')
                                exp_display = f"Exp: {exp_str}"
                            else:
                                exp_str = None
                                exp_display = "No Expiry"
                            
                            # ✅ BUILD DISPLAY NAME
                            display_name = (
                                f"{product_name} ({product_sku}) | "
                                f"Supplier: {supplier_name} | "
                                f"Cost: Rs {cost:.2f} | "
                                f"Price: Rs {price:.2f} | "
                                f"{exp_display} | "
                                f"Stock: {stock:.2f} {purchase_unit}"
                            )
                            
                            result.append({
                                "product_id": product_id,
                                "variation_id": None,
                                "batch_ids": batch_ids,
                                "grn_codes": grn_codes,
                                "supplier_id": supplier_id,
                                "supplier_name": supplier_name,
                                "product_name": product_name,
                                "variation_name": None,
                                "sku": product_sku,
                                "display_name": display_name,
                                "product_type": "single",
                                "product_quantity": stock,
                                "product_price": price,
                                "product_cost": cost,
                                "expiration_date": exp_str,
                                "tax_type": tax_type,
                                "product_tax": product_tax,
                                "purchase_unit": purchase_unit
                            })

        return jsonify(result), 200

    except mysql.connector.Error as err:
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({"error": f"Database error: {str(err)}"}), 500
    
    except Exception as e:
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
            
# @product_bp.route('/variations', methods=['GET'])
# @jwt_required()
# @role_required('admin')
# def view_all_variation():
#     conn = get_db_connection()
#     if conn is None:
#         return jsonify({"error": "Database connection failed"}), 500

#     try:
#         cursor = conn.cursor(dictionary=True)

#         query = """
#             SELECT 
#                 p.product_name AS product_name,
#                 p.sku AS product_sku,
#                 pv.variation_name AS variation_name,
#                 pv.variation_sku AS variation_sku
#             FROM product_variations pv
#             JOIN products p ON pv.product_id = p.id
#             ORDER BY p.product_name, pv.variation_name
#         """
#         cursor.execute(query)
#         results = cursor.fetchall()

#         return jsonify(results), 200

#     except Exception as e:
#         print("Error:", e)
#         traceback.print_exc()
#         return jsonify({"error": "Failed to fetch variations"}), 500
#     finally:
#         cursor.close()
#         conn.close()



