#=======================================================================
# Image filtering by convolution & kernels
#
# Cesare Brizio, 19 January 2023
#
# I did something similar around 15 years ago in Visual Basic,
# with the heavy burden of a Visual Studio installation.
# Thanks to Python, I can provide a working example with in a
# much lighter environment.
#
# The main purpose here is to illustrate the inner mechanics of
# a kernel convolution algorithm (see the nested loops) to allow
# a better understanding of the underlying logics.
#
# Inspired by a post by Dario Radečić
# (https://medium.com/@radecicdario)
# (https://towardsdatascience.com/tensorflow-for-computer-vision-how-to-implement-convolutions-from-scratch-in-python-609158c24f82)
# With some help by Mark Setchell
# (https://github.com/MarkSetchell)
#
# "cat image" 1.jpg is available as a part of the
#  Cats vs. Dogs dataset from Kaggle 
# (https://www.kaggle.com/datasets/pybear/cats-vs-dogs?select=PetImages)
#
# Includes kernels from https://stackoverflow.com/questions/58383477/how-to-create-a-python-convolution-kernel
# Just a few of the kernels listed are used in the code, feel free
# to edit it as needed
#=======================================================================

import numpy as np
from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
from matplotlib import colors as mcolors
from numpy import asarray
import cv2

def plot_image(img: np.array):
    plt.figure(figsize=(6, 6), dpi=96)
    plt.title("Cat Image")
    plt.xlabel("X pixel scaling")
    plt.ylabel("Y pixels scaling")
    #plt.imshow(img, cmap='gray'); # no need for a color map
    plt.imshow(img);
    plt.show()

    
def plot_two_images(img1: np.array, img2: np.array, imm_name):
    _, ax = plt.subplots(1, 2, figsize=(12, 6), dpi=96)
    plt.title(imm_name)
    plt.xlabel("X pixel scaling")
    plt.ylabel("Y pixels scaling")    
    #ax[0].imshow(img1, cmap='gray')
    #ax[1].imshow(img2, cmap='gray');    
    ax[0].imshow(img1)
    ax[1].imshow(img2);
    plt.show()

sharpen = np.array([
    [0, -1, 0],
    [-1, 5, -1],
    [0, -1, 0]
])

blur = np.array([
    [0.0625, 0.125, 0.0625],
    [0.125,  0.25,  0.125],
    [0.0625, 0.125, 0.0625]
])

outline = np.array([
    [-1, -1, -1],
    [-1,  8, -1],
    [-1, -1, -1]
])

laplacian = np.array([
    [0, 1, 0], 
    [1, -4, 1], 
    [0, 1, 0]
])

emboss = np.array([
    [-2, -1, 0], 
    [-1, 1, 1], 
    [0, 1, 2]
])

bottom_sobel = np.array([
    [-1, -2, -1], 
    [0, 0, 0], 
    [1, 2, 1]
])

left_sobel = np.array([
    [1, 0, -1], 
    [2, 0, -2], 
    [1, 0, -1]
])

right_sobel = np.array([
    [-1, 0, 1], 
    [-2, 0, 2], 
    [-1, 0, 1]
])

top_sobel = np.array([
    [1, 2, 1], 
    [0, 0, 0], 
    [-1, -2, -1]
])


def calculate_target_size(img_size: int, kernel_size: int) -> int:
    print(f'calculate_target_size({img_size}, {img_size})')    
    num_pixels = 0
    
    # From 0 up to img size (if img size = 224, then up to 223)
    for i in range(img_size):
        # Add the kernel size (let's say 3) to the current i
        added = i + kernel_size
        # It must be lower than the image size
        if added <= img_size:
            # Increment if so
            num_pixels += 1

    print(f'calculate_target_size returns {num_pixels}')            
    return num_pixels

def convolve(img: np.array, kernel: np.array) -> np.array:
    # Assuming a rectangular image
    tgt_size = calculate_target_size(
        img_size=img.shape[0],
        kernel_size=kernel.shape[0]
    )
    # To simplify things
    k = kernel.shape[0]
    
    # This will hold our 3-channel RGB result
    convolved = np.zeros(shape=(tgt_size, tgt_size,3))   
    
    # Iterate over the rows
    for i in range(tgt_size):
        # Iterate over the columns
        for j in range(tgt_size):
            # Iterate over channels
            for c in range(3):
                mat = img[i:i+k, j:j+k, c]
                # Apply the convolution - element-wise multiplication and summation of the result
                # Store the result to i-th row and j-th column of our convolved_img array
                convolved[i, j, c] = np.sum(np.multiply(mat, kernel))

    # Clip result array to range 0..255 and make into uint8
    result = np.clip(convolved, 0, 255).astype(np.uint8)
    print(f'{convolved.dtype}, {convolved.shape}') 
    print(f'Rmax: {np.max(result[...,0])}, Rmin: {np.min(result[...,0])}')
    print(f'Gmax: {np.max(result[...,1])}, Gmin: {np.min(result[...,1])}')
    print(f'Bmax: {np.max(result[...,2])}, Bmin: {np.min(result[...,2])}')

    return result

# ----------------------------------------------------
# The following is currently useless and is kept for 
# reference purposes (np.clip takes care of clipping)
# ----------------------------------------------------
#def negative_to_zero(img: np.array) -> np.array:
#    img = img.copy()
#    img[img < 0] = 0
#    return img

#===========================================================
# Open image as PIL Image and make Numpy array version too
#===========================================================
pI = Image.open('C:/Conv_Python/images/1.jpg')
img = np.array(pI)

plot_image(img=img)  
#------------------> don't use a cmap such as cmap='gray_r' as 3rd parameter
plt.imsave(fname='_original.png', arr=img, format='png')

#===================================
#  S H A R P E N E D
#===================================
Curr_Title="Cat Image - Sharpened"
img_sharpened = convolve(img=img, kernel=sharpen)
plt.imsave(fname='_sharpened.png', arr=img_sharpened, format='png')

plot_two_images(
    img1=img, 
    img2=img_sharpened,
    imm_name=Curr_Title
)        

#===================================
#  S H A R P E N E D 
#        vs.
#  SHARPENED AND NORMALIZED
#===================================
# Now useless, images are normalized in the
# convolve() function
#
# NORMALIZE
#img_shar_nor = cv2.normalize(img_sharpened,  None, 0, 255, cv2.NORM_MINMAX)

#plot_two_images(
#    img1=img_sharpened, 
#    img2=img_shar_nor
#)  

#===================================
#  B L U R R E D
#===================================
Curr_Title="Cat Image - Blurred"
img_blurred = convolve(img=img, kernel=blur)
plt.imsave(fname='_blurred.png', arr=img_blurred, format='png')

plot_two_images(
    img1=img, 
    img2=img_blurred,
    imm_name=Curr_Title
)

#===================================
#  O U T L I N E D
#===================================
Curr_Title="Cat Image - Outlined"
img_outlined = convolve(img=img, kernel=outline)
#plt.imsave(fname='_outlined.png', arr=img_outlined, cmap='gray_r', format='png')
plt.imsave(fname='_outlined.png', arr=img_outlined, format='png')

plot_two_images(
    img1=img, 
    img2=img_outlined,
    imm_name=Curr_Title
)

#===================================
#  NEG_TO_ZERO OUTLINED
#===================================
#img_neg_to_z_OUT = negative_to_zero(img=img_outlined)
#plt.imsave(fname='_neg_to_z_OUT.png', arr=img_neg_to_z_OUT, format='png')
#
#plot_two_images(
#    img1=img, 
#    img2=img_neg_to_z_OUT
#)
