{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0467dac7",
   "metadata": {},
   "source": [
    "# DE2 — Assignment 3: Graphs or Clustering\n",
    "\n",
    "**Author:** Badr TAJINI - Data Engineering II (Data-Intensive Workloads) - ESIEE 2025-2026\n",
    "\n",
    "**Track: C — Citi Bike CSV**\n",
    "\n",
    "**Path chosen: Clustering (KMeans)**\n",
    "\n",
    "**Names:** Yannick PRAT & Sara AISSAOUI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2bf36bdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Using incubator modules: jdk.incubator.vector\n",
      "Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "26/05/22 17:59:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spark version: 4.0.0\n",
      "Spark UI: http://127.0.0.1:4040\n",
      "Spark UI (WSL/Windows browser): http://localhost:4040\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from urllib.parse import urlparse\n",
    "from pyspark.sql import SparkSession, functions as F\n",
    "import time, pathlib\n",
    "\n",
    "DE2_SPARK_DRIVER_HOST = os.environ.get(\"DE2_SPARK_DRIVER_HOST\", \"127.0.0.1\")\n",
    "DE2_SPARK_BIND_ADDRESS = os.environ.get(\"DE2_SPARK_BIND_ADDRESS\", \"0.0.0.0\")\n",
    "os.environ.setdefault(\"SPARK_LOCAL_IP\", DE2_SPARK_DRIVER_HOST)\n",
    "\n",
    "\n",
    "def show_spark_ui(spark_session):\n",
    "    ui_url = spark_session.sparkContext.uiWebUrl\n",
    "    print(\"Spark version:\", spark_session.version)\n",
    "    if ui_url:\n",
    "        ui_port = urlparse(ui_url).port or 4040\n",
    "        print(\"Spark UI:\", ui_url)\n",
    "        print(\"Spark UI (WSL/Windows browser):\", f\"http://localhost:{ui_port}\")\n",
    "    else:\n",
    "        print(\"Spark UI: not available\")\n",
    "\n",
    "\n",
    "spark = SparkSession.builder \\\n",
    "    .appName(\"de2-assignment3\") \\\n",
    "    .master(\"local[*]\") \\\n",
    "    .config(\"spark.driver.host\", DE2_SPARK_DRIVER_HOST) \\\n",
    "    .config(\"spark.driver.bindAddress\", DE2_SPARK_BIND_ADDRESS) \\\n",
    "    .config(\"spark.ui.bindAddress\", DE2_SPARK_BIND_ADDRESS) \\\n",
    "    .getOrCreate()\n",
    "\n",
    "show_spark_ui(spark)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85676862",
   "metadata": {},
   "source": [
    "## Step 1 — Data & Features\n",
    "\n",
    "Load the Citi Bike CSV (Track C), clean data, and prepare feature vectors for clustering.\n",
    "We will cluster **stations** based on their usage patterns:\n",
    "- Average trip duration\n",
    "- Number of departures\n",
    "- Number of arrivals\n",
    "- Geographic coordinates (start_lat, start_lng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d26cd1ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Raw rows: 82,272  |  Columns: 13\n",
      "root\n",
      " |-- ride_id: string (nullable = true)\n",
      " |-- rideable_type: string (nullable = true)\n",
      " |-- started_at: timestamp (nullable = true)\n",
      " |-- ended_at: timestamp (nullable = true)\n",
      " |-- start_station_name: string (nullable = true)\n",
      " |-- start_station_id: string (nullable = true)\n",
      " |-- end_station_name: string (nullable = true)\n",
      " |-- end_station_id: string (nullable = true)\n",
      " |-- start_lat: double (nullable = true)\n",
      " |-- start_lng: double (nullable = true)\n",
      " |-- end_lat: double (nullable = true)\n",
      " |-- end_lng: double (nullable = true)\n",
      " |-- member_casual: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import pathlib\n",
    "from pyspark.ml.feature import VectorAssembler, StandardScaler\n",
    "from pyspark.ml.clustering import KMeans, BisectingKMeans\n",
    "from pyspark.ml.evaluation import ClusteringEvaluator\n",
    "from pyspark.sql import functions as F\n",
    "import time\n",
    "\n",
    "# ── Paths ───────────────────────────────────────────────────────────────────\n",
    "DATA_DIR = \"data/JC-202604-citibike-tripdata.csv\"\n",
    "OUT_DIR  = pathlib.Path(\"outputs/lab3\")\n",
    "PROOF_DIR = pathlib.Path(\"proof\")\n",
    "OUT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "PROOF_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# ── Load raw CSV ─────────────────────────────────────────────────────────────\n",
    "raw = spark.read.csv(DATA_DIR, header=True, inferSchema=True)\n",
    "print(f\"Raw rows: {raw.count():,}  |  Columns: {len(raw.columns)}\")\n",
    "raw.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d61ca74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+-------------+-----------------------+-----------------------+------------------+----------------+------------------------------------------+--------------+------------------+------------------+-----------------+------------------+-------------+\n",
      "|ride_id         |rideable_type|started_at             |ended_at               |start_station_name|start_station_id|end_station_name                          |end_station_id|start_lat         |start_lng         |end_lat          |end_lng           |member_casual|\n",
      "+----------------+-------------+-----------------------+-----------------------+------------------+----------------+------------------------------------------+--------------+------------------+------------------+-----------------+------------------+-------------+\n",
      "|558250BE9BDDEF62|classic_bike |2026-04-08 11:01:58.516|2026-04-08 11:15:28.078|City Hall         |JC003           |Southwest Park - Jackson St & Observer Hwy|HB401         |40.7177325        |-74.043845        |40.73755127245804|-74.04166370630264|member       |\n",
      "|DE08A3A0DC829851|electric_bike|2026-04-04 14:28:31.751|2026-04-04 14:31:34.921|6 St & Grand St   |HB302           |Willow Ave & 12 St                        |HB505         |40.744397833095604|-74.03450086712837|40.7518674823282 |-74.03037697076797|member       |\n",
      "|B0434D0A2865B3E2|electric_bike|2026-04-27 18:25:42.87 |2026-04-27 18:30:55.446|6 St & Grand St   |HB302           |Southwest Park - Jackson St & Observer Hwy|HB401         |40.744397833095604|-74.03450086712837|40.73755127245804|-74.04166370630264|casual       |\n",
      "+----------------+-------------+-----------------------+-----------------------+------------------+----------------+------------------------------------------+--------------+------------------+------------------+-----------------+------------------+-------------+\n",
      "only showing top 3 rows\n"
     ]
    }
   ],
   "source": [
    "# ── Inspect available columns ────────────────────────────────────────────────\n",
    "raw.show(3, truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f6932a4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Filtered trips: 81,975\n",
      "Stations: 108\n",
      "+----------------+------------------+------------------+--------------+------------------+------------+------------+------------------+------------+\n",
      "|start_station_id|           avg_lat|           avg_lng|num_departures|      avg_duration|member_trips|casual_trips|      member_ratio|num_arrivals|\n",
      "+----------------+------------------+------------------+--------------+------------------+------------+------------+------------------+------------+\n",
      "|           HB103| 40.73698221818679|-74.02778059244156|          1556| 678.1953727506427|        1058|         498| 0.679948586117815|        1510|\n",
      "|           HB201|40.750604142369276|-74.02402013540268|          1317|  592.870159453303|         943|         374|0.7160212604398511|        1343|\n",
      "|           JC072| 40.71241882375689|-74.03852552175522|           760| 548.3763157894737|         643|         117|0.8460526315778342|         729|\n",
      "|           JC142| 40.71005000000001|-74.08575999999998|            40|           1077.35|          24|          16|0.5999999999850001|          33|\n",
      "|           HB105|40.737359999999754|-74.03096999999983|          1260|464.16825396825396|         978|         282|0.7761904761898601|        1262|\n",
      "+----------------+------------------+------------------+--------------+------------------+------------+------------+------------------+------------+\n",
      "only showing top 5 rows\n"
     ]
    }
   ],
   "source": [
    "# ── Feature engineering: aggregate per start_station_id ──────────────────────\n",
    "# Citi Bike columns: ride_id, rideable_type, started_at, ended_at,\n",
    "#                    start_station_name, start_station_id,\n",
    "#                    end_station_name, end_station_id,\n",
    "#                    start_lat, start_lng, end_lat, end_lng, member_casual\n",
    "\n",
    "# Compute trip duration in seconds\n",
    "trips = raw.withColumn(\n",
    "    \"duration_sec\",\n",
    "    (F.unix_timestamp(\"ended_at\") - F.unix_timestamp(\"started_at\")).cast(\"double\")\n",
    ").filter(\n",
    "    (F.col(\"duration_sec\") > 60) &        # at least 1 min\n",
    "    (F.col(\"duration_sec\") < 7200) &       # at most 2 hours\n",
    "    F.col(\"start_station_id\").isNotNull() &\n",
    "    F.col(\"start_lat\").isNotNull() &\n",
    "    F.col(\"start_lng\").isNotNull()\n",
    ")\n",
    "\n",
    "print(f\"Filtered trips: {trips.count():,}\")\n",
    "\n",
    "# Aggregate features per start station\n",
    "station_features = trips.groupBy(\"start_station_id\").agg(\n",
    "    F.avg(\"start_lat\").alias(\"avg_lat\"),\n",
    "    F.avg(\"start_lng\").alias(\"avg_lng\"),\n",
    "    F.count(\"*\").alias(\"num_departures\"),\n",
    "    F.avg(\"duration_sec\").alias(\"avg_duration\"),\n",
    "    F.sum(F.when(F.col(\"member_casual\") == \"member\", 1).otherwise(0)).alias(\"member_trips\"),\n",
    "    F.sum(F.when(F.col(\"member_casual\") == \"casual\", 1).otherwise(0)).alias(\"casual_trips\")\n",
    ").withColumn(\n",
    "    \"member_ratio\", F.col(\"member_trips\") / (F.col(\"num_departures\") + 1e-9)\n",
    ")\n",
    "\n",
    "# Add arrivals count per station\n",
    "arrivals = trips.filter(F.col(\"end_station_id\").isNotNull()) \\\n",
    "    .groupBy(\"end_station_id\") \\\n",
    "    .agg(F.count(\"*\").alias(\"num_arrivals\"))\n",
    "\n",
    "station_features = station_features.join(\n",
    "    arrivals.withColumnRenamed(\"end_station_id\", \"start_station_id\"),\n",
    "    on=\"start_station_id\", how=\"left\"\n",
    ").fillna({\"num_arrivals\": 0})\n",
    "\n",
    "print(f\"Stations: {station_features.count():,}\")\n",
    "station_features.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "16034163",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 35:===========================================>         (166 + 13) / 200]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature DataFrame ready and cached.\n",
      "+----------------+--------------------------------------------------------------------------------------------------------------------------+\n",
      "|start_station_id|features                                                                                                                  |\n",
      "+----------------+--------------------------------------------------------------------------------------------------------------------------+\n",
      "|HB103           |[0.7117122859992427,1.2087819691135173,1.2867095663838983,0.15036766408706023,-0.5637394344697388,1.200704252908227]      |\n",
      "|HB201           |[1.5828119089939947,1.3880784434143996,0.900844692062088,-0.20549681006091636,-0.24237650928805451,0.9357808101423631]    |\n",
      "|JC072           |[-0.859076344204948,0.6964697893351047,0.0015696502325969594,-0.39106655745871843,0.9160422267677948,-0.03824909320039416]|\n",
      "+----------------+--------------------------------------------------------------------------------------------------------------------------+\n",
      "only showing top 3 rows\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    }
   ],
   "source": [
    "# ── Assemble & Normalize feature vector ──────────────────────────────────────\n",
    "FEATURE_COLS = [\"avg_lat\", \"avg_lng\", \"num_departures\", \"avg_duration\",\n",
    "                \"member_ratio\", \"num_arrivals\"]\n",
    "\n",
    "assembler = VectorAssembler(inputCols=FEATURE_COLS, outputCol=\"raw_features\",\n",
    "                            handleInvalid=\"skip\")\n",
    "scaler = StandardScaler(inputCol=\"raw_features\", outputCol=\"features\",\n",
    "                        withStd=True, withMean=True)\n",
    "\n",
    "assembled = assembler.transform(station_features)\n",
    "scaler_model = scaler.fit(assembled)\n",
    "feature_df = scaler_model.transform(assembled)\n",
    "\n",
    "# Cache for reuse across iterations\n",
    "feature_df = feature_df.cache()\n",
    "feature_df.count()  # materialize cache\n",
    "print(\"Feature DataFrame ready and cached.\")\n",
    "feature_df.select(\"start_station_id\", \"features\").show(3, truncate=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fa2178d",
   "metadata": {},
   "source": [
    "## Step 2 — Iterative Algorithm\n",
    "\n",
    "We sweep KMeans over **k = 2 to 8** (≥5 iterations) and log per-iteration metrics:\n",
    "- Silhouette score\n",
    "- Training time\n",
    "- WSSSE (within-cluster sum of squared errors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "db76ea15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   k |   Silhouette |          WSSSE |   Time(s)\n",
      "------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "26/05/22 18:00:17 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS\n",
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2 |     0.676834 |         326.04 |     14.58\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   3 |     0.503423 |         234.08 |     10.77\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   4 |     0.377225 |         205.69 |      9.81\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   5 |     0.463861 |         164.34 |      7.72\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   6 |     0.480858 |         135.86 |      8.81\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   7 |     0.424769 |         125.12 |     13.65\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   8 |     0.407141 |         132.39 |     10.18\n",
      "\n",
      "→ Best k = 2  |  Best silhouette = 0.676834\n"
     ]
    }
   ],
   "source": [
    "import csv\n",
    "\n",
    "evaluator = ClusteringEvaluator(featuresCol=\"features\", metricName=\"silhouette\",\n",
    "                                distanceMeasure=\"squaredEuclidean\")\n",
    "\n",
    "metrics_rows = []   # will be written to lab3_metrics_log.csv\n",
    "k_range = range(2, 9)   # k = 2..8  → 7 iterations\n",
    "\n",
    "print(f\"{'k':>4} | {'Silhouette':>12} | {'WSSSE':>14} | {'Time(s)':>9}\")\n",
    "print(\"-\" * 48)\n",
    "\n",
    "best_k, best_sil, best_model, best_predictions = 2, -1, None, None\n",
    "\n",
    "for k in k_range:\n",
    "    t0 = time.time()\n",
    "    km = KMeans(k=k, seed=42, featuresCol=\"features\", predictionCol=\"prediction\",\n",
    "                maxIter=20, tol=1e-4)\n",
    "    model = km.fit(feature_df)\n",
    "    preds = model.transform(feature_df)\n",
    "    sil   = evaluator.evaluate(preds)\n",
    "    wssse = model.summary.trainingCost\n",
    "    elapsed = time.time() - t0\n",
    "\n",
    "    print(f\"{k:>4} | {sil:>12.6f} | {wssse:>14.2f} | {elapsed:>9.2f}\")\n",
    "    metrics_rows.append({\n",
    "        \"step\": \"kmeans_sweep\", \"partitions\": spark.conf.get(\"spark.sql.shuffle.partitions\"),\n",
    "        \"k\": k, \"seed\": 42, \"silhouette\": round(sil, 6),\n",
    "        \"wssse\": round(wssse, 4), \"elapsed_sec\": round(elapsed, 3), \"strategy\": \"default\"\n",
    "    })\n",
    "\n",
    "    if sil > best_sil:\n",
    "        best_sil, best_k = sil, k\n",
    "        best_model, best_predictions = model, preds\n",
    "\n",
    "print(f\"\\n→ Best k = {best_k}  |  Best silhouette = {best_sil:.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cd975ce",
   "metadata": {},
   "source": [
    "## Step 3 — Partitioning Experiment\n",
    "\n",
    "We compare two strategies for `k = best_k`:\n",
    "\n",
    "| Strategy | `spark.sql.shuffle.partitions` | Notes |\n",
    "|----------|-------------------------------|-------|\n",
    "| **Before** | 200 (default) | No explicit repartition |\n",
    "| **After**  | 8 (tuned to cores) | Repartition on station id |\n",
    "\n",
    "We capture query plans and measure shuffle bytes via Spark metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f20a8061",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[BEFORE] k=2 | shuffle_partitions=200 | sil=0.676834 | wssse=326.04 | time=5.92s\n",
      "plan_before.txt saved.\n"
     ]
    }
   ],
   "source": [
    "# ── BEFORE: default partitioning (200 shuffle partitions) ────────────────────\n",
    "spark.conf.set(\"spark.sql.shuffle.partitions\", \"200\")\n",
    "\n",
    "t0 = time.time()\n",
    "km_before = KMeans(k=best_k, seed=42, featuresCol=\"features\", predictionCol=\"prediction\",\n",
    "                   maxIter=20)\n",
    "model_before = km_before.fit(feature_df)\n",
    "preds_before = model_before.transform(feature_df)\n",
    "sil_before   = evaluator.evaluate(preds_before)\n",
    "wssse_before = model_before.summary.trainingCost\n",
    "t_before = time.time() - t0\n",
    "\n",
    "print(f\"[BEFORE] k={best_k} | shuffle_partitions=200 | \"\n",
    "      f\"sil={sil_before:.6f} | wssse={wssse_before:.2f} | time={t_before:.2f}s\")\n",
    "\n",
    "# Save query plan\n",
    "plan_before = preds_before._jdf.queryExecution().toString()\n",
    "(PROOF_DIR / \"plan_before.txt\").write_text(plan_before)\n",
    "print(\"plan_before.txt saved.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8538b18c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[AFTER]  k=2 | shuffle_partitions=12 | sil=0.676834 | wssse=326.04 | time=1.13s\n",
      "Speed-up: 5.23x\n",
      "plan_after.txt saved.\n"
     ]
    }
   ],
   "source": [
    "# ── AFTER: tuned partitioning ────────────────────────────────────────────────\n",
    "num_cores = spark.sparkContext.defaultParallelism\n",
    "tuned_partitions = max(num_cores, 8)\n",
    "spark.conf.set(\"spark.sql.shuffle.partitions\", str(tuned_partitions))\n",
    "\n",
    "# Repartition feature_df by station_id hash to reduce shuffle skew\n",
    "feature_df_repartitioned = feature_df.repartition(tuned_partitions, \"start_station_id\").cache()\n",
    "feature_df_repartitioned.count()  # materialize\n",
    "\n",
    "t0 = time.time()\n",
    "km_after = KMeans(k=best_k, seed=42, featuresCol=\"features\", predictionCol=\"prediction\",\n",
    "                  maxIter=20)\n",
    "model_after = km_after.fit(feature_df_repartitioned)\n",
    "preds_after = model_after.transform(feature_df_repartitioned)\n",
    "sil_after   = evaluator.evaluate(preds_after)\n",
    "wssse_after = model_after.summary.trainingCost\n",
    "t_after = time.time() - t0\n",
    "\n",
    "print(f\"[AFTER]  k={best_k} | shuffle_partitions={tuned_partitions} | \"\n",
    "      f\"sil={sil_after:.6f} | wssse={wssse_after:.2f} | time={t_after:.2f}s\")\n",
    "print(f\"Speed-up: {t_before/t_after:.2f}x\")\n",
    "\n",
    "# Save query plan\n",
    "plan_after = preds_after._jdf.queryExecution().toString()\n",
    "(PROOF_DIR / \"plan_after.txt\").write_text(plan_after)\n",
    "print(\"plan_after.txt saved.\")\n",
    "\n",
    "# Add to metrics log\n",
    "for tag, sil, wssse, t, parts in [\n",
    "        (\"before\", sil_before, wssse_before, t_before, 200),\n",
    "        (\"after\",  sil_after,  wssse_after,  t_after,  tuned_partitions)]:\n",
    "    metrics_rows.append({\n",
    "        \"step\": \"partition_exp\", \"partitions\": parts,\n",
    "        \"k\": best_k, \"seed\": 42, \"silhouette\": round(sil, 6),\n",
    "        \"wssse\": round(wssse, 4), \"elapsed_sec\": round(t, 3), \"strategy\": tag\n",
    "    })"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1583a551",
   "metadata": {},
   "source": [
    "## Step 4 — Convergence / Stability Analysis\n",
    "\n",
    "Seed stability analysis over **≥5 seeds** for the best k.\n",
    "We record mean ± std of silhouette to verify robustness."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8563a188",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Seed stability analysis for k=2  (7 seeds)\n",
      "  Seed |   Silhouette |          WSSSE |   Time(s)\n",
      "--------------------------------------------------\n",
      "     0 |     0.676834 |         326.04 |      1.18\n",
      "     7 |     0.676834 |         326.04 |      1.20\n",
      "    42 |     0.676834 |         326.04 |      0.98\n",
      "    99 |     0.676834 |         326.04 |      1.07\n",
      "   123 |     0.676834 |         326.04 |      1.09\n",
      "   256 |     0.676834 |         326.04 |      0.97\n",
      "   512 |     0.676834 |         326.04 |      0.94\n",
      "\n",
      "Silhouette — mean: 0.676834  std: 0.000000\n",
      "WSSSE      — mean: 326.04  std: 0.00\n"
     ]
    }
   ],
   "source": [
    "import statistics\n",
    "\n",
    "spark.conf.set(\"spark.sql.shuffle.partitions\", str(tuned_partitions))\n",
    "SEEDS = [0, 7, 42, 99, 123, 256, 512]\n",
    "\n",
    "sil_scores = []\n",
    "wssse_scores = []\n",
    "print(f\"Seed stability analysis for k={best_k}  ({len(SEEDS)} seeds)\")\n",
    "print(f\"{'Seed':>6} | {'Silhouette':>12} | {'WSSSE':>14} | {'Time(s)':>9}\")\n",
    "print(\"-\" * 50)\n",
    "\n",
    "for seed in SEEDS:\n",
    "    t0 = time.time()\n",
    "    km_s = KMeans(k=best_k, seed=seed, featuresCol=\"features\",\n",
    "                  predictionCol=\"prediction\", maxIter=20)\n",
    "    m_s  = km_s.fit(feature_df_repartitioned)\n",
    "    p_s  = m_s.transform(feature_df_repartitioned)\n",
    "    sil_s  = evaluator.evaluate(p_s)\n",
    "    wssse_s = m_s.summary.trainingCost\n",
    "    elapsed = time.time() - t0\n",
    "    sil_scores.append(sil_s)\n",
    "    wssse_scores.append(wssse_s)\n",
    "    print(f\"{seed:>6} | {sil_s:>12.6f} | {wssse_s:>14.2f} | {elapsed:>9.2f}\")\n",
    "    metrics_rows.append({\n",
    "        \"step\": \"seed_stability\", \"partitions\": tuned_partitions,\n",
    "        \"k\": best_k, \"seed\": seed, \"silhouette\": round(sil_s, 6),\n",
    "        \"wssse\": round(wssse_s, 4), \"elapsed_sec\": round(elapsed, 3), \"strategy\": \"after\"\n",
    "    })\n",
    "\n",
    "mean_sil = statistics.mean(sil_scores)\n",
    "std_sil  = statistics.stdev(sil_scores)\n",
    "print(f\"\\nSilhouette — mean: {mean_sil:.6f}  std: {std_sil:.6f}\")\n",
    "print(f\"WSSSE      — mean: {statistics.mean(wssse_scores):.2f}  \"\n",
    "      f\"std: {statistics.stdev(wssse_scores):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f374da3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Elbow summary (from k-sweep):\n",
      "   k |          WSSSE |   Silhouette\n",
      "-----------------------------------\n",
      "   2 |       326.0415 |     0.676834\n",
      "   3 |       234.0813 |     0.503423\n",
      "   4 |       205.6939 |     0.377225\n",
      "   5 |       164.3434 |     0.463861\n",
      "   6 |       135.8636 |     0.480858\n",
      "   7 |       125.1229 |     0.424769\n",
      "   8 |       132.3893 |     0.407141\n"
     ]
    }
   ],
   "source": [
    "# ── Elbow curve (WSSSE vs k) — text table ────────────────────────────────────\n",
    "print(\"\\nElbow summary (from k-sweep):\")\n",
    "print(f\"{'k':>4} | {'WSSSE':>14} | {'Silhouette':>12}\")\n",
    "print(\"-\" * 35)\n",
    "for r in metrics_rows:\n",
    "    if r[\"step\"] == \"kmeans_sweep\":\n",
    "        print(f\"{r['k']:>4} | {r['wssse']:>14.4f} | {r['silhouette']:>12.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f5be6052",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cluster assignments written to outputs/lab3/cluster_assignments/\n",
      "+----------+-----+\n",
      "|prediction|count|\n",
      "+----------+-----+\n",
      "|         0|   80|\n",
      "|         1|   28|\n",
      "+----------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# ── Save best cluster assignments ────────────────────────────────────────────\n",
    "cluster_out = preds_after.select(\n",
    "    \"start_station_id\", \"avg_lat\", \"avg_lng\",\n",
    "    \"num_departures\", \"num_arrivals\",\n",
    "    \"avg_duration\", \"member_ratio\", \"prediction\"\n",
    ")\n",
    "cluster_out.write.mode(\"overwrite\").csv(str(OUT_DIR / \"cluster_assignments\"), header=True)\n",
    "print(f\"Cluster assignments written to {OUT_DIR}/cluster_assignments/\")\n",
    "\n",
    "# Cluster sizes\n",
    "cluster_out.groupBy(\"prediction\").count().orderBy(\"prediction\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80142537",
   "metadata": {},
   "source": [
    "## Step 5 — Evidence & Metrics\n",
    "\n",
    "Save plans, fill `lab3_metrics_log.csv`, and capture an iteration-level plan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0ea75972",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "plan_iteration.txt saved.\n",
      "lab3_metrics_log.csv written (16 rows).\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>step</th>\n",
       "      <th>partitions</th>\n",
       "      <th>k</th>\n",
       "      <th>seed</th>\n",
       "      <th>silhouette</th>\n",
       "      <th>wssse</th>\n",
       "      <th>elapsed_sec</th>\n",
       "      <th>strategy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>2</td>\n",
       "      <td>42</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>14.575</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>3</td>\n",
       "      <td>42</td>\n",
       "      <td>0.503423</td>\n",
       "      <td>234.0813</td>\n",
       "      <td>10.775</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>4</td>\n",
       "      <td>42</td>\n",
       "      <td>0.377225</td>\n",
       "      <td>205.6939</td>\n",
       "      <td>9.807</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>5</td>\n",
       "      <td>42</td>\n",
       "      <td>0.463861</td>\n",
       "      <td>164.3434</td>\n",
       "      <td>7.717</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>6</td>\n",
       "      <td>42</td>\n",
       "      <td>0.480858</td>\n",
       "      <td>135.8636</td>\n",
       "      <td>8.811</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>7</td>\n",
       "      <td>42</td>\n",
       "      <td>0.424769</td>\n",
       "      <td>125.1229</td>\n",
       "      <td>13.645</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>kmeans_sweep</td>\n",
       "      <td>200</td>\n",
       "      <td>8</td>\n",
       "      <td>42</td>\n",
       "      <td>0.407141</td>\n",
       "      <td>132.3893</td>\n",
       "      <td>10.184</td>\n",
       "      <td>default</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>partition_exp</td>\n",
       "      <td>200</td>\n",
       "      <td>2</td>\n",
       "      <td>42</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>5.920</td>\n",
       "      <td>before</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>partition_exp</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>42</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>1.133</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>1.183</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>7</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>1.199</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>42</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>0.977</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>99</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>1.073</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>123</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>1.093</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>256</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>0.969</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>seed_stability</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>512</td>\n",
       "      <td>0.676834</td>\n",
       "      <td>326.0415</td>\n",
       "      <td>0.944</td>\n",
       "      <td>after</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              step  partitions  k  seed  silhouette     wssse  elapsed_sec  \\\n",
       "0     kmeans_sweep         200  2    42    0.676834  326.0415       14.575   \n",
       "1     kmeans_sweep         200  3    42    0.503423  234.0813       10.775   \n",
       "2     kmeans_sweep         200  4    42    0.377225  205.6939        9.807   \n",
       "3     kmeans_sweep         200  5    42    0.463861  164.3434        7.717   \n",
       "4     kmeans_sweep         200  6    42    0.480858  135.8636        8.811   \n",
       "5     kmeans_sweep         200  7    42    0.424769  125.1229       13.645   \n",
       "6     kmeans_sweep         200  8    42    0.407141  132.3893       10.184   \n",
       "7    partition_exp         200  2    42    0.676834  326.0415        5.920   \n",
       "8    partition_exp          12  2    42    0.676834  326.0415        1.133   \n",
       "9   seed_stability          12  2     0    0.676834  326.0415        1.183   \n",
       "10  seed_stability          12  2     7    0.676834  326.0415        1.199   \n",
       "11  seed_stability          12  2    42    0.676834  326.0415        0.977   \n",
       "12  seed_stability          12  2    99    0.676834  326.0415        1.073   \n",
       "13  seed_stability          12  2   123    0.676834  326.0415        1.093   \n",
       "14  seed_stability          12  2   256    0.676834  326.0415        0.969   \n",
       "15  seed_stability          12  2   512    0.676834  326.0415        0.944   \n",
       "\n",
       "   strategy  \n",
       "0   default  \n",
       "1   default  \n",
       "2   default  \n",
       "3   default  \n",
       "4   default  \n",
       "5   default  \n",
       "6   default  \n",
       "7    before  \n",
       "8     after  \n",
       "9     after  \n",
       "10    after  \n",
       "11    after  \n",
       "12    after  \n",
       "13    after  \n",
       "14    after  \n",
       "15    after  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# ── Save per-iteration plan (one KMeans iteration plan as proxy) ──────────────\n",
    "km_single = KMeans(k=best_k, seed=42, featuresCol=\"features\",\n",
    "                   predictionCol=\"prediction\", maxIter=1)\n",
    "m_single  = km_single.fit(feature_df_repartitioned)\n",
    "p_single  = m_single.transform(feature_df_repartitioned)\n",
    "plan_iter = p_single._jdf.queryExecution().toString()\n",
    "(PROOF_DIR / \"plan_iteration.txt\").write_text(plan_iter)\n",
    "print(\"plan_iteration.txt saved.\")\n",
    "\n",
    "# ── Write lab3_metrics_log.csv ────────────────────────────────────────────────\n",
    "fieldnames = [\"step\", \"partitions\", \"k\", \"seed\", \"silhouette\",\n",
    "              \"wssse\", \"elapsed_sec\", \"strategy\"]\n",
    "\n",
    "with open(\"lab3_metrics_log.csv\", \"w\", newline=\"\") as f:\n",
    "    writer = csv.DictWriter(f, fieldnames=fieldnames)\n",
    "    writer.writeheader()\n",
    "    writer.writerows(metrics_rows)\n",
    "\n",
    "print(f\"lab3_metrics_log.csv written ({len(metrics_rows)} rows).\")\n",
    "\n",
    "# Preview\n",
    "import pandas as pd\n",
    "pd.read_csv(\"lab3_metrics_log.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7ab567f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "ASSIGNMENT 3 — SUMMARY\n",
      "============================================================\n",
      "Track          : C — Citi Bike CSV\n",
      "Path           : Clustering (KMeans)\n",
      "Dataset        : data/JC-202604-citibike-tripdata.csv\n",
      "Best k         : 2\n",
      "Best silhouette: 0.676834\n",
      "\n",
      "Partitioning comparison:\n",
      "  BEFORE (200 partitions) — sil=0.676834  time=5.92s\n",
      "  AFTER  (12 partitions)   — sil=0.676834  time=1.13s\n",
      "  Speed-up: 5.23x\n",
      "\n",
      "Seed stability (k=2, 7 seeds):\n",
      "  Silhouette mean=0.676834  std=0.000000\n",
      "\n",
      "Deliverables:\n",
      "  outputs/lab3/cluster_assignments/\n",
      "  proof/plan_before.txt\n",
      "  proof/plan_after.txt\n",
      "  proof/plan_iteration.txt\n",
      "  lab3_metrics_log.csv\n",
      "============================================================\n"
     ]
    }
   ],
   "source": [
    "# ── Summary printout ──────────────────────────────────────────────────────────\n",
    "print(\"=\" * 60)\n",
    "print(\"ASSIGNMENT 3 — SUMMARY\")\n",
    "print(\"=\" * 60)\n",
    "print(f\"Track          : C — Citi Bike CSV\")\n",
    "print(f\"Path           : Clustering (KMeans)\")\n",
    "print(f\"Dataset        : {DATA_DIR}\")\n",
    "print(f\"Best k         : {best_k}\")\n",
    "print(f\"Best silhouette: {best_sil:.6f}\")\n",
    "print()\n",
    "print(\"Partitioning comparison:\")\n",
    "print(f\"  BEFORE (200 partitions) — sil={sil_before:.6f}  time={t_before:.2f}s\")\n",
    "print(f\"  AFTER  ({tuned_partitions} partitions)   — sil={sil_after:.6f}  time={t_after:.2f}s\")\n",
    "print(f\"  Speed-up: {t_before/t_after:.2f}x\")\n",
    "print()\n",
    "print(f\"Seed stability (k={best_k}, {len(SEEDS)} seeds):\")\n",
    "print(f\"  Silhouette mean={mean_sil:.6f}  std={std_sil:.6f}\")\n",
    "print()\n",
    "print(\"Deliverables:\")\n",
    "print(f\"  outputs/lab3/cluster_assignments/\")\n",
    "print(f\"  proof/plan_before.txt\")\n",
    "print(f\"  proof/plan_after.txt\")\n",
    "print(f\"  proof/plan_iteration.txt\")\n",
    "print(f\"  lab3_metrics_log.csv\")\n",
    "print(\"=\" * 60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0ef05f95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done.\n"
     ]
    }
   ],
   "source": [
    "spark.stop()\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a893f71-29c2-49b6-9890-23d8b2198213",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (de2-env)",
   "language": "python",
   "name": "de2-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
