FastAPI backend, web UI, CosyVoice3/F5-TTS setup scripts, and handoff docs for GPU PC continuation. Co-authored-by: Cursor <cursoragent@cursor.com>
171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import shutil
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel, Field
|
|
|
|
from backend.app.config import get_settings, project_root
|
|
from backend.app.text_preprocess import preprocess_korean
|
|
from backend.app.tts.service import TTSService
|
|
|
|
ROOT = project_root()
|
|
WEB_DIR = ROOT / "web"
|
|
|
|
app = FastAPI(
|
|
title="Korean Voice Cloning TTS",
|
|
description="CosyVoice / F5-TTS 기반 한국어 보이스 클로닝 API",
|
|
version="0.1.0",
|
|
)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
_tts: TTSService | None = None
|
|
|
|
|
|
def get_tts() -> TTSService:
|
|
global _tts
|
|
if _tts is None:
|
|
_tts = TTSService()
|
|
return _tts
|
|
|
|
|
|
class TTSRequest(BaseModel):
|
|
text: str = Field(..., min_length=1, max_length=5000)
|
|
ref_audio: str | None = Field(
|
|
default=None, description="samples/ 또는 uploads/ 기준 상대/절대 경로"
|
|
)
|
|
ref_text: str | None = None
|
|
preprocess: bool = True
|
|
|
|
|
|
class TTSResponse(BaseModel):
|
|
job_id: str
|
|
audio_url: str
|
|
model: str
|
|
text_preview: str
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
model: str
|
|
samples_count: int
|
|
|
|
|
|
@app.get("/api/health", response_model=HealthResponse)
|
|
def health() -> HealthResponse:
|
|
s = get_settings()
|
|
samples = list(s.samples_dir.glob("*.wav"))
|
|
return HealthResponse(
|
|
status="ok",
|
|
model=s.tts_model,
|
|
samples_count=len(samples),
|
|
)
|
|
|
|
|
|
@app.post("/api/tts", response_model=TTSResponse)
|
|
def create_tts(body: TTSRequest) -> TTSResponse:
|
|
text = preprocess_korean(body.text) if body.preprocess else body.text.strip()
|
|
if not text:
|
|
raise HTTPException(400, "text is empty")
|
|
|
|
ref_path: Path | None = None
|
|
if body.ref_audio:
|
|
p = Path(body.ref_audio)
|
|
if not p.is_absolute():
|
|
for base in (get_settings().samples_dir, get_settings().uploads_dir):
|
|
candidate = base / p
|
|
if candidate.is_file():
|
|
p = candidate
|
|
break
|
|
if not p.is_file():
|
|
raise HTTPException(404, f"ref_audio not found: {body.ref_audio}")
|
|
ref_path = p
|
|
|
|
try:
|
|
job_id, _ = get_tts().synthesize_to_file(
|
|
text, ref_audio=ref_path, ref_text=body.ref_text
|
|
)
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(404, str(e)) from e
|
|
except RuntimeError as e:
|
|
raise HTTPException(503, str(e)) from e
|
|
|
|
return TTSResponse(
|
|
job_id=job_id,
|
|
audio_url=f"/api/audio/{job_id}",
|
|
model=get_settings().tts_model,
|
|
text_preview=text[:80] + ("…" if len(text) > 80 else ""),
|
|
)
|
|
|
|
|
|
@app.get("/api/audio/{job_id}")
|
|
def get_audio(job_id: str) -> FileResponse:
|
|
path = get_settings().outputs_dir / job_id / "output.wav"
|
|
if not path.is_file():
|
|
alt = get_settings().outputs_dir / job_id / "part_000.wav"
|
|
path = alt if alt.is_file() else path
|
|
if not path.is_file():
|
|
raise HTTPException(404, "audio not found")
|
|
return FileResponse(path, media_type="audio/wav", filename=f"{job_id}.wav")
|
|
|
|
|
|
@app.get("/api/voice-samples")
|
|
def list_voice_samples() -> dict:
|
|
s = get_settings()
|
|
samples = []
|
|
for d, label in ((s.samples_dir, "samples"), (s.uploads_dir, "uploads")):
|
|
for wav in sorted(d.glob("*.wav")):
|
|
txt = wav.with_suffix(".txt")
|
|
samples.append(
|
|
{
|
|
"id": wav.stem,
|
|
"path": str(wav),
|
|
"label": label,
|
|
"has_transcript": txt.is_file(),
|
|
}
|
|
)
|
|
return {"samples": samples, "default_model": s.tts_model}
|
|
|
|
|
|
@app.post("/api/voice-sample")
|
|
async def upload_voice_sample(
|
|
file: UploadFile = File(...),
|
|
ref_text: str = Form(""),
|
|
) -> dict:
|
|
if not file.filename or not file.filename.lower().endswith(".wav"):
|
|
raise HTTPException(400, "WAV 파일만 업로드 가능합니다")
|
|
|
|
sample_id = uuid.uuid4().hex[:10]
|
|
dest = get_settings().uploads_dir / f"{sample_id}.wav"
|
|
with open(dest, "wb") as f:
|
|
shutil.copyfileobj(file.file, f)
|
|
|
|
if ref_text.strip():
|
|
(dest.with_suffix(".txt")).write_text(ref_text.strip(), encoding="utf-8")
|
|
|
|
return {
|
|
"id": sample_id,
|
|
"path": str(dest),
|
|
"message": "업로드 완료. TTS 요청 시 ref_audio에 이 path를 사용하세요.",
|
|
}
|
|
|
|
|
|
if WEB_DIR.is_dir():
|
|
app.mount("/", StaticFiles(directory=str(WEB_DIR), html=True), name="web")
|
|
|
|
|
|
@app.on_event("startup")
|
|
def startup() -> None:
|
|
get_settings().outputs_dir.mkdir(parents=True, exist_ok=True)
|
|
get_settings().uploads_dir.mkdir(parents=True, exist_ok=True)
|