#!/usr/bin/env python
# coding: utf-8

# In[2]:


import pandas as pd
import numpy as np
import os

from matplotlib import pyplot as plt, rcParams


# In[34]:


def hinton(pivot, max_weight=None, ax=None, negtive_set_zero=False):
    """Draw Hinton diagram for visualizing a weight matrix."""
    if negtive_set_zero:
        min_abs = pivot.abs().values.min()
        pivot = pivot.applymap(lambda x: min_abs if x<0 else x)
    
    matrix = pivot.values
    fig, ax = plt.subplots(1,1,figsize=(8,8))
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log2(np.abs(matrix).max()))

    ax.patch.set_facecolor('white')
    ax.set_aspect('equal', 'box')
#     ax.xaxis.set_major_locator(plt.NullLocator())
#     ax.yaxis.set_major_locator(plt.NullLocator())

    for (y,x), w in np.ndenumerate(matrix):
        facecolor, edgecolor = ('gray','gray') if w > 0 else ('white', 'gray')
        size = np.sqrt(abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                             facecolor=facecolor, edgecolor=edgecolor)
        ax.add_patch(rect)
    
    ax.set_xlabel('Income Categories (￥)', fontsize=18)
    ax.set_ylabel('Hours Categories', fontsize=18)
    ax.set_xticks(np.arange(0,11), 
                  labels=[v if i%2==0 else f"\n{v}" for i,v in enumerate(pivot.columns)], rotation=0)
#     ax.set_xticks(np.arange(0,11), labels=pivot.columns, rotation=90)
    ax.set_yticks(np.arange(0,11), labels=pivot.index, rotation=0)
    ax.autoscale_view()
    ax.grid(alpha=0.1)
#     ax.invert_yaxis()
    plt.tight_layout()

    plt.savefig('hint_plot.png', dpi=300)


# In[3]:


rcParams['font.sans-serif'] = 'Songti SC'


# In[42]:


label = pd.read_excel(r'交互项.xlsx', 
                      header=None, names=['var', 'label'], sheet_name=1)


# In[51]:


data = pd.read_excel('result5_1.xlsx', header=1).dropna().iloc[:-3]\
    .rename(columns={'VARIABLES': 'coordinates', 'whether_stop': 'size'})
data['coordinates'] = data['coordinates'].apply(lambda x: x.replace('base_',''))

pattern1 = r'c\.group_accincome(?P<accincome>\d+)#c\.group_accduration(?P<accduration>\d+)'
pattern2 = r'c\.group_accduration(?P<accduration>\d+)#c\.group_accincome(?P<accincome>\d+)'
x1 = data['coordinates'].str.extract(pattern1)
x2 = data['coordinates'].str.extract(pattern2)
data = pd.concat([data,x1.combine_first(x2).astype('int')], axis=1)

data['size'] = data['size'].astype('str')\
    .str.replace(r'\*+', '', regex=True).astype('float')\
    .replace(0,1e-7)


pivot = data.pivot(index=['accduration'], columns=['accincome'], values='size').fillna(1e-7)
pivot.columns = pd.cut(pivot.columns, 11, 
                     labels=label.loc[label['var'].str.contains('income'), 'label'].rename('accincome'))
pivot.index = pd.cut(pivot.index, 11, 
                       labels=label.loc[label['var'].str.contains('duration'), 'label'].rename('accduration'))

hinton(pivot, negtive_set_zero=True)


# In[52]:


data = pd.read_excel('result5_2.xlsx', header=1).dropna().iloc[:-3]\
    .rename(columns={'VARIABLES': 'coordinates', 'whether_stop': 'size'})
data['coordinates'] = data['coordinates'].apply(lambda x: x.replace('base_',''))

pattern1 = r'c\.group_accincome(?P<accincome>\d+)#c\.group_accduration(?P<accduration>\d+)'
pattern2 = r'c\.group_accduration(?P<accduration>\d+)#c\.group_accincome(?P<accincome>\d+)'
x1 = data['coordinates'].str.extract(pattern1)
x2 = data['coordinates'].str.extract(pattern2)
data = pd.concat([data,x1.combine_first(x2).astype('int')], axis=1)

data['size'] = data['size'].astype('str')\
    .str.replace(r'\*+', '', regex=True).astype('float')\
    .replace(0,1e-7)


pivot = data.pivot(index=['accduration'], columns=['accincome'], values='size').fillna(1e-7)
pivot.columns = pd.cut(pivot.columns, 11, 
                     labels=label.loc[label['var'].str.contains('income'), 'label'].rename('accincome'))
pivot.index = pd.cut(pivot.index, 11, 
                       labels=label.loc[label['var'].str.contains('duration'), 'label'].rename('accduration'))

hinton(pivot, negtive_set_zero=True)

