huggingface/candle

Feature Request: Implement a Keras-Like API

Open

#1153 opened on Oct 22, 2023

View on GitHub
 (6 comments) (2 reactions) (0 assignees)Rust (19,476 stars) (1,440 forks)batch import
help wanted

Description

Feature Request: Implement a Keras-Like API

Description:

I would like to propose the addition of a high-level Keras-like API to our Rust crate. This API would provide a more intuitive and user-friendly way to define, compile, and train neural network models.

Motivation:

Currently candle provides a low-level API for building neural network models. This Keras-like API would greatly enhance the usability of the library, making it more accessible to a wider audience of users.

Proposed Changes:

  • Introduce a new module or structure, such as Sequential, that allows users to define models in a sequential manner by adding layers one after the other.
  • Implement methods for adding different types of layers (e.g., add_linear, add_conv2d, add_lstm, etc.) to the sequential model.
  • Add a compile method to the Sequential or an equivalent structure, which would take care of model compilation, including specifying loss functions, optimizers, and evaluation metrics.
  • Provide user-friendly methods for training the model, such as fit for training and evaluate for model evaluation.
  • Ensure that the Keras-like API integrates seamlessly with the existing low-level API for maximum flexibility.

Example I've implemented a sketch from the mnist forward example:

use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};

struct Sequential {
    layers: Vec<Linear>,
}

impl Sequential {
    fn new() -> Self {
        Sequential {
            layers: Vec::new(),
        }
    }

    fn add(&mut self, layer: Linear) {
        self.layers.push(layer);
    }

    fn compile(&self) -> Model {
        Model::new( &self.layers)
    }
}

struct Model {
    layers: Vec<Linear>,
}

impl Model {
    fn new(layers: &Vec<Linear>) -> Model {
        Model { layers: layers.clone() }
    }

    fn forward(&self, image: &Tensor) -> Result<Tensor> {
        let mut x = image.clone();
        for layer in &self.layers {
            x = layer.forward(&x)?;
            x = x.relu()?;
        }
        Ok(x)
    }
}

fn main() -> Result<()> {
    let device = Device::Cpu;

    let mut model = Sequential::new();
    model.add(Linear::new(
        Tensor::randn(0f32, 1.0, (100, 784), &device)?,
        Some(Tensor::randn(0f32, 1.0, (100,), &device)?),
    ));
    model.add(Linear::new(
        Tensor::randn(0f32, 1.0, (10, 100), &device)?,
        Some(Tensor::randn(0f32, 1.0, (10,), &device)?),
    ));

    let compiled_model = model.compile();

    let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;

    let digit = compiled_model.forward(&dummy_image)?;
    println!("Digit: {:?}", digit);
    Ok(())
}

I believe that implementing a Keras-like API would greatly enhance the usability and appeal of candle. This feature can make it easier for users to define and train neural network models, making our crate more accessible and user-friendly. We welcome feedback and discussions on this feature proposal.

Contributor guide