Initial commit: Korean voice-cloning TTS prototype
FastAPI backend, web UI, CosyVoice3/F5-TTS setup scripts, and handoff docs for GPU PC continuation. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
170
backend/app/main.py
Normal file
170
backend/app/main.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.app.config import get_settings, project_root
|
||||
from backend.app.text_preprocess import preprocess_korean
|
||||
from backend.app.tts.service import TTSService
|
||||
|
||||
ROOT = project_root()
|
||||
WEB_DIR = ROOT / "web"
|
||||
|
||||
app = FastAPI(
|
||||
title="Korean Voice Cloning TTS",
|
||||
description="CosyVoice / F5-TTS 기반 한국어 보이스 클로닝 API",
|
||||
version="0.1.0",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
_tts: TTSService | None = None
|
||||
|
||||
|
||||
def get_tts() -> TTSService:
|
||||
global _tts
|
||||
if _tts is None:
|
||||
_tts = TTSService()
|
||||
return _tts
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1, max_length=5000)
|
||||
ref_audio: str | None = Field(
|
||||
default=None, description="samples/ 또는 uploads/ 기준 상대/절대 경로"
|
||||
)
|
||||
ref_text: str | None = None
|
||||
preprocess: bool = True
|
||||
|
||||
|
||||
class TTSResponse(BaseModel):
|
||||
job_id: str
|
||||
audio_url: str
|
||||
model: str
|
||||
text_preview: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model: str
|
||||
samples_count: int
|
||||
|
||||
|
||||
@app.get("/api/health", response_model=HealthResponse)
|
||||
def health() -> HealthResponse:
|
||||
s = get_settings()
|
||||
samples = list(s.samples_dir.glob("*.wav"))
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
model=s.tts_model,
|
||||
samples_count=len(samples),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/tts", response_model=TTSResponse)
|
||||
def create_tts(body: TTSRequest) -> TTSResponse:
|
||||
text = preprocess_korean(body.text) if body.preprocess else body.text.strip()
|
||||
if not text:
|
||||
raise HTTPException(400, "text is empty")
|
||||
|
||||
ref_path: Path | None = None
|
||||
if body.ref_audio:
|
||||
p = Path(body.ref_audio)
|
||||
if not p.is_absolute():
|
||||
for base in (get_settings().samples_dir, get_settings().uploads_dir):
|
||||
candidate = base / p
|
||||
if candidate.is_file():
|
||||
p = candidate
|
||||
break
|
||||
if not p.is_file():
|
||||
raise HTTPException(404, f"ref_audio not found: {body.ref_audio}")
|
||||
ref_path = p
|
||||
|
||||
try:
|
||||
job_id, _ = get_tts().synthesize_to_file(
|
||||
text, ref_audio=ref_path, ref_text=body.ref_text
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(404, str(e)) from e
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(503, str(e)) from e
|
||||
|
||||
return TTSResponse(
|
||||
job_id=job_id,
|
||||
audio_url=f"/api/audio/{job_id}",
|
||||
model=get_settings().tts_model,
|
||||
text_preview=text[:80] + ("…" if len(text) > 80 else ""),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/audio/{job_id}")
|
||||
def get_audio(job_id: str) -> FileResponse:
|
||||
path = get_settings().outputs_dir / job_id / "output.wav"
|
||||
if not path.is_file():
|
||||
alt = get_settings().outputs_dir / job_id / "part_000.wav"
|
||||
path = alt if alt.is_file() else path
|
||||
if not path.is_file():
|
||||
raise HTTPException(404, "audio not found")
|
||||
return FileResponse(path, media_type="audio/wav", filename=f"{job_id}.wav")
|
||||
|
||||
|
||||
@app.get("/api/voice-samples")
|
||||
def list_voice_samples() -> dict:
|
||||
s = get_settings()
|
||||
samples = []
|
||||
for d, label in ((s.samples_dir, "samples"), (s.uploads_dir, "uploads")):
|
||||
for wav in sorted(d.glob("*.wav")):
|
||||
txt = wav.with_suffix(".txt")
|
||||
samples.append(
|
||||
{
|
||||
"id": wav.stem,
|
||||
"path": str(wav),
|
||||
"label": label,
|
||||
"has_transcript": txt.is_file(),
|
||||
}
|
||||
)
|
||||
return {"samples": samples, "default_model": s.tts_model}
|
||||
|
||||
|
||||
@app.post("/api/voice-sample")
|
||||
async def upload_voice_sample(
|
||||
file: UploadFile = File(...),
|
||||
ref_text: str = Form(""),
|
||||
) -> dict:
|
||||
if not file.filename or not file.filename.lower().endswith(".wav"):
|
||||
raise HTTPException(400, "WAV 파일만 업로드 가능합니다")
|
||||
|
||||
sample_id = uuid.uuid4().hex[:10]
|
||||
dest = get_settings().uploads_dir / f"{sample_id}.wav"
|
||||
with open(dest, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
|
||||
if ref_text.strip():
|
||||
(dest.with_suffix(".txt")).write_text(ref_text.strip(), encoding="utf-8")
|
||||
|
||||
return {
|
||||
"id": sample_id,
|
||||
"path": str(dest),
|
||||
"message": "업로드 완료. TTS 요청 시 ref_audio에 이 path를 사용하세요.",
|
||||
}
|
||||
|
||||
|
||||
if WEB_DIR.is_dir():
|
||||
app.mount("/", StaticFiles(directory=str(WEB_DIR), html=True), name="web")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
get_settings().outputs_dir.mkdir(parents=True, exist_ok=True)
|
||||
get_settings().uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
Reference in New Issue
Block a user