The goal is to try to convert a model trained on a server with pytorch into a tflite model that can be run on mobile.
The most straightforward idea is to try to convert pytorch models to tensorflow models and then to tflite. but this conversion has not been found to be more reliable.
After research, we found that the latest tflite already supports the conversion directly from the keras model, so we can use keras as a bridge for the intermediate conversion, so that we can take full advantage of the convenience of the keras high-level API.
The basic idea of the transformation is to take out the weights of each layer of the network in pytorch and assign them directly to the weights of the corresponding layer in the keras network.
After converting to a Keras model, the model is then converted directly to tflite by converting the model to tflite.
Here is an example, assuming the transformation is a two-layer CNN network.
import tensorflow as tf from tensorflow import keras import numpy as np import torch from torchvision import models import as nn # import as F from import Variable class PytorchNet(): def __init__(self): super(PytorchNet, self).__init__() conv1 = ( nn.Conv2d(3, 32, 3, 2), nn.BatchNorm2d(32), (inplace=True), nn.MaxPool2d(2, 2)) conv2 = ( nn.Conv2d(32, 64, 3, 1, groups=1), nn.BatchNorm2d(64), (inplace=True), nn.MaxPool2d(2, 2)) = (conv1, conv2) self.init_weights() def forward(self, x): return (x) def init_weights(self): for m in (): if isinstance(m, nn.Conv2d): .kaiming_normal_( , mode='fan_out', nonlinearity='relu') if is not None: .zero_() if isinstance(m, nn.BatchNorm2d): .fill_(1) .zero_() def KerasNet(input_shape=(224, 224, 3)): image_input = (shape=input_shape) # conv1 network = .Conv2D( 32, (3, 3), strides=(2, 2), padding="valid")(image_input) network = ( trainable=False, fused=False)(network) network = ("relu")(network) network = .MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network) # conv2 network = .Conv2D( 64, (3, 3), strides=(1, 1), padding="valid")(network) network = ( trainable=False, fused=True)(network) network = ("relu")(network) network = .MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network) model = (inputs=image_input, outputs=network) return model class PytorchToKeras(object): def __init__(self, pModel, kModel): super(PytorchToKeras, self) self.__source_layers = [] self.__target_layers = [] = pModel = kModel .set_learning_phase(0) def __retrieve_k_layers(self): for i, layer in enumerate(): if len() > 0: self.__target_layers.append(i) def __retrieve_p_layers(self, input_size): input = (input_size) input = Variable((0)) hooks = [] def add_hooks(module): def hook(module, input, output): if hasattr(module, "weight"): # print(module) self.__source_layers.append(module) if not isinstance(module, ) and not isinstance(module, ) and module != : (module.register_forward_hook(hook)) (add_hooks) (input) for hook in hooks: () def convert(self, input_size): self.__retrieve_k_layers() self.__retrieve_p_layers(input_size) for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)): print(source_layer) weight_size = len(source_layer.()) transpose_dims = [] for i in range(weight_size): transpose_dims.append(weight_size - i - 1) if isinstance(source_layer, nn.Conv2d): transpose_dims = [2,3,1,0] [target_layer].set_weights([source_layer.( ).transpose(transpose_dims), source_layer.()]) elif isinstance(source_layer, nn.BatchNorm2d): [target_layer].set_weights([source_layer.(), source_layer.(), source_layer.running_mean.(), source_layer.running_var.()]) def save_model(self, output_file): (output_file) def save_weights(self, output_file): .save_weights(output_file, save_format='h5') pytorch_model = PytorchNet() keras_model = KerasNet(input_shape=(224, 224, 3)) (pytorch_model, '') #Load the pretrained model pytorch_model = ('') # #Time to transfer weights converter = PytorchToKeras(pytorch_model, keras_model) ((3, 224, 224)) # #Save the converted keras model for later use # converter.save_weights("keras.h5") converter.save_model("keras_model.h5") # convert keras model to tflite model converter = .from_keras_model_file( "keras_model.h5") tflite_model = () open("convert_model.tflite", "wb").write(tflite_model)
Additional knowledge:Converting tensorflow models to tensorflow lite models
1. Package the graph and network model in a single file
bazel build tensorflow/python/tools:freeze_graph && \ bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=eval_graph_def.pb \ --input_checkpoint=checkpoint \ --output_graph=frozen_eval_graph.pb \ --output_node_names=outputs
For example:
bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \ --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \ --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \ --output_node_names=MobilenetV1/Predictions/Reshape_1
2. Convert the tensorflow pb model generated in the first step to a tf lite model
The conversion tool needs to be compiled before conversion
bazel build tensorflow/contrib/lite/toco:toco
There are two types of conversions, one of which converts the tf lite to a float, and the other can be converted to a quantized version of the model that is unit8 to the model. The two ways are as follows:
Non-quantized conversions:
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ The official website gives this wrong path ./bazel-bin/tensorflow/contrib/lite/toco/toco \ —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \ —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ --inference_type=FLOAT \ --input_shape="1,224, 224,3" \ --input_array=input \ --output_array=MobilenetV1/Predictions/Reshape_1
Quantization mode conversion (note that only quantization-trained models can perform quantization tf_lite conversion):
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ ./bazel-bin/tensorflow/contrib/lite/toco/toco \ --input_file=frozen_eval_graph.pb \ --output_file=tflite_model.tflite \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ --inference_type=QUANTIZED_UINT8 \ --input_shape="1,224, 224,3" \ --input_array=input \ --output_array=outputs \ --std_value=127.5 --mean_value=127.5
This Pytorch to tflite way above is all I have to share with you.