SoFunction
Updated on 2024-10-28

Pytorch to tflite method

The goal is to try to convert a model trained on a server with pytorch into a tflite model that can be run on mobile.

The most straightforward idea is to try to convert pytorch models to tensorflow models and then to tflite. but this conversion has not been found to be more reliable.

After research, we found that the latest tflite already supports the conversion directly from the keras model, so we can use keras as a bridge for the intermediate conversion, so that we can take full advantage of the convenience of the keras high-level API.

The basic idea of the transformation is to take out the weights of each layer of the network in pytorch and assign them directly to the weights of the corresponding layer in the keras network.

After converting to a Keras model, the model is then converted directly to tflite by converting the model to tflite.

Here is an example, assuming the transformation is a two-layer CNN network.

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import  as nn
# import  as F
from  import Variable

class PytorchNet():
 def __init__(self):
 super(PytorchNet, self).__init__()
 conv1 = (
  nn.Conv2d(3, 32, 3, 2),
  nn.BatchNorm2d(32),
  (inplace=True),
  nn.MaxPool2d(2, 2))
 conv2 = (
  nn.Conv2d(32, 64, 3, 1, groups=1),
  nn.BatchNorm2d(64),
  (inplace=True),
  nn.MaxPool2d(2, 2))
  = (conv1, conv2)
 self.init_weights()

 def forward(self, x):
 return (x)

 def init_weights(self):
 for m in ():
  if isinstance(m, nn.Conv2d):
  .kaiming_normal_(
   , mode='fan_out', nonlinearity='relu')
  if  is not None:
   .zero_()
  if isinstance(m, nn.BatchNorm2d):
  .fill_(1)
  .zero_()

def KerasNet(input_shape=(224, 224, 3)):
 image_input = (shape=input_shape)
 # conv1
 network = .Conv2D(
 32, (3, 3), strides=(2, 2), padding="valid")(image_input)
 network = (
 trainable=False, fused=False)(network)
 network = ("relu")(network)
 network = .MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 # conv2
 network = .Conv2D(
 64, (3, 3), strides=(1, 1), padding="valid")(network)
 network = (
 trainable=False, fused=True)(network)
 network = ("relu")(network)
 network = .MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 model = (inputs=image_input, outputs=network)

 return model

class PytorchToKeras(object):
 def __init__(self, pModel, kModel):
 super(PytorchToKeras, self)
 self.__source_layers = []
 self.__target_layers = []
  = pModel
  = kModel
 .set_learning_phase(0)

 def __retrieve_k_layers(self):
 for i, layer in enumerate():
  if len() > 0:
  self.__target_layers.append(i)

 def __retrieve_p_layers(self, input_size):

 input = (input_size)
 input = Variable((0))
 hooks = []

 def add_hooks(module):

  def hook(module, input, output):
  if hasattr(module, "weight"):
   # print(module)
   self.__source_layers.append(module)

  if not isinstance(module, ) and not isinstance(module, ) and module != :
  (module.register_forward_hook(hook))

 (add_hooks)

 (input)
 for hook in hooks:
  ()

 def convert(self, input_size):
 self.__retrieve_k_layers()
 self.__retrieve_p_layers(input_size)

 for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
  print(source_layer)
  weight_size = len(source_layer.())
  transpose_dims = []
  for i in range(weight_size):
  transpose_dims.append(weight_size - i - 1)
  if isinstance(source_layer, nn.Conv2d):
  transpose_dims = [2,3,1,0]
  [target_layer].set_weights([source_layer.(
  ).transpose(transpose_dims), source_layer.()])
  elif isinstance(source_layer, nn.BatchNorm2d):
  [target_layer].set_weights([source_layer.(), source_layer.(),
        source_layer.running_mean.(), source_layer.running_var.()])
 def save_model(self, output_file):
 (output_file)

 def save_weights(self, output_file):
 .save_weights(output_file, save_format='h5')

pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))

(pytorch_model, '')

#Load the pretrained model
pytorch_model = ('')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
((3, 224, 224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")

# convert keras model to tflite model
converter = .from_keras_model_file(
 "keras_model.h5")
tflite_model = ()
open("convert_model.tflite", "wb").write(tflite_model)

Additional knowledge:Converting tensorflow models to tensorflow lite models

1. Package the graph and network model in a single file

 bazel build tensorflow/python/tools:freeze_graph && \
 bazel-bin/tensorflow/python/tools/freeze_graph \
 --input_graph=eval_graph_def.pb \
 --input_checkpoint=checkpoint \
 --output_graph=frozen_eval_graph.pb \
 --output_node_names=outputs

For example:

 bazel-bin/tensorflow/python/tools/freeze_graph \ 
 --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
 --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
 --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
 --output_node_names=MobilenetV1/Predictions/Reshape_1

2. Convert the tensorflow pb model generated in the first step to a tf lite model

The conversion tool needs to be compiled before conversion

bazel build tensorflow/contrib/lite/toco:toco

There are two types of conversions, one of which converts the tf lite to a float, and the other can be converted to a quantized version of the model that is unit8 to the model. The two ways are as follows:

Non-quantized conversions:

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ The official website gives this wrong path       
./bazel-bin/tensorflow/contrib/lite/toco/toco \         
 —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \  
 —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \  
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \       
 --inference_type=FLOAT \           
 --input_shape="1,224, 224,3" \           
 --input_array=input \            
 --output_array=MobilenetV1/Predictions/Reshape_1 

Quantization mode conversion (note that only quantization-trained models can perform quantization tf_lite conversion):

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
 --input_file=frozen_eval_graph.pb \
 --output_file=tflite_model.tflite \
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
 --inference_type=QUANTIZED_UINT8 \
 --input_shape="1,224, 224,3" \
 --input_array=input \
 --output_array=outputs \
 --std_value=127.5 --mean_value=127.5

This Pytorch to tflite way above is all I have to share with you.