# 【2.1.3】散点图线性拟合（Scatter plot with linear regression line of best fit，pearson）

## 案例一：

from scipy import stats
import matplotlib.pyplot as plt

x = [1,2,3,4]
y = [3,5,7,10]   # 10, not 9, so the fit isn't perfect

#remove nan
dt = {'x':x,'y':y}
df_1 = pd.DataFrame(dt)
df_2 = df_1.dropna()
x = df_2['x']
y = df_2['y']

slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
line = [slope*ii for ii in x] + intercept

plt.plot(x, y, 'o', x, line)
plt.annotate('R=%.2f\n' % (r_value), xy=(0.05, 0.9), xycoords='axes fraction',color='red')

plt.xlim(0, 5)
plt.ylim(0, 12)
plt.show()


df['category']  = 'black'
for index_1,row_1 in df.iterrows():
category = row_1['Seq']
if category in inhouse_seq:
df.loc[index_1,'category'] = 'red'

data = df

fig,ax=plt.subplots(figsize=(12,8))

xx = 'coding_5_3'
yy = 'in_vitro'

slope, intercept, r_value, p_value, std_err = stats.linregress(df[xx], df[yy])
line = [slope*ii for ii in df[xx]] + intercept

ax.scatter(xx, yy,c=df.category, data=df) # , cmap="tab10"  ,  alpha=.8
plt.plot(df[xx], line, color='red') #
ax.annotate('R=%.2f\n' % (r_value), xy=(0.05, 0.9), xycoords='axes fraction',color='red',fontsize=20)

xxx = list(df[xx])
yyy = list(df[yy])
zzz = list(df['Seq'])
for ii in range(len(df)):
ax.text(xxx[ii], yyy[ii]+0.5, zzz[ii], ha="center", va="center", size=10)

plt.xlabel('5-MFE (kcal/mol)', size=20)
plt.ylabel('In_vitro_expression(mg/ml)', size=20)
ax.set_title('Correlation between in vitro expression and 5-MFE',size=24)

plt.show()


## 案例二

# Import Data
df_select = df.loc[df.cyl.isin([4,8]), :]

# Plot
sns.set_style("white")
gridobj = sns.lmplot(x="displ", y="hwy", hue="cyl", data=df_select,
height=7, aspect=1.6, robust=True, palette='tab10',
scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))

# Decorations
gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
plt.title("Scatterplot with line of best fit grouped by number of cylinders", fontsize=20)
plt.show()


# Import Data
df_select = df.loc[df.cyl.isin([4,8]), :]

# Each line in its own column
sns.set_style("white")
gridobj = sns.lmplot(x="displ", y="hwy",
data=df_select,
height=7,
robust=True,
palette='Set1',
col="cyl",
scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))

# Decorations
gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
plt.show()