#=======================================================================
# Non-convolutional image filtering by bitmasks
#
# Cesare Brizio, 20 January 2023
#
# Attempt to split an image in three separate images, 
# - one with the darkest colors
# - one with midtone colors
# - one with the brightest colors
#
# "cat image" 1.jpg is available as a part of the
#  Cats vs. Dogs dataset from Kaggle 
# (https://www.kaggle.com/datasets/pybear/cats-vs-dogs?select=PetImages)
#=======================================================================

import numpy as np
from PIL import Image, ImageOps
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
from matplotlib import colors as mcolors
from numpy import asarray
import cv2

def plot_image(img: np.array):
    plt.figure(figsize=(6, 6), dpi=96)
    plt.title("Cat Image")
    plt.xlabel("X pixel scaling")
    plt.ylabel("Y pixels scaling")
    #plt.imshow(img, cmap='gray');
    plt.imshow(img);
    plt.show()
    
def plot_two_images(img1: np.array, img2: np.array, imm_name):
    _, ax = plt.subplots(1, 2, figsize=(12, 6), dpi=96)
    plt.title(imm_name)
    plt.xlabel("X pixel scaling")
    plt.ylabel("Y pixels scaling")    
    #ax[0].imshow(img1, cmap='gray')
    #ax[1].imshow(img2, cmap='gray');    
    ax[0].imshow(img1)
    ax[1].imshow(img2);
    plt.show()

# Bitwise AND detects coincidence with different combination of 
# channel bits:
Mask_Lightest_Pass = 0x80 #top bit high for 8-bit bitwise AND 
Mask_Darkest_Pass  = 0x07 #bottom 3 bit high for 8-bit bitwise AND 
Mask_HIGH_3_BIT    = 0xE0 #top three bit high for 8-bit bitwise AND 
Mask_HIGH_2_BIT    = 0xC0 #top two bit high for 8-bit bitwise AND 


def BITMASK_LIGHTEST(img: np.array) -> np.array:
    
    h, w, c = img.shape

    #========> Empty image with all white pixels
    #light_img = np.zeros(shape=(h, w, c))  
    light_img = np.full(img.shape,0xFF) 

    # Iterate over the rows
    for i in range(h):
        # Iterate over the columns
        for j in range(w):
            # img[i, j] = individual pixel values
            # Get the current pixel one channel at a time
            AND_BRIGHT_B = img[i, j, 0] & Mask_Lightest_Pass
            AND_BRIGHT_G = img[i, j, 1] & Mask_Lightest_Pass
            AND_BRIGHT_R = img[i, j, 2] & Mask_Lightest_Pass
            #print("AND_B ",AND_B," AND_G ",AND_G," AND_R ",AND_R)
            # Check if in the top range (Mask_Light_Pass  = 0x808080)
            if AND_BRIGHT_B == Mask_Lightest_Pass and AND_BRIGHT_G == Mask_Lightest_Pass and AND_BRIGHT_R == Mask_Lightest_Pass:
                # Store the result to i-th row and j-th column of light_img
                for c in range(3):
                    light_img[i, j, c] = img[i, j, c]
            
    # Clip result array to range 0..255 and make into uint8
    result = np.clip(light_img, 0, 255).astype(np.uint8)

    return result

def BITMASK_DARKEST(img: np.array) -> np.array:
    
    h, w, c = img.shape

    #========> Empty image with all white pixels
    #dark_img = np.zeros(shape=(h, w, c))  
    dark_img = np.full(img.shape,0xFF) 

    # Iterate over the rows
    for i in range(h):
        # Iterate over the columns
        for j in range(w):
            # img[i, j] = individual pixel value
            # Get the current pixel one channel at a time
            AND_DARK_B = img[i, j, 0] & Mask_HIGH_2_BIT
            AND_DARK_G = img[i, j, 1] & Mask_HIGH_2_BIT
            AND_DARK_R = img[i, j, 2] & Mask_HIGH_2_BIT
            #print("AND_B ",AND_B," AND_G ",AND_G," AND_R ",AND_R)
            # Check if in the low range (Mask_Light_Pass  = 0x070707)
            if (AND_DARK_B + AND_DARK_G + AND_DARK_R) == 0:
                # Store the result to i-th row and j-th column of light_img
                for c in range(3):
                    dark_img[i, j, c] = img[i, j, c]
           
    # Clip result array to range 0..255 and make into uint8
    result = np.clip(dark_img, 0, 255).astype(np.uint8)

    return result

def BITMASK_MIDTONES(img: np.array) -> np.array:
    
    h, w, c = img.shape

    #========> Empty image with all white pixels
    #mid_img = np.zeros(shape=(h, w, c))  
    mid_img = np.full(img.shape,0xFF) 

    # Iterate over the rows
    for i in range(h):
        # Iterate over the columns
        for j in range(w):
            # img[i, j] = individual pixel value
            # Get the current pixel one channel at a time
            # check that it's not dark 
            AND_DARK_B = img[i, j, 0] & Mask_HIGH_2_BIT
            AND_DARK_G = img[i, j, 1] & Mask_HIGH_2_BIT
            AND_DARK_R = img[i, j, 2] & Mask_HIGH_2_BIT
            #print("AND_B ",AND_B," AND_G ",AND_G," AND_R ",AND_R)
            # Check if in the low range (Mask_Light_Pass  = 0x070707)
            if (AND_DARK_B + AND_DARK_G + AND_DARK_R) == 0:
               #----------------------------
               # THIS IS A DARK PIXEL !!!
               #----------------------------
               pass
            else:
                # check that it's not bright 
                AND_BRIGHT_B = img[i, j, 0] & Mask_Lightest_Pass
                AND_BRIGHT_G = img[i, j, 1] & Mask_Lightest_Pass
                AND_BRIGHT_R = img[i, j, 2] & Mask_Lightest_Pass
                #print("AND_B ",AND_B," AND_G ",AND_G," AND_R ",AND_R)
                # Check if in the top range (Mask_Light_Pass  = 0x808080)
                if AND_BRIGHT_B == Mask_Lightest_Pass and AND_BRIGHT_G == Mask_Lightest_Pass and AND_BRIGHT_R == Mask_Lightest_Pass:
                    #----------------------------
                    # THIS IS A BRIGHT PIXEL !!!
                    #----------------------------
                    pass
                else:
                    # Store the result to i-th row and j-th column of light_img
                    for c in range(3):
                        mid_img[i, j, c] = img[i, j, c]
            
    # Clip result array to range 0..255 and make into uint8
    result = np.clip(mid_img, 0, 255).astype(np.uint8)

    return result


#=======================================================
# LOAD THE ORIGINAL IMAGE by Image.open method
#=======================================================
pI = Image.open('C:/Conv_Python/images/1.jpg')

# create numpy array from image
img = np.array(pI)

#plot_image(img=img)  
#image_to_save = Image.fromarray(img,'RGB')
#image_to_save.save("original_image.jpg")


#============================================
# Save only lightest pixels with the
# most significant bit high in each 
# channels (upper half of range)
# AND succeds with Mask_Lightest_Pass = 0x80
#============================================
Curr_Title="Cat Image - LIGHTEST PIXELS (HIGHLIGHTS)"
img_lightest_pix = BITMASK_LIGHTEST(img=np.array(img))


plot_two_images(
    img1=img, 
    img2=img_lightest_pix,
    imm_name=Curr_Title
)     

plt.imsave(fname='lightest_pixels.png', arr=img_lightest_pix, format='png')


#=======================================
# Save only pixels with all three 
# channels having the two most
# significant bits low
# AND fails with Mask_HIGH_2_BIT = 0xC0 
#=======================================
Curr_Title="Cat Image - DARKEST PIXELS (SHADOWS)"
img_darkest_pix = BITMASK_DARKEST(img=np.array(img))


plot_two_images(
    img1=img, 
    img2=img_darkest_pix,
    imm_name=Curr_Title
)     

plt.imsave(fname='darkest_pixels.png', arr=img_darkest_pix, format='png')



#==========================================
#  Save the pixels not selected previously
#==========================================
Curr_Title="Cat Image - MIDTONE PIXELS (MIDTONES)"
img_midtone_pix = BITMASK_MIDTONES(img=np.array(img))


plot_two_images(
    img1=img, 
    img2=img_midtone_pix,
    imm_name=Curr_Title
)     

plt.imsave(fname='midtone_pixels.png', arr=img_midtone_pix, format='png')

