query
Function
Searches for related images based on the given text and sends the images and prompts to an LLM to generate images.
Prototype
def query(text, llm_config, *args, **kwargs)
Parameters
Parameter |
Data Type |
Required/Optional |
Description |
|---|---|---|---|
text |
String |
Required |
Description of the image to be searched for. The value range is (0, 1 × 1000 × 1000]. |
llm_config |
LLMParameterConfig |
Optional |
Inherits its parent class, which is not used here. |
args |
List |
Optional |
Not used |
kwargs["prompt"] |
String |
Required |
Image generation prompt, which is passed by kwargs. The length range is (0, 1 × 1024 × 1024]. |
kwargs["size"] |
String |
Optional |
Image size, in the format of "height × width", which is specified by kwargs. The supported size depends on the corresponding LLM. The regular expression matching format is ^\d{1,5}\*\d{1,5}$. The default value is 512 × 512. |
Return Value
Data Type |
Description |
|---|---|
Dict, {"prompt": prompt, "result": data} |
data is the Base64-encoded image data. |
Example
# In this example, a new image is generated by the LLM based on the retrieved images and generation prompts.
from paddle.base import libpaddle
from mx_rag.chain import Img2ImgChain
from mx_rag.llm import Img2ImgMultiModel
from mx_rag.retrievers import Retriever
from mx_rag.storage.vectorstore import MindFAISS
from mx_rag.storage.document_store import SQLiteDocstore
from mx_rag.embedding.local import ImageEmbedding
from mx_rag.utils import ClientParam
dev = 0
img_emb = ImageEmbedding(model_name="ViT-B-16", model_path="/path/to/chinese-clip-vit-base-patch16", dev_id=dev)
img_vector_store = MindFAISS(x_dim=512,
devs=[dev],
load_local_index="/path/to/image_faiss.index",
auto_save=True)
chunk_store = SQLiteDocstore(db_path="/path/to/sql.db")
img_retriever = Retriever(vector_store=img_vector_store, document_store=chunk_store,
embed_func=img_emb.embed_documents, k=1, score_threshold=0.5)
multi_model = Img2ImgMultiModel(model_name="sd",
url="img to image url",
client_param=ClientParam(ca_file="/path/to/ca.crt"))
img2img_chain = Img2ImgChain(multi_model=multi_model, retriever=img_retriever)
llm_data = img2img_chain.query("Search for an image of the little boy",
prompt="he is a knight, wearing armor, big sword in right hand. Blur the background, focus on the knight")
print(llm_data)