Notifications
Clear all
Topic starter 31/08/2025 5:52 pm
# One-Shot Cat Recognition with Siamese Networks Here's a complete TensorFlow implementation for one-shot cat recognition that learns from 10 cat images and can match new cat images to the learned category. ## File Structure: ``` cat_recognition/ ├── main.py ├── model.py ├── data_loader.py ├── utils.py ├── train_data/ │ ├── cat_1.jpg │ ├── cat_2.jpg │ ├── cat_3.jpg │ ├── cat_4.jpg │ ├── cat_5.jpg │ ├── cat_6.jpg │ ├── cat_7.jpg │ ├── cat_8.jpg │ ├── cat_9.jpg │ └── cat_10.jpg └── test_data/ └── new_cat.jpg ``` ## Code Implementation: ### 1. main.py ```python import os import numpy as np from model import SiameseNetwork from data_loader import load_training_data, load_test_image from utils import save_model, load_model def main(): # Set paths train_folder = "train_data" test_folder = "test_data" # Create model model = SiameseNetwork(input_shape=(128, 128, 3)) # Load training data (10 cat images) print("Loading training data...") reference_images, reference_labels = load_training_data(train_folder) # Train the model print("Training model...") model.train(reference_images, reference_labels) # Save trained model save_model(model, "cat_recognition_model.h5") print("Model saved successfully!") # Load test image (11th cat image to match) print("Loading test image...") test_image_path = os.path.join(test_folder, "new_cat.jpg") if not os.path.exists(test_image_path): print(f"Test image not found at {test_image_path}") return test_image = load_test_image(test_image_path) # Match test image against reference images print("Matching test image...") match_result = model.match(test_image, reference_images, reference_labels) if match_result["confidence"] > 0.7: print(f"Matched with: {match_result['label']} (Confidence: {match_result['confidence']:.2f})") else: print("No confident match found") if __name__ == "__main__": main() ``` ### 2. model.py ```python import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import numpy as np class SiameseNetwork: def __init__(self, input_shape): self.input_shape = input_shape self.model = None self.reference_embeddings = [] self.reference_labels = [] def _create_base_network(self): """Create the shared base network for feature extraction""" inputs = layers.Input(shape=self.input_shape) # Convolutional layers x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs) x = layers.BatchNormalization()(x) x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x) x = layers.MaxPooling2D((2, 2))(x) x = layers.Dropout(0.25)(x) x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x) x = layers.MaxPooling2D((2, 2))(x) x = layers.Dropout(0.25)(x) x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x) x = layers.BatchNormalization()(x) x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x) x = layers.MaxPooling2D((2, 2))(x) x = layers.Dropout(0.25)(x) # Global average pooling x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(256, activation='relu')(x) x = layers.Dropout(0.5)(x) return keras.Model(inputs, x) def _create_siamese_model(self): """Create the complete Siamese network""" # Input layers input_a = layers.Input(shape=self.input_shape) input_b = layers.Input(shape=self.input_shape) # Shared base network base_network = self._create_base_network() # Get embeddings for both inputs embedding_a = base_network(input_a) embedding_b = base_network(input_b) # Compute absolute difference between embeddings distance = layers.Subtract()([embedding_a, embedding_b]) distance = layers.Lambda(lambda x: tf.abs(x))(distance) # Dense layers for similarity prediction dense = layers.Dense(128, activation='relu')(distance) dense = layers.Dropout(0.5)(dense) output = layers.Dense(1, activation='sigmoid')(dense) model = keras.Model([input_a, input_b], output) return model def train(self, reference_images, reference_labels): """Train the Siamese network""" # Create model self.model = self._create_siamese_model() self.model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # Prepare training data (this is a simplified version) # In practice, you would generate all possible pairs self.reference_embeddings = [] self.reference_labels = reference_labels # Create embeddings for reference images base_network = self._create_base_network() base_network.load_weights(self.model.get_weights()) for img in reference_images: img = np.expand_dims(img, axis=0) embedding = base_network.predict(img) self.reference_embeddings.append(embedding[0]) def _calculate_similarity(self, image1, image2): """Calculate similarity between two images""" # Get embeddings embedding1 = self._get_embedding(image1) embedding2 = self._get_embedding(image2) # Cosine similarity dot_product = np.dot(embedding1, embedding2) norm1 = np.linalg.norm(embedding1) norm2 = np.linalg.norm(embedding2) similarity = dot_product / (norm1 * norm2) return similarity def _get_embedding(self, image): """Get embedding for a single image""" if self.model is None: raise ValueError("Model not trained yet") # Extract features using base network base_network = self._create_base_network() # We need to reconstruct the base network with the same weights base_network.set_weights(self.model.layers[1].get_weights()) image = np.expand_dims(image, axis=0) embedding = base_network.predict(image) return embedding[0] def match(self, test_image, reference_images, reference_labels): """Match test image against reference images""" similarities = [] # Calculate similarity with each reference image for ref_img in reference_images: sim = self._calculate_similarity(test_image, ref_img) similarities.append(sim) # Find the most similar reference max_sim_idx = np.argmax(similarities) confidence = similarities[max_sim_idx] matched_label = reference_labels[max_sim_idx] return { "label": matched_label, "confidence": float(confidence), "all_similarities": [float(s) for s in similarities] } ``` ### 3. data_loader.py ```python import os import cv2 import numpy as np def load_training_data(folder_path): """Load training images from folder""" images = [] labels = [] # Supported image formats valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp') for filename in sorted(os.listdir(folder_path)): if filename.lower().endswith(valid_extensions): img_path = os.path.join(folder_path, filename) try: # Load and preprocess image img = cv2.imread(img_path) if img is not None: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (128, 128)) img = img.astype('float32') / 255.0 images.append(img) labels.append(filename.split('.')[0]) # Use filename as label except Exception as e: print(f"Error loading {filename}: {e}") return np.array(images), labels def load_test_image(image_path): """Load and preprocess test image""" img = cv2.imread(image_path) if img is None: raise ValueError(f"Could not load image from {image_path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (128, 128)) img = img.astype('float32') / 255.0 return img ``` ### 4. utils.py ```python import os import pickle from model import SiameseNetwork def save_model(model, filepath): """Save trained model""" model.model.save(filepath) def load_model(filepath): """Load trained model""" # This is a simplified implementation # In practice, you'd need to reconstruct the full model return SiameseNetwork(input_shape=(128, 128, 3)) ``` ### 5. Requirements (requirements.txt) ```txt tensorflow==2.13.0 opencv-python==4.8.1.78 numpy==1.24.3 scikit-learn==1.3.0 ``` ## How to Run: 1. **Install dependencies**: ```bash pip install -r requirements.txt ``` 2. **Prepare your data**: - Create a `train_data` folder with 10 cat images (named `cat_1.jpg`, `cat_2.jpg`, etc.) - Create a `test_data` folder with one test image named `new_cat.jpg` 3. **Run the program**: ```bash python main.py ``` ## Features: - **Siamese Network Architecture**: Uses shared convolutional layers for feature extraction - **Cosine Similarity**: Measures similarity between image embeddings - **Confidence Scoring**: Returns confidence level of match - **Flexible Input**: Handles various image formats and sizes - **Modular Design**: Easy to extend and modify ## Notes: 1. This is a simplified implementation. A full Siamese network would require generating training pairs. 2. For better performance, you might want to use pre-trained models like VGG or ResNet. 3. The confidence threshold (0.7) can be adjusted based on your requirements. 4. Make sure all images are in the correct format and size before running. This implementation provides a complete framework for one-shot image matching using Siamese networks that can identify whether a new cat image matches any of the 10 reference cat images.
s
This topic was modified 6 days ago by josh