chezahuynya1/app/zadachi/zadacha3.py

55 lines
1.8 KiB
Python
Raw Normal View History

2024-10-15 00:02:41 +03:00
import numpy as np
import matplotlib.pyplot as plt
class zadacha3:
def __init__(self, X, Y, b0, b1):
self.X = X
self.Y = Y
self.b0 = b0
self.b1 = b1
def calculate_sse(self):
"""
Функция для расчета суммы квадратов ошибок (SSE)
:return: SSE
"""
Y_pred = self.b0 + self.b1 * self.X
sse = np.sum((self.Y - Y_pred) ** 2)
return sse
def calculate_r_squared(self):
"""
Функция для расчета коэффициента детерминации R^2
:return: R^2
"""
y_mean = np.mean(self.Y)
ss_total = np.sum((self.Y - y_mean) ** 2)
ss_residual = self.calculate_sse()
r_squared = 1 - (ss_residual / ss_total)
return r_squared
def plot_regression(self):
"""
Функция для построения графика регрессии и отображения метрик SSE и R^2
"""
Y_pred = self.b0 + self.b1 * self.X
plt.figure(figsize=(10, 6))
# График рассеяния и регрессионная прямая
plt.subplot(2, 1, 1)
plt.scatter(self.X, self.Y, label='Данные')
plt.plot(self.X, Y_pred, color='red', label='Регрессия')
plt.title("Линейная регрессия")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
# Расчет метрик и отображение на графике
sse = self.calculate_sse()
r_squared = self.calculate_r_squared()
plt.text(0.05, 0.95, f"SSE: {sse:.2f}\nR^2: {r_squared:.2f}",
transform=plt.gca().transAxes, verticalalignment='top')
plt.show()