diff --git a/Train_CR-GAN.ipynb b/Train_CR-GAN.ipynb new file mode 100644 index 0000000..0df367a --- /dev/null +++ b/Train_CR-GAN.ipynb @@ -0,0 +1,568 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c886b44b-4288-4260-8b5b-bb78a87e172f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Working Directory: /home/jovyan/cloud_train/dataset\n", + "✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n", + "Using device: cpu\n", + "✅ Data Ready: 29117 samples\n" + ] + } + ], + "source": [ + "# ==========================================\n", + "# CELL 1: SETUP & LOAD DỮ LIỆU (ĐÃ SỬA LỖI IMPORT)\n", + "# ==========================================\n", + "import os\n", + "import sys\n", + "import torch\n", + "import numpy as np\n", + "from torch.utils.data import Dataset, DataLoader\n", + "import torch.nn as nn\n", + "\n", + "# --- KHẮC PHỤC LỖI KHÔNG TÌM THẤY MODULE ---\n", + "# Thêm đường dẫn thư mục hiện tại vào hệ thống để Python tìm thấy file loader\n", + "current_dir = os.getcwd()\n", + "sys.path.append(current_dir)\n", + "print(f\"📂 Working Directory: {current_dir}\")\n", + "\n", + "try:\n", + " from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n", + " print(\"✅ Đã import thành công sen12ms_cr_dataLoader!\")\n", + "except ImportError:\n", + " # Nếu vẫn lỗi, thử trỏ cứng vào thư mục cloud_train\n", + " sys.path.append(\"/home/jovyan/cloud_train\")\n", + " try:\n", + " from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n", + " print(\"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\")\n", + " except ImportError:\n", + " raise RuntimeError(\"❌ Vẫn không tìm thấy file 'sen12ms_cr_dataLoader.py'. Hãy đảm bảo file này nằm cùng thư mục với Notebook!\")\n", + "\n", + "# Cấu hình thiết bị\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Đường dẫn dataset\n", + "base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n", + "\n", + "# Dataset Wrapper\n", + "class SEN12MSCR_TorchDataset(Dataset):\n", + " def __init__(self, base_dir, season=Seasons.SPRING):\n", + " self.loader = SEN12MSCRDataset(base_dir)\n", + " self.season = season\n", + " self.season_ids = self.loader.get_season_ids(season)\n", + " self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n", + " def __len__(self): return len(self.samples)\n", + " def __getitem__(self, idx):\n", + " sid, pid = self.samples[idx]\n", + " s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n", + " s1 = np.clip(s1, -25, 0) / 25.0\n", + " s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n", + " c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n", + " return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n", + "\n", + "try:\n", + " train_dataset = SEN12MSCR_TorchDataset(base_dir)\n", + " train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n", + " print(f\"✅ Data Ready: {len(train_dataset)} samples\")\n", + "except Exception as e:\n", + " print(f\"❌ Lỗi Data: {e}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a6357434-8b1c-48ef-ac84-7afadc4ce04a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# ==========================================\n", + "# CELL 2: ĐỊNH NGHĨA MODEL CR-GAN\n", + "# ==========================================\n", + "class CR_Generator(nn.Module):\n", + " def __init__(self):\n", + " super(CR_Generator, self).__init__()\n", + " # Input: 13 kênh (Mây) + 2 kênh (Radar) = 15\n", + " self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n", + " self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n", + " self.enc3 = nn.Sequential(nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2))\n", + " \n", + " self.dec1 = nn.Sequential(nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU())\n", + " self.dec2 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n", + " self.dec3 = nn.ConvTranspose2d(64, 13, 4, 2, 1) # Output 13 kênh sạch\n", + " self.tanh = nn.Tanh()\n", + "\n", + " def forward(self, s1, s2_cloudy):\n", + " x = torch.cat([s1, s2_cloudy], dim=1)\n", + " e1 = self.enc1(x)\n", + " e2 = self.enc2(e1)\n", + " e3 = self.enc3(e2)\n", + " \n", + " d1 = self.dec1(e3)\n", + " d2 = self.dec2(d1 + e2) # Skip connection\n", + " d3 = self.dec3(d2 + e1)\n", + " return self.tanh(d3)\n", + "\n", + "class Discriminator(nn.Module):\n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " # Input: 13 (Ảnh) + 15 (Điều kiện) = 28\n", + " self.model = nn.Sequential(\n", + " nn.Conv2d(28, 64, 4, 2, 1), nn.LeakyReLU(0.2),\n", + " nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),\n", + " nn.Conv2d(128, 1, 4, 1, 0), nn.Sigmoid()\n", + " )\n", + " def forward(self, img, cond1, cond2):\n", + " return self.model(torch.cat([img, cond1, cond2], 1))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09276904-a4c9-4774-b171-32d2e577afa0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🚀 Bắt đầu huấn luyện CR-GAN...\n", + "Epoch 0 | Step 0 | Loss G: 72.3185 | Loss D: 0.7102\n", + "Epoch 0 | Step 10 | Loss G: 37.1491 | Loss D: 0.5784\n", + "Epoch 0 | Step 20 | Loss G: 19.9918 | Loss D: 0.5008\n", + "Epoch 0 | Step 30 | Loss G: 21.9959 | Loss D: 0.6595\n", + "Epoch 0 | Step 40 | Loss G: 10.0374 | Loss D: 0.5188\n", + "Epoch 0 | Step 50 | Loss G: 7.7058 | Loss D: 0.5300\n", + "Epoch 0 | Step 60 | Loss G: 6.9481 | Loss D: 0.5223\n", + "Epoch 0 | Step 70 | Loss G: 12.7031 | Loss D: 0.6349\n", + "Epoch 0 | Step 80 | Loss G: 6.1611 | Loss D: 0.6098\n", + "Epoch 0 | Step 90 | Loss G: 8.4944 | Loss D: 0.4761\n", + "Epoch 0 | Step 100 | Loss G: 16.4352 | Loss D: 1.2795\n", + "Epoch 0 | Step 110 | Loss G: 6.7580 | Loss D: 0.6131\n", + "Epoch 0 | Step 120 | Loss G: 6.7856 | Loss D: 0.5403\n", + "Epoch 0 | Step 130 | Loss G: 6.2629 | Loss D: 0.6605\n", + "Epoch 0 | Step 140 | Loss G: 5.4081 | Loss D: 0.5667\n", + "Epoch 0 | Step 150 | Loss G: 5.5551 | Loss D: 0.5392\n", + "Epoch 0 | Step 160 | Loss G: 5.4969 | Loss D: 0.5394\n", + "Epoch 0 | Step 170 | Loss G: 4.6524 | Loss D: 0.5536\n", + "Epoch 0 | Step 180 | Loss G: 9.2369 | Loss D: 0.5098\n", + "Epoch 0 | Step 190 | Loss G: 10.5043 | Loss D: 0.4320\n", + "Epoch 0 | Step 200 | Loss G: 5.1155 | Loss D: 0.5958\n", + "Epoch 0 | Step 210 | Loss G: 5.1848 | Loss D: 0.5596\n", + "Epoch 0 | Step 220 | Loss G: 6.5014 | Loss D: 0.5307\n", + "Epoch 0 | Step 230 | Loss G: 6.4747 | Loss D: 0.6416\n", + "Epoch 0 | Step 240 | Loss G: 4.6687 | Loss D: 0.5795\n", + "Epoch 0 | Step 250 | Loss G: 3.8810 | Loss D: 0.6212\n", + "Epoch 0 | Step 260 | Loss G: 4.5834 | Loss D: 0.5906\n", + "Epoch 0 | Step 270 | Loss G: 6.8081 | Loss D: 0.4587\n", + "Epoch 0 | Step 280 | Loss G: 6.4570 | Loss D: 0.5759\n", + "Epoch 0 | Step 290 | Loss G: 4.5583 | Loss D: 0.5872\n", + "Epoch 0 | Step 300 | Loss G: 5.8205 | Loss D: 0.6410\n", + "Epoch 0 | Step 310 | Loss G: 4.7019 | Loss D: 0.6298\n", + "Epoch 0 | Step 320 | Loss G: 4.3386 | Loss D: 0.5507\n", + "Epoch 0 | Step 330 | Loss G: 4.2705 | Loss D: 0.5533\n", + "Epoch 0 | Step 340 | Loss G: 6.9023 | Loss D: 0.8489\n", + "Epoch 0 | Step 350 | Loss G: 8.0638 | Loss D: 0.7825\n", + "Epoch 0 | Step 360 | Loss G: 4.4931 | Loss D: 0.6148\n", + "Epoch 0 | Step 370 | Loss G: 3.7342 | Loss D: 0.6964\n", + "Epoch 0 | Step 380 | Loss G: 4.2160 | Loss D: 0.6782\n", + "Epoch 0 | Step 390 | Loss G: 5.1540 | Loss D: 0.6306\n", + "Epoch 0 | Step 400 | Loss G: 4.5333 | Loss D: 0.6755\n", + "Epoch 0 | Step 410 | Loss G: 4.0220 | Loss D: 0.6751\n", + "Epoch 0 | Step 420 | Loss G: 5.7766 | Loss D: 0.7027\n", + "Epoch 0 | Step 430 | Loss G: 4.0071 | Loss D: 0.5669\n", + "Epoch 0 | Step 440 | Loss G: 4.1375 | Loss D: 0.4732\n", + "Epoch 0 | Step 450 | Loss G: 3.8729 | Loss D: 0.5686\n", + "Epoch 0 | Step 460 | Loss G: 3.5877 | Loss D: 0.6608\n", + "Epoch 0 | Step 470 | Loss G: 5.6308 | Loss D: 0.7109\n", + "Epoch 0 | Step 480 | Loss G: 5.3286 | Loss D: 0.7127\n", + "Epoch 0 | Step 490 | Loss G: 2.9878 | Loss D: 0.6668\n", + "Epoch 0 | Step 500 | Loss G: 4.0083 | Loss D: 0.7079\n", + "Epoch 0 | Step 510 | Loss G: 3.6710 | Loss D: 0.7004\n", + "Epoch 0 | Step 520 | Loss G: 5.5583 | Loss D: 0.6757\n", + "Epoch 0 | Step 530 | Loss G: 4.2446 | Loss D: 0.6811\n", + "Epoch 0 | Step 540 | Loss G: 3.7496 | Loss D: 0.6590\n", + "Epoch 0 | Step 550 | Loss G: 7.4803 | Loss D: 0.8142\n", + "Epoch 0 | Step 560 | Loss G: 3.7535 | Loss D: 0.6432\n", + "Epoch 0 | Step 570 | Loss G: 2.8306 | Loss D: 0.5701\n", + "Epoch 0 | Step 580 | Loss G: 2.9589 | Loss D: 0.6700\n", + "Epoch 0 | Step 590 | Loss G: 4.7672 | Loss D: 0.7502\n", + "Epoch 0 | Step 600 | Loss G: 3.1664 | Loss D: 0.6552\n", + "Epoch 0 | Step 610 | Loss G: 3.2172 | Loss D: 0.6563\n", + "Epoch 0 | Step 620 | Loss G: 3.1097 | Loss D: 0.7513\n", + "Epoch 0 | Step 630 | Loss G: 3.8371 | Loss D: 0.7282\n", + "Epoch 0 | Step 640 | Loss G: 3.1871 | Loss D: 0.6991\n", + "Epoch 0 | Step 650 | Loss G: 3.4674 | Loss D: 0.6747\n", + "Epoch 0 | Step 660 | Loss G: 3.5930 | Loss D: 0.6720\n", + "Epoch 0 | Step 670 | Loss G: 2.9213 | Loss D: 0.6792\n", + "Epoch 0 | Step 680 | Loss G: 3.5470 | Loss D: 0.6509\n", + "Epoch 0 | Step 690 | Loss G: 3.6894 | Loss D: 0.6192\n", + "Epoch 0 | Step 700 | Loss G: 3.6680 | Loss D: 0.5443\n", + "Epoch 0 | Step 710 | Loss G: 3.6357 | Loss D: 0.5633\n", + "Epoch 0 | Step 720 | Loss G: 3.2239 | Loss D: 0.6394\n", + "Epoch 0 | Step 730 | Loss G: 2.5703 | Loss D: 0.6727\n", + "Epoch 0 | Step 740 | Loss G: 3.2920 | Loss D: 0.6691\n", + "Epoch 0 | Step 750 | Loss G: 3.1487 | Loss D: 0.7395\n", + "Epoch 0 | Step 760 | Loss G: 3.2808 | Loss D: 0.6791\n", + "Epoch 0 | Step 770 | Loss G: 3.6725 | Loss D: 0.7175\n", + "Epoch 0 | Step 780 | Loss G: 2.8584 | Loss D: 0.6647\n", + "Epoch 0 | Step 790 | Loss G: 3.5167 | Loss D: 0.7073\n", + "Epoch 0 | Step 800 | Loss G: 2.9515 | Loss D: 0.6595\n", + "Epoch 0 | Step 810 | Loss G: 3.0861 | Loss D: 0.7068\n", + "Epoch 0 | Step 820 | Loss G: 5.8470 | Loss D: 0.9232\n", + "Epoch 0 | Step 830 | Loss G: 3.4973 | Loss D: 0.6484\n", + "Epoch 0 | Step 840 | Loss G: 3.3402 | Loss D: 0.6405\n", + "Epoch 0 | Step 850 | Loss G: 2.9324 | Loss D: 0.6842\n", + "Epoch 0 | Step 860 | Loss G: 3.3249 | Loss D: 0.5885\n", + "Epoch 0 | Step 870 | Loss G: 3.2350 | Loss D: 0.7483\n", + "Epoch 0 | Step 880 | Loss G: 2.9857 | Loss D: 0.6709\n", + "Epoch 0 | Step 890 | Loss G: 3.1763 | Loss D: 0.5673\n", + "Epoch 0 | Step 900 | Loss G: 3.0522 | Loss D: 0.7522\n", + "Epoch 0 | Step 910 | Loss G: 2.5046 | Loss D: 0.7584\n", + "Epoch 0 | Step 920 | Loss G: 3.5277 | Loss D: 0.7532\n", + "Epoch 0 | Step 930 | Loss G: 2.7921 | Loss D: 0.6860\n", + "Epoch 0 | Step 940 | Loss G: 4.3079 | Loss D: 0.7080\n", + "Epoch 0 | Step 950 | Loss G: 2.7965 | Loss D: 0.6728\n", + "Epoch 0 | Step 960 | Loss G: 2.7622 | Loss D: 0.6773\n", + "Epoch 0 | Step 970 | Loss G: 2.8015 | Loss D: 0.6507\n", + "Epoch 0 | Step 980 | Loss G: 2.0900 | Loss D: 0.6775\n", + "Epoch 0 | Step 990 | Loss G: 2.9263 | Loss D: 0.6361\n", + "Epoch 0 | Step 1000 | Loss G: 3.5840 | Loss D: 0.6793\n", + "Epoch 0 | Step 1010 | Loss G: 3.0541 | Loss D: 0.7161\n", + "Epoch 0 | Step 1020 | Loss G: 2.6563 | Loss D: 0.6252\n", + "Epoch 0 | Step 1030 | Loss G: 2.9624 | Loss D: 0.6389\n", + "Epoch 0 | Step 1040 | Loss G: 4.0138 | Loss D: 0.5921\n", + "Epoch 0 | Step 1050 | Loss G: 3.4586 | Loss D: 0.6120\n", + "Epoch 0 | Step 1060 | Loss G: 3.6436 | Loss D: 0.5100\n", + "Epoch 0 | Step 1070 | Loss G: 2.5240 | Loss D: 0.7917\n", + "Epoch 0 | Step 1080 | Loss G: 2.9808 | Loss D: 0.5955\n", + "Epoch 0 | Step 1090 | Loss G: 3.5152 | Loss D: 0.6790\n", + "Epoch 0 | Step 1100 | Loss G: 4.8162 | Loss D: 0.9497\n", + "Epoch 0 | Step 1110 | Loss G: 4.0998 | Loss D: 0.6457\n", + "Epoch 0 | Step 1120 | Loss G: 3.7265 | Loss D: 0.7151\n", + "Epoch 0 | Step 1130 | Loss G: 2.8072 | Loss D: 0.7019\n", + "Epoch 0 | Step 1140 | Loss G: 3.8286 | Loss D: 0.7170\n", + "Epoch 0 | Step 1150 | Loss G: 2.3177 | Loss D: 0.6822\n", + "Epoch 0 | Step 1160 | Loss G: 2.3291 | Loss D: 0.6801\n", + "Epoch 0 | Step 1170 | Loss G: 2.5152 | Loss D: 0.6733\n", + "Epoch 0 | Step 1180 | Loss G: 2.1710 | Loss D: 0.6602\n", + "Epoch 0 | Step 1190 | Loss G: 3.0357 | Loss D: 0.6695\n", + "Epoch 0 | Step 1200 | Loss G: 2.9576 | Loss D: 0.6414\n", + "Epoch 0 | Step 1210 | Loss G: 3.0184 | Loss D: 0.6573\n", + "Epoch 0 | Step 1220 | Loss G: 2.2493 | Loss D: 0.7150\n", + "Epoch 0 | Step 1230 | Loss G: 2.7642 | Loss D: 0.7177\n", + "Epoch 0 | Step 1240 | Loss G: 2.4553 | Loss D: 0.6698\n", + "Epoch 0 | Step 1250 | Loss G: 2.5607 | Loss D: 0.6953\n", + "Epoch 0 | Step 1260 | Loss G: 2.9071 | Loss D: 0.6687\n", + "Epoch 0 | Step 1270 | Loss G: 2.6420 | Loss D: 0.6780\n", + "Epoch 0 | Step 1280 | Loss G: 2.6417 | Loss D: 0.6567\n", + "Epoch 0 | Step 1290 | Loss G: 2.5739 | Loss D: 0.6899\n", + "Epoch 0 | Step 1300 | Loss G: 2.9631 | Loss D: 0.6819\n", + "Epoch 0 | Step 1310 | Loss G: 2.5568 | Loss D: 0.6575\n", + "Epoch 0 | Step 1320 | Loss G: 2.2896 | Loss D: 0.6679\n", + "Epoch 0 | Step 1330 | Loss G: 2.9179 | Loss D: 0.6886\n", + "Epoch 0 | Step 1340 | Loss G: 2.8625 | Loss D: 0.6136\n", + "Epoch 0 | Step 1350 | Loss G: 2.6456 | Loss D: 0.6368\n", + "Epoch 0 | Step 1360 | Loss G: 2.9531 | Loss D: 0.6491\n", + "Epoch 0 | Step 1370 | Loss G: 2.5252 | Loss D: 0.7159\n", + "Epoch 0 | Step 1380 | Loss G: 3.9087 | Loss D: 0.7079\n", + "Epoch 0 | Step 1390 | Loss G: 2.3737 | Loss D: 0.6529\n", + "Epoch 0 | Step 1400 | Loss G: 2.3664 | Loss D: 0.6481\n", + "Epoch 0 | Step 1410 | Loss G: 2.7657 | Loss D: 0.5882\n", + "Epoch 0 | Step 1420 | Loss G: 2.7599 | Loss D: 0.6327\n", + "Epoch 0 | Step 1430 | Loss G: 2.6549 | Loss D: 0.8037\n", + "Epoch 0 | Step 1440 | Loss G: 3.3464 | Loss D: 0.5944\n", + "Epoch 0 | Step 1450 | Loss G: 2.6873 | Loss D: 0.7476\n", + "Epoch 0 | Step 1460 | Loss G: 2.5144 | Loss D: 0.6655\n", + "Epoch 0 | Step 1470 | Loss G: 3.1850 | Loss D: 0.7270\n", + "Epoch 0 | Step 1480 | Loss G: 5.0466 | Loss D: 0.6669\n", + "Epoch 0 | Step 1490 | Loss G: 2.7675 | Loss D: 0.6640\n", + "Epoch 0 | Step 1500 | Loss G: 2.3371 | Loss D: 0.6604\n", + "Epoch 0 | Step 1510 | Loss G: 2.6685 | Loss D: 0.6510\n", + "Epoch 0 | Step 1520 | Loss G: 3.0199 | Loss D: 0.6744\n", + "Epoch 0 | Step 1530 | Loss G: 2.9492 | Loss D: 0.7057\n", + "Epoch 0 | Step 1540 | Loss G: 2.8013 | Loss D: 0.7118\n", + "Epoch 0 | Step 1550 | Loss G: 2.8135 | Loss D: 0.6324\n", + "Epoch 0 | Step 1560 | Loss G: 3.1775 | Loss D: 0.6969\n", + "Epoch 0 | Step 1570 | Loss G: 3.1974 | Loss D: 0.6341\n", + "Epoch 0 | Step 1580 | Loss G: 2.4508 | Loss D: 0.6909\n", + "Epoch 0 | Step 1590 | Loss G: 3.2501 | Loss D: 0.7138\n", + "Epoch 0 | Step 1600 | Loss G: 3.5432 | Loss D: 0.7108\n", + "Epoch 0 | Step 1610 | Loss G: 2.5303 | Loss D: 0.6710\n", + "Epoch 0 | Step 1620 | Loss G: 2.8239 | Loss D: 0.6950\n", + "Epoch 0 | Step 1630 | Loss G: 2.2003 | Loss D: 0.6568\n", + "Epoch 0 | Step 1640 | Loss G: 2.6544 | Loss D: 0.7159\n", + "Epoch 0 | Step 1650 | Loss G: 2.3602 | Loss D: 0.6818\n", + "Epoch 0 | Step 1660 | Loss G: 3.1921 | Loss D: 0.6072\n", + "Epoch 0 | Step 1670 | Loss G: 2.5231 | Loss D: 0.7427\n", + "Epoch 0 | Step 1680 | Loss G: 2.9751 | Loss D: 0.6027\n", + "Epoch 0 | Step 1690 | Loss G: 2.3957 | Loss D: 0.6853\n", + "Epoch 0 | Step 1700 | Loss G: 2.5812 | Loss D: 0.6981\n", + "Epoch 0 | Step 1710 | Loss G: 2.7025 | Loss D: 0.7519\n", + "Epoch 0 | Step 1720 | Loss G: 2.9266 | Loss D: 0.6202\n", + "Epoch 0 | Step 1730 | Loss G: 2.7670 | Loss D: 0.6953\n", + "Epoch 0 | Step 1740 | Loss G: 2.7067 | Loss D: 0.5941\n", + "Epoch 0 | Step 1750 | Loss G: 3.5343 | Loss D: 0.7225\n", + "Epoch 0 | Step 1760 | Loss G: 2.4662 | Loss D: 0.6810\n", + "Epoch 0 | Step 1770 | Loss G: 2.6970 | Loss D: 0.6763\n", + "Epoch 0 | Step 1780 | Loss G: 3.1359 | Loss D: 0.6382\n", + "Epoch 0 | Step 1790 | Loss G: 2.6589 | Loss D: 0.6323\n", + "Epoch 0 | Step 1800 | Loss G: 2.6840 | Loss D: 0.6116\n", + "Epoch 0 | Step 1810 | Loss G: 3.0749 | Loss D: 0.7142\n", + "Epoch 0 | Step 1820 | Loss G: 2.4837 | Loss D: 0.7109\n", + "Epoch 0 | Step 1830 | Loss G: 2.9848 | Loss D: 0.6832\n", + "Epoch 0 | Step 1840 | Loss G: 3.1044 | Loss D: 0.5613\n", + "Epoch 0 | Step 1850 | Loss G: 2.6458 | Loss D: 1.0100\n", + "Epoch 0 | Step 1860 | Loss G: 2.4200 | Loss D: 0.7013\n", + "Epoch 0 | Step 1870 | Loss G: 2.4586 | Loss D: 0.7051\n", + "Epoch 0 | Step 1880 | Loss G: 2.3607 | Loss D: 0.7131\n", + "Epoch 0 | Step 1890 | Loss G: 2.6252 | Loss D: 0.7187\n", + "Epoch 0 | Step 1900 | Loss G: 2.4684 | Loss D: 0.6961\n", + "Epoch 0 | Step 1910 | Loss G: 2.2537 | Loss D: 0.6876\n", + "Epoch 0 | Step 1920 | Loss G: 2.4095 | Loss D: 0.6838\n", + "Epoch 0 | Step 1930 | Loss G: 2.8157 | Loss D: 0.6736\n", + "Epoch 0 | Step 1940 | Loss G: 2.4893 | Loss D: 0.6775\n", + "Epoch 0 | Step 1950 | Loss G: 3.3445 | Loss D: 0.6962\n", + "Epoch 0 | Step 1960 | Loss G: 2.2544 | Loss D: 0.6767\n", + "Epoch 0 | Step 1970 | Loss G: 2.6515 | Loss D: 0.6322\n", + "Epoch 0 | Step 1980 | Loss G: 2.2835 | Loss D: 0.6343\n", + "Epoch 0 | Step 1990 | Loss G: 2.4262 | Loss D: 0.6298\n", + "Epoch 0 | Step 2000 | Loss G: 2.7222 | Loss D: 0.6135\n", + "Epoch 0 | Step 2010 | Loss G: 2.3119 | Loss D: 0.6223\n", + "Epoch 0 | Step 2020 | Loss G: 2.8205 | Loss D: 0.6000\n", + "Epoch 0 | Step 2030 | Loss G: 2.8469 | Loss D: 0.5928\n", + "Epoch 0 | Step 2040 | Loss G: 3.0901 | Loss D: 0.6076\n", + "Epoch 0 | Step 2050 | Loss G: 2.4228 | Loss D: 0.6983\n", + "Epoch 0 | Step 2060 | Loss G: 2.7112 | Loss D: 0.7587\n", + "Epoch 0 | Step 2070 | Loss G: 2.6442 | Loss D: 0.7504\n", + "Epoch 0 | Step 2080 | Loss G: 2.4068 | Loss D: 0.7240\n", + "Epoch 0 | Step 2090 | Loss G: 2.9278 | Loss D: 0.6676\n", + "Epoch 0 | Step 2100 | Loss G: 2.2134 | Loss D: 0.7018\n", + "Epoch 0 | Step 2110 | Loss G: 2.8740 | Loss D: 0.6760\n", + "Epoch 0 | Step 2120 | Loss G: 2.4062 | Loss D: 0.6768\n", + "Epoch 0 | Step 2130 | Loss G: 2.2714 | Loss D: 0.6880\n", + "Epoch 0 | Step 2140 | Loss G: 2.3448 | Loss D: 0.6871\n", + "Epoch 0 | Step 2150 | Loss G: 2.1828 | Loss D: 0.6292\n", + "Epoch 0 | Step 2160 | Loss G: 2.4968 | Loss D: 0.6260\n", + "Epoch 0 | Step 2170 | Loss G: 2.2925 | Loss D: 0.6720\n", + "Epoch 0 | Step 2180 | Loss G: 2.1419 | Loss D: 0.6476\n", + "Epoch 0 | Step 2190 | Loss G: 2.4403 | Loss D: 0.6648\n", + "Epoch 0 | Step 2200 | Loss G: 3.3253 | Loss D: 0.5541\n", + "Epoch 0 | Step 2210 | Loss G: 2.3137 | Loss D: 0.6822\n", + "Epoch 0 | Step 2220 | Loss G: 3.3092 | Loss D: 0.6984\n", + "Epoch 0 | Step 2230 | Loss G: 2.7817 | Loss D: 0.6168\n", + "Epoch 0 | Step 2240 | Loss G: 2.4387 | Loss D: 0.6215\n", + "Epoch 0 | Step 2250 | Loss G: 2.7103 | Loss D: 0.5489\n", + "Epoch 0 | Step 2260 | Loss G: 2.8750 | Loss D: 0.7436\n", + "Epoch 0 | Step 2270 | Loss G: 3.2172 | Loss D: 0.7419\n", + "Epoch 0 | Step 2280 | Loss G: 2.2944 | Loss D: 0.7735\n", + "Epoch 0 | Step 2290 | Loss G: 2.9406 | Loss D: 0.6687\n", + "Epoch 0 | Step 2300 | Loss G: 2.3666 | Loss D: 0.7192\n", + "Epoch 0 | Step 2310 | Loss G: 2.0075 | Loss D: 0.7016\n", + "Epoch 0 | Step 2320 | Loss G: 2.5500 | Loss D: 0.7104\n", + "Epoch 0 | Step 2330 | Loss G: 2.2471 | Loss D: 0.6831\n", + "Epoch 0 | Step 2340 | Loss G: 2.2565 | Loss D: 0.6743\n", + "Epoch 0 | Step 2350 | Loss G: 2.1140 | Loss D: 0.6982\n", + "Epoch 0 | Step 2360 | Loss G: 2.2506 | Loss D: 0.6729\n", + "Epoch 0 | Step 2370 | Loss G: 2.4593 | Loss D: 0.6737\n", + "Epoch 0 | Step 2380 | Loss G: 2.3511 | Loss D: 0.6809\n", + "Epoch 0 | Step 2390 | Loss G: 2.4674 | Loss D: 0.6723\n", + "Epoch 0 | Step 2400 | Loss G: 2.7037 | Loss D: 0.6395\n", + "Epoch 0 | Step 2410 | Loss G: 2.3319 | Loss D: 0.7071\n", + "Epoch 0 | Step 2420 | Loss G: 2.4786 | Loss D: 0.6931\n", + "Epoch 0 | Step 2430 | Loss G: 3.5919 | Loss D: 0.7343\n", + "Epoch 0 | Step 2440 | Loss G: 2.3883 | Loss D: 0.7144\n", + "Epoch 0 | Step 2450 | Loss G: 2.1943 | Loss D: 0.6796\n", + "Epoch 0 | Step 2460 | Loss G: 2.6514 | Loss D: 0.7204\n", + "Epoch 0 | Step 2470 | Loss G: 2.0497 | Loss D: 0.6569\n", + "Epoch 0 | Step 2480 | Loss G: 2.3691 | Loss D: 0.7213\n", + "Epoch 0 | Step 2490 | Loss G: 2.7880 | Loss D: 0.6707\n", + "Epoch 0 | Step 2500 | Loss G: 2.9353 | Loss D: 0.7046\n", + "Epoch 0 | Step 2510 | Loss G: 2.0710 | Loss D: 0.7094\n", + "Epoch 0 | Step 2520 | Loss G: 2.1715 | Loss D: 0.6868\n", + "Epoch 0 | Step 2530 | Loss G: 2.9305 | Loss D: 0.7008\n", + "Epoch 0 | Step 2540 | Loss G: 3.2625 | Loss D: 0.7100\n", + "Epoch 0 | Step 2550 | Loss G: 2.4526 | Loss D: 0.6615\n", + "Epoch 0 | Step 2560 | Loss G: 2.1840 | Loss D: 0.6649\n", + "Epoch 0 | Step 2570 | Loss G: 2.3944 | Loss D: 0.6434\n", + "Epoch 0 | Step 2580 | Loss G: 2.3077 | Loss D: 0.6637\n", + "Epoch 0 | Step 2590 | Loss G: 2.2480 | Loss D: 0.6865\n", + "Epoch 0 | Step 2600 | Loss G: 2.2794 | Loss D: 0.6606\n", + "Epoch 0 | Step 2610 | Loss G: 2.2931 | Loss D: 0.6991\n", + "Epoch 0 | Step 2620 | Loss G: 2.1852 | Loss D: 0.6694\n", + "Epoch 0 | Step 2630 | Loss G: 2.4422 | Loss D: 0.6534\n", + "Epoch 0 | Step 2640 | Loss G: 2.4345 | Loss D: 0.7486\n", + "Epoch 0 | Step 2650 | Loss G: 2.1540 | Loss D: 0.7449\n", + "Epoch 0 | Step 2660 | Loss G: 2.2600 | Loss D: 0.6953\n", + "Epoch 0 | Step 2670 | Loss G: 2.6546 | Loss D: 0.6010\n", + "Epoch 0 | Step 2680 | Loss G: 2.1378 | Loss D: 0.7238\n", + "Epoch 0 | Step 2690 | Loss G: 2.4992 | Loss D: 0.6710\n", + "Epoch 0 | Step 2700 | Loss G: 2.4387 | Loss D: 0.6570\n", + "Epoch 0 | Step 2710 | Loss G: 2.7910 | Loss D: 0.7033\n", + "Epoch 0 | Step 2720 | Loss G: 2.0291 | Loss D: 0.6740\n", + "Epoch 0 | Step 2730 | Loss G: 2.3530 | Loss D: 0.6880\n", + "Epoch 0 | Step 2740 | Loss G: 2.4699 | Loss D: 0.6142\n", + "Epoch 0 | Step 2750 | Loss G: 2.2944 | Loss D: 0.7012\n", + "Epoch 0 | Step 2760 | Loss G: 2.3153 | Loss D: 0.6831\n", + "Epoch 0 | Step 2770 | Loss G: 1.9557 | Loss D: 0.6957\n", + "Epoch 0 | Step 2780 | Loss G: 1.9761 | Loss D: 0.7643\n", + "Epoch 0 | Step 2790 | Loss G: 2.0589 | Loss D: 0.7031\n", + "Epoch 0 | Step 2800 | Loss G: 2.1808 | Loss D: 0.6616\n", + "Epoch 0 | Step 2810 | Loss G: 3.0882 | Loss D: 0.6470\n", + "Epoch 0 | Step 2820 | Loss G: 2.2611 | Loss D: 0.7233\n", + "Epoch 0 | Step 2830 | Loss G: 2.2900 | Loss D: 0.6960\n", + "Epoch 0 | Step 2840 | Loss G: 2.8689 | Loss D: 0.6929\n", + "Epoch 0 | Step 2850 | Loss G: 2.6262 | Loss D: 0.6793\n", + "Epoch 0 | Step 2860 | Loss G: 2.0969 | Loss D: 0.6568\n", + "Epoch 0 | Step 2870 | Loss G: 2.0100 | Loss D: 0.6813\n", + "Epoch 0 | Step 2880 | Loss G: 2.4225 | Loss D: 0.6647\n", + "Epoch 0 | Step 2890 | Loss G: 2.3839 | Loss D: 0.6293\n", + "Epoch 0 | Step 2900 | Loss G: 2.2553 | Loss D: 0.6671\n", + "Epoch 0 | Step 2910 | Loss G: 2.2506 | Loss D: 0.6316\n", + "Epoch 0 | Step 2920 | Loss G: 2.2017 | Loss D: 0.6144\n", + "Epoch 0 | Step 2930 | Loss G: 2.3828 | Loss D: 0.7388\n", + "Epoch 0 | Step 2940 | Loss G: 3.6152 | Loss D: 0.5791\n", + "Epoch 0 | Step 2950 | Loss G: 2.4575 | Loss D: 0.6609\n", + "Epoch 0 | Step 2960 | Loss G: 2.5841 | Loss D: 0.8060\n", + "Epoch 0 | Step 2970 | Loss G: 2.6008 | Loss D: 0.6548\n", + "Epoch 0 | Step 2980 | Loss G: 2.9643 | Loss D: 0.6375\n", + "Epoch 0 | Step 2990 | Loss G: 2.0817 | Loss D: 0.6882\n", + "Epoch 0 | Step 3000 | Loss G: 2.4347 | Loss D: 0.7490\n", + "Epoch 0 | Step 3010 | Loss G: 2.7559 | Loss D: 0.7129\n", + "Epoch 0 | Step 3020 | Loss G: 2.1537 | Loss D: 0.6953\n", + "Epoch 0 | Step 3030 | Loss G: 2.2344 | Loss D: 0.7369\n", + "Epoch 0 | Step 3040 | Loss G: 2.3036 | Loss D: 0.6224\n", + "Epoch 0 | Step 3050 | Loss G: 2.3935 | Loss D: 0.7120\n", + "Epoch 0 | Step 3060 | Loss G: 2.3280 | Loss D: 0.6134\n", + "Epoch 0 | Step 3070 | Loss G: 2.5429 | Loss D: 0.7636\n", + "Epoch 0 | Step 3080 | Loss G: 2.3605 | Loss D: 0.6986\n", + "Epoch 0 | Step 3090 | Loss G: 2.1876 | Loss D: 0.7954\n", + "Epoch 0 | Step 3100 | Loss G: 2.3063 | Loss D: 0.6299\n", + "Epoch 0 | Step 3110 | Loss G: 2.3226 | Loss D: 0.8436\n", + "Epoch 0 | Step 3120 | Loss G: 2.9968 | Loss D: 0.7128\n", + "Epoch 0 | Step 3130 | Loss G: 2.0452 | Loss D: 0.7180\n", + "Epoch 0 | Step 3140 | Loss G: 1.9460 | Loss D: 0.7010\n", + "Epoch 0 | Step 3150 | Loss G: 2.7439 | Loss D: 0.6920\n", + "Epoch 0 | Step 3160 | Loss G: 1.8548 | Loss D: 0.7069\n", + "Epoch 0 | Step 3170 | Loss G: 2.6291 | Loss D: 0.6260\n", + "Epoch 0 | Step 3180 | Loss G: 3.1567 | Loss D: 0.6356\n", + "Epoch 0 | Step 3190 | Loss G: 1.9601 | Loss D: 0.6841\n", + "Epoch 0 | Step 3200 | Loss G: 2.3887 | Loss D: 0.6708\n", + "Epoch 0 | Step 3210 | Loss G: 2.0190 | Loss D: 0.6558\n", + "Epoch 0 | Step 3220 | Loss G: 2.5288 | Loss D: 0.6406\n", + "Epoch 0 | Step 3230 | Loss G: 2.0141 | Loss D: 0.6553\n", + "Epoch 0 | Step 3240 | Loss G: 3.5275 | Loss D: 0.6934\n", + "Epoch 0 | Step 3250 | Loss G: 2.5778 | Loss D: 0.5621\n", + "Epoch 0 | Step 3260 | Loss G: 2.4638 | Loss D: 0.6106\n", + "Epoch 0 | Step 3270 | Loss G: 3.1237 | Loss D: 0.6939\n", + "Epoch 0 | Step 3280 | Loss G: 2.8655 | Loss D: 0.6382\n", + "Epoch 0 | Step 3290 | Loss G: 2.6247 | Loss D: 0.6818\n", + "Epoch 0 | Step 3300 | Loss G: 2.2132 | Loss D: 0.6349\n", + "Epoch 0 | Step 3310 | Loss G: 2.3956 | Loss D: 0.7407\n", + "Epoch 0 | Step 3320 | Loss G: 2.8687 | Loss D: 0.4859\n", + "Epoch 0 | Step 3330 | Loss G: 2.4718 | Loss D: 0.7202\n", + "Epoch 0 | Step 3340 | Loss G: 1.9500 | Loss D: 0.6743\n", + "Epoch 0 | Step 3350 | Loss G: 2.4359 | Loss D: 0.6221\n", + "Epoch 0 | Step 3360 | Loss G: 2.1875 | Loss D: 0.7105\n", + "Epoch 0 | Step 3370 | Loss G: 2.2411 | Loss D: 0.5774\n", + "Epoch 0 | Step 3380 | Loss G: 2.2889 | Loss D: 0.6689\n", + "Epoch 0 | Step 3390 | Loss G: 2.5175 | Loss D: 0.7022\n", + "Epoch 0 | Step 3400 | Loss G: 2.1678 | Loss D: 0.6610\n", + "Epoch 0 | Step 3410 | Loss G: 1.8933 | Loss D: 0.7071\n", + "Epoch 0 | Step 3420 | Loss G: 2.1478 | Loss D: 0.6712\n", + "Epoch 0 | Step 3430 | Loss G: 2.5614 | Loss D: 0.6450\n", + "Epoch 0 | Step 3440 | Loss G: 2.5616 | Loss D: 0.6267\n", + "Epoch 0 | Step 3450 | Loss G: 2.2356 | Loss D: 0.6358\n", + "Epoch 0 | Step 3460 | Loss G: 3.6788 | Loss D: 0.6079\n", + "Epoch 0 | Step 3470 | Loss G: 2.4702 | Loss D: 0.6131\n" + ] + } + ], + "source": [ + "# ==========================================\n", + "# CELL 3: HUẤN LUYỆN & LƯU MODEL\n", + "# ==========================================\n", + "if 'train_loader' in locals():\n", + " generator = CR_Generator().to(device)\n", + " discriminator = Discriminator().to(device)\n", + " \n", + " criterion_GAN = nn.BCELoss()\n", + " criterion_L1 = nn.L1Loss()\n", + " opt_g = torch.optim.Adam(generator.parameters(), lr=0.0002)\n", + " opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002)\n", + "\n", + " print(\"🚀 Bắt đầu huấn luyện CR-GAN...\")\n", + " epochs = 5 # Demo chạy 5 vòng (tăng lên nếu muốn train thật)\n", + " \n", + " for epoch in range(epochs):\n", + " for i, (s1, s2, cloud) in enumerate(train_loader):\n", + " s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n", + " \n", + " # --- Train Generator ---\n", + " opt_g.zero_grad()\n", + " fake = generator(s1, cloud)\n", + " pred_fake = discriminator(fake, s1, cloud)\n", + " loss_g = criterion_GAN(pred_fake, torch.ones_like(pred_fake)) + 100 * criterion_L1(fake, s2)\n", + " loss_g.backward()\n", + " opt_g.step()\n", + " \n", + " # --- Train Discriminator ---\n", + " opt_d.zero_grad()\n", + " pred_real = discriminator(s2, s1, cloud)\n", + " pred_fake_d = discriminator(fake.detach(), s1, cloud)\n", + " loss_d = 0.5 * (criterion_GAN(pred_real, torch.ones_like(pred_real)) + \n", + " criterion_GAN(pred_fake_d, torch.zeros_like(pred_fake_d)))\n", + " loss_d.backward()\n", + " opt_d.step()\n", + " \n", + " if i % 10 == 0: \n", + " print(f\"Epoch {epoch} | Step {i} | Loss G: {loss_g.item():.4f} | Loss D: {loss_d.item():.4f}\")\n", + "\n", + " # --- LƯU MODEL ---\n", + " save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n", + " os.makedirs(save_dir, exist_ok=True)\n", + " torch.save(generator.state_dict(), os.path.join(save_dir, \"CRGAN_generator.pth\"))\n", + " print(f\"\\n💾 Đã lưu model tại: {save_dir}/CRGAN_generator.pth\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87873ca2-6a64-4fc9-8687-8c8555b99ed4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Train_GLF-CR.ipynb b/Train_GLF-CR.ipynb new file mode 100644 index 0000000..2702c3c --- /dev/null +++ b/Train_GLF-CR.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "id": "a148f700-b14e-4dc3-a203-9e2f93955587", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Working Directory: /home/jovyan/cloud_train/dataset\n", + "✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n", + "Using device: cpu\n", + "✅ Data Ready: 29117 samples\n" + ] + } + ], + "source": [ + "# ==========================================\n", + "# CELL 1: SETUP & LOAD DỮ LIỆU\n", + "# ==========================================\n", + "import os, sys, torch, numpy as np\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "# Fix Import Path\n", + "sys.path.append(os.getcwd())\n", + "sys.path.append(\"/home/jovyan/cloud_train\")\n", + "\n", + "try:\n", + " from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n", + "except ImportError:\n", + " raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n", + "\n", + "class SEN12MSCR_TorchDataset(Dataset):\n", + " def __init__(self, base_dir, season=Seasons.SPRING):\n", + " self.loader = SEN12MSCRDataset(base_dir)\n", + " self.season = season\n", + " self.season_ids = self.loader.get_season_ids(season)\n", + " self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n", + " def __len__(self): return len(self.samples)\n", + " def __getitem__(self, idx):\n", + " sid, pid = self.samples[idx]\n", + " s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n", + " s1 = np.clip(s1, -25, 0) / 25.0\n", + " s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n", + " c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n", + " return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n", + "\n", + "train_dataset = SEN12MSCR_TorchDataset(base_dir)\n", + "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n", + "print(f\"✅ GLF-CR Data Ready: {len(train_dataset)} samples\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cdb8ff0f-ea65-4ed4-8f5b-b99b05c602a8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# ==========================================\n", + "# CELL 2: MODEL GLF-CR\n", + "# ==========================================\n", + "class FusionBlock(nn.Module):\n", + " def __init__(self, channels):\n", + " super().__init__()\n", + " self.conv = nn.Sequential(\n", + " nn.Conv2d(channels * 2, channels, 1),\n", + " nn.BatchNorm2d(channels), nn.ReLU()\n", + " )\n", + " def forward(self, x1, x2):\n", + " return self.conv(torch.cat([x1, x2], dim=1))\n", + "\n", + "class GLF_Generator(nn.Module):\n", + " def __init__(self):\n", + " super(GLF_Generator, self).__init__()\n", + " # 2 Nhánh riêng biệt\n", + " self.enc_s1 = nn.Sequential(nn.Conv2d(2, 64, 3, 1, 1), nn.ReLU()) # Radar\n", + " self.enc_s2 = nn.Sequential(nn.Conv2d(13, 64, 3, 1, 1), nn.ReLU()) # Mây\n", + " self.fuse = FusionBlock(64)\n", + " self.dec = nn.Sequential(\n", + " nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),\n", + " nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(),\n", + " nn.Conv2d(64, 13, 3, 1, 1)\n", + " )\n", + " self.tanh = nn.Tanh()\n", + "\n", + " def forward(self, s1, s2_cloudy):\n", + " f1 = self.enc_s1(s1)\n", + " f2 = self.enc_s2(s2_cloudy)\n", + " f_fused = self.fuse(f1, f2)\n", + " return self.tanh(self.dec(f_fused))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b124a548-a7fe-46e7-939d-ae3ff5e5b83c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🚀 Bắt đầu train GLF-CR...\n", + "Step 0 | MSE Loss: 0.4782\n", + "Step 10 | MSE Loss: 0.0848\n", + "Step 20 | MSE Loss: 0.0446\n", + "Step 30 | MSE Loss: 0.0274\n", + "Step 40 | MSE Loss: 0.2338\n", + "Step 50 | MSE Loss: 0.0158\n", + "\n", + "💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/GLFCR_generator.pth\n" + ] + } + ], + "source": [ + "# ==========================================\n", + "# CELL 3: TRAINING & SAVE\n", + "# ==========================================\n", + "if 'train_loader' in locals():\n", + " gen = GLF_Generator().to(device)\n", + " opt = torch.optim.Adam(gen.parameters(), lr=0.0002)\n", + " crit = nn.MSELoss() \n", + " \n", + " print(\"🚀 Bắt đầu train GLF-CR...\")\n", + " for i, (s1, s2, cloud) in enumerate(train_loader):\n", + " s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n", + " \n", + " opt.zero_grad()\n", + " out = gen(s1, cloud)\n", + " loss = crit(out, s2)\n", + " loss.backward()\n", + " opt.step()\n", + " \n", + " if i % 10 == 0: print(f\"Step {i} | MSE Loss: {loss.item():.4f}\")\n", + " if i > 50: break\n", + " \n", + " # --- LƯU MODEL ---\n", + " save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n", + " os.makedirs(save_dir, exist_ok=True)\n", + " torch.save(gen.state_dict(), os.path.join(save_dir, \"GLFCR_generator.pth\"))\n", + " print(f\"\\n💾 Đã lưu model tại: {save_dir}/GLFCR_generator.pth\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f863772-5925-45b4-a9aa-e1dc54f52658", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Train_SpA-GAN.ipynb b/Train_SpA-GAN.ipynb new file mode 100644 index 0000000..5ed2f3e --- /dev/null +++ b/Train_SpA-GAN.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "3813f900-ea54-4533-873a-70aa46417ba0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ SpA-GAN Data Ready: 29117 samples\n" + ] + } + ], + "source": [ + "# ==========================================\n", + "# CELL 1: SETUP & LOAD DỮ LIỆU\n", + "# ==========================================\n", + "import os, sys, torch, numpy as np\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "# Fix Import Path\n", + "sys.path.append(os.getcwd())\n", + "sys.path.append(\"/home/jovyan/cloud_train\")\n", + "\n", + "try:\n", + " from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n", + "except ImportError:\n", + " raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n", + "\n", + "class SEN12MSCR_TorchDataset(Dataset):\n", + " def __init__(self, base_dir, season=Seasons.SPRING):\n", + " self.loader = SEN12MSCRDataset(base_dir)\n", + " self.season = season\n", + " self.season_ids = self.loader.get_season_ids(season)\n", + " self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n", + " def __len__(self): return len(self.samples)\n", + " def __getitem__(self, idx):\n", + " sid, pid = self.samples[idx]\n", + " s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n", + " s1 = np.clip(s1, -25, 0) / 25.0\n", + " s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n", + " c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n", + " return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n", + "\n", + "train_dataset = SEN12MSCR_TorchDataset(base_dir)\n", + "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n", + "print(f\"✅ SpA-GAN Data Ready: {len(train_dataset)} samples\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "078b6951-1dcf-44e9-895c-9ec0c8c848b6", + "metadata": {}, + "outputs": [], + "source": [ + "# ==========================================\n", + "# CELL 2: MODEL SpA-GAN\n", + "# ==========================================\n", + "class SpatialAttention(nn.Module):\n", + " def __init__(self, in_channels):\n", + " super(SpatialAttention, self).__init__()\n", + " self.conv = nn.Conv2d(in_channels, 1, kernel_size=1) \n", + " self.sigmoid = nn.Sigmoid()\n", + " def forward(self, x):\n", + " return x * self.sigmoid(self.conv(x))\n", + "\n", + "class SpA_Generator(nn.Module):\n", + " def __init__(self):\n", + " super(SpA_Generator, self).__init__()\n", + " self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n", + " self.attn = SpatialAttention(64) # Module Attention\n", + " self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n", + " \n", + " self.dec1 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n", + " self.dec2 = nn.ConvTranspose2d(64, 13, 4, 2, 1)\n", + " self.tanh = nn.Tanh()\n", + "\n", + " def forward(self, s1, s2_cloudy):\n", + " x = torch.cat([s1, s2_cloudy], dim=1)\n", + " e1 = self.enc1(x)\n", + " e1_attn = self.attn(e1) # Chú ý vào vùng quan trọng\n", + " e2 = self.enc2(e1_attn)\n", + " d1 = self.dec1(e2)\n", + " return self.tanh(self.dec2(d1 + e1))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e6bd89be-aa61-4429-b588-f2041311f2a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🚀 Bắt đầu train SpA-GAN...\n", + "Step 0 | Loss G: 0.6843\n", + "Step 10 | Loss G: 0.3739\n", + "Step 20 | Loss G: 0.1896\n", + "Step 30 | Loss G: 0.1224\n", + "Step 40 | Loss G: 0.0938\n", + "Step 50 | Loss G: 0.1833\n", + "\n", + "💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/SpAGAN_generator.pth\n" + ] + } + ], + "source": [ + "# ==========================================\n", + "# CELL 3: TRAINING & SAVE\n", + "# ==========================================\n", + "if 'train_loader' in locals():\n", + " gen = SpA_Generator().to(device)\n", + " opt_g = torch.optim.Adam(gen.parameters(), lr=0.0002)\n", + " criterion = nn.L1Loss()\n", + " \n", + " print(\"🚀 Bắt đầu train SpA-GAN...\")\n", + " for i, (s1, s2, cloud) in enumerate(train_loader):\n", + " s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n", + " \n", + " opt_g.zero_grad()\n", + " fake = gen(s1, cloud)\n", + " loss_g = criterion(fake, s2) \n", + " loss_g.backward()\n", + " opt_g.step()\n", + " \n", + " if i % 10 == 0: print(f\"Step {i} | Loss G: {loss_g.item():.4f}\")\n", + " if i > 50: break # Demo 50 bước\n", + " \n", + " # --- LƯU MODEL ---\n", + " save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n", + " os.makedirs(save_dir, exist_ok=True)\n", + " torch.save(gen.state_dict(), os.path.join(save_dir, \"SpAGAN_generator.pth\"))\n", + " print(f\"\\n💾 Đã lưu model tại: {save_dir}/SpAGAN_generator.pth\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc1d8e75-0a3f-4197-8196-7132122f2be9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Tải_Dự_liệu_train_xóa_mây.ipynb b/Tải_Dự_liệu_train_xóa_mây.ipynb new file mode 100644 index 0000000..e04e262 --- /dev/null +++ b/Tải_Dự_liệu_train_xóa_mây.ipynb @@ -0,0 +1,321 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d-0LFUETwSj8", + "outputId": "3b84f2ea-e6d6-4404-a0ae-68ed567ce5ef" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Đã cấu hình xong! Hãy chạy các cell bên dưới để tải từng phần.\n" + ] + } + ], + "source": [ + "# --- CELL 1: SETUP ---\n", + "import os\n", + "import time\n", + "\n", + "# 1. Cấu hình\n", + "os.environ['RSYNC_PASSWORD'] = 'm1554803'\n", + "output_dir = \"./dataset/\"\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "base_url = \"rsync://m1554803@dataserv.ub.tum.de/m1554803/\"\n", + "\n", + "# 2. Hàm tải file (Dùng chung cho các cell bên dưới)\n", + "def download_files(file_list):\n", + " print(f\"📂 Lưu tại: {output_dir}\")\n", + " print(\"-\" * 50)\n", + " for filename in file_list:\n", + " print(f\"\\n🚀 Đang tải: {filename}\")\n", + "\n", + " # Lệnh rsync chuẩn, có resume (-P) và bỏ qua lỗi permission\n", + " command = f\"rsync -rvP --no-o --no-g --no-p --no-t {base_url}{filename} {output_dir}\"\n", + "\n", + " exit_code = os.system(command)\n", + " if exit_code == 0:\n", + " print(f\"✅ XONG: {filename}\")\n", + " else:\n", + " print(f\"❌ LỖI: {filename} (Code: {exit_code})\")\n", + " time.sleep(1)\n", + "\n", + "print(\"✅ Đã cấu hình xong! Hãy chạy các cell bên dưới để tải từng phần.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JIijwdmmwSho", + "outputId": "b84d0afb-66bf-4fb7-b2c6-3db3c55c49d0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Lưu tại: ./dataset/\n", + "--------------------------------------------------\n", + "\n", + "🚀 Đang tải: checksums.sha512\n", + "receiving incremental file list\n", + "checksums.sha512\n", + " 2,230 100% 2.13MB/s 0:00:00 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 43 bytes received 2,330 bytes 431.45 bytes/sec\n", + "total size is 2,230 speedup is 0.94\n", + "✅ XONG: checksums.sha512\n", + "\n", + "🚀 Đang tải: sen12ms_cr_dataLoader.py\n", + "receiving incremental file list\n", + "sen12ms_cr_dataLoader.py\n", + " 9,285 100% 8.85MB/s 0:00:00 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 43 bytes received 9,394 bytes 1,715.82 bytes/sec\n", + "total size is 9,285 speedup is 0.98\n", + "✅ XONG: sen12ms_cr_dataLoader.py\n" + ] + } + ], + "source": [ + "# --- CELL 2: FILES NHỎ ---\n", + "files = [\n", + " \"checksums.sha512\",\n", + " \"sen12ms_cr_dataLoader.py\"\n", + "]\n", + "download_files(files)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2SezTEcFwSf7", + "outputId": "5fe6e5ed-4267-4623-8554-8f9646074639" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Lưu tại: ./dataset/\n", + "--------------------------------------------------\n", + "\n", + "🚀 Đang tải: ROIs2017_winter_s1.tar.gz\n", + "receiving incremental file list\n", + "ROIs2017_winter_s1.tar.gz\n", + " 8,294,291,725 100% 78.02MB/s 0:01:41 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 728,687 bytes received 364,428 bytes 6,726.86 bytes/sec\n", + "total size is 8,294,291,725 speedup is 7,587.76\n", + "✅ XONG: ROIs2017_winter_s1.tar.gz\n", + "\n", + "🚀 Đang tải: ROIs2017_winter_s2.tar.gz\n", + "receiving incremental file list\n", + "ROIs2017_winter_s2.tar.gz\n", + " 21,959,347,301 100% 76.10MB/s 0:04:35 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 1,340,419 bytes received 670,295 bytes 5,215.86 bytes/sec\n", + "total size is 21,959,347,301 speedup is 10,921.17\n", + "✅ XONG: ROIs2017_winter_s2.tar.gz\n", + "\n", + "🚀 Đang tải: ROIs2017_winter_s2_cloudy.tar.gz\n", + "receiving incremental file list\n", + "ROIs2017_winter_s2_cloudy.tar.gz\n", + " 13,391,641,622 100% 74.42MB/s 0:02:51 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 925,899 bytes received 463,043 bytes 5,566.90 bytes/sec\n", + "total size is 13,391,641,622 speedup is 9,641.61\n", + "✅ XONG: ROIs2017_winter_s2_cloudy.tar.gz\n" + ] + } + ], + "source": [ + "# --- CELL 3: MÙA ĐÔNG (WINTER) ---\n", + "# Tổng dung lượng: ~40.6 GB\n", + "files = [\n", + " \"ROIs2017_winter_s1.tar.gz\",\n", + " \"ROIs2017_winter_s2.tar.gz\",\n", + " \"ROIs2017_winter_s2_cloudy.tar.gz\"\n", + "]\n", + "download_files(files)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7iIFPApowSbu", + "outputId": "9e4d45c7-0ee5-45f6-860e-99af4601aa22" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Lưu tại: ./dataset/\n", + "--------------------------------------------------\n", + "\n", + "🚀 Đang tải: ROIs1158_spring_s1.tar.gz\n", + "receiving incremental file list\n", + "ROIs1158_spring_s1.tar.gz\n", + " 13,229,071,106 100% 79.26MB/s 0:02:39 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 920,259 bytes received 460,216 bytes 5,763.99 bytes/sec\n", + "total size is 13,229,071,106 speedup is 9,582.98\n", + "✅ XONG: ROIs1158_spring_s1.tar.gz\n", + "\n", + "🚀 Đang tải: ROIs1158_spring_s2.tar.gz\n", + "receiving incremental file list\n", + "ROIs1158_spring_s2.tar.gz\n", + " 34,953,924,853 100% 76.34MB/s 0:07:16 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 2,133,595 bytes received 1,066,884 bytes 5,419.95 bytes/sec\n", + "total size is 34,953,924,853 speedup is 10,921.47\n", + "✅ XONG: ROIs1158_spring_s2.tar.gz\n", + "\n", + "🚀 Đang tải: ROIs1158_spring_s2_cloudy.tar.gz\n", + "receiving incremental file list\n", + "ROIs1158_spring_s2_cloudy.tar.gz\n", + " 21,910,216,767 100% 77.47MB/s 0:04:29 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 1,337,419 bytes received 668,802 bytes 5,400.33 bytes/sec\n", + "total size is 21,910,216,767 speedup is 10,921.14\n", + "✅ XONG: ROIs1158_spring_s2_cloudy.tar.gz\n" + ] + } + ], + "source": [ + "# --- CELL 4: MÙA XUÂN (SPRING) ---\n", + "# Tổng dung lượng: ~65.3 GB\n", + "files = [\n", + " \"ROIs1158_spring_s1.tar.gz\",\n", + " \"ROIs1158_spring_s2.tar.gz\",\n", + " \"ROIs1158_spring_s2_cloudy.tar.gz\"\n", + "]\n", + "download_files(files)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "s6t_5qyUwSZ-", + "outputId": "ddc101ae-753d-44f1-c3ee-c8413c16fb99" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Lưu tại: ./dataset/\n", + "--------------------------------------------------\n", + "\n", + "🚀 Đang tải: ROIs1868_summer_s2_cloudy.tar.gz\n", + "receiving incremental file list\n", + "ROIs1868_summer_s2_cloudy.tar.gz\n", + " 24,766,970,197 100% 77.10MB/s 0:05:06 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 1,511,791 bytes received 755,986 bytes 5,329.68 bytes/sec\n", + "total size is 24,766,970,197 speedup is 10,921.25\n", + "✅ XONG: ROIs1868_summer_s2_cloudy.tar.gz\n" + ] + } + ], + "source": [ + "# --- CELL 5: MÙA HÈ (SUMMER) ---\n", + "# Tổng dung lượng: ~74.7 GB\n", + "files = [\n", + " \"ROIs1868_summer_s1.tar.gz\",\n", + " \"ROIs1868_summer_s2.tar.gz\",\n", + " \"ROIs1868_summer_s2_cloudy.tar.gz\"\n", + "]\n", + "download_files(files)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KbccbbI9wSUd", + "outputId": "7b208f9d-8c3f-45e8-9582-88a872527cb0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Lưu tại: ./dataset/\n", + "--------------------------------------------------\n", + "\n", + "🚀 Đang tải: ROIs1970_fall_s2_cloudy.tar.gz\n", + "receiving incremental file list\n", + "ROIs1970_fall_s2_cloudy.tar.gz\n", + " 30,149,143,027 100% 36.83MB/s 0:13:00 (xfr#1, to-chk=0/1)\n", + "\n", + "sent 1,785,615 bytes received 897,343,620 bytes 1,008,557.75 bytes/sec\n", + "total size is 30,149,143,027 speedup is 33.53\n", + "✅ XONG: ROIs1970_fall_s2_cloudy.tar.gz\n" + ] + } + ], + "source": [ + "### --- CELL 6: MÙA THU (FALL) ---\n", + "# Tổng dung lượng: ~91.2 GB\n", + "files = [\n", + " \"ROIs1970_fall_s1.tar.gz\",\n", + " \"ROIs1970_fall_s2.tar.gz\",\n", + " \"ROIs1970_fall_s2_cloudy.tar.gz\"\n", + "]\n", + "download_files(files)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/format_lại_đataset.ipynb b/format_lại_đataset.ipynb new file mode 100644 index 0000000..3d53a33 --- /dev/null +++ b/format_lại_đataset.ipynb @@ -0,0 +1,241 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "415c3125-320a-4164-a9ad-4007ac750ff9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Đang thực hiện sắp xếp tại: /home/jovyan/cloud_train/dataset\n", + "➕ Đã tạo folder mẹ: ROIs1158_spring\n", + " ➡ Di chuyển: ROIs1158_spring_s1 -> ROIs1158_spring/s1\n", + " ➡ Di chuyển: ROIs1158_spring_s2 -> ROIs1158_spring/s2\n", + " ➡ Di chuyển: ROIs1158_spring_s2_cloudy -> ROIs1158_spring/s2_cloudy\n", + "➕ Đã tạo folder mẹ: ROIs1868_summer\n", + " ➡ Di chuyển: ROIs1868_summer_s1 -> ROIs1868_summer/s1\n", + " ➡ Di chuyển: ROIs1868_summer_s2 -> ROIs1868_summer/s2\n", + " ➡ Di chuyển: ROIs1868_summer_s2_cloudy -> ROIs1868_summer/s2_cloudy\n", + "➕ Đã tạo folder mẹ: ROIs1970_fall\n", + " ➡ Di chuyển: ROIs1970_fall_s1 -> ROIs1970_fall/s1\n", + " ➡ Di chuyển: ROIs1970_fall_s2 -> ROIs1970_fall/s2\n", + " ➡ Di chuyển: ROIs1970_fall_s2_cloudy -> ROIs1970_fall/s2_cloudy\n", + "➕ Đã tạo folder mẹ: ROIs2017_winter\n", + " ➡ Di chuyển: ROIs2017_winter_s1 -> ROIs2017_winter/s1\n", + " ➡ Di chuyển: ROIs2017_winter_s2 -> ROIs2017_winter/s2\n", + " ➡ Di chuyển: ROIs2017_winter_s2_cloudy -> ROIs2017_winter/s2_cloudy\n", + "\n", + "✅ ĐÃ XONG! Đã di chuyển 12 thư mục.\n", + "👉 Bây giờ hãy chạy lại Cell load dữ liệu (Train code), nó sẽ hoạt động.\n" + ] + } + ], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "# ==========================================\n", + "# SCRIPT SẮP XẾP LẠI THƯ MỤC (CHẠY 1 LẦN)\n", + "# ==========================================\n", + "\n", + "# 1. Xác định đường dẫn gốc\n", + "user_home = os.path.expanduser(\"~\")\n", + "base_dir = os.path.join(user_home, \"cloud_train\", \"dataset\")\n", + "print(f\"📂 Đang thực hiện sắp xếp tại: {base_dir}\")\n", + "\n", + "if not os.path.exists(base_dir):\n", + " print(\"❌ Lỗi: Không tìm thấy thư mục dataset!\")\n", + "else:\n", + " # 2. Định nghĩa cấu trúc chuẩn cần gom\n", + " # Key = Tên thư mục mẹ (Mới)\n", + " # Value = Danh sách các thư mục con (Cũ - đang nằm lẻ)\n", + " seasons_map = {\n", + " \"ROIs1158_spring\": [\"ROIs1158_spring_s1\", \"ROIs1158_spring_s2\", \"ROIs1158_spring_s2_cloudy\"],\n", + " \"ROIs1868_summer\": [\"ROIs1868_summer_s1\", \"ROIs1868_summer_s2\", \"ROIs1868_summer_s2_cloudy\"],\n", + " \"ROIs1970_fall\": [\"ROIs1970_fall_s1\", \"ROIs1970_fall_s2\", \"ROIs1970_fall_s2_cloudy\"],\n", + " \"ROIs2017_winter\": [\"ROIs2017_winter_s1\", \"ROIs2017_winter_s2\", \"ROIs2017_winter_s2_cloudy\"]\n", + " }\n", + "\n", + " count_moved = 0\n", + " \n", + " for target_season, sub_folders in seasons_map.items():\n", + " # Tạo thư mục mẹ (ví dụ: ROIs1158_spring)\n", + " target_path = os.path.join(base_dir, target_season)\n", + " if not os.path.exists(target_path):\n", + " os.makedirs(target_path)\n", + " print(f\"➕ Đã tạo folder mẹ: {target_season}\")\n", + " \n", + " for sub in sub_folders:\n", + " source_path = os.path.join(base_dir, sub) # Đường dẫn thư mục lẻ hiện tại\n", + " \n", + " # Nếu thư mục lẻ tồn tại, di chuyển nội dung của nó vào thư mục mẹ\n", + " if os.path.exists(source_path):\n", + " # Đổi tên thư mục lẻ cho đúng chuẩn (bỏ prefix mùa đi)\n", + " # Ví dụ: ROIs1158_spring_s1 -> s1\n", + " \n", + " new_folder_name = \"\"\n", + " if \"_s1\" in sub: new_folder_name = \"s1\"\n", + " elif \"_s2_cloudy\" in sub: new_folder_name = \"s2_cloudy\"\n", + " elif \"_s2\" in sub: new_folder_name = \"s2\"\n", + " \n", + " final_dest = os.path.join(target_path, new_folder_name)\n", + " \n", + " print(f\" ➡ Di chuyển: {sub} -> {target_season}/{new_folder_name}\")\n", + " \n", + " try:\n", + " # Nếu đích chưa có thì move thẳng folder sang và đổi tên\n", + " if not os.path.exists(final_dest):\n", + " shutil.move(source_path, final_dest)\n", + " count_moved += 1\n", + " else:\n", + " print(f\" ⚠️ Thư mục {new_folder_name} đã tồn tại trong {target_season}, bỏ qua.\")\n", + " except Exception as e:\n", + " print(f\" ❌ Lỗi: {e}\")\n", + " else:\n", + " # Kiểm tra xem nó đã nằm đúng chỗ chưa\n", + " check_path = os.path.join(target_path, sub.split(\"_\")[-1].replace(\"cloudy\", \"s2_cloudy\") if \"cloudy\" in sub else sub.split(\"_\")[-1])\n", + " # Logic check đơn giản\n", + " pass\n", + "\n", + " if count_moved > 0:\n", + " print(f\"\\n✅ ĐÃ XONG! Đã di chuyển {count_moved} thư mục.\")\n", + " print(\"👉 Bây giờ hãy chạy lại Cell load dữ liệu (Train code), nó sẽ hoạt động.\")\n", + " else:\n", + " print(\"\\nℹ️ Không có gì thay đổi. Có thể cấu trúc đã đúng hoặc tên file không khớp.\")\n", + " print(\"Hãy kiểm tra thủ công bằng lệnh: !ls -F /home/jovyan/cloud_train/dataset/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0b006b3f-54f9-46d2-87a6-c4ee00f7b1e4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Đang xử lý tại: /home/jovyan/cloud_train/dataset\n", + "🔍 Đang quét mùa: ROIs1158_spring\n", + " found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s1\n", + " found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2\n", + " found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2_cloudy\n", + "🔍 Đang quét mùa: ROIs1868_summer\n", + " found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s1\n", + " found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2\n", + " found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2_cloudy\n", + "🔍 Đang quét mùa: ROIs1970_fall\n", + " found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s1\n", + " found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2\n", + " found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2_cloudy\n", + "🔍 Đang quét mùa: ROIs2017_winter\n", + " found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s1\n", + " found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2\n", + " found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n", + " ✅ Đã xóa folder rỗng: s2_cloudy\n", + "\n", + "✅ ĐÃ SỬA XONG CẤU TRÚC!\n" + ] + } + ], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "# ==========================================\n", + "# FIX CẤU TRÚC PHASE 2: ĐƯA THƯ MỤC CON RA NGOÀI\n", + "# ==========================================\n", + "\n", + "user_home = os.path.expanduser(\"~\")\n", + "base_dir = os.path.join(user_home, \"cloud_train\", \"dataset\")\n", + "print(f\"📂 Đang xử lý tại: {base_dir}\")\n", + "\n", + "seasons = [\n", + " \"ROIs1158_spring\", \n", + " \"ROIs1868_summer\", \n", + " \"ROIs1970_fall\", \n", + " \"ROIs2017_winter\"\n", + "]\n", + "\n", + "# Các folder trung gian cần loại bỏ\n", + "sub_types = [\"s1\", \"s2\", \"s2_cloudy\"]\n", + "\n", + "for season in seasons:\n", + " season_path = os.path.join(base_dir, season)\n", + " if not os.path.exists(season_path):\n", + " continue\n", + " \n", + " print(f\"🔍 Đang quét mùa: {season}\")\n", + " \n", + " for sub in sub_types:\n", + " # Đường dẫn tới folder trung gian (ví dụ: ROIs1158_spring/s1)\n", + " nested_path = os.path.join(season_path, sub)\n", + " \n", + " if os.path.exists(nested_path):\n", + " print(f\" found nested folder: {sub} -> Đang di chuyển nội dung ra ngoài...\")\n", + " \n", + " # Lấy danh sách các scene bên trong (s1_1, s1_2...)\n", + " scenes = os.listdir(nested_path)\n", + " for scene in scenes:\n", + " src = os.path.join(nested_path, scene)\n", + " dst = os.path.join(season_path, scene)\n", + " \n", + " # Di chuyển ra folder mẹ\n", + " if not os.path.exists(dst):\n", + " shutil.move(src, dst)\n", + " \n", + " # Sau khi chuyển hết thì xóa folder rỗng đi\n", + " try:\n", + " os.rmdir(nested_path)\n", + " print(f\" ✅ Đã xóa folder rỗng: {sub}\")\n", + " except OSError:\n", + " print(f\" ⚠️ Không xóa được {sub} (có thể còn file rác)\")\n", + "\n", + "print(\"\\n✅ ĐÃ SỬA XONG CẤU TRÚC!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6151206-f103-451c-a572-029a4da25816", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sen12ms_cr_dataLoader.py b/sen12ms_cr_dataLoader.py new file mode 100755 index 0000000..66a6c4d --- /dev/null +++ b/sen12ms_cr_dataLoader.py @@ -0,0 +1,270 @@ +""" + Generic data loading routines for the SEN12MS-CR dataset of corresponding Sentinel 1, + Sentinel 2 and cloudy Sentinel 2 data. + + The SEN12MS-CR class is meant to provide a set of helper routines for loading individual + image patches as well as triplets of patches from the dataset. These routines can easily + be wrapped or extended for use with many deep learning frameworks or as standalone helper + methods. For an example use case please see the "main" routine at the end of this file. + + NOTE: Some folder/file existence and validity checks are implemented but it is + by no means complete. + + Authors: Patrick Ebel (patrick.ebel@tum.de), Lloyd Hughes (lloyd.hughes@tum.de), + based on the exemplary data loader code of https://mediatum.ub.tum.de/1474000, with minimal modifications applied. +""" + +import os +import rasterio + +import numpy as np + +from enum import Enum +from glob import glob + + +class S1Bands(Enum): + VV = 1 + VH = 2 + ALL = [VV, VH] + NONE = [] + + +class S2Bands(Enum): + B01 = aerosol = 1 + B02 = blue = 2 + B03 = green = 3 + B04 = red = 4 + B05 = re1 = 5 + B06 = re2 = 6 + B07 = re3 = 7 + B08 = nir1 = 8 + B08A = nir2 = 9 + B09 = vapor = 10 + B10 = cirrus = 11 + B11 = swir1 = 12 + B12 = swir2 = 13 + ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12] + RGB = [B04, B03, B02] + NONE = [] + + +class Seasons(Enum): + SPRING = "ROIs1158_spring" + SUMMER = "ROIs1868_summer" + FALL = "ROIs1970_fall" + WINTER = "ROIs2017_winter" + ALL = [SPRING, SUMMER, FALL, WINTER] + + +class Sensor(Enum): + s1 = "s1" + s2 = "s2" + s2cloudy = "s2cloudy" + +# Note: The order in which you request the bands is the same order they will be returned in. + + +class SEN12MSCRDataset: + def __init__(self, base_dir): + self.base_dir = base_dir + + if not os.path.exists(self.base_dir): + raise Exception( + "The specified base_dir for SEN12MS-CR dataset does not exist") + + """ + Returns a list of scene ids for a specific season. + """ + + def get_scene_ids(self, season): + season = Seasons(season).value + path = os.path.join(self.base_dir, season) + + if not os.path.exists(path): + raise NameError("Could not find season {} in base directory {}".format( + season, self.base_dir)) + + # add all dirs except "s2_cloudy" (which messes with subsequent string splits) + scene_list = [os.path.basename(s) + for s in glob(os.path.join(path, "*")) if "s2_cloudy" not in s] + scene_list = [int(s.split("_")[1]) for s in scene_list] + return set(scene_list) + + """ + Returns a list of patch ids for a specific scene within a specific season + """ + + def get_patch_ids(self, season, scene_id): + season = Seasons(season).value + path = os.path.join(self.base_dir, season, f"s1_{scene_id}") + + if not os.path.exists(path): + raise NameError( + "Could not find scene {} within season {}".format(scene_id, season)) + + patch_ids = [os.path.splitext(os.path.basename(p))[0] + for p in glob(os.path.join(path, "*"))] + patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids] + + return patch_ids + + """ + Return a dict of scene ids and their corresponding patch ids. + key => scene_ids, value => list of patch_ids + """ + + def get_season_ids(self, season): + season = Seasons(season).value + ids = {} + scene_ids = self.get_scene_ids(season) + + for sid in scene_ids: + ids[sid] = self.get_patch_ids(season, sid) + + return ids + + """ + Returns raster data and image bounds for the defined bands of a specific patch + This method only loads a sinlge patch from a single sensor as defined by the bands specified + """ + + def get_patch(self, season, scene_id, patch_id, bands): + season = Seasons(season).value + sensor = None + + if isinstance(bands, (list, tuple)): + b = bands[0] + else: + b = bands + + if isinstance(b, S1Bands): + sensor = Sensor.s1.value + bandEnum = S1Bands + elif isinstance(b, S2Bands): + sensor = Sensor.s2.value + bandEnum = S2Bands + else: + raise Exception("Invalid bands specified") + + if isinstance(bands, (list, tuple)): + bands = [b.value for b in bands] + else: + bands = bands.value + + scene = "{}_{}".format(sensor, scene_id) + filename = "{}_{}_p{}.tif".format(season, scene, patch_id) + patch_path = os.path.join(self.base_dir, season, scene, filename) + + with rasterio.open(patch_path) as patch: + data = patch.read(bands) + bounds = patch.bounds + + if len(data.shape) == 2: + data = np.expand_dims(data, axis=0) + + return data, bounds + + """ + Returns a triplet of patches. S1, S2 and cloudy S2 as well as the geo-bounds of the patch + """ + + def get_s1s2s2cloudy_triplet(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL): + s1, bounds = self.get_patch(season, scene_id, patch_id, s1_bands) + s2, _ = self.get_patch(season, scene_id, patch_id, s2_bands) + s2cloudy, _ = self.get_patch(season, scene_id, patch_id, s2cloudy_bands) + + return s1, s2, s2cloudy, bounds + + """ + Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified + using scene_ids and patch_ids and B is the number of bands for S1, S2 or cloudy S2 + """ + + def get_triplets(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL): + season = Seasons(season) + scene_list = [] + patch_list = [] + bounds = [] + s1_data = [] + s2_data = [] + s2cloudy_data = [] + + # This is due to the fact that not all patch ids are available in all scenes + # And not all scenes exist in all seasons + if isinstance(scene_ids, list) and isinstance(patch_ids, list): + raise Exception("Only scene_ids or patch_ids can be a list, not both.") + + if scene_ids is None: + scene_list = self.get_scene_ids(season) + else: + try: + scene_list.extend(scene_ids) + except TypeError: + scene_list.append(scene_ids) + + if patch_ids is not None: + try: + patch_list.extend(patch_ids) + except TypeError: + patch_list.append(patch_ids) + + for sid in scene_list: + if patch_ids is None: + patch_list = self.get_patch_ids(season, sid) + + for pid in patch_list: + s1, s2, s2cloudy, bound = self.get_s1s2s2cloudy_triplet( + season, sid, pid, s1_bands, s2_bands, s2cloudy_bands) + s1_data.append(s1) + s2_data.append(s2) + s2cloudy_data.append(s2cloudy) + bounds.append(bound) + + return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(s2cloudy_data, axis=0), bounds + + +if __name__ == "__main__": + import time + # Load the dataset specifying the base directory + sen12mscr = SEN12MSCRDataset(".") + + spring_ids = sen12mscr.get_season_ids(Seasons.SPRING) + cnt_patches = sum([len(pids) for pids in spring_ids.values()]) + print("Spring: {} scenes with a total of {} patches".format( + len(spring_ids), cnt_patches)) + + start = time.time() + # Load the RGB bands of the first S2 patch in scene 8 + SCENE_ID = 8 + s2_rgb_patch, bounds = sen12mscr.get_patch(Seasons.SPRING, SCENE_ID, + spring_ids[SCENE_ID][0], bands=S2Bands.RGB) + print("Time Taken {}s".format(time.time() - start)) + + print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds)) + + print("\n") + + # Load a triplet of patches from the first three scenes of Spring - all S1 bands, NDVI S2 bands, and NDVI S2 cloudy bands + i = 0 + start = time.time() + for scene_id, patch_ids in spring_ids.items(): + if i >= 3: + break + + s1, s2, s2cloudy, bounds = sen12mscr.get_s1s2s2cloudy_triplet(Seasons.SPRING, scene_id, patch_ids[0], s1_bands=S1Bands.ALL, + s2_bands=[S2Bands.red, S2Bands.nir1], s2cloudy_bands=[S2Bands.red, S2Bands.nir1]) + print( + f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}, Bounds: {bounds}") + i += 1 + + print("Time Taken {}s".format(time.time() - start)) + print("\n") + + start = time.time() + # Load all bands of all patches in a specified scene (scene 106) + s1, s2, s2cloudy, _ = sen12mscr.get_triplets(Seasons.SPRING, 106, s1_bands=S1Bands.ALL, + s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL) + + print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}") + print("Time Taken {}s".format(time.time() - start))