from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
import mysql.connector
from datetime import datetime
from typing import Optional, List

lab_bill_bp = Blueprint('lab_bill', __name__)


# ═══════════════════════════════════════════════════════════
# SERIALIZERS
# ═══════════════════════════════════════════════════════════
def _serialize_bill(row: dict) -> dict:
    """Convert a lab_bills DB row to a JSON-safe dict."""
    return {
        "id":               row.get("id"),
        "bill_number":      row.get("bill_number"),
        "store_id":         row.get("store_id"),
        "patient_id":       row.get("patient_id"),
        "patient_name":     row.get("patient_name"),
        "patient_contact":  row.get("patient_contact"),
        "patient_nic":      row.get("patient_nic"),
        "patient_age":      row.get("patient_age"),
        "bill_date":        str(row["bill_date"])      if row.get("bill_date")      else None,
        "billed_by":        row.get("billed_by"),
        "billed_by_name":   row.get("billed_by_name"),   # ✅ THIS LINE — add කරන්න
        "payment_method":   row.get("payment_method"),
        "payment_status":   row.get("payment_status"),
        "discount_type":    row.get("discount_type"),
        "discount_value":   float(row.get("discount_value")  or 0),
        "discount_amount":  float(row.get("discount_amount") or 0),
        "subtotal":         float(row.get("subtotal")        or 0),
        "grand_total":      float(row.get("grand_total")     or 0),
        "received_amount":  float(row["received_amount"])    if row.get("received_amount") is not None else None,
        "balance_amount":   float(row["balance_amount"])     if row.get("balance_amount")  is not None else None,
        "notes":            row.get("notes"),
        "status":           row.get("status", "active"),
        "is_deleted":       bool(row.get("is_deleted", 0)),
        "created_at":       str(row["created_at"]) if row.get("created_at") else None,
        "updated_at":       str(row["updated_at"]) if row.get("updated_at") else None,
    }
    
    
def _serialize_bill_item(row: dict) -> dict:
    """Convert a lab_bill_items DB row to a JSON-safe dict."""
    return {
        "id":          row.get("id"),
        "bill_id":     row.get("bill_id"),
        "charge_id":   row.get("charge_id"),
        "charge_name": row.get("charge_name"),
        "unit_price":  float(row.get("unit_price") or 0),
        "qty":         int(row.get("qty") or 1),
        "line_total":  float(row.get("line_total") or 0),
    }


# ═══════════════════════════════════════════════════════════
# VALIDATION
# ═══════════════════════════════════════════════════════════

def _validate_bill_payload(data: dict) -> Optional[str]:
    """
    Validate required fields for creating a lab bill.
    Returns an error message string, or None if valid.
    """
    if not (data.get("patient_name") or "").strip():
        return "patient_name is required."
    if not (data.get("patient_contact") or "").strip():
        return "patient_contact is required."
    if not data.get("bill_date"):
        return "bill_date is required."
    if not (data.get("payment_method") or "").strip():
        return "payment_method is required."
    if not data.get("store_id"):
        return "store_id is required."

    items = data.get("items")
    if not items or not isinstance(items, list) or len(items) == 0:
        return "At least one bill item is required."

    for i, item in enumerate(items):
        if not item.get("charge_id"):
            return f"items[{i}].charge_id is required."
        if not (item.get("charge_name") or "").strip():
            return f"items[{i}].charge_name is required."
        if item.get("unit_price") is None:
            return f"items[{i}].unit_price is required."
        if item.get("qty") is None:
            return f"items[{i}].qty is required."
        try:
            qty = int(item["qty"])
            if qty < 1 or qty > 99:
                return f"items[{i}].qty must be between 1 and 99."
        except (TypeError, ValueError):
            return f"items[{i}].qty must be a valid integer."
        try:
            price = float(item["unit_price"])
            if price < 0:
                return f"items[{i}].unit_price must be non-negative."
        except (TypeError, ValueError):
            return f"items[{i}].unit_price must be a valid number."

    return None


# ═══════════════════════════════════════════════════════════
# HELPER: Generate unique bill number
# Format: LB-YYYYMMDD-XXXX  (e.g. LB-20240421-0007)
# ═══════════════════════════════════════════════════════════

def _generate_bill_number(cursor, bill_date: datetime) -> str:
    date_str = bill_date.strftime("%Y%m%d")
    prefix   = f"LB-{date_str}-"
    cursor.execute(
        """
        SELECT bill_number FROM lab_bills
        WHERE bill_number LIKE %s
        ORDER BY id DESC
        LIMIT 1
        FOR UPDATE
        """,
        (f"{prefix}%",)
    )
    row = cursor.fetchone()
    if row:
        try:
            last_seq = int(row["bill_number"].split("-")[-1])
        except (ValueError, IndexError):
            last_seq = 0
        seq = last_seq + 1
    else:
        seq = 1
    return f"{prefix}{seq:04d}"


# ═══════════════════════════════════════════════════════════
# POST /lab_bills
# Create a new lab bill with line items.
# ═══════════════════════════════════════════════════════════

@lab_bill_bp.route('/lab_bills', methods=['POST'])
@jwt_required()
@role_required('admin', 'receptionist', 'cashier')
def create_lab_bill():
    current_user = get_jwt_identity()
    data         = request.get_json(silent=True)

    if not data:
        return jsonify({"success": False, "error": "Request body is missing or not JSON."}), 400

    # ── Validate ──────────────────────────────────────────
    err = _validate_bill_payload(data)
    if err:
        return jsonify({"success": False, "error": err}), 400

    # ── Parse & sanitise fields ───────────────────────────
    patient_name    = data["patient_name"].strip()
    patient_contact = data["patient_contact"].strip()
    patient_nic     = (data.get("patient_nic") or "").strip() or None
    patient_id      = data.get("patient_id") or None
    notes           = (data.get("notes") or "").strip() or None

    payment_method  = data["payment_method"].strip()
    payment_status  = (data.get("payment_status") or "Paid").strip()
    discount_type   = (data.get("discount_type")  or "flat").strip()

    if payment_method not in ("Cash", "Card", "Online"):
        return jsonify({"success": False, "error": "payment_method must be 'Cash', 'Card', or 'Online'."}), 400
    if payment_status not in ("Paid", "Pending"):
        return jsonify({"success": False, "error": "payment_status must be 'Paid' or 'Pending'."}), 400
    if discount_type not in ("flat", "pct"):
        return jsonify({"success": False, "error": "discount_type must be 'flat' or 'pct'."}), 400

    try:
        store_id        = int(data["store_id"])
        patient_age     = int(data["patient_age"])     if data.get("patient_age") not in (None, "", 0) else None
        billed_by       = int(data["billed_by"])       if data.get("billed_by")   not in (None, "")   else (current_user if isinstance(current_user, int) else None)
        patient_id      = int(patient_id)              if patient_id is not None else None
        discount_value  = float(data.get("discount_value")  or 0)
        discount_amount = float(data.get("discount_amount") or 0)
        subtotal        = float(data.get("subtotal")        or 0)
        grand_total     = float(data.get("grand_total")     or 0)
    except (TypeError, ValueError) as ex:
        return jsonify({"success": False, "error": f"Invalid numeric field: {ex}"}), 400

    received_amount = None
    balance_amount  = None
    if payment_method == "Cash":
        try:
            received_amount = float(data["received_amount"]) if data.get("received_amount") is not None else None
            balance_amount  = float(data["balance_amount"])  if data.get("balance_amount")  is not None else None
        except (TypeError, ValueError):
            return jsonify({"success": False, "error": "received_amount and balance_amount must be valid numbers."}), 400

    # Parse bill date
    try:
        bill_date = datetime.strptime(data["bill_date"], "%Y-%m-%d")
    except ValueError:
        return jsonify({"success": False, "error": "bill_date must be in YYYY-MM-DD format."}), 400

    items: List[dict] = data["items"]

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # ── Verify store exists & is active ───────────────
        cursor.execute(
            "SELECT id FROM stores WHERE id = %s AND is_active = 1",
            (store_id,)
        )
        if not cursor.fetchone():
            return jsonify({"success": False, "error": f"Store ID {store_id} not found or not active."}), 404

        # ── Verify patient_id belongs to this store (if provided) ──
        if patient_id is not None:
            cursor.execute(
                "SELECT id FROM customers WHERE id = %s LIMIT 1",
                (patient_id,)
            )
            if not cursor.fetchone():
                return jsonify({"success": False, "error": f"Patient ID {patient_id} not found."}), 404

        # ── Verify all charge_ids are active & belong to this store ──
        charge_ids = [int(item["charge_id"]) for item in items]
        placeholders = ", ".join(["%s"] * len(charge_ids))
        cursor.execute(
            f"""
            SELECT id FROM hospital_charges
            WHERE id IN ({placeholders})
              AND store_id   = %s
              AND status     = 'Active'
              AND is_deleted = 0
            """,
            (*charge_ids, store_id)
        )
        valid_ids = {r["id"] for r in cursor.fetchall()}
        invalid   = [cid for cid in charge_ids if cid not in valid_ids]
        if invalid:
            return jsonify({
                "success": False,
                "error": f"Charge ID(s) {invalid} are not active or do not belong to this branch."
            }), 400

        # ── Server-side total verification ────────────────
        computed_subtotal = sum(
            round(float(item["unit_price"]) * int(item["qty"]), 2)
            for item in items
        )
        if discount_type == "pct":
            computed_discount = round(min(computed_subtotal, computed_subtotal * discount_value / 100), 2)
        else:
            computed_discount = round(min(computed_subtotal, discount_value), 2)
        computed_grand = round(max(0.0, computed_subtotal - computed_discount), 2)

        # Use server-computed values (never trust client totals blindly)
        subtotal        = computed_subtotal
        discount_amount = computed_discount
        grand_total     = computed_grand

        # ── Generate bill number (inside transaction) ──────
        bill_number = _generate_bill_number(cursor, bill_date)

        # ── Insert bill header ─────────────────────────────
        cursor.execute(
            """
            INSERT INTO lab_bills (
                bill_number,
                store_id,
                patient_id,      patient_name,    patient_contact,
                patient_nic,     patient_age,
                bill_date,       billed_by,
                payment_method,  payment_status,
                discount_type,   discount_value,  discount_amount,
                subtotal,        grand_total,
                received_amount, balance_amount,
                notes,           status,          is_deleted,
                created_at
            ) VALUES (
                %s,
                %s,
                %s, %s, %s,
                %s, %s,
                %s, %s,
                %s, %s,
                %s, %s, %s,
                %s, %s,
                %s, %s,
                %s, 'active', 0,
                %s
            )
            """,
            (
                bill_number,
                store_id,
                patient_id,      patient_name,    patient_contact,
                patient_nic,     patient_age,
                bill_date.date(), billed_by,
                payment_method,  payment_status,
                discount_type,   discount_value,  discount_amount,
                subtotal,        grand_total,
                received_amount, balance_amount,
                notes,
                datetime.now(),
            )
        )
        bill_id = cursor.lastrowid

        # ── Insert line items ──────────────────────────────
        item_rows = []
        for item in items:
            qty        = max(1, min(int(item["qty"]), 99))
            unit_price = round(float(item["unit_price"]), 2)
            line_total = round(unit_price * qty, 2)
            item_rows.append((
                bill_id,
                int(item["charge_id"]),
                item["charge_name"].strip(),
                unit_price,
                qty,
                line_total,
            ))

        cursor.executemany(
            """
            INSERT INTO lab_bill_items
                (bill_id, charge_id, charge_name, unit_price, qty, line_total)
            VALUES (%s, %s, %s, %s, %s, %s)
            """,
            item_rows
        )

        conn.commit()

        return jsonify({
            "success":     True,
            "message":     f"Lab bill {bill_number} created successfully.",
            "bill_id":     bill_id,
            "bill_number": bill_number,
            "grand_total": grand_total,
        }), 201

    except mysql.connector.IntegrityError as e:
        if conn: conn.rollback()
        print(f"[create_lab_bill] Integrity error: {e}")
        return jsonify({"success": False, "error": "Bill number conflict. Please retry."}), 409
    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[create_lab_bill] DB error: {e}")
        return jsonify({"success": False, "error": "Database error while creating lab bill."}), 500
    except Exception as e:
        if conn: conn.rollback()
        print(f"[create_lab_bill] Unexpected error: {e}")
        return jsonify({"success": False, "error": "Unexpected server error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ═══════════════════════════════════════════════════════════
# GET /lab_bills
# List bills with filtering & pagination.
# ═══════════════════════════════════════════════════════════

@lab_bill_bp.route('/lab_bills', methods=['GET'])
@jwt_required()
@role_required('admin', 'receptionist', 'cashier')
def get_lab_bills():
    store_id       = request.args.get('store_id',       type=int)
    date           = request.args.get('date',           type=str)
    from_date      = request.args.get('from_date',      type=str)
    to_date        = request.args.get('to_date',        type=str)
    payment_method = request.args.get('payment_method', type=str)
    payment_status = request.args.get('payment_status', type=str)
    status         = request.args.get('status',         type=str)
    search         = request.args.get('search',         type=str)
    page           = request.args.get('page',     default=1,  type=int)
    per_page       = request.args.get('per_page', default=20, type=int)

    if not store_id:
        return jsonify({"success": False, "error": "store_id is required."}), 400

    per_page = max(1, min(per_page, 100))
    page     = max(1, page)
    offset   = (page - 1) * per_page

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        base_where = " WHERE lb.store_id = %s AND lb.is_deleted = 0 "
        params: List = [store_id]

        if date:
            base_where += " AND lb.bill_date = %s "
            params.append(date)
        else:
            if from_date:
                base_where += " AND lb.bill_date >= %s "
                params.append(from_date)
            if to_date:
                base_where += " AND lb.bill_date <= %s "
                params.append(to_date)

        if payment_method and payment_method in ("Cash", "Card", "Online"):
            base_where += " AND lb.payment_method = %s "
            params.append(payment_method)
        if payment_status and payment_status in ("Paid", "Pending"):
            base_where += " AND lb.payment_status = %s "
            params.append(payment_status)
        if status and status in ("active", "cancelled"):
            base_where += " AND lb.status = %s "
            params.append(status)
        if search:
            base_where += """
                AND (
                    lb.patient_name       LIKE %s
                    OR lb.patient_contact LIKE %s
                    OR lb.bill_number     LIKE %s
                    OR lb.patient_nic     LIKE %s
                )
            """
            like = f"%{search}%"
            params.extend([like, like, like, like])

        # Total count
        cursor.execute(
            f"SELECT COUNT(*) AS total FROM lab_bills lb {base_where}",
            params
        )
        total = cursor.fetchone()["total"]

        # Paginated rows — billed_by_name ද include කරනවා
        cursor.execute(
            f"""
            SELECT
                lb.id,              lb.bill_number,
                lb.store_id,        lb.patient_id,
                lb.patient_name,    lb.patient_contact,
                lb.patient_nic,     lb.patient_age,
                lb.bill_date,       lb.billed_by,
                lb.payment_method,  lb.payment_status,
                lb.discount_type,   lb.discount_value,  lb.discount_amount,
                lb.subtotal,        lb.grand_total,
                lb.received_amount, lb.balance_amount,
                lb.notes,           lb.status,
                lb.created_at,      lb.updated_at,
                u.name AS billed_by_name
            FROM lab_bills lb
            LEFT JOIN users u ON u.id = lb.billed_by
            {base_where}
            ORDER BY lb.created_at DESC
            LIMIT %s OFFSET %s
            """,
            params + [per_page, offset]
        )
        bills = [_serialize_bill(r) for r in cursor.fetchall()]

        return jsonify({
            "success":  True,
            "bills":    bills,
            "count":    len(bills),
            "total":    total,
            "page":     page,
            "per_page": per_page,
            "pages":    max(1, (total + per_page - 1) // per_page),
        }), 200

    except mysql.connector.Error as e:
        print(f"[get_lab_bills] DB error: {e}")
        return jsonify({"success": False, "error": "Database error while fetching bills."}), 500
    except Exception as e:
        print(f"[get_lab_bills] Unexpected error: {e}")
        return jsonify({"success": False, "error": "Unexpected server error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ═══════════════════════════════════════════════════════════
# GET /lab_bills/summary
# Daily / range revenue summary for a store.
# Must be registered BEFORE /lab_bills/<id> to avoid Flask
# treating "summary" as an integer parameter.
# ═══════════════════════════════════════════════════════════

@lab_bill_bp.route('/lab_bills/summary', methods=['GET'])
@jwt_required()
@role_required('admin', 'receptionist', 'cashier')
def get_lab_bills_summary():
    store_id  = request.args.get('store_id', type=int)
    date      = request.args.get('date',      type=str)
    from_date = request.args.get('from_date', type=str)
    to_date   = request.args.get('to_date',   type=str)

    if not store_id:
        return jsonify({"success": False, "error": "store_id is required."}), 400

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        where  = " WHERE store_id = %s AND is_deleted = 0 AND status = 'active' "
        params: List = [store_id]

        if date:
            where  += " AND bill_date = %s "
            params.append(date)
        else:
            if from_date:
                where += " AND bill_date >= %s "
                params.append(from_date)
            if to_date:
                where += " AND bill_date <= %s "
                params.append(to_date)

        cursor.execute(
            f"""
            SELECT
                COUNT(*)                                  AS total_bills,
                COALESCE(SUM(grand_total),     0)         AS total_revenue,
                COALESCE(SUM(discount_amount), 0)         AS total_discounts,
                COALESCE(SUM(subtotal),        0)         AS total_subtotal,
                SUM(CASE WHEN payment_status = 'Paid'    THEN 1 ELSE 0 END)  AS paid_count,
                SUM(CASE WHEN payment_status = 'Pending' THEN 1 ELSE 0 END)  AS pending_count,
                SUM(CASE WHEN payment_method = 'Cash'    THEN 1 ELSE 0 END)  AS cash_count,
                SUM(CASE WHEN payment_method = 'Card'    THEN 1 ELSE 0 END)  AS card_count,
                SUM(CASE WHEN payment_method = 'Online'  THEN 1 ELSE 0 END)  AS online_count,
                COALESCE(SUM(CASE WHEN payment_method = 'Cash'   THEN grand_total ELSE 0 END), 0) AS cash_revenue,
                COALESCE(SUM(CASE WHEN payment_method = 'Card'   THEN grand_total ELSE 0 END), 0) AS card_revenue,
                COALESCE(SUM(CASE WHEN payment_method = 'Online' THEN grand_total ELSE 0 END), 0) AS online_revenue
            FROM lab_bills
            {where}
            """,
            params
        )
        summary = cursor.fetchone()

        for key in (
            "total_revenue", "total_discounts", "total_subtotal",
            "cash_revenue",  "card_revenue",     "online_revenue",
        ):
            summary[key] = float(summary[key] or 0)

        for key in (
            "total_bills", "paid_count", "pending_count",
            "cash_count",  "card_count", "online_count",
        ):
            summary[key] = int(summary[key] or 0)

        return jsonify({"success": True, "summary": summary}), 200

    except mysql.connector.Error as e:
        print(f"[get_lab_bills_summary] DB error: {e}")
        return jsonify({"success": False, "error": "Database error."}), 500
    except Exception as e:
        print(f"[get_lab_bills_summary] Unexpected error: {e}")
        return jsonify({"success": False, "error": "Unexpected server error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ═══════════════════════════════════════════════════════════
# GET /lab_bills/<id>
# Returns a single bill with all its line items.
# ✅ FIX: users JOIN added to return billed_by_name
# ═══════════════════════════════════════════════════════════

@lab_bill_bp.route('/lab_bills/<int:bill_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'receptionist', 'cashier')
def get_lab_bill(bill_id: int):
    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # ── users table JOIN කරලා billed_by_name ගන්නවා ──
        cursor.execute(
            """
            SELECT
                lb.id,              lb.bill_number,
                lb.store_id,        lb.patient_id,
                lb.patient_name,    lb.patient_contact,
                lb.patient_nic,     lb.patient_age,
                lb.bill_date,       lb.billed_by,
                lb.payment_method,  lb.payment_status,
                lb.discount_type,   lb.discount_value,  lb.discount_amount,
                lb.subtotal,        lb.grand_total,
                lb.received_amount, lb.balance_amount,
                lb.notes,           lb.status,          lb.is_deleted,
                lb.created_at,      lb.updated_at,
                u.name AS billed_by_name
            FROM lab_bills lb
            LEFT JOIN users u ON u.id = lb.billed_by
            WHERE lb.id = %s AND lb.is_deleted = 0
            """,
            (bill_id,)
        )
        bill = cursor.fetchone()
        if not bill:
            return jsonify({"success": False, "error": "Bill not found."}), 404

        cursor.execute(
            """
            SELECT id, bill_id, charge_id, charge_name, unit_price, qty, line_total
            FROM lab_bill_items
            WHERE bill_id = %s
            ORDER BY id ASC
            """,
            (bill_id,)
        )
        items = [_serialize_bill_item(r) for r in cursor.fetchall()]

        result          = _serialize_bill(bill)
        result["items"] = items

        return jsonify({"success": True, "bill": result}), 200

    except mysql.connector.Error as e:
        print(f"[get_lab_bill] DB error: {e}")
        return jsonify({"success": False, "error": "Database error."}), 500
    except Exception as e:
        print(f"[get_lab_bill] Unexpected error: {e}")
        return jsonify({"success": False, "error": "Unexpected server error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ═══════════════════════════════════════════════════════════
# PUT /lab_bills/<id>
# Update mutable fields only — payment status, payment method,
# received/balance amounts, and notes.
# ═══════════════════════════════════════════════════════════

@lab_bill_bp.route('/lab_bills/<int:bill_id>', methods=['PUT'])
@jwt_required()
@role_required('admin', 'receptionist', 'cashier')
def update_lab_bill(bill_id: int):
    data = request.get_json(silent=True)
    if not data:
        return jsonify({"success": False, "error": "Request body is missing or not JSON."}), 400

    payment_status = (data.get("payment_status") or "").strip() or None
    payment_method = (data.get("payment_method") or "").strip() or None
    notes          = data.get("notes")

    if payment_status and payment_status not in ("Paid", "Pending"):
        return jsonify({"success": False, "error": "payment_status must be 'Paid' or 'Pending'."}), 400
    if payment_method and payment_method not in ("Cash", "Card", "Online"):
        return jsonify({"success": False, "error": "payment_method must be 'Cash', 'Card', or 'Online'."}), 400

    received_amount = None
    balance_amount  = None
    has_received    = "received_amount" in data
    has_balance     = "balance_amount"  in data
    try:
        if has_received and data["received_amount"] is not None:
            received_amount = float(data["received_amount"])
        if has_balance and data["balance_amount"] is not None:
            balance_amount  = float(data["balance_amount"])
    except (TypeError, ValueError):
        return jsonify({"success": False, "error": "received_amount and balance_amount must be valid numbers."}), 400

    if not any([payment_status, payment_method, has_received, has_balance, "notes" in data]):
        return jsonify({"success": False, "error": "No updatable fields provided."}), 400

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute(
            "SELECT id, status FROM lab_bills WHERE id = %s AND is_deleted = 0",
            (bill_id,)
        )
        existing = cursor.fetchone()
        if not existing:
            return jsonify({"success": False, "error": "Bill not found."}), 404
        if existing["status"] == "cancelled":
            return jsonify({"success": False, "error": "Cannot update a cancelled bill."}), 409

        set_parts: List[str] = ["updated_at = %s"]
        set_vals:  List      = [datetime.now()]

        if payment_status:
            set_parts.append("payment_status = %s")
            set_vals.append(payment_status)
        if payment_method:
            set_parts.append("payment_method = %s")
            set_vals.append(payment_method)
        if has_received:
            set_parts.append("received_amount = %s")
            set_vals.append(received_amount)
        if has_balance:
            set_parts.append("balance_amount = %s")
            set_vals.append(balance_amount)
        if "notes" in data:
            set_parts.append("notes = %s")
            set_vals.append((notes or "").strip() or None)

        set_vals.append(bill_id)
        cursor.execute(
            f"UPDATE lab_bills SET {', '.join(set_parts)} WHERE id = %s",
            set_vals
        )
        conn.commit()

        return jsonify({
            "success": True,
            "message": f"Lab bill #{bill_id} updated successfully.",
        }), 200

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[update_lab_bill] DB error: {e}")
        return jsonify({"success": False, "error": "Database error while updating bill."}), 500
    except Exception as e:
        if conn: conn.rollback()
        print(f"[update_lab_bill] Unexpected error: {e}")
        return jsonify({"success": False, "error": "Unexpected server error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ═══════════════════════════════════════════════════════════
# DELETE /lab_bills/<id>
# Soft-delete: sets is_deleted = 1, status = 'cancelled'.
# Admin only.
# ═══════════════════════════════════════════════════════════

@lab_bill_bp.route('/lab_bills/<int:bill_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_lab_bill(bill_id: int):
    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute(
            "SELECT id, bill_number, status FROM lab_bills WHERE id = %s AND is_deleted = 0",
            (bill_id,)
        )
        row = cursor.fetchone()
        if not row:
            return jsonify({"success": False, "error": "Bill not found."}), 404
        if row["status"] == "cancelled":
            return jsonify({"success": False, "error": "Bill is already cancelled."}), 409

        cursor.execute(
            """
            UPDATE lab_bills
            SET is_deleted = 1,
                status     = 'cancelled',
                updated_at = %s
            WHERE id = %s
            """,
            (datetime.now(), bill_id)
        )
        conn.commit()

        return jsonify({
            "success": True,
            "message": f"Lab bill {row['bill_number']} cancelled and deleted successfully.",
        }), 200

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[delete_lab_bill] DB error: {e}")
        return jsonify({"success": False, "error": "Database error while deleting bill."}), 500
    except Exception as e:
        if conn: conn.rollback()
        print(f"[delete_lab_bill] Unexpected error: {e}")
        return jsonify({"success": False, "error": "Unexpected server error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()