"""
Labels for Soft Ordinal Regression (SORD)
"""

import numpy as np


def soft_label_theta(theta_true, n_bins=100, K=1):
    theta_pred = np.linspace(-np.pi/2, np.pi/2, n_bins, endpoint=False)
    th = np.array(theta_true).reshape(-1,1)    
    m1 = np.abs(theta_pred-th)
    m2 = np.abs(np.fmod(theta_pred+1*np.pi, 2*np.pi) - th)
    m3 = np.abs(np.fmod(theta_pred-1*np.pi, 2*np.pi) - th)
    phi = K * np.minimum.reduce([m1,m2,m3])
    labels = np.exp(-(phi**2), dtype=np.float32)
    return labels / labels.sum(axis=1,keepdims=True)


def soft_label_rho(rho_true, n_bins=100, K=1, K_range=1):
    rho_pred = np.tan(np.linspace(-np.pi/2, np.pi/2, n_bins, endpoint=True)) * K_range
    rho = np.array(rho_true).reshape(-1,1)
    d = (rho - rho_pred) * K
    labels = np.exp(-(d ** 2))
    return labels / labels.sum(axis=1,keepdims=True)


def theta_from_soft_label(theta_pred, method="weighted", offset=0):
    n_bins = theta_pred.shape[1]
    
    if method == "weighted":
        a = np.linspace(-np.pi, np.pi, n_bins, endpoint=False)
        t = np.maximum(theta_pred-offset, 0)
        x = (t * np.cos(a).reshape(1,-1)).sum(-1)
        y = (t * np.sin(a).reshape(1,-1)).sum(-1)
        theta_a = np.arctan2(y, x)
        theta_bin = (theta_a + np.pi) / (2*np.pi) * n_bins
    elif method == "argmax":
        theta_bin = np.argmax(theta_pred, axis=1)

    theta = theta_bin / n_bins * np.pi - (np.pi/2)
    return theta


def rho_from_soft_label(rho_pred, K_range=1):
    n_bins = rho_pred.shape[1]
    rho_bins = np.tan(np.linspace(-np.pi/2, np.pi/2, n_bins, endpoint=True)) * K_range
    rho_bins = rho_bins.reshape(1,-1)
    r = np.where(rho_pred<0.1*rho_pred.max(-1,keepdims=True), 0, rho_pred)
    r /= r.sum(axis=-1, keepdims=True)
    return (r * rho_bins).sum(-1)