{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "a148f700-b14e-4dc3-a203-9e2f93955587", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "📂 Working Directory: /home/jovyan/cloud_train/dataset\n", "✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n", "Using device: cpu\n", "✅ Data Ready: 29117 samples\n" ] } ], "source": [ "# ==========================================\n", "# CELL 1: SETUP & LOAD DỮ LIỆU\n", "# ==========================================\n", "import os, sys, torch, numpy as np\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "# Fix Import Path\n", "sys.path.append(os.getcwd())\n", "sys.path.append(\"/home/jovyan/cloud_train\")\n", "\n", "try:\n", " from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n", "except ImportError:\n", " raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n", "\n", "class SEN12MSCR_TorchDataset(Dataset):\n", " def __init__(self, base_dir, season=Seasons.SPRING):\n", " self.loader = SEN12MSCRDataset(base_dir)\n", " self.season = season\n", " self.season_ids = self.loader.get_season_ids(season)\n", " self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n", " def __len__(self): return len(self.samples)\n", " def __getitem__(self, idx):\n", " sid, pid = self.samples[idx]\n", " s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n", " s1 = np.clip(s1, -25, 0) / 25.0\n", " s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n", " c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n", " return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n", "\n", "train_dataset = SEN12MSCR_TorchDataset(base_dir)\n", "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n", "print(f\"✅ GLF-CR Data Ready: {len(train_dataset)} samples\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "cdb8ff0f-ea65-4ed4-8f5b-b99b05c602a8", "metadata": {}, "outputs": [], "source": [ "\n", "# ==========================================\n", "# CELL 2: MODEL GLF-CR\n", "# ==========================================\n", "class FusionBlock(nn.Module):\n", " def __init__(self, channels):\n", " super().__init__()\n", " self.conv = nn.Sequential(\n", " nn.Conv2d(channels * 2, channels, 1),\n", " nn.BatchNorm2d(channels), nn.ReLU()\n", " )\n", " def forward(self, x1, x2):\n", " return self.conv(torch.cat([x1, x2], dim=1))\n", "\n", "class GLF_Generator(nn.Module):\n", " def __init__(self):\n", " super(GLF_Generator, self).__init__()\n", " # 2 Nhánh riêng biệt\n", " self.enc_s1 = nn.Sequential(nn.Conv2d(2, 64, 3, 1, 1), nn.ReLU()) # Radar\n", " self.enc_s2 = nn.Sequential(nn.Conv2d(13, 64, 3, 1, 1), nn.ReLU()) # Mây\n", " self.fuse = FusionBlock(64)\n", " self.dec = nn.Sequential(\n", " nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),\n", " nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(),\n", " nn.Conv2d(64, 13, 3, 1, 1)\n", " )\n", " self.tanh = nn.Tanh()\n", "\n", " def forward(self, s1, s2_cloudy):\n", " f1 = self.enc_s1(s1)\n", " f2 = self.enc_s2(s2_cloudy)\n", " f_fused = self.fuse(f1, f2)\n", " return self.tanh(self.dec(f_fused))" ] }, { "cell_type": "code", "execution_count": 7, "id": "b124a548-a7fe-46e7-939d-ae3ff5e5b83c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 Bắt đầu train GLF-CR...\n", "Step 0 | MSE Loss: 0.4782\n", "Step 10 | MSE Loss: 0.0848\n", "Step 20 | MSE Loss: 0.0446\n", "Step 30 | MSE Loss: 0.0274\n", "Step 40 | MSE Loss: 0.2338\n", "Step 50 | MSE Loss: 0.0158\n", "\n", "💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/GLFCR_generator.pth\n" ] } ], "source": [ "# ==========================================\n", "# CELL 3: TRAINING & SAVE\n", "# ==========================================\n", "if 'train_loader' in locals():\n", " gen = GLF_Generator().to(device)\n", " opt = torch.optim.Adam(gen.parameters(), lr=0.0002)\n", " crit = nn.MSELoss() \n", " \n", " print(\"🚀 Bắt đầu train GLF-CR...\")\n", " for i, (s1, s2, cloud) in enumerate(train_loader):\n", " s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n", " \n", " opt.zero_grad()\n", " out = gen(s1, cloud)\n", " loss = crit(out, s2)\n", " loss.backward()\n", " opt.step()\n", " \n", " if i % 10 == 0: print(f\"Step {i} | MSE Loss: {loss.item():.4f}\")\n", " if i > 50: break\n", " \n", " # --- LƯU MODEL ---\n", " save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n", " os.makedirs(save_dir, exist_ok=True)\n", " torch.save(gen.state_dict(), os.path.join(save_dir, \"GLFCR_generator.pth\"))\n", " print(f\"\\n💾 Đã lưu model tại: {save_dir}/GLFCR_generator.pth\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1f863772-5925-45b4-a9aa-e1dc54f52658", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }