{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [],
   "source": [
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report \n",
    "from sklearn.model_selection import train_test_split\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras import Sequential\n",
    "from tensorflow.keras.layers import Dense, MaxPooling2D, GlobalAveragePooling2D, Conv2D, Input, Flatten,BatchNormalization, Dropout, Activation,  Add, LeakyReLU, ELU\n",
    "import tensorflow\n",
    "import tensorflow as tf\n",
    "from keras.optimizers import Adam, RMSprop\n",
    "from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n",
    "from keras.engine.topology import Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1323 962\n"
     ]
    }
   ],
   "source": [
    "TRAIN_PATH = '../datos/train/*/*'\n",
    "TEST_PATH  = '../datos/test/*/*'\n",
    "\n",
    "TRAIN_IMGS = glob(TRAIN_PATH)\n",
    "TEST_IMGS  = glob(TEST_PATH)\n",
    "\n",
    "print (len(TRAIN_IMGS) , len (TEST_IMGS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['../input/tfg-covid-data/train/NORMAL/NORMAL (159).png',\n",
       "  '../input/tfg-covid-data/train/NORMAL/NORMAL (28).png'],\n",
       " ['../input/tfg-covid-data/test/NORMAL/NORMAL (391).png',\n",
       "  '../input/tfg-covid-data/test/NORMAL/NORMAL (695).png'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TRAIN_IMGS[: 2], TEST_IMGS[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "## converts name into number\n",
    "class_dict = {'NORMAL': 0, 'COVID-19':1, 'Viral Pneumonia':2}\n",
    "num2label = {0:'NORMAL', 1:'COVID-19', 2:'Viral Pneumonia'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Generate labels\n",
    "\n",
    "TRAIN_LABELS = []\n",
    "TEST_LABELS = []\n",
    "\n",
    "for filename in TRAIN_IMGS:\n",
    "    label = str(filename.split('/')[4])\n",
    "    #print (label)\n",
    "    assert label in ['NORMAL', 'COVID-19', 'Viral Pneumonia']\n",
    "    label = int (class_dict[label])\n",
    "    TRAIN_LABELS.append (label)\n",
    "    \n",
    "for filename in TEST_IMGS:\n",
    "    label = str(filename.split('/')[4])\n",
    "    #print (label)\n",
    "    assert label in ['NORMAL', 'COVID-19', 'Viral Pneumonia']\n",
    "    label = int (class_dict[label])\n",
    "    TEST_LABELS.append (label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1323, 962, [0, 0, 0], [0, 0, 0])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(TRAIN_LABELS) , len(TEST_LABELS), TRAIN_LABELS[:3], TEST_LABELS[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "NCLASSES = 3\n",
    "SIZE = (1024,1024)\n",
    "BATCH_SIZE = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_image(filename, label = None, image_size= SIZE ):\n",
    "    \n",
    "    bits = tf.io.read_file(filename)\n",
    "    image = tf.image.decode_png(bits, channels=1)\n",
    "    image = tf.cast(image, tf.float32) / 255.0\n",
    "    image = tf.image.resize(image, image_size)\n",
    "    \n",
    "    if label is None:\n",
    "        return image\n",
    "    else:\n",
    "        return image, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, label  = decode_image(TRAIN_IMGS[0], TRAIN_LABELS[0])\n",
    "plt.imshow(image)\n",
    "plt.title (num2label[label] + str(label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = cv2.imread('../input/tfg-covid-data/train/Viral Pneumonia/Viral Pneumonia (1000).png')\n",
    "img.shape, img.min(), img.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "## DATALOADERS\n",
    "\n",
    "AUTO = tf.data.experimental.AUTOTUNE\n",
    "\n",
    "train_dataset = (\n",
    "    tf.data.TFRecordDataset\n",
    "    .from_tensor_slices((TRAIN_IMGS, TRAIN_LABELS))\n",
    "    .map(decode_image, num_parallel_calls=AUTO)\n",
    "    .repeat()\n",
    "    .shuffle(1024)\n",
    "    .batch(44)\n",
    "    #.batch(BATCH_SIZE)\n",
    "    )\n",
    "\n",
    "\n",
    "test_dataset = (\n",
    "    tf.data.TFRecordDataset\n",
    "    .from_tensor_slices((TEST_IMGS, TEST_LABELS))\n",
    "    .map(decode_image, num_parallel_calls=AUTO)\n",
    "    .batch(19)\n",
    "    #.batch(BATCH_SIZE)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for element in train_dataset:\n",
    "    \n",
    "    image = element[0]\n",
    "    label = element[1]\n",
    "\n",
    "    print (image.shape, label.shape)\n",
    "    for img, lbl in zip(image, label): #batch\n",
    "        \n",
    "        plt.imshow(img)\n",
    "        plt.show()\n",
    "        \n",
    "        break\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_10 (Conv2D)           (None, 1022, 1022, 8)     80        \n",
      "_________________________________________________________________\n",
      "conv2d_11 (Conv2D)           (None, 1020, 1020, 8)     584       \n",
      "_________________________________________________________________\n",
      "max_pooling2d_5 (MaxPooling2 (None, 510, 510, 8)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_12 (Conv2D)           (None, 508, 508, 16)      1168      \n",
      "_________________________________________________________________\n",
      "conv2d_13 (Conv2D)           (None, 506, 506, 16)      2320      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_6 (MaxPooling2 (None, 253, 253, 16)      0         \n",
      "_________________________________________________________________\n",
      "conv2d_14 (Conv2D)           (None, 251, 251, 32)      4640      \n",
      "_________________________________________________________________\n",
      "conv2d_15 (Conv2D)           (None, 249, 249, 32)      9248      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_7 (MaxPooling2 (None, 124, 124, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv2d_16 (Conv2D)           (None, 122, 122, 64)      18496     \n",
      "_________________________________________________________________\n",
      "conv2d_17 (Conv2D)           (None, 120, 120, 64)      36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_8 (MaxPooling2 (None, 60, 60, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_18 (Conv2D)           (None, 58, 58, 128)       73856     \n",
      "_________________________________________________________________\n",
      "conv2d_19 (Conv2D)           (None, 56, 56, 128)       147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_9 (MaxPooling2 (None, 28, 28, 128)       0         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 100352)            0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 256)               25690368  \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 3)                 771       \n",
      "=================================================================\n",
      "Total params: 25,986,043\n",
      "Trainable params: 25,986,043\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "#arquitectura Teodoro\n",
    "EPOCHS = 100\n",
    "\n",
    "model = Sequential()\n",
    "\n",
    "\n",
    "\n",
    "model.add(Conv2D(filters=8 , (3 , activation='relu', input_shape=(1024,1024,1)))\n",
    "model.add(Conv2D(filters=8 , kernel_size=3 , activation='relu'))\n",
    "model.add(MaxPooling2D((3,3)))\n",
    "model.add(Conv2D(filters=16 , kernel_size=3 , activation='relu'))\n",
    "model.add(Conv2D(filters=16 , kernel_size=3 , activation='relu'))\n",
    "model.add(MaxPooling2D((3,3)))\n",
    "model.add(Conv2D(filters=32 , kernel_size=3 , activation='relu'))\n",
    "model.add(Conv2D(filters=32 , kernel_size=3 , activation='relu'))\n",
    "model.add(MaxPooling2D((3,3)))\n",
    "model.add(Conv2D(filters=64 , kernel_size=3 , activation='relu'))\n",
    "model.add(Conv2D(filters=64 , kernel_size=3 , activation='relu'))\n",
    "model.add(MaxPooling2D((3,3)))\n",
    "model.add(Conv2D(filters=128 , kernel_size=3 , activation='relu'))\n",
    "model.add(Conv2D(filters=128 , kernel_size=3 , activation='relu'))\n",
    "model.add(MaxPooling2D((3,3)))\n",
    "\n",
    "#model.add(GlobalAveragePooling2D()) # mas eficiente que Flatten\n",
    "model.add(Flatten()) # o GAP = Global Average Pooling\n",
    "model.add(Dense(256,activation=\"relu\"))\n",
    "#model.add(Dropout (0.5))\n",
    "model.add(Dense(NCLASSES, activation='softmax'))\n",
    "\n",
    "print (model.summary())\n",
    "model.compile(optimizer= RMSprop(0.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LAYERS = 4\n",
    "EPOCHS = 5\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Input((SIZE[0],SIZE[0],3))) # (16, 512, 512, 3)\n",
    "\n",
    "for i in range(1, LAYERS+1):\n",
    "    print ('block' , i)\n",
    "    model.add(Conv2D(filters=16* (2**i) , kernel_size=3 , padding='same', activation='relu'))\n",
    "    model.add(Conv2D(filters=16* (2**i) , kernel_size=3 , padding='same', activation='relu'))\n",
    "    #model.add(Conv2D(filters=16* (2**i) , kernel_size=3 , padding='same', activation='relu'))\n",
    "    model.add(MaxPooling2D((2,2)))\n",
    "\n",
    "#model.add(GlobalAveragePooling2D()) # mas eficiente que Flatten\n",
    "model.add(Flatten())\n",
    "model.add(Dense(512,activation=\"relu\"))\n",
    "model.add(Dropout (0.5)) #regularizacion\n",
    "model.add(Dense(NCLASSES, activation='softmax'))\n",
    "\n",
    "print (model.summary())\n",
    "model.compile(optimizer= RMSprop(0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "STEPS_PER_EPOCH = len(TRAIN_IMGS) // 44#BATCH_SIZE\n",
    "BATCH_SIZE, EPOCHS, STEPS_PER_EPOCH\n",
    "VALIDATION_STEPS = len(TEST_IMGS) // 19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "30/30 [==============================] - 89s 2s/step - loss: 0.9596 - accuracy: 0.5794 - val_loss: 0.7236 - val_accuracy: 0.5632\n",
      "Epoch 2/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.5108 - accuracy: 0.8070 - val_loss: 0.3412 - val_accuracy: 0.8726\n",
      "Epoch 3/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.3205 - accuracy: 0.8718 - val_loss: 0.4332 - val_accuracy: 0.8126\n",
      "Epoch 4/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.2543 - accuracy: 0.9051 - val_loss: 0.3484 - val_accuracy: 0.8663\n",
      "Epoch 5/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.2787 - accuracy: 0.9041 - val_loss: 0.2671 - val_accuracy: 0.9063\n",
      "Epoch 6/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.2398 - accuracy: 0.9090 - val_loss: 0.2619 - val_accuracy: 0.9074\n",
      "Epoch 7/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.2304 - accuracy: 0.9126 - val_loss: 0.2845 - val_accuracy: 0.8895\n",
      "Epoch 8/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.1868 - accuracy: 0.9378 - val_loss: 0.3402 - val_accuracy: 0.8800\n",
      "Epoch 9/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.1531 - accuracy: 0.9444 - val_loss: 0.2378 - val_accuracy: 0.9242\n",
      "Epoch 10/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.1818 - accuracy: 0.9302 - val_loss: 0.2493 - val_accuracy: 0.9126\n",
      "Epoch 11/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.1632 - accuracy: 0.9474 - val_loss: 0.3234 - val_accuracy: 0.8905\n",
      "Epoch 12/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.1516 - accuracy: 0.9480 - val_loss: 0.3337 - val_accuracy: 0.8916\n",
      "Epoch 13/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.1084 - accuracy: 0.9574 - val_loss: 0.3399 - val_accuracy: 0.8937\n",
      "Epoch 14/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0936 - accuracy: 0.9671 - val_loss: 0.2425 - val_accuracy: 0.9284\n",
      "Epoch 15/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0817 - accuracy: 0.9687 - val_loss: 0.2315 - val_accuracy: 0.9263\n",
      "Epoch 16/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0639 - accuracy: 0.9789 - val_loss: 0.4438 - val_accuracy: 0.8411\n",
      "Epoch 17/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.1003 - accuracy: 0.9647 - val_loss: 0.2368 - val_accuracy: 0.9316\n",
      "Epoch 18/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0768 - accuracy: 0.9680 - val_loss: 0.2742 - val_accuracy: 0.9274\n",
      "Epoch 19/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0739 - accuracy: 0.9697 - val_loss: 0.2523 - val_accuracy: 0.9316\n",
      "Epoch 20/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0566 - accuracy: 0.9786 - val_loss: 0.3835 - val_accuracy: 0.8947\n",
      "Epoch 21/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0379 - accuracy: 0.9860 - val_loss: 0.3615 - val_accuracy: 0.9189\n",
      "Epoch 22/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0537 - accuracy: 0.9818 - val_loss: 0.2796 - val_accuracy: 0.9326\n",
      "Epoch 23/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0420 - accuracy: 0.9833 - val_loss: 0.5393 - val_accuracy: 0.8537\n",
      "Epoch 24/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0332 - accuracy: 0.9866 - val_loss: 0.5634 - val_accuracy: 0.8632\n",
      "Epoch 25/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0691 - accuracy: 0.9784 - val_loss: 0.3830 - val_accuracy: 0.9179\n",
      "Epoch 26/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.1033 - accuracy: 0.9678 - val_loss: 0.3829 - val_accuracy: 0.9263\n",
      "Epoch 27/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0151 - accuracy: 0.9965 - val_loss: 0.3094 - val_accuracy: 0.9442\n",
      "Epoch 28/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0354 - accuracy: 0.9892 - val_loss: 0.2608 - val_accuracy: 0.9474\n",
      "Epoch 29/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0180 - accuracy: 0.9974 - val_loss: 0.3035 - val_accuracy: 0.9347\n",
      "Epoch 30/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0122 - accuracy: 0.9972 - val_loss: 0.3220 - val_accuracy: 0.9474\n",
      "Epoch 31/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0369 - accuracy: 0.9884 - val_loss: 0.3080 - val_accuracy: 0.9347\n",
      "Epoch 32/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0021 - accuracy: 1.0000 - val_loss: 0.4127 - val_accuracy: 0.9295\n",
      "Epoch 33/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0405 - accuracy: 0.9883 - val_loss: 0.3328 - val_accuracy: 0.9389\n",
      "Epoch 34/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 6.7841e-04 - accuracy: 1.0000 - val_loss: 0.4397 - val_accuracy: 0.9295\n",
      "Epoch 35/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0207 - accuracy: 0.9958 - val_loss: 0.4035 - val_accuracy: 0.9347\n",
      "Epoch 36/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0039 - accuracy: 0.9990 - val_loss: 0.4232 - val_accuracy: 0.9274\n",
      "Epoch 37/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0071 - accuracy: 0.9970 - val_loss: 0.4524 - val_accuracy: 0.9379\n",
      "Epoch 38/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0127 - accuracy: 0.9962 - val_loss: 0.4557 - val_accuracy: 0.9316\n",
      "Epoch 39/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 5.2700e-04 - accuracy: 1.0000 - val_loss: 0.4592 - val_accuracy: 0.9358\n",
      "Epoch 40/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0053 - accuracy: 0.9990 - val_loss: 0.4466 - val_accuracy: 0.9284\n",
      "Epoch 41/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.4651 - val_accuracy: 0.9411\n",
      "Epoch 42/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 9.5511e-05 - accuracy: 1.0000 - val_loss: 0.5212 - val_accuracy: 0.9389\n",
      "Epoch 43/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 0.0421 - accuracy: 0.9905 - val_loss: 0.3996 - val_accuracy: 0.9347\n",
      "Epoch 44/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 5.2430e-04 - accuracy: 1.0000 - val_loss: 0.4369 - val_accuracy: 0.9432\n",
      "Epoch 45/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 8.4447e-05 - accuracy: 1.0000 - val_loss: 0.5082 - val_accuracy: 0.9411\n",
      "Epoch 46/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0101 - accuracy: 0.9989 - val_loss: 0.5193 - val_accuracy: 0.9147\n",
      "Epoch 47/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 0.0043 - accuracy: 0.9978 - val_loss: 0.5534 - val_accuracy: 0.9263\n",
      "Epoch 48/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 2.3153e-04 - accuracy: 1.0000 - val_loss: 0.5945 - val_accuracy: 0.9295\n",
      "Epoch 49/100\n",
      "30/30 [==============================] - 52s 2s/step - loss: 4.6704e-05 - accuracy: 1.0000 - val_loss: 0.6543 - val_accuracy: 0.9295\n",
      "Epoch 50/100\n",
      "30/30 [==============================] - 51s 2s/step - loss: 1.4655e-05 - accuracy: 1.0000 - val_loss: 0.6759 - val_accuracy: 0.9316\n",
      "Epoch 51/100\n",
      "25/30 [========================>.....] - ETA: 6s - loss: 0.0328 - accuracy: 0.9955"
     ]
    }
   ],
   "source": [
    "hist = model.fit(train_dataset, validation_data=test_dataset, batch_size=BATCH_SIZE, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(hist.history['loss'])\n",
    "plt.plot(hist.history['val_loss'])\n",
    "plt.title(\"model accuracy\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.legend([\"loss\",\"Validation Loss\"])\n",
    "#plt.savefig(\"history_vgg16.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(hist.history[\"accuracy\"])\n",
    "plt.plot(hist.history['val_accuracy'])\n",
    "plt.title(\"model accuracy\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.legend([\"Accuracy\",\"Validation Accuracy\",\"loss\",\"Validation Loss\"])\n",
    "#plt.savefig(\"history_vgg16.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import load_model\n",
    "\n",
    "model.save('model.h5')#136M de params - 3 canales - \n",
    "#model = load_model('modelo136M3C.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Predict test data\n",
    "\n",
    "predictions = model.predict(test_dataset)\n",
    "print (predictions.shape)\n",
    "predictions = np.argmax(predictions, axis=1).astype(np.int)\n",
    "\n",
    "print (accuracy_score(TEST_LABELS,predictions ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_names = {'COVID-19','NORMAL', 'Viral pneumonia'}\n",
    "print('Confussion matrix:')\n",
    "c = confusion_matrix(TEST_LABELS,predictions)\n",
    "print(c)\n",
    "print('Classification report')\n",
    "print(classification_report(TEST_LABELS,predictions,target_names=target_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(c,range(3),range(3))\n",
    "plt.figure(figsize=(6,4))\n",
    "sns.set(font_scale=1.4)\n",
    "sns.heatmap(df, annot=True, annot_kws={\"size\":16}, cmap=\"Blues\")\n",
    "plt.savefig(\"conf_mat_modelo136M3C.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 1022, 1022, 8)     80        \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 1020, 1020, 8)     584       \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 340, 340, 8)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 338, 338, 16)      1168      \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 336, 336, 16)      2320      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 168, 168, 16)      0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 166, 166, 32)      4640      \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 164, 164, 32)      9248      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 82, 82, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_6 (Conv2D)            (None, 80, 80, 64)        18496     \n",
      "_________________________________________________________________\n",
      "conv2d_7 (Conv2D)            (None, 78, 78, 64)        36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_3 (MaxPooling2 (None, 39, 39, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_8 (Conv2D)            (None, 37, 37, 128)       73856     \n",
      "_________________________________________________________________\n",
      "conv2d_9 (Conv2D)            (None, 35, 35, 128)       147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_4 (MaxPooling2 (None, 17, 17, 128)       0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 36992)             0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 256)               9470208   \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 3)                 771       \n",
      "=================================================================\n",
      "Total params: 9,765,883\n",
      "Trainable params: 9,765,883\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1024, 1024)\n",
      "(1, 1024, 1024, 1)\n",
      "[[9.9999976e-01 7.2375179e-15 2.0856763e-07]]\n",
      "[0]\n",
      "NORMAL\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from keras.preprocessing import image\n",
    "img = image.load_img(\"../datos/test/COVID-19/COVID-19(135).png\",target_size=(1024,1024), color_mode=\"grayscale\")\n",
    "img = np.asarray(img)/255.0\n",
    "plt.imshow(img)\n",
    "print(img.shape)\n",
    "img = np.expand_dims(img, axis=0)\n",
    "img = np.expand_dims(img, axis=3)\n",
    "print(img.shape)\n",
    "from keras.models import load_model\n",
    "#saved_model = load_model(\"./modelo136M3C.h5\")\n",
    "#output = saved_model.predict(img)\n",
    "output = model.predict(img)\n",
    "print(output)\n",
    "prediction = np.argmax(output, axis=1).astype(np.int)\n",
    "print(prediction)\n",
    "print(num2label[prediction[0]])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}