import os
import numpy as np
import json
import cv2
import pickle
import detectron2
from detectron2.engine import DefaultPredictor
import numpy as np
import sys

"""Module for predicting segmentation masks using Mask-RCNN convolutional network. 
   Usage:
   > python PredictMaskRCNN.py <maskRCNNfile> <inputImage> <outputProductMaskDir> <outputPackageMaskDir>
   
   Input arguments:
   * <maskRCNNfile> - File storing trained Mask-RCNN net.
   * <inputImage> - Image to be segmented.
   * <outputProductMaskDir> - Path to a folder for output product mask.
   * <outputPackageMaskDir> - Path to a folder for output package mask.

"""

__author__ = "Ondrej Klima"
__copyright__ = "Copyright 2020"
__credits__ = ["Ondrej Klima"]
__email__ = "iklima@fit.vutbr.cz"
__license__ = "BUT"
__version__ = "1.0"
__maintainer__ = "Ondrej Klima"


def main():
    # Parsing input arguments 
    argv = sys.argv

    try:
        cnnFileName = argv[1]
    except IndexError:
        raise IndexError('File with Mask-RCNN must be supplied as an argument')
    if not path.isfile(cnnFileName):
        raise ValueError('File "%s" does not exist!' % cnnFileName)
    
    try:
        imageFileName = argv[2]
    except IndexError:
        raise IndexError('Input image must be supplied as an argument')
    if not path.isfile(imageFileName):
        raise ValueError('File "%s" does not exist!' % imageFileName)   
    
    try:
        outputProductPath = argv[3]
    except IndexError:
        raise IndexError('Output direrctory for product mask must be supplied as an argument')
        
    try:
        outputPackagePath = argv[4]
    except IndexError:
        raise IndexError('Output direrctory for package mask must be supplied as an argument')    

    with open(cnnFileName, 'rb') as f:
        cfg = pickle.load(f)
    
    predictor = DefaultPredictor(cfg)    
    im = cv2.imread(imageFileName)
    
    folders = [outputProductPath, outputPackagePath]
    instances = predictor(im)['instances']
    classes = instances.pred_classes.cpu().numpy();
        
    base = os.path.basename(imageFileName)
    for i in range(0, 2):
        cv2.imwrite(folders[classes[i]] + '/' + base, instances.pred_masks[i].cpu().numpy() * 255)

if __name__ == "__main__":
    main()
