Back

AWS Lambda Proxy For AI Services

AI Service Providers like OpenAI and AWS offer Large Language Models, Speech To Text, and Text To Speech services.

I prefer to build AI-based applications that run strictly in-browser and that interact directly with the APIs of these services.

But sometimes in-browser isn't enough, and I need a lightweight backend.

To keep it as lightweight and as easy-to-manage as possible, I currently use a single AWS Lambda with an exposed function URL.

It has the following features:

Prerequisites & Lambda Configuration

Code


// Import AWS SDK clients and dependencies, which should already be available in the Lambda environment.
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand } from "@aws-sdk/client-bedrock-runtime";
import { ListFoundationModelsCommand, BedrockClient } from "@aws-sdk/client-bedrock";
import { PollyClient, SynthesizeSpeechCommand } from "@aws-sdk/client-polly";

// Import Readable for more easily streaming responses (currently only used for Polly audio streaming)
import { Readable } from 'stream';

// Import fetch for handling HTTP/2, which built-in Node.js fetch does not yet fully support
// For simplicity's sake, I built node-fetch locally and literally uploaded the file to the lambda
import fetch from './node-fetch.mjs';

// Pull in PASSWORD from environment variable to protect API endpoints with a faux bearer token
const PASSWORD = process?.env?.PASSWORD;

// Pull in OPENAI_API_KEY from environment variable to use OpenAI API
const OPENAI_API_KEY = process?.env?.OPENAI_API_KEY;

// Define API endpoints in the style of OpenAI
// Prefix with /provider/ to differentiate between different providers
const apiEndpoints = {
    '/open-ai/models': {
        method: ['GET'],
        handler: proxyTo('https://api.openai.com/v1/models', OPENAI_API_KEY),
        protection: 'password'
    },
    '/open-ai/chat/completions': {
        method: ['POST'],
        handler: proxyTo('https://api.openai.com/v1/chat/completions', OPENAI_API_KEY),
        protection: 'password'
    },
    '/open-ai/transcriptions': {
        method: ['POST'],
        handler: proxyTo('https://api.openai.com/v1/audio/transcriptions', OPENAI_API_KEY),
        protection: 'password'
    },
    '/open-ai/audio/speech': {
        method: ['POST'],
        handler: proxyTo('https://api.openai.com/v1/audio/speech', OPENAI_API_KEY, true),
        protection: 'password'
    },
    '/bedrock/chat/completions': {
        method: ['POST'],
        handler: bedrockConverse(),
        protection: 'password'
    },
    '/bedrock/models': {
        method: ['GET'],
        handler: bedrockListModels(),
        protection: 'password'
    },
    '/bedrock/audio/speech': {
        method: ['POST'],
        handler: pollySpeak(),
        protection: 'password'
    },
};

// Use the arcane streamifyResponse function to handle streaming responses from the Lambda
export const handler = awslambda.streamifyResponse(async (event, responseStream, _context) => {
    try {
        await processAPIRequest(event, responseStream);
    } catch (error) {
        errorResponse(responseStream, 500, error.message);
    }
});

// Route to the appropriate API endpoint based on the path and method
// Enforce password protection if necessary
async function processAPIRequest(event, responseStream) {
    const path = event?.rawPath;
    const method = event?.requestContext?.http?.method;
    const endpoint = apiEndpoints[path];
    if (!endpoint || !endpoint.method.includes(method)) {
        errorResponse(responseStream, 404, 'Not Found');
        return;
    }
    if (endpoint.protection === 'password') {
        const authHeader = event.headers.authorization;
        if (authHeader !== 'Bearer ' + PASSWORD) {
            errorResponse(responseStream, 401, 'Unauthorized - Invalid password');
            return;
        }
    }
    await endpoint.handler(event, responseStream);
}

// Handle interactions with Bedrock LLMs via AWS SDKs, reformatting responses to match OpenAI API specs
function bedrockConverse() {
    return async (event, responseStream) => {
        const client = new BedrockRuntimeClient({ region: "us-east-1" });
        const { model, messages, stream, inferenceConfig } = JSON.parse(event.body);
        const formattedMessages = messages.map(({ role, content }) => ({ role, content: [{ text: content }] }));

        try {
            
            if(stream) {
                const command = new ConverseStreamCommand({
                    modelId: model,
                    messages: formattedMessages,
                    inferenceConfig: inferenceConfig || { maxTokens: 4096, temperature: 0.5, topP: 0.9 },
                });
                const response = await client.send(command);
                for await (const item of response.stream) {
                    if (item.contentBlockDelta) {
                        const responseText = item.contentBlockDelta.delta?.text;
                        const partialCompletion = reformatPlainTextToOpenAIPartialCompletion(responseText, model);
                        responseStream.write(partialCompletion);
                    }
                }
                responseStream.write('data: [DONE]');
            } else {
                const command = new ConverseCommand({
                    modelId: model,
                    messages: formattedMessages,
                    inferenceConfig: inferenceConfig || { maxTokens: 4096, temperature: 0.5, topP: 0.9 },
                });
                const response = await client.send(command);
                const reformatted = reformatBedrockCompletionToOpenAI(response, model);
                successResponse(responseStream, reformatted);
            }
        } catch (err) {
            console.log(`ERROR: Can't invoke '${model}'. Reason: ${err}`);
            errorResponse(responseStream, 500, `Error invoking model: ${err.message}`);
        } finally {
            responseStream.end();
        }
    };
}

// Handle listing of Bedrock models via AWS SDKs, reformatting responses to match OpenAI API specs
function bedrockListModels() {
    return async (event, responseStream) => {
        const client = new BedrockClient({ region: "us-east-1" });
        const input = { byOutputModality: "TEXT" };
        try {
            const response = await client.send(new ListFoundationModelsCommand(input));
            const reformatted = reformatBedrockListToOpenAI(response);
            successResponse(responseStream, reformatted);
        } catch (err) {
            errorResponse(responseStream, 500, `Error listing models: ${err.message}`);
        }
    };
}

// Handle speech synthesis via AWS SDKs, streaming the audio response
function pollySpeak() {
    return async (event, responseStream) => {
        const body = event.body ? JSON.parse(event.body) : event;
        const { input } = body;
        const client = new PollyClient({ region: "us-east-1" });
        const parameters = {
            Engine: 'generative',
            OutputFormat: "mp3",
            Text: input,
            TextType: "text",
            VoiceId: 'Ruth'
        };
        const command = new SynthesizeSpeechCommand(parameters);
        const pollyResponse = await client.send(command);
        const pollyStream = pollyResponse.AudioStream;
        responseStream.write(JSON.stringify({
            statusCode: 200,
            headers: {
                'Content-Type': 'audio/mpeg',
            }
        }));
        responseStream.write('\n');
        Readable.fromWeb(Readable.toWeb(pollyStream)).pipe(responseStream);
    };
}

// Reformat Bedrock LLM response to match OpenAI API specs - non-streaming
function reformatBedrockCompletionToOpenAI(bedrockResponse, modelId) {
    return {
      id: bedrockResponse.$metadata.requestId,
      object: "chat.completion",
      created: Math.floor(Date.now() / 1000),
      model: modelId,
      choices: [
        {
          index: 0,
          message: {
            role: bedrockResponse.output.message.role,
            content: bedrockResponse.output.message.content[0].text
          },
          finish_reason: bedrockResponse.stopReason === "end_turn" ? "stop" : bedrockResponse.stopReason
        }
      ],
      usage: {
        prompt_tokens: bedrockResponse.usage.inputTokens,
        completion_tokens: bedrockResponse.usage.outputTokens,
        total_tokens: bedrockResponse.usage.totalTokens
      }
    };
}

// Reformat bedrock LLM response to match OpenAI API specs - streaming
function reformatPlainTextToOpenAIPartialCompletion(text, modelId){
    const json = {
        id: "partial-completion",
        object: "chat.completion.chunk",
        created: Math.floor(Date.now() / 1000),
        model: modelId,
        choices: [
            {
                index: 0,
                delta: {
                    content: text
                },
                finish_reason: null
            }
        ]
    };
    const jsonString = JSON.stringify(json);
    return 'data: ' + jsonString + '\n';

}

// Reformat Bedrock LLM model list to match OpenAI API specs
function reformatBedrockListToOpenAI(bedrockResponse) {
    return {
      object: "list",
      data: bedrockResponse.modelSummaries.map(model => ({
        id: model.modelId,
        object: "model",
        created: Math.floor(Date.now() / 1000),
        owned_by: model.providerName || "system"
      }))
    };
}

// When proxying, filter out headers that could cause issues
function filterHeaders(headers) {
    const allowedHeaders = ['content-type', 'content-length', 'accept', 'accept-encoding', 'accept-language'];
    return Object.fromEntries(
        Object.entries(headers)
            .filter(([key]) => allowedHeaders.includes(key))
    );
}

// Proxy requests to external APIs like OpenAI, adding an API key if necessary
// A bit of a mess, honestly, but it works
function proxyTo(url, apiKey) {
    return async (event, responseStream) => {
        const allowedHeaders = filterHeaders(event.headers);
        const { method } = event.requestContext.http;
        const isMultipart = allowedHeaders['content-type']?.includes('multipart');
        const requestBody = isMultipart ? handleMultipartRequest(event.body) : event.body;

        try {
            const response = await fetch(url, {
                method,
                headers: {
                    ...allowedHeaders,
                    Authorization: `Bearer ${apiKey}`,
                },
                body: requestBody,
            });
            let stream = false;
            if(method === 'POST' && requestBody){
                const requestJson = JSON.parse(requestBody);
                stream = requestJson.stream;
            }
            // If url includes audio, stream is true
            if(url.includes('audio')){
                stream = true;
            }

            if (stream) {
                console.log('streaming response');
                for await (const chunk of response.body) {
                    responseStream.write(chunk);
                }
                responseStream.end();
            } else {
                console.log('non-streaming response');
                const responseContentType = response.headers.get('content-type');

                if (responseContentType?.includes('audio/mpeg')) {
                    const audioBody = await response.arrayBuffer();
                    responseStream.write(Buffer.from(audioBody).toString('base64'));
                    responseStream.end();
                } else if (!response.ok) {
                    errorResponse(responseStream, response.status, response.statusText);
                } else {
                    const jsonResponse = await response.json();
                    successResponse(responseStream, jsonResponse);
                }
            }
        } catch (err) {
            errorResponse(responseStream, 500, `Error proxying request: ${err.message}`);
        }
    };
}

// For transcription requests
function handleMultipartRequest(body) {
    return typeof Buffer !== 'undefined' ? Buffer.from(body, 'base64') : body;
}

// Provide an error response
function errorResponse(responseStream, code, message) {
    responseStream.setContentType('application/json');
    responseStream.write(JSON.stringify({
        statusCode: code,
        body: {
            "error": message
        }
    }));
    responseStream.end();
}

// Provide a successful response
function successResponse(responseStream, body) {
    responseStream.setContentType('application/json');
    responseStream.write(JSON.stringify({
        ...body
    }));
    responseStream.end();
}