Refactor: Migrate Settings from electron-store to Centralized Database #113
This commit is contained in:
parent
9e0c74eed4
commit
9f29fa5873
29
package-lock.json
generated
29
package-lock.json
generated
@ -14,7 +14,7 @@
|
||||
"@google/genai": "^1.8.0",
|
||||
"@google/generative-ai": "^0.24.1",
|
||||
"axios": "^1.10.0",
|
||||
"better-sqlite3": "^9.4.3",
|
||||
"better-sqlite3": "^9.6.0",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^17.0.0",
|
||||
"electron-squirrel-startup": "^1.0.1",
|
||||
@ -27,6 +27,7 @@
|
||||
"keytar": "^7.9.0",
|
||||
"node-fetch": "^2.7.0",
|
||||
"openai": "^4.70.0",
|
||||
"portkey-ai": "^1.10.1",
|
||||
"react-hot-toast": "^2.5.2",
|
||||
"sharp": "^0.34.2",
|
||||
"validator": "^13.11.0",
|
||||
@ -2283,6 +2284,8 @@
|
||||
},
|
||||
"node_modules/better-sqlite3": {
|
||||
"version": "9.6.0",
|
||||
"resolved": "https://registry.npmjs.org/better-sqlite3/-/better-sqlite3-9.6.0.tgz",
|
||||
"integrity": "sha512-yR5HATnqeYNVnkaUTf4bOP2dJSnyhP4puJN/QPRyx4YkBEEUxib422n2XzPqDEHjQQqazoYoADdAm5vE15+dAQ==",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@ -5922,6 +5925,30 @@
|
||||
"node": ">=10.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/portkey-ai": {
|
||||
"version": "1.10.1",
|
||||
"resolved": "https://registry.npmjs.org/portkey-ai/-/portkey-ai-1.10.1.tgz",
|
||||
"integrity": "sha512-mRGDxm4xBMexYlk/bS8i+G5C/Ww+KaXcKlHtzzsmh0X4Awd1bPBGq5dlUmCrHGgN/umLpphxcOcLHsDa9NbjrQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"agentkeepalive": "^4.6.0",
|
||||
"dotenv": "^16.5.0",
|
||||
"openai": "4.104.0",
|
||||
"ws": "^8.18.2"
|
||||
}
|
||||
},
|
||||
"node_modules/portkey-ai/node_modules/dotenv": {
|
||||
"version": "16.6.1",
|
||||
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.6.1.tgz",
|
||||
"integrity": "sha512-uBq4egWHTcTt33a72vpSG0z3HnPuIl6NqYcTrKEg2azoEyl2hpW0zqlxysq2pK9HlDIHyHyakeYaYnSAwd8bow==",
|
||||
"license": "BSD-2-Clause",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://dotenvx.com"
|
||||
}
|
||||
},
|
||||
"node_modules/postject": {
|
||||
"version": "1.0.0-alpha.6",
|
||||
"dev": true,
|
||||
|
@ -36,7 +36,7 @@
|
||||
"@google/genai": "^1.8.0",
|
||||
"@google/generative-ai": "^0.24.1",
|
||||
"axios": "^1.10.0",
|
||||
"better-sqlite3": "^9.4.3",
|
||||
"better-sqlite3": "^9.6.0",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^17.0.0",
|
||||
"electron-squirrel-startup": "^1.0.1",
|
||||
@ -49,6 +49,7 @@
|
||||
"keytar": "^7.9.0",
|
||||
"node-fetch": "^2.7.0",
|
||||
"openai": "^4.70.0",
|
||||
"portkey-ai": "^1.10.1",
|
||||
"react-hot-toast": "^2.5.2",
|
||||
"sharp": "^0.34.2",
|
||||
"validator": "^13.11.0",
|
||||
|
@ -46,7 +46,8 @@ router.post('/find-or-create', async (req, res) => {
|
||||
|
||||
router.post('/api-key', async (req, res) => {
|
||||
try {
|
||||
await ipcRequest(req, 'save-api-key', req.body.apiKey);
|
||||
const { apiKey, provider = 'openai' } = req.body;
|
||||
await ipcRequest(req, 'save-api-key', { apiKey, provider });
|
||||
res.json({ message: 'API key saved successfully' });
|
||||
} catch (error) {
|
||||
console.error('Failed to save API key via IPC:', error);
|
||||
|
@ -66,15 +66,14 @@ const PROVIDERS = {
|
||||
'whisper': {
|
||||
name: 'Whisper (Local)',
|
||||
handler: () => {
|
||||
// Only load in main process
|
||||
// This needs to remain a function due to its conditional logic for renderer/main process
|
||||
if (typeof window === 'undefined') {
|
||||
return require("./providers/whisper");
|
||||
}
|
||||
// Return dummy for renderer
|
||||
// Return a dummy object for the renderer process
|
||||
return {
|
||||
validateApiKey: async () => ({ success: true }), // Mock validate for renderer
|
||||
createSTT: () => { throw new Error('Whisper STT is only available in main process'); },
|
||||
createLLM: () => { throw new Error('Whisper does not support LLM'); },
|
||||
createStreamingLLM: () => { throw new Error('Whisper does not support LLM'); }
|
||||
};
|
||||
},
|
||||
llmModels: [],
|
||||
@ -130,6 +129,32 @@ function createStreamingLLM(provider, opts) {
|
||||
return handler.createStreamingLLM(opts);
|
||||
}
|
||||
|
||||
function getProviderClass(providerId) {
|
||||
const providerConfig = PROVIDERS[providerId];
|
||||
if (!providerConfig) return null;
|
||||
|
||||
// Handle special cases for glass providers
|
||||
let actualProviderId = providerId;
|
||||
if (providerId === 'openai-glass') {
|
||||
actualProviderId = 'openai';
|
||||
}
|
||||
|
||||
// The handler function returns the module, from which we get the class.
|
||||
const module = providerConfig.handler();
|
||||
|
||||
// Map provider IDs to their actual exported class names
|
||||
const classNameMap = {
|
||||
'openai': 'OpenAIProvider',
|
||||
'anthropic': 'AnthropicProvider',
|
||||
'gemini': 'GeminiProvider',
|
||||
'ollama': 'OllamaProvider',
|
||||
'whisper': 'WhisperProvider'
|
||||
};
|
||||
|
||||
const className = classNameMap[actualProviderId];
|
||||
return className ? module[className] : null;
|
||||
}
|
||||
|
||||
function getAvailableProviders() {
|
||||
const stt = [];
|
||||
const llm = [];
|
||||
@ -145,5 +170,6 @@ module.exports = {
|
||||
createSTT,
|
||||
createLLM,
|
||||
createStreamingLLM,
|
||||
getProviderClass,
|
||||
getAvailableProviders,
|
||||
};
|
@ -1,4 +1,38 @@
|
||||
const Anthropic = require("@anthropic-ai/sdk")
|
||||
const { Anthropic } = require("@anthropic-ai/sdk")
|
||||
|
||||
class AnthropicProvider {
|
||||
static async validateApiKey(key) {
|
||||
if (!key || typeof key !== 'string' || !key.startsWith('sk-ant-')) {
|
||||
return { success: false, error: 'Invalid Anthropic API key format.' };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("https://api.anthropic.com/v1/messages", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "claude-3-haiku-20240307",
|
||||
max_tokens: 1,
|
||||
messages: [{ role: "user", content: "Hi" }],
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok || response.status === 400) { // 400 is a valid response for a bad request, not a bad key
|
||||
return { success: true };
|
||||
} else {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
return { success: false, error: errorData.error?.message || `Validation failed with status: ${response.status}` };
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[AnthropicProvider] Network error during key validation:`, error);
|
||||
return { success: false, error: 'A network error occurred during validation.' };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an Anthropic STT session
|
||||
@ -286,7 +320,8 @@ function createStreamingLLM({
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
AnthropicProvider,
|
||||
createSTT,
|
||||
createLLM,
|
||||
createStreamingLLM,
|
||||
}
|
||||
createStreamingLLM
|
||||
};
|
||||
|
@ -1,6 +1,31 @@
|
||||
const { GoogleGenerativeAI } = require("@google/generative-ai")
|
||||
const { GoogleGenAI } = require("@google/genai")
|
||||
|
||||
class GeminiProvider {
|
||||
static async validateApiKey(key) {
|
||||
if (!key || typeof key !== 'string') {
|
||||
return { success: false, error: 'Invalid Gemini API key format.' };
|
||||
}
|
||||
|
||||
try {
|
||||
const validationUrl = `https://generativelanguage.googleapis.com/v1beta/models?key=${key}`;
|
||||
const response = await fetch(validationUrl);
|
||||
|
||||
if (response.ok) {
|
||||
return { success: true };
|
||||
} else {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
const message = errorData.error?.message || `Validation failed with status: ${response.status}`;
|
||||
return { success: false, error: message };
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[GeminiProvider] Network error during key validation:`, error);
|
||||
return { success: false, error: 'A network error occurred during validation.' };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Creates a Gemini STT session
|
||||
* @param {object} opts - Configuration options
|
||||
@ -296,7 +321,8 @@ function createStreamingLLM({ apiKey, model = "gemini-2.5-flash", temperature =
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
GeminiProvider,
|
||||
createSTT,
|
||||
createLLM,
|
||||
createStreamingLLM,
|
||||
}
|
||||
createStreamingLLM
|
||||
};
|
||||
|
@ -1,6 +1,22 @@
|
||||
const http = require('http');
|
||||
const fetch = require('node-fetch');
|
||||
|
||||
class OllamaProvider {
|
||||
static async validateApiKey() {
|
||||
try {
|
||||
const response = await fetch('http://localhost:11434/api/tags');
|
||||
if (response.ok) {
|
||||
return { success: true };
|
||||
} else {
|
||||
return { success: false, error: 'Ollama service is not running. Please start Ollama first.' };
|
||||
}
|
||||
} catch (error) {
|
||||
return { success: false, error: 'Cannot connect to Ollama. Please ensure Ollama is installed and running.' };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function convertMessagesToOllamaFormat(messages) {
|
||||
return messages.map(msg => {
|
||||
if (Array.isArray(msg.content)) {
|
||||
@ -237,6 +253,8 @@ function createStreamingLLM({
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
OllamaProvider,
|
||||
createLLM,
|
||||
createStreamingLLM
|
||||
createStreamingLLM,
|
||||
convertMessagesToOllamaFormat
|
||||
};
|
@ -1,5 +1,35 @@
|
||||
const OpenAI = require('openai');
|
||||
const WebSocket = require('ws');
|
||||
const { Portkey } = require('portkey-ai');
|
||||
const { Readable } = require('stream');
|
||||
const { getProviderForModel } = require('../factory.js');
|
||||
|
||||
|
||||
class OpenAIProvider {
|
||||
static async validateApiKey(key) {
|
||||
if (!key || typeof key !== 'string' || !key.startsWith('sk-')) {
|
||||
return { success: false, error: 'Invalid OpenAI API key format.' };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('https://api.openai.com/v1/models', {
|
||||
headers: { 'Authorization': `Bearer ${key}` }
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return { success: true };
|
||||
} else {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
const message = errorData.error?.message || `Validation failed with status: ${response.status}`;
|
||||
return { success: false, error: message };
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[OpenAIProvider] Network error during key validation:`, error);
|
||||
return { success: false, error: 'A network error occurred during validation.' };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Creates an OpenAI STT session
|
||||
@ -257,6 +287,7 @@ function createStreamingLLM({ apiKey, model = 'gpt-4.1', temperature = 0.7, maxT
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
OpenAIProvider,
|
||||
createSTT,
|
||||
createLLM,
|
||||
createStreamingLLM
|
||||
|
@ -173,6 +173,11 @@ class WhisperSTTSession extends EventEmitter {
|
||||
}
|
||||
|
||||
class WhisperProvider {
|
||||
static async validateApiKey() {
|
||||
// Whisper is a local service, no API key validation needed.
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
constructor() {
|
||||
this.whisperService = null;
|
||||
}
|
||||
@ -224,8 +229,12 @@ class WhisperProvider {
|
||||
}
|
||||
|
||||
async createStreamingLLM() {
|
||||
throw new Error('Whisper provider does not support streaming LLM functionality');
|
||||
console.warn('[WhisperProvider] Streaming LLM is not supported by Whisper.');
|
||||
throw new Error('Whisper does not support LLM.');
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new WhisperProvider();
|
||||
module.exports = {
|
||||
WhisperProvider,
|
||||
WhisperSTTSession
|
||||
};
|
@ -5,8 +5,6 @@ const LATEST_SCHEMA = {
|
||||
{ name: 'display_name', type: 'TEXT NOT NULL' },
|
||||
{ name: 'email', type: 'TEXT NOT NULL' },
|
||||
{ name: 'created_at', type: 'INTEGER' },
|
||||
{ name: 'api_key', type: 'TEXT' },
|
||||
{ name: 'provider', type: 'TEXT DEFAULT \'openai\'' },
|
||||
{ name: 'auto_update_enabled', type: 'INTEGER DEFAULT 1' },
|
||||
{ name: 'has_migrated_to_firebase', type: 'INTEGER DEFAULT 0' }
|
||||
]
|
||||
@ -90,6 +88,28 @@ const LATEST_SCHEMA = {
|
||||
{ name: 'installed', type: 'INTEGER DEFAULT 0' },
|
||||
{ name: 'installing', type: 'INTEGER DEFAULT 0' }
|
||||
]
|
||||
},
|
||||
provider_settings: {
|
||||
columns: [
|
||||
{ name: 'uid', type: 'TEXT NOT NULL' },
|
||||
{ name: 'provider', type: 'TEXT NOT NULL' },
|
||||
{ name: 'api_key', type: 'TEXT' },
|
||||
{ name: 'selected_llm_model', type: 'TEXT' },
|
||||
{ name: 'selected_stt_model', type: 'TEXT' },
|
||||
{ name: 'created_at', type: 'INTEGER' },
|
||||
{ name: 'updated_at', type: 'INTEGER' }
|
||||
],
|
||||
constraints: ['PRIMARY KEY (uid, provider)']
|
||||
},
|
||||
user_model_selections: {
|
||||
columns: [
|
||||
{ name: 'uid', type: 'TEXT PRIMARY KEY' },
|
||||
{ name: 'selected_llm_provider', type: 'TEXT' },
|
||||
{ name: 'selected_llm_model', type: 'TEXT' },
|
||||
{ name: 'selected_stt_provider', type: 'TEXT' },
|
||||
{ name: 'selected_stt_model', type: 'TEXT' },
|
||||
{ name: 'updated_at', type: 'INTEGER' }
|
||||
]
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,83 @@
|
||||
const { collection, doc, getDoc, getDocs, setDoc, deleteDoc, query, where } = require('firebase/firestore');
|
||||
const { getFirestoreInstance: getFirestore } = require('../../services/firebaseClient');
|
||||
const { createEncryptedConverter } = require('../firestoreConverter');
|
||||
|
||||
// Create encrypted converter for provider settings
|
||||
const providerSettingsConverter = createEncryptedConverter([
|
||||
'api_key', // Encrypt API keys
|
||||
'selected_llm_model', // Encrypt model selections for privacy
|
||||
'selected_stt_model'
|
||||
]);
|
||||
|
||||
function providerSettingsCol() {
|
||||
const db = getFirestore();
|
||||
return collection(db, 'provider_settings').withConverter(providerSettingsConverter);
|
||||
}
|
||||
|
||||
async function getByProvider(uid, provider) {
|
||||
try {
|
||||
const docRef = doc(providerSettingsCol(), `${uid}_${provider}`);
|
||||
const docSnap = await getDoc(docRef);
|
||||
return docSnap.exists() ? { id: docSnap.id, ...docSnap.data() } : null;
|
||||
} catch (error) {
|
||||
console.error('[ProviderSettings Firebase] Error getting provider settings:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function getAllByUid(uid) {
|
||||
try {
|
||||
const q = query(providerSettingsCol(), where('uid', '==', uid));
|
||||
const querySnapshot = await getDocs(q);
|
||||
return querySnapshot.docs.map(doc => ({ id: doc.id, ...doc.data() }));
|
||||
} catch (error) {
|
||||
console.error('[ProviderSettings Firebase] Error getting all provider settings:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async function upsert(uid, provider, settings) {
|
||||
try {
|
||||
const docRef = doc(providerSettingsCol(), `${uid}_${provider}`);
|
||||
await setDoc(docRef, settings, { merge: true });
|
||||
return { changes: 1 };
|
||||
} catch (error) {
|
||||
console.error('[ProviderSettings Firebase] Error upserting provider settings:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async function remove(uid, provider) {
|
||||
try {
|
||||
const docRef = doc(providerSettingsCol(), `${uid}_${provider}`);
|
||||
await deleteDoc(docRef);
|
||||
return { changes: 1 };
|
||||
} catch (error) {
|
||||
console.error('[ProviderSettings Firebase] Error removing provider settings:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async function removeAllByUid(uid) {
|
||||
try {
|
||||
const settings = await getAllByUid(uid);
|
||||
const deletePromises = settings.map(setting => {
|
||||
const docRef = doc(providerSettingsCol(), setting.id);
|
||||
return deleteDoc(docRef);
|
||||
});
|
||||
|
||||
await Promise.all(deletePromises);
|
||||
return { changes: settings.length };
|
||||
} catch (error) {
|
||||
console.error('[ProviderSettings Firebase] Error removing all provider settings:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getByProvider,
|
||||
getAllByUid,
|
||||
upsert,
|
||||
remove,
|
||||
removeAllByUid
|
||||
};
|
65
src/common/repositories/providerSettings/index.js
Normal file
65
src/common/repositories/providerSettings/index.js
Normal file
@ -0,0 +1,65 @@
|
||||
const firebaseRepository = require('./firebase.repository');
|
||||
const sqliteRepository = require('./sqlite.repository');
|
||||
|
||||
let authService = null;
|
||||
|
||||
function setAuthService(service) {
|
||||
authService = service;
|
||||
}
|
||||
|
||||
function getBaseRepository() {
|
||||
if (!authService) {
|
||||
throw new Error('AuthService not set for providerSettings repository');
|
||||
}
|
||||
|
||||
const user = authService.getCurrentUser();
|
||||
return user.isLoggedIn ? firebaseRepository : sqliteRepository;
|
||||
}
|
||||
|
||||
const providerSettingsRepositoryAdapter = {
|
||||
// Core CRUD operations
|
||||
async getByProvider(provider) {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
return await repo.getByProvider(uid, provider);
|
||||
},
|
||||
|
||||
async getAllByUid() {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
return await repo.getAllByUid(uid);
|
||||
},
|
||||
|
||||
async upsert(provider, settings) {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
const now = Date.now();
|
||||
|
||||
const settingsWithMeta = {
|
||||
...settings,
|
||||
uid,
|
||||
provider,
|
||||
updated_at: now,
|
||||
created_at: settings.created_at || now
|
||||
};
|
||||
|
||||
return await repo.upsert(uid, provider, settingsWithMeta);
|
||||
},
|
||||
|
||||
async remove(provider) {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
return await repo.remove(uid, provider);
|
||||
},
|
||||
|
||||
async removeAllByUid() {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
return await repo.removeAllByUid(uid);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
...providerSettingsRepositoryAdapter,
|
||||
setAuthService
|
||||
};
|
@ -0,0 +1,62 @@
|
||||
const sqliteClient = require('../../services/sqliteClient');
|
||||
|
||||
function getByProvider(uid, provider) {
|
||||
const db = sqliteClient.getDb();
|
||||
const stmt = db.prepare('SELECT * FROM provider_settings WHERE uid = ? AND provider = ?');
|
||||
return stmt.get(uid, provider) || null;
|
||||
}
|
||||
|
||||
function getAllByUid(uid) {
|
||||
const db = sqliteClient.getDb();
|
||||
const stmt = db.prepare('SELECT * FROM provider_settings WHERE uid = ? ORDER BY provider');
|
||||
return stmt.all(uid);
|
||||
}
|
||||
|
||||
function upsert(uid, provider, settings) {
|
||||
const db = sqliteClient.getDb();
|
||||
|
||||
// Use SQLite's UPSERT syntax (INSERT ... ON CONFLICT ... DO UPDATE)
|
||||
const stmt = db.prepare(`
|
||||
INSERT INTO provider_settings (uid, provider, api_key, selected_llm_model, selected_stt_model, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(uid, provider) DO UPDATE SET
|
||||
api_key = excluded.api_key,
|
||||
selected_llm_model = excluded.selected_llm_model,
|
||||
selected_stt_model = excluded.selected_stt_model,
|
||||
updated_at = excluded.updated_at
|
||||
`);
|
||||
|
||||
const result = stmt.run(
|
||||
uid,
|
||||
provider,
|
||||
settings.api_key || null,
|
||||
settings.selected_llm_model || null,
|
||||
settings.selected_stt_model || null,
|
||||
settings.created_at || Date.now(),
|
||||
settings.updated_at
|
||||
);
|
||||
|
||||
return { changes: result.changes };
|
||||
}
|
||||
|
||||
function remove(uid, provider) {
|
||||
const db = sqliteClient.getDb();
|
||||
const stmt = db.prepare('DELETE FROM provider_settings WHERE uid = ? AND provider = ?');
|
||||
const result = stmt.run(uid, provider);
|
||||
return { changes: result.changes };
|
||||
}
|
||||
|
||||
function removeAllByUid(uid) {
|
||||
const db = sqliteClient.getDb();
|
||||
const stmt = db.prepare('DELETE FROM provider_settings WHERE uid = ?');
|
||||
const result = stmt.run(uid);
|
||||
return { changes: result.changes };
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getByProvider,
|
||||
getAllByUid,
|
||||
upsert,
|
||||
remove,
|
||||
removeAllByUid
|
||||
};
|
@ -3,7 +3,7 @@ const { getFirestoreInstance } = require('../../services/firebaseClient');
|
||||
const { createEncryptedConverter } = require('../firestoreConverter');
|
||||
const encryptionService = require('../../services/encryptionService');
|
||||
|
||||
const userConverter = createEncryptedConverter(['api_key']);
|
||||
const userConverter = createEncryptedConverter([]);
|
||||
|
||||
function usersCol() {
|
||||
const db = getFirestoreInstance();
|
||||
@ -38,11 +38,7 @@ async function getById(uid) {
|
||||
return docSnap.exists() ? docSnap.data() : null;
|
||||
}
|
||||
|
||||
async function saveApiKey(uid, apiKey, provider = 'openai') {
|
||||
const docRef = doc(usersCol(), uid);
|
||||
await setDoc(docRef, { api_key: apiKey, provider }, { merge: true });
|
||||
return { changes: 1 };
|
||||
}
|
||||
|
||||
|
||||
async function update({ uid, displayName }) {
|
||||
const docRef = doc(usersCol(), uid);
|
||||
@ -85,7 +81,6 @@ async function deleteById(uid) {
|
||||
module.exports = {
|
||||
findOrCreate,
|
||||
getById,
|
||||
saveApiKey,
|
||||
update,
|
||||
deleteById,
|
||||
};
|
@ -1,8 +1,16 @@
|
||||
const sqliteRepository = require('./sqlite.repository');
|
||||
const firebaseRepository = require('./firebase.repository');
|
||||
const authService = require('../../services/authService');
|
||||
|
||||
let authService = null;
|
||||
|
||||
function setAuthService(service) {
|
||||
authService = service;
|
||||
}
|
||||
|
||||
function getBaseRepository() {
|
||||
if (!authService) {
|
||||
throw new Error('AuthService has not been set for the user repository.');
|
||||
}
|
||||
const user = authService.getCurrentUser();
|
||||
if (user && user.isLoggedIn) {
|
||||
return firebaseRepository;
|
||||
@ -21,10 +29,7 @@ const userRepositoryAdapter = {
|
||||
return getBaseRepository().getById(uid);
|
||||
},
|
||||
|
||||
saveApiKey: (apiKey, provider) => {
|
||||
const uid = authService.getCurrentUserId();
|
||||
return getBaseRepository().saveApiKey(uid, apiKey, provider);
|
||||
},
|
||||
|
||||
|
||||
update: (updateData) => {
|
||||
const uid = authService.getCurrentUserId();
|
||||
@ -37,4 +42,7 @@ const userRepositoryAdapter = {
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = userRepositoryAdapter;
|
||||
module.exports = {
|
||||
...userRepositoryAdapter,
|
||||
setAuthService
|
||||
};
|
@ -40,17 +40,7 @@ function getById(uid) {
|
||||
return db.prepare('SELECT * FROM users WHERE uid = ?').get(uid);
|
||||
}
|
||||
|
||||
function saveApiKey(uid, apiKey, provider = 'openai') {
|
||||
const db = sqliteClient.getDb();
|
||||
try {
|
||||
const result = db.prepare('UPDATE users SET api_key = ?, provider = ? WHERE uid = ?').run(apiKey, provider, uid);
|
||||
console.log(`SQLite: API key saved for user ${uid} with provider ${provider}.`);
|
||||
return { changes: result.changes };
|
||||
} catch (err) {
|
||||
console.error('SQLite: Failed to save API key:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function update({ uid, displayName }) {
|
||||
const db = sqliteClient.getDb();
|
||||
@ -96,7 +86,6 @@ function deleteById(uid) {
|
||||
module.exports = {
|
||||
findOrCreate,
|
||||
getById,
|
||||
saveApiKey,
|
||||
update,
|
||||
setMigrationComplete,
|
||||
deleteById
|
||||
|
@ -0,0 +1,55 @@
|
||||
const { collection, doc, getDoc, setDoc, deleteDoc } = require('firebase/firestore');
|
||||
const { getFirestoreInstance: getFirestore } = require('../../services/firebaseClient');
|
||||
const { createEncryptedConverter } = require('../firestoreConverter');
|
||||
|
||||
// Create encrypted converter for user model selections
|
||||
const userModelSelectionsConverter = createEncryptedConverter([
|
||||
'selected_llm_provider',
|
||||
'selected_llm_model',
|
||||
'selected_stt_provider',
|
||||
'selected_stt_model'
|
||||
]);
|
||||
|
||||
function userModelSelectionsCol() {
|
||||
const db = getFirestore();
|
||||
return collection(db, 'user_model_selections').withConverter(userModelSelectionsConverter);
|
||||
}
|
||||
|
||||
async function get(uid) {
|
||||
try {
|
||||
const docRef = doc(userModelSelectionsCol(), uid);
|
||||
const docSnap = await getDoc(docRef);
|
||||
return docSnap.exists() ? { id: docSnap.id, ...docSnap.data() } : null;
|
||||
} catch (error) {
|
||||
console.error('[UserModelSelections Firebase] Error getting user model selections:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function upsert(uid, selections) {
|
||||
try {
|
||||
const docRef = doc(userModelSelectionsCol(), uid);
|
||||
await setDoc(docRef, selections, { merge: true });
|
||||
return { changes: 1 };
|
||||
} catch (error) {
|
||||
console.error('[UserModelSelections Firebase] Error upserting user model selections:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async function remove(uid) {
|
||||
try {
|
||||
const docRef = doc(userModelSelectionsCol(), uid);
|
||||
await deleteDoc(docRef);
|
||||
return { changes: 1 };
|
||||
} catch (error) {
|
||||
console.error('[UserModelSelections Firebase] Error removing user model selections:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
get,
|
||||
upsert,
|
||||
remove
|
||||
};
|
50
src/common/repositories/userModelSelections/index.js
Normal file
50
src/common/repositories/userModelSelections/index.js
Normal file
@ -0,0 +1,50 @@
|
||||
const firebaseRepository = require('./firebase.repository');
|
||||
const sqliteRepository = require('./sqlite.repository');
|
||||
|
||||
let authService = null;
|
||||
|
||||
function setAuthService(service) {
|
||||
authService = service;
|
||||
}
|
||||
|
||||
function getBaseRepository() {
|
||||
if (!authService) {
|
||||
throw new Error('AuthService not set for userModelSelections repository');
|
||||
}
|
||||
|
||||
const user = authService.getCurrentUser();
|
||||
return user.isLoggedIn ? firebaseRepository : sqliteRepository;
|
||||
}
|
||||
|
||||
const userModelSelectionsRepositoryAdapter = {
|
||||
async get() {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
return await repo.get(uid);
|
||||
},
|
||||
|
||||
async upsert(selections) {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
const now = Date.now();
|
||||
|
||||
const selectionsWithMeta = {
|
||||
...selections,
|
||||
uid,
|
||||
updated_at: now
|
||||
};
|
||||
|
||||
return await repo.upsert(uid, selectionsWithMeta);
|
||||
},
|
||||
|
||||
async remove() {
|
||||
const repo = getBaseRepository();
|
||||
const uid = authService.getCurrentUserId();
|
||||
return await repo.remove(uid);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
...userModelSelectionsRepositoryAdapter,
|
||||
setAuthService
|
||||
};
|
@ -0,0 +1,48 @@
|
||||
const sqliteClient = require('../../services/sqliteClient');
|
||||
|
||||
function get(uid) {
|
||||
const db = sqliteClient.getDb();
|
||||
const stmt = db.prepare('SELECT * FROM user_model_selections WHERE uid = ?');
|
||||
return stmt.get(uid) || null;
|
||||
}
|
||||
|
||||
function upsert(uid, selections) {
|
||||
const db = sqliteClient.getDb();
|
||||
|
||||
// Use SQLite's UPSERT syntax (INSERT ... ON CONFLICT ... DO UPDATE)
|
||||
const stmt = db.prepare(`
|
||||
INSERT INTO user_model_selections (uid, selected_llm_provider, selected_llm_model,
|
||||
selected_stt_provider, selected_stt_model, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(uid) DO UPDATE SET
|
||||
selected_llm_provider = excluded.selected_llm_provider,
|
||||
selected_llm_model = excluded.selected_llm_model,
|
||||
selected_stt_provider = excluded.selected_stt_provider,
|
||||
selected_stt_model = excluded.selected_stt_model,
|
||||
updated_at = excluded.updated_at
|
||||
`);
|
||||
|
||||
const result = stmt.run(
|
||||
uid,
|
||||
selections.selected_llm_provider || null,
|
||||
selections.selected_llm_model || null,
|
||||
selections.selected_stt_provider || null,
|
||||
selections.selected_stt_model || null,
|
||||
selections.updated_at
|
||||
);
|
||||
|
||||
return { changes: result.changes };
|
||||
}
|
||||
|
||||
function remove(uid) {
|
||||
const db = sqliteClient.getDb();
|
||||
const stmt = db.prepare('DELETE FROM user_model_selections WHERE uid = ?');
|
||||
const result = stmt.run(uid);
|
||||
return { changes: result.changes };
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
get,
|
||||
upsert,
|
||||
remove
|
||||
};
|
@ -5,6 +5,8 @@ const fetch = require('node-fetch');
|
||||
const encryptionService = require('./encryptionService');
|
||||
const migrationService = require('./migrationService');
|
||||
const sessionRepository = require('../repositories/session');
|
||||
const providerSettingsRepository = require('../repositories/providerSettings');
|
||||
const userModelSelectionsRepository = require('../repositories/userModelSelections');
|
||||
|
||||
async function getVirtualKeyByEmail(email, idToken) {
|
||||
if (!idToken) {
|
||||
@ -40,10 +42,13 @@ class AuthService {
|
||||
this.currentUser = null;
|
||||
this.isInitialized = false;
|
||||
|
||||
// Initialize immediately for the default local user on startup.
|
||||
// This ensures the key is ready before any login/logout state change.
|
||||
encryptionService.initializeKey(this.currentUserId);
|
||||
this.initializationPromise = null;
|
||||
|
||||
sessionRepository.setAuthService(this);
|
||||
providerSettingsRepository.setAuthService(this);
|
||||
userModelSelectionsRepository.setAuthService(this);
|
||||
}
|
||||
|
||||
initialize() {
|
||||
|
@ -1,21 +1,28 @@
|
||||
const Store = require('electron-store');
|
||||
const fetch = require('node-fetch');
|
||||
const { ipcMain, webContents } = require('electron');
|
||||
const { PROVIDERS } = require('../ai/factory');
|
||||
const { PROVIDERS, getProviderClass } = require('../ai/factory');
|
||||
const encryptionService = require('./encryptionService');
|
||||
const providerSettingsRepository = require('../repositories/providerSettings');
|
||||
const userModelSelectionsRepository = require('../repositories/userModelSelections');
|
||||
|
||||
class ModelStateService {
|
||||
constructor(authService) {
|
||||
this.authService = authService;
|
||||
this.store = new Store({ name: 'pickle-glass-model-state' });
|
||||
this.state = {};
|
||||
this.hasMigrated = false;
|
||||
|
||||
// Set auth service for repositories
|
||||
providerSettingsRepository.setAuthService(authService);
|
||||
userModelSelectionsRepository.setAuthService(authService);
|
||||
}
|
||||
|
||||
async initialize() {
|
||||
console.log('[ModelStateService] Initializing...');
|
||||
await this._loadStateForCurrentUser();
|
||||
|
||||
this.setupIpcHandlers();
|
||||
console.log('[ModelStateService] Initialized.');
|
||||
console.log('[ModelStateService] Initialization complete');
|
||||
}
|
||||
|
||||
_logCurrentSelection() {
|
||||
@ -64,44 +71,186 @@ class ModelStateService {
|
||||
});
|
||||
}
|
||||
|
||||
async _migrateFromElectronStore() {
|
||||
console.log('[ModelStateService] Starting migration from electron-store to database...');
|
||||
const userId = this.authService.getCurrentUserId();
|
||||
|
||||
try {
|
||||
// Get data from electron-store
|
||||
const legacyData = this.store.get(`users.${userId}`, null);
|
||||
|
||||
if (!legacyData) {
|
||||
console.log('[ModelStateService] No legacy data to migrate');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('[ModelStateService] Found legacy data, migrating...');
|
||||
|
||||
// Migrate provider settings (API keys and selected models per provider)
|
||||
const { apiKeys = {}, selectedModels = {} } = legacyData;
|
||||
|
||||
for (const [provider, apiKey] of Object.entries(apiKeys)) {
|
||||
if (apiKey && PROVIDERS[provider]) {
|
||||
// For encrypted keys, they are already decrypted in _loadStateForCurrentUser
|
||||
await providerSettingsRepository.upsert(provider, {
|
||||
api_key: apiKey
|
||||
});
|
||||
console.log(`[ModelStateService] Migrated API key for ${provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate global model selections
|
||||
if (selectedModels.llm || selectedModels.stt) {
|
||||
const llmProvider = selectedModels.llm ? this.getProviderForModel('llm', selectedModels.llm) : null;
|
||||
const sttProvider = selectedModels.stt ? this.getProviderForModel('stt', selectedModels.stt) : null;
|
||||
|
||||
await userModelSelectionsRepository.upsert({
|
||||
selected_llm_provider: llmProvider,
|
||||
selected_llm_model: selectedModels.llm,
|
||||
selected_stt_provider: sttProvider,
|
||||
selected_stt_model: selectedModels.stt
|
||||
});
|
||||
console.log('[ModelStateService] Migrated global model selections');
|
||||
}
|
||||
|
||||
// Mark migration as complete by removing legacy data
|
||||
this.store.delete(`users.${userId}`);
|
||||
console.log('[ModelStateService] Migration completed and legacy data cleaned up');
|
||||
|
||||
} catch (error) {
|
||||
console.error('[ModelStateService] Migration failed:', error);
|
||||
// Don't throw - continue with normal operation
|
||||
}
|
||||
}
|
||||
|
||||
async _loadStateFromDatabase() {
|
||||
console.log('[ModelStateService] Loading state from database...');
|
||||
const userId = this.authService.getCurrentUserId();
|
||||
|
||||
try {
|
||||
// Load provider settings
|
||||
const providerSettings = await providerSettingsRepository.getAllByUid();
|
||||
const apiKeys = {};
|
||||
|
||||
// Reconstruct apiKeys object
|
||||
Object.keys(PROVIDERS).forEach(provider => {
|
||||
apiKeys[provider] = null;
|
||||
});
|
||||
|
||||
for (const setting of providerSettings) {
|
||||
if (setting.api_key) {
|
||||
// API keys are stored encrypted in database, decrypt them
|
||||
if (setting.provider !== 'ollama' && setting.provider !== 'whisper') {
|
||||
try {
|
||||
apiKeys[setting.provider] = encryptionService.decrypt(setting.api_key);
|
||||
} catch (error) {
|
||||
console.error(`[ModelStateService] Failed to decrypt API key for ${setting.provider}, resetting`);
|
||||
apiKeys[setting.provider] = null;
|
||||
}
|
||||
} else {
|
||||
apiKeys[setting.provider] = setting.api_key;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load global model selections
|
||||
const modelSelections = await userModelSelectionsRepository.get();
|
||||
const selectedModels = {
|
||||
llm: modelSelections?.selected_llm_model || null,
|
||||
stt: modelSelections?.selected_stt_model || null
|
||||
};
|
||||
|
||||
this.state = {
|
||||
apiKeys,
|
||||
selectedModels
|
||||
};
|
||||
|
||||
console.log(`[ModelStateService] State loaded from database for user: ${userId}`);
|
||||
|
||||
} catch (error) {
|
||||
console.error('[ModelStateService] Failed to load state from database:', error);
|
||||
// Fall back to default state
|
||||
const initialApiKeys = Object.keys(PROVIDERS).reduce((acc, key) => {
|
||||
acc[key] = null;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
this.state = {
|
||||
apiKeys: initialApiKeys,
|
||||
selectedModels: { llm: null, stt: null },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async _loadStateForCurrentUser() {
|
||||
const userId = this.authService.getCurrentUserId();
|
||||
|
||||
// Initialize encryption service for current user
|
||||
await encryptionService.initializeKey(userId);
|
||||
|
||||
const initialApiKeys = Object.keys(PROVIDERS).reduce((acc, key) => {
|
||||
acc[key] = null;
|
||||
return acc;
|
||||
}, {});
|
||||
// Try to load from database first
|
||||
await this._loadStateFromDatabase();
|
||||
|
||||
const defaultState = {
|
||||
apiKeys: initialApiKeys,
|
||||
selectedModels: { llm: null, stt: null },
|
||||
};
|
||||
this.state = this.store.get(`users.${userId}`, defaultState);
|
||||
console.log(`[ModelStateService] State loaded for user: ${userId}`);
|
||||
|
||||
for (const p of Object.keys(PROVIDERS)) {
|
||||
if (!(p in this.state.apiKeys)) {
|
||||
this.state.apiKeys[p] = null;
|
||||
} else if (this.state.apiKeys[p] && p !== 'ollama' && p !== 'whisper') {
|
||||
try {
|
||||
this.state.apiKeys[p] = encryptionService.decrypt(this.state.apiKeys[p]);
|
||||
} catch (error) {
|
||||
console.error(`[ModelStateService] Failed to decrypt API key for ${p}, resetting`);
|
||||
this.state.apiKeys[p] = null;
|
||||
}
|
||||
}
|
||||
// Check if we need to migrate from electron-store
|
||||
const legacyData = this.store.get(`users.${userId}`, null);
|
||||
if (legacyData && !this.hasMigrated) {
|
||||
await this._migrateFromElectronStore();
|
||||
// Reload state after migration
|
||||
await this._loadStateFromDatabase();
|
||||
this.hasMigrated = true;
|
||||
}
|
||||
|
||||
this._autoSelectAvailableModels();
|
||||
this._saveState();
|
||||
await this._saveState();
|
||||
this._logCurrentSelection();
|
||||
}
|
||||
|
||||
async _saveState() {
|
||||
console.log('[ModelStateService] Saving state to database...');
|
||||
const userId = this.authService.getCurrentUserId();
|
||||
|
||||
_saveState() {
|
||||
try {
|
||||
// Save provider settings (API keys)
|
||||
for (const [provider, apiKey] of Object.entries(this.state.apiKeys)) {
|
||||
if (apiKey) {
|
||||
const encryptedKey = (provider !== 'ollama' && provider !== 'whisper')
|
||||
? encryptionService.encrypt(apiKey)
|
||||
: apiKey;
|
||||
|
||||
await providerSettingsRepository.upsert(provider, {
|
||||
api_key: encryptedKey
|
||||
});
|
||||
} else {
|
||||
// Remove empty API keys
|
||||
await providerSettingsRepository.remove(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// Save global model selections
|
||||
const llmProvider = this.state.selectedModels.llm ? this.getProviderForModel('llm', this.state.selectedModels.llm) : null;
|
||||
const sttProvider = this.state.selectedModels.stt ? this.getProviderForModel('stt', this.state.selectedModels.stt) : null;
|
||||
|
||||
if (llmProvider || sttProvider || this.state.selectedModels.llm || this.state.selectedModels.stt) {
|
||||
await userModelSelectionsRepository.upsert({
|
||||
selected_llm_provider: llmProvider,
|
||||
selected_llm_model: this.state.selectedModels.llm,
|
||||
selected_stt_provider: sttProvider,
|
||||
selected_stt_model: this.state.selectedModels.stt
|
||||
});
|
||||
}
|
||||
|
||||
console.log(`[ModelStateService] State saved to database for user: ${userId}`);
|
||||
this._logCurrentSelection();
|
||||
|
||||
} catch (error) {
|
||||
console.error('[ModelStateService] Failed to save state to database:', error);
|
||||
// Fall back to electron-store for now
|
||||
this._saveStateToElectronStore();
|
||||
}
|
||||
}
|
||||
|
||||
_saveStateToElectronStore() {
|
||||
console.log('[ModelStateService] Falling back to electron-store...');
|
||||
const userId = this.authService.getCurrentUserId();
|
||||
const stateToSave = {
|
||||
...this.state,
|
||||
@ -120,93 +269,34 @@ class ModelStateService {
|
||||
}
|
||||
|
||||
this.store.set(`users.${userId}`, stateToSave);
|
||||
console.log(`[ModelStateService] State saved for user: ${userId}`);
|
||||
console.log(`[ModelStateService] State saved to electron-store for user: ${userId}`);
|
||||
this._logCurrentSelection();
|
||||
}
|
||||
|
||||
async validateApiKey(provider, key) {
|
||||
if (!key || key.trim() === '') {
|
||||
if (!key || (key.trim() === '' && provider !== 'ollama' && provider !== 'whisper')) {
|
||||
return { success: false, error: 'API key cannot be empty.' };
|
||||
}
|
||||
|
||||
let validationUrl, headers;
|
||||
const body = undefined;
|
||||
const ProviderClass = getProviderClass(provider);
|
||||
|
||||
switch (provider) {
|
||||
case 'ollama':
|
||||
// Ollama doesn't need API key validation
|
||||
// Just check if the service is running
|
||||
try {
|
||||
const response = await fetch('http://localhost:11434/api/tags');
|
||||
if (response.ok) {
|
||||
console.log(`[ModelStateService] Ollama service is accessible.`);
|
||||
this.setApiKey(provider, 'local'); // Use 'local' as a placeholder
|
||||
if (!ProviderClass || typeof ProviderClass.validateApiKey !== 'function') {
|
||||
// Default to success if no specific validator is found
|
||||
console.warn(`[ModelStateService] No validateApiKey function for provider: ${provider}. Assuming valid.`);
|
||||
return { success: true };
|
||||
} else {
|
||||
return { success: false, error: 'Ollama service is not running. Please start Ollama first.' };
|
||||
}
|
||||
} catch (error) {
|
||||
return { success: false, error: 'Cannot connect to Ollama. Please ensure Ollama is installed and running.' };
|
||||
}
|
||||
case 'whisper':
|
||||
// Whisper is a local service, no API key validation needed
|
||||
console.log(`[ModelStateService] Whisper is a local service.`);
|
||||
this.setApiKey(provider, 'local'); // Use 'local' as a placeholder
|
||||
return { success: true };
|
||||
case 'openai':
|
||||
validationUrl = 'https://api.openai.com/v1/models';
|
||||
headers = { 'Authorization': `Bearer ${key}` };
|
||||
break;
|
||||
case 'gemini':
|
||||
validationUrl = `https://generativelanguage.googleapis.com/v1beta/models?key=${key}`;
|
||||
headers = {};
|
||||
break;
|
||||
case 'anthropic': {
|
||||
if (!key.startsWith('sk-ant-')) {
|
||||
throw new Error('Invalid Anthropic key format.');
|
||||
}
|
||||
const response = await fetch("https://api.anthropic.com/v1/messages", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: "claude-3-haiku-20240307",
|
||||
max_tokens: 1,
|
||||
messages: [{ role: "user", content: "Hi" }],
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok && response.status !== 400) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
return { success: false, error: errorData.error?.message || `Validation failed with status: ${response.status}` };
|
||||
}
|
||||
|
||||
console.log(`[ModelStateService] API key for ${provider} is valid.`);
|
||||
this.setApiKey(provider, key);
|
||||
return { success: true };
|
||||
}
|
||||
default:
|
||||
return { success: false, error: 'Unknown provider.' };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(validationUrl, { headers, body });
|
||||
if (response.ok) {
|
||||
const result = await ProviderClass.validateApiKey(key);
|
||||
if (result.success) {
|
||||
console.log(`[ModelStateService] API key for ${provider} is valid.`);
|
||||
this.setApiKey(provider, key);
|
||||
return { success: true };
|
||||
} else {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
const message = errorData.error?.message || `Validation failed with status: ${response.status}`;
|
||||
console.log(`[ModelStateService] API key for ${provider} is invalid: ${message}`);
|
||||
return { success: false, error: message };
|
||||
console.log(`[ModelStateService] API key for ${provider} is invalid: ${result.error}`);
|
||||
}
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.error(`[ModelStateService] Network error during ${provider} key validation:`, error);
|
||||
return { success: false, error: 'A network error occurred during validation.' };
|
||||
console.error(`[ModelStateService] Error during ${provider} key validation:`, error);
|
||||
return { success: false, error: 'An unexpected error occurred during validation.' };
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,33 +329,14 @@ class ModelStateService {
|
||||
setApiKey(provider, key) {
|
||||
if (provider in this.state.apiKeys) {
|
||||
this.state.apiKeys[provider] = key;
|
||||
|
||||
const llmModels = PROVIDERS[provider]?.llmModels;
|
||||
const sttModels = PROVIDERS[provider]?.sttModels;
|
||||
|
||||
// Prioritize newly set API key provider over existing selections
|
||||
// Only for non-local providers or if no model is currently selected
|
||||
if (llmModels?.length > 0) {
|
||||
if (!this.state.selectedModels.llm || provider !== 'ollama') {
|
||||
this.state.selectedModels.llm = llmModels[0].id;
|
||||
console.log(`[ModelStateService] Selected LLM model from newly configured provider ${provider}: ${llmModels[0].id}`);
|
||||
}
|
||||
}
|
||||
if (sttModels?.length > 0) {
|
||||
if (!this.state.selectedModels.stt || provider !== 'whisper') {
|
||||
this.state.selectedModels.stt = sttModels[0].id;
|
||||
console.log(`[ModelStateService] Selected STT model from newly configured provider ${provider}: ${sttModels[0].id}`);
|
||||
}
|
||||
}
|
||||
this._saveState();
|
||||
this._logCurrentSelection();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
getApiKey(provider) {
|
||||
return this.state.apiKeys[provider] || null;
|
||||
return this.state.apiKeys[provider];
|
||||
}
|
||||
|
||||
getAllApiKeys() {
|
||||
@ -351,6 +422,18 @@ class ModelStateService {
|
||||
return result;
|
||||
}
|
||||
|
||||
hasValidApiKey() {
|
||||
if (this.isLoggedInWithFirebase()) return true;
|
||||
|
||||
// Check if any provider has a valid API key
|
||||
return Object.entries(this.state.apiKeys).some(([provider, key]) => {
|
||||
if (provider === 'ollama' || provider === 'whisper') {
|
||||
return key === 'local';
|
||||
}
|
||||
return key && key.trim().length > 0;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
getAvailableModels(type) {
|
||||
const available = [];
|
||||
@ -445,10 +528,28 @@ class ModelStateService {
|
||||
}
|
||||
|
||||
setupIpcHandlers() {
|
||||
ipcMain.handle('model:validate-key', (e, { provider, key }) => this.validateApiKey(provider, key));
|
||||
ipcMain.handle('model:validate-key', async (e, { provider, key }) => {
|
||||
const result = await this.validateApiKey(provider, key);
|
||||
if (result.success) {
|
||||
// Use 'local' as placeholder for local services
|
||||
const finalKey = (provider === 'ollama' || provider === 'whisper') ? 'local' : key;
|
||||
this.setApiKey(provider, finalKey);
|
||||
// After setting the key, auto-select models
|
||||
this._autoSelectAvailableModels();
|
||||
this._saveState(); // Ensure state is saved after model selection
|
||||
}
|
||||
return result;
|
||||
});
|
||||
ipcMain.handle('model:get-all-keys', () => this.getAllApiKeys());
|
||||
ipcMain.handle('model:set-api-key', (e, { provider, key }) => this.setApiKey(provider, key));
|
||||
ipcMain.handle('model:remove-api-key', (e, { provider }) => {
|
||||
ipcMain.handle('model:set-api-key', async (e, { provider, key }) => {
|
||||
const success = this.setApiKey(provider, key);
|
||||
if (success) {
|
||||
this._autoSelectAvailableModels();
|
||||
await this._saveState();
|
||||
}
|
||||
return success;
|
||||
});
|
||||
ipcMain.handle('model:remove-api-key', async (e, { provider }) => {
|
||||
const success = this.removeApiKey(provider);
|
||||
if (success) {
|
||||
const selectedModels = this.getSelectedModels();
|
||||
@ -461,7 +562,7 @@ class ModelStateService {
|
||||
return success;
|
||||
});
|
||||
ipcMain.handle('model:get-selected-models', () => this.getSelectedModels());
|
||||
ipcMain.handle('model:set-selected-model', (e, { type, modelId }) => this.setSelectedModel(type, modelId));
|
||||
ipcMain.handle('model:set-selected-model', async (e, { type, modelId }) => this.setSelectedModel(type, modelId));
|
||||
ipcMain.handle('model:get-available-models', (e, { type }) => this.getAvailableModels(type));
|
||||
ipcMain.handle('model:are-providers-configured', () => this.areProvidersConfigured());
|
||||
ipcMain.handle('model:get-current-model-info', (e, { type }) => this.getCurrentModelInfo(type));
|
||||
|
@ -33,6 +33,13 @@ class SQLiteClient {
|
||||
return this.db;
|
||||
}
|
||||
|
||||
_validateAndQuoteIdentifier(identifier) {
|
||||
if (!/^[a-zA-Z0-9_]+$/.test(identifier)) {
|
||||
throw new Error(`Invalid database identifier used: ${identifier}. Only alphanumeric characters and underscores are allowed.`);
|
||||
}
|
||||
return `"${identifier}"`;
|
||||
}
|
||||
|
||||
synchronizeSchema() {
|
||||
console.log('[DB Sync] Starting schema synchronization...');
|
||||
const tablesInDb = this.getTablesFromDb();
|
||||
@ -57,25 +64,42 @@ class SQLiteClient {
|
||||
}
|
||||
|
||||
createTable(tableName, tableSchema) {
|
||||
const columnDefs = tableSchema.columns.map(col => `"${col.name}" ${col.type}`).join(', ');
|
||||
const query = `CREATE TABLE IF NOT EXISTS "${tableName}" (${columnDefs})`;
|
||||
const safeTableName = this._validateAndQuoteIdentifier(tableName);
|
||||
const columnDefs = tableSchema.columns
|
||||
.map(col => `${this._validateAndQuoteIdentifier(col.name)} ${col.type}`)
|
||||
.join(', ');
|
||||
|
||||
const constraints = tableSchema.constraints || [];
|
||||
const constraintsDef = constraints.length > 0 ? ', ' + constraints.join(', ') : '';
|
||||
|
||||
const query = `CREATE TABLE IF NOT EXISTS ${safeTableName} (${columnDefs}${constraintsDef})`;
|
||||
console.log(`[DB Sync] Creating table: ${tableName}`);
|
||||
this.db.prepare(query).run();
|
||||
this.db.exec(query);
|
||||
}
|
||||
|
||||
updateTable(tableName, tableSchema) {
|
||||
const existingColumns = this.db.prepare(`PRAGMA table_info("${tableName}")`).all();
|
||||
const existingColumnNames = existingColumns.map(c => c.name);
|
||||
const columnsToAdd = tableSchema.columns.filter(col => !existingColumnNames.includes(col.name));
|
||||
const safeTableName = this._validateAndQuoteIdentifier(tableName);
|
||||
|
||||
if (columnsToAdd.length > 0) {
|
||||
console.log(`[DB Sync] Updating table: ${tableName}. Adding columns: ${columnsToAdd.map(c=>c.name).join(', ')}`);
|
||||
for (const column of columnsToAdd) {
|
||||
const addColumnQuery = `ALTER TABLE "${tableName}" ADD COLUMN "${column.name}" ${column.type}`;
|
||||
this.db.prepare(addColumnQuery).run();
|
||||
// Get current columns
|
||||
const currentColumns = this.db.prepare(`PRAGMA table_info(${safeTableName})`).all();
|
||||
const currentColumnNames = currentColumns.map(col => col.name);
|
||||
|
||||
// Check for new columns to add
|
||||
const newColumns = tableSchema.columns.filter(col => !currentColumnNames.includes(col.name));
|
||||
|
||||
if (newColumns.length > 0) {
|
||||
console.log(`[DB Sync] Adding ${newColumns.length} new column(s) to ${tableName}`);
|
||||
for (const col of newColumns) {
|
||||
const safeColName = this._validateAndQuoteIdentifier(col.name);
|
||||
const addColumnQuery = `ALTER TABLE ${safeTableName} ADD COLUMN ${safeColName} ${col.type}`;
|
||||
this.db.exec(addColumnQuery);
|
||||
console.log(`[DB Sync] Added column ${col.name} to ${tableName}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (tableSchema.constraints && tableSchema.constraints.length > 0) {
|
||||
console.log(`[DB Sync] Note: Constraints for ${tableName} can only be set during table creation`);
|
||||
}
|
||||
}
|
||||
|
||||
runQuery(query, params = []) {
|
||||
|
@ -1,7 +1,6 @@
|
||||
const { ipcMain, BrowserWindow } = require('electron');
|
||||
const Store = require('electron-store');
|
||||
const authService = require('../../common/services/authService');
|
||||
const userRepository = require('../../common/repositories/user');
|
||||
const settingsRepository = require('./repositories');
|
||||
const { getStoredApiKey, getStoredProvider, windowPool } = require('../../electron/windowManager');
|
||||
|
||||
@ -282,26 +281,13 @@ async function deletePreset(id) {
|
||||
|
||||
async function saveApiKey(apiKey, provider = 'openai') {
|
||||
try {
|
||||
const user = authService.getCurrentUser();
|
||||
if (!user.isLoggedIn) {
|
||||
// For non-logged-in users, save to local storage
|
||||
const Store = require('electron-store');
|
||||
const store = new Store();
|
||||
store.set('apiKey', apiKey);
|
||||
store.set('provider', provider);
|
||||
|
||||
// Notify windows
|
||||
BrowserWindow.getAllWindows().forEach(win => {
|
||||
if (!win.isDestroyed()) {
|
||||
win.webContents.send('api-key-validated', apiKey);
|
||||
}
|
||||
});
|
||||
|
||||
return { success: true };
|
||||
// Use ModelStateService as the single source of truth for API key management
|
||||
const modelStateService = global.modelStateService;
|
||||
if (!modelStateService) {
|
||||
throw new Error('ModelStateService not initialized');
|
||||
}
|
||||
|
||||
// For logged-in users, use the repository adapter which injects the UID.
|
||||
await userRepository.saveApiKey(apiKey, provider);
|
||||
await modelStateService.setApiKey(provider, apiKey);
|
||||
|
||||
// Notify windows
|
||||
BrowserWindow.getAllWindows().forEach(win => {
|
||||
@ -319,16 +305,16 @@ async function saveApiKey(apiKey, provider = 'openai') {
|
||||
|
||||
async function removeApiKey() {
|
||||
try {
|
||||
const user = authService.getCurrentUser();
|
||||
if (!user.isLoggedIn) {
|
||||
// For non-logged-in users, remove from local storage
|
||||
const Store = require('electron-store');
|
||||
const store = new Store();
|
||||
store.delete('apiKey');
|
||||
store.delete('provider');
|
||||
} else {
|
||||
// For logged-in users, use the repository adapter.
|
||||
await userRepository.saveApiKey(null, null);
|
||||
// Use ModelStateService as the single source of truth for API key management
|
||||
const modelStateService = global.modelStateService;
|
||||
if (!modelStateService) {
|
||||
throw new Error('ModelStateService not initialized');
|
||||
}
|
||||
|
||||
// Remove all API keys for all providers
|
||||
const providers = ['openai', 'anthropic', 'gemini', 'ollama', 'whisper'];
|
||||
for (const provider of providers) {
|
||||
await modelStateService.removeApiKey(provider);
|
||||
}
|
||||
|
||||
// Notify windows
|
||||
|
10
src/index.js
10
src/index.js
@ -680,13 +680,13 @@ function setupWebDataHandlers() {
|
||||
result = await userRepository.findOrCreate(payload);
|
||||
break;
|
||||
case 'save-api-key':
|
||||
// Assuming payload is { apiKey, provider }
|
||||
result = await userRepository.saveApiKey(payload.apiKey, payload.provider);
|
||||
// Use ModelStateService as the single source of truth for API key management
|
||||
result = await modelStateService.setApiKey(payload.provider, payload.apiKey);
|
||||
break;
|
||||
case 'check-api-key-status':
|
||||
// Adapter injects UID
|
||||
const user = await userRepository.getById();
|
||||
result = { hasApiKey: !!user?.api_key && user.api_key.length > 0 };
|
||||
// Use ModelStateService to check API key status
|
||||
const hasApiKey = await modelStateService.hasValidApiKey();
|
||||
result = { hasApiKey };
|
||||
break;
|
||||
case 'delete-account':
|
||||
// Adapter injects UID
|
||||
|
Loading…
x
Reference in New Issue
Block a user