55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
class zadacha3:
|
|
def __init__(self, X, Y, b0, b1):
|
|
self.X = X
|
|
self.Y = Y
|
|
self.b0 = b0
|
|
self.b1 = b1
|
|
|
|
def calculate_sse(self):
|
|
"""
|
|
Функция для расчета суммы квадратов ошибок (SSE)
|
|
:return: SSE
|
|
"""
|
|
Y_pred = self.b0 + self.b1 * self.X
|
|
sse = np.sum((self.Y - Y_pred) ** 2)
|
|
return sse
|
|
|
|
def calculate_r_squared(self):
|
|
"""
|
|
Функция для расчета коэффициента детерминации R^2
|
|
:return: R^2
|
|
"""
|
|
y_mean = np.mean(self.Y)
|
|
ss_total = np.sum((self.Y - y_mean) ** 2)
|
|
ss_residual = self.calculate_sse()
|
|
r_squared = 1 - (ss_residual / ss_total)
|
|
return r_squared
|
|
|
|
def plot_regression(self):
|
|
"""
|
|
Функция для построения графика регрессии и отображения метрик SSE и R^2
|
|
"""
|
|
Y_pred = self.b0 + self.b1 * self.X
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# График рассеяния и регрессионная прямая
|
|
plt.subplot(2, 1, 1)
|
|
plt.scatter(self.X, self.Y, label='Данные')
|
|
plt.plot(self.X, Y_pred, color='red', label='Регрессия')
|
|
plt.title("Линейная регрессия")
|
|
plt.xlabel("X")
|
|
plt.ylabel("Y")
|
|
plt.legend()
|
|
|
|
# Расчет метрик и отображение на графике
|
|
sse = self.calculate_sse()
|
|
r_squared = self.calculate_r_squared()
|
|
plt.text(0.05, 0.95, f"SSE: {sse:.2f}\nR^2: {r_squared:.2f}",
|
|
transform=plt.gca().transAxes, verticalalignment='top')
|
|
|
|
plt.show()
|