from netCDF4 import Dataset
import numpy as np
from matplotlib.colors import BoundaryNorm
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import LinearSegmentedColormap
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.io.shapereader import Reader
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
from wrf import (getvar, to_np, latlon_coords, get_cartopy,
                 ALL_TIMES)
from mypath import lambertticks as lt
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)

reader = Reader("D:/Dhy/MAP/China_provinces/China_provinces")
provinces = cfeature.ShapelyFeature(reader.geometries(), ccrs.PlateCarree(), edgecolor='k', facecolor='none')
kzmin = ['0.2', '0.5']
days = ['08-10','11-16','20-21','22-24']
day = ['07', '08', '09', '10', '11', '12', '13', '14', '19', '20', '21', '22', '23']
pbl = ['AC', 'BL']
pathf1 = 'F:/wrfout/xiangmu/wrfoutACORG/'
#pathf2 = 'F:/wrfout/xiangmu/wrfoutAC0.5/'
#pathf2 = 'J:/wrfout10/'
pathf2 = 'F:/wrfout/xiangmu/wrfoutkf/'

# Define daytime and nighttime
daytime = np.arange(8,17)
for i in np.arange(1,13):  # 12+1 day
    daytime = np.concatenate((daytime,np.arange(24*i+8,24*i+17)),axis=0)

alltime = np.arange(0,311)
nighttime = np.setdiff1d(alltime, daytime)


def setticks(lat, lon, ax):
    # Set the map bounds and plot ticks
    print([to_np(lat).max(), to_np(lon).max(), to_np(lat).min(), to_np(lon).min()])
    minlon=to_np(lon).min()
    maxlon=to_np(lon).max()
    minlat=to_np(lat).min()
    maxlat=to_np(lat).max()
    ax.set_extent([minlon, maxlon, minlat, maxlat], crs=ccrs.PlateCarree())
    fig.canvas.draw()
    xticks = list(range(minlon.astype(int)-1, maxlon.astype(int)+3, 1))
    yticks = list(range(minlat.astype(int)-1, maxlat.astype(int)+3, 1))
    ax.gridlines(xlocs=xticks, ylocs=yticks)
    ax.xaxis.set_major_formatter(LONGITUDE_FORMATTER)
    ax.yaxis.set_major_formatter(LATITUDE_FORMATTER)
    lt.lambert_xticks(ax, xticks)
    lt.lambert_yticks(ax, yticks)
    ax.tick_params(labelsize=15)


def readvar2d(pathf1, day, var1d):
    for dd in np.arange(0, 13):
        if dd == 0:
            ncfile = Dataset(pathf1 + 'wrfout_d03_2014-01-' + day[dd] + '_00_00_00')
            t2x = np.array(ncfile.variables[var1d])
            t2 = t2x[16:16+24,:,:]
        else:
            ncfile = Dataset(pathf1 + 'wrfout_d03_2014-01-' + day[dd] + '_00_00_00')
            t2x = ncfile.variables[var1d]
            t2x = np.array(t2x)
            t2 = np.concatenate((t2, t2x[16:16+24,:,:]), axis=0)
    lat = getvar(ncfile, "XLAT")
    lon = getvar(ncfile, "XLONG")
    return t2, lat, lon


hgt, lat, lon = readvar2d(pathf1, day, 'HGT')
t202, lat, lon = readvar2d(pathf1, day, 'TSK')    # T2 TSK HFX
t205, lat, lon = readvar2d(pathf2, day, 'TSK')
t2d = t205-t202

t2m = t2d[daytime, :, :].mean(axis=0)     # select daytime or nighttime


# Get the map projection information
cart_proj = get_cartopy(lon)

# Plot
fig = plt.figure(figsize=(6, 5.5))
ax = plt.axes(projection=cart_proj)

# add the states and coastlines
ax.add_feature(provinces, linewidth=0.5, edgecolor="black")

levels = np.arange(-3.1,3.15,0.1)#np.arange(-10.3,10.4,0.1) # np.arange(-3.1,3.15,0.1)# np.arange(-1.5,1.6,0.1)  # np.arange(-5,6,0.5)    #np.arange(250,280,1)
cmap = plt.get_cmap('bwr')
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
wspd_contours = ax.pcolormesh(to_np(lon), to_np(lat), to_np(t2m[:, :]),#to_np(hgt),#to_np(pm25[0, 0, :, :]),
                             #vmin=0,vmax=0.6, #levels=levels,
                             cmap = get_cmap("bwr"), norm=norm, #get_cmap("YlOrRd"),get_cmap("terrain")
                             transform = ccrs.PlateCarree())

setticks(lat, lon, ax)
# plt.title("T2diff(K)", fontsize=15)
fig.subplots_adjust(top=1, bottom=0.05, left=0.1, right=1)


# color bar plot
fig, ax = plt.subplots(figsize=(12,10))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
fig.subplots_adjust(top=0.985)
fig.subplots_adjust(bottom=0.045)
fig.subplots_adjust(left=0.013)
fig.subplots_adjust(right=0.988)
ax.set_xticks([])
ax.set_yticks([])
cbar = plt.colorbar(wspd_contours, ax=ax, orientation="horizontal")
cbar.ax.set_xlabel('HFX diff (W/'+r'$\mathregular{m^2}$)', fontsize=25)  #'HFX diff (W/m^2)'
# cbar.ax.xaxis.set_major_locator(np.arange(-3,3,0.5))
cbar.set_ticks(np.arange(-3,3.1,1))
cbar.ax.tick_params(labelsize=25)  # 设置colorbar字体大小

plt.show()


a=1
