from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
from datetime import datetime
from decimal import Decimal
import json
import mysql.connector
import traceback

stock_transfer_bp = Blueprint('stock_transfer', __name__)


# ============================================
# CREATE STOCK TRANSFER
# ============================================
@stock_transfer_bp.route('/create', methods=['POST'])
@jwt_required()
@role_required('admin')
def create_stock_transfer():
    """
    ✅ CORRECTED: Create a new stock transfer between warehouses/stores
    
    FIXES:
    1. Proper batch_id handling for NULL variation_id
    2. Correct warehouse-specific stock verification
    3. SERIALIZABLE transaction isolation to prevent race conditions
    4. INSERT ... ON DUPLICATE KEY UPDATE for atomic operations
    5. Proper error handling and rollback
    """
    
    connection = None
    cursor = None
    
    try:
        current_user_id = get_jwt_identity()
        data = request.get_json()
        
        # Validate required fields
        required_fields = ['from_store_id', 'to_store_id', 'transfer_date', 'status', 'items']
        for field in required_fields:
            if field not in data:
                return jsonify({
                    'success': False,
                    'message': f'Missing required field: {field}'
                }), 400
        
        if not isinstance(data['items'], list) or len(data['items']) == 0:
            return jsonify({
                'success': False,
                'message': 'Items array must contain at least one item'
            }), 400
        
        from_store_id = data['from_store_id']
        from_warehouse_id = data.get('from_warehouse_id')
        to_store_id = data['to_store_id']
        to_warehouse_id = data.get('to_warehouse_id')
        transfer_date = data['transfer_date']
        status = data['status']
        note = data.get('note', '')
        items = data['items']
        
        valid_statuses = ['pending', 'in_transit', 'completed', 'cancelled']
        if status not in valid_statuses:
            return jsonify({
                'success': False,
                'message': f'Invalid status. Must be one of: {", ".join(valid_statuses)}'
            }), 400
        
        if from_store_id == to_store_id and from_warehouse_id == to_warehouse_id:
            return jsonify({
                'success': False,
                'message': 'Source and destination cannot be the same'
            }), 400
        
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        connection.start_transaction(isolation_level='SERIALIZABLE')
        
        transfer_code = generate_transfer_code(cursor, transfer_date)
        
        # Insert Transfer Header
        insert_transfer_query = """
            INSERT INTO stock_transfers (
                transfer_code,
                from_store_id,
                from_warehouse_id,
                to_store_id,
                to_warehouse_id,
                transfer_date,
                note,
                total_items,
                status,
                created_by
            ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """
        
        cursor.execute(insert_transfer_query, (
            transfer_code,
            from_store_id,
            from_warehouse_id,
            to_store_id,
            to_warehouse_id,
            transfer_date,
            note,
            len(items),
            status,
            current_user_id
        ))
        
        transfer_id = cursor.lastrowid
        
        print(f"\n{'='*60}")
        print(f"🔄 CREATING STOCK TRANSFER - ID: {transfer_id}")
        print(f"Code: {transfer_code}")
        print(f"{'='*60}")
        
        # Process Each Transfer Item
        for idx, item in enumerate(items, 1):
            if 'product_id' not in item or 'quantity' not in item or 'batch_id' not in item:
                raise ValueError('Each item must have product_id, batch_id, and quantity')
            
            product_id = item['product_id']
            variation_id = item.get('variation_id')
            batch_id = item['batch_id']
            quantity = Decimal(str(item['quantity']))
            
            print(f"\n--- Item {idx} ---")
            print(f"Product ID: {product_id}")
            print(f"Variation ID: {variation_id} ({'SINGLE' if variation_id is None else 'VARIATION'})")
            print(f"Batch ID: {batch_id}")
            print(f"Quantity: {quantity}")
            
            if quantity <= 0:
                raise ValueError('Quantity must be greater than 0')
            
            # Verify Batch exists and matches product/variation
            verify_batch_query = """
                SELECT 
                    pb.batch_id,
                    pb.batch_number,
                    pb.product_id,
                    pb.variation_id,
                    pb.remaining_quantity,
                    pb.cost,
                    pb.price,
                    pb.expiration_date,
                    pb.grn_id,
                    g.grn_code,
                    s.supplier_name
                FROM product_batches pb
                LEFT JOIN grn g ON pb.grn_id = g.grn_id
                LEFT JOIN suppliers s ON g.supplier_id = s.id
                WHERE pb.batch_id = %s
                  AND pb.product_id = %s
            """
            
            cursor.execute(verify_batch_query, (batch_id, product_id))
            batch_info = cursor.fetchone()
            cursor.fetchall()
            
            if not batch_info:
                raise ValueError(f'Batch {batch_id} not found for product {product_id}')
            
            # Verify variation_id matches
            if variation_id is None:
                if batch_info['variation_id'] is not None:
                    raise ValueError(
                        f'Batch {batch_id} belongs to a variation product, but no variation_id provided'
                    )
            else:
                if batch_info['variation_id'] != variation_id:
                    raise ValueError(
                        f'Batch {batch_id} variation mismatch. Expected: {variation_id}, Got: {batch_info["variation_id"]}'
                    )
            
            print(f"✓ Batch verified: {batch_info['batch_number']}")
            print(f"  GRN: {batch_info.get('grn_code', 'N/A')}")
            print(f"  Supplier: {batch_info.get('supplier_name', 'N/A')}")
            
            # Check warehouse-specific stock with FOR UPDATE lock
            check_stock_query = """
                SELECT 
                    ws.id as stock_id,
                    ws.quantity, 
                    p.product_name, 
                    p.product_type,
                    pv.variation_name,
                    pv.variation_type
                FROM warehouse_stock ws
                INNER JOIN products p ON ws.product_id = p.id
                LEFT JOIN product_variations pv ON ws.variation_id = pv.id
                WHERE ws.store_id = %s 
                  AND ws.warehouse_id = %s
                  AND ws.product_id = %s 
                  AND ws.batch_id = %s
                  AND ((ws.variation_id IS NULL AND %s IS NULL) OR (ws.variation_id = %s))
                FOR UPDATE
            """
            
            cursor.execute(check_stock_query, (
                from_store_id,
                from_warehouse_id,
                product_id,
                batch_id,
                variation_id,
                variation_id
            ))
            
            source_stock = cursor.fetchone()
            cursor.fetchall()
            
            if not source_stock:
                raise ValueError(
                    f'Product {product_id}, Batch {batch_id} not found in source warehouse '
                    f'(Store: {from_store_id}, Warehouse: {from_warehouse_id})'
                )
            
            print(f"✓ Source warehouse stock: {source_stock['quantity']} units (LOCKED)")
            
            # Verify sufficient warehouse-specific stock
            if source_stock['quantity'] < quantity:
                product_name = source_stock['product_name']
                if source_stock.get('variation_name'):
                    product_name += f' ({source_stock["variation_name"]} - {source_stock["variation_type"]})'
                raise ValueError(
                    f'Insufficient stock in source warehouse for {product_name}. '
                    f'Warehouse has: {source_stock["quantity"]}, Requested: {quantity}'
                )
            
            # Insert Transfer Item
            insert_item_query = """
                INSERT INTO stock_transfer_items (
                    transfer_id,
                    product_id,
                    variation_id,
                    batch_id,
                    quantity,
                    received_quantity
                ) VALUES (%s, %s, %s, %s, %s, %s)
            """
            
            received_quantity = quantity if status == 'completed' else 0
            
            cursor.execute(insert_item_query, (
                transfer_id,
                product_id,
                variation_id,
                batch_id,
                quantity,
                received_quantity
            ))
            
            print(f"✓ Transfer item record created")
            
            # Update Source using stock_id
            update_source_query = """
                UPDATE warehouse_stock 
                SET quantity = quantity - %s
                WHERE id = %s
                  AND quantity >= %s
            """
            
            cursor.execute(update_source_query, (
                quantity,
                source_stock['stock_id'],
                quantity
            ))
            
            if cursor.rowcount == 0:
                raise ValueError(
                    f'Failed to update source stock. Stock may have changed during transaction.'
                )
            
            print(f"✓ Source warehouse stock reduced by {quantity}")
            
            # If status is "completed", use INSERT ... ON DUPLICATE KEY UPDATE
            if status == 'completed':
                print(f"➤ Adding to destination warehouse (atomic operation)...")
                
                upsert_dest_query = """
                    INSERT INTO warehouse_stock (
                        store_id,
                        warehouse_id,
                        product_id,
                        variation_id,
                        batch_id,
                        quantity
                    ) VALUES (%s, %s, %s, %s, %s, %s)
                    ON DUPLICATE KEY UPDATE
                        quantity = quantity + VALUES(quantity)
                """
                
                cursor.execute(upsert_dest_query, (
                    to_store_id,
                    to_warehouse_id,
                    product_id,
                    variation_id,
                    batch_id,
                    quantity
                ))
                
                if cursor.rowcount == 0:
                    raise ValueError(
                        f'Failed to update destination stock for product {product_id}, batch {batch_id}'
                    )
                
                if cursor.rowcount == 1:
                    print(f"  ✓ NEW destination stock created with quantity: {quantity}")
                elif cursor.rowcount == 2:
                    print(f"  ✓ EXISTING destination stock updated by adding: {quantity}")
        
        # Commit Transaction
        connection.commit()
        print(f"\n{'='*60}")
        print(f"✅ TRANSFER COMMITTED SUCCESSFULLY!")
        print(f"{'='*60}\n")
        
        # Fetch Created Transfer Details
        cursor.execute("""
            SELECT 
                st.*,
                fs.store_name as from_store_name,
                ts.store_name as to_store_name,
                fw.warehouse_name as from_warehouse_name,
                tw.warehouse_name as to_warehouse_name,
                u.name as created_by_name
            FROM stock_transfers st
            LEFT JOIN stores fs ON st.from_store_id = fs.id
            LEFT JOIN stores ts ON st.to_store_id = ts.id
            LEFT JOIN warehouses fw ON st.from_warehouse_id = fw.id
            LEFT JOIN warehouses tw ON st.to_warehouse_id = tw.id
            LEFT JOIN users u ON st.created_by = u.id
            WHERE st.transfer_id = %s
        """, (transfer_id,))
        
        transfer_details = cursor.fetchone()
        cursor.fetchall()
        
        # Fetch transfer items
        cursor.execute("""
            SELECT 
                sti.*,
                p.product_name,
                p.sku,
                p.product_type,
                pv.variation_name,
                pv.variation_type,
                pb.batch_number,
                pb.cost,
                pb.price,
                pb.expiration_date,
                g.grn_code,
                s.supplier_name
            FROM stock_transfer_items sti
            LEFT JOIN products p ON sti.product_id = p.id
            LEFT JOIN product_variations pv ON sti.variation_id = pv.id
            LEFT JOIN product_batches pb ON sti.batch_id = pb.batch_id
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            WHERE sti.transfer_id = %s
        """, (transfer_id,))
        
        transfer_items = cursor.fetchall()
        
        # Convert datetime objects
        if transfer_details:
            for key, value in transfer_details.items():
                if isinstance(value, datetime):
                    transfer_details[key] = value.isoformat()
        
        for item in transfer_items:
            for key, value in item.items():
                if isinstance(value, datetime):
                    item[key] = value.isoformat()
        
        return jsonify({
            'success': True,
            'message': 'Stock transfer created successfully',
            'data': {
                'transfer': transfer_details,
                'items': transfer_items
            }
        }), 201
        
    except ValueError as ve:
        if connection:
            connection.rollback()
        print(f"\n❌ VALIDATION ERROR: {ve}\n")
        return jsonify({
            'success': False,
            'message': str(ve)
        }), 400
        
    except mysql.connector.Error as err:
        if connection:
            connection.rollback()
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'Database error: {str(err)}'
        }), 500
        
    except Exception as e:
        if connection:
            connection.rollback()
        print(f"❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'An unexpected error occurred: {str(e)}'
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()


# ============================================
# GENERATE TRANSFER CODE
# ============================================
def generate_transfer_code(cursor, transfer_date):
    """Generate unique transfer code: TR-YYYYMMDD-XXXX"""
    try:
        date_obj = datetime.strptime(transfer_date, '%Y-%m-%d')
        date_str = date_obj.strftime('%Y%m%d')
        
        query = """
            SELECT transfer_code 
            FROM stock_transfers 
            WHERE transfer_code LIKE %s 
            ORDER BY transfer_code DESC 
            LIMIT 1
        """
        
        cursor.execute(query, (f'TR-{date_str}-%',))
        result = cursor.fetchone()
        cursor.fetchall()
        
        if result:
            last_code = result['transfer_code']
            last_seq = int(last_code.split('-')[-1])
            new_seq = last_seq + 1
        else:
            new_seq = 1
        
        transfer_code = f'TR-{date_str}-{new_seq:04d}'
        return transfer_code
        
    except Exception as e:
        timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
        return f'TR-{timestamp}'


# ============================================
# GET ALL STOCK TRANSFERS
# ============================================
@stock_transfer_bp.route('/list', methods=['GET'])
@jwt_required()
def get_all_transfers():
    """Get all stock transfers with optional filters"""
    connection = None
    cursor = None
    
    try:
        status = request.args.get('status')
        from_store_id = request.args.get('from_store_id')
        to_store_id = request.args.get('to_store_id')
        from_date = request.args.get('from_date')
        to_date = request.args.get('to_date')
        
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        query = """
            SELECT 
                st.*,
                fs.store_name as from_store_name,
                ts.store_name as to_store_name,
                fw.warehouse_name as from_warehouse_name,
                tw.warehouse_name as to_warehouse_name,
                u.name as created_by_name
            FROM stock_transfers st
            LEFT JOIN stores fs ON st.from_store_id = fs.id
            LEFT JOIN stores ts ON st.to_store_id = ts.id
            LEFT JOIN warehouses fw ON st.from_warehouse_id = fw.id
            LEFT JOIN warehouses tw ON st.to_warehouse_id = tw.id
            LEFT JOIN users u ON st.created_by = u.id
            WHERE 1=1
        """
        
        params = []
        
        if status:
            query += " AND st.status = %s"
            params.append(status)
        
        if from_store_id:
            query += " AND st.from_store_id = %s"
            params.append(from_store_id)
        
        if to_store_id:
            query += " AND st.to_store_id = %s"
            params.append(to_store_id)
        
        if from_date:
            query += " AND st.transfer_date >= %s"
            params.append(from_date)
        
        if to_date:
            query += " AND st.transfer_date <= %s"
            params.append(to_date)
        
        query += " ORDER BY st.created_at DESC"
        
        cursor.execute(query, params)
        transfers = cursor.fetchall()
        
        for transfer in transfers:
            for key, value in transfer.items():
                if isinstance(value, datetime):
                    transfer[key] = value.isoformat()
        
        return jsonify({
            'success': True,
            'data': transfers,
            'count': len(transfers)
        }), 200
        
    except Exception as e:
        print(f"Error: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': str(e)
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()


# ============================================
# GET SINGLE STOCK TRANSFER
# ============================================
@stock_transfer_bp.route('/<int:transfer_id>', methods=['GET'])
@jwt_required()
def get_transfer(transfer_id):
    """Get detailed information about a specific transfer with GRN info"""
    connection = None
    cursor = None
    
    try:
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        cursor.execute("""
            SELECT 
                st.*,
                fs.store_name as from_store_name,
                ts.store_name as to_store_name,
                fw.warehouse_name as from_warehouse_name,
                tw.warehouse_name as to_warehouse_name,
                u.name as created_by_name,
                ru.name as received_by_name
            FROM stock_transfers st
            LEFT JOIN stores fs ON st.from_store_id = fs.id
            LEFT JOIN stores ts ON st.to_store_id = ts.id
            LEFT JOIN warehouses fw ON st.from_warehouse_id = fw.id
            LEFT JOIN warehouses tw ON st.to_warehouse_id = tw.id
            LEFT JOIN users u ON st.created_by = u.id
            LEFT JOIN users ru ON st.received_by = ru.id
            WHERE st.transfer_id = %s
        """, (transfer_id,))
        
        transfer = cursor.fetchone()
        cursor.fetchall()
        
        if not transfer:
            return jsonify({
                'success': False,
                'message': 'Transfer not found'
            }), 404
        
        cursor.execute("""
            SELECT 
                sti.*,
                p.product_name,
                p.sku,
                p.product_type,
                pv.variation_name,
                pv.variation_type,
                pb.batch_number,
                pb.cost,
                pb.price,
                pb.expiration_date,
                g.grn_code,
                s.supplier_name
            FROM stock_transfer_items sti
            LEFT JOIN products p ON sti.product_id = p.id
            LEFT JOIN product_variations pv ON sti.variation_id = pv.id
            LEFT JOIN product_batches pb ON sti.batch_id = pb.batch_id
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            WHERE sti.transfer_id = %s
        """, (transfer_id,))
        
        items = cursor.fetchall()
        
        for key, value in transfer.items():
            if isinstance(value, datetime):
                transfer[key] = value.isoformat()
        
        for item in items:
            for key, value in item.items():
                if isinstance(value, datetime):
                    item[key] = value.isoformat()
        
        return jsonify({
            'success': True,
            'data': {
                'transfer': transfer,
                'items': items
            }
        }), 200
        
    except Exception as e:
        print(f"Error: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': str(e)
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()


# ============================================
# UPDATE TRANSFER STATUS
# ============================================
@stock_transfer_bp.route('/<int:transfer_id>/status', methods=['PATCH'])
@jwt_required()
@role_required('admin')
def update_transfer_status(transfer_id):
    """Update transfer status"""
    connection = None
    cursor = None
    
    try:
        current_user_id = get_jwt_identity()
        data = request.get_json()
        
        new_status = data.get('status')
        
        if not new_status:
            return jsonify({
                'success': False,
                'message': 'Status is required'
            }), 400
        
        valid_statuses = ['pending', 'in_transit', 'completed', 'cancelled']
        if new_status not in valid_statuses:
            return jsonify({
                'success': False,
                'message': f'Invalid status. Must be one of: {", ".join(valid_statuses)}'
            }), 400
        
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        connection.start_transaction()
        
        cursor.execute("""
            SELECT * FROM stock_transfers WHERE transfer_id = %s
        """, (transfer_id,))
        
        transfer = cursor.fetchone()
        cursor.fetchall()
        
        if not transfer:
            return jsonify({
                'success': False,
                'message': 'Transfer not found'
            }), 404
        
        old_status = transfer['status']
        
        if old_status in ['pending', 'in_transit'] and new_status == 'completed':
            cursor.execute("""
                SELECT * FROM stock_transfer_items WHERE transfer_id = %s
            """, (transfer_id,))
            
            items = cursor.fetchall()
            
            for item in items:
                variation_id = item['variation_id']
                
                check_dest_query = """
                    SELECT id, quantity 
                    FROM warehouse_stock 
                    WHERE store_id = %s 
                      AND warehouse_id = %s
                      AND product_id = %s 
                      AND batch_id = %s
                      AND (
                          (variation_id IS NULL AND %s IS NULL) OR
                          (variation_id = %s)
                      )
                """
                
                cursor.execute(check_dest_query, (
                    transfer['to_store_id'],
                    transfer['to_warehouse_id'],
                    item['product_id'],
                    item['batch_id'],
                    variation_id,
                    variation_id
                ))
                
                existing_dest = cursor.fetchone()
                cursor.fetchall()
                
                if existing_dest:
                    update_dest_query = """
                        UPDATE warehouse_stock 
                        SET quantity = quantity + %s
                        WHERE id = %s
                    """
                    cursor.execute(update_dest_query, (item['quantity'], existing_dest['id']))
                    
                    if cursor.rowcount == 0:
                        raise ValueError(f'Failed to update destination stock for product {item["product_id"]}')
                else:
                    insert_dest_query = """
                        INSERT INTO warehouse_stock (
                            store_id,
                            warehouse_id,
                            product_id,
                            variation_id,
                            batch_id,
                            quantity
                        ) VALUES (%s, %s, %s, %s, %s, %s)
                    """
                    
                    cursor.execute(insert_dest_query, (
                        transfer['to_store_id'],
                        transfer['to_warehouse_id'],
                        item['product_id'],
                        variation_id,
                        item['batch_id'],
                        item['quantity']
                    ))
                    
                    if cursor.rowcount == 0:
                        raise ValueError(f'Failed to insert destination stock for product {item["product_id"]}')
                
                cursor.execute("""
                    UPDATE stock_transfer_items 
                    SET received_quantity = %s
                    WHERE item_id = %s
                """, (item['quantity'], item['item_id']))
        
        update_query = """
            UPDATE stock_transfers 
            SET status = %s,
                received_by = %s
            WHERE transfer_id = %s
        """
        
        cursor.execute(update_query, (
            new_status,
            current_user_id if new_status == 'completed' else None,
            transfer_id
        ))
        
        connection.commit()
        
        return jsonify({
            'success': True,
            'message': f'Transfer status updated to {new_status}'
        }), 200
        
    except ValueError as ve:
        if connection:
            connection.rollback()
        return jsonify({
            'success': False,
            'message': str(ve)
        }), 400
        
    except Exception as e:
        if connection:
            connection.rollback()
        print(f"Error: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': str(e)
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()


# ============================================
# ✅ FIXED DELETE TRANSFER
# ============================================
@stock_transfer_bp.route('/<int:transfer_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_transfer(transfer_id):
    """Delete a stock transfer and reverse stock movements"""
    connection = None
    cursor = None
    
    try:
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        connection.start_transaction(isolation_level='SERIALIZABLE')
        
        # Get transfer details
        cursor.execute("""
            SELECT * FROM stock_transfers WHERE transfer_id = %s FOR UPDATE
        """, (transfer_id,))
        
        transfer = cursor.fetchone()
        cursor.fetchall()
        
        if not transfer:
            return jsonify({
                'success': False,
                'message': 'Transfer not found'
            }), 404
        
        print(f"\n{'='*60}")
        print(f"🗑️ DELETING TRANSFER - ID: {transfer_id}")
        print(f"Status: {transfer['status']}")
        print(f"Code: {transfer.get('transfer_code', 'N/A')}")
        print(f"{'='*60}")
        
        # Get all transfer items
        cursor.execute("""
            SELECT * FROM stock_transfer_items WHERE transfer_id = %s
        """, (transfer_id,))
        
        items = cursor.fetchall()
        cursor.fetchall()
        
        print(f"Found {len(items)} items to reverse")
        
        # Process based on status
        if transfer['status'] == 'completed':
            print("\n⚠️ Transfer is COMPLETED - Reversing stock movements...")
            
            for idx, item in enumerate(items, 1):
                product_id = item['product_id']
                variation_id = item['variation_id']
                batch_id = item['batch_id']
                quantity = Decimal(str(item['quantity']))
                
                print(f"\n--- Reversing Item {idx} ---")
                print(f"Product: {product_id}, Variation: {variation_id}, Batch: {batch_id}, Qty: {quantity}")
                
                # ✅ FIX 1: Add back to source warehouse using UPSERT
                upsert_source_query = """
                    INSERT INTO warehouse_stock (
                        store_id,
                        warehouse_id,
                        product_id,
                        variation_id,
                        batch_id,
                        quantity
                    ) VALUES (%s, %s, %s, %s, %s, %s)
                    ON DUPLICATE KEY UPDATE
                        quantity = quantity + VALUES(quantity)
                """
                
                cursor.execute(upsert_source_query, (
                    transfer['from_store_id'],
                    transfer['from_warehouse_id'],
                    product_id,
                    variation_id,
                    batch_id,
                    quantity
                ))
                
                if cursor.rowcount == 1:
                    print(f"  ✓ Source stock record CREATED with {quantity} units")
                elif cursor.rowcount == 2:
                    print(f"  ✓ Source stock UPDATED - added {quantity} units")
                else:
                    raise ValueError(f'Failed to restore stock to source warehouse')
                
                # ✅ FIX 2: Reduce from destination warehouse with proper locking
                select_dest_query = """
                    SELECT id, quantity 
                    FROM warehouse_stock 
                    WHERE store_id = %s 
                      AND warehouse_id = %s
                      AND product_id = %s 
                      AND batch_id = %s
                      AND ((variation_id IS NULL AND %s IS NULL) OR (variation_id = %s))
                    FOR UPDATE
                """
                
                cursor.execute(select_dest_query, (
                    transfer['to_store_id'],
                    transfer['to_warehouse_id'],
                    product_id,
                    batch_id,
                    variation_id,
                    variation_id
                ))
                
                dest_stock = cursor.fetchone()
                cursor.fetchall()
                
                if not dest_stock:
                    print(f"  ⚠️ WARNING: Destination stock record not found")
                    print(f"     This may indicate the stock was already moved or deleted")
                    print(f"     Continuing deletion without reducing destination stock")
                    continue
                
                current_dest_qty = Decimal(str(dest_stock['quantity']))
                print(f"  📊 Destination has: {current_dest_qty} units")
                
                if current_dest_qty < quantity:
                    print(f"  ⚠️ WARNING: Destination has less stock than transferred amount")
                    print(f"     Requested to reduce: {quantity}, Available: {current_dest_qty}")
                    print(f"     Reducing available amount only")
                    quantity_to_reduce = current_dest_qty
                else:
                    quantity_to_reduce = quantity
                
                # Reduce from destination
                new_dest_qty = current_dest_qty - quantity_to_reduce
                
                if new_dest_qty <= 0:
                    # Delete the record if quantity becomes zero or negative
                    delete_dest_query = """
                        DELETE FROM warehouse_stock 
                        WHERE id = %s
                    """
                    cursor.execute(delete_dest_query, (dest_stock['id'],))
                    print(f"  ✓ Destination stock record DELETED (quantity was {quantity_to_reduce})")
                else:
                    # Update with remaining quantity
                    update_dest_query = """
                        UPDATE warehouse_stock 
                        SET quantity = %s
                        WHERE id = %s
                    """
                    cursor.execute(update_dest_query, (new_dest_qty, dest_stock['id']))
                    print(f"  ✓ Destination stock REDUCED by {quantity_to_reduce} (remaining: {new_dest_qty})")
                
                if cursor.rowcount == 0:
                    raise ValueError(f'Failed to update destination stock for product {product_id}')
        
        elif transfer['status'] in ['pending', 'in_transit']:
            print(f"\n⚠️ Transfer is {transfer['status'].upper()} - Returning stock to source...")
            
            for idx, item in enumerate(items, 1):
                product_id = item['product_id']
                variation_id = item['variation_id']
                batch_id = item['batch_id']
                quantity = Decimal(str(item['quantity']))
                
                print(f"\n--- Returning Item {idx} to Source ---")
                print(f"Product: {product_id}, Variation: {variation_id}, Batch: {batch_id}, Qty: {quantity}")
                
                # Add back to source using UPSERT
                upsert_source_query = """
                    INSERT INTO warehouse_stock (
                        store_id,
                        warehouse_id,
                        product_id,
                        variation_id,
                        batch_id,
                        quantity
                    ) VALUES (%s, %s, %s, %s, %s, %s)
                    ON DUPLICATE KEY UPDATE
                        quantity = quantity + VALUES(quantity)
                """
                
                cursor.execute(upsert_source_query, (
                    transfer['from_store_id'],
                    transfer['from_warehouse_id'],
                    product_id,
                    variation_id,
                    batch_id,
                    quantity
                ))
                
                if cursor.rowcount == 1:
                    print(f"  ✓ Stock record CREATED with {quantity} units")
                elif cursor.rowcount == 2:
                    print(f"  ✓ Stock UPDATED - added {quantity} units")
        
        elif transfer['status'] == 'cancelled':
            print("\n⚠️ Transfer is CANCELLED - Stock should already be correct")
            print("   No stock movements to reverse")
        
        # Delete transfer items
        cursor.execute("""
            DELETE FROM stock_transfer_items WHERE transfer_id = %s
        """, (transfer_id,))
        deleted_items = cursor.rowcount
        print(f"\n✓ Deleted {deleted_items} transfer items")
        
        # Delete transfer record
        cursor.execute("""
            DELETE FROM stock_transfers WHERE transfer_id = %s
        """, (transfer_id,))
        print(f"✓ Deleted transfer record")
        
        connection.commit()
        
        print(f"\n{'='*60}")
        print(f"✅ TRANSFER DELETED SUCCESSFULLY!")
        print(f"{'='*60}\n")
        
        return jsonify({
            'success': True,
            'message': 'Transfer deleted successfully and stock movements reversed'
        }), 200
        
    except ValueError as ve:
        if connection:
            connection.rollback()
        print(f"\n❌ VALIDATION ERROR: {ve}\n")
        return jsonify({
            'success': False,
            'message': str(ve)
        }), 400
        
    except mysql.connector.Error as err:
        if connection:
            connection.rollback()
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'Database error: {str(err)}'
        }), 500
        
    except Exception as e:
        if connection:
            connection.rollback()
        print(f"❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'An unexpected error occurred: {str(e)}'
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()


# ============================================
# GET BATCHES BY WAREHOUSE - WITH GRN INFO
# ============================================
@stock_transfer_bp.route('/batches/warehouse', methods=['GET'])
@jwt_required()
def get_batches_by_warehouse():
    """
    Get batches for a specific product in a specific warehouse
    ✅ INCLUDES GRN and supplier information
    
    Query params:
    - product_id (required)
    - store_id (required)
    - warehouse_id (required)
    - variation_id (optional) - Only for variation products
    
    Returns only batches that have stock in the specified warehouse
    """
    connection = None
    cursor = None
    
    try:
        product_id = request.args.get('product_id')
        store_id = request.args.get('store_id')
        warehouse_id = request.args.get('warehouse_id')
        variation_id = request.args.get('variation_id')
        
        if not product_id or not store_id or not warehouse_id:
            return jsonify({
                'success': False,
                'message': 'Missing required parameters: product_id, store_id, warehouse_id'
            }), 400
        
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        query = """
            SELECT 
                pb.batch_id,
                pb.batch_number,
                pb.cost,
                pb.price,
                pb.expiration_date,
                pb.remaining_quantity as batch_remaining_quantity,
                ws.quantity as warehouse_stock_quantity,
                ws.warehouse_id,
                w.warehouse_name,
                p.product_name,
                p.sku,
                p.product_type,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                g.grn_code,
                g.grn_id,
                s.supplier_name,
                s.supplier_code,
                pb.created_on
            FROM warehouse_stock ws
            INNER JOIN product_batches pb ON ws.batch_id = pb.batch_id
            INNER JOIN warehouses w ON ws.warehouse_id = w.id
            INNER JOIN products p ON ws.product_id = p.id
            LEFT JOIN product_variations pv ON ws.variation_id = pv.id
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            WHERE ws.product_id = %s
              AND ws.warehouse_id = %s
              AND ws.store_id = %s
              AND ws.quantity > 0
              AND (
                  (ws.variation_id IS NULL AND %s IS NULL) OR
                  (ws.variation_id = %s)
              )
            ORDER BY pb.expiration_date ASC, pb.created_on ASC
        """
        
        cursor.execute(query, (
            product_id,
            warehouse_id,
            store_id,
            variation_id,
            variation_id
        ))
        
        batches = cursor.fetchall()
        
        print(f"\n{'='*60}")
        print(f"📦 BATCH QUERY RESULTS (WITH GRN)")
        print(f"{'='*60}")
        print(f"Product ID: {product_id}")
        print(f"Variation ID: {variation_id} ({'SINGLE' if variation_id is None else 'VARIATION'})")
        print(f"Warehouse ID: {warehouse_id}")
        print(f"Store ID: {store_id}")
        print(f"Found: {len(batches)} batch(es)")
        
        for batch in batches:
            print(f"  - Batch {batch['batch_number']}: {batch['warehouse_stock_quantity']} units")
            print(f"    GRN: {batch.get('grn_code', 'N/A')}, Supplier: {batch.get('supplier_name', 'N/A')}")
        print(f"{'='*60}\n")
        
        # Convert datetime and use warehouse stock quantity
        for batch in batches:
            batch['remaining_quantity'] = batch['warehouse_stock_quantity']
            
            for key, value in batch.items():
                if isinstance(value, datetime):
                    batch[key] = value.isoformat()
        
        return jsonify({
            'success': True,
            'batches': batches,
            'count': len(batches)
        }), 200
        
    except mysql.connector.Error as err:
        print(f"Database Error: {err}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'Database error: {str(err)}'
        }), 500
        
    except Exception as e:
        print(f"Error: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': str(e)
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()
            
            
            
            
# ============================================
# ✅ COMPLETELY FIXED UPDATE STOCK TRANSFER
# ============================================
@stock_transfer_bp.route('/<int:transfer_id>/update', methods=['PUT'])
@jwt_required()
@role_required('admin')
def update_stock_transfer(transfer_id):
    """
    ✅ FULLY CORRECTED: Properly handles same warehouse edits
    
    When editing transfer with SAME warehouses:
    - Old: 10 units transferred (stock: 140)
    - New: 20 units transferred
    - Should only take additional 10 units (not 20 again!)
    - Final stock: 130 ✅
    """
    
    connection = None
    cursor = None
    
    try:
        current_user_id = get_jwt_identity()
        data = request.get_json()
        
        # Validate required fields
        required_fields = ['from_store_id', 'to_store_id', 'transfer_date', 'status', 'items']
        for field in required_fields:
            if field not in data:
                return jsonify({
                    'success': False,
                    'message': f'Missing required field: {field}'
                }), 400
        
        if not isinstance(data['items'], list) or len(data['items']) == 0:
            return jsonify({
                'success': False,
                'message': 'Items array must contain at least one item'
            }), 400
        
        connection = get_db_connection()
        cursor = connection.cursor(dictionary=True, buffered=True)
        
        connection.start_transaction(isolation_level='SERIALIZABLE')
        
        # Get old transfer details
        cursor.execute("""
            SELECT * FROM stock_transfers WHERE transfer_id = %s FOR UPDATE
        """, (transfer_id,))
        
        old_transfer = cursor.fetchone()
        cursor.fetchall()
        
        if not old_transfer:
            return jsonify({
                'success': False,
                'message': 'Transfer not found'
            }), 404
        
        old_status = old_transfer['status']
        new_status = data['status']
        
        old_from_store = old_transfer['from_store_id']
        old_from_warehouse = old_transfer['from_warehouse_id']
        old_to_store = old_transfer['to_store_id']
        old_to_warehouse = old_transfer['to_warehouse_id']
        
        new_from_store = data['from_store_id']
        new_from_warehouse = data.get('from_warehouse_id')
        new_to_store = data['to_store_id']
        new_to_warehouse = data.get('to_warehouse_id')
        
        print(f"\n{'='*80}")
        print(f"📝 UPDATING TRANSFER #{transfer_id}")
        print(f"{'='*80}")
        print(f"Status: {old_status} → {new_status}")
        print(f"Old Route: Store {old_from_store}/WH {old_from_warehouse} → Store {old_to_store}/WH {old_to_warehouse}")
        print(f"New Route: Store {new_from_store}/WH {new_from_warehouse} → Store {new_to_store}/WH {new_to_warehouse}")
        
        # Check if warehouses are the same
        same_warehouses = (
            old_from_store == new_from_store and 
            old_from_warehouse == new_from_warehouse and
            old_to_store == new_to_store and 
            old_to_warehouse == new_to_warehouse
        )
        
        # Check if it's a simple quantity edit (same warehouses, both pending/in_transit)
        is_simple_edit = (
            same_warehouses and 
            old_status in ['pending', 'in_transit'] and 
            new_status in ['pending', 'in_transit']
        )
        
        if is_simple_edit:
            print("✓ SIMPLE EDIT MODE: Same warehouses, adjusting quantities only")
        else:
            print("⚠️ FULL REVERSAL MODE: Different warehouses or status change")
        
        print(f"{'='*80}\n")
        
        # Get old transfer items
        cursor.execute("""
            SELECT * FROM stock_transfer_items WHERE transfer_id = %s ORDER BY item_id
        """, (transfer_id,))
        
        old_items = cursor.fetchall()
        cursor.fetchall()
        
        # ============================================
        # MODE 1: SIMPLE QUANTITY EDIT
        # ============================================
        if is_simple_edit:
            print("🔄 SIMPLE EDIT: Adjusting quantities by difference...\n")
            
            # Build maps for comparison
            old_items_map = {}
            for item in old_items:
                key = (item['product_id'], item['variation_id'], item['batch_id'])
                old_items_map[key] = Decimal(str(item['quantity']))
            
            new_items_map = {}
            for item in data['items']:
                key = (item['product_id'], item.get('variation_id'), item['batch_id'])
                new_items_map[key] = Decimal(str(item['quantity']))
            
            # Process each item
            all_keys = set(old_items_map.keys()) | set(new_items_map.keys())
            
            for key in all_keys:
                product_id, variation_id, batch_id = key
                old_qty = old_items_map.get(key, Decimal('0'))
                new_qty = new_items_map.get(key, Decimal('0'))
                diff = new_qty - old_qty
                
                print(f"Product {product_id}, Batch {batch_id}:")
                print(f"  Old: {old_qty}, New: {new_qty}, Diff: {diff}")
                
                if diff == 0:
                    print(f"  → No change\n")
                    continue
                
                if diff > 0:
                    # Need to take MORE from warehouse
                    print(f"  → Taking additional {diff} from source warehouse")
                    
                    # Check if enough stock available
                    cursor.execute("""
                        SELECT id as stock_id, quantity 
                        FROM warehouse_stock
                        WHERE store_id = %s 
                          AND warehouse_id = %s
                          AND product_id = %s 
                          AND batch_id = %s
                          AND ((variation_id IS NULL AND %s IS NULL) OR (variation_id = %s))
                        FOR UPDATE
                    """, (old_from_store, old_from_warehouse, product_id, batch_id, variation_id, variation_id))
                    
                    source_stock = cursor.fetchone()
                    cursor.fetchall()
                    
                    if not source_stock:
                        raise ValueError(f'Stock not found for Product {product_id}, Batch {batch_id}')
                    
                    available = Decimal(str(source_stock['quantity']))
                    print(f"  → Available in warehouse: {available}")
                    
                    if available < diff:
                        raise ValueError(
                            f'Insufficient stock for Product {product_id}, Batch {batch_id}. '
                            f'Need {diff} more, but only {available} available'
                        )
                    
                    # Take the difference from warehouse
                    cursor.execute("""
                        UPDATE warehouse_stock 
                        SET quantity = quantity - %s
                        WHERE id = %s
                    """, (diff, source_stock['stock_id']))
                    
                    new_warehouse_qty = available - diff
                    print(f"  ✓ Reduced warehouse stock by {diff} (now {new_warehouse_qty})\n")
                    
                elif diff < 0:
                    # Need to RETURN to warehouse
                    return_qty = abs(diff)
                    print(f"  → Returning {return_qty} to source warehouse")
                    
                    # Add back to warehouse
                    cursor.execute("""
                        INSERT INTO warehouse_stock (
                            store_id, warehouse_id, product_id, variation_id, batch_id, quantity
                        ) VALUES (%s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            quantity = quantity + VALUES(quantity)
                    """, (old_from_store, old_from_warehouse, product_id, variation_id, batch_id, return_qty))
                    
                    print(f"  ✓ Returned {return_qty} to warehouse\n")
        
        # ============================================
        # MODE 2: FULL REVERSAL + NEW TRANSFER
        # ============================================
        else:
            print("🔄 FULL REVERSAL: Complete stock movement reversal...\n")
            
            # STEP 1: Reverse old transfer completely
            if old_status == 'completed':
                print("Reversing COMPLETED transfer:\n")
                
                for idx, item in enumerate(old_items, 1):
                    product_id = item['product_id']
                    variation_id = item['variation_id']
                    batch_id = item['batch_id']
                    qty = Decimal(str(item['quantity']))
                    
                    print(f"  [{idx}] Product {product_id}, Batch {batch_id}, Qty {qty}")
                    
                    # Return to OLD source
                    cursor.execute("""
                        INSERT INTO warehouse_stock (
                            store_id, warehouse_id, product_id, variation_id, batch_id, quantity
                        ) VALUES (%s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            quantity = quantity + VALUES(quantity)
                    """, (old_from_store, old_from_warehouse, product_id, variation_id, batch_id, qty))
                    print(f"      ✓ Returned {qty} to OLD source")
                    
                    # Remove from OLD destination
                    cursor.execute("""
                        SELECT id, quantity FROM warehouse_stock
                        WHERE store_id = %s AND warehouse_id = %s
                          AND product_id = %s AND batch_id = %s
                          AND ((variation_id IS NULL AND %s IS NULL) OR (variation_id = %s))
                        FOR UPDATE
                    """, (old_to_store, old_to_warehouse, product_id, batch_id, variation_id, variation_id))
                    
                    dest_stock = cursor.fetchone()
                    cursor.fetchall()
                    
                    if dest_stock:
                        dest_qty = Decimal(str(dest_stock['quantity']))
                        new_dest_qty = dest_qty - qty
                        
                        if new_dest_qty <= 0:
                            cursor.execute("DELETE FROM warehouse_stock WHERE id = %s", (dest_stock['id'],))
                            print(f"      ✓ Removed {qty} from OLD dest (deleted record)\n")
                        else:
                            cursor.execute("UPDATE warehouse_stock SET quantity = %s WHERE id = %s",
                                         (new_dest_qty, dest_stock['id']))
                            print(f"      ✓ Removed {qty} from OLD dest (now {new_dest_qty})\n")
                    else:
                        print(f"      ⚠️ OLD dest stock not found\n")
                        
            elif old_status in ['pending', 'in_transit']:
                print("Reversing PENDING/IN_TRANSIT transfer:\n")
                
                for idx, item in enumerate(old_items, 1):
                    product_id = item['product_id']
                    variation_id = item['variation_id']
                    batch_id = item['batch_id']
                    qty = Decimal(str(item['quantity']))
                    
                    print(f"  [{idx}] Product {product_id}, Batch {batch_id}, Qty {qty}")
                    
                    # Return to OLD source
                    cursor.execute("""
                        INSERT INTO warehouse_stock (
                            store_id, warehouse_id, product_id, variation_id, batch_id, quantity
                        ) VALUES (%s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            quantity = quantity + VALUES(quantity)
                    """, (old_from_store, old_from_warehouse, product_id, variation_id, batch_id, qty))
                    print(f"      ✓ Returned {qty} to source\n")
            
            # STEP 2: Apply NEW transfer
            print("Applying NEW transfer:\n")
            
            for idx, item in enumerate(data['items'], 1):
                product_id = item['product_id']
                variation_id = item.get('variation_id')
                batch_id = item['batch_id']
                qty = Decimal(str(item['quantity']))
                
                print(f"  [{idx}] Product {product_id}, Batch {batch_id}, Qty {qty}")
                
                # Verify batch exists
                cursor.execute("""
                    SELECT batch_id, batch_number FROM product_batches
                    WHERE batch_id = %s AND product_id = %s
                """, (batch_id, product_id))
                
                batch_info = cursor.fetchone()
                cursor.fetchall()
                
                if not batch_info:
                    raise ValueError(f'Batch {batch_id} not found for product {product_id}')
                
                # Get NEW source stock (with lock)
                cursor.execute("""
                    SELECT id as stock_id, quantity FROM warehouse_stock
                    WHERE store_id = %s AND warehouse_id = %s
                      AND product_id = %s AND batch_id = %s
                      AND ((variation_id IS NULL AND %s IS NULL) OR (variation_id = %s))
                    FOR UPDATE
                """, (new_from_store, new_from_warehouse, product_id, batch_id, variation_id, variation_id))
                
                source_stock = cursor.fetchone()
                cursor.fetchall()
                
                if not source_stock:
                    raise ValueError(f'Stock not found for Product {product_id}, Batch {batch_id}')
                
                available = Decimal(str(source_stock['quantity']))
                print(f"      Available: {available}")
                
                if available < qty:
                    raise ValueError(f'Insufficient stock. Need {qty}, have {available}')
                
                # Take from NEW source
                cursor.execute("""
                    UPDATE warehouse_stock 
                    SET quantity = quantity - %s
                    WHERE id = %s
                """, (qty, source_stock['stock_id']))
                
                new_qty = available - qty
                print(f"      ✓ Took {qty} from NEW source (now {new_qty})")
                
                # If completed, add to NEW destination
                if new_status == 'completed':
                    cursor.execute("""
                        INSERT INTO warehouse_stock (
                            store_id, warehouse_id, product_id, variation_id, batch_id, quantity
                        ) VALUES (%s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            quantity = quantity + VALUES(quantity)
                    """, (new_to_store, new_to_warehouse, product_id, variation_id, batch_id, qty))
                    print(f"      ✓ Added {qty} to NEW dest\n")
                else:
                    print(f"      (not adding to dest - status is {new_status})\n")
        
        # ============================================
        # UPDATE DATABASE RECORDS
        # ============================================
        print("📝 Updating transfer records...\n")
        
        # Delete old items
        cursor.execute("DELETE FROM stock_transfer_items WHERE transfer_id = %s", (transfer_id,))
        print(f"  ✓ Deleted {cursor.rowcount} old items")
        
        # Update transfer header
        cursor.execute("""
            UPDATE stock_transfers 
            SET from_store_id = %s, from_warehouse_id = %s,
                to_store_id = %s, to_warehouse_id = %s,
                transfer_date = %s, note = %s, total_items = %s,
                status = %s, received_by = %s
            WHERE transfer_id = %s
        """, (new_from_store, new_from_warehouse, new_to_store, new_to_warehouse,
              data['transfer_date'], data.get('note', ''), len(data['items']),
              new_status, current_user_id if new_status == 'completed' else None,
              transfer_id))
        print(f"  ✓ Updated transfer header")
        
        # Insert new items
        for item in data['items']:
            cursor.execute("""
                INSERT INTO stock_transfer_items (
                    transfer_id, product_id, variation_id, batch_id,
                    quantity, received_quantity
                ) VALUES (%s, %s, %s, %s, %s, %s)
            """, (transfer_id, item['product_id'], item.get('variation_id'),
                  item['batch_id'], item['quantity'],
                  item['quantity'] if new_status == 'completed' else 0))
        print(f"  ✓ Inserted {len(data['items'])} new items")
        
        # Commit transaction
        connection.commit()
        
        print(f"\n{'='*80}")
        print(f"✅ TRANSFER UPDATED SUCCESSFULLY!")
        print(f"{'='*80}\n")
        
        # Fetch and return updated data
        cursor.execute("""
            SELECT st.*, 
                   fs.store_name as from_store_name, 
                   ts.store_name as to_store_name,
                   fw.warehouse_name as from_warehouse_name, 
                   tw.warehouse_name as to_warehouse_name,
                   u.name as created_by_name
            FROM stock_transfers st
            LEFT JOIN stores fs ON st.from_store_id = fs.id
            LEFT JOIN stores ts ON st.to_store_id = ts.id
            LEFT JOIN warehouses fw ON st.from_warehouse_id = fw.id
            LEFT JOIN warehouses tw ON st.to_warehouse_id = tw.id
            LEFT JOIN users u ON st.created_by = u.id
            WHERE st.transfer_id = %s
        """, (transfer_id,))
        
        transfer_details = cursor.fetchone()
        cursor.fetchall()
        
        cursor.execute("""
            SELECT sti.*, 
                   p.product_name, p.sku, p.product_type,
                   pv.variation_name, pv.variation_type,
                   pb.batch_number, pb.cost, pb.price, pb.expiration_date,
                   g.grn_code, s.supplier_name
            FROM stock_transfer_items sti
            LEFT JOIN products p ON sti.product_id = p.id
            LEFT JOIN product_variations pv ON sti.variation_id = pv.id
            LEFT JOIN product_batches pb ON sti.batch_id = pb.batch_id
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            WHERE sti.transfer_id = %s
        """, (transfer_id,))
        
        transfer_items = cursor.fetchall()
        
        # Convert datetime objects
        if transfer_details:
            for key, value in transfer_details.items():
                if isinstance(value, datetime):
                    transfer_details[key] = value.isoformat()
        
        for item in transfer_items:
            for key, value in item.items():
                if isinstance(value, datetime):
                    item[key] = value.isoformat()
        
        return jsonify({
            'success': True,
            'message': 'Stock transfer updated successfully',
            'data': {
                'transfer': transfer_details,
                'items': transfer_items
            }
        }), 200
        
    except ValueError as ve:
        if connection:
            connection.rollback()
        print(f"\n❌ VALIDATION ERROR: {ve}\n")
        return jsonify({
            'success': False,
            'message': str(ve)
        }), 400
        
    except mysql.connector.Error as err:
        if connection:
            connection.rollback()
        print(f"❌ DATABASE ERROR: {err}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'Database error: {str(err)}'
        }), 500
        
    except Exception as e:
        if connection:
            connection.rollback()
        print(f"❌ UNEXPECTED ERROR: {e}")
        traceback.print_exc()
        return jsonify({
            'success': False,
            'message': f'An unexpected error occurred: {str(e)}'
        }), 500
        
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()