我目前正在尝试使用.FITS图像文件来训练CNN(EfficientNet),但是我认为这也适用于其他图像类型。这种图像需要打开库 astropy 才能打开,在我的情况下,为了访问图像数据,我只需键入:
from astropy.io import fits
path = "path/to/file.fits"
hdul = fits.open(path)
image = hdul[1].data
此变量image
将具有类型numpy.ndarray
。我首先尝试使用keras的image_dataset_from_directory
,但正如预期的那样,它没有成功。然后,我在此处https://www.tensorflow.org/tutorials/load_data/images#using_tfdata_for_finer_control处签出了tf.data
。我尝试创建类似的管道,直到decode_img
函数都解决了。由于我没有处理jpeg,因此尝试了一种变通方法,以便得到:
data_dir = home/astro/train
class_names = np.array(sorted([item.name for item in data_dir.glob('*')]))
# class_names = ["stars", "galaxies"]
def get_label(file_path):
parts = tf.strings.split(file_path, os.path.sep)
one_hot = parts[-2] == class_names
return tf.argmax(one_hot)
def decode_img(img):
hdul = fits.open(img)
data = hdul[1].data
data = data.reshape((data.shape[0], data.shape[1], 1))
data = np.pad(data, [(0,0), (0,0), (0, 2)], 'constant') # padding to create 3 channels
img = tf.convert_to_tensor(data, np.float32)
return tf.image.resize(img, [img_height, img_width])
def process_path(file_path):
label = get_label(file_path)
img = decode_img(file_path)
return img, label
它实际上工作得很好,以某种方式打印process_path
时,我得到两个张量,一个张量用于图像,一个张量用于标签,并具有所需的正确形状和值。
问题:
在学习完本教程之后,请按照以下步骤操作:
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
我收到以下错误:
TypeError Traceback (most recent call last)
in
1 AUTOTUNE = tf.data.experimental.AUTOTUNE
2
----> 3 train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
4 val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
1700 num_parallel_calls,
1701 deterministic,
-> 1702 preserve_cardinality=True)
1703
1704 def flat_map(self, map_func):
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
4082 self._transformation_name(),
4083 dataset=input_dataset,
-> 4084 use_legacy_function=use_legacy_function)
4085 if deterministic is None:
4086 self._deterministic = "default"
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
3369 with tracking.resource_tracker_scope(resource_tracker):
3370 # TODO(b/141462134): Switch to using garbage collection.
-> 3371 self._function = wrapper_fn.get_concrete_function()
3372 if add_to_graph:
3373 self._function.add_to_graph(ops.get_default_graph())
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
2937 """
2938 graph_function = self._get_concrete_function_garbage_collected(
-> 2939 *args, **kwargs)
2940 graph_function._garbage_collector.release() # pylint: disable=protected-access
2941 return graph_function
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
2904 args, kwargs = None, None
2905 with self._lock:
-> 2906 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2907 seen_names = set()
2908 captured = object_identity.ObjectIdentitySet(
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3211
3212 self._function_cache.missed.add(call_context_key)
-> 3213 graph_function = self._create_graph_function(args, kwargs)
3214 self._function_cache.primary[cache_key] = graph_function
3215 return graph_function, args, kwargs
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3073 arg_names=arg_names,
3074 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075 capture_by_value=self._capture_by_value),
3076 self._function_attributes,
3077 function_spec=self.function_spec,
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
984 _, original_func = tf_decorator.unwrap(python_func)
985
--> 986 func_outputs = python_func(*func_args, **func_kwargs)
987
988 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
3362 attributes=defun_kwargs)
3363 def wrapper_fn(*args): # pylint: disable=missing-docstring
-> 3364 ret = _wrapper_helper(*args)
3365 ret = structure.to_tensor_list(self._output_structure, ret)
3366 return [ops.convert_to_tensor(t) for t in ret]
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
3297 nested_args = (nested_args,)
3298
-> 3299 ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
3300 # If `func` returns a list of tensors, `nest.flatten()` and
3301 # `ops.convert_to_tensor()` would conspire to attempt to stack
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
256 except Exception as e: # pylint:disable=broad-except
257 if hasattr(e, 'ag_error_metadata'):
--> 258 raise e.ag_error_metadata.to_exception(e)
259 else:
260 raise
TypeError: in user code:
:17 process_path *
img = decode_img(file_path)
:7 decode_img *
hdul = fits.open(img)
/home/marcostidball/anaconda3/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py:154 fitsopen *
if not name:
/home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/operators/logical.py:29 not_
return _tf_not(a)
/home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/operators/logical.py:35 _tf_not
return gen_math_ops.logical_not(a)
/home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:5481 logical_not
"LogicalNot", x=x, name=name)
/home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:493 _apply_op_helper
(prefix, dtypes.as_dtype(input_arg.type).name))
TypeError: Input 'x' of 'LogicalNot' Op has type string that does not match expected type of bool.
有人知道解决此问题的方法吗?我搜寻了直接使用numpy数组训练CNN的方法,例如在进行张量转换之前得到的数组,并找到了一些使用带有独立keras的MNIST的示例。不过,我想应用常规的数据增强和批处理训练,并且我不确定是否可以通过遵循我所见的方法来实现。
非常感谢!
答案 0 :(得分:0)
我遇到了同样的问题。我认为的问题是(但我不确定),当您制作train_ds
时,它说的是“这种情况将在我们实际需要下一批时发生”。 astropy.io.fits
会在调用时运行,因此会在给定的文件名不存在(或占位符)时抱怨。
我想出的解决方案是编写代码以加载适合文件而不使用astropy
。您可以从此GitHub Repo获取它。可能是您喜欢的越野车,但对我来说可以打开适合的图像(我还没有弄清楚桌子)。